From b923543da2e1525a89bb5b401b24003dd24cc21c Mon Sep 17 00:00:00 2001 From: Andrej Simurka Date: Tue, 28 Apr 2026 13:42:08 +0200 Subject: [PATCH 1/2] Use context data class in responses --- src/app/endpoints/query.py | 2 +- src/app/endpoints/responses.py | 419 +++++++----------- src/app/endpoints/streaming_query.py | 3 +- .../common/responses/responses_api_params.py | 175 ++++++++ .../common/responses/responses_context.py | 45 ++ src/models/requests.py | 27 -- src/utils/responses.py | 2 +- src/utils/types.py | 117 ----- tests/unit/app/endpoints/test_query.py | 2 +- tests/unit/app/endpoints/test_responses.py | 276 ++++++++---- .../app/endpoints/test_responses_splunk.py | 71 ++- .../app/endpoints/test_streaming_query.py | 2 +- tests/unit/utils/test_types.py | 2 +- 13 files changed, 638 insertions(+), 505 deletions(-) create mode 100644 src/models/common/responses/responses_api_params.py create mode 100644 src/models/common/responses/responses_context.py diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index b4c31b017..5e5c119bd 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -23,6 +23,7 @@ from client import AsyncLlamaStackClientHolder from configuration import configuration from log import get_logger +from models.common.responses.responses_api_params import ResponsesApiParams from models.config import Action from models.requests import QueryRequest from models.responses import ( @@ -65,7 +66,6 @@ from utils.shields import run_shield_moderation, validate_shield_ids_override from utils.suid import normalize_conversation_id from utils.types import ( - ResponsesApiParams, ShieldModerationResult, TurnSummary, ) diff --git a/src/app/endpoints/responses.py b/src/app/endpoints/responses.py index 331341279..7031ca676 100644 --- a/src/app/endpoints/responses.py +++ b/src/app/endpoints/responses.py @@ -22,7 +22,6 @@ ) from llama_stack_client import ( APIConnectionError, - AsyncLlamaStackClient, ) from llama_stack_client import ( APIStatusError as LLSApiStatusError, @@ -39,6 +38,8 @@ from configuration import configuration from constants import SUBSTITUTED_INSTRUCTIONS_PLACEHOLDER from log import get_logger +from models.common.responses.responses_api_params import ResponsesApiParams +from models.common.responses.responses_context import ResponsesContext from models.config import Action from models.requests import ResponsesRequest from models.responses import ( @@ -98,10 +99,7 @@ ) from utils.tool_formatter import translate_vector_store_ids_to_user_facing from utils.types import ( - RAGContext, - ResponsesApiParams, ShieldModerationBlocked, - ShieldModerationResult, TurnSummary, ) from utils.vector_search import ( @@ -343,15 +341,9 @@ async def responses_endpoint_handler( original_request.input, inline_rag_context.context_text ) - response_handler = ( - handle_streaming_response - if original_request.stream - else handle_non_streaming_response - ) - return await response_handler( + api_params = ResponsesApiParams.model_validate(updated_request.model_dump()) + context = ResponsesContext( client=client, - original_request=original_request, - updated_request=updated_request, auth=auth, input_text=input_text, started_at=started_at, @@ -360,96 +352,82 @@ async def responses_endpoint_handler( filter_server_tools=filter_server_tools, background_tasks=background_tasks, rh_identity_context=rh_identity_context, + generate_topic_summary=updated_request.generate_topic_summary, + ) + response_handler = ( + handle_streaming_response + if original_request.stream + else handle_non_streaming_response + ) + return await response_handler( + original_request=original_request, + api_params=api_params, + context=context, ) async def handle_streaming_response( - client: AsyncLlamaStackClient, original_request: ResponsesRequest, - updated_request: ResponsesRequest, - auth: AuthTuple, - input_text: str, - started_at: datetime, - moderation_result: ShieldModerationResult, - inline_rag_context: RAGContext, - filter_server_tools: bool = False, - background_tasks: Optional[BackgroundTasks] = None, - rh_identity_context: tuple[str, str] = ("", ""), + api_params: ResponsesApiParams, + context: ResponsesContext, ) -> StreamingResponse: """Handle streaming response from Responses API. Args: client: The AsyncLlamaStackClient instance - request: ResponsesRequest (LCORE-specific fields e.g. generate_topic_summary) - auth: Authentication tuple - input_text: The extracted input text - started_at: Timestamp when the conversation started - moderation_result: Result of shield moderation check - inline_rag_context: Inline RAG context to be used for the response - filter_server_tools: Whether to filter server-deployed MCP tool events from the stream - background_tasks: FastAPI background task manager for telemetry events - rh_identity_context: Tuple of (org_id, system_id) from RH identity + original_request: Original request (read-only) + api_params: API parameters + responses_context: Responses context Returns: StreamingResponse with SSE-formatted events """ - api_params = ResponsesApiParams.model_validate(updated_request.model_dump()) turn_summary = TurnSummary() # Handle blocked response - if moderation_result.decision == "blocked": - turn_summary.id = moderation_result.moderation_id - turn_summary.llm_response = moderation_result.message - available_quotas = get_available_quotas( - quota_limiters=configuration.quota_limiters, user_id=auth[0] - ) - generator = shield_violation_generator( - moderation_result, - api_params.conversation, - updated_request.echoed_params(), - started_at, - available_quotas, - ) + if context.moderation_result.decision == "blocked": + turn_summary.id = context.moderation_result.moderation_id + turn_summary.llm_response = context.moderation_result.message + generator = shield_violation_generator(api_params, context) if api_params.store: await append_turn_items_to_conversation( - client=client, + client=context.client, conversation_id=api_params.conversation, - user_input=updated_request.input, - llm_output=[moderation_result.refusal_response], + user_input=api_params.input, + llm_output=[context.moderation_result.refusal_response], ) _queue_responses_splunk_event( - background_tasks=background_tasks, - input_text=input_text, - response_text=moderation_result.message, + background_tasks=context.background_tasks, + input_text=context.input_text, + response_text=context.moderation_result.message, conversation_id=normalize_conversation_id(api_params.conversation), model=api_params.model, - rh_identity_context=rh_identity_context, - inference_time=(datetime.now(UTC) - started_at).total_seconds(), + rh_identity_context=context.rh_identity_context, + inference_time=(datetime.now(UTC) - context.started_at).total_seconds(), sourcetype="responses_shield_blocked", ) else: try: - response = await client.responses.create( + response = await context.client.responses.create( **api_params.model_dump(exclude_none=True) ) generator = response_generator( stream=cast(AsyncIterator[OpenAIResponseObjectStream], response), original_request=original_request, - updated_request=updated_request, api_params=api_params, - user_id=auth[0], + context=context, turn_summary=turn_summary, - inline_rag_context=inline_rag_context, - filter_server_tools=filter_server_tools, ) except RuntimeError as e: # library mode wraps 413 into runtime error if is_context_length_error(str(e)): _queue_responses_splunk_event( - background_tasks=background_tasks, - input_text=input_text, + background_tasks=context.background_tasks, + input_text=context.input_text, response_text=str(e), conversation_id=normalize_conversation_id(api_params.conversation), model=api_params.model, - rh_identity_context=rh_identity_context, - inference_time=(datetime.now(UTC) - started_at).total_seconds(), + rh_identity_context=context.rh_identity_context, + inference_time=( + datetime.now(UTC) - context.started_at + ).total_seconds(), sourcetype="responses_error", fire_and_forget=True, ) @@ -458,13 +436,13 @@ async def handle_streaming_response( raise e except APIConnectionError as e: _queue_responses_splunk_event( - background_tasks=background_tasks, - input_text=input_text, + background_tasks=context.background_tasks, + input_text=context.input_text, response_text=str(e), conversation_id=normalize_conversation_id(api_params.conversation), model=api_params.model, - rh_identity_context=rh_identity_context, - inference_time=(datetime.now(UTC) - started_at).total_seconds(), + rh_identity_context=context.rh_identity_context, + inference_time=(datetime.now(UTC) - context.started_at).total_seconds(), sourcetype="responses_error", fire_and_forget=True, ) @@ -475,13 +453,13 @@ async def handle_streaming_response( raise HTTPException(**error_response.model_dump()) from e except (LLSApiStatusError, OpenAIAPIStatusError) as e: _queue_responses_splunk_event( - background_tasks=background_tasks, - input_text=input_text, + background_tasks=context.background_tasks, + input_text=context.input_text, response_text=str(e), conversation_id=normalize_conversation_id(api_params.conversation), model=api_params.model, - rh_identity_context=rh_identity_context, - inference_time=(datetime.now(UTC) - started_at).total_seconds(), + rh_identity_context=context.rh_identity_context, + inference_time=(datetime.now(UTC) - context.started_at).total_seconds(), sourcetype="responses_error", fire_and_forget=True, ) @@ -491,60 +469,42 @@ async def handle_streaming_response( return StreamingResponse( generate_response( generator=generator, - turn_summary=turn_summary, - client=client, - auth=auth, - input_text=input_text, - started_at=started_at, api_params=api_params, - generate_topic_summary=updated_request.generate_topic_summary or False, - background_tasks=background_tasks, - rh_identity_context=rh_identity_context, - shield_blocked=(moderation_result.decision == "blocked"), + context=context, + turn_summary=turn_summary, ), media_type="text/event-stream", ) async def shield_violation_generator( - moderation_result: ShieldModerationBlocked, - conversation_id: str, - echoed_params: dict[str, Any], - created_at: datetime, - available_quotas: dict[str, int], + api_params: ResponsesApiParams, + context: ResponsesContext, ) -> AsyncIterator[str]: """Generate SSE-formatted streaming response for shield-blocked requests. - Follows the Open Responses spec: - - Content-Type: text/event-stream - - Each event has 'event:' field matching the type in the event body - - Data objects are JSON-encoded strings - - Terminal event is the literal string [DONE] - - Emits full event sequence: response.created (in_progress), output_item.added, - output_item.done, response.completed (completed) - - Performs topic summary and persistence after [DONE] is emitted - Args: - moderation_result: The moderation result - conversation_id: The conversation ID to include in the response - echoed_params: Echoed parameters from the request - created_at: Unix timestamp when the response was created - available_quotas: Available quotas dictionary for the user + api_params: ResponsesApiParams + context: ResponsesContext Yields: SSE-formatted strings for streaming events, ending with [DONE] """ - normalized_conv_id = normalize_conversation_id(conversation_id) + normalized_conv_id = normalize_conversation_id(api_params.conversation) + available_quotas = get_available_quotas( + quota_limiters=configuration.quota_limiters, user_id=context.auth[0] + ) + moderation_result = cast(ShieldModerationBlocked, context.moderation_result) # 1. Send response.created event with status "in_progress" and empty output created_response_object = ResponsesResponse.model_construct( id=moderation_result.moderation_id, - created_at=int(created_at.timestamp()), + created_at=int(context.started_at.timestamp()), status="in_progress", output=[], conversation=normalized_conv_id, available_quotas={}, output_text="", - **echoed_params, + **api_params.echoed_params(configuration.rag_id_mapping), ) created_response_dict = created_response_object.model_dump( exclude_none=True, by_alias=True @@ -582,7 +542,7 @@ async def shield_violation_generator( # 4. Send response.completed event with status "completed" and output populated completed_response_object = ResponsesResponse.model_construct( id=moderation_result.moderation_id, - created_at=int(created_at.timestamp()), + created_at=int(context.started_at.timestamp()), completed_at=int(datetime.now(UTC).timestamp()), status="completed", output=[moderation_result.refusal_response], @@ -590,7 +550,7 @@ async def shield_violation_generator( conversation=normalized_conv_id, available_quotas=available_quotas, output_text=moderation_result.message, - **echoed_params, + **api_params.echoed_params(configuration.rag_id_mapping), ) completed_response_dict = completed_response_object.model_dump( exclude_none=True, by_alias=True @@ -667,7 +627,6 @@ def _is_server_mcp_output_item( def _should_filter_mcp_chunk( chunk: OpenAIResponseObjectStream, - event_type: Optional[str], configured_mcp_labels: set[str], server_mcp_output_indices: set[int], ) -> bool: @@ -682,7 +641,7 @@ def _should_filter_mcp_chunk( Returns: True if the chunk should be filtered out from the client stream. """ - if event_type == "response.output_item.added": + if chunk.type == "response.output_item.added": item_added_chunk = cast(OutputItemAddedChunk, chunk) item = item_added_chunk.item item_type = getattr(item, "type", None) @@ -692,16 +651,16 @@ def _should_filter_mcp_chunk( server_mcp_output_indices.add(item_added_chunk.output_index) return True - if event_type and ( - event_type.startswith("response.mcp_call.") - or event_type.startswith("response.mcp_list_tools.") - or event_type.startswith("response.mcp_approval_request.") + if chunk.type and ( + chunk.type.startswith("response.mcp_call.") + or chunk.type.startswith("response.mcp_list_tools.") + or chunk.type.startswith("response.mcp_approval_request.") ): output_index = getattr(chunk, "output_index", None) if output_index in server_mcp_output_indices: return True - if event_type == "response.output_item.done": + if chunk.type == "response.output_item.done": item_done_chunk = cast(OutputItemDoneChunk, chunk) item = item_done_chunk.item item_type = getattr(item, "type", None) @@ -715,19 +674,17 @@ def _should_filter_mcp_chunk( def _populate_turn_summary( response_object: OpenAIResponseObject, - turn_summary: TurnSummary, api_params: ResponsesApiParams, - inline_rag_context: RAGContext, - filter_server_tools: bool, + context: ResponsesContext, + turn_summary: TurnSummary, ) -> None: """Populate turn summary with metadata extracted from the final response object. Args: response_object: The completed response object from Llama Stack - turn_summary: TurnSummary to populate api_params: ResponsesApiParams - inline_rag_context: Inline RAG context used for the response - filter_server_tools: Whether to filter server-deployed MCP tool events + context: Responses context + turn_summary: TurnSummary to populate """ turn_summary.id = response_object.id vector_store_ids = extract_vector_store_ids_from_tools(api_params.tools) @@ -735,10 +692,10 @@ def _populate_turn_summary( response_object, vector_store_ids, configuration.rag_id_mapping ) turn_summary.referenced_documents = deduplicate_referenced_documents( - inline_rag_context.referenced_documents + tool_rag_docs + context.inline_rag_context.referenced_documents + tool_rag_docs ) for item in response_object.output: - if filter_server_tools and not is_server_deployed_output(item): + if context.filter_server_tools and not is_server_deployed_output(item): continue tool_call, tool_result = build_tool_call_summary(item) if tool_call: @@ -751,35 +708,27 @@ def _populate_turn_summary( vector_store_ids, configuration.rag_id_mapping, ) - turn_summary.rag_chunks = inline_rag_context.rag_chunks + tool_rag_chunks + turn_summary.rag_chunks = context.inline_rag_context.rag_chunks + tool_rag_chunks async def response_generator( stream: AsyncIterator[OpenAIResponseObjectStream], original_request: ResponsesRequest, - updated_request: ResponsesRequest, api_params: ResponsesApiParams, - user_id: str, + context: ResponsesContext, turn_summary: TurnSummary, - inline_rag_context: RAGContext, - filter_server_tools: bool = False, ) -> AsyncIterator[str]: """Generate SSE-formatted streaming response with LCORE-enriched events. Args: stream: The streaming response from Llama Stack - original_request: Original request object - updated_request: Updated request object + original_request: Original request (read-only) api_params: ResponsesApiParams - user_id: User ID for quota retrieval + context: Responses context turn_summary: TurnSummary to populate during streaming - inline_rag_context: Inline RAG context to be used for the response - filter_server_tools: Whether to filter server-deployed MCP tool events from the stream Yields: SSE-formatted strings for streaming events, ending with [DONE] """ - normalized_conv_id = normalize_conversation_id(api_params.conversation) - logger.debug("Starting streaming response (Responses API) processing") latest_response_object: Optional[OpenAIResponseObject] = None @@ -789,14 +738,13 @@ async def response_generator( server_mcp_output_indices: set[int] = set() async for chunk in stream: - event_type = getattr(chunk, "type", None) - logger.debug("Processing streaming chunk, type: %s", event_type) + logger.debug("Processing streaming chunk, type: %s", chunk.type) # Filter out streaming events for server-deployed MCP tools. # These are handled internally by LCS and should not be forwarded # to clients that don't understand the mcp_call item type. if _should_filter_mcp_chunk( - chunk, event_type, configured_mcp_labels, server_mcp_output_indices + chunk, configured_mcp_labels, server_mcp_output_indices ): continue @@ -807,7 +755,9 @@ async def response_generator( sequence_number += 1 if "response" in chunk_dict: - chunk_dict["response"]["conversation"] = normalized_conv_id + chunk_dict["response"]["conversation"] = normalize_conversation_id( + api_params.conversation + ) _sanitize_response_dict( chunk_dict["response"], configured_mcp_labels, @@ -822,12 +772,12 @@ async def response_generator( ) ) # Intermediate response - no quota consumption and text yet - if event_type == "response.in_progress": + if chunk.type == "response.in_progress": chunk_dict["response"]["available_quotas"] = {} chunk_dict["response"]["output_text"] = "" # Handle completion, incomplete, and failed events - only quota handling here - if event_type in ( + if chunk.type in ( "response.completed", "response.incomplete", "response.failed", @@ -841,41 +791,37 @@ async def response_generator( latest_response_object.usage, api_params.model ) consume_query_tokens( - user_id=user_id, + user_id=context.auth[0], model_id=api_params.model, token_usage=turn_summary.token_usage, ) # Get available quotas after token consumption - available_quotas = get_available_quotas( - quota_limiters=configuration.quota_limiters, user_id=user_id + chunk_dict["response"]["available_quotas"] = get_available_quotas( + quota_limiters=configuration.quota_limiters, user_id=context.auth[0] ) - chunk_dict["response"]["available_quotas"] = available_quotas turn_summary.llm_response = extract_text_from_response_items( latest_response_object.output ) chunk_dict["response"]["output_text"] = turn_summary.llm_response - data_json = json.dumps(chunk_dict) - yield f"event: {event_type or 'error'}\ndata: {data_json}\n\n" + yield f"event: {chunk.type or 'error'}\ndata: {json.dumps(chunk_dict)}\n\n" # Extract response metadata from final response object if latest_response_object: _populate_turn_summary( latest_response_object, - turn_summary, api_params, - inline_rag_context, - filter_server_tools, + context, + turn_summary, ) - client = AsyncLlamaStackClientHolder().get_client() # Explicitly append the turn to conversation if context passed by previous response if api_params.store and api_params.previous_response_id and latest_response_object: await append_turn_items_to_conversation( - client, + context.client, api_params.conversation, - updated_request.input, + context.input_text, latest_response_object.output, ) @@ -884,16 +830,9 @@ async def response_generator( async def generate_response( generator: AsyncIterator[str], - turn_summary: TurnSummary, - client: AsyncLlamaStackClient, - auth: AuthTuple, - input_text: str, - started_at: datetime, api_params: ResponsesApiParams, - generate_topic_summary: bool, - background_tasks: Optional[BackgroundTasks] = None, - rh_identity_context: tuple[str, str] = ("", ""), - shield_blocked: bool = False, + context: ResponsesContext, + turn_summary: TurnSummary, ) -> AsyncIterator[str]: """Stream the response from the generator and persist conversation details. @@ -902,27 +841,23 @@ async def generate_response( Args: generator: The SSE event generator turn_summary: TurnSummary populated during streaming - client: The AsyncLlamaStackClient instance - auth: Authentication tuple - input_text: The extracted input text - started_at: Timestamp when the conversation started api_params: ResponsesApiParams - generate_topic_summary: Whether to generate topic summary for new conversations - background_tasks: FastAPI background task manager for telemetry events - rh_identity_context: Tuple of (org_id, system_id) from RH identity - shield_blocked: Whether the request was blocked by a shield + context: Responses context + turn_summary: TurnSummary to populate during streaming Yields: SSE-formatted strings from the generator """ - user_id, _, skip_userid_check, _ = auth + user_id, _, skip_userid_check, _ = context.auth async for event in generator: yield event # Get topic summary for new conversation topic_summary = None - if generate_topic_summary: + if context.generate_topic_summary: logger.debug("Generating topic summary for new conversation") - topic_summary = await get_topic_summary(input_text, client, api_params.model) + topic_summary = await get_topic_summary( + context.input_text, context.client, api_params.model + ) completed_at = datetime.now(UTC) if api_params.store: @@ -930,101 +865,78 @@ async def generate_response( user_id=user_id, conversation_id=normalize_conversation_id(api_params.conversation), model=api_params.model, - started_at=started_at.strftime("%Y-%m-%dT%H:%M:%SZ"), + started_at=context.started_at.strftime("%Y-%m-%dT%H:%M:%SZ"), completed_at=completed_at.strftime("%Y-%m-%dT%H:%M:%SZ"), summary=turn_summary, - query=input_text, + query=context.input_text, attachments=[], skip_userid_check=skip_userid_check, topic_summary=topic_summary, ) - if not shield_blocked: + if context.moderation_result.decision == "passed": _queue_responses_splunk_event( - background_tasks=background_tasks, - input_text=input_text, + background_tasks=context.background_tasks, + input_text=context.input_text, response_text=turn_summary.llm_response, conversation_id=normalize_conversation_id(api_params.conversation), model=api_params.model, - rh_identity_context=rh_identity_context, - inference_time=(completed_at - started_at).total_seconds(), + rh_identity_context=context.rh_identity_context, + inference_time=(completed_at - context.started_at).total_seconds(), sourcetype="responses_completed", - input_tokens=( - turn_summary.token_usage.input_tokens if turn_summary.token_usage else 0 - ), - output_tokens=( - turn_summary.token_usage.output_tokens - if turn_summary.token_usage - else 0 - ), + input_tokens=turn_summary.token_usage.input_tokens, + output_tokens=turn_summary.token_usage.output_tokens, ) async def handle_non_streaming_response( - client: AsyncLlamaStackClient, original_request: ResponsesRequest, - updated_request: ResponsesRequest, - auth: AuthTuple, - input_text: str, - started_at: datetime, - moderation_result: ShieldModerationResult, - inline_rag_context: RAGContext, - filter_server_tools: bool = False, - background_tasks: Optional[BackgroundTasks] = None, - rh_identity_context: tuple[str, str] = ("", ""), + api_params: ResponsesApiParams, + context: ResponsesContext, ) -> ResponsesResponse: """Handle non-streaming response from Responses API. Args: - client: The AsyncLlamaStackClient instance - original_request: Original request object - updated_request: Updated request object - auth: Authentication tuple - input_text: The extracted input text - started_at: Timestamp when the conversation started - moderation_result: Result of shield moderation check - inline_rag_context: Inline RAG context to be used for the response - filter_server_tools: Whether to filter server-deployed MCP tool output - background_tasks: FastAPI background task manager for telemetry events - rh_identity_context: Tuple of (org_id, system_id) from RH identity + original_request: Original request (read-only) + api_params: API parameters + context: Responses context Returns: ResponsesResponse with the completed response """ - user_id, _, skip_userid_check, _ = auth - api_params = ResponsesApiParams.model_validate(updated_request.model_dump()) + user_id, _, skip_userid_check, _ = context.auth # Fork: Get response object (blocked vs normal) - if moderation_result.decision == "blocked": - output_text = moderation_result.message + if context.moderation_result.decision == "blocked": + output_text = context.moderation_result.message api_response = OpenAIResponseObject.model_construct( - id=moderation_result.moderation_id, - created_at=int(started_at.timestamp()), + id=context.moderation_result.moderation_id, + created_at=int(context.started_at.timestamp()), status="completed", - output=[moderation_result.refusal_response], + output=[context.moderation_result.refusal_response], usage=get_zero_usage(), - **updated_request.echoed_params(), + **api_params.echoed_params(configuration.rag_id_mapping), ) if api_params.store: await append_turn_items_to_conversation( - client=client, + client=context.client, conversation_id=api_params.conversation, - user_input=updated_request.input, - llm_output=[moderation_result.refusal_response], + user_input=api_params.input, + llm_output=[context.moderation_result.refusal_response], ) _queue_responses_splunk_event( - background_tasks=background_tasks, - input_text=input_text, + background_tasks=context.background_tasks, + input_text=context.input_text, response_text=output_text, conversation_id=normalize_conversation_id(api_params.conversation), model=api_params.model, - rh_identity_context=rh_identity_context, - inference_time=(datetime.now(UTC) - started_at).total_seconds(), + rh_identity_context=context.rh_identity_context, + inference_time=(datetime.now(UTC) - context.started_at).total_seconds(), sourcetype="responses_shield_blocked", ) else: try: api_response = cast( OpenAIResponseObject, - await client.responses.create( + await context.client.responses.create( **api_params.model_dump(exclude_none=True) ), ) @@ -1039,22 +951,24 @@ async def handle_non_streaming_response( # Explicitly append the turn to conversation if context passed by previous response if api_params.store and api_params.previous_response_id: await append_turn_items_to_conversation( - client, + context.client, api_params.conversation, - updated_request.input, + api_params.input, api_response.output, ) except RuntimeError as e: if is_context_length_error(str(e)): _queue_responses_splunk_event( - background_tasks=background_tasks, - input_text=input_text, + background_tasks=context.background_tasks, + input_text=context.input_text, response_text=str(e), conversation_id=normalize_conversation_id(api_params.conversation), model=api_params.model, - rh_identity_context=rh_identity_context, - inference_time=(datetime.now(UTC) - started_at).total_seconds(), + rh_identity_context=context.rh_identity_context, + inference_time=( + datetime.now(UTC) - context.started_at + ).total_seconds(), sourcetype="responses_error", fire_and_forget=True, ) @@ -1063,13 +977,13 @@ async def handle_non_streaming_response( raise e except APIConnectionError as e: _queue_responses_splunk_event( - background_tasks=background_tasks, - input_text=input_text, + background_tasks=context.background_tasks, + input_text=context.input_text, response_text=str(e), conversation_id=normalize_conversation_id(api_params.conversation), model=api_params.model, - rh_identity_context=rh_identity_context, - inference_time=(datetime.now(UTC) - started_at).total_seconds(), + rh_identity_context=context.rh_identity_context, + inference_time=(datetime.now(UTC) - context.started_at).total_seconds(), sourcetype="responses_error", fire_and_forget=True, ) @@ -1080,13 +994,13 @@ async def handle_non_streaming_response( raise HTTPException(**error_response.model_dump()) from e except (LLSApiStatusError, OpenAIAPIStatusError) as e: _queue_responses_splunk_event( - background_tasks=background_tasks, - input_text=input_text, + background_tasks=context.background_tasks, + input_text=context.input_text, response_text=str(e), conversation_id=normalize_conversation_id(api_params.conversation), model=api_params.model, - rh_identity_context=rh_identity_context, - inference_time=(datetime.now(UTC) - started_at).total_seconds(), + rh_identity_context=context.rh_identity_context, + inference_time=(datetime.now(UTC) - context.started_at).total_seconds(), sourcetype="responses_error", fire_and_forget=True, ) @@ -1100,9 +1014,11 @@ async def handle_non_streaming_response( ) # Get topic summary for new conversation topic_summary = None - if updated_request.generate_topic_summary: + if context.generate_topic_summary: logger.debug("Generating topic summary for new conversation") - topic_summary = await get_topic_summary(input_text, client, api_params.model) + topic_summary = await get_topic_summary( + context.input_text, context.client, api_params.model + ) vector_store_ids = extract_vector_store_ids_from_tools(api_params.tools) turn_summary = build_turn_summary( @@ -1110,41 +1026,36 @@ async def handle_non_streaming_response( api_params.model, vector_store_ids, configuration.rag_id_mapping, - filter_server_tools=filter_server_tools, + filter_server_tools=context.filter_server_tools, ) turn_summary.referenced_documents = deduplicate_referenced_documents( - inline_rag_context.referenced_documents + turn_summary.referenced_documents + context.inline_rag_context.referenced_documents + + turn_summary.referenced_documents ) - turn_summary.rag_chunks.extend(inline_rag_context.rag_chunks) + turn_summary.rag_chunks.extend(context.inline_rag_context.rag_chunks) completed_at = datetime.now(UTC) - if moderation_result.decision != "blocked": + if context.moderation_result.decision == "passed": _queue_responses_splunk_event( - background_tasks=background_tasks, - input_text=input_text, + background_tasks=context.background_tasks, + input_text=context.input_text, response_text=output_text, conversation_id=normalize_conversation_id(api_params.conversation), model=api_params.model, - rh_identity_context=rh_identity_context, - inference_time=(completed_at - started_at).total_seconds(), + rh_identity_context=context.rh_identity_context, + inference_time=(completed_at - context.started_at).total_seconds(), sourcetype="responses_completed", - input_tokens=( - turn_summary.token_usage.input_tokens if turn_summary.token_usage else 0 - ), - output_tokens=( - turn_summary.token_usage.output_tokens - if turn_summary.token_usage - else 0 - ), + input_tokens=turn_summary.token_usage.input_tokens, + output_tokens=turn_summary.token_usage.output_tokens, ) if api_params.store: store_query_results( user_id=user_id, conversation_id=normalize_conversation_id(api_params.conversation), model=api_params.model, - started_at=started_at.strftime("%Y-%m-%dT%H:%M:%SZ"), + started_at=context.started_at.strftime("%Y-%m-%dT%H:%M:%SZ"), completed_at=completed_at.strftime("%Y-%m-%dT%H:%M:%SZ"), summary=turn_summary, - query=input_text, + query=context.input_text, attachments=[], skip_userid_check=skip_userid_check, topic_summary=topic_summary, diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index acab21f40..8f7ac23a1 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -59,6 +59,7 @@ TOPIC_SUMMARY_INTERRUPT_TIMEOUT_SECONDS, ) from log import get_logger +from models.common.responses.responses_api_params import ResponsesApiParams from models.config import Action from models.context import ResponseGeneratorContext from models.requests import QueryRequest @@ -115,7 +116,7 @@ from utils.stream_interrupts import get_stream_interrupt_registry from utils.suid import get_suid, normalize_conversation_id from utils.token_counter import TokenCounter -from utils.types import ReferencedDocument, ResponsesApiParams, TurnSummary +from utils.types import ReferencedDocument, TurnSummary from utils.vector_search import build_rag_context logger = get_logger(__name__) diff --git a/src/models/common/responses/responses_api_params.py b/src/models/common/responses/responses_api_params.py new file mode 100644 index 000000000..64709811b --- /dev/null +++ b/src/models/common/responses/responses_api_params.py @@ -0,0 +1,175 @@ +"""Request parameter model for Llama Stack responses API calls.""" + +from collections.abc import Mapping +from typing import Any, Optional + +from llama_stack_api.openai_responses import ( + OpenAIResponseInputTool as InputTool, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseInputToolChoice as ToolChoice, +) +from llama_stack_api.openai_responses import ( + OpenAIResponsePrompt as Prompt, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseReasoning as Reasoning, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseText as Text, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseToolMCP as OutputToolMCP, +) +from pydantic import BaseModel, Field + +from utils.tool_formatter import translate_vector_store_ids_to_user_facing +from utils.types import IncludeParameter, ResponseInput + +# Attribute names that are echoed back in the response. +_ECHOED_FIELDS = set( + { + "instructions", + "max_tool_calls", + "max_output_tokens", + "metadata", + "model", + "parallel_tool_calls", + "previous_response_id", + "prompt", + "reasoning", + "safety_identifier", + "temperature", + "top_p", + "truncation", + "text", + "tool_choice", + "store", + } +) + + +class ResponsesApiParams(BaseModel): + """Parameters for a Llama Stack Responses API request. + + All fields accepted by the Llama Stack client responses.create() body are + included so that dumped model can be passed directly to response create. + """ + + input: ResponseInput = Field(description="The input text or structured input items") + model: str = Field(description='The full model ID in format "provider/model"') + conversation: str = Field(description="The conversation ID in llama-stack format") + include: Optional[list[IncludeParameter]] = Field( + default=None, + description="Output item types to include in the response", + ) + instructions: Optional[str] = Field( + default=None, description="The resolved system prompt" + ) + max_infer_iters: Optional[int] = Field( + default=None, + description="Maximum number of inference iterations", + ) + max_output_tokens: Optional[int] = Field( + default=None, + description="Maximum number of tokens allowed in the response", + ) + max_tool_calls: Optional[int] = Field( + default=None, + description="Maximum tool calls allowed in a single response", + ) + metadata: Optional[dict[str, str]] = Field( + default=None, + description="Custom metadata for tracking or logging", + ) + parallel_tool_calls: Optional[bool] = Field( + default=None, + description="Whether the model can make multiple tool calls in parallel", + ) + previous_response_id: Optional[str] = Field( + default=None, + description="Identifier of the previous response in a multi-turn conversation", + ) + prompt: Optional[Prompt] = Field( + default=None, + description="Prompt template with variables for dynamic substitution", + ) + reasoning: Optional[Reasoning] = Field( + default=None, + description="Reasoning configuration for the response", + ) + safety_identifier: Optional[str] = Field( + default=None, + description="Stable identifier for safety monitoring and abuse detection", + ) + store: bool = Field(description="Whether to store the response") + stream: bool = Field(description="Whether to stream the response") + temperature: Optional[float] = Field( + default=None, + description="Sampling temperature (e.g. 0.0-2.0)", + ) + text: Optional[Text] = Field( + default=None, + description="Text response configuration (format constraints)", + ) + tool_choice: Optional[ToolChoice] = Field( + default=None, + description="Tool selection strategy", + ) + tools: Optional[list[InputTool]] = Field( + default=None, + description="Prepared tool groups for Responses API (same type as ResponsesRequest.tools)", + ) + extra_headers: Optional[dict[str, str]] = Field( + default=None, + description="Extra HTTP headers to send with the request (e.g. x-llamastack-provider-data)", + ) + + def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + """Serialize params, re-injecting MCP authorization stripped by exclude=True. + + llama-stack-api marks ``InputToolMCP.authorization`` with + ``Field(exclude=True)`` to prevent token leakage in API responses. + The base ``model_dump()`` therefore strips the field, but we need it + in the request payload so llama-stack server can authenticate with + MCP servers. See LCORE-1414 / GitHub issue #1269. + """ + result = super().model_dump(*args, **kwargs) + # Only one context option is allowed, previous_response_id has priority + # Turn is added to conversation manually if previous_response_id is used + if self.previous_response_id: + result.pop("conversation", None) + dumped_tools = result.get("tools") + if not self.tools or not isinstance(dumped_tools, list): + return result + if len(dumped_tools) != len(self.tools): + return result + for tool, dumped_tool in zip(self.tools, dumped_tools): + authorization = getattr(tool, "authorization", None) + if authorization is not None and isinstance(dumped_tool, dict): + dumped_tool["authorization"] = authorization + return result + + def echoed_params(self, rag_id_mapping: Mapping[str, str]) -> dict[str, Any]: + """Build kwargs echoed into synthetic OpenAI-style responses (e.g. moderation blocks). + + Parameters: + rag_id_mapping: Llama Stack vector_db_id to user-facing RAG id (from app config). + Returns: + dict[str, Any]: Field names and values to merge into the response object. + """ + data = self.model_dump(include=_ECHOED_FIELDS) + if self.tools is not None: + tool_dicts: list[dict[str, Any]] = [] + for t in self.tools: + if t.type == "mcp": + validated = OutputToolMCP.model_validate(t.model_dump()) + tool_dicts.append(validated.model_dump()) + else: + tool_dicts.append(t.model_dump()) + + data["tools"] = translate_vector_store_ids_to_user_facing( + tool_dicts, rag_id_mapping + ) + + return data diff --git a/src/models/common/responses/responses_context.py b/src/models/common/responses/responses_context.py new file mode 100644 index 000000000..251a671ef --- /dev/null +++ b/src/models/common/responses/responses_context.py @@ -0,0 +1,45 @@ +"""Request-scoped context model for the responses endpoint pipeline.""" + +from datetime import datetime +from typing import Optional + +from fastapi import BackgroundTasks +from llama_stack_client import AsyncLlamaStackClient +from pydantic import BaseModel, ConfigDict, Field + +from utils.types import RAGContext, ShieldModerationResult + + +class ResponsesContext(BaseModel): + """Shared request-scoped context for the /responses endpoint pipeline.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + client: AsyncLlamaStackClient = Field(description="The Llama Stack client") + auth: tuple[str, str, bool, str] = Field( + description="Authentication tuple (user_id, username, skip_userid_check, token)", + ) + input_text: str = Field(description="Extracted user input text for the turn") + started_at: datetime = Field(description="UTC timestamp when the request started") + moderation_result: ShieldModerationResult = Field( + description="Shield moderation outcome", + ) + inline_rag_context: RAGContext = Field( + description="Inline RAG context for the turn" + ) + filter_server_tools: bool = Field( + default=False, + description="Whether to filter server-deployed MCP tool events from output", + ) + background_tasks: Optional[BackgroundTasks] = Field( + default=None, + description="Background tasks for telemetry, if enabled", + ) + rh_identity_context: tuple[str, str] = Field( + default_factory=lambda: ("", ""), + description="RH identity (org_id, system_id) for Splunk events", + ) + generate_topic_summary: bool = Field( + default=False, + description="Whether to generate a topic summary for new conversations", + ) diff --git a/src/models/requests.py b/src/models/requests.py index 9299bb4e1..ebdb90a87 100644 --- a/src/models/requests.py +++ b/src/models/requests.py @@ -21,12 +21,8 @@ from llama_stack_api.openai_responses import ( OpenAIResponseText as Text, ) -from llama_stack_api.openai_responses import ( - OpenAIResponseToolMCP as OutputToolMCP, -) from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from configuration import configuration from constants import ( MCP_AUTH_CLIENT, MCP_AUTH_KUBERNETES, @@ -38,7 +34,6 @@ ) from log import get_logger from utils import suid -from utils.tool_formatter import translate_vector_store_ids_to_user_facing from utils.types import IncludeParameter, ResponseInput logger = get_logger(__name__) @@ -867,28 +862,6 @@ def check_previous_response_id(cls, value: Optional[str]) -> Optional[str]: raise ValueError("You cannot provide context by moderation response.") return value - def echoed_params(self) -> dict[str, Any]: - """Build kwargs echoed into synthetic OpenAI-style responses (e.g. moderation blocks). - - Returns: - dict[str, Any]: Field names and values to merge into the response object. - """ - data = self.model_dump(include=_ECHOED_FIELDS) - if self.tools is not None: - tool_dicts: list[dict[str, Any]] = [ - ( - OutputToolMCP.model_validate(t.model_dump()).model_dump() - if t.type == "mcp" - else t.model_dump() - ) - for t in self.tools - ] - data["tools"] = translate_vector_store_ids_to_user_facing( - tool_dicts, configuration.rag_id_mapping - ) - - return data - class MCPServerRegistrationRequest(BaseModel): """Request model for dynamically registering an MCP server. diff --git a/src/utils/responses.py b/src/utils/responses.py index 858973d02..e0956a65c 100644 --- a/src/utils/responses.py +++ b/src/utils/responses.py @@ -91,6 +91,7 @@ from configuration import configuration from constants import DEFAULT_RAG_TOOL from log import get_logger +from models.common.responses.responses_api_params import ResponsesApiParams from models.config import ByokRag from models.database.conversations import UserConversation from models.requests import QueryRequest @@ -118,7 +119,6 @@ ReferencedDocument, ResponseInput, ResponseItem, - ResponsesApiParams, ToolCallSummary, ToolResultSummary, TurnSummary, diff --git a/src/utils/types.py b/src/utils/types.py index 018c264f6..48f0fc226 100644 --- a/src/utils/types.py +++ b/src/utils/types.py @@ -6,12 +6,6 @@ from llama_stack_api.openai_responses import ( OpenAIResponseInputFunctionToolCallOutput as FunctionToolCallOutput, ) -from llama_stack_api.openai_responses import ( - OpenAIResponseInputTool as InputTool, -) -from llama_stack_api.openai_responses import ( - OpenAIResponseInputToolChoice as ToolChoice, -) from llama_stack_api.openai_responses import ( OpenAIResponseMCPApprovalRequest as McpApprovalRequest, ) @@ -36,15 +30,6 @@ from llama_stack_api.openai_responses import ( OpenAIResponseOutputMessageWebSearchToolCall as WebSearchToolCall, ) -from llama_stack_api.openai_responses import ( - OpenAIResponsePrompt as Prompt, -) -from llama_stack_api.openai_responses import ( - OpenAIResponseReasoning as Reasoning, -) -from llama_stack_api.openai_responses import ( - OpenAIResponseText as Text, -) from pydantic import AnyUrl, BaseModel, ConfigDict, Field from models.database.conversations import UserConversation @@ -167,108 +152,6 @@ class ShieldModerationBlocked(BaseModel): type ResponseInput = str | list[ResponseItem] -class ResponsesApiParams(BaseModel): - """Parameters for a Llama Stack Responses API request. - - All fields accepted by the Llama Stack client responses.create() body are - included so that dumped model can be passed directly to response create. - """ - - input: ResponseInput = Field(description="The input text or structured input items") - model: str = Field(description='The full model ID in format "provider/model"') - conversation: str = Field(description="The conversation ID in llama-stack format") - include: Optional[list[IncludeParameter]] = Field( - default=None, - description="Output item types to include in the response", - ) - instructions: Optional[str] = Field( - default=None, description="The resolved system prompt" - ) - max_infer_iters: Optional[int] = Field( - default=None, - description="Maximum number of inference iterations", - ) - max_output_tokens: Optional[int] = Field( - default=None, - description="Maximum number of tokens allowed in the response", - ) - max_tool_calls: Optional[int] = Field( - default=None, - description="Maximum tool calls allowed in a single response", - ) - metadata: Optional[dict[str, str]] = Field( - default=None, - description="Custom metadata for tracking or logging", - ) - parallel_tool_calls: Optional[bool] = Field( - default=None, - description="Whether the model can make multiple tool calls in parallel", - ) - previous_response_id: Optional[str] = Field( - default=None, - description="Identifier of the previous response in a multi-turn conversation", - ) - prompt: Optional[Prompt] = Field( - default=None, - description="Prompt template with variables for dynamic substitution", - ) - reasoning: Optional[Reasoning] = Field( - default=None, - description="Reasoning configuration for the response", - ) - safety_identifier: Optional[str] = Field( - default=None, - description="Stable identifier for safety monitoring and abuse detection", - ) - store: bool = Field(description="Whether to store the response") - stream: bool = Field(description="Whether to stream the response") - temperature: Optional[float] = Field( - default=None, - description="Sampling temperature (e.g. 0.0-2.0)", - ) - text: Optional[Text] = Field( - default=None, - description="Text response configuration (format constraints)", - ) - tool_choice: Optional[ToolChoice] = Field( - default=None, - description="Tool selection strategy", - ) - tools: Optional[list[InputTool]] = Field( - default=None, - description="Prepared tool groups for Responses API (same type as ResponsesRequest.tools)", - ) - extra_headers: Optional[dict[str, str]] = Field( - default=None, - description="Extra HTTP headers to send with the request (e.g. x-llamastack-provider-data)", - ) - - def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]: - """Serialize params, re-injecting MCP authorization stripped by exclude=True. - - llama-stack-api marks ``InputToolMCP.authorization`` with - ``Field(exclude=True)`` to prevent token leakage in API responses. - The base ``model_dump()`` therefore strips the field, but we need it - in the request payload so llama-stack server can authenticate with - MCP servers. See LCORE-1414 / GitHub issue #1269. - """ - result = super().model_dump(*args, **kwargs) - # Only one context option is allowed, previous_response_id has priority - # Turn is added to conversation manually if previous_response_id is used - if self.previous_response_id: - result.pop("conversation", None) - dumped_tools = result.get("tools") - if not self.tools or not isinstance(dumped_tools, list): - return result - if len(dumped_tools) != len(self.tools): - return result - for tool, dumped_tool in zip(self.tools, dumped_tools): - authorization = getattr(tool, "authorization", None) - if authorization is not None and isinstance(dumped_tool, dict): - dumped_tool["authorization"] = authorization - return result - - class ToolCallSummary(BaseModel): """Model representing a tool call made during response generation (for tool_calls list).""" diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 06ee69926..58a955289 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -12,6 +12,7 @@ from app.endpoints.query import query_endpoint_handler, retrieve_response from configuration import AppConfig +from models.common.responses.responses_api_params import ResponsesApiParams from models.database.conversations import UserConversation from models.requests import Attachment, QueryRequest from models.responses import QueryResponse @@ -20,7 +21,6 @@ RAGChunk, RAGContext, ReferencedDocument, - ResponsesApiParams, ShieldModerationPassed, ToolCallSummary, ToolResultSummary, diff --git a/tests/unit/app/endpoints/test_responses.py b/tests/unit/app/endpoints/test_responses.py index 96550af75..afa9cd530 100644 --- a/tests/unit/app/endpoints/test_responses.py +++ b/tests/unit/app/endpoints/test_responses.py @@ -26,11 +26,19 @@ ) from configuration import AppConfig from constants import DEFAULT_SYSTEM_PROMPT, SUBSTITUTED_INSTRUCTIONS_PLACEHOLDER +from models.common.responses.responses_api_params import ResponsesApiParams +from models.common.responses.responses_context import ResponsesContext from models.config import Action, ModelContextProtocolServer from models.database.conversations import UserConversation from models.requests import ResponsesRequest from models.responses import ResponsesResponse -from utils.types import RAGContext, ResponsesConversationContext, TurnSummary +from utils.types import ( + RAGContext, + ResponsesConversationContext, + ShieldModerationBlocked, + ShieldModerationPassed, + TurnSummary, +) MOCK_AUTH = ( "00000001-0001-0001-0001-000000000001", @@ -47,6 +55,37 @@ SERVER_INSTRUCTIONS = "Server instructions" +def build_api_params_and_context( # pylint: disable=too-many-arguments + *, + updated_request: ResponsesRequest, + client: Any, + auth: tuple[str, str, bool, str], + input_text: str, + started_at: datetime, + moderation_result: Any, + inline_rag_context: RAGContext, + background_tasks: Any = None, + rh_identity_context: tuple[str, str] = ("", ""), + filter_server_tools: bool = False, + generate_topic_summary: bool = False, +) -> tuple[ResponsesApiParams, ResponsesContext]: + """Build api_params/context for direct helper invocation tests.""" + api_params = ResponsesApiParams.model_validate(updated_request.model_dump()) + context = ResponsesContext.model_construct( + client=client, + auth=auth, + input_text=input_text, + started_at=started_at, + moderation_result=moderation_result, + inline_rag_context=inline_rag_context, + filter_server_tools=filter_server_tools, + background_tasks=background_tasks, + rh_identity_context=rh_identity_context, + generate_topic_summary=generate_topic_summary, + ) + return api_params, context + + def _patch_base(mocker: MockerFixture, config: AppConfig) -> None: """Patch configuration and mandatory checks for responses endpoint.""" mocker.patch(f"{MODULE}.configuration", config) @@ -115,14 +154,24 @@ def _patch_rag( def _patch_moderation(mocker: MockerFixture, decision: str = "passed") -> Any: - """Patch run_shield_moderation; return mock moderation result.""" - mock_moderation = mocker.Mock() - mock_moderation.decision = decision + """Patch run_shield_moderation; return typed moderation result.""" + if decision == "blocked": + moderation_result = ShieldModerationBlocked( + message="Content blocked", + moderation_id="mod_blocked", + refusal_response=OpenAIResponseMessage( + role="assistant", + content="Content blocked", + type="message", + ), + ) + else: + moderation_result = ShieldModerationPassed() mocker.patch( f"{MODULE}.run_shield_moderation", - new=mocker.AsyncMock(return_value=mock_moderation), + new=mocker.AsyncMock(return_value=moderation_result), ) - return mock_moderation + return moderation_result def _make_responses_response( @@ -635,8 +684,8 @@ async def test_tool_choice_none_without_tools_does_not_load_server_tools( # The handler passes tools=None and tool_choice=None to the response handler # (the endpoint deep-copies the request, so we inspect the handler call args) call_kwargs = mock_handle.call_args[1] - assert call_kwargs["updated_request"].tools is None - assert call_kwargs["updated_request"].tool_choice is None + assert call_kwargs["api_params"].tools is None + assert call_kwargs["api_params"].tool_choice is None @pytest.mark.asyncio async def test_responses_endpoint_rejects_without_responses_action( @@ -720,16 +769,20 @@ async def test_handle_non_streaming_blocked_returns_refusal( return_value=mock_api_response, ) - response = await handle_non_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Bad input", started_at=datetime.now(UTC), moderation_result=mock_moderation, inline_rag_context=RAGContext(), ) + response = await handle_non_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) assert isinstance(response, ResponsesResponse) assert response.output_text == "Content blocked" mock_client.responses.create.assert_not_called() @@ -791,16 +844,20 @@ async def test_handle_non_streaming_success_returns_response( return_value=VALID_CONV_ID_NORMALIZED, ) - response = await handle_non_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Hello", started_at=datetime.now(UTC), moderation_result=mock_moderation, inline_rag_context=RAGContext(), ) + response = await handle_non_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) assert isinstance(response, ResponsesResponse) assert response.output_text == "Model reply" @@ -868,16 +925,20 @@ async def test_handle_non_streaming_with_previous_response_id_appends_turn( new=mocker.AsyncMock(), ) - await handle_non_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), moderation_result=mock_moderation, inline_rag_context=RAGContext(), ) + await handle_non_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) mock_append.assert_awaited_once() call_args = mock_append.call_args[0] @@ -906,16 +967,20 @@ async def test_handle_non_streaming_context_length_raises_413( ) with pytest.raises(HTTPException) as exc_info: - await handle_non_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Long input", started_at=datetime.now(UTC), moderation_result=mock_moderation, inline_rag_context=RAGContext(), ) + await handle_non_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) assert exc_info.value.status_code == 413 @@ -944,16 +1009,20 @@ async def test_handle_non_streaming_connection_error_raises_503( ) with pytest.raises(HTTPException) as exc_info: - await handle_non_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), moderation_result=mock_moderation, inline_rag_context=RAGContext(), ) + await handle_non_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) assert exc_info.value.status_code == 503 @@ -992,16 +1061,20 @@ async def test_handle_non_streaming_api_status_error_raises_http( ) with pytest.raises(HTTPException) as exc_info: - await handle_non_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), moderation_result=mock_moderation, inline_rag_context=RAGContext(), ) + await handle_non_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) assert exc_info.value.status_code == 500 @@ -1027,16 +1100,20 @@ async def test_handle_non_streaming_runtime_error_without_context_reraises( ) with pytest.raises(RuntimeError, match="Some other error"): - await handle_non_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), moderation_result=mock_moderation, inline_rag_context=RAGContext(), ) + await handle_non_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) class TestHandleStreamingResponse: @@ -1073,16 +1150,20 @@ async def test_handle_streaming_blocked_returns_sse_consumes_shield_generator( mocker.patch(f"{MODULE}.store_query_results") mock_client.conversations.items.create = mocker.AsyncMock() - response = await handle_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Bad", started_at=datetime.now(UTC), moderation_result=mock_moderation, inline_rag_context=RAGContext(), ) + response = await handle_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) assert isinstance(response, StreamingResponse) assert response.media_type == "text/event-stream" @@ -1153,16 +1234,20 @@ async def mock_stream() -> Any: mock_holder = mocker.Mock() mock_holder.get_client.return_value = mock_client mocker.patch(f"{MODULE}.AsyncLlamaStackClientHolder", return_value=mock_holder) - response = await handle_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), moderation_result=mock_moderation, inline_rag_context=RAGContext(), ) + response = await handle_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) assert isinstance(response, StreamingResponse) collected: list[str] = [] async for part in response.body_iterator: @@ -1237,16 +1322,20 @@ async def mock_stream() -> Any: mock_holder.get_client.return_value = mock_client mocker.patch(f"{MODULE}.AsyncLlamaStackClientHolder", return_value=mock_holder) - response = await handle_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), moderation_result=mock_moderation, inline_rag_context=RAGContext(), ) + response = await handle_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) collected: list[str] = [] async for part in response.body_iterator: chunk_str = ( @@ -1321,16 +1410,20 @@ async def mock_stream() -> Any: mock_holder.get_client.return_value = mock_client mocker.patch(f"{MODULE}.AsyncLlamaStackClientHolder", return_value=mock_holder) - response = await handle_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), moderation_result=mock_moderation, inline_rag_context=RAGContext(), ) + response = await handle_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) collected: list[str] = [] async for part in response.body_iterator: chunk_str = ( @@ -1399,16 +1492,20 @@ async def mock_stream() -> Any: mock_holder.get_client.return_value = mock_client mocker.patch(f"{MODULE}.AsyncLlamaStackClientHolder", return_value=mock_holder) - response = await handle_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), moderation_result=mock_moderation, inline_rag_context=RAGContext(), ) + response = await handle_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) collected: list[str] = [] async for part in response.body_iterator: chunk_str = ( @@ -1443,16 +1540,20 @@ async def test_handle_streaming_context_length_raises_413( return_value=VALID_CONV_ID_NORMALIZED, ) with pytest.raises(HTTPException) as exc_info: - await handle_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Long", started_at=datetime.now(UTC), moderation_result=mock_moderation, inline_rag_context=RAGContext(), ) + await handle_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) assert exc_info.value.status_code == 413 @pytest.mark.asyncio @@ -1479,16 +1580,20 @@ async def test_handle_streaming_connection_error_raises_503( return_value=VALID_CONV_ID_NORMALIZED, ) with pytest.raises(HTTPException) as exc_info: - await handle_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), moderation_result=mock_moderation, inline_rag_context=RAGContext(), ) + await handle_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) assert exc_info.value.status_code == 503 @@ -1539,7 +1644,7 @@ async def test_default_instructions_applied_when_client_omits_them( # The updated request passed to handle_non_streaming_response should have # instructions resolved to the default system prompt. call_kwargs = mock_handler.call_args[1] - assert call_kwargs["updated_request"].instructions == DEFAULT_SYSTEM_PROMPT + assert call_kwargs["api_params"].instructions == DEFAULT_SYSTEM_PROMPT @pytest.mark.asyncio async def test_client_provided_instructions_pass_through( @@ -1584,7 +1689,7 @@ async def test_client_provided_instructions_pass_through( ) call_kwargs = mock_handler.call_args[1] - assert call_kwargs["updated_request"].instructions == custom_instructions + assert call_kwargs["api_params"].instructions == custom_instructions @pytest.mark.asyncio async def test_configured_system_prompt_used_when_no_client_instructions( @@ -1646,8 +1751,7 @@ async def test_configured_system_prompt_used_when_no_client_instructions( call_kwargs = mock_handler.call_args[1] assert ( - call_kwargs["updated_request"].instructions - == "You are a deployment assistant." + call_kwargs["api_params"].instructions == "You are a deployment assistant." ) @pytest.mark.asyncio @@ -1734,7 +1838,7 @@ async def test_streaming_response_uses_resolved_instructions( ) call_kwargs = mock_handler.call_args[1] - assert call_kwargs["updated_request"].instructions == DEFAULT_SYSTEM_PROMPT + assert call_kwargs["api_params"].instructions == DEFAULT_SYSTEM_PROMPT class TestIsServerMcpOutputItem: @@ -1796,11 +1900,11 @@ def test_filters_mcp_call_substream_events(self, mocker: MockerFixture) -> None: """Test that response.mcp_call.* events are filtered for tracked indices.""" chunk = mocker.Mock() chunk.output_index = 5 + chunk.type = "response.mcp_call.in_progress" server_mcp_output_indices: set[int] = {5} assert ( _should_filter_mcp_chunk( chunk, - "response.mcp_call.in_progress", {"server-a"}, server_mcp_output_indices, ) @@ -1813,11 +1917,11 @@ def test_filters_mcp_list_tools_substream_events( """Test that response.mcp_list_tools.* events are filtered for tracked indices.""" chunk = mocker.Mock() chunk.output_index = 3 + chunk.type = "response.mcp_list_tools.in_progress" server_mcp_output_indices: set[int] = {3} assert ( _should_filter_mcp_chunk( chunk, - "response.mcp_list_tools.in_progress", {"server-a"}, server_mcp_output_indices, ) @@ -1830,11 +1934,11 @@ def test_filters_mcp_approval_request_substream_events( """Test that response.mcp_approval_request.* events are filtered for tracked indices.""" chunk = mocker.Mock() chunk.output_index = 7 + chunk.type = "response.mcp_approval_request.in_progress" server_mcp_output_indices: set[int] = {7} assert ( _should_filter_mcp_chunk( chunk, - "response.mcp_approval_request.in_progress", {"server-a"}, server_mcp_output_indices, ) @@ -1847,11 +1951,11 @@ def test_does_not_filter_untracked_mcp_approval_request( """Test that mcp_approval_request events for untracked indices pass through.""" chunk = mocker.Mock() chunk.output_index = 7 + chunk.type = "response.mcp_approval_request.in_progress" server_mcp_output_indices: set[int] = {99} assert ( _should_filter_mcp_chunk( chunk, - "response.mcp_approval_request.in_progress", {"server-a"}, server_mcp_output_indices, ) @@ -1862,11 +1966,11 @@ def test_does_not_filter_untracked_mcp_call(self, mocker: MockerFixture) -> None """Test that mcp_call events for untracked indices pass through.""" chunk = mocker.Mock() chunk.output_index = 10 + chunk.type = "response.mcp_call.completed" server_mcp_output_indices: set[int] = {5} assert ( _should_filter_mcp_chunk( chunk, - "response.mcp_call.completed", {"server-a"}, server_mcp_output_indices, ) @@ -1883,11 +1987,11 @@ def test_filters_output_item_added_for_server_mcp( chunk = mocker.Mock() chunk.item = item chunk.output_index = 2 + chunk.type = "response.output_item.added" server_mcp_output_indices: set[int] = set() assert ( _should_filter_mcp_chunk( chunk, - "response.output_item.added", {"server-a"}, server_mcp_output_indices, ) @@ -1904,11 +2008,11 @@ def test_filters_output_item_done_for_server_mcp( chunk = mocker.Mock() chunk.item = item chunk.output_index = 2 + chunk.type = "response.output_item.done" server_mcp_output_indices: set[int] = {2} assert ( _should_filter_mcp_chunk( chunk, - "response.output_item.done", {"server-a"}, server_mcp_output_indices, ) @@ -1919,12 +2023,8 @@ def test_filters_output_item_done_for_server_mcp( def test_does_not_filter_non_mcp_event(self, mocker: MockerFixture) -> None: """Test that non-MCP events pass through.""" chunk = mocker.Mock() - assert ( - _should_filter_mcp_chunk( - chunk, "response.output_text.delta", {"server-a"}, set() - ) - is False - ) + chunk.type = "response.output_text.delta" + assert _should_filter_mcp_chunk(chunk, {"server-a"}, set()) is False def mock_original_request( @@ -2222,16 +2322,20 @@ async def test_non_streaming_sanitizes_mcp_output_and_model( return_value=VALID_CONV_ID_NORMALIZED, ) - response = await handle_non_streaming_response( - client=mock_client, - original_request=original_request, + api_params, context = build_api_params_and_context( updated_request=updated_request, + client=mock_client, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), moderation_result=mock_moderation, inline_rag_context=RAGContext(), ) + response = await handle_non_streaming_response( + original_request=original_request, + api_params=api_params, + context=context, + ) assert isinstance(response, ResponsesResponse) # Model provider prefix should be stripped when server-substituted @@ -2340,10 +2444,9 @@ async def mock_stream() -> Any: mock_holder.get_client.return_value = mock_client mocker.patch(f"{MODULE}.AsyncLlamaStackClientHolder", return_value=mock_holder) - response = await handle_streaming_response( - client=mock_client, - original_request=original_request, + api_params, context = build_api_params_and_context( updated_request=updated_request, + client=mock_client, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), @@ -2351,6 +2454,11 @@ async def mock_stream() -> Any: inline_rag_context=RAGContext(), filter_server_tools=False, ) + response = await handle_streaming_response( + original_request=original_request, + api_params=api_params, + context=context, + ) collected: list[str] = [] async for part in response.body_iterator: chunk_str = ( @@ -2458,10 +2566,9 @@ async def mock_stream() -> Any: mock_holder.get_client.return_value = mock_client mocker.patch(f"{MODULE}.AsyncLlamaStackClientHolder", return_value=mock_holder) - response = await handle_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), @@ -2469,6 +2576,11 @@ async def mock_stream() -> Any: inline_rag_context=RAGContext(), filter_server_tools=False, ) + response = await handle_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) collected: list[str] = [] async for part in response.body_iterator: chunk_str = ( @@ -2550,10 +2662,9 @@ async def mock_stream() -> Any: mock_holder.get_client.return_value = mock_client mocker.patch(f"{MODULE}.AsyncLlamaStackClientHolder", return_value=mock_holder) - response = await handle_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), @@ -2561,6 +2672,11 @@ async def mock_stream() -> Any: inline_rag_context=RAGContext(), filter_server_tools=False, ) + response = await handle_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) collected: list[str] = [] async for part in response.body_iterator: chunk_str = ( diff --git a/tests/unit/app/endpoints/test_responses_splunk.py b/tests/unit/app/endpoints/test_responses_splunk.py index ce1588a75..73d585963 100644 --- a/tests/unit/app/endpoints/test_responses_splunk.py +++ b/tests/unit/app/endpoints/test_responses_splunk.py @@ -23,6 +23,7 @@ from configuration import AppConfig from models.requests import ResponsesRequest from observability.formats.responses import ResponsesEventData +from tests.unit.app.endpoints.test_responses import build_api_params_and_context from utils.types import RAGContext, TurnSummary MODULE = "app.endpoints.responses" @@ -256,10 +257,9 @@ async def test_non_streaming_shield_blocked( ) mock_queue = mocker.patch(f"{MODULE}._queue_responses_splunk_event") - await handle_non_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Bad input", started_at=datetime.now(UTC), @@ -268,6 +268,11 @@ async def test_non_streaming_shield_blocked( background_tasks=mock_background_tasks, rh_identity_context=("org1", "sys1"), ) + await handle_non_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) mock_queue.assert_called_once() call_kwargs = mock_queue.call_args[1] @@ -333,10 +338,9 @@ async def test_non_streaming_error_fires_telemetry( mock_queue = mocker.patch(f"{MODULE}._queue_responses_splunk_event") with pytest.raises(HTTPException): - await handle_non_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Hello", started_at=datetime.now(UTC), @@ -345,6 +349,11 @@ async def test_non_streaming_error_fires_telemetry( background_tasks=mock_background_tasks, rh_identity_context=("org1", "sys1"), ) + await handle_non_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) mock_queue.assert_called_once() assert mock_queue.call_args[1]["sourcetype"] == "responses_error" @@ -412,10 +421,9 @@ async def test_non_streaming_success( mock_queue = mocker.patch(f"{MODULE}._queue_responses_splunk_event") - await handle_non_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Hello", started_at=datetime.now(UTC), @@ -424,6 +432,11 @@ async def test_non_streaming_success( background_tasks=mock_background_tasks, rh_identity_context=("org1", "sys1"), ) + await handle_non_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) # The success hook fires once (blocked hook is skipped because decision != "blocked") mock_queue.assert_called_once() @@ -469,10 +482,9 @@ async def test_streaming_shield_blocked( mock_queue = mocker.patch(f"{MODULE}._queue_responses_splunk_event") - response = await handle_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Bad", started_at=datetime.now(UTC), @@ -481,6 +493,11 @@ async def test_streaming_shield_blocked( background_tasks=mock_background_tasks, rh_identity_context=("org1", "sys1"), ) + response = await handle_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) assert isinstance(response, StreamingResponse) # Consume the stream to trigger generate_response() completion @@ -547,10 +564,9 @@ async def test_streaming_error_fires_telemetry( mock_queue = mocker.patch(f"{MODULE}._queue_responses_splunk_event") with pytest.raises(HTTPException): - await handle_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Hello", started_at=datetime.now(UTC), @@ -559,6 +575,11 @@ async def test_streaming_error_fires_telemetry( background_tasks=mock_background_tasks, rh_identity_context=("org1", "sys1"), ) + await handle_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) mock_queue.assert_called_once() assert mock_queue.call_args[1]["sourcetype"] == "responses_error" @@ -628,10 +649,9 @@ async def mock_stream() -> Any: mock_queue = mocker.patch(f"{MODULE}._queue_responses_splunk_event") - response = await handle_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Hi", started_at=datetime.now(UTC), @@ -640,6 +660,11 @@ async def mock_stream() -> Any: background_tasks=mock_background_tasks, rh_identity_context=("org1", "sys1"), ) + response = await handle_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) assert isinstance(response, StreamingResponse) @@ -700,10 +725,9 @@ async def test_splunk_disabled_no_background_tasks( mock_queue = mocker.patch(f"{MODULE}._queue_responses_splunk_event") # background_tasks=None (the default) means Splunk is disabled - await handle_non_streaming_response( - client=mock_client, - original_request=request, + api_params, context = build_api_params_and_context( updated_request=request, + client=mock_client, auth=MOCK_AUTH, input_text="Bad input", started_at=datetime.now(UTC), @@ -712,6 +736,11 @@ async def test_splunk_disabled_no_background_tasks( background_tasks=None, rh_identity_context=("org1", "sys1"), ) + await handle_non_streaming_response( + original_request=request, + api_params=api_params, + context=context, + ) mock_queue.assert_called_once() assert mock_queue.call_args[1]["background_tasks"] is None diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index bf2738cf4..073930941 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -64,6 +64,7 @@ MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT, ) +from models.common.responses.responses_api_params import ResponsesApiParams from models.config import Action from models.context import ResponseGeneratorContext from models.requests import Attachment, QueryRequest @@ -74,7 +75,6 @@ RAGChunk, RAGContext, ReferencedDocument, - ResponsesApiParams, ShieldModerationPassed, TurnSummary, ) diff --git a/tests/unit/utils/test_types.py b/tests/unit/utils/test_types.py index e05219d81..6a62c5da5 100644 --- a/tests/unit/utils/test_types.py +++ b/tests/unit/utils/test_types.py @@ -10,9 +10,9 @@ ) from pydantic import AnyUrl, ValidationError +from models.common.responses.responses_api_params import ResponsesApiParams from utils.types import ( ReferencedDocument, - ResponsesApiParams, ToolCallSummary, ToolResultSummary, content_to_str, From c9d74ee210294a456efb98010c9b883fcb09fe28 Mon Sep 17 00:00:00 2001 From: Andrej Simurka Date: Tue, 28 Apr 2026 15:27:04 +0200 Subject: [PATCH 2/2] Fixed validation bug --- src/app/endpoints/responses.py | 9 +++++++-- src/models/common/responses/responses_api_params.py | 4 ++-- tests/unit/app/endpoints/test_responses.py | 7 ++++++- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/app/endpoints/responses.py b/src/app/endpoints/responses.py index 7031ca676..971bb1b59 100644 --- a/src/app/endpoints/responses.py +++ b/src/app/endpoints/responses.py @@ -341,7 +341,12 @@ async def responses_endpoint_handler( original_request.input, inline_rag_context.context_text ) - api_params = ResponsesApiParams.model_validate(updated_request.model_dump()) + api_params = ResponsesApiParams.model_validate( + { + **updated_request.model_dump(exclude={"tools"}), + "tools": updated_request.tools, + } + ) context = ResponsesContext( client=client, auth=auth, @@ -821,7 +826,7 @@ async def response_generator( await append_turn_items_to_conversation( context.client, api_params.conversation, - context.input_text, + api_params.input, latest_response_object.output, ) diff --git a/src/models/common/responses/responses_api_params.py b/src/models/common/responses/responses_api_params.py index 64709811b..6767c392e 100644 --- a/src/models/common/responses/responses_api_params.py +++ b/src/models/common/responses/responses_api_params.py @@ -1,7 +1,7 @@ """Request parameter model for Llama Stack responses API calls.""" from collections.abc import Mapping -from typing import Any, Optional +from typing import Any, Final, Optional from llama_stack_api.openai_responses import ( OpenAIResponseInputTool as InputTool, @@ -27,7 +27,7 @@ from utils.types import IncludeParameter, ResponseInput # Attribute names that are echoed back in the response. -_ECHOED_FIELDS = set( +_ECHOED_FIELDS: Final[set[str]] = set( { "instructions", "max_tool_calls", diff --git a/tests/unit/app/endpoints/test_responses.py b/tests/unit/app/endpoints/test_responses.py index afa9cd530..9533af0a9 100644 --- a/tests/unit/app/endpoints/test_responses.py +++ b/tests/unit/app/endpoints/test_responses.py @@ -70,7 +70,12 @@ def build_api_params_and_context( # pylint: disable=too-many-arguments generate_topic_summary: bool = False, ) -> tuple[ResponsesApiParams, ResponsesContext]: """Build api_params/context for direct helper invocation tests.""" - api_params = ResponsesApiParams.model_validate(updated_request.model_dump()) + api_params = ResponsesApiParams.model_validate( + { + **updated_request.model_dump(exclude={"tools"}), + "tools": updated_request.tools, + } + ) context = ResponsesContext.model_construct( client=client, auth=auth,