diff --git a/packages/sdk/server-ai/src/ldai/judge/__init__.py b/packages/sdk/server-ai/src/ldai/judge/__init__.py index d889f17f..f2e8c362 100644 --- a/packages/sdk/server-ai/src/ldai/judge/__init__.py +++ b/packages/sdk/server-ai/src/ldai/judge/__init__.py @@ -75,8 +75,8 @@ async def evaluate( assert self._evaluation_response_structure is not None response = await tracker.track_metrics_of_async( - lambda: self._model_runner.invoke_structured_model(messages, self._evaluation_response_structure), lambda result: result.metrics, + lambda: self._model_runner.invoke_structured_model(messages, self._evaluation_response_structure), ) parsed = self._parse_evaluation_response(response.data) diff --git a/packages/sdk/server-ai/src/ldai/managed_agent.py b/packages/sdk/server-ai/src/ldai/managed_agent.py index eb2dee7f..ab3ee5e6 100644 --- a/packages/sdk/server-ai/src/ldai/managed_agent.py +++ b/packages/sdk/server-ai/src/ldai/managed_agent.py @@ -29,8 +29,8 @@ async def run(self, input: str) -> AgentResult: """ tracker = self._ai_config.create_tracker() return await tracker.track_metrics_of_async( - lambda: self._agent_runner.run(input), lambda result: result.metrics, + lambda: self._agent_runner.run(input), ) def get_agent_runner(self) -> AgentRunner: diff --git a/packages/sdk/server-ai/src/ldai/managed_model.py b/packages/sdk/server-ai/src/ldai/managed_model.py index dc4393a0..ef6f21e9 100644 --- a/packages/sdk/server-ai/src/ldai/managed_model.py +++ b/packages/sdk/server-ai/src/ldai/managed_model.py @@ -49,8 +49,8 @@ async def invoke(self, prompt: str) -> ModelResponse: all_messages = config_messages + self._messages response = await tracker.track_metrics_of_async( - lambda: self._model_runner.invoke_model(all_messages), lambda result: result.metrics, + lambda: self._model_runner.invoke_model(all_messages), ) if ( diff --git a/packages/sdk/server-ai/src/ldai/tracker.py b/packages/sdk/server-ai/src/ldai/tracker.py index e446011f..ed5d3e7a 100644 --- a/packages/sdk/server-ai/src/ldai/tracker.py +++ b/packages/sdk/server-ai/src/ldai/tracker.py @@ -262,8 +262,8 @@ def _track_from_metrics_extractor( def track_metrics_of( self, - func: Callable[[], Any], metrics_extractor: Callable[[Any], Any], + func: Callable[[], Any], ) -> Any: """ Track metrics for a synchronous AI operation. @@ -277,8 +277,8 @@ def track_metrics_of( For async operations, use :meth:`track_metrics_of_async`. - :param func: Synchronous callable that runs the operation :param metrics_extractor: Function that extracts LDAIMetrics from the operation result + :param func: Synchronous callable that runs the operation :return: The result of the operation """ start_ns = time.perf_counter_ns() @@ -294,14 +294,14 @@ def track_metrics_of( self.track_duration(duration) return self._track_from_metrics_extractor(result, metrics_extractor) - async def track_metrics_of_async(self, func, metrics_extractor): + async def track_metrics_of_async(self, metrics_extractor, func): """ Track metrics for an async AI operation (``func`` is awaited). Same event semantics as :meth:`track_metrics_of`. - :param func: Async callable or zero-arg callable that returns an awaitable when called :param metrics_extractor: Function that extracts LDAIMetrics from the operation result + :param func: Async callable or zero-arg callable that returns an awaitable when called :return: The result of the operation """ start_ns = time.perf_counter_ns() diff --git a/packages/sdk/server-ai/tests/test_tracker.py b/packages/sdk/server-ai/tests/test_tracker.py index 2350e61d..09f12f0e 100644 --- a/packages/sdk/server-ai/tests/test_tracker.py +++ b/packages/sdk/server-ai/tests/test_tracker.py @@ -531,7 +531,7 @@ def fn(): def extract(r): return LDAIMetrics(success=True, usage=TokenUsage(5, 2, 3)) - out = tracker.track_metrics_of(fn, extract) + out = tracker.track_metrics_of(extract, fn) assert out == "done" calls = client.track.mock_calls # type: ignore assert any(c.args[0] == "$ld:ai:generation:success" for c in calls) @@ -551,7 +551,7 @@ async def fn(): def extract(r): return LDAIMetrics(success=True, usage=TokenUsage(5, 2, 3)) - await tracker.track_metrics_of_async(fn, extract) + await tracker.track_metrics_of_async(extract, fn) gk_td = {**_base_td(), "graphKey": "gg"} calls = client.track.mock_calls # type: ignore assert any(