From 7a76f832c79e2ce4ae425afc525e2a42d7caa29f Mon Sep 17 00:00:00 2001 From: Major Hayden Date: Tue, 28 Apr 2026 15:03:38 -0500 Subject: [PATCH] refactor(metrics): centralize metric recording Signed-off-by: Major Hayden --- src/app/endpoints/rlsapi_v1.py | 4 +- src/app/endpoints/streaming_query.py | 4 +- src/app/main.py | 6 +- src/metrics/__init__.py | 6 +- src/metrics/recording.py | 99 +++++++++++ src/utils/responses.py | 32 ++-- src/utils/shields.py | 6 +- .../app/endpoints/test_streaming_query.py | 10 +- tests/unit/app/test_main_middleware.py | 31 ++-- tests/unit/metrics/test_recording.py | 158 ++++++++++++++++++ tests/unit/utils/test_responses.py | 102 ++--------- tests/unit/utils/test_shields.py | 35 ++-- 12 files changed, 340 insertions(+), 153 deletions(-) create mode 100644 src/metrics/recording.py create mode 100644 tests/unit/metrics/test_recording.py diff --git a/src/app/endpoints/rlsapi_v1.py b/src/app/endpoints/rlsapi_v1.py index c38a46a04..3c7c7a6e6 100644 --- a/src/app/endpoints/rlsapi_v1.py +++ b/src/app/endpoints/rlsapi_v1.py @@ -17,13 +17,13 @@ from openai._exceptions import APIStatusError as OpenAIAPIStatusError import constants -import metrics from authentication import get_auth_dependency from authentication.interface import AuthTuple from authorization.middleware import authorize from client import AsyncLlamaStackClientHolder from configuration import configuration from log import get_logger +from metrics import recording from models.config import Action from models.responses import ( UNAUTHORIZED_OPENAPI_EXAMPLES, @@ -447,7 +447,7 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po The total inference time in seconds. """ inference_time = time.monotonic() - start_time - metrics.llm_calls_failures_total.labels(provider, model).inc() + recording.record_llm_failure(provider, model) _queue_splunk_event( background_tasks, infer_request, diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index acab21f40..f11887d2f 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -40,7 +40,6 @@ ) from openai._exceptions import APIStatusError as OpenAIAPIStatusError -import metrics from authentication import get_auth_dependency from authentication.interface import AuthTuple from authorization.azure_token_manager import AzureEntraIDManager @@ -59,6 +58,7 @@ TOPIC_SUMMARY_INTERRUPT_TIMEOUT_SECONDS, ) from log import get_logger +from metrics import recording from models.config import Action from models.context import ResponseGeneratorContext from models.requests import QueryRequest @@ -283,7 +283,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals provider_id, model_id = extract_provider_and_model_from_model_id( responses_params.model ) - metrics.llm_calls_total.labels(provider_id, model_id).inc() + recording.record_llm_call(provider_id, model_id) generator, turn_summary = await retrieve_response_generator( responses_params=responses_params, diff --git a/src/app/main.py b/src/app/main.py index fab5d9573..1535d0532 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -12,7 +12,6 @@ from starlette.routing import Mount, Route, WebSocketRoute from starlette.types import ASGIApp, Message, Receive, Scope, Send -import metrics import version from a2a_storage import A2AStorageFactory from app import routers @@ -22,6 +21,7 @@ from client import AsyncLlamaStackClientHolder from configuration import configuration from log import get_logger +from metrics import recording from models.responses import InternalServerErrorResponse from sentry import initialize_sentry from utils.common import register_mcp_servers_async @@ -182,12 +182,12 @@ async def send_wrapper(message: Message) -> None: # Measure duration and forward the request. Use try/finally so the # call counter is always incremented, even when the inner app raises. try: - with metrics.response_duration_seconds.labels(path).time(): + with recording.measure_response_duration(path): await self.app(scope, receive, send_wrapper) finally: # Ignore /metrics endpoint that will be called periodically. if not path.endswith("/metrics"): - metrics.rest_api_calls_total.labels(path, status_code).inc() + recording.record_rest_api_call(path, status_code) class GlobalExceptionMiddleware: # pylint: disable=too-few-public-methods diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py index 5c4e4e44f..49c6767a0 100644 --- a/src/metrics/__init__.py +++ b/src/metrics/__init__.py @@ -42,14 +42,12 @@ "ls_llm_validation_errors_total", "LLM validation errors" ) -# TODO(lucasagomes): Add metric for token usage -# https://issues.redhat.com/browse/LCORE-411 +# Metric that counts how many tokens were sent to LLMs llm_token_sent_total = Counter( "ls_llm_token_sent_total", "LLM tokens sent", ["provider", "model"] ) -# TODO(lucasagomes): Add metric for token usage -# https://issues.redhat.com/browse/LCORE-411 +# Metric that counts how many tokens were received from LLMs llm_token_received_total = Counter( "ls_llm_token_received_total", "LLM tokens received", ["provider", "model"] ) diff --git a/src/metrics/recording.py b/src/metrics/recording.py new file mode 100644 index 000000000..abea2270d --- /dev/null +++ b/src/metrics/recording.py @@ -0,0 +1,99 @@ +"""Recording helpers for Prometheus metrics. + +This module keeps metric definitions in ``metrics.__init__`` while providing a +small facade for application code. New metrics should add a recording helper +here so callers do not need to know Prometheus object details. +""" + +from collections.abc import Iterator +from contextlib import contextmanager + +import metrics +from log import get_logger + +logger = get_logger(__name__) + + +@contextmanager +def measure_response_duration(path: str) -> Iterator[None]: + """Measure REST API response duration for a route path. + + Args: + path: Normalized route path used as the metric label. + """ + try: + cm = metrics.response_duration_seconds.labels(path).time() + except (AttributeError, TypeError, ValueError): + logger.warning("Failed to start response duration metric", exc_info=True) + yield + return + with cm: + yield + + +def record_rest_api_call(path: str, status_code: int) -> None: + """Record one REST API request. + + Args: + path: Normalized route path used as the metric label. + status_code: HTTP response status code returned by the endpoint. + """ + try: + metrics.rest_api_calls_total.labels(path, status_code).inc() + except (AttributeError, TypeError, ValueError): + logger.warning("Failed to update REST API call metric", exc_info=True) + + +def record_llm_call(provider: str, model: str) -> None: + """Record one LLM call for a provider and model. + + Args: + provider: LLM provider identifier. + model: LLM model identifier without the provider prefix. + """ + try: + metrics.llm_calls_total.labels(provider, model).inc() + except (AttributeError, TypeError, ValueError): + logger.warning("Failed to update LLM call metric", exc_info=True) + + +def record_llm_failure(provider: str, model: str) -> None: + """Record one failed LLM call for a provider and model. + + Args: + provider: LLM provider identifier. + model: LLM model identifier without the provider prefix. + """ + try: + metrics.llm_calls_failures_total.labels(provider, model).inc() + except (AttributeError, TypeError, ValueError): + logger.warning("Failed to update LLM failure metric", exc_info=True) + + +def record_llm_validation_error() -> None: + """Record one LLM validation error, such as a shield violation.""" + try: + metrics.llm_calls_validation_errors_total.inc() + except (AttributeError, TypeError, ValueError): + logger.warning("Failed to update LLM validation error metric", exc_info=True) + + +def record_llm_token_usage( + provider: str, + model: str, + input_tokens: int, + output_tokens: int, +) -> None: + """Record LLM token usage for a provider and model. + + Args: + provider: LLM provider identifier. + model: LLM model identifier without the provider prefix. + input_tokens: Number of tokens sent to the LLM. + output_tokens: Number of tokens received from the LLM. + """ + try: + metrics.llm_token_sent_total.labels(provider, model).inc(input_tokens) + metrics.llm_token_received_total.labels(provider, model).inc(output_tokens) + except (AttributeError, TypeError, ValueError): + logger.warning("Failed to update token metrics", exc_info=True) diff --git a/src/utils/responses.py b/src/utils/responses.py index 858973d02..869a3a959 100644 --- a/src/utils/responses.py +++ b/src/utils/responses.py @@ -86,11 +86,11 @@ from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient import constants -import metrics from client import AsyncLlamaStackClientHolder from configuration import configuration from constants import DEFAULT_RAG_TOOL from log import get_logger +from metrics import recording from models.config import ByokRag from models.database.conversations import UserConversation from models.requests import QueryRequest @@ -922,7 +922,7 @@ def extract_token_usage(usage: Optional[ResponseUsage], model: str) -> TokenCoun logger.debug( "No usage information in Responses API response, token counts will be 0" ) - _increment_llm_call_metric(provider_id, model_id) + recording.record_llm_call(provider_id, model_id) return TokenCounter(llm_calls=1) token_counter = TokenCounter( @@ -934,18 +934,14 @@ def extract_token_usage(usage: Optional[ResponseUsage], model: str) -> TokenCoun token_counter.output_tokens, ) - # Update Prometheus metrics only when we have actual usage data - try: - metrics.llm_token_sent_total.labels(provider_id, model_id).inc( - token_counter.input_tokens - ) - metrics.llm_token_received_total.labels(provider_id, model_id).inc( - token_counter.output_tokens - ) - except (AttributeError, TypeError, ValueError) as e: - logger.warning("Failed to update token metrics: %s", e) - - _increment_llm_call_metric(provider_id, model_id) + # Update Prometheus metrics only when we have actual usage data. + recording.record_llm_token_usage( + provider_id, + model_id, + token_counter.input_tokens, + token_counter.output_tokens, + ) + recording.record_llm_call(provider_id, model_id) return token_counter @@ -1251,14 +1247,6 @@ def extract_rag_chunks_from_file_search_item( return rag_chunks -def _increment_llm_call_metric(provider: str, model: str) -> None: - """Safely increment LLM call metric.""" - try: - metrics.llm_calls_total.labels(provider, model).inc() - except (AttributeError, TypeError, ValueError) as e: - logger.warning("Failed to update LLM call metric: %s", e) - - def parse_arguments_string(arguments_str: str) -> dict[str, Any]: """Parse an arguments string into a dictionary. diff --git a/src/utils/shields.py b/src/utils/shields.py index 7e3b4cc0a..6d6089139 100644 --- a/src/utils/shields.py +++ b/src/utils/shields.py @@ -14,10 +14,10 @@ from llama_stack_client.types import ShieldListResponse from openai._exceptions import APIStatusError as OpenAIAPIStatusError -import metrics from configuration import AppConfig from constants import DEFAULT_VIOLATION_MESSAGE from log import get_logger +from metrics import recording from models.requests import QueryRequest from models.responses import ( InternalServerErrorResponse, @@ -77,7 +77,7 @@ def detect_shield_violations(output_items: list[Any]) -> bool: refusal = getattr(output_item, "refusal", None) if refusal: # Metric for LLM validation errors (shield violations) - metrics.llm_calls_validation_errors_total.inc() + recording.record_llm_validation_error() logger.warning("Shield violation detected: %s", refusal) return True return False @@ -178,7 +178,7 @@ async def run_shield_moderation( if moderation_result.results and moderation_result.results[0].flagged: result = moderation_result.results[0] - metrics.llm_calls_validation_errors_total.inc() + recording.record_llm_validation_error() logger.warning( "Shield '%s' flagged content: categories=%s", shield.identifier, diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index bf2738cf4..f6c2bc958 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -389,7 +389,7 @@ async def test_successful_streaming_query( "app.endpoints.streaming_query.extract_provider_and_model_from_model_id", return_value=("provider1", "model1"), ) - mocker.patch("app.endpoints.streaming_query.metrics.llm_calls_total") + mocker.patch("app.endpoints.streaming_query.recording.record_llm_call") async def mock_generator() -> AsyncIterator[str]: yield "data: test\n\n" @@ -476,7 +476,7 @@ async def test_streaming_query_text_media_type_header( "app.endpoints.streaming_query.extract_provider_and_model_from_model_id", return_value=("provider1", "model1"), ) - mocker.patch("app.endpoints.streaming_query.metrics.llm_calls_total") + mocker.patch("app.endpoints.streaming_query.recording.record_llm_call") async def mock_generator() -> AsyncIterator[str]: yield "data: test\n\n" @@ -574,7 +574,7 @@ async def test_streaming_query_with_conversation( "app.endpoints.streaming_query.extract_provider_and_model_from_model_id", return_value=("provider1", "model1"), ) - mocker.patch("app.endpoints.streaming_query.metrics.llm_calls_total") + mocker.patch("app.endpoints.streaming_query.recording.record_llm_call") async def mock_generator() -> AsyncIterator[str]: yield "data: test\n\n" @@ -670,7 +670,7 @@ async def test_streaming_query_with_attachments( "app.endpoints.streaming_query.extract_provider_and_model_from_model_id", return_value=("provider1", "model1"), ) - mocker.patch("app.endpoints.streaming_query.metrics.llm_calls_total") + mocker.patch("app.endpoints.streaming_query.recording.record_llm_call") async def mock_generator() -> AsyncIterator[str]: yield "data: test\n\n" @@ -770,7 +770,7 @@ async def test_streaming_query_azure_token_refresh( "app.endpoints.streaming_query.run_shield_moderation", new=mocker.AsyncMock(return_value=ShieldModerationPassed()), ) - mocker.patch("app.endpoints.streaming_query.metrics.llm_calls_total") + mocker.patch("app.endpoints.streaming_query.recording.record_llm_call") async def mock_generator() -> AsyncIterator[str]: yield "data: test\n\n" diff --git a/tests/unit/app/test_main_middleware.py b/tests/unit/app/test_main_middleware.py index f4f5420a8..24a4a366a 100644 --- a/tests/unit/app/test_main_middleware.py +++ b/tests/unit/app/test_main_middleware.py @@ -1,6 +1,7 @@ """Unit tests for the pure ASGI middlewares in main.py.""" import json +from contextlib import nullcontext from typing import cast import pytest @@ -165,7 +166,10 @@ async def test_rest_api_metrics_increments_counter_on_exception( ) -> None: """Counter must be incremented even when the inner app raises.""" mocker.patch("app.main.app_routes_paths", ["/v1/infer"]) - mock_metrics = mocker.patch("app.main.metrics") + mock_measure_duration = mocker.patch( + "app.main.recording.measure_response_duration", return_value=nullcontext() + ) + mock_record_call = mocker.patch("app.main.recording.record_rest_api_call") async def failing_app(_scope: Scope, _receive: Receive, _send: Send) -> None: raise RuntimeError("boom") @@ -175,9 +179,8 @@ async def failing_app(_scope: Scope, _receive: Receive, _send: Send) -> None: with pytest.raises(RuntimeError, match="boom"): await middleware(_make_scope("/v1/infer"), _noop_receive, _ResponseCollector()) - mock_metrics.response_duration_seconds.labels.assert_called_once_with("/v1/infer") - mock_metrics.rest_api_calls_total.labels.assert_called_once_with("/v1/infer", 500) - mock_metrics.rest_api_calls_total.labels.return_value.inc.assert_called_once() + mock_measure_duration.assert_called_once_with("/v1/infer") + mock_record_call.assert_called_once_with("/v1/infer", 500) @pytest.mark.asyncio @@ -186,7 +189,10 @@ async def test_rest_api_metrics_strips_root_path( ) -> None: """Middleware must strip root_path so prefixed requests still match routes.""" mocker.patch("app.main.app_routes_paths", ["/v1/infer"]) - mock_metrics = mocker.patch("app.main.metrics") + mock_measure_duration = mocker.patch( + "app.main.recording.measure_response_duration", return_value=nullcontext() + ) + mock_record_call = mocker.patch("app.main.recording.record_rest_api_call") async def ok_app(_scope: Scope, _receive: Receive, send: Send) -> None: await send({"type": "http.response.start", "status": 200, "headers": []}) @@ -204,9 +210,8 @@ async def ok_app(_scope: Scope, _receive: Receive, send: Send) -> None: assert collector.status_code == 200 # Metrics labels should use the stripped path, not the full prefixed path. - mock_metrics.response_duration_seconds.labels.assert_called_once_with("/v1/infer") - mock_metrics.rest_api_calls_total.labels.assert_called_once_with("/v1/infer", 200) - mock_metrics.rest_api_calls_total.labels.return_value.inc.assert_called_once() + mock_measure_duration.assert_called_once_with("/v1/infer") + mock_record_call.assert_called_once_with("/v1/infer", 200) @pytest.mark.asyncio @@ -215,7 +220,10 @@ async def test_rest_api_metrics_no_root_path_unchanged( ) -> None: """Without root_path, middleware behaves as before.""" mocker.patch("app.main.app_routes_paths", ["/v1/infer"]) - mock_metrics = mocker.patch("app.main.metrics") + mock_measure_duration = mocker.patch( + "app.main.recording.measure_response_duration", return_value=nullcontext() + ) + mock_record_call = mocker.patch("app.main.recording.record_rest_api_call") async def ok_app(_scope: Scope, _receive: Receive, send: Send) -> None: await send({"type": "http.response.start", "status": 200, "headers": []}) @@ -231,6 +239,5 @@ async def ok_app(_scope: Scope, _receive: Receive, send: Send) -> None: ) assert collector.status_code == 200 - mock_metrics.response_duration_seconds.labels.assert_called_once_with("/v1/infer") - mock_metrics.rest_api_calls_total.labels.assert_called_once_with("/v1/infer", 200) - mock_metrics.rest_api_calls_total.labels.return_value.inc.assert_called_once() + mock_measure_duration.assert_called_once_with("/v1/infer") + mock_record_call.assert_called_once_with("/v1/infer", 200) diff --git a/tests/unit/metrics/test_recording.py b/tests/unit/metrics/test_recording.py new file mode 100644 index 000000000..d132a8293 --- /dev/null +++ b/tests/unit/metrics/test_recording.py @@ -0,0 +1,158 @@ +"""Unit tests for Prometheus metric recording helpers.""" + +from pytest_mock import MockerFixture + +from metrics import recording + + +def test_measure_response_duration_records_timer(mocker: MockerFixture) -> None: + """Test that response duration measurement uses the path label timer.""" + mock_timer = mocker.MagicMock() + mock_metric = mocker.patch("metrics.recording.metrics.response_duration_seconds") + mock_metric.labels.return_value.time.return_value = mock_timer + + with recording.measure_response_duration("/v1/infer"): + pass + + mock_metric.labels.assert_called_once_with("/v1/infer") + mock_metric.labels.return_value.time.assert_called_once() + mock_timer.__enter__.assert_called_once() + mock_timer.__exit__.assert_called_once() + + +def test_measure_response_duration_logs_metric_errors(mocker: MockerFixture) -> None: + """Test that response duration metric errors are logged and request still proceeds.""" + mock_metric = mocker.patch("metrics.recording.metrics.response_duration_seconds") + mock_metric.labels.return_value.time.side_effect = AttributeError("missing") + mock_logger = mocker.patch("metrics.recording.logger") + + with recording.measure_response_duration("/v1/infer"): + pass + + mock_logger.warning.assert_called_once_with( + "Failed to start response duration metric", exc_info=True + ) + + +def test_record_rest_api_call_records_counter(mocker: MockerFixture) -> None: + """Test that REST API call recording increments the labeled counter.""" + mock_metric = mocker.patch("metrics.recording.metrics.rest_api_calls_total") + + recording.record_rest_api_call("/v1/infer", 200) + + mock_metric.labels.assert_called_once_with("/v1/infer", 200) + mock_metric.labels.return_value.inc.assert_called_once() + + +def test_record_rest_api_call_logs_metric_errors(mocker: MockerFixture) -> None: + """Test that REST API call metric errors are logged and swallowed.""" + mock_metric = mocker.patch("metrics.recording.metrics.rest_api_calls_total") + mock_metric.labels.return_value.inc.side_effect = AttributeError("missing") + mock_logger = mocker.patch("metrics.recording.logger") + + recording.record_rest_api_call("/v1/infer", 200) + + mock_logger.warning.assert_called_once_with( + "Failed to update REST API call metric", exc_info=True + ) + + +def test_record_llm_call_records_counter(mocker: MockerFixture) -> None: + """Test that LLM call recording increments the provider/model counter.""" + mock_metric = mocker.patch("metrics.recording.metrics.llm_calls_total") + + recording.record_llm_call("provider1", "model1") + + mock_metric.labels.assert_called_once_with("provider1", "model1") + mock_metric.labels.return_value.inc.assert_called_once() + + +def test_record_llm_call_logs_metric_errors(mocker: MockerFixture) -> None: + """Test that LLM call metric errors are logged and swallowed.""" + mock_metric = mocker.patch("metrics.recording.metrics.llm_calls_total") + mock_metric.labels.return_value.inc.side_effect = AttributeError("missing") + mock_logger = mocker.patch("metrics.recording.logger") + + recording.record_llm_call("provider1", "model1") + + mock_logger.warning.assert_called_once_with( + "Failed to update LLM call metric", exc_info=True + ) + + +def test_record_llm_failure_records_counter(mocker: MockerFixture) -> None: + """Test that LLM failure recording increments the provider/model counter.""" + mock_metric = mocker.patch("metrics.recording.metrics.llm_calls_failures_total") + + recording.record_llm_failure("provider1", "model1") + + mock_metric.labels.assert_called_once_with("provider1", "model1") + mock_metric.labels.return_value.inc.assert_called_once() + + +def test_record_llm_failure_logs_metric_errors(mocker: MockerFixture) -> None: + """Test that LLM failure metric errors are logged and swallowed.""" + mock_metric = mocker.patch("metrics.recording.metrics.llm_calls_failures_total") + mock_metric.labels.return_value.inc.side_effect = TypeError("bad") + mock_logger = mocker.patch("metrics.recording.logger") + + recording.record_llm_failure("provider1", "model1") + + mock_logger.warning.assert_called_once_with( + "Failed to update LLM failure metric", exc_info=True + ) + + +def test_record_llm_validation_error_records_counter(mocker: MockerFixture) -> None: + """Test that validation error recording increments the counter.""" + mock_metric = mocker.patch( + "metrics.recording.metrics.llm_calls_validation_errors_total" + ) + + recording.record_llm_validation_error() + + mock_metric.inc.assert_called_once() + + +def test_record_llm_validation_error_logs_metric_errors( + mocker: MockerFixture, +) -> None: + """Test that validation error metric failures are logged and swallowed.""" + mock_metric = mocker.patch( + "metrics.recording.metrics.llm_calls_validation_errors_total" + ) + mock_metric.inc.side_effect = ValueError("bad") + mock_logger = mocker.patch("metrics.recording.logger") + + recording.record_llm_validation_error() + + mock_logger.warning.assert_called_once_with( + "Failed to update LLM validation error metric", exc_info=True + ) + + +def test_record_llm_token_usage_records_counters(mocker: MockerFixture) -> None: + """Test that token usage recording increments sent and received counters.""" + mock_sent = mocker.patch("metrics.recording.metrics.llm_token_sent_total") + mock_received = mocker.patch("metrics.recording.metrics.llm_token_received_total") + + recording.record_llm_token_usage("provider1", "model1", 100, 50) + + mock_sent.labels.assert_called_once_with("provider1", "model1") + mock_sent.labels.return_value.inc.assert_called_once_with(100) + mock_received.labels.assert_called_once_with("provider1", "model1") + mock_received.labels.return_value.inc.assert_called_once_with(50) + + +def test_record_llm_token_usage_logs_metric_errors(mocker: MockerFixture) -> None: + """Test that token metric failures are logged and swallowed.""" + mock_sent = mocker.patch("metrics.recording.metrics.llm_token_sent_total") + mock_sent.labels.return_value.inc.side_effect = ValueError("bad") + mocker.patch("metrics.recording.metrics.llm_token_received_total") + mock_logger = mocker.patch("metrics.recording.logger") + + recording.record_llm_token_usage("provider1", "model1", 100, 50) + + mock_logger.warning.assert_called_once_with( + "Failed to update token metrics", exc_info=True + ) diff --git a/tests/unit/utils/test_responses.py b/tests/unit/utils/test_responses.py index 485ecbf26..2335193e3 100644 --- a/tests/unit/utils/test_responses.py +++ b/tests/unit/utils/test_responses.py @@ -63,7 +63,6 @@ from models.requests import QueryRequest from utils.responses import ( _build_chunk_attributes, - _increment_llm_call_metric, _merge_tools, _resolve_source_for_result, build_mcp_tool_call_from_arguments_done, @@ -2225,14 +2224,19 @@ def test_extract_token_usage_with_usage_object( "utils.responses.extract_provider_and_model_from_model_id", return_value=("provider1", "model1"), ) - mocker.patch("utils.responses.metrics.llm_token_sent_total") - mocker.patch("utils.responses.metrics.llm_token_received_total") - mocker.patch("utils.responses._increment_llm_call_metric") + mock_token_usage = mocker.patch( + "utils.responses.recording.record_llm_token_usage" + ) + mock_llm_call = mocker.patch("utils.responses.recording.record_llm_call") result = extract_token_usage(mock_usage, "provider1/model1") assert result.input_tokens == input_tokens assert result.output_tokens == output_tokens assert result.llm_calls == 1 + mock_token_usage.assert_called_once_with( + "provider1", "model1", input_tokens, output_tokens + ) + mock_llm_call.assert_called_once_with("provider1", "model1") def test_extract_token_usage_no_usage(self, mocker: MockerFixture) -> None: """Test extracting token usage when usage is None.""" @@ -2240,12 +2244,13 @@ def test_extract_token_usage_no_usage(self, mocker: MockerFixture) -> None: "utils.responses.extract_provider_and_model_from_model_id", return_value=("provider1", "model1"), ) - mocker.patch("utils.responses._increment_llm_call_metric") + mock_llm_call = mocker.patch("utils.responses.recording.record_llm_call") result = extract_token_usage(None, "provider1/model1") assert result.input_tokens == 0 assert result.output_tokens == 0 assert result.llm_calls == 1 + mock_llm_call.assert_called_once_with("provider1", "model1") def test_extract_token_usage_zero_tokens(self, mocker: MockerFixture) -> None: """Test extracting token usage when tokens are 0.""" @@ -2257,11 +2262,16 @@ def test_extract_token_usage_zero_tokens(self, mocker: MockerFixture) -> None: "utils.responses.extract_provider_and_model_from_model_id", return_value=("provider1", "model1"), ) - mocker.patch("utils.responses._increment_llm_call_metric") + mock_token_usage = mocker.patch( + "utils.responses.recording.record_llm_token_usage" + ) + mock_llm_call = mocker.patch("utils.responses.recording.record_llm_call") result = extract_token_usage(mock_usage, "provider1/model1") assert result.input_tokens == 0 assert result.output_tokens == 0 + mock_token_usage.assert_called_once_with("provider1", "model1", 0, 0) + mock_llm_call.assert_called_once_with("provider1", "model1") def test_extract_token_usage_none_response(self, mocker: MockerFixture) -> None: """Test extracting token usage with None response.""" @@ -2269,36 +2279,12 @@ def test_extract_token_usage_none_response(self, mocker: MockerFixture) -> None: "utils.responses.extract_provider_and_model_from_model_id", return_value=("provider1", "model1"), ) - mocker.patch("utils.responses._increment_llm_call_metric") + mock_llm_call = mocker.patch("utils.responses.recording.record_llm_call") result = extract_token_usage(None, "provider1/model1") assert result.input_tokens == 0 assert result.output_tokens == 0 - - def test_extract_token_usage_metrics_error(self, mocker: MockerFixture) -> None: - """Test extracting token usage handles errors when updating metrics.""" - mock_usage = mocker.Mock() - mock_usage.input_tokens = 100 - mock_usage.output_tokens = 50 - - mocker.patch( - "utils.responses.extract_provider_and_model_from_model_id", - return_value=("provider1", "model1"), - ) - # Make metrics raise an error - mock_metric = mocker.Mock() - mock_metric.labels.return_value.inc = mocker.Mock( - side_effect=AttributeError("No attribute") - ) - mocker.patch("utils.responses.metrics.llm_token_sent_total", mock_metric) - mocker.patch("utils.responses.metrics.llm_token_received_total", mock_metric) - mocker.patch("utils.responses.logger") - mocker.patch("utils.responses._increment_llm_call_metric") - - # Should not raise, just log warning - result = extract_token_usage(mock_usage, "provider1/model1") - assert result.input_tokens == 100 - assert result.output_tokens == 50 + mock_llm_call.assert_called_once_with("provider1", "model1") class TestBuildToolCallSummary: @@ -2541,58 +2527,6 @@ def test_parse_arguments_string_empty_string(self) -> None: assert result == {"args": ""} -class TestIncrementLlmCallMetric: - """Tests for _increment_llm_call_metric function.""" - - def test_increment_llm_call_metric_success(self, mocker: MockerFixture) -> None: - """Test successful metric increment.""" - mock_metric = mocker.Mock() - mock_metric.labels.return_value.inc = mocker.Mock() - mocker.patch("utils.responses.metrics.llm_calls_total", mock_metric) - - _increment_llm_call_metric("provider1", "model1") - - mock_metric.labels.assert_called_once_with("provider1", "model1") - mock_metric.labels.return_value.inc.assert_called_once() - - def test_increment_llm_call_metric_attribute_error( - self, mocker: MockerFixture - ) -> None: - """Test metric increment handles AttributeError.""" - mocker.patch( - "utils.responses.metrics.llm_calls_total", - side_effect=AttributeError("No attribute"), - ) - mocker.patch("utils.responses.logger") - - # Should not raise exception - _increment_llm_call_metric("provider1", "model1") - - def test_increment_llm_call_metric_type_error(self, mocker: MockerFixture) -> None: - """Test metric increment handles TypeError.""" - mock_metric = mocker.Mock() - mock_metric.labels.return_value.inc = mocker.Mock( - side_effect=TypeError("Invalid type") - ) - mocker.patch("utils.responses.metrics.llm_calls_total", mock_metric) - mocker.patch("utils.responses.logger") - - # Should not raise exception - _increment_llm_call_metric("provider1", "model1") - - def test_increment_llm_call_metric_value_error(self, mocker: MockerFixture) -> None: - """Test metric increment handles ValueError.""" - mock_metric = mocker.Mock() - mock_metric.labels.return_value.inc = mocker.Mock( - side_effect=ValueError("Invalid value") - ) - mocker.patch("utils.responses.metrics.llm_calls_total", mock_metric) - mocker.patch("utils.responses.logger") - - # Should not raise exception - _increment_llm_call_metric("provider1", "model1") - - class TestBuildMCPToolCallFromArgumentsDone: """Tests for build_mcp_tool_call_from_arguments_done function.""" diff --git a/tests/unit/utils/test_shields.py b/tests/unit/utils/test_shields.py index 333c96df0..b7e73b2c1 100644 --- a/tests/unit/utils/test_shields.py +++ b/tests/unit/utils/test_shields.py @@ -54,8 +54,8 @@ def test_detects_violation_when_refusal_present( self, mocker: MockerFixture ) -> None: """Test that detect_shield_violations returns True when refusal is present.""" - mock_metric = mocker.patch( - "utils.shields.metrics.llm_calls_validation_errors_total" + mock_record_error = mocker.patch( + "utils.shields.recording.record_llm_validation_error" ) output_item = mocker.Mock(type="message", refusal="Content blocked") @@ -64,12 +64,12 @@ def test_detects_violation_when_refusal_present( result = detect_shield_violations(output_items) assert result is True - mock_metric.inc.assert_called_once() + mock_record_error.assert_called_once() def test_returns_false_when_no_violation(self, mocker: MockerFixture) -> None: """Test that detect_shield_violations returns False when no refusal.""" - mock_metric = mocker.patch( - "utils.shields.metrics.llm_calls_validation_errors_total" + mock_record_error = mocker.patch( + "utils.shields.recording.record_llm_validation_error" ) output_item = mocker.Mock(type="message", refusal=None) @@ -78,12 +78,12 @@ def test_returns_false_when_no_violation(self, mocker: MockerFixture) -> None: result = detect_shield_violations(output_items) assert result is False - mock_metric.inc.assert_not_called() + mock_record_error.assert_not_called() def test_returns_false_for_non_message_items(self, mocker: MockerFixture) -> None: """Test that detect_shield_violations ignores non-message items.""" - mock_metric = mocker.patch( - "utils.shields.metrics.llm_calls_validation_errors_total" + mock_record_error = mocker.patch( + "utils.shields.recording.record_llm_validation_error" ) output_item = mocker.Mock(type="tool_call", refusal="Content blocked") @@ -92,18 +92,18 @@ def test_returns_false_for_non_message_items(self, mocker: MockerFixture) -> Non result = detect_shield_violations(output_items) assert result is False - mock_metric.inc.assert_not_called() + mock_record_error.assert_not_called() def test_returns_false_for_empty_list(self, mocker: MockerFixture) -> None: """Test that detect_shield_violations returns False for empty list.""" - mock_metric = mocker.patch( - "utils.shields.metrics.llm_calls_validation_errors_total" + mock_record_error = mocker.patch( + "utils.shields.recording.record_llm_validation_error" ) result = detect_shield_violations([]) assert result is False - mock_metric.inc.assert_not_called() + mock_record_error.assert_not_called() class TestRunShieldModeration: @@ -159,8 +159,8 @@ async def test_returns_blocked_when_content_flagged( self, mocker: MockerFixture ) -> None: """Test that run_shield_moderation returns blocked when content is flagged.""" - mock_metric = mocker.patch( - "utils.shields.metrics.llm_calls_validation_errors_total" + mock_record_error = mocker.patch( + "utils.shields.recording.record_llm_validation_error" ) mock_client = mocker.Mock() @@ -191,14 +191,16 @@ async def test_returns_blocked_when_content_flagged( assert result.decision == "blocked" assert result.message == "Content blocked for violence" - mock_metric.inc.assert_called_once() + mock_record_error.assert_called_once() @pytest.mark.asyncio async def test_returns_blocked_with_default_message_when_no_user_message( self, mocker: MockerFixture ) -> None: """Test that run_shield_moderation uses default message when user_message is None.""" - mocker.patch("utils.shields.metrics.llm_calls_validation_errors_total") + mock_record_error = mocker.patch( + "utils.shields.recording.record_llm_validation_error" + ) mock_client = mocker.Mock() # Setup shield @@ -228,6 +230,7 @@ async def test_returns_blocked_with_default_message_when_no_user_message( assert result.decision == "blocked" assert result.message == DEFAULT_VIOLATION_MESSAGE + mock_record_error.assert_called_once() @pytest.mark.asyncio async def test_skips_model_check_for_non_llama_guard_shields(