From 6dff618eca56a8b42f7c42c4a53bfff6b7f2b6e3 Mon Sep 17 00:00:00 2001 From: Ellis Low Date: Tue, 28 Apr 2026 13:55:11 -0400 Subject: [PATCH 1/4] feat(observability): add user_agent to ResponsesEventData for CLA/Goose differentiation Adds a sanitized user_agent field to ResponsesEventData and the Splunk event payload, enabling differentiation between Goose and other clients in telemetry. Extracts and sanitizes the User-Agent header (strips control characters, truncates to 128 chars) before storing. Closes RSPEED-2849 --- src/app/endpoints/responses.py | 42 +++++++++++++ src/observability/formats/responses.py | 4 +- .../app/endpoints/test_responses_splunk.py | 56 +++++++++++++++++ .../observability/formats/test_responses.py | 63 +++++++++++++++++++ 4 files changed, 164 insertions(+), 1 deletion(-) diff --git a/src/app/endpoints/responses.py b/src/app/endpoints/responses.py index 331341279..d999dbed8 100644 --- a/src/app/endpoints/responses.py +++ b/src/app/endpoints/responses.py @@ -112,6 +112,30 @@ logger = get_logger(__name__) router = APIRouter(tags=["responses"]) +_USER_AGENT_MAX_LENGTH = 128 + + +def _get_user_agent(request: Request) -> Optional[str]: + """Extract and sanitize the User-Agent header from the request. + + Parses the raw User-Agent header, strips control characters and newlines, + and truncates to a safe maximum length. Returns None when the header is + absent or empty. + + Args: + request: The FastAPI request object. + + Returns: + Sanitized User-Agent string, or None if the header is absent or empty. + """ + raw = request.headers.get("User-Agent", "") + if not raw: + return None + sanitized = "".join(c for c in raw if ord(c) >= 32 and c not in ("\r", "\n")) + sanitized = sanitized[:_USER_AGENT_MAX_LENGTH] + return sanitized or None + + responses_response: dict[int | str, dict[str, Any]] = { 200: ResponsesResponse.openapi_response(), 401: UnauthorizedResponse.openapi_response( @@ -153,6 +177,7 @@ def _queue_responses_splunk_event( # pylint: disable=too-many-arguments,too-man input_tokens: int = 0, output_tokens: int = 0, fire_and_forget: bool = False, + user_agent: Optional[str] = None, ) -> None: """Build and queue a Splunk telemetry event for the responses endpoint. @@ -173,6 +198,7 @@ def _queue_responses_splunk_event( # pylint: disable=too-many-arguments,too-man fire_and_forget: When True, dispatch via asyncio.create_task() instead of background_tasks. Use for error paths where an HTTPException follows, since FastAPI discards BackgroundTasks on non-2xx responses. + user_agent: Sanitized User-Agent string from the request header, or None. """ if not fire_and_forget and background_tasks is None: return @@ -187,6 +213,7 @@ def _queue_responses_splunk_event( # pylint: disable=too-many-arguments,too-man inference_time=inference_time, input_tokens=input_tokens, output_tokens=output_tokens, + user_agent=user_agent, ) event = build_responses_event(event_data) if fire_and_forget: @@ -360,6 +387,7 @@ async def responses_endpoint_handler( filter_server_tools=filter_server_tools, background_tasks=background_tasks, rh_identity_context=rh_identity_context, + user_agent=_get_user_agent(request), ) @@ -375,6 +403,7 @@ async def handle_streaming_response( filter_server_tools: bool = False, background_tasks: Optional[BackgroundTasks] = None, rh_identity_context: tuple[str, str] = ("", ""), + user_agent: Optional[str] = None, ) -> StreamingResponse: """Handle streaming response from Responses API. @@ -424,6 +453,7 @@ async def handle_streaming_response( rh_identity_context=rh_identity_context, inference_time=(datetime.now(UTC) - started_at).total_seconds(), sourcetype="responses_shield_blocked", + user_agent=user_agent, ) else: try: @@ -452,6 +482,7 @@ async def handle_streaming_response( inference_time=(datetime.now(UTC) - started_at).total_seconds(), sourcetype="responses_error", fire_and_forget=True, + user_agent=user_agent, ) error_response = PromptTooLongResponse(model=api_params.model) raise HTTPException(**error_response.model_dump()) from e @@ -467,6 +498,7 @@ async def handle_streaming_response( inference_time=(datetime.now(UTC) - started_at).total_seconds(), sourcetype="responses_error", fire_and_forget=True, + user_agent=user_agent, ) error_response = ServiceUnavailableResponse( backend_name="Llama Stack", @@ -484,6 +516,7 @@ async def handle_streaming_response( inference_time=(datetime.now(UTC) - started_at).total_seconds(), sourcetype="responses_error", fire_and_forget=True, + user_agent=user_agent, ) error_response = handle_known_apistatus_errors(e, api_params.model) raise HTTPException(**error_response.model_dump()) from e @@ -501,6 +534,7 @@ async def handle_streaming_response( background_tasks=background_tasks, rh_identity_context=rh_identity_context, shield_blocked=(moderation_result.decision == "blocked"), + user_agent=user_agent, ), media_type="text/event-stream", ) @@ -894,6 +928,7 @@ async def generate_response( background_tasks: Optional[BackgroundTasks] = None, rh_identity_context: tuple[str, str] = ("", ""), shield_blocked: bool = False, + user_agent: Optional[str] = None, ) -> AsyncIterator[str]: """Stream the response from the generator and persist conversation details. @@ -956,6 +991,7 @@ async def generate_response( if turn_summary.token_usage else 0 ), + user_agent=user_agent, ) @@ -971,6 +1007,7 @@ async def handle_non_streaming_response( filter_server_tools: bool = False, background_tasks: Optional[BackgroundTasks] = None, rh_identity_context: tuple[str, str] = ("", ""), + user_agent: Optional[str] = None, ) -> ResponsesResponse: """Handle non-streaming response from Responses API. @@ -1019,6 +1056,7 @@ async def handle_non_streaming_response( rh_identity_context=rh_identity_context, inference_time=(datetime.now(UTC) - started_at).total_seconds(), sourcetype="responses_shield_blocked", + user_agent=user_agent, ) else: try: @@ -1057,6 +1095,7 @@ async def handle_non_streaming_response( inference_time=(datetime.now(UTC) - started_at).total_seconds(), sourcetype="responses_error", fire_and_forget=True, + user_agent=user_agent, ) error_response = PromptTooLongResponse(model=api_params.model) raise HTTPException(**error_response.model_dump()) from e @@ -1072,6 +1111,7 @@ async def handle_non_streaming_response( inference_time=(datetime.now(UTC) - started_at).total_seconds(), sourcetype="responses_error", fire_and_forget=True, + user_agent=user_agent, ) error_response = ServiceUnavailableResponse( backend_name="Llama Stack", @@ -1089,6 +1129,7 @@ async def handle_non_streaming_response( inference_time=(datetime.now(UTC) - started_at).total_seconds(), sourcetype="responses_error", fire_and_forget=True, + user_agent=user_agent, ) error_response = handle_known_apistatus_errors(e, api_params.model) raise HTTPException(**error_response.model_dump()) from e @@ -1135,6 +1176,7 @@ async def handle_non_streaming_response( if turn_summary.token_usage else 0 ), + user_agent=user_agent, ) if api_params.store: store_query_results( diff --git a/src/observability/formats/responses.py b/src/observability/formats/responses.py index d8564de8d..380cfd8eb 100644 --- a/src/observability/formats/responses.py +++ b/src/observability/formats/responses.py @@ -6,7 +6,7 @@ """ from dataclasses import dataclass -from typing import Any +from typing import Any, Optional from configuration import configuration @@ -24,6 +24,7 @@ class ResponsesEventData: # pylint: disable=too-many-instance-attributes inference_time: float input_tokens: int = 0 output_tokens: int = 0 + user_agent: Optional[str] = None def build_responses_event(data: ResponsesEventData) -> dict[str, Any]: @@ -45,4 +46,5 @@ def build_responses_event(data: ResponsesEventData) -> dict[str, Any]: "org_id": data.org_id, "system_id": data.system_id, "total_llm_tokens": data.input_tokens + data.output_tokens, + "user_agent": data.user_agent, } diff --git a/tests/unit/app/endpoints/test_responses_splunk.py b/tests/unit/app/endpoints/test_responses_splunk.py index ce1588a75..1b6ccf283 100644 --- a/tests/unit/app/endpoints/test_responses_splunk.py +++ b/tests/unit/app/endpoints/test_responses_splunk.py @@ -16,6 +16,7 @@ from app.endpoints.responses import ( _background_splunk_tasks, + _get_user_agent, _queue_responses_splunk_event, handle_non_streaming_response, handle_streaming_response, @@ -715,3 +716,58 @@ async def test_splunk_disabled_no_background_tasks( mock_queue.assert_called_once() assert mock_queue.call_args[1]["background_tasks"] is None + + +class TestGetUserAgent: + """Tests for _get_user_agent header extraction and sanitization.""" + + def test_returns_user_agent_from_header(self, mocker: MockerFixture) -> None: + """Test that a valid User-Agent header is returned as-is.""" + request = mocker.MagicMock() + request.headers.get.return_value = "goose/1.0.0" + + result = _get_user_agent(request) + + assert result == "goose/1.0.0" + + def test_returns_none_when_header_absent(self, mocker: MockerFixture) -> None: + """Test that None is returned when User-Agent header is empty.""" + request = mocker.MagicMock() + request.headers.get.return_value = "" + + result = _get_user_agent(request) + + assert result is None + + def test_strips_control_characters(self, mocker: MockerFixture) -> None: + """Test that control characters and newlines are stripped from User-Agent.""" + request = mocker.MagicMock() + request.headers.get.return_value = "goose/1.0.0\r\nX-Injected: evil" + + result: str = _get_user_agent(request) or "" + + assert result != "" + assert "\r" not in result + assert "\n" not in result + assert "goose/1.0.0" in result + + def test_truncates_to_128_characters(self, mocker: MockerFixture) -> None: + """Test that User-Agent is truncated to 128 characters.""" + request = mocker.MagicMock() + request.headers.get.return_value = "a" * 200 + + result = _get_user_agent(request) + + assert isinstance(result, str) + assert len(result) == 128 + + def test_returns_none_for_only_control_characters( + self, mocker: MockerFixture + ) -> None: + """Test that None is returned when User-Agent contains only control characters.""" + request = mocker.MagicMock() + request.headers.get.return_value = "\r\n\x01\x02" + + result = _get_user_agent(request) + + assert result is None diff --git a/tests/unit/observability/formats/test_responses.py b/tests/unit/observability/formats/test_responses.py index 03e585076..25d71267e 100644 --- a/tests/unit/observability/formats/test_responses.py +++ b/tests/unit/observability/formats/test_responses.py @@ -20,6 +20,21 @@ def sample_event_data_fixture() -> ResponsesEventData: ) +@pytest.fixture(name="sample_event_data_with_user_agent") +def sample_event_data_with_user_agent_fixture() -> ResponsesEventData: + """Create sample responses event data with user_agent set.""" + return ResponsesEventData( + input_text="How do I configure SSH?", + response_text="To configure SSH, edit /etc/ssh/sshd_config...", + conversation_id="conv-abc-123", + model="granite-3-8b-instruct", + org_id="12345678", + system_id="abc-def-123", + inference_time=2.34, + user_agent="goose/1.0.0", + ) + + def test_builds_event_with_all_fields( mocker: MockerFixture, sample_event_data: ResponsesEventData ) -> None: @@ -97,3 +112,51 @@ def test_default_token_values() -> None: assert data.input_tokens == 0 assert data.output_tokens == 0 + + +def test_user_agent_defaults_to_none() -> None: + """Test user_agent field defaults to None when not provided.""" + data = ResponsesEventData( + input_text="test", + response_text="test", + conversation_id="conv-789", + inference_time=1.0, + model="test-model", + org_id="org1", + system_id="sys1", + ) + + assert data.user_agent is None + + +def test_user_agent_included_in_splunk_event( + mocker: MockerFixture, + sample_event_data_with_user_agent: ResponsesEventData, +) -> None: + """Test user_agent field is included in the Splunk event payload.""" + mock_config = mocker.patch("observability.formats.responses.configuration") + mock_config.deployment_environment = "production" + + event = build_responses_event(sample_event_data_with_user_agent) + + assert event["user_agent"] == "goose/1.0.0" + + +def test_user_agent_none_included_in_splunk_event(mocker: MockerFixture) -> None: + """Test user_agent=None is included in the Splunk event payload.""" + mock_config = mocker.patch("observability.formats.responses.configuration") + mock_config.deployment_environment = "production" + + data = ResponsesEventData( + input_text="test", + response_text="test", + conversation_id="conv-123", + inference_time=1.0, + model="test-model", + org_id="org1", + system_id="sys1", + ) + + event = build_responses_event(data) + + assert event["user_agent"] is None From b1b141dd2f0dcb52991875b7921578e7eba12bfd Mon Sep 17 00:00:00 2001 From: Ellis Low Date: Tue, 28 Apr 2026 14:17:53 -0400 Subject: [PATCH 2/4] fix: add Final annotation and docstring updates for user_agent --- src/app/endpoints/responses.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/app/endpoints/responses.py b/src/app/endpoints/responses.py index d999dbed8..230ce111e 100644 --- a/src/app/endpoints/responses.py +++ b/src/app/endpoints/responses.py @@ -6,7 +6,7 @@ import json from collections.abc import AsyncIterator from datetime import UTC, datetime -from typing import Annotated, Any, Optional, cast +from typing import Annotated, Any, Final, Optional, cast from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request from fastapi.responses import StreamingResponse @@ -112,7 +112,7 @@ logger = get_logger(__name__) router = APIRouter(tags=["responses"]) -_USER_AGENT_MAX_LENGTH = 128 +_USER_AGENT_MAX_LENGTH: Final[int] = 128 def _get_user_agent(request: Request) -> Optional[str]: @@ -418,6 +418,7 @@ async def handle_streaming_response( filter_server_tools: Whether to filter server-deployed MCP tool events from the stream background_tasks: FastAPI background task manager for telemetry events rh_identity_context: Tuple of (org_id, system_id) from RH identity + user_agent: Sanitized User-Agent string from request headers, or None. Returns: StreamingResponse with SSE-formatted events """ @@ -946,6 +947,7 @@ async def generate_response( background_tasks: FastAPI background task manager for telemetry events rh_identity_context: Tuple of (org_id, system_id) from RH identity shield_blocked: Whether the request was blocked by a shield + user_agent: Sanitized User-Agent string from request headers, or None. Yields: SSE-formatted strings from the generator """ @@ -1023,6 +1025,7 @@ async def handle_non_streaming_response( filter_server_tools: Whether to filter server-deployed MCP tool output background_tasks: FastAPI background task manager for telemetry events rh_identity_context: Tuple of (org_id, system_id) from RH identity + user_agent: Sanitized User-Agent string from request headers, or None. Returns: ResponsesResponse with the completed response """ From b0ec99a46ece494f7cd1f68c11936e1a329d5e13 Mon Sep 17 00:00:00 2001 From: Ellis Low Date: Tue, 28 Apr 2026 14:45:05 -0400 Subject: [PATCH 3/4] docs: add PR title prefix requirement to AGENTS.md --- AGENTS.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index d36362f09..119022978 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -193,6 +193,15 @@ uv run make test-e2e # End-to-end tests - **SQLAlchemy**: Database ORM - **Kubernetes**: K8s auth integration +## Pull Request Requirements + +**PR titles MUST start with a JIRA issue key prefix.** CI enforces this via `pr-title-checker` (config: `.github/pr-title-checker-config.json`). + +Allowed prefixes: `LCORE-`, `RSPEED-`, `MGTM-`, `OLS-`, `RHDHPAI-`, `LEADS-` + +- ✅ `RSPEED-2849: add user_agent to ResponsesEventData` +- ❌ `feat(observability): add user_agent to ResponsesEventData` + ## Development Workflow 1. Use `uv sync --group dev --group llslibdev` for dependencies 2. Always use `uv run` prefix for commands From 8ae0d656fe377de8d2b1f7b059cd16f54da79f5e Mon Sep 17 00:00:00 2001 From: Ellis Low Date: Tue, 28 Apr 2026 17:13:02 -0400 Subject: [PATCH 4/4] RSPEED-2849: add endpoint label to LLM Prometheus metrics --- src/app/endpoints/query.py | 14 +++++++-- src/app/endpoints/responses.py | 14 +++++++-- src/app/endpoints/rlsapi_v1.py | 37 +++++++++++++++++++----- src/app/endpoints/streaming_query.py | 11 +++++-- src/metrics/__init__.py | 14 +++++---- src/utils/responses.py | 27 ++++++++++------- src/utils/shields.py | 4 ++- tests/unit/app/endpoints/test_metrics.py | 1 - tests/unit/utils/test_responses.py | 22 +++++++------- tests/unit/utils/test_shields.py | 35 +++++++++++++++------- 10 files changed, 124 insertions(+), 55 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index b4c31b017..61e7bd81e 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -170,9 +170,10 @@ async def query_endpoint_handler( # Moderation input is the raw user content (query + attachments) without injected RAG # context, to avoid false positives from retrieved document content. + endpoint_path = "/v1/query" moderation_input = prepare_input(query_request) moderation_result = await run_shield_moderation( - client, moderation_input, query_request.shield_ids + client, moderation_input, endpoint_path, query_request.shield_ids ) # Build RAG context from Inline RAG sources @@ -207,7 +208,9 @@ async def query_endpoint_handler( client = await update_azure_token(client) # Retrieve response using Responses API - turn_summary = await retrieve_response(client, responses_params, moderation_result) + turn_summary = await retrieve_response( + client, responses_params, moderation_result, endpoint_path + ) if moderation_result.decision == "passed": # Combine inline RAG results (BYOK + Solr) with tool-based RAG results for the transcript @@ -280,6 +283,7 @@ async def retrieve_response( client: AsyncLlamaStackClient, responses_params: ResponsesApiParams, moderation_result: ShieldModerationResult, + endpoint_path: str = "", ) -> TurnSummary: """ Retrieve response from LLMs and agents. @@ -332,5 +336,9 @@ async def retrieve_response( vector_store_ids = extract_vector_store_ids_from_tools(responses_params.tools) rag_id_mapping = configuration.rag_id_mapping return build_turn_summary( - response, responses_params.model, vector_store_ids, rag_id_mapping + response, + responses_params.model, + endpoint_path, + vector_store_ids, + rag_id_mapping, ) diff --git a/src/app/endpoints/responses.py b/src/app/endpoints/responses.py index 230ce111e..b13cde090 100644 --- a/src/app/endpoints/responses.py +++ b/src/app/endpoints/responses.py @@ -331,9 +331,11 @@ async def responses_endpoint_handler( ) attachments_text = extract_attachments_text(original_request.input) + endpoint_path = "/v1/responses" moderation_result = await run_shield_moderation( client, input_text + "\n\n" + attachments_text, + endpoint_path, original_request.shield_ids, ) @@ -388,6 +390,7 @@ async def responses_endpoint_handler( background_tasks=background_tasks, rh_identity_context=rh_identity_context, user_agent=_get_user_agent(request), + endpoint_path=endpoint_path, ) @@ -404,6 +407,7 @@ async def handle_streaming_response( background_tasks: Optional[BackgroundTasks] = None, rh_identity_context: tuple[str, str] = ("", ""), user_agent: Optional[str] = None, + endpoint_path: str = "", ) -> StreamingResponse: """Handle streaming response from Responses API. @@ -470,6 +474,7 @@ async def handle_streaming_response( turn_summary=turn_summary, inline_rag_context=inline_rag_context, filter_server_tools=filter_server_tools, + endpoint_path=endpoint_path, ) except RuntimeError as e: # library mode wraps 413 into runtime error if is_context_length_error(str(e)): @@ -798,6 +803,7 @@ async def response_generator( turn_summary: TurnSummary, inline_rag_context: RAGContext, filter_server_tools: bool = False, + endpoint_path: str = "", ) -> AsyncIterator[str]: """Generate SSE-formatted streaming response with LCORE-enriched events. @@ -873,7 +879,7 @@ async def response_generator( # Extract and consume tokens if any were used turn_summary.token_usage = extract_token_usage( - latest_response_object.usage, api_params.model + latest_response_object.usage, api_params.model, endpoint_path ) consume_query_tokens( user_id=user_id, @@ -1010,6 +1016,7 @@ async def handle_non_streaming_response( background_tasks: Optional[BackgroundTasks] = None, rh_identity_context: tuple[str, str] = ("", ""), user_agent: Optional[str] = None, + endpoint_path: str = "", ) -> ResponsesResponse: """Handle non-streaming response from Responses API. @@ -1069,7 +1076,9 @@ async def handle_non_streaming_response( **api_params.model_dump(exclude_none=True) ), ) - token_usage = extract_token_usage(api_response.usage, api_params.model) + token_usage = extract_token_usage( + api_response.usage, api_params.model, endpoint_path + ) logger.info("Consuming tokens") consume_query_tokens( user_id=user_id, @@ -1152,6 +1161,7 @@ async def handle_non_streaming_response( turn_summary = build_turn_summary( api_response, api_params.model, + endpoint_path, vector_store_ids, configuration.rag_id_mapping, filter_server_tools=filter_server_tools, diff --git a/src/app/endpoints/rlsapi_v1.py b/src/app/endpoints/rlsapi_v1.py index c38a46a04..e2c077d9b 100644 --- a/src/app/endpoints/rlsapi_v1.py +++ b/src/app/endpoints/rlsapi_v1.py @@ -241,6 +241,7 @@ async def retrieve_simple_response( instructions: str, tools: Optional[list[Any]] = None, model_id: Optional[str] = None, + endpoint_path: str = "/v1/infer", ) -> str: """Retrieve a simple response from the LLM for a stateless query. @@ -263,7 +264,7 @@ async def retrieve_simple_response( """ resolved_model_id = model_id or await _get_default_model_id() response = await _call_llm(question, instructions, tools, resolved_model_id) - extract_token_usage(response.usage, resolved_model_id) + extract_token_usage(response.usage, resolved_model_id, endpoint_path) return extract_text_from_response_items(response.output) @@ -366,12 +367,13 @@ def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-position background_tasks.add_task(send_splunk_event, event, sourcetype) -async def _check_shield_moderation( +async def _check_shield_moderation( # pylint: disable=too-many-arguments,too-many-positional-arguments input_text: str, request_id: str, background_tasks: BackgroundTasks, infer_request: RlsapiV1InferRequest, request: Request, + endpoint_path: str, ) -> Optional[RlsapiV1InferResponse]: """Run shield moderation and return a refusal response if blocked. @@ -384,13 +386,14 @@ async def _check_shield_moderation( background_tasks: FastAPI background tasks for async Splunk event sending. infer_request: The original inference request (for Splunk event context). request: The FastAPI request object (for Splunk event context). + endpoint_path: The API endpoint path for metric labeling. Returns: An RlsapiV1InferResponse containing the refusal message if the input was blocked, or None if moderation passed. """ client = AsyncLlamaStackClientHolder().get_client() - moderation_result = await run_shield_moderation(client, input_text) + moderation_result = await run_shield_moderation(client, input_text, endpoint_path) if moderation_result.decision != "blocked": return None @@ -432,6 +435,7 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po start_time: float, model: str, provider: str, + endpoint_path: str, ) -> float: """Record metrics and queue Splunk event for an inference failure. @@ -442,12 +446,15 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po request_id: Unique identifier for the request. error: The exception that caused the failure. start_time: Monotonic clock time when inference started. + model: The model name. + provider: The provider name. + endpoint_path: The API endpoint path for metric labeling. Returns: The total inference time in seconds. """ inference_time = time.monotonic() - start_time - metrics.llm_calls_failures_total.labels(provider, model).inc() + metrics.llm_calls_failures_total.labels(provider, model, endpoint_path).inc() _queue_splunk_event( background_tasks, infer_request, @@ -530,6 +537,7 @@ def _build_infer_response( request_id: str, response: Optional[OpenAIResponseObject], model_id: str, + endpoint_path: str, ) -> RlsapiV1InferResponse: """Build the final inference response, with optional verbose metadata. @@ -549,7 +557,11 @@ def _build_infer_response( """ if response is not None: turn_summary = build_turn_summary( - response, model_id, vector_store_ids=None, rag_id_mapping=None + response, + model_id, + endpoint_path, + vector_store_ids=None, + rag_id_mapping=None, ) return RlsapiV1InferResponse( data=RlsapiV1InferData( @@ -673,12 +685,19 @@ async def infer_endpoint( # pylint: disable=R0914 "Request %s: Combined input source length: %d", request_id, len(input_source) ) + endpoint_path = "/v1/infer" + # Run shield moderation on user input before inference. # Uses all configured shields; no-op when no shields are registered. # Runs before model/tool discovery so blocked requests short-circuit # without incurring external I/O. blocked_response = await _check_shield_moderation( - input_source, request_id, background_tasks, infer_request, request + input_source, + request_id, + background_tasks, + infer_request, + request, + endpoint_path, ) if blocked_response is not None: return blocked_response @@ -700,11 +719,11 @@ async def infer_endpoint( # pylint: disable=R0914 model_id=model_id, ) response_text = extract_text_from_response_items(response.output) - token_usage = extract_token_usage(response.usage, model_id) + token_usage = extract_token_usage(response.usage, model_id, endpoint_path) inference_time = time.monotonic() - start_time except _INFER_HANDLED_EXCEPTIONS as error: if response is not None: - extract_token_usage(response.usage, model_id) # type: ignore[arg-type] + extract_token_usage(response.usage, model_id, endpoint_path) # type: ignore[arg-type] _record_inference_failure( background_tasks, infer_request, @@ -714,6 +733,7 @@ async def infer_endpoint( # pylint: disable=R0914 start_time, model, provider, + endpoint_path, ) mapped_error = _map_inference_error_to_http_exception( error, @@ -755,4 +775,5 @@ async def infer_endpoint( # pylint: disable=R0914 request_id, response if verbose_enabled else None, model_id, + endpoint_path, ) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index acab21f40..6cc20e8a3 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -226,8 +226,9 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals # Moderation input is the raw user content (query + attachments) without injected RAG # context, to avoid false positives from retrieved document content. moderation_input = prepare_input(query_request) + endpoint_path = "/v1/streaming_query" moderation_result = await run_shield_moderation( - client, moderation_input, query_request.shield_ids + client, moderation_input, endpoint_path, query_request.shield_ids ) # Build RAG context from Inline RAG sources @@ -283,11 +284,12 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals provider_id, model_id = extract_provider_and_model_from_model_id( responses_params.model ) - metrics.llm_calls_total.labels(provider_id, model_id).inc() + metrics.llm_calls_total.labels(provider_id, model_id, endpoint_path).inc() generator, turn_summary = await retrieve_response_generator( responses_params=responses_params, context=context, + endpoint_path=endpoint_path, ) # Combine inline RAG results (BYOK + Solr) with tool-based results @@ -316,6 +318,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals async def retrieve_response_generator( responses_params: ResponsesApiParams, context: ResponseGeneratorContext, + endpoint_path: str = "", ) -> tuple[AsyncIterator[str], TurnSummary]: """ Retrieve the appropriate response generator. @@ -360,6 +363,7 @@ async def retrieve_response_generator( response, context, turn_summary, + endpoint_path, ), turn_summary, ) @@ -685,6 +689,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat turn_response: AsyncIterator[OpenAIResponseObjectStream], context: ResponseGeneratorContext, turn_summary: TurnSummary, + endpoint_path: str = "", ) -> AsyncIterator[str]: """Generate SSE formatted streaming response. @@ -862,7 +867,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat return turn_summary.token_usage = extract_token_usage( - latest_response_object.usage, context.model_id + latest_response_object.usage, context.model_id, endpoint_path ) # Parse tool-based referenced documents from the final response object tool_rag_docs = parse_referenced_documents( diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py index 5c4e4e44f..31aa641d8 100644 --- a/src/metrics/__init__.py +++ b/src/metrics/__init__.py @@ -29,27 +29,31 @@ # Metric that counts how many LLM calls were made for each provider + model llm_calls_total = Counter( - "ls_llm_calls_total", "LLM calls counter", ["provider", "model"] + "ls_llm_calls_total", "LLM calls counter", ["provider", "model", "endpoint"] ) # Metric that counts how many LLM calls failed llm_calls_failures_total = Counter( - "ls_llm_calls_failures_total", "LLM calls failures", ["provider", "model"] + "ls_llm_calls_failures_total", + "LLM calls failures", + ["provider", "model", "endpoint"], ) # Metric that counts how many LLM calls had validation errors llm_calls_validation_errors_total = Counter( - "ls_llm_validation_errors_total", "LLM validation errors" + "ls_llm_validation_errors_total", "LLM validation errors", ["endpoint"] ) # TODO(lucasagomes): Add metric for token usage # https://issues.redhat.com/browse/LCORE-411 llm_token_sent_total = Counter( - "ls_llm_token_sent_total", "LLM tokens sent", ["provider", "model"] + "ls_llm_token_sent_total", "LLM tokens sent", ["provider", "model", "endpoint"] ) # TODO(lucasagomes): Add metric for token usage # https://issues.redhat.com/browse/LCORE-411 llm_token_received_total = Counter( - "ls_llm_token_received_total", "LLM tokens received", ["provider", "model"] + "ls_llm_token_received_total", + "LLM tokens received", + ["provider", "model", "endpoint"], ) diff --git a/src/utils/responses.py b/src/utils/responses.py index 858973d02..7d4ad3a88 100644 --- a/src/utils/responses.py +++ b/src/utils/responses.py @@ -907,12 +907,15 @@ def parse_rag_chunks( return rag_chunks -def extract_token_usage(usage: Optional[ResponseUsage], model: str) -> TokenCounter: +def extract_token_usage( + usage: Optional[ResponseUsage], model: str, endpoint_path: str +) -> TokenCounter: """Extract token usage from Responses API usage object and update metrics. Args: usage: ResponseUsage from the Responses API response, or None if not available. model: The model identifier in "provider/model" format + endpoint_path: The API endpoint path for metric labeling. Returns: TokenCounter with input_tokens and output_tokens @@ -922,7 +925,7 @@ def extract_token_usage(usage: Optional[ResponseUsage], model: str) -> TokenCoun logger.debug( "No usage information in Responses API response, token counts will be 0" ) - _increment_llm_call_metric(provider_id, model_id) + _increment_llm_call_metric(provider_id, model_id, endpoint_path) return TokenCounter(llm_calls=1) token_counter = TokenCounter( @@ -936,16 +939,16 @@ def extract_token_usage(usage: Optional[ResponseUsage], model: str) -> TokenCoun # Update Prometheus metrics only when we have actual usage data try: - metrics.llm_token_sent_total.labels(provider_id, model_id).inc( + metrics.llm_token_sent_total.labels(provider_id, model_id, endpoint_path).inc( token_counter.input_tokens ) - metrics.llm_token_received_total.labels(provider_id, model_id).inc( - token_counter.output_tokens - ) + metrics.llm_token_received_total.labels( + provider_id, model_id, endpoint_path + ).inc(token_counter.output_tokens) except (AttributeError, TypeError, ValueError) as e: logger.warning("Failed to update token metrics: %s", e) - _increment_llm_call_metric(provider_id, model_id) + _increment_llm_call_metric(provider_id, model_id, endpoint_path) return token_counter @@ -1251,10 +1254,10 @@ def extract_rag_chunks_from_file_search_item( return rag_chunks -def _increment_llm_call_metric(provider: str, model: str) -> None: +def _increment_llm_call_metric(provider: str, model: str, endpoint_path: str) -> None: """Safely increment LLM call metric.""" try: - metrics.llm_calls_total.labels(provider, model).inc() + metrics.llm_calls_total.labels(provider, model, endpoint_path).inc() except (AttributeError, TypeError, ValueError) as e: logger.warning("Failed to update LLM call metric: %s", e) @@ -1444,9 +1447,10 @@ def is_server_deployed_output(output_item: ResponseOutput) -> bool: return True -def build_turn_summary( +def build_turn_summary( # pylint: disable=too-many-arguments,too-many-positional-arguments response: Optional[OpenAIResponseObject], model: str, + endpoint_path: str, vector_store_ids: Optional[list[str]] = None, rag_id_mapping: Optional[dict[str, str]] = None, filter_server_tools: bool = False, @@ -1456,6 +1460,7 @@ def build_turn_summary( Args: response: The ResponseObject to build the turn summary from, or None model: The model identifier in "provider/model" format + endpoint_path: The API endpoint path for metric labeling. vector_store_ids: Vector store IDs used in the query for source resolution. rag_id_mapping: Mapping from vector_db_id to user-facing rag_id. filter_server_tools: When True, skip client-provided tool output items @@ -1490,7 +1495,7 @@ def build_turn_summary( summary.tool_results.append(tool_result) summary.rag_chunks = parse_rag_chunks(response, vector_store_ids, rag_id_mapping) - summary.token_usage = extract_token_usage(response.usage, model) + summary.token_usage = extract_token_usage(response.usage, model, endpoint_path) return summary diff --git a/src/utils/shields.py b/src/utils/shields.py index 7e3b4cc0a..c6da708f9 100644 --- a/src/utils/shields.py +++ b/src/utils/shields.py @@ -122,6 +122,7 @@ def validate_shield_ids_override( async def run_shield_moderation( client: AsyncLlamaStackClient, input_text: str, + endpoint_path: str, shield_ids: Optional[list[str]] = None, ) -> ShieldModerationResult: """ @@ -134,6 +135,7 @@ async def run_shield_moderation( ---------- client: The Llama Stack client. input_text: The text to moderate. + endpoint_path: The API endpoint path for metric labeling. shield_ids: Optional list of shield IDs to use. If None, uses all shields. If empty list, skips all shields. @@ -178,7 +180,7 @@ async def run_shield_moderation( if moderation_result.results and moderation_result.results[0].flagged: result = moderation_result.results[0] - metrics.llm_calls_validation_errors_total.inc() + metrics.llm_calls_validation_errors_total.labels(endpoint_path).inc() logger.warning( "Shield '%s' flagged content: categories=%s", shield.identifier, diff --git a/tests/unit/app/endpoints/test_metrics.py b/tests/unit/app/endpoints/test_metrics.py index d2e1dbbab..5cebc0529 100644 --- a/tests/unit/app/endpoints/test_metrics.py +++ b/tests/unit/app/endpoints/test_metrics.py @@ -43,6 +43,5 @@ async def test_metrics_endpoint(mocker: MockerFixture) -> None: assert "# TYPE ls_llm_calls_total counter" in response_body assert "# TYPE ls_llm_calls_failures_total counter" in response_body assert "# TYPE ls_llm_validation_errors_total counter" in response_body - assert "# TYPE ls_llm_validation_errors_created gauge" in response_body assert "# TYPE ls_llm_token_sent_total counter" in response_body assert "# TYPE ls_llm_token_received_total counter" in response_body diff --git a/tests/unit/utils/test_responses.py b/tests/unit/utils/test_responses.py index 485ecbf26..2c11d3540 100644 --- a/tests/unit/utils/test_responses.py +++ b/tests/unit/utils/test_responses.py @@ -2229,7 +2229,7 @@ def test_extract_token_usage_with_usage_object( mocker.patch("utils.responses.metrics.llm_token_received_total") mocker.patch("utils.responses._increment_llm_call_metric") - result = extract_token_usage(mock_usage, "provider1/model1") + result = extract_token_usage(mock_usage, "provider1/model1", "/test-endpoint") assert result.input_tokens == input_tokens assert result.output_tokens == output_tokens assert result.llm_calls == 1 @@ -2242,7 +2242,7 @@ def test_extract_token_usage_no_usage(self, mocker: MockerFixture) -> None: ) mocker.patch("utils.responses._increment_llm_call_metric") - result = extract_token_usage(None, "provider1/model1") + result = extract_token_usage(None, "provider1/model1", "/test-endpoint") assert result.input_tokens == 0 assert result.output_tokens == 0 assert result.llm_calls == 1 @@ -2259,7 +2259,7 @@ def test_extract_token_usage_zero_tokens(self, mocker: MockerFixture) -> None: ) mocker.patch("utils.responses._increment_llm_call_metric") - result = extract_token_usage(mock_usage, "provider1/model1") + result = extract_token_usage(mock_usage, "provider1/model1", "/test-endpoint") assert result.input_tokens == 0 assert result.output_tokens == 0 @@ -2271,7 +2271,7 @@ def test_extract_token_usage_none_response(self, mocker: MockerFixture) -> None: ) mocker.patch("utils.responses._increment_llm_call_metric") - result = extract_token_usage(None, "provider1/model1") + result = extract_token_usage(None, "provider1/model1", "/test-endpoint") assert result.input_tokens == 0 assert result.output_tokens == 0 @@ -2296,7 +2296,7 @@ def test_extract_token_usage_metrics_error(self, mocker: MockerFixture) -> None: mocker.patch("utils.responses._increment_llm_call_metric") # Should not raise, just log warning - result = extract_token_usage(mock_usage, "provider1/model1") + result = extract_token_usage(mock_usage, "provider1/model1", "/test-endpoint") assert result.input_tokens == 100 assert result.output_tokens == 50 @@ -2550,9 +2550,11 @@ def test_increment_llm_call_metric_success(self, mocker: MockerFixture) -> None: mock_metric.labels.return_value.inc = mocker.Mock() mocker.patch("utils.responses.metrics.llm_calls_total", mock_metric) - _increment_llm_call_metric("provider1", "model1") + _increment_llm_call_metric("provider1", "model1", "/test-endpoint") - mock_metric.labels.assert_called_once_with("provider1", "model1") + mock_metric.labels.assert_called_once_with( + "provider1", "model1", "/test-endpoint" + ) mock_metric.labels.return_value.inc.assert_called_once() def test_increment_llm_call_metric_attribute_error( @@ -2566,7 +2568,7 @@ def test_increment_llm_call_metric_attribute_error( mocker.patch("utils.responses.logger") # Should not raise exception - _increment_llm_call_metric("provider1", "model1") + _increment_llm_call_metric("provider1", "model1", "/test-endpoint") def test_increment_llm_call_metric_type_error(self, mocker: MockerFixture) -> None: """Test metric increment handles TypeError.""" @@ -2578,7 +2580,7 @@ def test_increment_llm_call_metric_type_error(self, mocker: MockerFixture) -> No mocker.patch("utils.responses.logger") # Should not raise exception - _increment_llm_call_metric("provider1", "model1") + _increment_llm_call_metric("provider1", "model1", "/test-endpoint") def test_increment_llm_call_metric_value_error(self, mocker: MockerFixture) -> None: """Test metric increment handles ValueError.""" @@ -2590,7 +2592,7 @@ def test_increment_llm_call_metric_value_error(self, mocker: MockerFixture) -> N mocker.patch("utils.responses.logger") # Should not raise exception - _increment_llm_call_metric("provider1", "model1") + _increment_llm_call_metric("provider1", "model1", "/test-endpoint") class TestBuildMCPToolCallFromArgumentsDone: diff --git a/tests/unit/utils/test_shields.py b/tests/unit/utils/test_shields.py index 333c96df0..05bcf649d 100644 --- a/tests/unit/utils/test_shields.py +++ b/tests/unit/utils/test_shields.py @@ -118,7 +118,9 @@ async def test_returns_not_blocked_when_no_shields( mock_client.shields.list = mocker.AsyncMock(return_value=[]) mock_client.models.list = mocker.AsyncMock(return_value=[]) - result = await run_shield_moderation(mock_client, "test input") + result = await run_shield_moderation( + mock_client, "test input", "/test-endpoint" + ) assert result.decision == "passed" @@ -147,7 +149,9 @@ async def test_returns_not_blocked_when_moderation_passes( return_value=moderation_result ) - result = await run_shield_moderation(mock_client, "safe input") + result = await run_shield_moderation( + mock_client, "safe input", "/test-endpoint" + ) assert result.decision == "passed" mock_client.moderations.create.assert_called_once_with( @@ -187,11 +191,14 @@ async def test_returns_blocked_when_content_flagged( return_value=moderation_result ) - result = await run_shield_moderation(mock_client, "violent content") + result = await run_shield_moderation( + mock_client, "violent content", "/test-endpoint" + ) assert result.decision == "blocked" assert result.message == "Content blocked for violence" - mock_metric.inc.assert_called_once() + mock_metric.labels.assert_called_once_with("/test-endpoint") + mock_metric.labels.return_value.inc.assert_called_once() @pytest.mark.asyncio async def test_returns_blocked_with_default_message_when_no_user_message( @@ -224,7 +231,9 @@ async def test_returns_blocked_with_default_message_when_no_user_message( return_value=moderation_result ) - result = await run_shield_moderation(mock_client, "spam content") + result = await run_shield_moderation( + mock_client, "spam content", "/test-endpoint" + ) assert result.decision == "blocked" assert result.message == DEFAULT_VIOLATION_MESSAGE @@ -253,7 +262,9 @@ async def test_skips_model_check_for_non_llama_guard_shields( return_value=moderation_result ) - result = await run_shield_moderation(mock_client, "test input") + result = await run_shield_moderation( + mock_client, "test input", "/test-endpoint" + ) assert result.decision == "passed" mock_client.moderations.create.assert_called_once_with( @@ -280,7 +291,7 @@ async def test_raises_http_exception_when_shield_model_not_found( mock_client.models.list = mocker.AsyncMock(return_value=[model]) with pytest.raises(HTTPException) as exc_info: - await run_shield_moderation(mock_client, "test input") + await run_shield_moderation(mock_client, "test input", "/test-endpoint") assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND assert "missing-model" in exc_info.value.detail["cause"] # type: ignore @@ -302,7 +313,7 @@ async def test_raises_http_exception_when_shield_has_no_provider_resource_id( mock_client.models.list = mocker.AsyncMock(return_value=[]) with pytest.raises(HTTPException) as exc_info: - await run_shield_moderation(mock_client, "test input") + await run_shield_moderation(mock_client, "test input", "/test-endpoint") assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND @@ -317,7 +328,9 @@ async def test_shield_ids_empty_list_runs_no_shields_returns_passed( mock_client.shields.list = mocker.AsyncMock(return_value=[shield]) mock_client.models.list = mocker.AsyncMock(return_value=[]) - result = await run_shield_moderation(mock_client, "test input", shield_ids=[]) + result = await run_shield_moderation( + mock_client, "test input", "/test-endpoint", shield_ids=[] + ) assert result.decision == "passed" @@ -333,7 +346,7 @@ async def test_shield_ids_raises_404_when_no_shields_found( with pytest.raises(HTTPException) as exc_info: await run_shield_moderation( - mock_client, "test input", shield_ids=["typo-shield"] + mock_client, "test input", "/test-endpoint", shield_ids=["typo-shield"] ) assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND @@ -366,7 +379,7 @@ async def test_shield_ids_filters_to_specific_shield( ) result = await run_shield_moderation( - mock_client, "test input", shield_ids=["shield-1"] + mock_client, "test input", "/test-endpoint", shield_ids=["shield-1"] ) assert result.decision == "passed"