Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 108 additions & 87 deletions backend/src/apis/inference_api/chat/voice_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,117 +268,110 @@ async def _finalize_voice_session(session_id: str, user_id: str, voice_agent: An
logger.error(f"Failed to store voice cost metadata for {_sanitize_log(session_id)}: {e}", exc_info=True)


@router.websocket("/ws")
async def voice_stream(websocket: WebSocket):
def _get_param_from_request(websocket: WebSocket, header_suffix: str, query_param: Optional[str]) -> Optional[str]:
"""Extract param from AgentCore custom header (cloud) or query param (local)."""
header_name = f"x-amzn-bedrock-agentcore-runtime-custom-{header_suffix}"
custom_header = websocket.headers.get(header_name)
if custom_header:
return custom_header
return query_param


def _get_enabled_tools_from_request(websocket: WebSocket, query_param: Optional[str]) -> Optional[list]:
"""Extract enabled_tools from AgentCore custom header or query param."""
tools_json = _get_param_from_request(websocket, "enabled-tools", query_param)
if not tools_json:
return None
try:
return json.loads(tools_json)
except json.JSONDecodeError as e:
logger.warning(f"Invalid enabled_tools JSON: {_sanitize_log(str(e))}")
return None


@router.websocket("/voice/stream")
async def voice_stream(
websocket: WebSocket,
session_id: Optional[str] = None,
user_id: Optional[str] = None,
enabled_tools: Optional[str] = None,
token: Optional[str] = None,
):
"""
Bidirectional voice streaming endpoint.

Supports two connection modes:

**AgentCore (deployed):** Browser connects via
``wss://bedrock-agentcore.<region>.amazonaws.com/runtimes/<ARN>/ws``.
Auth is handled by AgentCore's JWT Authorizer; the bearer token is sent
in ``Sec-WebSocket-Protocol`` (base64url-encoded). Session ID and
enabled_tools arrive as ``X-Amzn-Bedrock-AgentCore-Runtime-Custom-*``
headers (set via query params on the original URL). The token for
user-claim extraction comes in the first ``config`` message.

**Local dev:** Browser connects directly to ``ws://localhost:8001/ws``.
Session ID and token are plain query params; auth is checked pre-accept.
"""
# Read session_id: AgentCore custom header → query param → auto-generate
session_id = (
websocket.headers.get("x-amzn-bedrock-agentcore-runtime-custom-sessionid")
or websocket.query_params.get("session_id")
or str(uuid.uuid4())
)
# Read enabled_tools: AgentCore custom header → query param
enabled_tools_raw = (
websocket.headers.get("x-amzn-bedrock-agentcore-runtime-custom-enabledtools")
or websocket.query_params.get("enabled_tools")
or ""
)
# Local dev sends token as query param; AgentCore flow uses config message
token = websocket.query_params.get("token", "")
voice_agent = None
user_id = None
Auth is handled by AgentCore's JWT Authorizer at the proxy layer.
The bearer token for user-claim extraction arrives in the first
``config`` message sent by the client after connection opens.

**Local dev:** Browser connects directly to
``ws://localhost:8001/voice/stream``. Session ID and token are
plain query params; the config message supplements them.
"""
# Accept immediately — AgentCore validates auth at the proxy layer;
# user claims are extracted from the config message after accept.
await websocket.accept()

# Resolve params: AgentCore custom header → query param → default
session_id = _get_param_from_request(websocket, "session-id", session_id)
user_id = _get_param_from_request(websocket, "user-id", user_id)
enabled_tools_list = _get_enabled_tools_from_request(websocket, enabled_tools)
auth_token = _get_param_from_request(websocket, "auth-token", token) or ""

# Always read config message from client (sent on WebSocket open).
# Required for auth_token in AgentCore mode and supplements any
# missing params (AgentCore proxy may not forward all query params).
try:
# Parse enabled tools
enabled_tools = None
if enabled_tools_raw:
try:
enabled_tools = json.loads(enabled_tools_raw)
except json.JSONDecodeError:
logger.warning(f"Invalid enabled_tools JSON: {_sanitize_log(enabled_tools_raw)}")

# Detect AgentCore path via custom headers (AgentCore strips
# Sec-WebSocket-Protocol, so we can't rely on subprotocol detection)
is_agentcore = bool(
websocket.headers.get("x-amzn-bedrock-agentcore-runtime-custom-sessionid")
first_msg = await asyncio.wait_for(
websocket.receive_json(), timeout=10.0
)
if first_msg.get("type") == "config":
session_id = first_msg.get("session_id") or session_id
user_id = first_msg.get("user_id") or user_id
enabled_tools_list = first_msg.get("enabled_tools") or enabled_tools_list
auth_token = first_msg.get("auth_token") or auth_token
logger.info(f"Voice config received from client message")
except asyncio.TimeoutError:
logger.warning("No config message received within 10s, using query params")
except Exception as e:
logger.warning(f"Error reading config message: {e}")

# Check if browser sent subprotocol (may reach container in local-proxy setups)
requested_protocols = websocket.headers.get("sec-websocket-protocol", "")
accept_subprotocol = None
if "base64UrlBearerAuthorization" in requested_protocols:
accept_subprotocol = "base64UrlBearerAuthorization"

# Local dev: authenticate before accepting using query-param token
user_info = _extract_user_from_token(token) if token else None
if not is_agentcore and not accept_subprotocol and not user_info:
# Not an AgentCore connection AND no valid local token → reject
await websocket.close(code=4001, reason="Authentication required")
return
# Generate session_id if not provided by any source
if not session_id:
session_id = str(uuid.uuid4())
logger.info(f"Generated new voice session ID: {_sanitize_log(session_id)}")

# Extract user from token (query param or config message)
if not user_id and auth_token:
user_info = _extract_user_from_token(auth_token)
if user_info:
user_id = user_info["user_id"]
auth_token = user_info["raw_token"]
else:
# AgentCore path: auth_token will come from config message
auth_token = ""

# Accept the WebSocket connection (with subprotocol if applicable)
await websocket.accept(subprotocol=accept_subprotocol)
logger.info(f"Voice WebSocket connected: session={_sanitize_log(session_id)}, user={_sanitize_log(user_id or 'pending')}")
if not user_id:
await websocket.send_json({"type": "bidi_error", "message": "Authentication required"})
await websocket.close(code=4001, reason="Authentication required")
return

# Wait for initial config message (supplements query params)
try:
first_msg = await asyncio.wait_for(
websocket.receive_json(), timeout=10.0
)
if first_msg.get("type") == "config":
if first_msg.get("auth_token"):
auth_token = first_msg["auth_token"]
if first_msg.get("enabled_tools"):
enabled_tools = first_msg["enabled_tools"]
logger.info(f"Voice config received: session={_sanitize_log(session_id)}")
except asyncio.TimeoutError:
logger.warning("No config message received within 10s, using query params")
except Exception as e:
logger.warning(f"Error reading config message: {e}")
if not user_id:
await websocket.send_json({"type": "bidi_error", "message": "Config message required"})
await websocket.close(code=4001, reason="Authentication required")
return
logger.info(
f"Voice WebSocket connected: session={_sanitize_log(session_id)}, "
f"user={_sanitize_log(user_id)}, tools={len(enabled_tools_list or [])}, "
f"auth_token={'present' if auth_token else 'missing'}"
)

# AgentCore path: extract user from config message token if not yet identified
if not user_id and auth_token:
user_info = _extract_user_from_token(auth_token)
if user_info:
user_id = user_info["user_id"]
else:
await websocket.send_json({"type": "bidi_error", "message": "Authentication required"})
await websocket.close(code=4001, reason="Authentication required")
return
voice_agent = None

try:
# Create VoiceAgent
VoiceAgent = _get_voice_agent_class()
voice_agent = VoiceAgent(
session_id=session_id,
user_id=user_id,
auth_token=auth_token,
enabled_tools=enabled_tools,
enabled_tools=enabled_tools_list,
)

_active_sessions[session_id] = voice_agent
Expand Down Expand Up @@ -553,3 +546,31 @@ async def stop_voice_session(session_id: str):

_active_sessions.pop(session_id, None)
return {"status": "stopped", "session_id": session_id}


# =============================================================================
# /ws alias for AgentCore Runtime
# AgentCore Runtime routes WebSocket requests to /ws on the container.
# This delegates to /voice/stream for cloud deployment compatibility.
# =============================================================================

@router.websocket("/ws")
async def ws_stream(
websocket: WebSocket,
session_id: Optional[str] = None,
user_id: Optional[str] = None,
enabled_tools: Optional[str] = None,
token: Optional[str] = None,
):
"""WebSocket endpoint for AgentCore Runtime (cloud mode).

AgentCore Runtime expects containers to implement WebSocket at /ws
path on port 8080. This endpoint delegates to voice_stream.
"""
await voice_stream(
websocket=websocket,
session_id=session_id,
user_id=user_id,
enabled_tools=enabled_tools,
token=token,
)
Original file line number Diff line number Diff line change
Expand Up @@ -167,19 +167,17 @@ export class VoiceChatService implements OnDestroy {
let protocols: string[] | undefined;

if (isAgentCore) {
// AgentCore: /ws path, session via custom header param, auth via Sec-WebSocket-Protocol
const params = new URLSearchParams();
params.set('X-Amzn-Bedrock-AgentCore-Runtime-Custom-SessionId', this.sessionId!);
url = `${wsUrl}/ws?${params.toString()}`;
// AgentCore: /ws path, auth via Sec-WebSocket-Protocol
url = `${wsUrl}/ws`;

const base64url = btoa(token)
.replace(/\+/g, '-')
.replace(/\//g, '_')
.replace(/=/g, '');
protocols = [`base64UrlBearerAuthorization.${base64url}`, 'base64UrlBearerAuthorization'];
} else {
// Local dev: direct connection with query params
url = `${wsUrl}/ws?session_id=${encodeURIComponent(this.sessionId!)}&token=${encodeURIComponent(token)}`;
// Local dev: /voice/stream path with query params
url = `${wsUrl}/voice/stream?session_id=${encodeURIComponent(this.sessionId!)}&token=${encodeURIComponent(token)}`;
}

await this.openWebSocket(url, token, protocols);
Expand Down
Loading