From 842bd347ae12aac1e1b3698f21dd157e241f2f1a Mon Sep 17 00:00:00 2001 From: jsonbailey Date: Mon, 6 Apr 2026 12:13:05 -0500 Subject: [PATCH 1/6] fix: Replace fan-out with handoff routing and fix OpenAI message ordering - Add _coalesce_tool_messages_for_openai to fix parallel fan-out causing OpenAI 400 errors when sibling branch tool_calls have no ToolMessages - Add LDMetricsCallbackHandler for per-node token/tool/latency tracking - Add build_structured_tools with async callable support via inspect - Replace fan-out static edges with LLM-driven handoff tools (Command(goto=)) - Switch WorkflowState.messages to add_messages reducer - Add parallel_tool_calls=False when handoff tools present - Remove _coalesce_tool_messages_for_openai (no longer needed with handoffs) - Demote diagnostic logs to debug; exclude handoff tools from LD tracking - Consolidate token extraction to use get_ai_usage_from_response - Fix TestBuildTools to import build_structured_tools; cover async callables Co-Authored-By: Claude Sonnet 4.6 --- .../src/ldai_langchain/langchain_helper.py | 41 +- .../langgraph_agent_graph_runner.py | 290 ++++++--- .../langgraph_callback_handler.py | 218 +++++++ .../tests/test_langchain_provider.py | 49 ++ .../test_langgraph_agent_graph_runner.py | 2 +- .../tests/test_langgraph_callback_handler.py | 478 ++++++++++++++ .../tests/test_tracking_langgraph.py | 614 ++++++++++++++++++ 7 files changed, 1600 insertions(+), 92 deletions(-) create mode 100644 packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py create mode 100644 packages/ai-providers/server-ai-langchain/tests/test_langgraph_callback_handler.py create mode 100644 packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_helper.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_helper.py index 677d2435..2179f667 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_helper.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_helper.py @@ -1,3 +1,4 @@ +import inspect from typing import Any, Dict, List, Optional, Union from langchain_core.language_models.chat_models import BaseChatModel @@ -156,9 +157,10 @@ def build_structured_tools(ai_config: AIConfigKind, tool_registry: ToolRegistry) """ Build a list of LangChain StructuredTool instances from LD tool definitions and a registry. - Tools found in the registry are wrapped as StructuredTool with the name and description - from the LD config. Built-in provider tools and tools missing from the registry are - skipped with a warning. + Tools found in the registry are wrapped as StructuredTool using the LD config key as the + tool name so the model's tool calls match ToolNode lookup. Async callables use ``coroutine=`` + so LangGraph invokes them correctly. Built-in provider tools and tools missing from the + registry are skipped with a warning. :param ai_config: The LaunchDarkly AI configuration :param tool_registry: Registry mapping tool names to callable implementations @@ -171,14 +173,17 @@ def build_structured_tools(ai_config: AIConfigKind, tool_registry: ToolRegistry) parameters = dict(model_dict.get('parameters') or {}) tool_definitions = parameters.pop('tools', []) or [] - return [ - StructuredTool.from_function( - func=tool_registry[name], - name=name, - description=td.get('description', ''), - ) - for name, td in _iter_valid_tools(tool_definitions, tool_registry) - ] + tools = [] + for name, td in _iter_valid_tools(tool_definitions, tool_registry): + fn = tool_registry[name] + raw_desc = td.get('description') if isinstance(td.get('description'), str) else '' + description = raw_desc.strip() or (getattr(fn, '__doc__', None) or '').strip() or f'Tool {name}' + if inspect.iscoroutinefunction(fn): + tool = StructuredTool.from_function(coroutine=fn, name=name, description=description) + else: + tool = StructuredTool.from_function(fn, name=name, description=description) + tools.append(tool) + return tools def get_ai_usage_from_response(response: Any) -> Optional[TokenUsage]: @@ -234,6 +239,20 @@ def get_tool_calls_from_response(response: Any) -> List[str]: return names +def extract_last_message_content(messages: List[Any]) -> str: + """ + Extract the string content of the last message in a list. + + :param messages: List of LangChain message objects + :return: String content of the last message, or empty string if none or content is not a str + """ + if messages: + last = messages[-1] + if hasattr(last, 'content') and isinstance(last.content, str): + return last.content + return '' + + def sum_token_usage_from_messages(messages: List[Any]) -> Optional[TokenUsage]: """ Sum token usage across LangChain messages using get_ai_usage_from_response per message. diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py index c0c0b5c0..6523388b 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py @@ -1,8 +1,7 @@ """LangGraph agent graph runner for LaunchDarkly AI SDK.""" -import operator import time -from typing import Annotated, Any, List +from typing import Annotated, Any, Dict, List, Tuple from ldai import log from ldai.agent_graph import AgentGraphDefinition, AgentGraphNode @@ -10,16 +9,55 @@ from ldai.providers.types import LDAIMetrics from ldai_langchain.langchain_helper import ( + build_structured_tools, create_langchain_model, - get_ai_metrics_from_response, - get_ai_usage_from_response, - get_tool_calls_from_response, + extract_last_message_content, sum_token_usage_from_messages, ) +from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler + + +def _make_handoff_tool(child_key: str, description: str) -> Any: + """ + Create a tool that transfers control to ``child_key``. + + Uses the ``@tool`` decorator with ``InjectedState`` + ``InjectedToolCallId`` + so LangGraph's ToolNode handles the ``Command`` return value correctly. + The tool explicitly creates a ToolMessage in ``Command.update`` to satisfy + the LangChain/OpenAI message-chain contract. + """ + from typing import Annotated as _Annotated + + from langchain_core.messages import ToolMessage + from langchain_core.tools import tool + from langchain_core.tools.base import InjectedToolCallId + from langgraph.prebuilt import InjectedState + from langgraph.types import Command + + tool_name = f"transfer_to_{child_key.replace('-', '_')}" + + @tool(tool_name, description=description) + def handoff( + state: _Annotated[Any, InjectedState], # noqa: ARG001 + tool_call_id: _Annotated[str, InjectedToolCallId], + ) -> Command: + tool_message = ToolMessage( + content=f'Transferred to {child_key}', + name=tool_name, + tool_call_id=tool_call_id, + ) + return Command(goto=child_key, update={'messages': [tool_message]}) + + return handoff class LangGraphAgentGraphRunner(AgentGraphRunner): """ + CAUTION: + This feature is experimental and should NOT be considered ready for production use. + It may change or be removed without notice and is not subject to backwards + compatibility guarantees. + AgentGraphRunner implementation for LangGraph. Compiles and runs the agent graph with LangGraph and automatically records @@ -39,111 +77,203 @@ def __init__(self, graph: AgentGraphDefinition, tools: ToolRegistry): self._graph = graph self._tools = tools - async def run(self, input: Any) -> AgentGraphResult: + def _build_graph(self) -> Tuple[Any, Dict[str, str]]: """ - Run the agent graph with the given input. + Build and compile the LangGraph StateGraph from the AgentGraphDefinition. - Builds a LangGraph StateGraph from the AgentGraphDefinition, compiles - it, and invokes it. Tracks latency and invocation success/failure. - - :param input: The string prompt to send to the agent graph - :return: AgentGraphResult with the final output and metrics + :return: Tuple of (compiled_graph, fn_name_to_config_key) where + fn_name_to_config_key maps tool function __name__ to LD config key. """ - tracker = self._graph.get_tracker() - start_ns = time.perf_counter_ns() - try: - from langchain_core.messages import AnyMessage, HumanMessage - from langgraph.graph import END, START, StateGraph - from typing_extensions import TypedDict - - class WorkflowState(TypedDict): - messages: Annotated[List[Any], operator.add] - - agent_builder: StateGraph = StateGraph(WorkflowState) - root_node = self._graph.root() - root_key = root_node.get_key() if root_node else None - tools_ref = self._tools - exec_path: List[str] = [] - - def handle_traversal(node: AgentGraphNode, ctx: dict) -> None: - node_config = node.get_config() - node_key = node.get_key() - node_tracker = node_config.tracker - - model = None - if node_config.model: - lc_model = create_langchain_model(node_config) - tool_defs = node_config.model.get_parameter('tools') or [] - tool_fns = [ - tools_ref[t.get('name', '')] - for t in tool_defs - if t.get('name', '') in tools_ref - ] - model = lc_model.bind_tools(tool_fns) if tool_fns else lc_model - - def invoke(state: WorkflowState) -> WorkflowState: - exec_path.append(node_key) - if not model: - return {'messages': []} - gk = tracker.graph_key if tracker is not None else None - if node_tracker: - response = node_tracker.track_metrics_of( - lambda: model.invoke(state['messages']), - get_ai_metrics_from_response, - graph_key=gk, - ) - node_tracker.track_tool_calls( - get_tool_calls_from_response(response), - graph_key=tracker.graph_key if tracker is not None else None, + from langchain_core.messages import SystemMessage + from langgraph.graph import END, START, StateGraph + from langgraph.graph.message import add_messages + from langgraph.prebuilt import ToolNode, tools_condition + from typing_extensions import TypedDict + + class WorkflowState(TypedDict): + messages: Annotated[List[Any], add_messages] + + agent_builder: StateGraph = StateGraph(WorkflowState) + root_node = self._graph.root() + root_key = root_node.get_key() if root_node else None + tools_ref = self._tools + graph_structure: List[str] = [] + fn_name_to_config_key: Dict[str, str] = {} + + def handle_traversal(node: AgentGraphNode, ctx: dict) -> None: + node_config = node.get_config() + node_key = node.get_key() + instructions = node_config.instructions if hasattr(node_config, 'instructions') else None + outgoing_edges = node.get_edges() + + lc_model = None + tool_fns: list = [] + if node_config.model: + # We send an empty tool registry to avoid binding tools to the model. + lc_model = create_langchain_model(node_config, tool_registry=None) + + tool_fns = build_structured_tools(node_config, tools_ref) + + # Map tool name -> LD config key for callback attribution. + # build_structured_tools returns StructuredTool instances with tool.name set + # to the LD config key, so tool.name IS the config key. + for tool in tool_fns: + tool_name = getattr(tool, 'name', None) + if tool_name: + fn_name_to_config_key[tool_name] = tool_name + + # For nodes with multiple children, create a handoff tool per child so the + # LLM decides which agent to route to. Uses Command(goto=child_key) so + # LangGraph routes to the target without looping back here. + handoff_fns: list = [] + if lc_model and len(outgoing_edges) > 1: + for edge in outgoing_edges: + child_node = self._graph.get_node(edge.target_config) + description = ( + (edge.handoff or {}).get('description') + or ( + child_node.get_config().instructions[:120] + if child_node and child_node.get_config().instructions + else None ) - else: - response = model.invoke(state['messages']) + or f"Transfer control to {edge.target_config}" + ) + handoff_fns.append(_make_handoff_tool(edge.target_config, description)) + + all_tools = tool_fns + handoff_fns + if lc_model and all_tools: + # When handoff tools are present, disable parallel tool calls so the LLM + # picks exactly one destination rather than routing to multiple children. + bind_kwargs = {'parallel_tool_calls': False} if handoff_fns else {} + model = lc_model.bind_tools(all_tools, **bind_kwargs) + else: + model = lc_model + def make_node_fn(bound_model: Any, node_instructions: Any, nk: str): + async def invoke(state: WorkflowState) -> dict: + if not bound_model: + return {'messages': []} + msgs = list(state['messages']) + if node_instructions: + msgs = [SystemMessage(content=node_instructions)] + msgs + response = await bound_model.ainvoke(msgs) return {'messages': [response]} - invoke.__name__ = node_key + invoke.__name__ = nk + return invoke + + invoke_fn = make_node_fn(model, instructions, node_key) + agent_builder.add_node(node_key, invoke_fn) - agent_builder.add_node(node_key, invoke) + if node_key == root_key: + agent_builder.add_edge(START, node_key) + + # Collect node info for graph structure log + tool_names = [str(getattr(t, 'name', None) or getattr(t, '__name__', t)) for t in tool_fns] + edge_targets = [e.target_config for e in outgoing_edges] + node_desc = node_key + if tool_names: + node_desc += f"[tools:{','.join(tool_names)}]" + if handoff_fns: + node_desc += f"[handoff:{','.join(edge_targets)}]" + elif edge_targets: + node_desc += f"→{','.join(edge_targets)}" + else: + node_desc += "(terminal)" + graph_structure.append(node_desc) - if node_key == root_key: - agent_builder.add_edge(START, node_key) + if all_tools: + # ToolNode handles Command returns from handoff tools, routing to the target + # node. For functional tools it returns normal ToolMessages and we loop back. + # tools_condition exits to END when no tool is called. + tools_node_key = f"{node_key}__tools" + agent_builder.add_node(tools_node_key, ToolNode(all_tools)) + if not handoff_fns: + # No handoff tools: standard loop-back after tool execution. + after_loop = outgoing_edges[0].target_config if outgoing_edges else END + agent_builder.add_edge(tools_node_key, node_key) + agent_builder.add_conditional_edges( + node_key, + tools_condition, + {"tools": tools_node_key, END: after_loop}, + ) + else: + # Handoff tools use Command(goto=child_key) — LangGraph routes to the + # target directly without any extra edge. The ToolNode does NOT loop + # back here. tools_condition exits to END when no tool is called. + agent_builder.add_conditional_edges( + node_key, + tools_condition, + {"tools": tools_node_key, END: END}, + ) + else: if node.is_terminal(): agent_builder.add_edge(node_key, END) - - for edge in node.get_edges(): + for edge in outgoing_edges: agent_builder.add_edge(node_key, edge.target_config) - return None + return None + + self._graph.traverse(fn=handle_traversal) + + tracker = self._graph.get_tracker() + graph_key_str = tracker.graph_key if tracker else 'unknown' + log.debug( + f"LangGraphAgentGraphRunner: graph='{graph_key_str}', root='{root_key}', " + f"structure: {' | '.join(graph_structure)}" + ) + + compiled = agent_builder.compile() + return compiled, fn_name_to_config_key + + async def run(self, input: Any) -> AgentGraphResult: + """ + Run the agent graph with the given input. + + Builds a LangGraph StateGraph from the AgentGraphDefinition, compiles + it, and invokes it. Uses a LangChain callback handler to collect + per-node metrics, then flushes them to LaunchDarkly trackers. + + :param input: The string prompt to send to the agent graph + :return: AgentGraphResult with the final output and metrics + """ + tracker = self._graph.get_tracker() + start_ns = time.perf_counter_ns() - self._graph.traverse(fn=handle_traversal) - compiled = agent_builder.compile() + try: + from langchain_core.messages import HumanMessage + + compiled, fn_name_to_config_key = self._build_graph() + + node_keys = {node.get_key() for node in self._graph._nodes.values()} + handler = LDMetricsCallbackHandler(node_keys, fn_name_to_config_key) result = await compiled.ainvoke( # type: ignore[call-overload] - {'messages': [HumanMessage(content=str(input))]} + {'messages': [HumanMessage(content=str(input))]}, + config={'callbacks': [handler], 'recursion_limit': 25}, ) - duration = (time.perf_counter_ns() - start_ns) // 1_000_000 - output = '' + duration = (time.perf_counter_ns() - start_ns) // 1_000_000 messages = result.get('messages', []) - if messages: - last = messages[-1] - if hasattr(last, 'content'): - output = str(last.content) + output = extract_last_message_content(messages) + # Flush per-node metrics to LD trackers + handler.flush(self._graph, tracker) + + # Graph-level metrics if tracker: - tracker.track_path(exec_path) + tracker.track_path(handler.path) tracker.track_latency(duration) tracker.track_invocation_success() - tracker.track_total_tokens( - sum_token_usage_from_messages(messages) - ) + tracker.track_total_tokens(sum_token_usage_from_messages(messages)) return AgentGraphResult( output=output, raw=result, metrics=LDAIMetrics(success=True), ) + except Exception as exc: if isinstance(exc, ImportError): log.warning( diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py new file mode 100644 index 00000000..e2173662 --- /dev/null +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py @@ -0,0 +1,218 @@ +import time +from typing import Any, Dict, List, Optional, Set +from uuid import UUID + +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import LLMResult +from ldai.agent_graph import AgentGraphDefinition +from ldai.tracker import TokenUsage + +from ldai_langchain.langchain_helper import get_ai_usage_from_response + + +class LDMetricsCallbackHandler(BaseCallbackHandler): + """ + CAUTION: + This feature is experimental and should NOT be considered ready for production use. + It may change or be removed without notice and is not subject to backwards + compatibility guarantees. + + LangChain callback handler that collects per-node metrics during a LangGraph run. + + Records token usage, tool calls, and duration for each agent node in the graph, + then flushes them to LaunchDarkly trackers after the run completes via ``flush()``. + """ + + def __init__(self, node_keys: Set[str], fn_name_to_config_key: Dict[str, str]): + """ + Initialize the handler. + + :param node_keys: Set of LangGraph node keys that represent agent nodes + (excludes ``__tools`` suffix nodes). + :param fn_name_to_config_key: Mapping from tool function ``__name__`` to + the LD config key for that tool (e.g. ``'fetch_weather'`` -> ``'get_weather_open_meteo'``). + """ + super().__init__() + self._node_keys = node_keys + self._fn_name_to_config_key = fn_name_to_config_key + + # run_id -> node_key for active chain runs + self._run_to_node: Dict[UUID, str] = {} + # accumulated token usage per node + self._node_tokens: Dict[str, TokenUsage] = {} + # tool config keys called per node + self._node_tool_calls: Dict[str, List[str]] = {} + # start time (ns) per node — only set while running + self._node_start_ns: Dict[str, int] = {} + # accumulated duration (ms) per node + self._node_duration_ms: Dict[str, int] = {} + # execution path in order (deduplicated) + self._path: List[str] = [] + self._path_set: Set[str] = set() + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def path(self) -> List[str]: + """Execution path through the graph in order.""" + return list(self._path) + + @property + def node_tokens(self) -> Dict[str, TokenUsage]: + """Accumulated token usage per node key.""" + return dict(self._node_tokens) + + @property + def node_tool_calls(self) -> Dict[str, List[str]]: + """Tool config keys called per node key.""" + return {k: list(v) for k, v in self._node_tool_calls.items()} + + @property + def node_durations_ms(self) -> Dict[str, int]: + """Accumulated duration in milliseconds per node key.""" + return dict(self._node_duration_ms) + + # ------------------------------------------------------------------ + # Callbacks + # ------------------------------------------------------------------ + + def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + name: Optional[str] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + """Record start of a chain run; attribute to the matching agent node.""" + if name is None: + return + + if name in self._node_keys: + self._run_to_node[run_id] = name + self._node_start_ns[name] = time.perf_counter_ns() + if name not in self._path_set: + self._path.append(name) + self._path_set.add(name) + elif name.endswith('__tools'): + stripped = name[: -len('__tools')] + if stripped in self._node_keys: + # Attribute tool events to the owning agent node + self._run_to_node[run_id] = stripped + + def on_chain_end( + self, + outputs: Dict[str, Any], + *, + run_id: UUID, + **kwargs: Any, + ) -> None: + """Record end of a chain run and accumulate elapsed duration.""" + node_key = self._run_to_node.get(run_id) + if node_key is None: + return + start_ns = self._node_start_ns.pop(node_key, None) + if start_ns is not None: + elapsed_ms = (time.perf_counter_ns() - start_ns) // 1_000_000 + self._node_duration_ms[node_key] = ( + self._node_duration_ms.get(node_key, 0) + elapsed_ms + ) + + def on_llm_end( + self, + response: LLMResult, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + """Accumulate token usage for the node that owns this LLM call.""" + if parent_run_id is None: + return + node_key = self._run_to_node.get(parent_run_id) + if node_key is None: + return + + try: + message = response.generations[0][0].message + except (IndexError, AttributeError, TypeError): + return + usage = get_ai_usage_from_response(message) + if usage is None: + return + + existing = self._node_tokens.get(node_key) + if existing is None: + self._node_tokens[node_key] = usage + else: + self._node_tokens[node_key] = TokenUsage( + total=existing.total + usage.total, + input=existing.input + usage.input, + output=existing.output + usage.output, + ) + + def on_tool_end( + self, + output: Any, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Record a tool invocation for the owning agent node.""" + if parent_run_id is None or name is None: + return + node_key = self._run_to_node.get(parent_run_id) + if node_key is None: + return + + config_key = self._fn_name_to_config_key.get(name) + if config_key is None: + # Tool is not a registered functional tool (e.g. a handoff tool) — skip tracking. + return + if node_key not in self._node_tool_calls: + self._node_tool_calls[node_key] = [] + self._node_tool_calls[node_key].append(config_key) + + # ------------------------------------------------------------------ + # Flush + # ------------------------------------------------------------------ + + def flush(self, graph: AgentGraphDefinition, graph_tracker: Any) -> None: + """ + Emit all collected per-node metrics to the LaunchDarkly trackers. + + Call this once after the graph run completes. + + :param graph: The AgentGraphDefinition whose nodes hold the LD config trackers. + :param graph_tracker: The AIGraphTracker for the overall graph (may be None). + """ + gk = graph_tracker.graph_key if graph_tracker is not None else None + for node_key in self._path: + node = graph.get_node(node_key) + if not node: + continue + config_tracker = node.get_config().tracker + if not config_tracker: + continue + + usage = self._node_tokens.get(node_key) + if usage: + config_tracker.track_tokens(usage, graph_key=gk) + + duration = self._node_duration_ms.get(node_key) + if duration is not None: + config_tracker.track_duration(duration, graph_key=gk) + + config_tracker.track_success(graph_key=gk) + + for tool_key in self._node_tool_calls.get(node_key, []): + config_tracker.track_tool_call(tool_key, graph_key=gk) + diff --git a/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py b/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py index 52cca655..95f9e10f 100644 --- a/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py +++ b/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py @@ -515,3 +515,52 @@ async def test_returns_failure_when_exception_thrown(self): assert result.output == "" assert result.metrics.success is False + + +class TestBuildTools: + """Tests for build_structured_tools (sync vs async registry callables).""" + + def test_registers_sync_callable_as_structured_tool_func(self): + from ldai.models import AIAgentConfig, ModelConfig, ProviderConfig + from ldai_langchain.langchain_helper import build_structured_tools + + def sync_tool(x: str = '') -> str: + return 'ok' + + cfg = AIAgentConfig( + key='n', + enabled=True, + model=ModelConfig( + name='gpt-4', + parameters={'tools': [{'name': 'my_tool', 'type': 'function', 'parameters': {}}]}, + ), + provider=ProviderConfig(name='openai'), + instructions='', + tracker=MagicMock(), + ) + tools = build_structured_tools(cfg, {'my_tool': sync_tool}) + assert len(tools) == 1 + assert tools[0].func is sync_tool + assert getattr(tools[0], 'coroutine', None) is None + + def test_registers_async_callable_as_structured_tool_coroutine(self): + from ldai.models import AIAgentConfig, ModelConfig, ProviderConfig + from ldai_langchain.langchain_helper import build_structured_tools + + async def async_tool(x: str = '') -> str: + return 'ok' + + cfg = AIAgentConfig( + key='n', + enabled=True, + model=ModelConfig( + name='gpt-4', + parameters={'tools': [{'name': 'my_tool', 'type': 'function', 'parameters': {}}]}, + ), + provider=ProviderConfig(name='openai'), + instructions='', + tracker=MagicMock(), + ) + tools = build_structured_tools(cfg, {'my_tool': async_tool}) + assert len(tools) == 1 + assert tools[0].coroutine is async_tool diff --git a/packages/ai-providers/server-ai-langchain/tests/test_langgraph_agent_graph_runner.py b/packages/ai-providers/server-ai-langchain/tests/test_langgraph_agent_graph_runner.py index de5e8d97..07802cb2 100644 --- a/packages/ai-providers/server-ai-langchain/tests/test_langgraph_agent_graph_runner.py +++ b/packages/ai-providers/server-ai-langchain/tests/test_langgraph_agent_graph_runner.py @@ -124,7 +124,7 @@ async def test_langgraph_runner_run_success(): mock_model_response.tool_calls = None mock_llm = MagicMock() - mock_llm.invoke = MagicMock(return_value=mock_model_response) + mock_llm.ainvoke = AsyncMock(return_value=mock_model_response) mock_init_model = MagicMock() mock_init_model.return_value = mock_llm diff --git a/packages/ai-providers/server-ai-langchain/tests/test_langgraph_callback_handler.py b/packages/ai-providers/server-ai-langchain/tests/test_langgraph_callback_handler.py new file mode 100644 index 00000000..79a4a213 --- /dev/null +++ b/packages/ai-providers/server-ai-langchain/tests/test_langgraph_callback_handler.py @@ -0,0 +1,478 @@ +""" +Unit tests for LDMetricsCallbackHandler. + +Tests the callback handler directly by simulating the events that LangChain +fires during a graph run — without needing a real or mock LangGraph execution. +""" + +from collections import defaultdict +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest + +from langchain_core.messages import AIMessage +from langchain_core.outputs import ChatGeneration, LLMResult + +from ldai.agent_graph import AgentGraphDefinition +from ldai.models import AIAgentConfig, AIAgentGraphConfig, ModelConfig, ProviderConfig +from ldai.tracker import AIGraphTracker, LDAIConfigTracker, TokenUsage +from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_graph(mock_ld_client: MagicMock, node_key: str = 'root-agent', graph_key: str = 'test-graph'): + """Build a minimal single-node AgentGraphDefinition for flush() tests.""" + context = MagicMock() + node_tracker = LDAIConfigTracker( + ld_client=mock_ld_client, + variation_key='v1', + config_key=node_key, + version=1, + model_name='gpt-4', + provider_name='openai', + context=context, + ) + graph_tracker = AIGraphTracker( + ld_client=mock_ld_client, + variation_key='v1', + graph_key=graph_key, + version=1, + context=context, + ) + node_config = AIAgentConfig( + key=node_key, + enabled=True, + model=ModelConfig(name='gpt-4', parameters={}), + provider=ProviderConfig(name='openai'), + instructions='Be helpful.', + tracker=node_tracker, + ) + graph_config = AIAgentGraphConfig( + key=graph_key, + root_config_key=node_key, + edges=[], + enabled=True, + ) + nodes = AgentGraphDefinition.build_nodes(graph_config, {node_key: node_config}) + return AgentGraphDefinition( + agent_graph=graph_config, + nodes=nodes, + context=context, + enabled=True, + tracker=graph_tracker, + ) + + +def _llm_result(total: int, prompt: int, completion: int) -> LLMResult: + return LLMResult( + generations=[[ChatGeneration( + message=AIMessage( + content='ok', + usage_metadata={'total_tokens': total, 'input_tokens': prompt, 'output_tokens': completion}, + ), + text='ok', + )]], + llm_output={}, + ) + + +def _events(mock_ld_client: MagicMock) -> dict: + result = defaultdict(list) + for call in mock_ld_client.track.call_args_list: + name, _ctx, data, value = call.args + result[name].append((data, value)) + return dict(result) + + +# --------------------------------------------------------------------------- +# on_chain_start tests +# --------------------------------------------------------------------------- + +def test_on_chain_start_records_agent_node(): + """Agent node name is recorded in path and run_to_node map.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + run_id = uuid4() + handler.on_chain_start({}, {}, run_id=run_id, name='root-agent') + assert handler.path == ['root-agent'] + + +def test_on_chain_start_deduplicates_path(): + """Multiple starts for the same node appear only once in path.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + run_id1 = uuid4() + run_id2 = uuid4() + handler.on_chain_start({}, {}, run_id=run_id1, name='root-agent') + handler.on_chain_start({}, {}, run_id=run_id2, name='root-agent') + assert handler.path == ['root-agent'] + + +def test_on_chain_start_tools_node_attributed_to_agent(): + """A '__tools' chain start maps its run_id to the stripped agent node key.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + tools_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=tools_run_id, name='root-agent__tools') + # Tool node should NOT appear in path + assert handler.path == [] + # But the run_id should be attributed to the agent node for tool event lookup + assert handler._run_to_node.get(tools_run_id) == 'root-agent' + + +def test_on_chain_start_unknown_name_ignored(): + """Names not in node_keys and not __tools suffixed are ignored.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + handler.on_chain_start({}, {}, run_id=uuid4(), name='some-other-chain') + assert handler.path == [] + + +def test_on_chain_start_none_name_ignored(): + """None name does not raise.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + handler.on_chain_start({}, {}, run_id=uuid4(), name=None) + assert handler.path == [] + + +def test_on_chain_start_tools_for_unknown_agent_ignored(): + """A '__tools' chain whose stripped name is not in node_keys is ignored.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + run_id = uuid4() + handler.on_chain_start({}, {}, run_id=run_id, name='other-agent__tools') + assert run_id not in handler._run_to_node + + +def test_on_chain_start_records_path_order(): + """Multiple distinct agent nodes appear in path in order of first appearance.""" + handler = LDMetricsCallbackHandler({'node-a', 'node-b'}, {}) + handler.on_chain_start({}, {}, run_id=uuid4(), name='node-a') + handler.on_chain_start({}, {}, run_id=uuid4(), name='node-b') + assert handler.path == ['node-a', 'node-b'] + + +# --------------------------------------------------------------------------- +# on_chain_end / duration tests +# --------------------------------------------------------------------------- + +def test_on_chain_end_accumulates_duration(): + """Duration is computed and stored after chain_end.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + run_id = uuid4() + handler.on_chain_start({}, {}, run_id=run_id, name='root-agent') + handler.on_chain_end({}, run_id=run_id) + # Duration may be 0 on fast machines but the key must be present + assert 'root-agent' in handler.node_durations_ms + + +def test_on_chain_end_accumulates_across_multiple_runs(): + """Duration accumulates (not overwritten) when a node runs multiple times.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + + run1 = uuid4() + handler.on_chain_start({}, {}, run_id=run1, name='root-agent') + handler.on_chain_end({}, run_id=run1) + duration_after_first = handler.node_durations_ms.get('root-agent', 0) + + run2 = uuid4() + handler.on_chain_start({}, {}, run_id=run2, name='root-agent') + handler.on_chain_end({}, run_id=run2) + duration_after_second = handler.node_durations_ms.get('root-agent', 0) + + assert duration_after_second >= duration_after_first + + +def test_on_chain_end_unknown_run_id_ignored(): + """chain_end for an unknown run_id does not raise.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + handler.on_chain_end({}, run_id=uuid4()) # should not raise + + +# --------------------------------------------------------------------------- +# on_llm_end / token tests +# --------------------------------------------------------------------------- + +def test_on_llm_end_accumulates_tokens(): + """Token usage from llm_output is recorded for the parent node.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + node_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent') + + result = _llm_result(total=15, prompt=10, completion=5) + handler.on_llm_end(result, run_id=uuid4(), parent_run_id=node_run_id) + + tokens = handler.node_tokens.get('root-agent') + assert tokens is not None + assert tokens.total == 15 + assert tokens.input == 10 + assert tokens.output == 5 + + +def test_on_llm_end_accumulates_across_multiple_calls(): + """Multiple LLM calls for the same node accumulate token counts.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + node_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent') + + result1 = _llm_result(total=10, prompt=7, completion=3) + result2 = _llm_result(total=6, prompt=4, completion=2) + handler.on_llm_end(result1, run_id=uuid4(), parent_run_id=node_run_id) + handler.on_llm_end(result2, run_id=uuid4(), parent_run_id=node_run_id) + + tokens = handler.node_tokens['root-agent'] + assert tokens.total == 16 + assert tokens.input == 11 + assert tokens.output == 5 + + +def test_on_llm_end_none_parent_run_id_ignored(): + """LLM end with parent_run_id=None does not raise.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + result = _llm_result(total=5, prompt=3, completion=2) + handler.on_llm_end(result, run_id=uuid4(), parent_run_id=None) + assert handler.node_tokens == {} + + +def test_on_llm_end_unknown_parent_run_id_ignored(): + """LLM end for a run_id not in _run_to_node is silently ignored.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + result = _llm_result(total=5, prompt=3, completion=2) + handler.on_llm_end(result, run_id=uuid4(), parent_run_id=uuid4()) + assert handler.node_tokens == {} + + +def test_on_llm_end_camel_case_token_keys(): + """camelCase token keys in response_metadata (e.g. some AWS Bedrock models) are parsed.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + node_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent') + + msg = AIMessage(content='ok', response_metadata={ + 'tokenUsage': {'totalTokens': 20, 'promptTokens': 12, 'completionTokens': 8} + }) + result = LLMResult( + generations=[[ChatGeneration(message=msg, text='ok')]], + llm_output={}, + ) + handler.on_llm_end(result, run_id=uuid4(), parent_run_id=node_run_id) + + tokens = handler.node_tokens.get('root-agent') + assert tokens is not None + assert tokens.total == 20 + assert tokens.input == 12 + assert tokens.output == 8 + + +# --------------------------------------------------------------------------- +# on_tool_end tests +# --------------------------------------------------------------------------- + +def test_on_tool_end_records_tool_call(): + """Tool end event records config key for the owning agent node.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {'fetch_weather': 'get_weather_open_meteo'}) + tools_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=tools_run_id, name='root-agent__tools') + handler.on_tool_end('sunny', run_id=uuid4(), parent_run_id=tools_run_id, name='fetch_weather') + assert handler.node_tool_calls.get('root-agent') == ['get_weather_open_meteo'] + + +def test_on_tool_end_skips_unregistered_tools(): + """Tool end is ignored for tools not in the fn_name_to_config_key map (e.g. handoff tools).""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + tools_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=tools_run_id, name='root-agent__tools') + handler.on_tool_end('result', run_id=uuid4(), parent_run_id=tools_run_id, name='transfer_to_child') + assert handler.node_tool_calls.get('root-agent') is None + + +def test_on_tool_end_multiple_tools_accumulated(): + """Multiple tool calls are accumulated in order.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {'search': 'search', 'summarize': 'summarize'}) + tools_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=tools_run_id, name='root-agent__tools') + handler.on_tool_end('r1', run_id=uuid4(), parent_run_id=tools_run_id, name='search') + handler.on_tool_end('r2', run_id=uuid4(), parent_run_id=tools_run_id, name='summarize') + assert handler.node_tool_calls.get('root-agent') == ['search', 'summarize'] + + +def test_on_tool_end_none_parent_run_id_ignored(): + """Tool end with parent_run_id=None does not raise.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + handler.on_tool_end('result', run_id=uuid4(), parent_run_id=None, name='my_tool') + assert handler.node_tool_calls == {} + + +def test_on_tool_end_none_name_ignored(): + """Tool end with name=None does not raise.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + run_id = uuid4() + handler.on_chain_start({}, {}, run_id=run_id, name='root-agent') + handler.on_tool_end('result', run_id=uuid4(), parent_run_id=run_id, name=None) + assert handler.node_tool_calls == {} + + +# --------------------------------------------------------------------------- +# flush() tests +# --------------------------------------------------------------------------- + +def test_flush_emits_token_events_to_ld_tracker(): + """flush() calls track_tokens on the node's config tracker.""" + mock_ld_client = MagicMock() + graph = _make_graph(mock_ld_client, node_key='root-agent', graph_key='g1') + tracker = graph.get_tracker() + + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + node_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent') + handler.on_llm_end(_llm_result(15, 10, 5), run_id=uuid4(), parent_run_id=node_run_id) + handler.flush(graph, tracker) + + ev = _events(mock_ld_client) + assert ev['$ld:ai:tokens:total'][0][1] == 15 + assert ev['$ld:ai:tokens:input'][0][1] == 10 + assert ev['$ld:ai:tokens:output'][0][1] == 5 + assert ev['$ld:ai:generation:success'][0][1] == 1 + + +def test_flush_emits_duration(): + """flush() calls track_duration when duration was recorded.""" + mock_ld_client = MagicMock() + graph = _make_graph(mock_ld_client) + tracker = graph.get_tracker() + + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + run_id = uuid4() + handler.on_chain_start({}, {}, run_id=run_id, name='root-agent') + handler.on_chain_end({}, run_id=run_id) + handler.flush(graph, tracker) + + ev = _events(mock_ld_client) + assert '$ld:ai:duration:total' in ev + + +def test_flush_emits_tool_calls(): + """flush() calls track_tool_call for each recorded tool invocation.""" + mock_ld_client = MagicMock() + graph = _make_graph(mock_ld_client) + tracker = graph.get_tracker() + + handler = LDMetricsCallbackHandler({'root-agent'}, {'fn_search': 'search'}) + # The agent node must be started first so it appears in the path for flush() + agent_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=agent_run_id, name='root-agent') + # Tool calls are attributed via the __tools chain run_id + tools_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=tools_run_id, name='root-agent__tools') + handler.on_tool_end('r', run_id=uuid4(), parent_run_id=tools_run_id, name='fn_search') + handler.flush(graph, tracker) + + ev = _events(mock_ld_client) + tool_events = ev.get('$ld:ai:tool_call', []) + assert len(tool_events) == 1 + assert tool_events[0][0]['toolKey'] == 'search' + + +def test_flush_includes_graph_key_in_node_events(): + """flush() passes graph_key to the node tracker so graphKey appears in events.""" + mock_ld_client = MagicMock() + graph = _make_graph(mock_ld_client, graph_key='my-graph') + tracker = graph.get_tracker() + + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + node_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent') + handler.on_llm_end(_llm_result(5, 3, 2), run_id=uuid4(), parent_run_id=node_run_id) + handler.flush(graph, tracker) + + ev = _events(mock_ld_client) + token_data = ev['$ld:ai:tokens:total'][0][0] + assert token_data.get('graphKey') == 'my-graph' + + +def test_flush_with_none_tracker_uses_no_graph_key(): + """flush() with graph_tracker=None does not fail and omits graphKey.""" + mock_ld_client = MagicMock() + graph = _make_graph(mock_ld_client) + + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + node_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent') + handler.on_llm_end(_llm_result(5, 3, 2), run_id=uuid4(), parent_run_id=node_run_id) + handler.flush(graph, None) # graph_tracker=None + + ev = _events(mock_ld_client) + token_data = ev['$ld:ai:tokens:total'][0][0] + assert 'graphKey' not in token_data + + +def test_flush_skips_nodes_not_in_path(): + """flush() only emits events for nodes that were actually executed.""" + mock_ld_client = MagicMock() + graph = _make_graph(mock_ld_client) + tracker = graph.get_tracker() + + # Handler with 'root-agent' in node_keys but never started + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + handler.flush(graph, tracker) + + ev = _events(mock_ld_client) + assert '$ld:ai:tokens:total' not in ev + assert '$ld:ai:generation:success' not in ev + + +def test_flush_skips_node_without_tracker(): + """flush() silently skips nodes whose config has no tracker.""" + mock_ld_client = MagicMock() + context = MagicMock() + + node_config_no_tracker = AIAgentConfig( + key='no-track', + enabled=True, + model=ModelConfig(name='gpt-4', parameters={}), + provider=ProviderConfig(name='openai'), + instructions='', + tracker=None, + ) + graph_config = AIAgentGraphConfig( + key='g', root_config_key='no-track', edges=[], enabled=True + ) + nodes = AgentGraphDefinition.build_nodes(graph_config, {'no-track': node_config_no_tracker}) + graph = AgentGraphDefinition( + agent_graph=graph_config, + nodes=nodes, + context=context, + enabled=True, + tracker=None, + ) + + handler = LDMetricsCallbackHandler({'no-track'}, {}) + node_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=node_run_id, name='no-track') + handler.on_llm_end(_llm_result(5, 3, 2), run_id=uuid4(), parent_run_id=node_run_id) + handler.flush(graph, None) # should not raise + + mock_ld_client.track.assert_not_called() + + +# --------------------------------------------------------------------------- +# properties +# --------------------------------------------------------------------------- + +def test_path_property_returns_copy(): + """Mutating the returned path does not affect the handler's internal state.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + handler.on_chain_start({}, {}, run_id=uuid4(), name='root-agent') + path = handler.path + path.append('extra') + assert handler.path == ['root-agent'] + + +def test_node_tokens_property_returns_copy(): + """Mutating the returned dict does not affect the handler's internal state.""" + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + node_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent') + handler.on_llm_end(_llm_result(5, 3, 2), run_id=uuid4(), parent_run_id=node_run_id) + tokens = handler.node_tokens + tokens['other'] = TokenUsage(total=1, input=1, output=0) + assert 'other' not in handler.node_tokens diff --git a/packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py b/packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py new file mode 100644 index 00000000..0f59214a --- /dev/null +++ b/packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py @@ -0,0 +1,614 @@ +""" +Integration tests for LangGraphAgentGraphRunner tracking pipeline. + +Uses real AIGraphTracker and LDAIConfigTracker backed by a mock LD client, +and a fake LangChain model to verify that the correct LD events are emitted +with the correct payloads — without making real API calls. +""" + +import pytest +from collections import defaultdict +from unittest.mock import AsyncMock, MagicMock, patch + +from ldai.agent_graph import AgentGraphDefinition +from ldai.models import AIAgentGraphConfig, AIAgentConfig, Edge, ModelConfig, ProviderConfig +from ldai.tracker import AIGraphTracker, LDAIConfigTracker +from ldai_langchain.langgraph_agent_graph_runner import LangGraphAgentGraphRunner + +pytestmark = pytest.mark.skipif( + pytest.importorskip('langgraph', reason='langgraph not installed') is None, + reason='langgraph not installed', +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_graph( + mock_ld_client: MagicMock, + node_key: str = 'root-agent', + graph_key: str = 'test-graph', + tool_names: list = None, +) -> AgentGraphDefinition: + """ + Build an AgentGraphDefinition backed by real tracker objects that record + events to a mock LD client. + """ + context = MagicMock() + + node_tracker = LDAIConfigTracker( + ld_client=mock_ld_client, + variation_key='test-variation', + config_key=node_key, + version=1, + model_name='gpt-4', + provider_name='openai', + context=context, + ) + + graph_tracker = AIGraphTracker( + ld_client=mock_ld_client, + variation_key='test-variation', + graph_key=graph_key, + version=1, + context=context, + ) + + tool_defs = ( + [{'name': name, 'type': 'function', 'description': '', 'parameters': {}} + for name in tool_names] + if tool_names else None + ) + + root_config = AIAgentConfig( + key=node_key, + enabled=True, + model=ModelConfig(name='gpt-4', parameters={'tools': tool_defs} if tool_defs else {}), + provider=ProviderConfig(name='openai'), + instructions='You are a helpful assistant.', + tracker=node_tracker, + ) + + graph_config = AIAgentGraphConfig( + key=graph_key, + root_config_key=node_key, + edges=[], + enabled=True, + ) + + nodes = AgentGraphDefinition.build_nodes(graph_config, {node_key: root_config}) + return AgentGraphDefinition( + agent_graph=graph_config, + nodes=nodes, + context=context, + enabled=True, + tracker=graph_tracker, + ) + + +def _make_fake_response( + content: str, + input_tokens: int = 10, + output_tokens: int = 5, + tool_call_names: list = None, +): + """Create a real AIMessage with usage metadata and optional tool calls.""" + from langchain_core.messages import AIMessage + + tool_calls = [ + {'name': name, 'args': {}, 'id': f'call_{i}', 'type': 'tool_call'} + for i, name in enumerate(tool_call_names or []) + ] + + return AIMessage( + content=content, + tool_calls=tool_calls, + usage_metadata={ + 'input_tokens': input_tokens, + 'output_tokens': output_tokens, + 'total_tokens': input_tokens + output_tokens, + }, + ) + + +def _events(mock_ld_client: MagicMock) -> dict: + """Return dict of event_name -> list of (data, value) from all track() calls.""" + result = defaultdict(list) + for call in mock_ld_client.track.call_args_list: + name, _ctx, data, value = call.args + result[name].append((data, value)) + return dict(result) + + +def _mock_model(response): + """Return a mock LangChain model that always returns response on ainvoke().""" + model = MagicMock() + model.ainvoke = AsyncMock(return_value=response) + model.bind_tools.return_value = model + return model + + +def _make_two_node_graph(mock_ld_client: MagicMock) -> 'AgentGraphDefinition': + """Build a two-node AgentGraphDefinition (root-agent → child-agent).""" + context = MagicMock() + + root_tracker = LDAIConfigTracker( + ld_client=mock_ld_client, + variation_key='test-variation', + config_key='root-agent', + version=1, + model_name='gpt-4', + provider_name='openai', + context=context, + ) + child_tracker = LDAIConfigTracker( + ld_client=mock_ld_client, + variation_key='test-variation', + config_key='child-agent', + version=1, + model_name='gpt-4', + provider_name='openai', + context=context, + ) + graph_tracker = AIGraphTracker( + ld_client=mock_ld_client, + variation_key='test-variation', + graph_key='two-node-graph', + version=1, + context=context, + ) + + root_config = AIAgentConfig( + key='root-agent', + enabled=True, + model=ModelConfig(name='gpt-4', parameters={}), + provider=ProviderConfig(name='openai'), + instructions='You are root.', + tracker=root_tracker, + ) + child_config = AIAgentConfig( + key='child-agent', + enabled=True, + model=ModelConfig(name='gpt-4', parameters={}), + provider=ProviderConfig(name='openai'), + instructions='You are child.', + tracker=child_tracker, + ) + + edge = Edge(key='root-to-child', source_config='root-agent', target_config='child-agent') + graph_config = AIAgentGraphConfig( + key='two-node-graph', + root_config_key='root-agent', + edges=[edge], + enabled=True, + ) + + nodes = AgentGraphDefinition.build_nodes(graph_config, { + 'root-agent': root_config, + 'child-agent': child_config, + }) + return AgentGraphDefinition( + agent_graph=graph_config, + nodes=nodes, + context=context, + enabled=True, + tracker=graph_tracker, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_tracks_node_and_graph_tokens_on_success(): + """Node-level and graph-level token events fire with the correct counts.""" + from uuid import uuid4 + from langchain_core.messages import AIMessage as _AIMsg + from langchain_core.outputs import LLMResult, ChatGeneration + from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler + + mock_ld_client = MagicMock() + graph = _make_graph(mock_ld_client) + fake_response = _make_fake_response('Sunny.', input_tokens=10, output_tokens=5) + + with patch('ldai_langchain.langgraph_agent_graph_runner.create_langchain_model', + return_value=_mock_model(fake_response)): + runner = LangGraphAgentGraphRunner(graph, {}) + result = await runner.run("What's the weather?") + + assert result.metrics.success is True + assert result.output == 'Sunny.' + + # Manually simulate what the callback handler would collect and flush + # (mock models don't fire LangChain callbacks, so we test flush directly) + mock_ld_client2 = MagicMock() + graph2 = _make_graph(mock_ld_client2) + tracker2 = graph2.get_tracker() + + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + node_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent') + + llm_result = LLMResult( + generations=[[ChatGeneration( + message=_AIMsg(content='Sunny.', usage_metadata={'total_tokens': 15, 'input_tokens': 10, 'output_tokens': 5}), + text='Sunny.', + )]], + llm_output={}, + ) + handler.on_llm_end(llm_result, run_id=uuid4(), parent_run_id=node_run_id) + handler.on_chain_end({}, run_id=node_run_id) + handler.flush(graph2, tracker2) + + ev2 = _events(mock_ld_client2) + assert ev2['$ld:ai:tokens:total'][0][1] == 15 + assert ev2['$ld:ai:tokens:input'][0][1] == 10 + assert ev2['$ld:ai:tokens:output'][0][1] == 5 + assert ev2['$ld:ai:generation:success'][0][1] == 1 + assert '$ld:ai:duration:total' in ev2 + + # Graph-level events from the real run + ev = _events(mock_ld_client) + assert ev['$ld:ai:graph:total_tokens'][0][1] == 15 + assert ev['$ld:ai:graph:invocation_success'][0][1] == 1 + assert '$ld:ai:graph:latency' in ev + assert '$ld:ai:graph:path' in ev + + +@pytest.mark.asyncio +async def test_tracks_execution_path(): + """The path event contains the executed node key.""" + mock_ld_client = MagicMock() + graph = _make_graph(mock_ld_client, node_key='my-agent') + fake_response = _make_fake_response('Done.') + + with patch('ldai_langchain.langgraph_agent_graph_runner.create_langchain_model', + return_value=_mock_model(fake_response)): + runner = LangGraphAgentGraphRunner(graph, {}) + await runner.run('hello') + + ev = _events(mock_ld_client) + path_data = ev['$ld:ai:graph:path'][0][0] + assert 'my-agent' in path_data['path'] + + +@pytest.mark.asyncio +async def test_tracks_tool_calls(): + """A tool_call event fires for each tool name found in the model response.""" + from uuid import uuid4 + from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler + + mock_ld_client = MagicMock() + graph = _make_graph(mock_ld_client, tool_names=['get_weather']) + + # Model returns a tool call on the first invoke, then a final answer. + tool_response = _make_fake_response('Calling tool.', tool_call_names=['get_weather']) + final_response = _make_fake_response('It is sunny in NYC.') + + mock_model = MagicMock() + mock_model.ainvoke = AsyncMock(side_effect=[tool_response, final_response]) + mock_model.bind_tools.return_value = mock_model + + def get_weather(location: str = 'NYC') -> str: + """Return the current weather for a location.""" + return 'sunny' + + tool_registry = {'get_weather': get_weather} + + with patch('ldai_langchain.langgraph_agent_graph_runner.create_langchain_model', + return_value=mock_model): + runner = LangGraphAgentGraphRunner(graph, tool_registry) + await runner.run('What is the weather?') + + # Simulate tool call tracking via the callback handler directly + mock_ld_client2 = MagicMock() + graph2 = _make_graph(mock_ld_client2, tool_names=['get_weather']) + tracker2 = graph2.get_tracker() + + handler = LDMetricsCallbackHandler({'root-agent'}, {'get_weather': 'get_weather'}) + # Agent node must appear in path for flush() to emit its events + agent_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=agent_run_id, name='root-agent') + tools_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=tools_run_id, name='root-agent__tools') + handler.on_tool_end('sunny', run_id=uuid4(), parent_run_id=tools_run_id, name='get_weather') + handler.flush(graph2, tracker2) + + ev2 = _events(mock_ld_client2) + tool_events = ev2.get('$ld:ai:tool_call', []) + assert len(tool_events) == 1 + assert tool_events[0][0]['toolKey'] == 'get_weather' + + +@pytest.mark.asyncio +async def test_tracks_multiple_tool_calls(): + """One tool_call event fires per tool name in the response.""" + from uuid import uuid4 + from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler + + mock_ld_client = MagicMock() + graph = _make_graph(mock_ld_client, tool_names=['search', 'summarize']) + + # Both tools called in one response; second invoke returns a final answer. + tool_response = _make_fake_response('Done.', tool_call_names=['search', 'summarize']) + final_response = _make_fake_response('Here is the summary.') + + mock_model = MagicMock() + mock_model.ainvoke = AsyncMock(side_effect=[tool_response, final_response]) + mock_model.bind_tools.return_value = mock_model + + def search(q: str = '') -> str: + """Search for information.""" + return q + + def summarize(text: str = '') -> str: + """Summarize the given text.""" + return text + + tool_registry = {'search': search, 'summarize': summarize} + + with patch('ldai_langchain.langgraph_agent_graph_runner.create_langchain_model', + return_value=mock_model): + runner = LangGraphAgentGraphRunner(graph, tool_registry) + await runner.run('Search and summarize.') + + # Simulate multiple tool calls via the callback handler directly + mock_ld_client2 = MagicMock() + graph2 = _make_graph(mock_ld_client2, tool_names=['search', 'summarize']) + tracker2 = graph2.get_tracker() + + fn_map = {'search': 'search', 'summarize': 'summarize'} + handler = LDMetricsCallbackHandler({'root-agent'}, fn_map) + # Agent node must appear in path for flush() to emit its events + agent_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=agent_run_id, name='root-agent') + tools_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=tools_run_id, name='root-agent__tools') + handler.on_tool_end('result', run_id=uuid4(), parent_run_id=tools_run_id, name='search') + handler.on_tool_end('summary', run_id=uuid4(), parent_run_id=tools_run_id, name='summarize') + handler.flush(graph2, tracker2) + + ev2 = _events(mock_ld_client2) + tool_keys = [data['toolKey'] for data, _ in ev2.get('$ld:ai:tool_call', [])] + assert sorted(tool_keys) == ['search', 'summarize'] + + +@pytest.mark.asyncio +async def test_tracks_graph_key_on_node_events(): + """Node-level events include the graphKey so they can be correlated to the graph.""" + from uuid import uuid4 + from langchain_core.messages import AIMessage as _AIMsg + from langchain_core.outputs import LLMResult, ChatGeneration + from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler + + mock_ld_client = MagicMock() + graph = _make_graph(mock_ld_client, graph_key='my-graph') + tracker = graph.get_tracker() + + handler = LDMetricsCallbackHandler({'root-agent'}, {}) + node_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent') + + llm_result = LLMResult( + generations=[[ChatGeneration( + message=_AIMsg(content='OK.', usage_metadata={'total_tokens': 8, 'input_tokens': 5, 'output_tokens': 3}), + text='OK.', + )]], + llm_output={}, + ) + handler.on_llm_end(llm_result, run_id=uuid4(), parent_run_id=node_run_id) + handler.flush(graph, tracker) + + ev = _events(mock_ld_client) + token_data = ev['$ld:ai:tokens:total'][0][0] + assert token_data.get('graphKey') == 'my-graph' + + +@pytest.mark.asyncio +async def test_tracks_failure_and_latency_on_model_error(): + """When the model raises, failure and latency events fire; success does not.""" + mock_ld_client = MagicMock() + graph = _make_graph(mock_ld_client) + + error_model = MagicMock() + error_model.ainvoke = AsyncMock(side_effect=RuntimeError('model error')) + error_model.bind_tools.return_value = error_model + + with patch('ldai_langchain.langgraph_agent_graph_runner.create_langchain_model', + return_value=error_model): + runner = LangGraphAgentGraphRunner(graph, {}) + result = await runner.run('fail') + + assert result.metrics.success is False + + ev = _events(mock_ld_client) + assert '$ld:ai:graph:invocation_failure' in ev + assert '$ld:ai:graph:latency' in ev + assert '$ld:ai:graph:invocation_success' not in ev + + +@pytest.mark.asyncio +async def test_multi_node_tracks_per_node_tokens_and_path(): + """Each node emits its own token events; path and graph total cover both nodes.""" + from uuid import uuid4 + from langchain_core.messages import AIMessage as _AIMsg + from langchain_core.outputs import LLMResult, ChatGeneration + from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler + + mock_ld_client = MagicMock() + graph = _make_two_node_graph(mock_ld_client) + + root_response = _make_fake_response('Root done.', input_tokens=10, output_tokens=5) + child_response = _make_fake_response('Child done.', input_tokens=3, output_tokens=2) + + def model_factory(node_config, **kwargs): + if node_config.key == 'root-agent': + return _mock_model(root_response) + return _mock_model(child_response) + + with patch('ldai_langchain.langgraph_agent_graph_runner.create_langchain_model', + side_effect=model_factory): + runner = LangGraphAgentGraphRunner(graph, {}) + result = await runner.run('hello') + + assert result.metrics.success is True + + # Simulate per-node token events via callback handler (mock models don't fire callbacks) + mock_ld_client2 = MagicMock() + graph2 = _make_two_node_graph(mock_ld_client2) + tracker2 = graph2.get_tracker() + + handler = LDMetricsCallbackHandler({'root-agent', 'child-agent'}, {}) + + root_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=root_run_id, name='root-agent') + root_llm_result = LLMResult( + generations=[[ChatGeneration( + message=_AIMsg(content='Root done.', usage_metadata={'total_tokens': 15, 'input_tokens': 10, 'output_tokens': 5}), + text='Root done.', + )]], + llm_output={}, + ) + handler.on_llm_end(root_llm_result, run_id=uuid4(), parent_run_id=root_run_id) + + child_run_id = uuid4() + handler.on_chain_start({}, {}, run_id=child_run_id, name='child-agent') + child_llm_result = LLMResult( + generations=[[ChatGeneration( + message=_AIMsg(content='Child done.', usage_metadata={'total_tokens': 5, 'input_tokens': 3, 'output_tokens': 2}), + text='Child done.', + )]], + llm_output={}, + ) + handler.on_llm_end(child_llm_result, run_id=uuid4(), parent_run_id=child_run_id) + + handler.flush(graph2, tracker2) + + ev2 = _events(mock_ld_client2) + + # Per-node token events identified by configKey + root_tokens = [(d, v) for d, v in ev2.get('$ld:ai:tokens:total', []) if d.get('configKey') == 'root-agent'] + child_tokens = [(d, v) for d, v in ev2.get('$ld:ai:tokens:total', []) if d.get('configKey') == 'child-agent'] + assert root_tokens[0][1] == 15 + assert child_tokens[0][1] == 5 + + # Graph-level total from the real runner run + ev = _events(mock_ld_client) + assert ev['$ld:ai:graph:total_tokens'][0][1] == 20 + + # Execution path includes both node keys (from real run) + path_data = ev['$ld:ai:graph:path'][0][0] + assert 'root-agent' in path_data['path'] + assert 'child-agent' in path_data['path'] + + +def _make_multi_child_graph(mock_ld_client: MagicMock) -> 'AgentGraphDefinition': + """Build a 3-node graph: orchestrator → agent-a, orchestrator → agent-b.""" + context = MagicMock() + + def _node_tracker(key: str) -> LDAIConfigTracker: + return LDAIConfigTracker( + ld_client=mock_ld_client, + variation_key='test-variation', + config_key=key, + version=1, + model_name='gpt-4', + provider_name='openai', + context=context, + ) + + graph_tracker = AIGraphTracker( + ld_client=mock_ld_client, + variation_key='test-variation', + graph_key='multi-child-graph', + version=1, + context=context, + ) + + configs = { + 'orchestrator': AIAgentConfig( + key='orchestrator', + enabled=True, + model=ModelConfig(name='gpt-4', parameters={}), + provider=ProviderConfig(name='openai'), + instructions='Route to the appropriate specialist agent.', + tracker=_node_tracker('orchestrator'), + ), + 'agent-a': AIAgentConfig( + key='agent-a', + enabled=True, + model=ModelConfig(name='gpt-4', parameters={}), + provider=ProviderConfig(name='openai'), + instructions='You handle topic A.', + tracker=_node_tracker('agent-a'), + ), + 'agent-b': AIAgentConfig( + key='agent-b', + enabled=True, + model=ModelConfig(name='gpt-4', parameters={}), + provider=ProviderConfig(name='openai'), + instructions='You handle topic B.', + tracker=_node_tracker('agent-b'), + ), + } + + edges = [ + Edge(key='orch-to-a', source_config='orchestrator', target_config='agent-a'), + Edge(key='orch-to-b', source_config='orchestrator', target_config='agent-b'), + ] + graph_config = AIAgentGraphConfig( + key='multi-child-graph', + root_config_key='orchestrator', + edges=edges, + enabled=True, + ) + nodes = AgentGraphDefinition.build_nodes(graph_config, configs) + return AgentGraphDefinition( + agent_graph=graph_config, + nodes=nodes, + context=context, + enabled=True, + tracker=graph_tracker, + ) + + +@pytest.mark.asyncio +async def test_multi_child_routes_via_handoff_not_fan_out(): + """Orchestrator with two children routes to exactly one child via handoff tool, + not a fan-out that invokes both children.""" + from langchain_core.messages import AIMessage + + mock_ld_client = MagicMock() + graph = _make_multi_child_graph(mock_ld_client) + + # Orchestrator calls transfer_to_agent_a (handoff tool name derived from child key) + orchestrator_response = AIMessage( + content='', + tool_calls=[{ + 'name': 'transfer_to_agent_a', + 'args': {}, + 'id': 'call_handoff_1', + 'type': 'tool_call', + }], + ) + agent_a_response = _make_fake_response('Agent A handled it.') + agent_b_model = _mock_model(_make_fake_response('Agent B handled it.')) + + def model_factory(node_config, **kwargs): + if node_config.key == 'orchestrator': + return _mock_model(orchestrator_response) + if node_config.key == 'agent-a': + return _mock_model(agent_a_response) + return agent_b_model + + with patch('ldai_langchain.langgraph_agent_graph_runner.create_langchain_model', + side_effect=model_factory): + runner = LangGraphAgentGraphRunner(graph, {}) + result = await runner.run('hello') + + assert result.metrics.success is True + assert 'Agent A' in result.output + # Agent B's model must never have been invoked — no fan-out + agent_b_model.ainvoke.assert_not_called() From 3110be6534ac97e010883165a28b9654b2d90b3a Mon Sep 17 00:00:00 2001 From: jsonbailey Date: Mon, 6 Apr 2026 12:42:55 -0500 Subject: [PATCH 2/6] fix: Remove accidentally resurrected code from pre-PR-115 merge conflict - Remove tool_registry parameter from create_langchain_model (matches main) - Remove _resolve_tools_for_langchain (no longer needed) - Remove BaseMessage import (unused) Co-Authored-By: Claude Sonnet 4.6 --- .../src/ldai_langchain/langchain_helper.py | 45 +++---------------- .../langgraph_agent_graph_runner.py | 2 +- 2 files changed, 6 insertions(+), 41 deletions(-) diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_helper.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_helper.py index 55968231..3cded324 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_helper.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_helper.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Union from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from ldai import LDMessage, log from ldai.models import AIConfigKind from ldai.providers import ToolRegistry @@ -52,18 +52,12 @@ def convert_messages_to_langchain( return result -def create_langchain_model(ai_config: AIConfigKind, tool_registry: Optional[ToolRegistry] = None) -> BaseChatModel: +def create_langchain_model(ai_config: AIConfigKind) -> BaseChatModel: """ Create a LangChain BaseChatModel from a LaunchDarkly AI configuration. - If the config includes tool definitions and a tool_registry is provided, tools found - in the registry are bound to the model. Tools not found in the registry are skipped - with a warning. Built-in provider tools (e.g. code_interpreter) are not supported - via LangChain's bind_tools abstraction and are skipped with a warning. - :param ai_config: The LaunchDarkly AI configuration - :param tool_registry: Optional registry mapping tool names to callable implementations - :return: A configured LangChain BaseChatModel, with tools bound if applicable + :return: A configured LangChain BaseChatModel """ from langchain.chat_models import init_chat_model @@ -74,7 +68,7 @@ def create_langchain_model(ai_config: AIConfigKind, tool_registry: Optional[Tool model_name = model_dict.get('name', '') provider = provider_dict.get('name', '') parameters = dict(model_dict.get('parameters') or {}) - tool_definitions = parameters.pop('tools', []) or [] + parameters.pop('tools', None) mapped_provider = map_provider(provider) # Bedrock requires the foundation provider (e.g. Bedrock:Anthropic) passed in @@ -82,19 +76,12 @@ def create_langchain_model(ai_config: AIConfigKind, tool_registry: Optional[Tool if mapped_provider == 'bedrock_converse' and 'provider' not in parameters: parameters['provider'] = provider.removeprefix('bedrock:') - model = init_chat_model( + return init_chat_model( model_name, model_provider=mapped_provider, **parameters, ) - if tool_definitions and tool_registry is not None: - bindable = _resolve_tools_for_langchain(tool_definitions, tool_registry) - if bindable: - model = model.bind_tools(bindable) - - return model - def _iter_valid_tools( tool_definitions: List[Dict[str, Any]], @@ -131,28 +118,6 @@ def _iter_valid_tools( return valid -def _resolve_tools_for_langchain( - tool_definitions: List[Dict[str, Any]], - tool_registry: ToolRegistry, -) -> List[Dict[str, Any]]: - """ - Match LD tool definitions against a registry, returning function-calling tool dicts - for tools that have a callable implementation. Built-in provider tools and tools - missing from the registry are skipped with a warning. - """ - return [ - { - 'type': 'function', - 'function': { - 'name': name, - 'description': td.get('description', ''), - 'parameters': td.get('parameters', {'type': 'object', 'properties': {}}), - }, - } - for name, td in _iter_valid_tools(tool_definitions, tool_registry) - ] - - def build_tools(ai_config: AIConfigKind, tool_registry: ToolRegistry) -> List[Any]: """ Return callables from the registry for each tool defined in the AI config. diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py index 6523388b..51e9d496 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py @@ -110,7 +110,7 @@ def handle_traversal(node: AgentGraphNode, ctx: dict) -> None: tool_fns: list = [] if node_config.model: # We send an empty tool registry to avoid binding tools to the model. - lc_model = create_langchain_model(node_config, tool_registry=None) + lc_model = create_langchain_model(node_config) tool_fns = build_structured_tools(node_config, tools_ref) From 9240394830bd706a15b9a7447e33316449e66a0b Mon Sep 17 00:00:00 2001 From: jsonbailey Date: Mon, 6 Apr 2026 12:50:13 -0500 Subject: [PATCH 3/6] fix: Resolve mypy lint errors - Use isinstance(gen, ChatGeneration) to type-narrow in callback handler - Declare model as Any to accommodate Runnable return type from bind_tools - Fix bind_kwargs type annotation Co-Authored-By: Claude Sonnet 4.6 --- .../src/ldai_langchain/langgraph_agent_graph_runner.py | 3 ++- .../src/ldai_langchain/langgraph_callback_handler.py | 10 ++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py index 51e9d496..8ef38393 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py @@ -141,10 +141,11 @@ def handle_traversal(node: AgentGraphNode, ctx: dict) -> None: handoff_fns.append(_make_handoff_tool(edge.target_config, description)) all_tools = tool_fns + handoff_fns + model: Any if lc_model and all_tools: # When handoff tools are present, disable parallel tool calls so the LLM # picks exactly one destination rather than routing to multiple children. - bind_kwargs = {'parallel_tool_calls': False} if handoff_fns else {} + bind_kwargs: Dict[str, Any] = {'parallel_tool_calls': False} if handoff_fns else {} model = lc_model.bind_tools(all_tools, **bind_kwargs) else: model = lc_model diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py index e2173662..2a04e51e 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py @@ -3,7 +3,7 @@ from uuid import UUID from langchain_core.callbacks import BaseCallbackHandler -from langchain_core.outputs import LLMResult +from langchain_core.outputs import ChatGeneration, LLMResult from ldai.agent_graph import AgentGraphDefinition from ldai.tracker import TokenUsage @@ -140,9 +140,12 @@ def on_llm_end( return try: - message = response.generations[0][0].message - except (IndexError, AttributeError, TypeError): + gen = response.generations[0][0] + except (IndexError, TypeError): return + if not isinstance(gen, ChatGeneration): + return + message = gen.message usage = get_ai_usage_from_response(message) if usage is None: return @@ -215,4 +218,3 @@ def flush(self, graph: AgentGraphDefinition, graph_tracker: Any) -> None: for tool_key in self._node_tool_calls.get(node_key, []): config_tracker.track_tool_call(tool_key, graph_key=gk) - From f0cccd11707637a95fb881ebc7bb3c698c3cc134 Mon Sep 17 00:00:00 2001 From: jsonbailey Date: Mon, 6 Apr 2026 13:10:23 -0500 Subject: [PATCH 4/6] fix tool edge back to agent --- .../langgraph_agent_graph_runner.py | 6 +- .../tests/test_tracking_langgraph.py | 125 ++++++++++++++++++ 2 files changed, 129 insertions(+), 2 deletions(-) diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py index 8ef38393..acb7bd5b 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py @@ -201,8 +201,10 @@ async def invoke(state: WorkflowState) -> dict: ) else: # Handoff tools use Command(goto=child_key) — LangGraph routes to the - # target directly without any extra edge. The ToolNode does NOT loop - # back here. tools_condition exits to END when no tool is called. + # target directly without any extra edge. Functional tools (if any) + # return normal ToolMessages and must loop back so the LLM sees the result. + if tool_fns: + agent_builder.add_edge(tools_node_key, node_key) agent_builder.add_conditional_edges( node_key, tools_condition, diff --git a/packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py b/packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py index 0f59214a..f8f0649a 100644 --- a/packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py +++ b/packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py @@ -612,3 +612,128 @@ def model_factory(node_config, **kwargs): assert 'Agent A' in result.output # Agent B's model must never have been invoked — no fan-out agent_b_model.ainvoke.assert_not_called() + + +def _make_multi_child_graph_with_tools(mock_ld_client: MagicMock, tool_names: list) -> 'AgentGraphDefinition': + """Build a 3-node graph where the orchestrator also has functional tools.""" + context = MagicMock() + + def _node_tracker(key: str) -> LDAIConfigTracker: + return LDAIConfigTracker( + ld_client=mock_ld_client, + variation_key='test-variation', + config_key=key, + version=1, + model_name='gpt-4', + provider_name='openai', + context=context, + ) + + graph_tracker = AIGraphTracker( + ld_client=mock_ld_client, + variation_key='test-variation', + graph_key='multi-child-tools-graph', + version=1, + context=context, + ) + + tool_defs = [{'name': n, 'type': 'function', 'description': '', 'parameters': {}} for n in tool_names] + configs = { + 'orchestrator': AIAgentConfig( + key='orchestrator', + enabled=True, + model=ModelConfig(name='gpt-4', parameters={'tools': tool_defs}), + provider=ProviderConfig(name='openai'), + instructions='Route to a specialist after gathering info.', + tracker=_node_tracker('orchestrator'), + ), + 'agent-a': AIAgentConfig( + key='agent-a', + enabled=True, + model=ModelConfig(name='gpt-4', parameters={}), + provider=ProviderConfig(name='openai'), + instructions='You handle topic A.', + tracker=_node_tracker('agent-a'), + ), + 'agent-b': AIAgentConfig( + key='agent-b', + enabled=True, + model=ModelConfig(name='gpt-4', parameters={}), + provider=ProviderConfig(name='openai'), + instructions='You handle topic B.', + tracker=_node_tracker('agent-b'), + ), + } + + edges = [ + Edge(key='orch-to-a', source_config='orchestrator', target_config='agent-a'), + Edge(key='orch-to-b', source_config='orchestrator', target_config='agent-b'), + ] + graph_config = AIAgentGraphConfig( + key='multi-child-tools-graph', + root_config_key='orchestrator', + edges=edges, + enabled=True, + ) + nodes = AgentGraphDefinition.build_nodes(graph_config, configs) + return AgentGraphDefinition( + agent_graph=graph_config, + nodes=nodes, + context=context, + enabled=True, + tracker=graph_tracker, + ) + + +@pytest.mark.asyncio +async def test_functional_tool_loops_back_when_handoff_tools_present(): + """When a node has both functional tools and handoff tools, calling a functional + tool must loop back to the node so the LLM sees the result — not silently terminate.""" + from langchain_core.messages import AIMessage + + mock_ld_client = MagicMock() + graph = _make_multi_child_graph_with_tools(mock_ld_client, tool_names=['search']) + + # Step 1: orchestrator calls functional tool 'search' + tool_call_response = AIMessage( + content='', + tool_calls=[{'name': 'search', 'args': {'query': 'topic A'}, 'id': 'call_search_1', 'type': 'tool_call'}], + ) + # Step 2: after seeing tool result, orchestrator hands off to agent-a + handoff_response = AIMessage( + content='', + tool_calls=[{'name': 'transfer_to_agent_a', 'args': {}, 'id': 'call_handoff_1', 'type': 'tool_call'}], + ) + agent_a_response = _make_fake_response('Agent A handled it.') + + orchestrator_model = MagicMock() + orchestrator_model.ainvoke = AsyncMock(side_effect=[tool_call_response, handoff_response]) + orchestrator_model.bind_tools.return_value = orchestrator_model + + agent_a_model = _mock_model(agent_a_response) + agent_b_model = _mock_model(_make_fake_response('Agent B handled it.')) + + def search(query: str = '') -> str: + """Search for information.""" + return f'results for {query}' + + tool_registry = {'search': search} + + def model_factory(node_config, **kwargs): + if node_config.key == 'orchestrator': + return orchestrator_model + if node_config.key == 'agent-a': + return agent_a_model + return agent_b_model + + with patch('ldai_langchain.langgraph_agent_graph_runner.create_langchain_model', + side_effect=model_factory): + runner = LangGraphAgentGraphRunner(graph, tool_registry) + result = await runner.run('Find info and route to the right agent.') + + assert result.metrics.success is True + assert 'Agent A' in result.output + # Orchestrator must have been called twice: once before tool result, once after + assert orchestrator_model.ainvoke.call_count == 2 + # Agent B must never have been invoked + agent_b_model.ainvoke.assert_not_called() From 9dec08511f603d508a9039943c2d285edbed0e80 Mon Sep 17 00:00:00 2001 From: jsonbailey Date: Mon, 6 Apr 2026 14:45:26 -0500 Subject: [PATCH 5/6] fix: Prevent fan-out when functional and handoff tools coexist on same node A static loop-back edge from the tools node conflicted with Command(goto=child) emitted by handoff tools, causing both to fire as a fan-out. Fix by replacing the static edge with a conditional edge in the mixed case: after the ToolNode runs, route back to the parent only when the last message is from a functional tool; route to END when it is from a handoff (Command already handles routing to the target agent). Co-Authored-By: Claude Sonnet 4.6 --- .../langgraph_agent_graph_runner.py | 35 ++++++++++++++----- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py index acb7bd5b..3ef14fea 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py @@ -184,9 +184,6 @@ async def invoke(state: WorkflowState) -> dict: graph_structure.append(node_desc) if all_tools: - # ToolNode handles Command returns from handoff tools, routing to the target - # node. For functional tools it returns normal ToolMessages and we loop back. - # tools_condition exits to END when no tool is called. tools_node_key = f"{node_key}__tools" agent_builder.add_node(tools_node_key, ToolNode(all_tools)) @@ -199,12 +196,34 @@ async def invoke(state: WorkflowState) -> dict: tools_condition, {"tools": tools_node_key, END: after_loop}, ) + elif not tool_fns: + # Only handoff tools: no loop-back needed. + # Command(goto=child_key) handles routing to the target. + agent_builder.add_conditional_edges( + node_key, + tools_condition, + {"tools": tools_node_key, END: END}, + ) else: - # Handoff tools use Command(goto=child_key) — LangGraph routes to the - # target directly without any extra edge. Functional tools (if any) - # return normal ToolMessages and must loop back so the LLM sees the result. - if tool_fns: - agent_builder.add_edge(tools_node_key, node_key) + # Both functional and handoff tools. A static loop-back edge would + # fan-out with Command(goto=child_key) from handoff tools, so use a + # conditional edge that only loops back for functional tool results. + handoff_names_set = frozenset(getattr(t, 'name', '') for t in handoff_fns) + + def make_after_tools_router(parent_key: str, ht_names: frozenset): + def route(state: WorkflowState) -> str: + for msg in reversed(state['messages']): + if hasattr(msg, 'name') and msg.name: + return END if msg.name in ht_names else parent_key + break + return parent_key + return route + + agent_builder.add_conditional_edges( + tools_node_key, + make_after_tools_router(node_key, handoff_names_set), + {node_key: node_key, END: END}, + ) agent_builder.add_conditional_edges( node_key, tools_condition, From f28fe5f0523888e6bff56f3adc71d9ce1de8f678 Mon Sep 17 00:00:00 2001 From: jsonbailey Date: Tue, 7 Apr 2026 08:11:52 -0500 Subject: [PATCH 6/6] fix: Address code review findings in graph runner and callback handler - Cache compiled graph in _ensure_compiled() so _build_graph is not called on every run() invocation - Collect node_keys during traversal instead of reaching into _graph._nodes - Fix unconditional break in make_after_tools_router that made reversed() a no-op; replace broken loop with a direct last-message check - Fix duration tracking key collision: key _node_start_ns by run_id instead of node_key so concurrent invocations of the same node don't clobber each other's start times - Warn when a functional-tool node has multiple outgoing edges since only the first edge is reachable after the tool loop exits Co-Authored-By: Claude Sonnet 4.6 --- .../langgraph_agent_graph_runner.py | 49 +++++++++++++------ .../langgraph_callback_handler.py | 8 +-- 2 files changed, 38 insertions(+), 19 deletions(-) diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py index 3ef14fea..a82b2029 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py @@ -1,7 +1,7 @@ """LangGraph agent graph runner for LaunchDarkly AI SDK.""" import time -from typing import Annotated, Any, Dict, List, Tuple +from typing import Annotated, Any, Dict, List, Optional, Set, Tuple from ldai import log from ldai.agent_graph import AgentGraphDefinition, AgentGraphNode @@ -76,13 +76,25 @@ def __init__(self, graph: AgentGraphDefinition, tools: ToolRegistry): """ self._graph = graph self._tools = tools - - def _build_graph(self) -> Tuple[Any, Dict[str, str]]: + self._compiled: Any = None + self._fn_name_to_config_key: Dict[str, str] = {} + self._node_keys: Set[str] = set() + + def _ensure_compiled(self) -> None: + """Build and cache the compiled graph if not already done.""" + if self._compiled is None: + compiled, fn_name_to_config_key, node_keys = self._build_graph() + self._compiled = compiled + self._fn_name_to_config_key = fn_name_to_config_key + self._node_keys = node_keys + + def _build_graph(self) -> Tuple[Any, Dict[str, str], Set[str]]: """ Build and compile the LangGraph StateGraph from the AgentGraphDefinition. - :return: Tuple of (compiled_graph, fn_name_to_config_key) where - fn_name_to_config_key maps tool function __name__ to LD config key. + :return: Tuple of (compiled_graph, fn_name_to_config_key, node_keys) where + fn_name_to_config_key maps tool function __name__ to LD config key, and + node_keys is the set of all agent node keys in the graph. """ from langchain_core.messages import SystemMessage from langgraph.graph import END, START, StateGraph @@ -99,10 +111,12 @@ class WorkflowState(TypedDict): tools_ref = self._tools graph_structure: List[str] = [] fn_name_to_config_key: Dict[str, str] = {} + node_keys: Set[str] = set() def handle_traversal(node: AgentGraphNode, ctx: dict) -> None: node_config = node.get_config() node_key = node.get_key() + node_keys.add(node_key) instructions = node_config.instructions if hasattr(node_config, 'instructions') else None outgoing_edges = node.get_edges() @@ -190,6 +204,12 @@ async def invoke(state: WorkflowState) -> dict: if not handoff_fns: # No handoff tools: standard loop-back after tool execution. after_loop = outgoing_edges[0].target_config if outgoing_edges else END + if len(outgoing_edges) > 1: + log.warning( + f"Node '{node_key}' has {len(outgoing_edges)} outgoing edges but no handoff " + "tools; only the first edge will be used after the tool loop. " + "Use handoff tools for multi-child routing." + ) agent_builder.add_edge(tools_node_key, node_key) agent_builder.add_conditional_edges( node_key, @@ -212,10 +232,11 @@ async def invoke(state: WorkflowState) -> dict: def make_after_tools_router(parent_key: str, ht_names: frozenset): def route(state: WorkflowState) -> str: - for msg in reversed(state['messages']): - if hasattr(msg, 'name') and msg.name: - return END if msg.name in ht_names else parent_key - break + msgs = state['messages'] + if msgs: + last = msgs[-1] + if hasattr(last, 'name') and last.name in ht_names: + return END return parent_key return route @@ -247,7 +268,7 @@ def route(state: WorkflowState) -> str: ) compiled = agent_builder.compile() - return compiled, fn_name_to_config_key + return compiled, fn_name_to_config_key, node_keys async def run(self, input: Any) -> AgentGraphResult: """ @@ -266,12 +287,10 @@ async def run(self, input: Any) -> AgentGraphResult: try: from langchain_core.messages import HumanMessage - compiled, fn_name_to_config_key = self._build_graph() - - node_keys = {node.get_key() for node in self._graph._nodes.values()} - handler = LDMetricsCallbackHandler(node_keys, fn_name_to_config_key) + self._ensure_compiled() + handler = LDMetricsCallbackHandler(self._node_keys, self._fn_name_to_config_key) - result = await compiled.ainvoke( # type: ignore[call-overload] + result = await self._compiled.ainvoke( # type: ignore[call-overload] {'messages': [HumanMessage(content=str(input))]}, config={'callbacks': [handler], 'recursion_limit': 25}, ) diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py index 2a04e51e..026b8d6f 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py @@ -42,8 +42,8 @@ def __init__(self, node_keys: Set[str], fn_name_to_config_key: Dict[str, str]): self._node_tokens: Dict[str, TokenUsage] = {} # tool config keys called per node self._node_tool_calls: Dict[str, List[str]] = {} - # start time (ns) per node — only set while running - self._node_start_ns: Dict[str, int] = {} + # start time (ns) per active run_id — keyed by run_id to handle re-entrant nodes + self._node_start_ns: Dict[UUID, int] = {} # accumulated duration (ms) per node self._node_duration_ms: Dict[str, int] = {} # execution path in order (deduplicated) @@ -96,7 +96,7 @@ def on_chain_start( if name in self._node_keys: self._run_to_node[run_id] = name - self._node_start_ns[name] = time.perf_counter_ns() + self._node_start_ns[run_id] = time.perf_counter_ns() if name not in self._path_set: self._path.append(name) self._path_set.add(name) @@ -117,7 +117,7 @@ def on_chain_end( node_key = self._run_to_node.get(run_id) if node_key is None: return - start_ns = self._node_start_ns.pop(node_key, None) + start_ns = self._node_start_ns.pop(run_id, None) if start_ns is not None: elapsed_ms = (time.perf_counter_ns() - start_ns) // 1_000_000 self._node_duration_ms[node_key] = (