diff --git a/src/uipath_langchain/_cli/cli_debug.py b/src/uipath_langchain/_cli/cli_debug.py index 9bc7bcad..5760589b 100644 --- a/src/uipath_langchain/_cli/cli_debug.py +++ b/src/uipath_langchain/_cli/cli_debug.py @@ -39,7 +39,6 @@ async def execute(): context.entrypoint = entrypoint context.input = input context.resume = resume - context.execution_id = context.job_id or "default" _instrument_traceable_attributes() diff --git a/src/uipath_langchain/_cli/cli_eval.py b/src/uipath_langchain/_cli/cli_eval.py index 51d00eec..efbf6a27 100644 --- a/src/uipath_langchain/_cli/cli_eval.py +++ b/src/uipath_langchain/_cli/cli_eval.py @@ -1,6 +1,8 @@ import asyncio +import os from typing import List, Optional +from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver from openinference.instrumentation.langchain import ( LangChainInstrumentor, get_current_span, @@ -9,6 +11,7 @@ from uipath._cli._evals._progress_reporter import StudioWebProgressReporter from uipath._cli._evals._runtime import UiPathEvalContext, UiPathEvalRuntime from uipath._cli._runtime._contracts import ( + UiPathRuntimeContext, UiPathRuntimeFactory, ) from uipath._cli._utils._eval_set import EvalHelpers @@ -25,6 +28,13 @@ ) +def get_connection_string(context: UiPathRuntimeContext) -> str: + if context.runtime_dir and context.state_file: + os.makedirs(context.runtime_dir, exist_ok=True) + return os.path.join(context.runtime_dir, context.state_file) + return os.path.join("__uipath", "state.db") + + def langgraph_eval_middleware( entrypoint: Optional[str], eval_set: Optional[str], eval_ids: List[str], **kwargs ) -> MiddlewareResult: @@ -35,57 +45,70 @@ def langgraph_eval_middleware( ) # Continue with normal flow if no langgraph.json try: - _instrument_traceable_attributes() - - event_bus = EventBus() - if kwargs.get("register_progress_reporter", False): - progress_reporter = StudioWebProgressReporter( - spans_exporter=LangChainExporter() - ) - asyncio.run(progress_reporter.subscribe_to_eval_runtime_events(event_bus)) - console_reporter = ConsoleProgressReporter() - asyncio.run(console_reporter.subscribe_to_eval_runtime_events(event_bus)) - - def generate_runtime_context( - context_entrypoint: str, **context_kwargs - ) -> LangGraphRuntimeContext: - context = LangGraphRuntimeContext.with_defaults(**context_kwargs) - context.entrypoint = context_entrypoint - return context + async def execute(): + _instrument_traceable_attributes() - runtime_entrypoint = entrypoint or auto_discover_entrypoint() + event_bus = EventBus() - eval_context = UiPathEvalContext.with_defaults( - entrypoint=runtime_entrypoint, **kwargs - ) - eval_context.eval_set = eval_set or EvalHelpers.auto_discover_eval_set() - eval_context.eval_ids = eval_ids + if kwargs.get("register_progress_reporter", False): + progress_reporter = StudioWebProgressReporter( + spans_exporter=LangChainExporter() + ) + await progress_reporter.subscribe_to_eval_runtime_events(event_bus) - def generate_runtime(ctx: LangGraphRuntimeContext) -> LangGraphScriptRuntime: - return LangGraphScriptRuntime(ctx, ctx.entrypoint) + console_reporter = ConsoleProgressReporter() + await console_reporter.subscribe_to_eval_runtime_events(event_bus) - runtime_factory = UiPathRuntimeFactory( - LangGraphScriptRuntime, - LangGraphRuntimeContext, - context_generator=lambda **context_kwargs: generate_runtime_context( - context_entrypoint=runtime_entrypoint, - **context_kwargs, - ), - runtime_generator=generate_runtime, - ) + runtime_entrypoint = entrypoint or auto_discover_entrypoint() - if eval_context.job_id: - runtime_factory.add_span_exporter(LangChainExporter()) + eval_context = UiPathEvalContext.with_defaults( + entrypoint=runtime_entrypoint, **kwargs + ) + eval_context.eval_set = eval_set or EvalHelpers.auto_discover_eval_set() + eval_context.eval_ids = eval_ids - runtime_factory.add_instrumentor(LangChainInstrumentor, get_current_span) + def generate_runtime( + ctx: LangGraphRuntimeContext, + ) -> LangGraphScriptRuntime: + return LangGraphScriptRuntime(ctx, ctx.entrypoint) - async def execute(): - async with UiPathEvalRuntime.from_eval_context( - factory=runtime_factory, context=eval_context, event_bus=event_bus - ) as eval_runtime: - await eval_runtime.execute() - await event_bus.wait_for_all() + def generate_runtime_context( + context_entrypoint: str, + context_memory: AsyncSqliteSaver, + **context_kwargs, + ) -> LangGraphRuntimeContext: + context = LangGraphRuntimeContext.with_defaults(**context_kwargs) + context.entrypoint = context_entrypoint + context.memory = context_memory + return context + + async with AsyncSqliteSaver.from_conn_string( + get_connection_string(eval_context) + ) as memory: + runtime_factory = UiPathRuntimeFactory( + LangGraphScriptRuntime, + LangGraphRuntimeContext, + context_generator=lambda **context_kwargs: generate_runtime_context( + context_entrypoint=runtime_entrypoint, + context_memory=memory, + **context_kwargs, + ), + runtime_generator=generate_runtime, + ) + + if eval_context.job_id: + runtime_factory.add_span_exporter(LangChainExporter()) + + runtime_factory.add_instrumentor( + LangChainInstrumentor, get_current_span + ) + + async with UiPathEvalRuntime.from_eval_context( + factory=runtime_factory, context=eval_context, event_bus=event_bus + ) as eval_runtime: + await eval_runtime.execute() + await event_bus.wait_for_all() asyncio.run(execute()) return MiddlewareResult(should_continue=False) diff --git a/src/uipath_langchain/_cli/cli_run.py b/src/uipath_langchain/_cli/cli_run.py index b5e2c49a..c875d6e9 100644 --- a/src/uipath_langchain/_cli/cli_run.py +++ b/src/uipath_langchain/_cli/cli_run.py @@ -40,7 +40,6 @@ async def execute(): context.entrypoint = entrypoint context.input = input context.resume = resume - context.execution_id = context.job_id or "default" _instrument_traceable_attributes() def generate_runtime( @@ -66,7 +65,7 @@ def generate_runtime( await runtime_factory.execute(context) else: debug_bridge: UiPathDebugBridge = ConsoleDebugBridge() - await debug_bridge.emit_execution_started(context.execution_id) + await debug_bridge.emit_execution_started("default") async for event in runtime_factory.stream(context): if isinstance(event, UiPathRuntimeResult): await debug_bridge.emit_execution_completed(event)