diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_runner_factory.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_runner_factory.py index 61c85b3..94c427a 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_runner_factory.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_runner_factory.py @@ -39,7 +39,11 @@ def create_agent(self, config: Any, tools: Optional[ToolRegistry] = None) -> Lan ) return LangChainAgentRunner(agent) - def create_agent_graph(self, graph_def: Any, tools: ToolRegistry) -> Any: + def create_agent_graph( + self, + graph_def: Any, + tools: ToolRegistry, + ) -> Any: """ CAUTION: This feature is experimental and should NOT be considered ready for production use. diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py index c828105..9ecb235 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py @@ -1,7 +1,9 @@ """LangGraph agent graph runner for LaunchDarkly AI SDK.""" +import asyncio import time -from typing import Annotated, Any, Dict, List, Optional, Set, Tuple +from contextvars import ContextVar +from typing import Annotated, Any, Dict, List, Set, Tuple from ldai import log from ldai.agent_graph import AgentGraphDefinition, AgentGraphNode @@ -16,6 +18,9 @@ ) from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler +# Per-run eval task accumulator, isolated per concurrent run() call via ContextVar. +_run_eval_tasks: ContextVar[Dict[str, List[asyncio.Task]]] = ContextVar('_run_eval_tasks') + def _make_handoff_tool(child_key: str, description: str) -> Any: """ @@ -67,7 +72,11 @@ class LangGraphAgentGraphRunner(AgentGraphRunner): Requires ``langgraph`` to be installed. """ - def __init__(self, graph: AgentGraphDefinition, tools: ToolRegistry): + def __init__( + self, + graph: AgentGraphDefinition, + tools: ToolRegistry, + ): """ Initialize the runner. @@ -172,6 +181,26 @@ async def invoke(state: WorkflowState) -> dict: if node_instructions: msgs = [SystemMessage(content=node_instructions)] + msgs response = await bound_model.ainvoke(msgs) + + node_obj = self._graph.get_node(nk) + if node_obj is not None: + input_text = '\r\n'.join( + m.content if isinstance(m.content, str) else str(m.content) + for m in msgs + ) if msgs else '' + output_text = ( + response.content if hasattr(response, 'content') else str(response) + ) + task = node_obj.get_config().evaluator.evaluate(input_text, output_text) + run_tasks = _run_eval_tasks.get(None) + if run_tasks is not None: + run_tasks.setdefault(nk, []).append(task) + else: + log.warning( + f"LangGraphAgentGraphRunner: eval task for node '{nk}' " + "has no run context; judge results will not be tracked" + ) + return {'messages': [response]} invoke.__name__ = nk @@ -280,7 +309,9 @@ async def run(self, input: Any) -> AgentGraphResult: :param input: The string prompt to send to the agent graph :return: AgentGraphResult with the final output and metrics """ - tracker = self._graph.create_tracker() if self._graph.create_tracker is not None else None + pending_eval_tasks: Dict[str, List[asyncio.Task]] = {} + token = _run_eval_tasks.set(pending_eval_tasks) + tracker = self._graph.create_tracker() start_ns = time.perf_counter_ns() try: @@ -299,19 +330,18 @@ async def run(self, input: Any) -> AgentGraphResult: output = extract_last_message_content(messages) # Flush per-node metrics to LD trackers - handler.flush(self._graph) + all_eval_results = await handler.flush(self._graph, pending_eval_tasks) - # Graph-level metrics - if tracker: - tracker.track_path(handler.path) - tracker.track_duration(duration) - tracker.track_invocation_success() - tracker.track_total_tokens(sum_token_usage_from_messages(messages)) + tracker.track_path(handler.path) + tracker.track_duration(duration) + tracker.track_invocation_success() + tracker.track_total_tokens(sum_token_usage_from_messages(messages)) return AgentGraphResult( output=output, raw=result, metrics=LDAIMetrics(success=True), + evaluations=all_eval_results, ) except Exception as exc: @@ -323,11 +353,12 @@ async def run(self, input: Any) -> AgentGraphResult: else: log.warning(f'LangGraphAgentGraphRunner run failed: {exc}') duration = (time.perf_counter_ns() - start_ns) // 1_000_000 - if tracker: - tracker.track_duration(duration) - tracker.track_invocation_failure() + tracker.track_duration(duration) + tracker.track_invocation_failure() return AgentGraphResult( output='', raw=None, metrics=LDAIMetrics(success=False), ) + finally: + _run_eval_tasks.reset(token) diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py index a517a03..183a3eb 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py @@ -5,6 +5,7 @@ from langchain_core.callbacks import BaseCallbackHandler from langchain_core.outputs import ChatGeneration, LLMResult from ldai.agent_graph import AgentGraphDefinition +from ldai.providers.types import JudgeResult from ldai.tracker import TokenUsage from ldai_langchain.langchain_helper import get_ai_usage_from_response @@ -188,15 +189,22 @@ def on_tool_end( # Flush # ------------------------------------------------------------------ - def flush(self, graph: AgentGraphDefinition) -> None: + async def flush( + self, graph: AgentGraphDefinition, eval_tasks=None + ) -> List[JudgeResult]: """ Emit all collected per-node metrics to the LaunchDarkly trackers. Call this once after the graph run completes. :param graph: The AgentGraphDefinition whose nodes hold the LD config trackers. + :param eval_tasks: Optional dict mapping node key to a list of awaitables that + return judge evaluation results. Multiple tasks arise when a node is visited + more than once (e.g. in a graph with cycles). + :return: All judge results collected across all nodes. """ node_trackers: Dict[str, Any] = {} + all_eval_results: List[JudgeResult] = [] for node_key in self._path: if node_key in node_trackers: continue @@ -220,3 +228,15 @@ def flush(self, graph: AgentGraphDefinition) -> None: for tool_key in self._node_tool_calls.get(node_key, []): config_tracker.track_tool_call(tool_key) + + if not eval_tasks: + continue + + for eval_task in eval_tasks.get(node_key, []): + results = await eval_task + all_eval_results.extend(results) + for r in results: + if r.success: + config_tracker.track_judge_result(r) + + return all_eval_results diff --git a/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py b/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py index 5f78f45..4018e7c 100644 --- a/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py +++ b/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py @@ -6,6 +6,7 @@ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from ldai import LDMessage +from ldai.evaluator import Evaluator from ldai_langchain import ( LangChainModelRunner, @@ -530,6 +531,7 @@ def sync_tool(x: str = '') -> str: cfg = AIAgentConfig( key='n', enabled=True, + evaluator=Evaluator.noop(), create_tracker=MagicMock(), model=ModelConfig( name='gpt-4', @@ -553,6 +555,7 @@ async def async_tool(x: str = '') -> str: cfg = AIAgentConfig( key='n', enabled=True, + evaluator=Evaluator.noop(), create_tracker=MagicMock(), model=ModelConfig( name='gpt-4', diff --git a/packages/ai-providers/server-ai-langchain/tests/test_langgraph_agent_graph_runner.py b/packages/ai-providers/server-ai-langchain/tests/test_langgraph_agent_graph_runner.py index 23fb345..0a3ff6c 100644 --- a/packages/ai-providers/server-ai-langchain/tests/test_langgraph_agent_graph_runner.py +++ b/packages/ai-providers/server-ai-langchain/tests/test_langgraph_agent_graph_runner.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, MagicMock, patch from ldai.agent_graph import AgentGraphDefinition +from ldai.evaluator import Evaluator from ldai.models import AIAgentGraphConfig, AIAgentConfig, ModelConfig, ProviderConfig from ldai.providers import AgentGraphResult, ToolRegistry from ldai_langchain.langgraph_agent_graph_runner import LangGraphAgentGraphRunner @@ -20,6 +21,7 @@ def _make_graph(enabled: bool = True) -> AgentGraphDefinition: model=ModelConfig(name='gpt-4'), provider=ProviderConfig(name='openai'), instructions='You are a helpful assistant.', + evaluator=Evaluator.noop(), ) graph_config = AIAgentGraphConfig( key='test-graph', diff --git a/packages/ai-providers/server-ai-langchain/tests/test_langgraph_callback_handler.py b/packages/ai-providers/server-ai-langchain/tests/test_langgraph_callback_handler.py index 73330ef..65592fa 100644 --- a/packages/ai-providers/server-ai-langchain/tests/test_langgraph_callback_handler.py +++ b/packages/ai-providers/server-ai-langchain/tests/test_langgraph_callback_handler.py @@ -17,6 +17,7 @@ from ldai.agent_graph import AgentGraphDefinition from ldai.models import AIAgentConfig, AIAgentGraphConfig, ModelConfig, ProviderConfig from ldai.tracker import AIGraphTracker, LDAIConfigTracker, TokenUsage +from ldai.evaluator import Evaluator from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler @@ -48,6 +49,7 @@ def _make_graph(mock_ld_client: MagicMock, node_key: str = 'root-agent', graph_k node_config = AIAgentConfig( key=node_key, enabled=True, + evaluator=Evaluator.noop(), model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='Be helpful.', @@ -317,7 +319,8 @@ def test_on_tool_end_none_name_ignored(): # flush() tests # --------------------------------------------------------------------------- -def test_flush_emits_token_events_to_ld_tracker(): +@pytest.mark.asyncio +async def test_flush_emits_token_events_to_ld_tracker(): """flush() calls track_tokens on the node's config tracker.""" mock_ld_client = MagicMock() graph = _make_graph(mock_ld_client, node_key='root-agent', graph_key='g1') @@ -327,7 +330,7 @@ def test_flush_emits_token_events_to_ld_tracker(): node_run_id = uuid4() handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent') handler.on_llm_end(_llm_result(15, 10, 5), run_id=uuid4(), parent_run_id=node_run_id) - handler.flush(graph) + await handler.flush(graph) ev = _events(mock_ld_client) assert ev['$ld:ai:tokens:total'][0][1] == 15 @@ -336,7 +339,8 @@ def test_flush_emits_token_events_to_ld_tracker(): assert ev['$ld:ai:generation:success'][0][1] == 1 -def test_flush_emits_duration(): +@pytest.mark.asyncio +async def test_flush_emits_duration(): """flush() calls track_duration when duration was recorded.""" mock_ld_client = MagicMock() graph = _make_graph(mock_ld_client) @@ -346,13 +350,14 @@ def test_flush_emits_duration(): run_id = uuid4() handler.on_chain_start({}, {}, run_id=run_id, name='root-agent') handler.on_chain_end({}, run_id=run_id) - handler.flush(graph) + await handler.flush(graph) ev = _events(mock_ld_client) assert '$ld:ai:duration:total' in ev -def test_flush_emits_tool_calls(): +@pytest.mark.asyncio +async def test_flush_emits_tool_calls(): """flush() calls track_tool_call for each recorded tool invocation.""" mock_ld_client = MagicMock() graph = _make_graph(mock_ld_client) @@ -366,7 +371,7 @@ def test_flush_emits_tool_calls(): tools_run_id = uuid4() handler.on_chain_start({}, {}, run_id=tools_run_id, name='root-agent__tools') handler.on_tool_end('r', run_id=uuid4(), parent_run_id=tools_run_id, name='fn_search') - handler.flush(graph) + await handler.flush(graph) ev = _events(mock_ld_client) tool_events = ev.get('$ld:ai:tool_call', []) @@ -374,7 +379,8 @@ def test_flush_emits_tool_calls(): assert tool_events[0][0]['toolKey'] == 'search' -def test_flush_includes_graph_key_in_node_events(): +@pytest.mark.asyncio +async def test_flush_includes_graph_key_in_node_events(): """flush() passes graph_key to the node tracker so graphKey appears in events.""" mock_ld_client = MagicMock() graph = _make_graph(mock_ld_client, graph_key='my-graph') @@ -384,14 +390,15 @@ def test_flush_includes_graph_key_in_node_events(): node_run_id = uuid4() handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent') handler.on_llm_end(_llm_result(5, 3, 2), run_id=uuid4(), parent_run_id=node_run_id) - handler.flush(graph) + await handler.flush(graph) ev = _events(mock_ld_client) token_data = ev['$ld:ai:tokens:total'][0][0] assert token_data.get('graphKey') == 'my-graph' -def test_flush_with_no_graph_key_on_node_tracker(): +@pytest.mark.asyncio +async def test_flush_with_no_graph_key_on_node_tracker(): """When node tracker has no graph_key, events omit graphKey.""" mock_ld_client = MagicMock() context = MagicMock() @@ -408,6 +415,7 @@ def test_flush_with_no_graph_key_on_node_tracker(): node_config = AIAgentConfig( key='root-agent', enabled=True, + evaluator=Evaluator.noop(), model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='Be helpful.', @@ -425,21 +433,22 @@ def test_flush_with_no_graph_key_on_node_tracker(): nodes=nodes, context=context, enabled=True, - create_tracker=lambda: None, + create_tracker=lambda: AIGraphTracker(mock_ld_client, 'v1', 'test-graph', 1, context), ) handler = LDMetricsCallbackHandler({'root-agent'}, {}) node_run_id = uuid4() handler.on_chain_start({}, {}, run_id=node_run_id, name='root-agent') handler.on_llm_end(_llm_result(5, 3, 2), run_id=uuid4(), parent_run_id=node_run_id) - handler.flush(graph) + await handler.flush(graph) ev = _events(mock_ld_client) token_data = ev['$ld:ai:tokens:total'][0][0] assert 'graphKey' not in token_data -def test_flush_skips_nodes_not_in_path(): +@pytest.mark.asyncio +async def test_flush_skips_nodes_not_in_path(): """flush() only emits events for nodes that were actually executed.""" mock_ld_client = MagicMock() graph = _make_graph(mock_ld_client) @@ -447,14 +456,15 @@ def test_flush_skips_nodes_not_in_path(): # Handler with 'root-agent' in node_keys but never started handler = LDMetricsCallbackHandler({'root-agent'}, {}) - handler.flush(graph) + await handler.flush(graph) ev = _events(mock_ld_client) assert '$ld:ai:tokens:total' not in ev assert '$ld:ai:generation:success' not in ev -def test_flush_skips_node_without_tracker(): +@pytest.mark.asyncio +async def test_flush_skips_node_without_tracker(): """flush() silently skips nodes whose config has no tracker.""" mock_ld_client = MagicMock() context = MagicMock() @@ -463,6 +473,7 @@ def test_flush_skips_node_without_tracker(): key='no-track', enabled=True, create_tracker=lambda: None, + evaluator=Evaluator.noop(), model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='', @@ -483,7 +494,7 @@ def test_flush_skips_node_without_tracker(): node_run_id = uuid4() handler.on_chain_start({}, {}, run_id=node_run_id, name='no-track') handler.on_llm_end(_llm_result(5, 3, 2), run_id=uuid4(), parent_run_id=node_run_id) - handler.flush(graph) # should not raise + await handler.flush(graph) # should not raise mock_ld_client.track.assert_not_called() diff --git a/packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py b/packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py index 832d8c1..3b45783 100644 --- a/packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py +++ b/packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py @@ -13,6 +13,7 @@ from ldai.agent_graph import AgentGraphDefinition from ldai.models import AIAgentGraphConfig, AIAgentConfig, Edge, ModelConfig, ProviderConfig from ldai.tracker import AIGraphTracker, LDAIConfigTracker +from ldai.evaluator import Evaluator from ldai_langchain.langgraph_agent_graph_runner import LangGraphAgentGraphRunner pytestmark = pytest.mark.skipif( @@ -66,6 +67,7 @@ def _make_graph( root_config = AIAgentConfig( key=node_key, enabled=True, + evaluator=Evaluator.noop(), model=ModelConfig(name='gpt-4', parameters={'tools': tool_defs} if tool_defs else {}), provider=ProviderConfig(name='openai'), instructions='You are a helpful assistant.', @@ -168,6 +170,7 @@ def _make_two_node_graph(mock_ld_client: MagicMock) -> 'AgentGraphDefinition': root_config = AIAgentConfig( key='root-agent', enabled=True, + evaluator=Evaluator.noop(), model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='You are root.', @@ -176,6 +179,7 @@ def _make_two_node_graph(mock_ld_client: MagicMock) -> 'AgentGraphDefinition': child_config = AIAgentConfig( key='child-agent', enabled=True, + evaluator=Evaluator.noop(), model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='You are child.', @@ -246,7 +250,7 @@ async def test_tracks_node_and_graph_tokens_on_success(): ) handler.on_llm_end(llm_result, run_id=uuid4(), parent_run_id=node_run_id) handler.on_chain_end({}, run_id=node_run_id) - handler.flush(graph2) + await handler.flush(graph2) ev2 = _events(mock_ld_client2) assert ev2['$ld:ai:tokens:total'][0][1] == 15 @@ -320,7 +324,7 @@ def get_weather(location: str = 'NYC') -> str: tools_run_id = uuid4() handler.on_chain_start({}, {}, run_id=tools_run_id, name='root-agent__tools') handler.on_tool_end('sunny', run_id=uuid4(), parent_run_id=tools_run_id, name='get_weather') - handler.flush(graph2) + await handler.flush(graph2) ev2 = _events(mock_ld_client2) tool_events = ev2.get('$ld:ai:tool_call', []) @@ -374,7 +378,7 @@ def summarize(text: str = '') -> str: handler.on_chain_start({}, {}, run_id=tools_run_id, name='root-agent__tools') handler.on_tool_end('result', run_id=uuid4(), parent_run_id=tools_run_id, name='search') handler.on_tool_end('summary', run_id=uuid4(), parent_run_id=tools_run_id, name='summarize') - handler.flush(graph2) + await handler.flush(graph2) ev2 = _events(mock_ld_client2) tool_keys = [data['toolKey'] for data, _ in ev2.get('$ld:ai:tool_call', [])] @@ -405,7 +409,7 @@ async def test_tracks_graph_key_on_node_events(): llm_output={}, ) handler.on_llm_end(llm_result, run_id=uuid4(), parent_run_id=node_run_id) - handler.flush(graph) + await handler.flush(graph) ev = _events(mock_ld_client) token_data = ev['$ld:ai:tokens:total'][0][0] @@ -490,7 +494,7 @@ def model_factory(node_config, **kwargs): ) handler.on_llm_end(child_llm_result, run_id=uuid4(), parent_run_id=child_run_id) - handler.flush(graph2) + await handler.flush(graph2) ev2 = _events(mock_ld_client2) @@ -539,6 +543,7 @@ def _node_tracker(key: str) -> LDAIConfigTracker: 'orchestrator': AIAgentConfig( key='orchestrator', enabled=True, + evaluator=Evaluator.noop(), model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='Route to the appropriate specialist agent.', @@ -547,6 +552,7 @@ def _node_tracker(key: str) -> LDAIConfigTracker: 'agent-a': AIAgentConfig( key='agent-a', enabled=True, + evaluator=Evaluator.noop(), model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='You handle topic A.', @@ -555,6 +561,7 @@ def _node_tracker(key: str) -> LDAIConfigTracker: 'agent-b': AIAgentConfig( key='agent-b', enabled=True, + evaluator=Evaluator.noop(), model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='You handle topic B.', @@ -652,6 +659,7 @@ def _node_tracker(key: str) -> LDAIConfigTracker: 'orchestrator': AIAgentConfig( key='orchestrator', enabled=True, + evaluator=Evaluator.noop(), model=ModelConfig(name='gpt-4', parameters={'tools': tool_defs}), provider=ProviderConfig(name='openai'), instructions='Route to a specialist after gathering info.', @@ -660,6 +668,7 @@ def _node_tracker(key: str) -> LDAIConfigTracker: 'agent-a': AIAgentConfig( key='agent-a', enabled=True, + evaluator=Evaluator.noop(), model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='You handle topic A.', @@ -668,6 +677,7 @@ def _node_tracker(key: str) -> LDAIConfigTracker: 'agent-b': AIAgentConfig( key='agent-b', enabled=True, + evaluator=Evaluator.noop(), model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='You handle topic B.', diff --git a/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_agent_graph_runner.py b/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_agent_graph_runner.py index 51351f9..6d35328 100644 --- a/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_agent_graph_runner.py +++ b/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_agent_graph_runner.py @@ -46,7 +46,11 @@ class OpenAIAgentGraphRunner(AgentGraphRunner): Requires ``openai-agents`` to be installed. """ - def __init__(self, graph: AgentGraphDefinition, tools: ToolRegistry): + def __init__( + self, + graph: AgentGraphDefinition, + tools: ToolRegistry, + ): """ Initialize the runner. @@ -70,36 +74,36 @@ async def run(self, input: Any) -> AgentGraphResult: :param input: The string prompt to send to the agent graph :return: AgentGraphResult with the final output and metrics """ - tracker = self._graph.create_tracker() if self._graph.create_tracker is not None else None + tracker = self._graph.create_tracker() path: List[str] = [] root_node = self._graph.root() root_key = root_node.get_key() if root_node else '' if root_key: path.append(root_key) + input_str = str(input) start_ns = time.perf_counter_ns() state = _RunState(last_handoff_ns=start_ns, last_node_key=root_key) try: from agents import Runner root_agent = self._build_agents(path, state, tracker) - result = await Runner.run(root_agent, str(input)) + result = await Runner.run(root_agent, input_str) self._flush_final_segment(state, result) self._track_tool_calls(result) duration = (time.perf_counter_ns() - start_ns) // 1_000_000 + token_usage = get_ai_usage_from_response(result) - if tracker: - tracker.track_path(path) - tracker.track_duration(duration) - tracker.track_invocation_success() - token_usage = get_ai_usage_from_response(result) - if token_usage is not None: - tracker.track_total_tokens(token_usage) + tracker.track_path(path) + tracker.track_duration(duration) + tracker.track_invocation_success() + if token_usage is not None: + tracker.track_total_tokens(token_usage) return AgentGraphResult( output=str(result.final_output), raw=result, - metrics=LDAIMetrics(success=True), + metrics=LDAIMetrics(success=True, usage=token_usage), ) except Exception as exc: if isinstance(exc, ImportError): @@ -110,9 +114,8 @@ async def run(self, input: Any) -> AgentGraphResult: else: log.warning(f'OpenAIAgentGraphRunner run failed: {exc}') duration = (time.perf_counter_ns() - start_ns) // 1_000_000 - if tracker: - tracker.track_duration(duration) - tracker.track_invocation_failure() + tracker.track_duration(duration) + tracker.track_invocation_failure() return AgentGraphResult( output='', raw=None, @@ -222,9 +225,7 @@ def _make_on_handoff( state: _RunState, ): def on_handoff(run_ctx: Any) -> None: - self._handle_handoff( - run_ctx, src, tgt, path, tracker, config_tracker, state - ) + self._handle_handoff(run_ctx, src, tgt, path, tracker, config_tracker, state) return on_handoff def _handle_handoff( @@ -239,8 +240,7 @@ def _handle_handoff( ) -> None: path.append(tgt) state.last_node_key = tgt - if tracker: - tracker.track_handoff_success(src, tgt) + tracker.track_handoff_success(src, tgt) now_ns = time.perf_counter_ns() duration_ms = (now_ns - state.last_handoff_ns) // 1_000_000 @@ -261,11 +261,7 @@ def _handle_handoff( config_tracker.track_duration(int(duration_ms)) config_tracker.track_success() - def _flush_final_segment( - self, - state: _RunState, - result: Any, - ) -> None: + def _flush_final_segment(self, state: _RunState, result: Any) -> None: """Record duration/tokens for the last active agent (no handoff after it).""" if not state.last_node_key: return diff --git a/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_runner_factory.py b/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_runner_factory.py index a653d1f..93d5577 100644 --- a/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_runner_factory.py +++ b/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_runner_factory.py @@ -64,7 +64,11 @@ def create_agent(self, config: Any, tools: Optional[ToolRegistry] = None) -> 'Op tools or {}, ) - def create_agent_graph(self, graph_def: Any, tools: ToolRegistry) -> Any: + def create_agent_graph( + self, + graph_def: Any, + tools: ToolRegistry, + ) -> Any: """ CAUTION: This feature is experimental and should NOT be considered ready for production use. diff --git a/packages/ai-providers/server-ai-openai/tests/test_openai_agent_graph_runner.py b/packages/ai-providers/server-ai-openai/tests/test_openai_agent_graph_runner.py index 9a8369c..bba5a84 100644 --- a/packages/ai-providers/server-ai-openai/tests/test_openai_agent_graph_runner.py +++ b/packages/ai-providers/server-ai-openai/tests/test_openai_agent_graph_runner.py @@ -8,6 +8,7 @@ from ldai.providers import AgentGraphResult, ToolRegistry from ldai_openai.openai_agent_graph_runner import OpenAIAgentGraphRunner from ldai_openai.openai_runner_factory import OpenAIRunnerFactory +from ldai.evaluator import Evaluator def _make_graph(enabled: bool = True) -> AgentGraphDefinition: @@ -19,6 +20,7 @@ def _make_graph(enabled: bool = True) -> AgentGraphDefinition: root_config = AIAgentConfig( key='root-agent', enabled=enabled, + evaluator=Evaluator.noop(), model=ModelConfig(name='gpt-4'), provider=ProviderConfig(name='openai'), instructions='You are a helpful assistant.', diff --git a/packages/ai-providers/server-ai-openai/tests/test_tracking_openai_agents.py b/packages/ai-providers/server-ai-openai/tests/test_tracking_openai_agents.py index e751fc7..6d8cbc4 100644 --- a/packages/ai-providers/server-ai-openai/tests/test_tracking_openai_agents.py +++ b/packages/ai-providers/server-ai-openai/tests/test_tracking_openai_agents.py @@ -14,6 +14,7 @@ from ldai.models import AIAgentGraphConfig, AIAgentConfig, Edge, ModelConfig, ProviderConfig from ldai.tracker import AIGraphTracker, LDAIConfigTracker from ldai_openai.openai_agent_graph_runner import OpenAIAgentGraphRunner +from ldai.evaluator import Evaluator # --------------------------------------------------------------------------- @@ -61,6 +62,7 @@ def _make_graph( root_config = AIAgentConfig( key=node_key, enabled=True, + evaluator=Evaluator.noop(), model=ModelConfig(name='gpt-4', parameters={'tools': tool_defs} if tool_defs else {}), provider=ProviderConfig(name='openai'), instructions='You are a helpful assistant.', @@ -205,6 +207,7 @@ def _make_two_node_graph(mock_ld_client: MagicMock) -> AgentGraphDefinition: root_config = AIAgentConfig( key='root-agent', enabled=True, + evaluator=Evaluator.noop(), model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='You are root.', @@ -213,6 +216,7 @@ def _make_two_node_graph(mock_ld_client: MagicMock) -> AgentGraphDefinition: child_config = AIAgentConfig( key='child-agent', enabled=True, + evaluator=Evaluator.noop(), model=ModelConfig(name='gpt-4', parameters={}), provider=ProviderConfig(name='openai'), instructions='You are child.', diff --git a/packages/sdk/server-ai/src/ldai/__init__.py b/packages/sdk/server-ai/src/ldai/__init__.py index b7d8752..405ec5a 100644 --- a/packages/sdk/server-ai/src/ldai/__init__.py +++ b/packages/sdk/server-ai/src/ldai/__init__.py @@ -5,6 +5,7 @@ from ldai.agent_graph import AgentGraphDefinition from ldai.chat import Chat # Deprecated — use ManagedModel from ldai.client import LDAIClient +from ldai.evaluator import Evaluator from ldai.judge import Judge from ldai.managed_agent import ManagedAgent from ldai.managed_agent_graph import ManagedAgentGraph @@ -42,6 +43,7 @@ __all__ = [ 'LDAIClient', + 'Evaluator', 'AgentRunner', 'AgentGraphRunner', 'AgentResult', diff --git a/packages/sdk/server-ai/src/ldai/agent_graph/__init__.py b/packages/sdk/server-ai/src/ldai/agent_graph/__init__.py index ffb1995..5980952 100644 --- a/packages/sdk/server-ai/src/ldai/agent_graph/__init__.py +++ b/packages/sdk/server-ai/src/ldai/agent_graph/__init__.py @@ -52,7 +52,7 @@ def __init__( nodes: Dict[str, AgentGraphNode], context: Context, enabled: bool, - create_tracker: Optional[Callable[[], AIGraphTracker]] = None, + create_tracker: Callable[[], AIGraphTracker], ): self._agent_graph = agent_graph self._context = context diff --git a/packages/sdk/server-ai/src/ldai/client.py b/packages/sdk/server-ai/src/ldai/client.py index 4023574..065da73 100644 --- a/packages/sdk/server-ai/src/ldai/client.py +++ b/packages/sdk/server-ai/src/ldai/client.py @@ -7,6 +7,7 @@ from ldai import log from ldai.agent_graph import AgentGraphDefinition +from ldai.evaluator import Evaluator from ldai.judge import Judge from ldai.managed_agent import ManagedAgent from ldai.managed_agent_graph import ManagedAgentGraph @@ -107,12 +108,14 @@ def _completion_config( context: Context, default: AICompletionConfigDefault, variables: Optional[Dict[str, Any]] = None, + default_ai_provider: Optional[str] = None, ) -> AICompletionConfig: (model, provider, messages, instructions, tracker_factory, enabled, judge_configuration, variation) = self.__evaluate( key, context, default.to_dict(), variables ) + evaluator = self._build_evaluator(judge_configuration, context, default_ai_provider, variables) tools = _parse_tools(variation.get('tools')) config = AICompletionConfig( @@ -122,6 +125,7 @@ def _completion_config( messages=messages, provider=provider, create_tracker=tracker_factory, + evaluator=evaluator, judge_configuration=judge_configuration, tools=tools, ) @@ -134,6 +138,7 @@ def completion_config( context: Context, default: Optional[AICompletionConfigDefault] = None, variables: Optional[Dict[str, Any]] = None, + default_ai_provider: Optional[str] = None, ) -> AICompletionConfig: """ Get the value of a completion configuration. @@ -143,12 +148,13 @@ def completion_config( :param default: The default value of the completion configuration. When not provided, a disabled config is used as the fallback. :param variables: Additional variables for the completion configuration. + :param default_ai_provider: Optional default AI provider to use for judge evaluation. :return: The completion configuration with a tracker used for gathering metrics. """ self._client.track(_TRACK_USAGE_COMPLETION_CONFIG, context, key, 1) return self._completion_config( - key, context, default or _DISABLED_COMPLETION_DEFAULT, variables + key, context, default or _DISABLED_COMPLETION_DEFAULT, variables, default_ai_provider ) def config( @@ -235,7 +241,7 @@ def judge_config( key, context, default or _DISABLED_JUDGE_DEFAULT, variables ) - async def create_judge( + def create_judge( self, key: str, context: Context, @@ -304,7 +310,7 @@ async def create_judge( except Exception as error: return None - async def _initialize_judges( + def _initialize_judges( self, judge_configs: List[JudgeConfiguration.Judge], context: Context, @@ -322,33 +328,48 @@ async def _initialize_judges( """ judges: Dict[str, Judge] = {} - async def create_judge_for_config(judge_key: str): - judge = await self.create_judge( - judge_key, - context, - AIJudgeConfigDefault.disabled(), - variables, - default_ai_provider, - ) - return judge_key, judge - - judge_promises = [ - create_judge_for_config(judge_config.key) - for judge_config in judge_configs - ] - - import asyncio - results = await asyncio.gather(*judge_promises, return_exceptions=True) - - for result in results: - if isinstance(result, Exception): + for judge_config in judge_configs: + try: + judge = self.create_judge( + judge_config.key, + context, + AIJudgeConfigDefault.disabled(), + variables, + default_ai_provider, + ) + if judge: + judges[judge_config.key] = judge + except Exception as e: + log.warning(f'Failed to initialize judge {judge_config.key!r}: {e}') continue - judge_key, judge = result # type: ignore[misc] - if judge: - judges[judge_key] = judge return judges + def _build_evaluator( + self, + judge_configuration: Optional[JudgeConfiguration], + context: Context, + default_ai_provider: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + ) -> Evaluator: + """ + Build an Evaluator for the given judge configuration. + + :param judge_configuration: The judge configuration listing judges to initialize + :param context: Standard Context used when evaluating flags + :param default_ai_provider: Optional default AI provider to use + :param variables: Optional variables for judge instruction interpolation + :return: Evaluator wrapping the initialized judges, or a no-op Evaluator if + judge_configuration is None or has no judges + """ + if not judge_configuration or not judge_configuration.judges: + return Evaluator.noop() + judges = self._initialize_judges( + judge_configuration.judges, context, default_ai_provider=default_ai_provider, + variables=variables, + ) + return Evaluator(judges, judge_configuration) + async def create_model( self, key: str, @@ -388,7 +409,9 @@ async def create_model( """ self._client.track(_TRACK_USAGE_CREATE_MODEL, context, key, 1) log.debug(f"Creating managed model for key: {key}") - config = self._completion_config(key, context, default or _DISABLED_COMPLETION_DEFAULT, variables) + config = self._completion_config( + key, context, default or _DISABLED_COMPLETION_DEFAULT, variables, default_ai_provider + ) if not config.enabled: return None @@ -397,16 +420,7 @@ async def create_model( if not runner: return None - judges = {} - if config.judge_configuration and config.judge_configuration.judges: - judges = await self._initialize_judges( - config.judge_configuration.judges, - context, - variables, - default_ai_provider, - ) - - return ManagedModel(config, runner, judges) + return ManagedModel(config, runner) async def create_chat( self, @@ -471,7 +485,10 @@ async def create_agent( """ self._client.track(_TRACK_USAGE_CREATE_AGENT, context, key, 1) log.debug(f"Creating managed agent for key: {key}") - config = self.__evaluate_agent(key, context, default or _DISABLED_AGENT_DEFAULT, variables) + config = self.__evaluate_agent( + key, context, default or _DISABLED_AGENT_DEFAULT, variables, + default_ai_provider=default_ai_provider, + ) if not config.enabled: return None @@ -613,6 +630,7 @@ def agent_graph( self, key: str, context: Context, + default_ai_provider: Optional[str] = None, ) -> AgentGraphDefinition: """` Retrieve an AI agent graph. @@ -657,7 +675,8 @@ def graph_tracker_factory() -> AIGraphTracker: graph_key_value = key agent_configs = { agent_key: self.__evaluate_agent( - agent_key, context, AIAgentConfigDefault.disabled(), graph_key=graph_key_value + agent_key, context, AIAgentConfigDefault.disabled(), graph_key=graph_key_value, + default_ai_provider=default_ai_provider, ) for agent_key in all_agent_keys } @@ -765,12 +784,12 @@ async def create_agent_graph( self._client.track(_TRACK_USAGE_CREATE_AGENT_GRAPH, context, key, 1) log.debug(f"Creating managed agent graph for key: {key}") - graph = self.agent_graph(key, context) + graph = self.agent_graph(key, context, default_ai_provider) if not graph.enabled: return None runner = RunnerFactory.create_agent_graph( - graph, tools or {}, default_ai_provider + graph, tools or {}, default_ai_provider, ) if not runner: return None @@ -902,6 +921,7 @@ def __evaluate_agent( default: AIAgentConfigDefault, variables: Optional[Dict[str, Any]] = None, graph_key: Optional[str] = None, + default_ai_provider: Optional[str] = None, ) -> AIAgentConfig: """ Internal method to evaluate an agent configuration. @@ -911,6 +931,7 @@ def __evaluate_agent( :param default: Default agent values. :param variables: Variables for interpolation. :param graph_key: When set, passed to the tracker so all events include ``graphKey``. + :param default_ai_provider: Optional default AI provider for judge evaluation. :return: Configured AIAgentConfig instance. """ (model, provider, messages, instructions, @@ -921,6 +942,9 @@ def __evaluate_agent( # For agents, prioritize instructions over messages final_instructions = instructions if instructions is not None else default.instructions + effective_judge_configuration = judge_configuration or JudgeConfiguration(judges=[]) + + evaluator = self._build_evaluator(effective_judge_configuration, context, default_ai_provider, variables) tools = _parse_tools(variation.get('tools')) return AIAgentConfig( @@ -930,7 +954,8 @@ def __evaluate_agent( provider=provider or default.provider, instructions=final_instructions, create_tracker=tracker_factory, - judge_configuration=judge_configuration or default.judge_configuration, + evaluator=evaluator, + judge_configuration=effective_judge_configuration, tools=tools, ) diff --git a/packages/sdk/server-ai/src/ldai/evaluator.py b/packages/sdk/server-ai/src/ldai/evaluator.py new file mode 100644 index 0000000..dce8b83 --- /dev/null +++ b/packages/sdk/server-ai/src/ldai/evaluator.py @@ -0,0 +1,77 @@ +"""Evaluator implementation for coordinating multiple judges.""" + +from __future__ import annotations + +import asyncio +from typing import Dict, List + +from ldai import log +from ldai.judge import Judge +from ldai.models import JudgeConfiguration +from ldai.providers.types import JudgeResult + + +class Evaluator: + """ + Coordinates multiple judge evaluations for a single AI config invocation. + + Instances are created by the SDK client via ``_build_evaluator()`` and injected + into ``AIConfig`` objects (and runners) at construction time. User code should + not need to construct this directly. + """ + + def __init__(self, judges: Dict[str, Judge], judge_configuration: JudgeConfiguration): + """ + Initialize the Evaluator. + + :param judges: Mapping of judge config key to initialized Judge instances + :param judge_configuration: The judge configuration specifying which judges to run + """ + self._judges = judges + self._judge_configuration = judge_configuration + + @classmethod + def noop(cls) -> Evaluator: + return cls({}, JudgeConfiguration(judges=[])) + + def evaluate( + self, + input_text: str, + output_text: str, + ) -> asyncio.Task[List[JudgeResult]]: + """ + Run all configured judges against the given input/output pair. + + Schedules the judge evaluations as an asyncio Task and returns it + immediately. The caller can await the task to get results or pass it + to tracking helpers. + + :param input_text: The input that was provided to the AI model + :param output_text: The AI-generated output to evaluate + :return: An asyncio Task that resolves to a list of JudgeResult instances + """ + return asyncio.create_task(self._run_judges(input_text, output_text)) + + async def _run_judges( + self, + input_text: str, + output_text: str, + ) -> List[JudgeResult]: + """ + Execute all configured judges and collect results. + + :param input_text: The input that was provided to the AI model + :param output_text: The AI-generated output to evaluate + :return: List of JudgeResult instances (one per configured judge that was found) + """ + if not self._judge_configuration.judges: + log.debug('No judges configured, no evaluations to run') + return [] + results: List[JudgeResult] = [] + for jc in self._judge_configuration.judges: + judge = self._judges.get(jc.key) + if not judge: + log.warning(f'Judge not enabled: {jc.key}') + continue + results.append(await judge.evaluate(input_text, output_text, jc.sampling_rate)) + return results diff --git a/packages/sdk/server-ai/src/ldai/managed_model.py b/packages/sdk/server-ai/src/ldai/managed_model.py index ef6f21e..566af7a 100644 --- a/packages/sdk/server-ai/src/ldai/managed_model.py +++ b/packages/sdk/server-ai/src/ldai/managed_model.py @@ -1,8 +1,6 @@ import asyncio -from typing import Any, Dict, List, Optional +from typing import List, Optional -from ldai import log -from ldai.judge import Judge from ldai.models import AICompletionConfig, LDMessage from ldai.providers.model_runner import ModelRunner from ldai.providers.types import JudgeResult, ModelResponse @@ -22,11 +20,9 @@ def __init__( self, ai_config: AICompletionConfig, model_runner: ModelRunner, - judges: Optional[Dict[str, Judge]] = None, ): self._ai_config = ai_config self._model_runner = model_runner - self._judges = judges or {} self._messages: List[LDMessage] = [] async def invoke(self, prompt: str) -> ModelResponse: @@ -53,38 +49,32 @@ async def invoke(self, prompt: str) -> ModelResponse: lambda: self._model_runner.invoke_model(all_messages), ) - if ( - self._ai_config.judge_configuration - and self._ai_config.judge_configuration.judges - ): - response.evaluations = self._start_judge_evaluations(tracker, self._messages, response) + input_text = '\r\n'.join(m.content for m in self._messages) if self._messages else '' + output_text = response.message.content + response.evaluations = self._track_judge_results(tracker, input_text, output_text) self._messages.append(response.message) return response - def _start_judge_evaluations( + def _track_judge_results( self, tracker: LDAIConfigTracker, - messages: List[LDMessage], - response: ModelResponse, - ) -> List[asyncio.Task[Optional[JudgeResult]]]: - if not self._ai_config.judge_configuration or not self._ai_config.judge_configuration.judges: - return [] - - async def evaluate_judge(judge_config: Any) -> Optional[JudgeResult]: - judge = self._judges.get(judge_config.key) - if not judge: - log.warning(f'Judge configuration is not enabled: {judge_config.key}') - return None - judge_result = await judge.evaluate_messages(messages, response, judge_config.sampling_rate) - if judge_result.success: - tracker.track_judge_result(judge_result) - return judge_result - - return [ - asyncio.create_task(evaluate_judge(jc)) - for jc in self._ai_config.judge_configuration.judges - ] + input_text: str, + output_text: str, + ) -> asyncio.Task[List[JudgeResult]]: + eval_task = self._ai_config.evaluator.evaluate(input_text, output_text) + + def _on_done(task: asyncio.Task) -> None: + if task.cancelled(): + return + if task.exception() is not None: + return + for r in task.result(): + if r.success: + tracker.track_judge_result(r) + + eval_task.add_done_callback(_on_done) + return eval_task def get_messages(self, include_config_messages: bool = False) -> List[LDMessage]: """ @@ -116,7 +106,3 @@ def get_model_runner(self) -> ModelRunner: def get_config(self) -> AICompletionConfig: """Return the AI completion config.""" return self._ai_config - - def get_judges(self) -> Dict[str, Judge]: - """Return the judges associated with this model.""" - return self._judges diff --git a/packages/sdk/server-ai/src/ldai/models.py b/packages/sdk/server-ai/src/ldai/models.py index 28d8cef..47fbddd 100644 --- a/packages/sdk/server-ai/src/ldai/models.py +++ b/packages/sdk/server-ai/src/ldai/models.py @@ -1,6 +1,9 @@ import warnings from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union + +if TYPE_CHECKING: + from ldai.evaluator import Evaluator from typing_extensions import Self @@ -253,6 +256,7 @@ class AICompletionConfig(AIConfig): """ Completion AI Config (default mode). """ + evaluator: 'Evaluator' = field(kw_only=True, repr=False, compare=False, hash=False) messages: Optional[List[LDMessage]] = None judge_configuration: Optional[JudgeConfiguration] = None tools: Optional[Dict[str, 'LDTool']] = None @@ -302,6 +306,7 @@ class AIAgentConfig(AIConfig): """ Agent-specific AI Config with instructions. """ + evaluator: 'Evaluator' = field(kw_only=True, repr=False, compare=False, hash=False) instructions: Optional[str] = None judge_configuration: Optional[JudgeConfiguration] = None tools: Optional[Dict[str, 'LDTool']] = None diff --git a/packages/sdk/server-ai/src/ldai/providers/ai_provider.py b/packages/sdk/server-ai/src/ldai/providers/ai_provider.py index 70acbe0..6e2cb6c 100644 --- a/packages/sdk/server-ai/src/ldai/providers/ai_provider.py +++ b/packages/sdk/server-ai/src/ldai/providers/ai_provider.py @@ -91,7 +91,11 @@ def create_agent(self, config: Any, tools: Optional[ToolRegistry] = None) -> Opt log.warning('create_agent not implemented by this provider') return None - def create_agent_graph(self, graph_def: Any, tools: Any) -> Optional[Any]: + def create_agent_graph( + self, + graph_def: Any, + tools: Any, + ) -> Optional[Any]: """ CAUTION: This feature is experimental and should NOT be considered ready for production use. diff --git a/packages/sdk/server-ai/src/ldai/providers/runner_factory.py b/packages/sdk/server-ai/src/ldai/providers/runner_factory.py index 4c28334..9363f8e 100644 --- a/packages/sdk/server-ai/src/ldai/providers/runner_factory.py +++ b/packages/sdk/server-ai/src/ldai/providers/runner_factory.py @@ -176,7 +176,10 @@ def create_agent_graph( if graph_def.root() and graph_def.root().get_config() and graph_def.root().get_config().provider: provider_name = graph_def.root().get_config().provider.name.lower() providers = RunnerFactory._get_providers_to_try(default_ai_provider, provider_name) - return RunnerFactory._with_fallback(providers, lambda p: p.create_agent_graph(graph_def, tools)) + return RunnerFactory._with_fallback( + providers, + lambda p: p.create_agent_graph(graph_def, tools), + ) @staticmethod def _pkg_exists(package_name: str) -> None: diff --git a/packages/sdk/server-ai/src/ldai/providers/types.py b/packages/sdk/server-ai/src/ldai/providers/types.py index 083141d..aa53788 100644 --- a/packages/sdk/server-ai/src/ldai/providers/types.py +++ b/packages/sdk/server-ai/src/ldai/providers/types.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional @@ -44,7 +45,7 @@ class ModelResponse: """ message: LDMessage metrics: LDAIMetrics - evaluations: Optional[List[JudgeResult]] = None + evaluations: Optional[asyncio.Task[List[JudgeResult]]] = None @dataclass @@ -109,3 +110,4 @@ class AgentGraphResult: output: str raw: Any metrics: LDAIMetrics + evaluations: Optional[List[JudgeResult]] = None diff --git a/packages/sdk/server-ai/tests/test_agent_graph.py b/packages/sdk/server-ai/tests/test_agent_graph.py index f16a174..140dbe7 100644 --- a/packages/sdk/server-ai/tests/test_agent_graph.py +++ b/packages/sdk/server-ai/tests/test_agent_graph.py @@ -11,6 +11,7 @@ AIAgentConfig, Edge, ) +from ldai.evaluator import Evaluator @pytest.fixture @@ -270,16 +271,16 @@ def test_agent_graph_build_nodes(ldai_client: LDAIClient): ai_graph_config, { "customer-support-agent": AIAgentConfig( - key="customer-support-agent", enabled=True, create_tracker=MagicMock(), + key="customer-support-agent", enabled=True, create_tracker=MagicMock(), evaluator=Evaluator.noop(), ), "personalized-agent": AIAgentConfig( - key="personalized-agent", enabled=True, create_tracker=MagicMock(), + key="personalized-agent", enabled=True, create_tracker=MagicMock(), evaluator=Evaluator.noop(), ), "multi-context-agent": AIAgentConfig( - key="multi-context-agent", enabled=True, create_tracker=MagicMock(), + key="multi-context-agent", enabled=True, create_tracker=MagicMock(), evaluator=Evaluator.noop(), ), "minimal-agent": AIAgentConfig( - key="minimal-agent", enabled=True, create_tracker=MagicMock(), + key="minimal-agent", enabled=True, create_tracker=MagicMock(), evaluator=Evaluator.noop(), ), }, )