Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/app/endpoints/rlsapi_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
from openai._exceptions import APIStatusError as OpenAIAPIStatusError

import constants
import metrics
from authentication import get_auth_dependency
from authentication.interface import AuthTuple
from authorization.middleware import authorize
from client import AsyncLlamaStackClientHolder
from configuration import configuration
from log import get_logger
from metrics import recording
from models.config import Action
from models.responses import (
UNAUTHORIZED_OPENAPI_EXAMPLES,
Expand Down Expand Up @@ -447,7 +447,7 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
The total inference time in seconds.
"""
inference_time = time.monotonic() - start_time
metrics.llm_calls_failures_total.labels(provider, model).inc()
recording.record_llm_failure(provider, model)
_queue_splunk_event(
background_tasks,
infer_request,
Expand Down
4 changes: 2 additions & 2 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
)
from openai._exceptions import APIStatusError as OpenAIAPIStatusError

import metrics
from authentication import get_auth_dependency
from authentication.interface import AuthTuple
from authorization.azure_token_manager import AzureEntraIDManager
Expand All @@ -59,6 +58,7 @@
TOPIC_SUMMARY_INTERRUPT_TIMEOUT_SECONDS,
)
from log import get_logger
from metrics import recording
from models.config import Action
from models.context import ResponseGeneratorContext
from models.requests import QueryRequest
Expand Down Expand Up @@ -283,7 +283,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
provider_id, model_id = extract_provider_and_model_from_model_id(
responses_params.model
)
metrics.llm_calls_total.labels(provider_id, model_id).inc()
recording.record_llm_call(provider_id, model_id)

generator, turn_summary = await retrieve_response_generator(
responses_params=responses_params,
Expand Down
6 changes: 3 additions & 3 deletions src/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from starlette.routing import Mount, Route, WebSocketRoute
from starlette.types import ASGIApp, Message, Receive, Scope, Send

import metrics
import version
from a2a_storage import A2AStorageFactory
from app import routers
Expand All @@ -22,6 +21,7 @@
from client import AsyncLlamaStackClientHolder
from configuration import configuration
from log import get_logger
from metrics import recording
from models.responses import InternalServerErrorResponse
from sentry import initialize_sentry
from utils.common import register_mcp_servers_async
Expand Down Expand Up @@ -182,12 +182,12 @@ async def send_wrapper(message: Message) -> None:
# Measure duration and forward the request. Use try/finally so the
# call counter is always incremented, even when the inner app raises.
try:
with metrics.response_duration_seconds.labels(path).time():
with recording.measure_response_duration(path):
await self.app(scope, receive, send_wrapper)
finally:
# Ignore /metrics endpoint that will be called periodically.
if not path.endswith("/metrics"):
metrics.rest_api_calls_total.labels(path, status_code).inc()
recording.record_rest_api_call(path, status_code)


class GlobalExceptionMiddleware: # pylint: disable=too-few-public-methods
Expand Down
6 changes: 2 additions & 4 deletions src/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,12 @@
"ls_llm_validation_errors_total", "LLM validation errors"
)

# TODO(lucasagomes): Add metric for token usage
# https://issues.redhat.com/browse/LCORE-411
# Metric that counts how many tokens were sent to LLMs
llm_token_sent_total = Counter(
"ls_llm_token_sent_total", "LLM tokens sent", ["provider", "model"]
)

# TODO(lucasagomes): Add metric for token usage
# https://issues.redhat.com/browse/LCORE-411
# Metric that counts how many tokens were received from LLMs
llm_token_received_total = Counter(
"ls_llm_token_received_total", "LLM tokens received", ["provider", "model"]
)
99 changes: 99 additions & 0 deletions src/metrics/recording.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Recording helpers for Prometheus metrics.

This module keeps metric definitions in ``metrics.__init__`` while providing a
small facade for application code. New metrics should add a recording helper
here so callers do not need to know Prometheus object details.
"""

from collections.abc import Iterator
from contextlib import contextmanager

import metrics
from log import get_logger

logger = get_logger(__name__)


@contextmanager
def measure_response_duration(path: str) -> Iterator[None]:
"""Measure REST API response duration for a route path.

Args:
path: Normalized route path used as the metric label.
"""
try:
cm = metrics.response_duration_seconds.labels(path).time()
except (AttributeError, TypeError, ValueError):
logger.warning("Failed to start response duration metric", exc_info=True)
yield
return
with cm:
yield


def record_rest_api_call(path: str, status_code: int) -> None:
"""Record one REST API request.

Args:
path: Normalized route path used as the metric label.
status_code: HTTP response status code returned by the endpoint.
"""
try:
metrics.rest_api_calls_total.labels(path, status_code).inc()
except (AttributeError, TypeError, ValueError):
logger.warning("Failed to update REST API call metric", exc_info=True)


def record_llm_call(provider: str, model: str) -> None:
"""Record one LLM call for a provider and model.

Args:
provider: LLM provider identifier.
model: LLM model identifier without the provider prefix.
"""
try:
metrics.llm_calls_total.labels(provider, model).inc()
except (AttributeError, TypeError, ValueError):
logger.warning("Failed to update LLM call metric", exc_info=True)


def record_llm_failure(provider: str, model: str) -> None:
"""Record one failed LLM call for a provider and model.

Args:
provider: LLM provider identifier.
model: LLM model identifier without the provider prefix.
"""
try:
metrics.llm_calls_failures_total.labels(provider, model).inc()
except (AttributeError, TypeError, ValueError):
logger.warning("Failed to update LLM failure metric", exc_info=True)


def record_llm_validation_error() -> None:
"""Record one LLM validation error, such as a shield violation."""
try:
metrics.llm_calls_validation_errors_total.inc()
except (AttributeError, TypeError, ValueError):
logger.warning("Failed to update LLM validation error metric", exc_info=True)


def record_llm_token_usage(
provider: str,
model: str,
input_tokens: int,
output_tokens: int,
) -> None:
"""Record LLM token usage for a provider and model.

Args:
provider: LLM provider identifier.
model: LLM model identifier without the provider prefix.
input_tokens: Number of tokens sent to the LLM.
output_tokens: Number of tokens received from the LLM.
"""
try:
metrics.llm_token_sent_total.labels(provider, model).inc(input_tokens)
metrics.llm_token_received_total.labels(provider, model).inc(output_tokens)
except (AttributeError, TypeError, ValueError):
logger.warning("Failed to update token metrics", exc_info=True)
32 changes: 10 additions & 22 deletions src/utils/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@
from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient

import constants
import metrics
from client import AsyncLlamaStackClientHolder
from configuration import configuration
from constants import DEFAULT_RAG_TOOL
from log import get_logger
from metrics import recording
from models.config import ByokRag
from models.database.conversations import UserConversation
from models.requests import QueryRequest
Expand Down Expand Up @@ -922,7 +922,7 @@ def extract_token_usage(usage: Optional[ResponseUsage], model: str) -> TokenCoun
logger.debug(
"No usage information in Responses API response, token counts will be 0"
)
_increment_llm_call_metric(provider_id, model_id)
recording.record_llm_call(provider_id, model_id)
return TokenCounter(llm_calls=1)

token_counter = TokenCounter(
Expand All @@ -934,18 +934,14 @@ def extract_token_usage(usage: Optional[ResponseUsage], model: str) -> TokenCoun
token_counter.output_tokens,
)

# Update Prometheus metrics only when we have actual usage data
try:
metrics.llm_token_sent_total.labels(provider_id, model_id).inc(
token_counter.input_tokens
)
metrics.llm_token_received_total.labels(provider_id, model_id).inc(
token_counter.output_tokens
)
except (AttributeError, TypeError, ValueError) as e:
logger.warning("Failed to update token metrics: %s", e)

_increment_llm_call_metric(provider_id, model_id)
# Update Prometheus metrics only when we have actual usage data.
recording.record_llm_token_usage(
provider_id,
model_id,
token_counter.input_tokens,
token_counter.output_tokens,
)
recording.record_llm_call(provider_id, model_id)
return token_counter


Expand Down Expand Up @@ -1251,14 +1247,6 @@ def extract_rag_chunks_from_file_search_item(
return rag_chunks


def _increment_llm_call_metric(provider: str, model: str) -> None:
"""Safely increment LLM call metric."""
try:
metrics.llm_calls_total.labels(provider, model).inc()
except (AttributeError, TypeError, ValueError) as e:
logger.warning("Failed to update LLM call metric: %s", e)


def parse_arguments_string(arguments_str: str) -> dict[str, Any]:
"""Parse an arguments string into a dictionary.

Expand Down
6 changes: 3 additions & 3 deletions src/utils/shields.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from llama_stack_client.types import ShieldListResponse
from openai._exceptions import APIStatusError as OpenAIAPIStatusError

import metrics
from configuration import AppConfig
from constants import DEFAULT_VIOLATION_MESSAGE
from log import get_logger
from metrics import recording
from models.requests import QueryRequest
from models.responses import (
InternalServerErrorResponse,
Expand Down Expand Up @@ -77,7 +77,7 @@ def detect_shield_violations(output_items: list[Any]) -> bool:
refusal = getattr(output_item, "refusal", None)
if refusal:
# Metric for LLM validation errors (shield violations)
metrics.llm_calls_validation_errors_total.inc()
recording.record_llm_validation_error()
logger.warning("Shield violation detected: %s", refusal)
return True
return False
Expand Down Expand Up @@ -178,7 +178,7 @@ async def run_shield_moderation(

if moderation_result.results and moderation_result.results[0].flagged:
result = moderation_result.results[0]
metrics.llm_calls_validation_errors_total.inc()
recording.record_llm_validation_error()
logger.warning(
"Shield '%s' flagged content: categories=%s",
shield.identifier,
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/app/endpoints/test_streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ async def test_successful_streaming_query(
"app.endpoints.streaming_query.extract_provider_and_model_from_model_id",
return_value=("provider1", "model1"),
)
mocker.patch("app.endpoints.streaming_query.metrics.llm_calls_total")
mocker.patch("app.endpoints.streaming_query.recording.record_llm_call")

async def mock_generator() -> AsyncIterator[str]:
yield "data: test\n\n"
Expand Down Expand Up @@ -476,7 +476,7 @@ async def test_streaming_query_text_media_type_header(
"app.endpoints.streaming_query.extract_provider_and_model_from_model_id",
return_value=("provider1", "model1"),
)
mocker.patch("app.endpoints.streaming_query.metrics.llm_calls_total")
mocker.patch("app.endpoints.streaming_query.recording.record_llm_call")

async def mock_generator() -> AsyncIterator[str]:
yield "data: test\n\n"
Expand Down Expand Up @@ -574,7 +574,7 @@ async def test_streaming_query_with_conversation(
"app.endpoints.streaming_query.extract_provider_and_model_from_model_id",
return_value=("provider1", "model1"),
)
mocker.patch("app.endpoints.streaming_query.metrics.llm_calls_total")
mocker.patch("app.endpoints.streaming_query.recording.record_llm_call")

async def mock_generator() -> AsyncIterator[str]:
yield "data: test\n\n"
Expand Down Expand Up @@ -670,7 +670,7 @@ async def test_streaming_query_with_attachments(
"app.endpoints.streaming_query.extract_provider_and_model_from_model_id",
return_value=("provider1", "model1"),
)
mocker.patch("app.endpoints.streaming_query.metrics.llm_calls_total")
mocker.patch("app.endpoints.streaming_query.recording.record_llm_call")

async def mock_generator() -> AsyncIterator[str]:
yield "data: test\n\n"
Expand Down Expand Up @@ -770,7 +770,7 @@ async def test_streaming_query_azure_token_refresh(
"app.endpoints.streaming_query.run_shield_moderation",
new=mocker.AsyncMock(return_value=ShieldModerationPassed()),
)
mocker.patch("app.endpoints.streaming_query.metrics.llm_calls_total")
mocker.patch("app.endpoints.streaming_query.recording.record_llm_call")

async def mock_generator() -> AsyncIterator[str]:
yield "data: test\n\n"
Expand Down
Loading
Loading