diff --git a/backend/src/apis/inference_api/chat/voice_routes.py b/backend/src/apis/inference_api/chat/voice_routes.py index 1727cede..bef70709 100644 --- a/backend/src/apis/inference_api/chat/voice_routes.py +++ b/backend/src/apis/inference_api/chat/voice_routes.py @@ -268,8 +268,35 @@ async def _finalize_voice_session(session_id: str, user_id: str, voice_agent: An logger.error(f"Failed to store voice cost metadata for {_sanitize_log(session_id)}: {e}", exc_info=True) -@router.websocket("/ws") -async def voice_stream(websocket: WebSocket): +def _get_param_from_request(websocket: WebSocket, header_suffix: str, query_param: Optional[str]) -> Optional[str]: + """Extract param from AgentCore custom header (cloud) or query param (local).""" + header_name = f"x-amzn-bedrock-agentcore-runtime-custom-{header_suffix}" + custom_header = websocket.headers.get(header_name) + if custom_header: + return custom_header + return query_param + + +def _get_enabled_tools_from_request(websocket: WebSocket, query_param: Optional[str]) -> Optional[list]: + """Extract enabled_tools from AgentCore custom header or query param.""" + tools_json = _get_param_from_request(websocket, "enabled-tools", query_param) + if not tools_json: + return None + try: + return json.loads(tools_json) + except json.JSONDecodeError as e: + logger.warning(f"Invalid enabled_tools JSON: {_sanitize_log(str(e))}") + return None + + +@router.websocket("/voice/stream") +async def voice_stream( + websocket: WebSocket, + session_id: Optional[str] = None, + user_id: Optional[str] = None, + enabled_tools: Optional[str] = None, + token: Optional[str] = None, +): """ Bidirectional voice streaming endpoint. @@ -277,108 +304,74 @@ async def voice_stream(websocket: WebSocket): **AgentCore (deployed):** Browser connects via ``wss://bedrock-agentcore..amazonaws.com/runtimes//ws``. - Auth is handled by AgentCore's JWT Authorizer; the bearer token is sent - in ``Sec-WebSocket-Protocol`` (base64url-encoded). Session ID and - enabled_tools arrive as ``X-Amzn-Bedrock-AgentCore-Runtime-Custom-*`` - headers (set via query params on the original URL). The token for - user-claim extraction comes in the first ``config`` message. - - **Local dev:** Browser connects directly to ``ws://localhost:8001/ws``. - Session ID and token are plain query params; auth is checked pre-accept. - """ - # Read session_id: AgentCore custom header → query param → auto-generate - session_id = ( - websocket.headers.get("x-amzn-bedrock-agentcore-runtime-custom-sessionid") - or websocket.query_params.get("session_id") - or str(uuid.uuid4()) - ) - # Read enabled_tools: AgentCore custom header → query param - enabled_tools_raw = ( - websocket.headers.get("x-amzn-bedrock-agentcore-runtime-custom-enabledtools") - or websocket.query_params.get("enabled_tools") - or "" - ) - # Local dev sends token as query param; AgentCore flow uses config message - token = websocket.query_params.get("token", "") - voice_agent = None - user_id = None + Auth is handled by AgentCore's JWT Authorizer at the proxy layer. + The bearer token for user-claim extraction arrives in the first + ``config`` message sent by the client after connection opens. + **Local dev:** Browser connects directly to + ``ws://localhost:8001/voice/stream``. Session ID and token are + plain query params; the config message supplements them. + """ + # Accept immediately — AgentCore validates auth at the proxy layer; + # user claims are extracted from the config message after accept. + await websocket.accept() + + # Resolve params: AgentCore custom header → query param → default + session_id = _get_param_from_request(websocket, "session-id", session_id) + user_id = _get_param_from_request(websocket, "user-id", user_id) + enabled_tools_list = _get_enabled_tools_from_request(websocket, enabled_tools) + auth_token = _get_param_from_request(websocket, "auth-token", token) or "" + + # Always read config message from client (sent on WebSocket open). + # Required for auth_token in AgentCore mode and supplements any + # missing params (AgentCore proxy may not forward all query params). try: - # Parse enabled tools - enabled_tools = None - if enabled_tools_raw: - try: - enabled_tools = json.loads(enabled_tools_raw) - except json.JSONDecodeError: - logger.warning(f"Invalid enabled_tools JSON: {_sanitize_log(enabled_tools_raw)}") - - # Detect AgentCore path via custom headers (AgentCore strips - # Sec-WebSocket-Protocol, so we can't rely on subprotocol detection) - is_agentcore = bool( - websocket.headers.get("x-amzn-bedrock-agentcore-runtime-custom-sessionid") + first_msg = await asyncio.wait_for( + websocket.receive_json(), timeout=10.0 ) + if first_msg.get("type") == "config": + session_id = first_msg.get("session_id") or session_id + user_id = first_msg.get("user_id") or user_id + enabled_tools_list = first_msg.get("enabled_tools") or enabled_tools_list + auth_token = first_msg.get("auth_token") or auth_token + logger.info(f"Voice config received from client message") + except asyncio.TimeoutError: + logger.warning("No config message received within 10s, using query params") + except Exception as e: + logger.warning(f"Error reading config message: {e}") - # Check if browser sent subprotocol (may reach container in local-proxy setups) - requested_protocols = websocket.headers.get("sec-websocket-protocol", "") - accept_subprotocol = None - if "base64UrlBearerAuthorization" in requested_protocols: - accept_subprotocol = "base64UrlBearerAuthorization" - - # Local dev: authenticate before accepting using query-param token - user_info = _extract_user_from_token(token) if token else None - if not is_agentcore and not accept_subprotocol and not user_info: - # Not an AgentCore connection AND no valid local token → reject - await websocket.close(code=4001, reason="Authentication required") - return + # Generate session_id if not provided by any source + if not session_id: + session_id = str(uuid.uuid4()) + logger.info(f"Generated new voice session ID: {_sanitize_log(session_id)}") + # Extract user from token (query param or config message) + if not user_id and auth_token: + user_info = _extract_user_from_token(auth_token) if user_info: user_id = user_info["user_id"] - auth_token = user_info["raw_token"] - else: - # AgentCore path: auth_token will come from config message - auth_token = "" - # Accept the WebSocket connection (with subprotocol if applicable) - await websocket.accept(subprotocol=accept_subprotocol) - logger.info(f"Voice WebSocket connected: session={_sanitize_log(session_id)}, user={_sanitize_log(user_id or 'pending')}") + if not user_id: + await websocket.send_json({"type": "bidi_error", "message": "Authentication required"}) + await websocket.close(code=4001, reason="Authentication required") + return - # Wait for initial config message (supplements query params) - try: - first_msg = await asyncio.wait_for( - websocket.receive_json(), timeout=10.0 - ) - if first_msg.get("type") == "config": - if first_msg.get("auth_token"): - auth_token = first_msg["auth_token"] - if first_msg.get("enabled_tools"): - enabled_tools = first_msg["enabled_tools"] - logger.info(f"Voice config received: session={_sanitize_log(session_id)}") - except asyncio.TimeoutError: - logger.warning("No config message received within 10s, using query params") - except Exception as e: - logger.warning(f"Error reading config message: {e}") - if not user_id: - await websocket.send_json({"type": "bidi_error", "message": "Config message required"}) - await websocket.close(code=4001, reason="Authentication required") - return + logger.info( + f"Voice WebSocket connected: session={_sanitize_log(session_id)}, " + f"user={_sanitize_log(user_id)}, tools={len(enabled_tools_list or [])}, " + f"auth_token={'present' if auth_token else 'missing'}" + ) - # AgentCore path: extract user from config message token if not yet identified - if not user_id and auth_token: - user_info = _extract_user_from_token(auth_token) - if user_info: - user_id = user_info["user_id"] - else: - await websocket.send_json({"type": "bidi_error", "message": "Authentication required"}) - await websocket.close(code=4001, reason="Authentication required") - return + voice_agent = None + try: # Create VoiceAgent VoiceAgent = _get_voice_agent_class() voice_agent = VoiceAgent( session_id=session_id, user_id=user_id, auth_token=auth_token, - enabled_tools=enabled_tools, + enabled_tools=enabled_tools_list, ) _active_sessions[session_id] = voice_agent @@ -553,3 +546,31 @@ async def stop_voice_session(session_id: str): _active_sessions.pop(session_id, None) return {"status": "stopped", "session_id": session_id} + + +# ============================================================================= +# /ws alias for AgentCore Runtime +# AgentCore Runtime routes WebSocket requests to /ws on the container. +# This delegates to /voice/stream for cloud deployment compatibility. +# ============================================================================= + +@router.websocket("/ws") +async def ws_stream( + websocket: WebSocket, + session_id: Optional[str] = None, + user_id: Optional[str] = None, + enabled_tools: Optional[str] = None, + token: Optional[str] = None, +): + """WebSocket endpoint for AgentCore Runtime (cloud mode). + + AgentCore Runtime expects containers to implement WebSocket at /ws + path on port 8080. This endpoint delegates to voice_stream. + """ + await voice_stream( + websocket=websocket, + session_id=session_id, + user_id=user_id, + enabled_tools=enabled_tools, + token=token, + ) diff --git a/frontend/ai.client/src/app/session/services/voice/voice-chat.service.ts b/frontend/ai.client/src/app/session/services/voice/voice-chat.service.ts index 019582a0..50a2817f 100644 --- a/frontend/ai.client/src/app/session/services/voice/voice-chat.service.ts +++ b/frontend/ai.client/src/app/session/services/voice/voice-chat.service.ts @@ -167,10 +167,8 @@ export class VoiceChatService implements OnDestroy { let protocols: string[] | undefined; if (isAgentCore) { - // AgentCore: /ws path, session via custom header param, auth via Sec-WebSocket-Protocol - const params = new URLSearchParams(); - params.set('X-Amzn-Bedrock-AgentCore-Runtime-Custom-SessionId', this.sessionId!); - url = `${wsUrl}/ws?${params.toString()}`; + // AgentCore: /ws path, auth via Sec-WebSocket-Protocol + url = `${wsUrl}/ws`; const base64url = btoa(token) .replace(/\+/g, '-') @@ -178,8 +176,8 @@ export class VoiceChatService implements OnDestroy { .replace(/=/g, ''); protocols = [`base64UrlBearerAuthorization.${base64url}`, 'base64UrlBearerAuthorization']; } else { - // Local dev: direct connection with query params - url = `${wsUrl}/ws?session_id=${encodeURIComponent(this.sessionId!)}&token=${encodeURIComponent(token)}`; + // Local dev: /voice/stream path with query params + url = `${wsUrl}/voice/stream?session_id=${encodeURIComponent(this.sessionId!)}&token=${encodeURIComponent(token)}`; } await this.openWebSocket(url, token, protocols);