From ddea438f4fc3fd573952cdce02565688437f3e15 Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Tue, 21 Apr 2026 22:03:07 -0600 Subject: [PATCH 01/35] feat(connectors): add AgentCore Identity wrapper and Runtime context middleware First phase of the Connectors refactor, which will eventually replace the bespoke OAuth token store (OAuthTokenRepository, KMS-encrypted DynamoDB, Secrets Manager client credentials, manual refresh) with AgentCore Identity's managed token vault and credential providers. - AgentCoreContextMiddleware copies the four Runtime headers (WorkloadAccessToken, OAuth2CallbackUrl, session ID, request ID) into BedrockAgentCoreContext on every invocation. Required because the Inference API is a plain FastAPI app rather than BedrockAgentCoreApp, so the SDK does not populate the context for us. No-op when headers are absent, so local development and unit tests continue to work without mocks. - AgentCoreIdentityClient wraps IdentityClient.get_token() with a narrower, platform-friendly surface for USER_FEDERATION (3LO) flows. Surfaces the "user consent required" case as a structured TokenResult(authorization_url=...) rather than an exception, so it can flow through the existing SSE stream as a new event type in a later phase. Both modules are pure additions; no existing code path calls them yet. Co-Authored-By: Claude Opus 4.7 --- .../integrations/agentcore_identity.py | 169 ++++++++++++++++++ .../apis/inference_api/middleware/__init__.py | 1 + .../middleware/agentcore_context.py | 56 ++++++ .../integrations/test_agentcore_identity.py | 169 ++++++++++++++++++ .../test_agentcore_context_middleware.py | 103 +++++++++++ 5 files changed, 498 insertions(+) create mode 100644 backend/src/agents/main_agent/integrations/agentcore_identity.py create mode 100644 backend/src/apis/inference_api/middleware/__init__.py create mode 100644 backend/src/apis/inference_api/middleware/agentcore_context.py create mode 100644 backend/tests/agents/main_agent/integrations/test_agentcore_identity.py create mode 100644 backend/tests/apis/inference_api/test_agentcore_context_middleware.py diff --git a/backend/src/agents/main_agent/integrations/agentcore_identity.py b/backend/src/agents/main_agent/integrations/agentcore_identity.py new file mode 100644 index 00000000..195464ec --- /dev/null +++ b/backend/src/agents/main_agent/integrations/agentcore_identity.py @@ -0,0 +1,169 @@ +"""AgentCore Identity integration for external MCP tool authorization. + +Wraps `bedrock_agentcore.services.identity.IdentityClient` with a narrower, +platform-friendly surface for retrieving OAuth2 access tokens on behalf of a +user via the USER_FEDERATION (3LO) flow. + +The client pulls the per-invocation workload identity token from +`BedrockAgentCoreContext`, which is populated by `AgentCoreContextMiddleware` +on the Inference API request path. No workload token has to be threaded +through function arguments. + +Two results are possible when fetching a token: + +1. A valid token exists in the AgentCore Token Vault for this user+provider + → returned synchronously as `TokenResult(access_token=...)`. +2. The user has never consented (or consent has been revoked, or scopes have + changed) → the caller receives `TokenResult(authorization_url=...)`. The + URL must be surfaced to the user; after they complete the consent flow the + frontend calls `CompleteResourceTokenAuthCommand` and the next tool call + will hit case 1. + +This module intentionally does not raise on "consent required" — it returns +a structured result because surfacing an auth URL is a normal, expected +outcome that flows through our SSE stream, not an error. +""" + +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass +from typing import List, Optional + +from bedrock_agentcore.runtime import BedrockAgentCoreContext +from bedrock_agentcore.services.identity import IdentityClient + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class TokenResult: + """Result of a token fetch attempt. + + Exactly one of `access_token` or `authorization_url` will be populated. + """ + + access_token: Optional[str] = None + authorization_url: Optional[str] = None + + @property + def requires_consent(self) -> bool: + return self.access_token is None and self.authorization_url is not None + + def __post_init__(self) -> None: + if bool(self.access_token) == bool(self.authorization_url): + raise ValueError( + "TokenResult must have exactly one of access_token or authorization_url" + ) + + +class WorkloadTokenUnavailableError(RuntimeError): + """Raised when no workload access token is present on the current context. + + This indicates the caller is running outside an AgentCore Runtime + invocation, or the `AgentCoreContextMiddleware` was not applied. + """ + + +class AgentCoreIdentityClient: + """Thin async-friendly wrapper around `IdentityClient` for 3LO tokens. + + The underlying `IdentityClient` is synchronous and uses boto3; callers + should treat `get_token_for_user` as potentially blocking and run it via + `asyncio.to_thread` when invoked from async code. + """ + + def __init__(self, region: Optional[str] = None): + self._region = region or os.environ.get("AWS_REGION", "us-east-1") + self._client = IdentityClient(region=self._region) + + def get_token_for_user( + self, + *, + provider_name: str, + scopes: List[str], + callback_url: Optional[str] = None, + force_authentication: bool = False, + ) -> TokenResult: + """Fetch a user-federated OAuth2 access token for `provider_name`. + + Pulls the workload identity token from `BedrockAgentCoreContext`, so + this must be called from inside an AgentCore Runtime invocation that + has been processed by `AgentCoreContextMiddleware`. + + If the user has not consented (or re-consent is required), returns a + `TokenResult` with `authorization_url` populated instead of raising. + + Args: + provider_name: Credential provider name registered with AgentCore + Identity (e.g. "google-workspace"). + scopes: OAuth2 scopes to request for this token. + callback_url: OAuth2 return URL. Defaults to the callback URL on + the current context (injected by Runtime via the + `OAuth2CallbackUrl` header). + force_authentication: If True, bypasses the token vault cache and + forces the user through the consent flow again. Used for + scope upgrades. + + Returns: + `TokenResult` with either `access_token` or `authorization_url`. + + Raises: + WorkloadTokenUnavailableError: No workload token on context. + """ + workload_token = BedrockAgentCoreContext.get_workload_access_token() + if not workload_token: + raise WorkloadTokenUnavailableError( + "No WorkloadAccessToken on context — ensure " + "AgentCoreContextMiddleware is installed and this call " + "runs inside an AgentCore Runtime invocation." + ) + + resolved_callback_url = ( + callback_url or BedrockAgentCoreContext.get_oauth2_callback_url() + ) + + captured_url: dict[str, Optional[str]] = {"url": None} + + def _capture_auth_url(url: str) -> None: + captured_url["url"] = url + + token = self._client.get_token( + provider_name=provider_name, + scopes=scopes, + agent_identity_token=workload_token, + auth_flow="USER_FEDERATION", + callback_url=resolved_callback_url, + force_authentication=force_authentication, + on_auth_url=_capture_auth_url, + ) + + # `get_token` returns either the token string or triggers on_auth_url + # when consent is required. Guard both: if we captured a URL, surface + # it as a TokenResult even if the SDK also returned a (stale) token. + if captured_url["url"]: + logger.info( + "AgentCore Identity requires user consent for provider=%s", + provider_name, + ) + return TokenResult(authorization_url=captured_url["url"]) + + if not token: + raise RuntimeError( + f"AgentCore Identity returned neither a token nor an " + f"authorization URL for provider={provider_name}" + ) + + return TokenResult(access_token=token) + + +_default_client: Optional[AgentCoreIdentityClient] = None + + +def get_agentcore_identity_client() -> AgentCoreIdentityClient: + """Return the process-wide `AgentCoreIdentityClient` singleton.""" + global _default_client + if _default_client is None: + _default_client = AgentCoreIdentityClient() + return _default_client diff --git a/backend/src/apis/inference_api/middleware/__init__.py b/backend/src/apis/inference_api/middleware/__init__.py new file mode 100644 index 00000000..44a0d67c --- /dev/null +++ b/backend/src/apis/inference_api/middleware/__init__.py @@ -0,0 +1 @@ +"""Inference API middleware.""" diff --git a/backend/src/apis/inference_api/middleware/agentcore_context.py b/backend/src/apis/inference_api/middleware/agentcore_context.py new file mode 100644 index 00000000..2bf8fe90 --- /dev/null +++ b/backend/src/apis/inference_api/middleware/agentcore_context.py @@ -0,0 +1,56 @@ +"""AgentCore Runtime context middleware. + +Bridges AgentCore Runtime request headers into BedrockAgentCoreContext so that +downstream code (e.g. IdentityClient token lookups) can access the per-invocation +workload identity token without threading it through every function call. + +AgentCore Runtime injects these headers on every invocation: + - WorkloadAccessToken: per-user workload identity token, derived from the + validated inbound JWT by the Runtime's managed JWT authorizer. + - OAuth2CallbackUrl: OAuth2 callback URL registered on the workload identity. + - X-Amzn-Bedrock-AgentCore-Runtime-Session-Id: current session ID. + - X-Amzn-Request-Id: per-request trace ID. + +When the inference API is wrapped by BedrockAgentCoreApp, these are populated +automatically. Because this service runs as a plain FastAPI app inside +AgentCore Runtime, we populate the context ourselves. + +The middleware is a no-op in local development where these headers are absent, +which keeps tests and `python -m main` runs working without mocks. +""" + +import logging + +from bedrock_agentcore.runtime import BedrockAgentCoreContext +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response + +logger = logging.getLogger(__name__) + +HEADER_WORKLOAD_ACCESS_TOKEN = "WorkloadAccessToken" +HEADER_OAUTH2_CALLBACK_URL = "OAuth2CallbackUrl" +HEADER_SESSION_ID = "X-Amzn-Bedrock-AgentCore-Runtime-Session-Id" +HEADER_REQUEST_ID = "X-Amzn-Request-Id" + + +class AgentCoreContextMiddleware(BaseHTTPMiddleware): + """Populates BedrockAgentCoreContext from Runtime request headers.""" + + async def dispatch(self, request: Request, call_next) -> Response: + workload_token = request.headers.get(HEADER_WORKLOAD_ACCESS_TOKEN) + if workload_token: + BedrockAgentCoreContext.set_workload_access_token(workload_token) + + callback_url = request.headers.get(HEADER_OAUTH2_CALLBACK_URL) + if callback_url: + BedrockAgentCoreContext.set_oauth2_callback_url(callback_url) + + session_id = request.headers.get(HEADER_SESSION_ID) + if session_id: + BedrockAgentCoreContext.set_request_context( + request_id=request.headers.get(HEADER_REQUEST_ID, ""), + session_id=session_id, + ) + + return await call_next(request) diff --git a/backend/tests/agents/main_agent/integrations/test_agentcore_identity.py b/backend/tests/agents/main_agent/integrations/test_agentcore_identity.py new file mode 100644 index 00000000..40348df1 --- /dev/null +++ b/backend/tests/agents/main_agent/integrations/test_agentcore_identity.py @@ -0,0 +1,169 @@ +"""Tests for AgentCoreIdentityClient.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from agents.main_agent.integrations.agentcore_identity import ( + AgentCoreIdentityClient, + TokenResult, + WorkloadTokenUnavailableError, +) + + +class TestTokenResult: + def test_access_token_only_is_valid(self) -> None: + result = TokenResult(access_token="abc") + assert result.access_token == "abc" + assert result.authorization_url is None + assert result.requires_consent is False + + def test_authorization_url_only_is_valid(self) -> None: + result = TokenResult(authorization_url="https://example.com/auth") + assert result.requires_consent is True + + def test_both_populated_raises(self) -> None: + with pytest.raises(ValueError): + TokenResult(access_token="a", authorization_url="https://example.com") + + def test_neither_populated_raises(self) -> None: + with pytest.raises(ValueError): + TokenResult() + + +@pytest.fixture +def mock_identity_sdk(): + """Patch the IdentityClient class used inside the wrapper.""" + with patch( + "agents.main_agent.integrations.agentcore_identity.IdentityClient" + ) as sdk_cls: + yield sdk_cls + + +@pytest.fixture +def mock_context(): + """Patch BedrockAgentCoreContext accessors used inside the wrapper.""" + with patch( + "agents.main_agent.integrations.agentcore_identity.BedrockAgentCoreContext" + ) as ctx: + ctx.get_workload_access_token.return_value = "workload-token-xyz" + ctx.get_oauth2_callback_url.return_value = "https://cb.example.com/oauth" + yield ctx + + +class TestGetTokenForUserCacheHit: + def test_returns_access_token_when_vault_has_token( + self, mock_identity_sdk: MagicMock, mock_context: MagicMock + ) -> None: + sdk_instance = mock_identity_sdk.return_value + sdk_instance.get_token.return_value = "ya29.access-token" + + client = AgentCoreIdentityClient(region="us-east-1") + result = client.get_token_for_user( + provider_name="google-workspace", scopes=["openid"] + ) + + assert result.access_token == "ya29.access-token" + assert result.requires_consent is False + + sdk_instance.get_token.assert_called_once() + kwargs = sdk_instance.get_token.call_args.kwargs + assert kwargs["provider_name"] == "google-workspace" + assert kwargs["scopes"] == ["openid"] + assert kwargs["auth_flow"] == "USER_FEDERATION" + assert kwargs["agent_identity_token"] == "workload-token-xyz" + assert kwargs["callback_url"] == "https://cb.example.com/oauth" + assert kwargs["force_authentication"] is False + + def test_explicit_callback_url_overrides_context( + self, mock_identity_sdk: MagicMock, mock_context: MagicMock + ) -> None: + sdk_instance = mock_identity_sdk.return_value + sdk_instance.get_token.return_value = "t" + + client = AgentCoreIdentityClient() + client.get_token_for_user( + provider_name="p", + scopes=["s"], + callback_url="https://override.example.com/cb", + ) + + kwargs = sdk_instance.get_token.call_args.kwargs + assert kwargs["callback_url"] == "https://override.example.com/cb" + + +class TestGetTokenForUserConsentRequired: + def test_returns_authorization_url_when_sdk_invokes_callback( + self, mock_identity_sdk: MagicMock, mock_context: MagicMock + ) -> None: + """When the user needs to consent, the SDK calls on_auth_url with the + consent URL. The wrapper captures it and returns a TokenResult with + authorization_url set rather than raising.""" + sdk_instance = mock_identity_sdk.return_value + + def fake_get_token(**kwargs): + kwargs["on_auth_url"]("https://accounts.example.com/consent?x=1") + return None + + sdk_instance.get_token.side_effect = fake_get_token + + client = AgentCoreIdentityClient() + result = client.get_token_for_user(provider_name="p", scopes=["s"]) + + assert result.requires_consent is True + assert result.authorization_url == "https://accounts.example.com/consent?x=1" + assert result.access_token is None + + def test_auth_url_takes_precedence_over_stale_token( + self, mock_identity_sdk: MagicMock, mock_context: MagicMock + ) -> None: + """Defensive: if the SDK both returns a token AND invokes on_auth_url, + we treat consent-required as the authoritative signal.""" + sdk_instance = mock_identity_sdk.return_value + + def fake_get_token(**kwargs): + kwargs["on_auth_url"]("https://consent.example.com") + return "stale-token" + + sdk_instance.get_token.side_effect = fake_get_token + + client = AgentCoreIdentityClient() + result = client.get_token_for_user(provider_name="p", scopes=["s"]) + + assert result.requires_consent is True + assert result.authorization_url == "https://consent.example.com" + + +class TestGetTokenForUserErrors: + def test_raises_when_no_workload_token_on_context( + self, mock_identity_sdk: MagicMock, mock_context: MagicMock + ) -> None: + mock_context.get_workload_access_token.return_value = None + + client = AgentCoreIdentityClient() + with pytest.raises(WorkloadTokenUnavailableError): + client.get_token_for_user(provider_name="p", scopes=["s"]) + + def test_raises_when_sdk_returns_nothing_and_no_auth_url( + self, mock_identity_sdk: MagicMock, mock_context: MagicMock + ) -> None: + sdk_instance = mock_identity_sdk.return_value + sdk_instance.get_token.return_value = None + + client = AgentCoreIdentityClient() + with pytest.raises(RuntimeError, match="neither a token nor"): + client.get_token_for_user(provider_name="p", scopes=["s"]) + + def test_force_authentication_flag_is_forwarded( + self, mock_identity_sdk: MagicMock, mock_context: MagicMock + ) -> None: + sdk_instance = mock_identity_sdk.return_value + sdk_instance.get_token.return_value = "t" + + client = AgentCoreIdentityClient() + client.get_token_for_user( + provider_name="p", scopes=["s"], force_authentication=True + ) + + kwargs = sdk_instance.get_token.call_args.kwargs + assert kwargs["force_authentication"] is True diff --git a/backend/tests/apis/inference_api/test_agentcore_context_middleware.py b/backend/tests/apis/inference_api/test_agentcore_context_middleware.py new file mode 100644 index 00000000..15516179 --- /dev/null +++ b/backend/tests/apis/inference_api/test_agentcore_context_middleware.py @@ -0,0 +1,103 @@ +"""Tests for AgentCoreContextMiddleware. + +Verifies that Runtime headers are copied into BedrockAgentCoreContext on +each request and that the middleware is a no-op when headers are absent +(local development, unit tests without Runtime). +""" + +from unittest.mock import patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from apis.inference_api.middleware.agentcore_context import ( + HEADER_OAUTH2_CALLBACK_URL, + HEADER_REQUEST_ID, + HEADER_SESSION_ID, + HEADER_WORKLOAD_ACCESS_TOKEN, + AgentCoreContextMiddleware, +) + + +@pytest.fixture +def app() -> FastAPI: + app = FastAPI() + app.add_middleware(AgentCoreContextMiddleware) + + @app.get("/echo") + def echo() -> dict: + return {"ok": True} + + return app + + +@pytest.fixture +def client(app: FastAPI) -> TestClient: + return TestClient(app) + + +class TestAgentCoreContextMiddleware: + def test_copies_workload_access_token_to_context(self, client: TestClient) -> None: + with patch( + "apis.inference_api.middleware.agentcore_context.BedrockAgentCoreContext" + ) as ctx: + response = client.get( + "/echo", headers={HEADER_WORKLOAD_ACCESS_TOKEN: "wat-abc123"} + ) + + assert response.status_code == 200 + ctx.set_workload_access_token.assert_called_once_with("wat-abc123") + + def test_copies_oauth2_callback_url_to_context(self, client: TestClient) -> None: + with patch( + "apis.inference_api.middleware.agentcore_context.BedrockAgentCoreContext" + ) as ctx: + client.get( + "/echo", + headers={HEADER_OAUTH2_CALLBACK_URL: "https://example.com/cb"}, + ) + + ctx.set_oauth2_callback_url.assert_called_once_with("https://example.com/cb") + + def test_copies_session_and_request_id_to_context( + self, client: TestClient + ) -> None: + with patch( + "apis.inference_api.middleware.agentcore_context.BedrockAgentCoreContext" + ) as ctx: + client.get( + "/echo", + headers={ + HEADER_SESSION_ID: "sess-1", + HEADER_REQUEST_ID: "req-1", + }, + ) + + ctx.set_request_context.assert_called_once_with( + request_id="req-1", session_id="sess-1" + ) + + def test_noop_when_headers_absent(self, client: TestClient) -> None: + """Local dev and tests without AgentCore Runtime must still work.""" + with patch( + "apis.inference_api.middleware.agentcore_context.BedrockAgentCoreContext" + ) as ctx: + response = client.get("/echo") + + assert response.status_code == 200 + ctx.set_workload_access_token.assert_not_called() + ctx.set_oauth2_callback_url.assert_not_called() + ctx.set_request_context.assert_not_called() + + def test_session_id_defaults_request_id_to_empty(self, client: TestClient) -> None: + """When session is present but request-id header is missing, the + request_id falls back to empty string rather than None.""" + with patch( + "apis.inference_api.middleware.agentcore_context.BedrockAgentCoreContext" + ) as ctx: + client.get("/echo", headers={HEADER_SESSION_ID: "sess-1"}) + + ctx.set_request_context.assert_called_once_with( + request_id="", session_id="sess-1" + ) From 529c49ac8aff4ecd8e14e485bbf7b3709dc5efac Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Tue, 21 Apr 2026 22:15:55 -0600 Subject: [PATCH 02/35] feat(connectors): route external MCP OAuth through AgentCore Identity Wires the Runtime context middleware into the Inference API and swaps the external MCP client's token source from the bespoke OAuthService to AgentCore Identity's USER_FEDERATION flow. - main.py: installs AgentCoreContextMiddleware so WorkloadAccessToken and OAuth2CallbackUrl Runtime headers populate BedrockAgentCoreContext on every invocation. - external_mcp_client.py: _get_oauth_token now returns a TokenResult from AgentCoreIdentityClient instead of a decrypted token string from OAuthService. Scopes are read from the platform's OAuth provider record so organizations can change them without code. When the SDK signals that user consent is required, the authorization URL is stashed per-user for the inference route to surface via an oauth_required SSE event (emitter to follow in a subsequent commit). load_external_tools skips client creation on consent-required rather than creating a client that would fail at the first request. - Convention: the platform's provider_id is used verbatim as the AgentCore Identity credential-provider name. Admins register matching names via CreateOauth2CredentialProvider during provider setup. The OAuthService, token vault, and encryption layer are still referenced by unrelated code paths (admin routes, connections UI) and will be removed in Phase 3 once the AgentCore-backed flow is validated end-to-end. Co-Authored-By: Claude Opus 4.7 --- .../integrations/external_mcp_client.py | 122 ++++++++++--- backend/src/apis/inference_api/main.py | 8 + .../integrations/test_external_mcp_client.py | 166 +++++++++++++++++- 3 files changed, 268 insertions(+), 28 deletions(-) diff --git a/backend/src/agents/main_agent/integrations/external_mcp_client.py b/backend/src/agents/main_agent/integrations/external_mcp_client.py index 403bc9ea..3a9c08bc 100644 --- a/backend/src/agents/main_agent/integrations/external_mcp_client.py +++ b/backend/src/agents/main_agent/integrations/external_mcp_client.py @@ -23,6 +23,11 @@ MCPTransport, ToolDefinition, ) +from agents.main_agent.integrations.agentcore_identity import ( + TokenResult, + WorkloadTokenUnavailableError, + get_agentcore_identity_client, +) from agents.main_agent.integrations.gateway_auth import get_sigv4_auth from agents.main_agent.integrations.oauth_auth import ( CompositeAuth, @@ -213,6 +218,11 @@ def __init__(self): """Initialize external MCP integration.""" # Cache key: tool_id for non-OAuth tools, "user_id:tool_id" for OAuth tools self.clients: dict[str, MCPClient] = {} + # Consent URLs collected during load_external_tools, keyed by user_id. + # Consumed (and cleared) by the inference route on the next response so + # they surface as an oauth_required SSE event. Shape: + # { user_id: [ { "provider_id": str, "authorization_url": str }, ... ] } + self.pending_consent: dict[str, list[dict[str, str]]] = {} def _get_cache_key(self, tool_id: str, user_id: Optional[str], requires_oauth: bool) -> str: """Get the cache key for a tool client.""" @@ -222,28 +232,59 @@ def _get_cache_key(self, tool_id: str, user_id: Optional[str], requires_oauth: b async def _get_oauth_token( self, - user_id: str, provider_id: str, - ) -> Optional[str]: - """ - Get decrypted OAuth token for a user and provider. + ) -> TokenResult: + """Fetch an OAuth token for `provider_id` via AgentCore Identity. - Args: - user_id: The user's ID - provider_id: The OAuth provider ID + The user is identified implicitly by the WorkloadAccessToken on + `BedrockAgentCoreContext` (populated from request headers by + `AgentCoreContextMiddleware`). Scopes are read from the platform's + OAuth provider record so organizations can change them without code + changes. + + Convention: `provider_id` is used verbatim as the AgentCore Identity + credential-provider name. Admins register providers with matching + names via `CreateOauth2CredentialProvider`. Returns: - Decrypted access token or None if not connected + `TokenResult` — either `.access_token` on cache hit or + `.authorization_url` when user consent is required. + + Raises: + WorkloadTokenUnavailableError: Not running inside an AgentCore + Runtime invocation (e.g. misconfigured middleware). """ - try: - from apis.shared.oauth.service import get_oauth_service - - oauth_service = get_oauth_service() - token = await oauth_service.get_decrypted_token(user_id, provider_id) - return token - except Exception as e: - logger.error("Error getting OAuth token") - return None + from apis.shared.oauth.provider_repository import get_provider_repository + + provider = await get_provider_repository().get_provider(provider_id) + scopes = provider.scopes if provider else [] + + identity_client = get_agentcore_identity_client() + return identity_client.get_token_for_user( + provider_name=provider_id, scopes=scopes + ) + + def _record_pending_consent( + self, user_id: str, provider_id: str, authorization_url: str + ) -> None: + """Stash a consent URL to be surfaced to the user via SSE.""" + bucket = self.pending_consent.setdefault(user_id, []) + # Dedupe on provider_id — if the user has two tools needing the same + # provider, one consent covers both. + if any(entry["provider_id"] == provider_id for entry in bucket): + return + bucket.append( + {"provider_id": provider_id, "authorization_url": authorization_url} + ) + + def drain_pending_consent(self, user_id: str) -> list[dict[str, str]]: + """Consume and return pending consent prompts for a user. + + Called by the inference route on each response so the frontend can + render "Connect to X" affordances. Idempotent across repeated calls + because consent entries are removed once read. + """ + return self.pending_consent.pop(user_id, []) async def load_external_tools( self, @@ -313,7 +354,7 @@ async def load_external_tools( logger.info(f"Using OIDC token forwarding for tool {tool_id}") elif requires_oauth: - # Use stored OAuth token from provider + # Fetch user-federated token via AgentCore Identity. if not user_id: logger.warning( f"Tool {tool_id} requires OAuth provider '{tool.requires_oauth_provider}' " @@ -321,17 +362,44 @@ async def load_external_tools( ) continue - token_to_use = await self._get_oauth_token( - user_id=user_id, - provider_id=tool.requires_oauth_provider, - ) + try: + token_result = await self._get_oauth_token( + provider_id=tool.requires_oauth_provider, + ) + except WorkloadTokenUnavailableError: + logger.error( + "No workload token on context for tool %s — " + "AgentCoreContextMiddleware may be misconfigured", + tool_id, + ) + continue + except Exception as e: + logger.error( + "Failed to fetch OAuth token for tool %s: %s", + tool_id, + e, + ) + continue - if not token_to_use: - logger.warning( - "User not connected to required OAuth provider for tool" + if token_result.requires_consent: + # Record the auth URL; the inference route will emit + # an oauth_required SSE event on the next response. + self._record_pending_consent( + user_id=user_id, + provider_id=tool.requires_oauth_provider, + authorization_url=token_result.authorization_url, + ) + logger.info( + "User consent required for tool %s (provider=%s); " + "skipping client creation until consent completes", + tool_id, + tool.requires_oauth_provider, ) - # Still create the client - it will fail gracefully when used - # The MCP server should return an appropriate error + # Skip loading this tool — the frontend will prompt + # the user to consent before the next invocation. + continue + + token_to_use = token_result.access_token # Create MCP client with optional token (works for both OAuth and OIDC) client = create_external_mcp_client( diff --git a/backend/src/apis/inference_api/main.py b/backend/src/apis/inference_api/main.py index de9308d0..e3f2177a 100644 --- a/backend/src/apis/inference_api/main.py +++ b/backend/src/apis/inference_api/main.py @@ -33,6 +33,8 @@ from contextlib import asynccontextmanager import logging +from apis.inference_api.middleware.agentcore_context import AgentCoreContextMiddleware + # Set up logging logging.basicConfig( level=logging.INFO, @@ -117,6 +119,12 @@ async def lifespan(app: FastAPI): ) logger.info("Added GZip middleware for response compression") +# Bridge AgentCore Runtime headers (WorkloadAccessToken, OAuth2CallbackUrl, +# session ID) into BedrockAgentCoreContext so downstream code can look up +# per-user OAuth tokens via AgentCore Identity. +app.add_middleware(AgentCoreContextMiddleware) +logger.info("Added AgentCore Runtime context middleware") + # Add CORS middleware - origins from CDK-provided CORS_ORIGINS env var _cors_origins = os.environ.get("CORS_ORIGINS", "").split(",") app.add_middleware( diff --git a/backend/tests/agents/main_agent/integrations/test_external_mcp_client.py b/backend/tests/agents/main_agent/integrations/test_external_mcp_client.py index 2023c2ed..b1e39992 100644 --- a/backend/tests/agents/main_agent/integrations/test_external_mcp_client.py +++ b/backend/tests/agents/main_agent/integrations/test_external_mcp_client.py @@ -5,9 +5,16 @@ """ import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from agents.main_agent.integrations.agentcore_identity import ( + TokenResult, + WorkloadTokenUnavailableError, +) from agents.main_agent.integrations.external_mcp_client import ( - extract_region_from_url, + ExternalMCPIntegration, detect_aws_service_from_url, + extract_region_from_url, ) @@ -62,3 +69,160 @@ def test_defaults_to_lambda_for_unknown_url(self): """Req 25.3: Defaults to 'lambda' for unrecognized URL patterns.""" url = "https://example.com/api/v1" assert detect_aws_service_from_url(url) == "lambda" + + +class TestGetOAuthTokenViaAgentCoreIdentity: + """Tests for ExternalMCPIntegration._get_oauth_token delegating to AgentCore Identity.""" + + @pytest.mark.asyncio + async def test_fetches_scopes_from_provider_repo_and_calls_identity(self): + """Provider scopes from the platform's provider record are forwarded to AgentCore.""" + integration = ExternalMCPIntegration() + + mock_provider = MagicMock() + mock_provider.scopes = ["openid", "profile", "email"] + mock_repo = MagicMock() + mock_repo.get_provider = AsyncMock(return_value=mock_provider) + + mock_identity = MagicMock() + mock_identity.get_token_for_user.return_value = TokenResult(access_token="tok") + + with patch( + "apis.shared.oauth.provider_repository.get_provider_repository", + return_value=mock_repo, + ), patch( + "agents.main_agent.integrations.external_mcp_client.get_agentcore_identity_client", + return_value=mock_identity, + ): + result = await integration._get_oauth_token(provider_id="google") + + assert result.access_token == "tok" + mock_identity.get_token_for_user.assert_called_once_with( + provider_name="google", scopes=["openid", "profile", "email"] + ) + + @pytest.mark.asyncio + async def test_returns_authorization_url_when_consent_required(self): + integration = ExternalMCPIntegration() + + mock_repo = MagicMock() + mock_repo.get_provider = AsyncMock( + return_value=MagicMock(scopes=["openid"]) + ) + + mock_identity = MagicMock() + mock_identity.get_token_for_user.return_value = TokenResult( + authorization_url="https://accounts.example.com/consent" + ) + + with patch( + "apis.shared.oauth.provider_repository.get_provider_repository", + return_value=mock_repo, + ), patch( + "agents.main_agent.integrations.external_mcp_client.get_agentcore_identity_client", + return_value=mock_identity, + ): + result = await integration._get_oauth_token(provider_id="google") + + assert result.requires_consent is True + assert result.authorization_url == "https://accounts.example.com/consent" + + @pytest.mark.asyncio + async def test_empty_scopes_when_provider_record_missing(self): + """Missing provider record falls back to empty scopes so the call still succeeds + and AgentCore can apply its own provider defaults.""" + integration = ExternalMCPIntegration() + + mock_repo = MagicMock() + mock_repo.get_provider = AsyncMock(return_value=None) + + mock_identity = MagicMock() + mock_identity.get_token_for_user.return_value = TokenResult(access_token="t") + + with patch( + "apis.shared.oauth.provider_repository.get_provider_repository", + return_value=mock_repo, + ), patch( + "agents.main_agent.integrations.external_mcp_client.get_agentcore_identity_client", + return_value=mock_identity, + ): + await integration._get_oauth_token(provider_id="unknown") + + mock_identity.get_token_for_user.assert_called_once_with( + provider_name="unknown", scopes=[] + ) + + @pytest.mark.asyncio + async def test_propagates_workload_token_unavailable(self): + """Misconfigured middleware should surface as a typed error, not be swallowed.""" + integration = ExternalMCPIntegration() + + mock_repo = MagicMock() + mock_repo.get_provider = AsyncMock(return_value=MagicMock(scopes=[])) + + mock_identity = MagicMock() + mock_identity.get_token_for_user.side_effect = WorkloadTokenUnavailableError( + "no ctx" + ) + + with patch( + "apis.shared.oauth.provider_repository.get_provider_repository", + return_value=mock_repo, + ), patch( + "agents.main_agent.integrations.external_mcp_client.get_agentcore_identity_client", + return_value=mock_identity, + ): + with pytest.raises(WorkloadTokenUnavailableError): + await integration._get_oauth_token(provider_id="google") + + +class TestPendingConsent: + """Tests for the per-user consent URL stash consumed by the SSE emitter.""" + + def test_record_and_drain_roundtrip(self): + integration = ExternalMCPIntegration() + integration._record_pending_consent( + user_id="u1", provider_id="google", authorization_url="https://a/1" + ) + + drained = integration.drain_pending_consent("u1") + + assert drained == [ + {"provider_id": "google", "authorization_url": "https://a/1"} + ] + + def test_drain_is_idempotent(self): + """Second drain returns empty — consent prompts are single-delivery.""" + integration = ExternalMCPIntegration() + integration._record_pending_consent("u1", "google", "https://a/1") + + integration.drain_pending_consent("u1") + second = integration.drain_pending_consent("u1") + + assert second == [] + + def test_dedupe_by_provider(self): + """Two tools needing the same provider produce one consent prompt.""" + integration = ExternalMCPIntegration() + integration._record_pending_consent("u1", "google", "https://a/1") + integration._record_pending_consent("u1", "google", "https://a/1") + + drained = integration.drain_pending_consent("u1") + + assert len(drained) == 1 + + def test_per_user_isolation(self): + integration = ExternalMCPIntegration() + integration._record_pending_consent("u1", "google", "https://a/1") + integration._record_pending_consent("u2", "slack", "https://a/2") + + assert integration.drain_pending_consent("u1") == [ + {"provider_id": "google", "authorization_url": "https://a/1"} + ] + assert integration.drain_pending_consent("u2") == [ + {"provider_id": "slack", "authorization_url": "https://a/2"} + ] + + def test_drain_empty_when_no_prompts(self): + integration = ExternalMCPIntegration() + assert integration.drain_pending_consent("u-nobody") == [] From 57b3f8763bd87cb90d7e5aaf58c7a718596e1e27 Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Wed, 22 Apr 2026 00:37:13 -0600 Subject: [PATCH 03/35] refactor(frontend): rename connections to connectors Rebrand the user-facing OAuth UI from "connections" to "connectors" for consistent vernacular across the product. Folders, classes, types, and route paths all follow the new name; the /settings/connections URL redirects to /settings/connectors. The backend /oauth/connections endpoint is preserved as a stable contract and translated at the service layer. Co-Authored-By: Claude Opus 4.7 --- .../main_agent/tools/oauth_tool_service.py | 6 +- .../app/admin/tools/pages/tool-form.page.ts | 2 +- frontend/ai.client/src/app/app.routes.ts | 7 +- .../settings/connections/connections.page.ts | 399 ------------------ .../app/settings/connections/models/index.ts | 1 - .../settings/connections/services/index.ts | 1 - .../{connections => connectors}/index.ts | 1 - .../app/settings/connectors/models/index.ts | 1 + .../models/oauth-connector.model.ts} | 25 +- .../services/connectors.service.spec.ts} | 30 +- .../services/connectors.service.ts} | 65 +-- .../app/settings/connectors/services/index.ts | 1 + .../oauth-callback/oauth-callback.page.ts | 14 +- .../connectors-settings.page.ts} | 96 ++--- .../src/app/settings/settings.page.ts | 2 +- .../src/app/settings/settings.routes.ts | 9 +- 16 files changed, 120 insertions(+), 540 deletions(-) delete mode 100644 frontend/ai.client/src/app/settings/connections/connections.page.ts delete mode 100644 frontend/ai.client/src/app/settings/connections/models/index.ts delete mode 100644 frontend/ai.client/src/app/settings/connections/services/index.ts rename frontend/ai.client/src/app/settings/{connections => connectors}/index.ts (60%) create mode 100644 frontend/ai.client/src/app/settings/connectors/models/index.ts rename frontend/ai.client/src/app/settings/{connections/models/oauth-connection.model.ts => connectors/models/oauth-connector.model.ts} (52%) rename frontend/ai.client/src/app/settings/{connections/services/connections.service.spec.ts => connectors/services/connectors.service.spec.ts} (81%) rename frontend/ai.client/src/app/settings/{connections/services/connections.service.ts => connectors/services/connectors.service.ts} (61%) create mode 100644 frontend/ai.client/src/app/settings/connectors/services/index.ts rename frontend/ai.client/src/app/settings/pages/{connections-settings/connections-settings.page.ts => connectors-settings/connectors-settings.page.ts} (75%) diff --git a/backend/src/agents/main_agent/tools/oauth_tool_service.py b/backend/src/agents/main_agent/tools/oauth_tool_service.py index 2d42f763..f5d47d7c 100644 --- a/backend/src/agents/main_agent/tools/oauth_tool_service.py +++ b/backend/src/agents/main_agent/tools/oauth_tool_service.py @@ -69,7 +69,7 @@ def not_connected_response(self, tool_name: str = "this tool") -> dict: Returns a dict suitable for returning from a tool. """ frontend_url = os.getenv(EnvVars.FRONTEND_URL, Defaults.FRONTEND_URL) - connect_url = f"{frontend_url}/settings/connections" + connect_url = f"{frontend_url}/settings/connectors" if self.needs_reauth: message = f"""⚠️ **Re-authorization Required** @@ -241,7 +241,7 @@ def get_connect_url(self, provider_id: str) -> str: URL to the connections page """ frontend_url = os.getenv(EnvVars.FRONTEND_URL, Defaults.FRONTEND_URL) - return f"{frontend_url}/settings/connections" + return f"{frontend_url}/settings/connectors" # Singleton instance @@ -314,7 +314,7 @@ def format_oauth_connection_guidance( return "" frontend_url = os.getenv(EnvVars.FRONTEND_URL, Defaults.FRONTEND_URL) - connect_url = f"{frontend_url}/settings/connections" + connect_url = f"{frontend_url}/settings/connectors" if len(missing_connections) == 1: conn = missing_connections[0] diff --git a/frontend/ai.client/src/app/admin/tools/pages/tool-form.page.ts b/frontend/ai.client/src/app/admin/tools/pages/tool-form.page.ts index 7df23c46..dbb20b8d 100644 --- a/frontend/ai.client/src/app/admin/tools/pages/tool-form.page.ts +++ b/frontend/ai.client/src/app/admin/tools/pages/tool-form.page.ts @@ -340,7 +340,7 @@ import {
-

User OAuth Connection

+

User OAuth Connector

If this tool requires access to a user's external account (e.g., Google Workspace, Microsoft 365), diff --git a/frontend/ai.client/src/app/app.routes.ts b/frontend/ai.client/src/app/app.routes.ts index a750cf21..1be59cdd 100644 --- a/frontend/ai.client/src/app/app.routes.ts +++ b/frontend/ai.client/src/app/app.routes.ts @@ -32,9 +32,14 @@ export const routes: Routes = [ path: 'auth/callback', loadComponent: () => import('./auth/callback/callback.page').then(m => m.CallbackPage), }, + { + path: 'connectors', + redirectTo: 'settings/connectors', + pathMatch: 'full', + }, { path: 'connections', - redirectTo: 'settings/connections', + redirectTo: 'settings/connectors', pathMatch: 'full', }, { diff --git a/frontend/ai.client/src/app/settings/connections/connections.page.ts b/frontend/ai.client/src/app/settings/connections/connections.page.ts deleted file mode 100644 index 904104c8..00000000 --- a/frontend/ai.client/src/app/settings/connections/connections.page.ts +++ /dev/null @@ -1,399 +0,0 @@ -import { - Component, - ChangeDetectionStrategy, - inject, - signal, - computed, - OnInit, -} from '@angular/core'; -import { Router, RouterLink, ActivatedRoute } from '@angular/router'; -import { NgIcon, provideIcons } from '@ng-icons/core'; -import { - heroArrowLeft, - heroLink, - heroCloud, - heroCodeBracket, - heroAcademicCap, - heroCheck, - heroExclamationTriangle, - heroArrowPath, - heroKey, -} from '@ng-icons/heroicons/outline'; -import { ConnectionsService } from './services'; -import { OAuthConnection, OAuthProviderType } from './models'; -import { ToastService } from '../../services/toast/toast.service'; - -@Component({ - selector: 'app-connections', - changeDetection: ChangeDetectionStrategy.OnPush, - imports: [RouterLink, NgIcon], - providers: [ - provideIcons({ - heroArrowLeft, - heroLink, - heroCloud, - heroCodeBracket, - heroAcademicCap, - heroCheck, - heroExclamationTriangle, - heroArrowPath, - heroKey, - }), - ], - host: { - class: 'block', - }, - template: ` -

-
- - - - Back to Chat - - - -
-

Connected Apps

-

- Connect your accounts to enable tools that require third-party authentication. -

-
- - - - - - - @if (connectionsResource.isLoading() && connections().length === 0) { -
-
-
-

- Loading connections... -

-
-
- } - - - @if (connectionsResource.error()) { -
-

Failed to load connections

-

Please check your connection and try again.

- -
- } - - - @if (!connectionsResource.isLoading() || connections().length > 0) { - @if (connections().length === 0) { - -
-
- -
-

No connections available

-

- There are no OAuth providers configured for your account. Contact an administrator if you need access to external tools. -

-
- } @else { -
- @for (connection of connections(); track connection.providerId) { -
- -
-
- -
-
-

- {{ connection.displayName }} -

- - - @if (isConnected(connection)) { -
- - - Connected - - @if (connection.connectedAt) { - - since {{ formatDate(connection.connectedAt) }} - - } -
- } @else if (connection.needsReauth || connection.status === 'needs_reauth' || connection.status === 'expired') { -
- - - Needs Re-authorization - -
- } @else { -

- Not connected -

- } -
-
- - -
- @if (isConnected(connection) && !connection.needsReauth && connection.status !== 'needs_reauth' && connection.status !== 'expired') { - - } @else { - - } -
-
- } -
- } - } - - - @if (connections().length > 0) { -
-

About Connections

-
-

- Connected apps allow certain tools to access external services on your behalf. -

-

- Re-authorization may be required if the app's permissions change or your token expires. -

-

- Disconnecting will revoke the app's access to your account. You can reconnect at any time. -

-
-
- } -
-
- `, -}) -export class ConnectionsPage implements OnInit { - connectionsService = inject(ConnectionsService); - private router = inject(Router); - private route = inject(ActivatedRoute); - private toast = inject(ToastService); - - readonly connectionsResource = this.connectionsService.connectionsResource; - - // Local state - connecting = signal(null); - disconnecting = signal(null); - - // Computed - readonly connections = computed(() => this.connectionsService.getConnections()); - - ngOnInit(): void { - this.handleCallbackParams(); - } - - /** - * Handle OAuth callback query parameters. - */ - private handleCallbackParams(): void { - const params = this.route.snapshot.queryParams; - - if (params['success'] === 'true') { - const provider = params['provider'] || 'the service'; - this.toast.success('Connected!', `Successfully connected to ${provider}.`); - // Refresh connections after successful OAuth - this.connectionsService.reload(); - // Clear query params - this.router.navigate([], { - relativeTo: this.route, - queryParams: {}, - replaceUrl: true, - }); - } else if (params['error']) { - const error = params['error']; - const provider = params['provider'] || 'the service'; - const description = params['error_description']; - - let message = `Failed to connect to ${provider}.`; - if (description) { - message = description; - } else if (error === 'access_denied') { - message = 'Authorization was denied. Please try again.'; - } else if (error === 'missing_params') { - message = 'Invalid callback. Please try again.'; - } - - this.toast.error('Connection Failed', message); - // Clear query params - this.router.navigate([], { - relativeTo: this.route, - queryParams: {}, - replaceUrl: true, - }); - } - } - - /** - * Check if a connection is actively connected. - */ - isConnected(connection: OAuthConnection): boolean { - return connection.status === 'connected'; - } - - /** - * Initiate OAuth connection flow. - */ - async connect(connection: OAuthConnection): Promise { - this.connecting.set(connection.providerId); - - try { - const redirectUrl = window.location.origin + '/settings/oauth/callback'; - const authUrl = await this.connectionsService.connect(connection.providerId, redirectUrl); - // Redirect to OAuth authorization - window.location.href = authUrl; - } catch (error: any) { - console.error('Error initiating connection:', error); - const message = error?.error?.detail || error?.message || 'Failed to initiate connection.'; - this.toast.error('Connection Error', message); - this.connecting.set(null); - } - } - - /** - * Disconnect from a provider. - */ - async disconnect(connection: OAuthConnection): Promise { - if (!confirm(`Are you sure you want to disconnect from ${connection.displayName}?`)) { - return; - } - - this.disconnecting.set(connection.providerId); - - try { - await this.connectionsService.disconnect(connection.providerId); - this.toast.success('Disconnected', `Successfully disconnected from ${connection.displayName}.`); - } catch (error: any) { - console.error('Error disconnecting:', error); - const message = error?.error?.detail || error?.message || 'Failed to disconnect.'; - this.toast.error('Disconnect Error', message); - } finally { - this.disconnecting.set(null); - } - } - - /** - * Get icon name for a provider. - */ - getProviderIcon(connection: OAuthConnection): string { - if (connection.iconName && connection.iconName !== 'heroLink') { - return connection.iconName; - } - // Default icons by type - switch (connection.providerType) { - case 'google': - case 'microsoft': - return 'heroCloud'; - case 'github': - return 'heroCodeBracket'; - case 'canvas': - return 'heroAcademicCap'; - default: - return 'heroLink'; - } - } - - /** - * Get icon container classes for a provider type. - */ - getProviderIconClasses(type: OAuthProviderType): string { - const baseClasses = 'flex size-12 shrink-0 items-center justify-center rounded-sm'; - switch (type) { - case 'google': - return `${baseClasses} bg-red-100 text-red-600 dark:bg-red-900/30 dark:text-red-400`; - case 'microsoft': - return `${baseClasses} bg-blue-100 text-blue-600 dark:bg-blue-900/30 dark:text-blue-400`; - case 'github': - return `${baseClasses} bg-gray-800 text-white dark:bg-gray-700`; - case 'canvas': - return `${baseClasses} bg-orange-100 text-orange-600 dark:bg-orange-900/30 dark:text-orange-400`; - default: - return `${baseClasses} bg-purple-100 text-purple-600 dark:bg-purple-900/30 dark:text-purple-400`; - } - } - - /** - * Format a date string for display. - */ - formatDate(dateString: string): string { - try { - const date = new Date(dateString); - return date.toLocaleDateString(undefined, { - month: 'short', - day: 'numeric', - year: 'numeric', - }); - } catch { - return dateString; - } - } -} diff --git a/frontend/ai.client/src/app/settings/connections/models/index.ts b/frontend/ai.client/src/app/settings/connections/models/index.ts deleted file mode 100644 index df906738..00000000 --- a/frontend/ai.client/src/app/settings/connections/models/index.ts +++ /dev/null @@ -1 +0,0 @@ -export * from './oauth-connection.model'; diff --git a/frontend/ai.client/src/app/settings/connections/services/index.ts b/frontend/ai.client/src/app/settings/connections/services/index.ts deleted file mode 100644 index d8c6cca4..00000000 --- a/frontend/ai.client/src/app/settings/connections/services/index.ts +++ /dev/null @@ -1 +0,0 @@ -export * from './connections.service'; diff --git a/frontend/ai.client/src/app/settings/connections/index.ts b/frontend/ai.client/src/app/settings/connectors/index.ts similarity index 60% rename from frontend/ai.client/src/app/settings/connections/index.ts rename to frontend/ai.client/src/app/settings/connectors/index.ts index b7348c17..b01b9746 100644 --- a/frontend/ai.client/src/app/settings/connections/index.ts +++ b/frontend/ai.client/src/app/settings/connectors/index.ts @@ -1,3 +1,2 @@ export * from './models'; export * from './services'; -export * from './connections.page'; diff --git a/frontend/ai.client/src/app/settings/connectors/models/index.ts b/frontend/ai.client/src/app/settings/connectors/models/index.ts new file mode 100644 index 00000000..eee4f455 --- /dev/null +++ b/frontend/ai.client/src/app/settings/connectors/models/index.ts @@ -0,0 +1 @@ +export * from './oauth-connector.model'; diff --git a/frontend/ai.client/src/app/settings/connections/models/oauth-connection.model.ts b/frontend/ai.client/src/app/settings/connectors/models/oauth-connector.model.ts similarity index 52% rename from frontend/ai.client/src/app/settings/connections/models/oauth-connection.model.ts rename to frontend/ai.client/src/app/settings/connectors/models/oauth-connector.model.ts index af88bfb7..f4fe81fb 100644 --- a/frontend/ai.client/src/app/settings/connections/models/oauth-connection.model.ts +++ b/frontend/ai.client/src/app/settings/connectors/models/oauth-connector.model.ts @@ -1,36 +1,39 @@ /** - * OAuth connection models for user-facing connections UI. + * OAuth connector models for the user-facing Connectors UI. + * + * A "connector" is a single user-to-provider OAuth link surfaced in + * /settings/connectors. The underlying backend endpoint still returns a + * `connections` array — we translate that at the service layer. */ -/** Connection status for user OAuth tokens */ -export type OAuthConnectionStatus = 'connected' | 'expired' | 'revoked' | 'needs_reauth'; +/** Connection status for a user's OAuth connector */ +export type OAuthConnectorStatus = 'connected' | 'expired' | 'revoked' | 'needs_reauth'; /** Supported OAuth provider types */ export type OAuthProviderType = 'google' | 'microsoft' | 'github' | 'canvas' | 'custom'; /** - * User's OAuth connection to a provider. - * Returned from GET /oauth/connections + * A user's OAuth connector for a single provider. */ -export interface OAuthConnection { +export interface OAuthConnector { providerId: string; displayName: string; providerType: OAuthProviderType; iconName: string; - status: OAuthConnectionStatus; + status: OAuthConnectorStatus; connectedAt: string | null; needsReauth: boolean; } /** - * Response from GET /oauth/connections + * Response shape returned by {@link ConnectorsService.fetchConnectors}. */ -export interface OAuthConnectionListResponse { - connections: OAuthConnection[]; +export interface OAuthConnectorListResponse { + connectors: OAuthConnector[]; } /** - * Available OAuth provider for connection. + * Available OAuth provider a user may connect to. * Returned from GET /oauth/providers (filtered by user roles) */ export interface OAuthProvider { diff --git a/frontend/ai.client/src/app/settings/connections/services/connections.service.spec.ts b/frontend/ai.client/src/app/settings/connectors/services/connectors.service.spec.ts similarity index 81% rename from frontend/ai.client/src/app/settings/connections/services/connections.service.spec.ts rename to frontend/ai.client/src/app/settings/connectors/services/connectors.service.spec.ts index 22fec454..a135a551 100644 --- a/frontend/ai.client/src/app/settings/connections/services/connections.service.spec.ts +++ b/frontend/ai.client/src/app/settings/connectors/services/connectors.service.spec.ts @@ -2,12 +2,12 @@ import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; import { TestBed } from '@angular/core/testing'; import { HttpClientTestingModule, HttpTestingController } from '@angular/common/http/testing'; import { signal } from '@angular/core'; -import { ConnectionsService } from './connections.service'; +import { ConnectorsService } from './connectors.service'; import { ConfigService } from '../../../services/config.service'; import { AuthService } from '../../../auth/auth.service'; -describe('ConnectionsService', () => { - let service: ConnectionsService; +describe('ConnectorsService', () => { + let service: ConnectorsService; let httpMock: HttpTestingController; beforeEach(() => { @@ -15,38 +15,38 @@ describe('ConnectionsService', () => { TestBed.configureTestingModule({ imports: [HttpClientTestingModule], providers: [ - ConnectionsService, + ConnectorsService, { provide: AuthService, useValue: { ensureAuthenticated: vi.fn().mockResolvedValue(undefined) } }, { provide: ConfigService, useValue: { appApiUrl: signal('http://localhost:8000') } }, ], }); - service = TestBed.inject(ConnectionsService); + service = TestBed.inject(ConnectorsService); httpMock = TestBed.inject(HttpTestingController); }); afterEach(() => { - httpMock.match(() => true); // discard pending requests + httpMock.match(() => true); TestBed.resetTestingModule(); }); - it('should fetch connections', async () => { + it('should fetch connectors', async () => { const mockResponse = { connections: [{ provider_id: 'google', status: 'connected' }] }; - const connectionsPromise = service.fetchConnections(); - + const connectorsPromise = service.fetchConnectors(); + await vi.waitFor(() => { httpMock.expectOne('http://localhost:8000/oauth/connections').flush(mockResponse); }); - const connections = await connectionsPromise; - expect(connections.connections[0].providerId).toBe('google'); + const connectors = await connectorsPromise; + expect(connectors.connectors[0].providerId).toBe('google'); }); it('should fetch providers', async () => { const mockResponse = { providers: [{ provider_id: 'google', name: 'Google' }], total: 1 }; const providersPromise = service.fetchProviders(); - + await vi.waitFor(() => { httpMock.expectOne('http://localhost:8000/oauth/providers').flush(mockResponse); }); @@ -60,7 +60,7 @@ describe('ConnectionsService', () => { const mockResponse = { authorization_url: 'https://oauth.example.com/auth' }; const connectPromise = service.connect('google'); - + await vi.waitFor(() => { httpMock.expectOne('http://localhost:8000/oauth/connect/google').flush(mockResponse); }); @@ -71,11 +71,11 @@ describe('ConnectionsService', () => { it('should disconnect from provider', async () => { const disconnectPromise = service.disconnect('google'); - + await vi.waitFor(() => { httpMock.expectOne('http://localhost:8000/oauth/connections/google').flush({}); }); await disconnectPromise; }); -}); \ No newline at end of file +}); diff --git a/frontend/ai.client/src/app/settings/connections/services/connections.service.ts b/frontend/ai.client/src/app/settings/connectors/services/connectors.service.ts similarity index 61% rename from frontend/ai.client/src/app/settings/connections/services/connections.service.ts rename to frontend/ai.client/src/app/settings/connectors/services/connectors.service.ts index b356ba11..55272e2f 100644 --- a/frontend/ai.client/src/app/settings/connections/services/connections.service.ts +++ b/frontend/ai.client/src/app/settings/connectors/services/connectors.service.ts @@ -4,15 +4,12 @@ import { firstValueFrom } from 'rxjs'; import { ConfigService } from '../../../services/config.service'; import { AuthService } from '../../../auth/auth.service'; import { - OAuthConnection, - OAuthConnectionListResponse, + OAuthConnector, + OAuthConnectorListResponse, OAuthProvider, OAuthProviderListResponse, } from '../models'; -/** - * Convert snake_case to camelCase for frontend models. - */ function toCamelCase(obj: Record): Record { const result: Record = {}; for (const [key, value] of Object.entries(obj)) { @@ -23,34 +20,28 @@ function toCamelCase(obj: Record): Record { } /** - * Service for managing user OAuth connections. + * Service for managing user OAuth connectors. * - * Provides access to available providers and user's connections, + * Provides access to available providers and user's connectors, * as well as connect/disconnect operations. */ @Injectable({ providedIn: 'root' }) -export class ConnectionsService { +export class ConnectorsService { private http = inject(HttpClient); private authService = inject(AuthService); private config = inject(ConfigService); private readonly baseUrl = computed(() => `${this.config.appApiUrl()}/oauth`); - /** - * Reactive resource for fetching user's OAuth connections. - */ - readonly connectionsResource = resource({ + readonly connectorsResource = resource({ loader: async () => { await this.authService.ensureAuthenticated(); - return this.fetchConnections(); + return this.fetchConnectors(); } }); - /** - * Reactive resource for fetching available OAuth providers. - */ readonly providersResource = resource({ loader: async () => { await this.authService.ensureAuthenticated(); @@ -58,42 +49,28 @@ export class ConnectionsService { } }); - /** - * Get all user connections (from resource). - */ - getConnections(): OAuthConnection[] { - return this.connectionsResource.value()?.connections ?? []; + getConnectors(): OAuthConnector[] { + return this.connectorsResource.value()?.connectors ?? []; } - /** - * Get all available providers (from resource). - */ getProviders(): OAuthProvider[] { return this.providersResource.value()?.providers ?? []; } - /** - * Get a connection by provider ID. - */ - getConnectionByProviderId(providerId: string): OAuthConnection | undefined { - return this.getConnections().find(c => c.providerId === providerId); + getConnectorByProviderId(providerId: string): OAuthConnector | undefined { + return this.getConnectors().find(c => c.providerId === providerId); } - /** - * Fetch user's OAuth connections from the API. - */ - async fetchConnections(): Promise { + async fetchConnectors(): Promise { + // Backend endpoint still returns a `connections` array — translate at the service layer. const response = await firstValueFrom( this.http.get(`${this.baseUrl()}/connections`) ); return { - connections: response.connections.map((c: any) => toCamelCase(c) as OAuthConnection), + connectors: response.connections.map((c: any) => toCamelCase(c) as OAuthConnector), }; } - /** - * Fetch available OAuth providers from the API. - */ async fetchProviders(): Promise { const response = await firstValueFrom( this.http.get(`${this.baseUrl()}/providers`) @@ -104,10 +81,6 @@ export class ConnectionsService { }; } - /** - * Initiate OAuth connection flow. - * Returns the authorization URL to redirect to. - */ async connect(providerId: string, redirectUrl?: string): Promise { const params = redirectUrl ? `?redirect=${encodeURIComponent(redirectUrl)}` : ''; const response = await firstValueFrom( @@ -116,21 +89,15 @@ export class ConnectionsService { return response.authorization_url; } - /** - * Disconnect from an OAuth provider. - */ async disconnect(providerId: string): Promise { await firstValueFrom( this.http.delete(`${this.baseUrl()}/connections/${providerId}`) ); - this.connectionsResource.reload(); + this.connectorsResource.reload(); } - /** - * Reload both resources. - */ reload(): void { - this.connectionsResource.reload(); + this.connectorsResource.reload(); this.providersResource.reload(); } } diff --git a/frontend/ai.client/src/app/settings/connectors/services/index.ts b/frontend/ai.client/src/app/settings/connectors/services/index.ts new file mode 100644 index 00000000..32879454 --- /dev/null +++ b/frontend/ai.client/src/app/settings/connectors/services/index.ts @@ -0,0 +1 @@ +export * from './connectors.service'; diff --git a/frontend/ai.client/src/app/settings/oauth-callback/oauth-callback.page.ts b/frontend/ai.client/src/app/settings/oauth-callback/oauth-callback.page.ts index 3ab5489a..df98aa94 100644 --- a/frontend/ai.client/src/app/settings/oauth-callback/oauth-callback.page.ts +++ b/frontend/ai.client/src/app/settings/oauth-callback/oauth-callback.page.ts @@ -85,7 +85,7 @@ type CallbackState = 'processing' | 'success' | 'error'; }

- Redirecting to your connections... + Redirecting to your connectors...

} @@ -587,8 +587,8 @@ export class OAuthCallbackPage implements OnInit, OnDestroy { } else if (params['error']) { this.handleError(params); } else { - // No valid params, redirect to connections - this.redirectToConnections(); + // No valid params, redirect to connectors + this.redirectToConnectors(); } }, 800); } @@ -602,7 +602,7 @@ export class OAuthCallbackPage implements OnInit, OnDestroy { // Redirect after showing success setTimeout(() => { - this.redirectToConnections({ success: 'true', provider }); + this.redirectToConnectors({ success: 'true', provider }); }, 1500); } @@ -632,12 +632,12 @@ export class OAuthCallbackPage implements OnInit, OnDestroy { // Redirect after showing error setTimeout(() => { - this.redirectToConnections({ error, provider }); + this.redirectToConnectors({ error, provider }); }, 2500); } - private redirectToConnections(queryParams?: Record): void { - this.router.navigate(['/settings/connections'], { + private redirectToConnectors(queryParams?: Record): void { + this.router.navigate(['/settings/connectors'], { queryParams, replaceUrl: true, }); diff --git a/frontend/ai.client/src/app/settings/pages/connections-settings/connections-settings.page.ts b/frontend/ai.client/src/app/settings/pages/connectors-settings/connectors-settings.page.ts similarity index 75% rename from frontend/ai.client/src/app/settings/pages/connections-settings/connections-settings.page.ts rename to frontend/ai.client/src/app/settings/pages/connectors-settings/connectors-settings.page.ts index a1247468..282e3d06 100644 --- a/frontend/ai.client/src/app/settings/pages/connections-settings/connections-settings.page.ts +++ b/frontend/ai.client/src/app/settings/pages/connectors-settings/connectors-settings.page.ts @@ -18,12 +18,12 @@ import { heroArrowPath, heroKey, } from '@ng-icons/heroicons/outline'; -import { ConnectionsService } from '../../connections/services'; -import { OAuthConnection, OAuthProviderType } from '../../connections/models'; +import { ConnectorsService } from '../../connectors/services'; +import { OAuthConnector, OAuthProviderType } from '../../connectors/models'; import { ToastService } from '../../../services/toast/toast.service'; @Component({ - selector: 'app-connections-settings', + selector: 'app-connectors-settings', changeDetection: ChangeDetectionStrategy.OnPush, imports: [NgIcon], providers: [ @@ -43,29 +43,29 @@ import { ToastService } from '../../../services/toast/toast.service';
-

Connections

+

Connectors

Connect your accounts to enable tools that require third-party authentication.

- @if (connectionsResource.isLoading() && connections().length === 0) { + @if (connectorsResource.isLoading() && connectors().length === 0) {
-

Loading connections...

+

Loading connectors...

} - @if (connectionsResource.error()) { + @if (connectorsResource.error()) {
-

Failed to load connections

+

Failed to load connectors

Please check your connection and try again.

} - - @if (!connectionsResource.isLoading() || connections().length > 0) { - @if (connections().length === 0 && !connectionsResource.error()) { + + @if (!connectorsResource.isLoading() || connectors().length > 0) { + @if (connectors().length === 0 && !connectorsResource.error()) {
-

No connections available

+

No connectors available

There are no OAuth providers configured for your account. Contact an administrator if you need access to external tools.

} @else {
- @for (connection of connections(); track connection.providerId) { + @for (connector of connectors(); track connector.providerId) {
-
- +
+

- {{ connection.displayName }} + {{ connector.displayName }}

- @if (isConnected(connection)) { + @if (isConnected(connector)) {
Connected - @if (connection.connectedAt) { + @if (connector.connectedAt) { - since {{ formatDate(connection.connectedAt) }} + since {{ formatDate(connector.connectedAt) }} }
- } @else if (connection.needsReauth || connection.status === 'needs_reauth' || connection.status === 'expired') { + } @else if (connector.needsReauth || connector.status === 'needs_reauth' || connector.status === 'expired') {
@@ -130,13 +130,13 @@ import { ToastService } from '../../../services/toast/toast.service';
- @if (isConnected(connection) && !connection.needsReauth && connection.status !== 'needs_reauth' && connection.status !== 'expired') { + @if (isConnected(connector) && !connector.needsReauth && connector.status !== 'needs_reauth' && connector.status !== 'expired') { } @else {
`, }) -export class ConnectionsSettingsPage implements OnInit { - readonly connectionsService = inject(ConnectionsService); +export class ConnectorsSettingsPage implements OnInit { + readonly connectorsService = inject(ConnectorsService); private router = inject(Router); private route = inject(ActivatedRoute); private toast = inject(ToastService); - readonly connectionsResource = this.connectionsService.connectionsResource; + readonly connectorsResource = this.connectorsService.connectorsResource; connecting = signal(null); disconnecting = signal(null); - readonly connections = computed(() => this.connectionsService.getConnections()); + readonly connectors = computed(() => this.connectorsService.getConnectors()); ngOnInit(): void { this.handleCallbackParams(); @@ -193,7 +193,7 @@ export class ConnectionsSettingsPage implements OnInit { if (params['success'] === 'true') { const provider = params['provider'] || 'the service'; this.toast.success('Connected!', `Successfully connected to ${provider}.`); - this.connectionsService.reload(); + this.connectorsService.reload(); this.router.navigate([], { relativeTo: this.route, queryParams: {}, @@ -222,16 +222,16 @@ export class ConnectionsSettingsPage implements OnInit { } } - isConnected(connection: OAuthConnection): boolean { - return connection.status === 'connected'; + isConnected(connector: OAuthConnector): boolean { + return connector.status === 'connected'; } - async connect(connection: OAuthConnection): Promise { - this.connecting.set(connection.providerId); + async connect(connector: OAuthConnector): Promise { + this.connecting.set(connector.providerId); try { const redirectUrl = window.location.origin + '/settings/oauth/callback'; - const authUrl = await this.connectionsService.connect(connection.providerId, redirectUrl); + const authUrl = await this.connectorsService.connect(connector.providerId, redirectUrl); window.location.href = authUrl; } catch (error: unknown) { const err = error as { error?: { detail?: string }; message?: string }; @@ -241,16 +241,16 @@ export class ConnectionsSettingsPage implements OnInit { } } - async disconnect(connection: OAuthConnection): Promise { - if (!confirm(`Are you sure you want to disconnect from ${connection.displayName}?`)) { + async disconnect(connector: OAuthConnector): Promise { + if (!confirm(`Are you sure you want to disconnect from ${connector.displayName}?`)) { return; } - this.disconnecting.set(connection.providerId); + this.disconnecting.set(connector.providerId); try { - await this.connectionsService.disconnect(connection.providerId); - this.toast.success('Disconnected', `Successfully disconnected from ${connection.displayName}.`); + await this.connectorsService.disconnect(connector.providerId); + this.toast.success('Disconnected', `Successfully disconnected from ${connector.displayName}.`); } catch (error: unknown) { const err = error as { error?: { detail?: string }; message?: string }; const message = err?.error?.detail || err?.message || 'Failed to disconnect.'; @@ -260,11 +260,11 @@ export class ConnectionsSettingsPage implements OnInit { } } - getProviderIcon(connection: OAuthConnection): string { - if (connection.iconName && connection.iconName !== 'heroLink') { - return connection.iconName; + getProviderIcon(connector: OAuthConnector): string { + if (connector.iconName && connector.iconName !== 'heroLink') { + return connector.iconName; } - switch (connection.providerType) { + switch (connector.providerType) { case 'google': case 'microsoft': return 'heroCloud'; diff --git a/frontend/ai.client/src/app/settings/settings.page.ts b/frontend/ai.client/src/app/settings/settings.page.ts index c9d064ab..5a273d55 100644 --- a/frontend/ai.client/src/app/settings/settings.page.ts +++ b/frontend/ai.client/src/app/settings/settings.page.ts @@ -114,7 +114,7 @@ export class SettingsPage { { label: 'Profile', icon: 'heroUser', route: '/settings/profile', description: 'Your personal information' }, { label: 'Appearance', icon: 'heroPaintBrush', route: '/settings/appearance', description: 'Theme and display' }, { label: 'Chat', icon: 'heroChatBubbleLeftRight', route: '/settings/chat', description: 'Chat preferences' }, - { label: 'Connections', icon: 'heroLink', route: '/settings/connections', description: 'Connected apps' }, + { label: 'Connectors', icon: 'heroLink', route: '/settings/connectors', description: 'Connected apps' }, { label: 'API Keys', icon: 'heroKey', route: '/settings/api-keys', description: 'API key management' }, { label: 'Usage', icon: 'heroChartBar', route: '/settings/usage', description: 'Usage and billing' }, ]; diff --git a/frontend/ai.client/src/app/settings/settings.routes.ts b/frontend/ai.client/src/app/settings/settings.routes.ts index 959ff2c7..121976fd 100644 --- a/frontend/ai.client/src/app/settings/settings.routes.ts +++ b/frontend/ai.client/src/app/settings/settings.routes.ts @@ -22,9 +22,14 @@ export const settingsRoutes: Routes = [ import('./pages/chat-preferences/chat-preferences-settings.page').then(m => m.ChatPreferencesSettingsPage), }, { - path: 'connections', + path: 'connectors', loadComponent: () => - import('./pages/connections-settings/connections-settings.page').then(m => m.ConnectionsSettingsPage), + import('./pages/connectors-settings/connectors-settings.page').then(m => m.ConnectorsSettingsPage), + }, + { + path: 'connections', + redirectTo: 'connectors', + pathMatch: 'full', }, { path: 'api-keys', From 8f0bf7b9ce525dacfdb3025e46513efc679d7220 Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Wed, 22 Apr 2026 10:40:16 -0600 Subject: [PATCH 04/35] feat(connectors): add AgentCore credential-provider registrar service Wraps bedrock-agentcore-control for admin-side OAuth2 credential provider CRUD: create/update/delete/get with vendor mapping (Google/Microsoft/GitHub to their native vendors; Canvas/Custom routed through CustomOauth2 via an OIDC discovery URL or explicit authorization-server metadata). Domain errors map 404/conflict/invalid-custom to typed exceptions so route handlers can translate cleanly. Update is intentionally non-partial: AgentCore's UpdateOauth2CredentialProvider requires a full oauth2ProviderConfigInput and Get never returns the stored client_secret, so credential rotation always re-submits both clientId and clientSecret. 17 unit tests cover every vendor path, error mapping, and the Custom-only discovery rule. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../apis/shared/oauth/agentcore_registrar.py | 324 ++++++++++++++++++ .../shared/test_oauth_agentcore_registrar.py | 264 ++++++++++++++ 2 files changed, 588 insertions(+) create mode 100644 backend/src/apis/shared/oauth/agentcore_registrar.py create mode 100644 backend/tests/shared/test_oauth_agentcore_registrar.py diff --git a/backend/src/apis/shared/oauth/agentcore_registrar.py b/backend/src/apis/shared/oauth/agentcore_registrar.py new file mode 100644 index 00000000..3e140bac --- /dev/null +++ b/backend/src/apis/shared/oauth/agentcore_registrar.py @@ -0,0 +1,324 @@ +"""AgentCore Identity credential-provider registrar. + +Wraps `bedrock-agentcore-control` for managing OAuth2 credential providers +owned by AgentCore Identity. Callers upsert one of our `OAuthProvider` +records by first registering the client_id + client_secret here, then +storing the returned `callbackUrl` and `credentialProviderArn` on the +DynamoDB record. + +Division of authority: + +- AgentCore Identity: clientId, clientSecret, vendor-specific endpoint + config, callback URL. Returns a `credentialProviderArn` that identifies + the provider within the default token vault. +- Our DynamoDB: displayName, scopes, allowedRoles, iconName, enabled flag. + +Update semantics matter: `UpdateOauth2CredentialProvider` is NOT a partial +update — the full `oauth2ProviderConfigInput` (including clientId and +clientSecret) must be re-submitted on every call. Because +`GetOauth2CredentialProvider` returns only `clientSecretArn` (not the +secret value), credential rotation always requires the admin to re-enter +both fields. `update_credential_provider` enforces this. +""" + +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import boto3 +from botocore.exceptions import ClientError + +from .models import OAuthProviderType + +logger = logging.getLogger(__name__) + + +# Mapping from our OAuthProviderType to AgentCore's credentialProviderVendor +# and the corresponding vendor-specific config key inside oauth2ProviderConfigInput. +# CANVAS routes through CustomOauth2 because AgentCore Identity does not ship +# a first-class Canvas vendor. +_VENDOR_BY_TYPE: Dict[OAuthProviderType, str] = { + OAuthProviderType.GOOGLE: "GoogleOauth2", + OAuthProviderType.MICROSOFT: "MicrosoftOauth2", + OAuthProviderType.GITHUB: "GithubOauth2", + OAuthProviderType.CANVAS: "CustomOauth2", + OAuthProviderType.CUSTOM: "CustomOauth2", +} + +_CONFIG_KEY_BY_TYPE: Dict[OAuthProviderType, str] = { + OAuthProviderType.GOOGLE: "googleOauth2ProviderConfig", + OAuthProviderType.MICROSOFT: "microsoftOauth2ProviderConfig", + OAuthProviderType.GITHUB: "githubOauth2ProviderConfig", + OAuthProviderType.CANVAS: "customOauth2ProviderConfig", + OAuthProviderType.CUSTOM: "customOauth2ProviderConfig", +} + + +@dataclass(frozen=True) +class CredentialProviderInfo: + """AgentCore Identity record for one OAuth2 credential provider. + + `client_id` is populated on `get_credential_provider`; Create/Update + responses include it in `oauth2ProviderConfigOutput` when the vendor + echoes it back, and we surface it when present. `client_secret` is + never returned by AgentCore — only `client_secret_arn`. + """ + + provider_id: str + vendor: str + credential_provider_arn: str + client_secret_arn: str + callback_url: str + client_id: Optional[str] = None + + +class CredentialProviderNotFoundError(LookupError): + """Raised when an AgentCore credential provider does not exist.""" + + +class CredentialProviderConflictError(RuntimeError): + """Raised when creating a provider that already exists in AgentCore.""" + + +class InvalidCustomProviderConfigError(ValueError): + """Raised when Custom vendor is selected without exactly one discovery mode.""" + + +class AgentCoreRegistrar: + """Thin wrapper around `bedrock-agentcore-control` for OAuth2 providers. + + Stateless apart from the boto3 client. Safe to share across requests. + """ + + def __init__( + self, + *, + region: Optional[str] = None, + client: Any = None, + ): + self._region = region or os.environ.get("AWS_REGION", "us-east-1") + self._client = client or boto3.client( + "bedrock-agentcore-control", region_name=self._region + ) + + # ------------------------------------------------------------------ create + def create_credential_provider( + self, + *, + provider_id: str, + provider_type: OAuthProviderType, + client_id: str, + client_secret: str, + discovery_url: Optional[str] = None, + authorization_server_metadata: Optional[Dict[str, Any]] = None, + ) -> CredentialProviderInfo: + """Register a new OAuth2 credential provider in AgentCore Identity. + + Raises: + CredentialProviderConflictError: A provider with `provider_id` + already exists. + InvalidCustomProviderConfigError: Custom/Canvas vendor was used + without exactly one of `discovery_url` or + `authorization_server_metadata`. + botocore.exceptions.ClientError: Any other AWS error bubbles up. + """ + vendor, config_input = self._build_config_input( + provider_type=provider_type, + client_id=client_id, + client_secret=client_secret, + discovery_url=discovery_url, + authorization_server_metadata=authorization_server_metadata, + ) + + try: + response = self._client.create_oauth2_credential_provider( + name=provider_id, + credentialProviderVendor=vendor, + oauth2ProviderConfigInput=config_input, + ) + except ClientError as err: + code = err.response.get("Error", {}).get("Code") + if code in ("ConflictException", "ResourceAlreadyExistsException"): + raise CredentialProviderConflictError( + f"AgentCore credential provider '{provider_id}' already exists" + ) from err + raise + + return self._info_from_response( + provider_id=provider_id, vendor=vendor, response=response + ) + + # ------------------------------------------------------------------ update + def update_credential_provider( + self, + *, + provider_id: str, + provider_type: OAuthProviderType, + client_id: str, + client_secret: str, + discovery_url: Optional[str] = None, + authorization_server_metadata: Optional[Dict[str, Any]] = None, + ) -> CredentialProviderInfo: + """Replace the AgentCore provider's full config. + + `UpdateOauth2CredentialProvider` requires the full + `oauth2ProviderConfigInput`, so the caller must supply both + `client_id` and `client_secret`. There is no "change only the + secret" path — Get does not return the existing secret, and the API + does not support partial updates. + + Raises: + CredentialProviderNotFoundError: No such provider. + InvalidCustomProviderConfigError: Custom/Canvas without exactly + one discovery mode. + botocore.exceptions.ClientError: Any other AWS error. + """ + vendor, config_input = self._build_config_input( + provider_type=provider_type, + client_id=client_id, + client_secret=client_secret, + discovery_url=discovery_url, + authorization_server_metadata=authorization_server_metadata, + ) + + try: + response = self._client.update_oauth2_credential_provider( + name=provider_id, + credentialProviderVendor=vendor, + oauth2ProviderConfigInput=config_input, + ) + except ClientError as err: + if self._is_not_found(err): + raise CredentialProviderNotFoundError(provider_id) from err + raise + + return self._info_from_response( + provider_id=provider_id, vendor=vendor, response=response + ) + + # --------------------------------------------------------------------- get + def get_credential_provider(self, provider_id: str) -> CredentialProviderInfo: + """Fetch the AgentCore record for `provider_id`. + + Raises: + CredentialProviderNotFoundError: No such provider. + """ + try: + response = self._client.get_oauth2_credential_provider(name=provider_id) + except ClientError as err: + if self._is_not_found(err): + raise CredentialProviderNotFoundError(provider_id) from err + raise + + return self._info_from_response( + provider_id=provider_id, + vendor=response["credentialProviderVendor"], + response=response, + ) + + # ------------------------------------------------------------------ delete + def delete_credential_provider(self, provider_id: str) -> None: + """Delete the AgentCore provider. Missing providers are treated as success.""" + try: + self._client.delete_oauth2_credential_provider(name=provider_id) + except ClientError as err: + if self._is_not_found(err): + logger.info( + "AgentCore provider '%s' already absent; delete is a no-op", + provider_id, + ) + return + raise + + # ------------------------------------------------------------- build helper + def _build_config_input( + self, + *, + provider_type: OAuthProviderType, + client_id: str, + client_secret: str, + discovery_url: Optional[str], + authorization_server_metadata: Optional[Dict[str, Any]], + ) -> tuple[str, Dict[str, Any]]: + """Return `(vendor, oauth2ProviderConfigInput)` for AgentCore.""" + try: + vendor = _VENDOR_BY_TYPE[provider_type] + config_key = _CONFIG_KEY_BY_TYPE[provider_type] + except KeyError as err: + raise ValueError(f"Unsupported OAuth provider type: {provider_type}") from err + + vendor_config: Dict[str, Any] = { + "clientId": client_id, + "clientSecret": client_secret, + } + + if config_key == "customOauth2ProviderConfig": + vendor_config["oauthDiscovery"] = self._build_oauth_discovery( + discovery_url=discovery_url, + authorization_server_metadata=authorization_server_metadata, + ) + elif discovery_url or authorization_server_metadata: + raise ValueError( + f"Discovery config is only valid for CustomOauth2; " + f"provider_type={provider_type} ignores it" + ) + + return vendor, {config_key: vendor_config} + + @staticmethod + def _build_oauth_discovery( + *, + discovery_url: Optional[str], + authorization_server_metadata: Optional[Dict[str, Any]], + ) -> Dict[str, Any]: + if bool(discovery_url) == bool(authorization_server_metadata): + raise InvalidCustomProviderConfigError( + "CustomOauth2 requires exactly one of discovery_url or " + "authorization_server_metadata" + ) + if discovery_url: + return {"discoveryUrl": discovery_url} + return {"authorizationServerMetadata": authorization_server_metadata} + + # ----------------------------------------------------------- parse helpers + @staticmethod + def _info_from_response( + *, provider_id: str, vendor: str, response: Dict[str, Any] + ) -> CredentialProviderInfo: + client_secret = response.get("clientSecretArn") or {} + output_config = response.get("oauth2ProviderConfigOutput") or {} + # Each vendor variant nests its own output object; the clientId lives + # one level deeper when present. We tolerate its absence. + client_id: Optional[str] = None + for nested in output_config.values(): + if isinstance(nested, dict) and "clientId" in nested: + client_id = nested["clientId"] + break + + return CredentialProviderInfo( + provider_id=provider_id, + vendor=vendor, + credential_provider_arn=response["credentialProviderArn"], + client_secret_arn=client_secret.get("secretArn", ""), + callback_url=response.get("callbackUrl", ""), + client_id=client_id, + ) + + @staticmethod + def _is_not_found(err: ClientError) -> bool: + code = err.response.get("Error", {}).get("Code") + return code in ("ResourceNotFoundException", "NotFoundException") + + +_default_registrar: Optional[AgentCoreRegistrar] = None + + +def get_agentcore_registrar() -> AgentCoreRegistrar: + """Return the process-wide `AgentCoreRegistrar` singleton.""" + global _default_registrar + if _default_registrar is None: + _default_registrar = AgentCoreRegistrar() + return _default_registrar diff --git a/backend/tests/shared/test_oauth_agentcore_registrar.py b/backend/tests/shared/test_oauth_agentcore_registrar.py new file mode 100644 index 00000000..09cf3443 --- /dev/null +++ b/backend/tests/shared/test_oauth_agentcore_registrar.py @@ -0,0 +1,264 @@ +"""AgentCore Identity credential-provider registrar tests. + +Mocks the `bedrock-agentcore-control` boto3 client directly — these tests +verify our translation layer (our OAuthProviderType → AgentCore vendor + +config shape), not AWS behaviour. +""" + +from unittest.mock import MagicMock + +import pytest +from botocore.exceptions import ClientError + +from apis.shared.oauth.agentcore_registrar import ( + AgentCoreRegistrar, + CredentialProviderConflictError, + CredentialProviderNotFoundError, + InvalidCustomProviderConfigError, +) +from apis.shared.oauth.models import OAuthProviderType + + +def _client_error(code: str) -> ClientError: + return ClientError( + error_response={"Error": {"Code": code, "Message": code}}, + operation_name="op", + ) + + +def _create_response( + *, arn="arn:aws:acps:us-east-1:123:token-vault/default/oauth2credentialprovider/p", + secret_arn="arn:aws:secretsmanager:us-east-1:123:secret:s", + callback="https://example.invalid/callback/p", + config_output=None, +): + return { + "credentialProviderArn": arn, + "clientSecretArn": {"secretArn": secret_arn}, + "callbackUrl": callback, + "name": "p", + "oauth2ProviderConfigOutput": config_output or {}, + } + + +@pytest.fixture +def boto_client(): + return MagicMock() + + +@pytest.fixture +def registrar(boto_client): + return AgentCoreRegistrar(client=boto_client, region="us-east-1") + + +class TestCreateCredentialProvider: + def test_google_uses_google_vendor_and_config_key(self, registrar, boto_client): + boto_client.create_oauth2_credential_provider.return_value = _create_response() + + info = registrar.create_credential_provider( + provider_id="google-workspace", + provider_type=OAuthProviderType.GOOGLE, + client_id="cid", + client_secret="sec", + ) + + boto_client.create_oauth2_credential_provider.assert_called_once_with( + name="google-workspace", + credentialProviderVendor="GoogleOauth2", + oauth2ProviderConfigInput={ + "googleOauth2ProviderConfig": {"clientId": "cid", "clientSecret": "sec"} + }, + ) + assert info.vendor == "GoogleOauth2" + assert info.callback_url.endswith("/callback/p") + + @pytest.mark.parametrize( + "provider_type,expected_vendor,expected_key", + [ + (OAuthProviderType.MICROSOFT, "MicrosoftOauth2", "microsoftOauth2ProviderConfig"), + (OAuthProviderType.GITHUB, "GithubOauth2", "githubOauth2ProviderConfig"), + ], + ) + def test_other_known_vendors( + self, registrar, boto_client, provider_type, expected_vendor, expected_key + ): + boto_client.create_oauth2_credential_provider.return_value = _create_response() + + registrar.create_credential_provider( + provider_id="p", + provider_type=provider_type, + client_id="cid", + client_secret="sec", + ) + + call = boto_client.create_oauth2_credential_provider.call_args.kwargs + assert call["credentialProviderVendor"] == expected_vendor + assert expected_key in call["oauth2ProviderConfigInput"] + + def test_custom_requires_discovery(self, registrar): + with pytest.raises(InvalidCustomProviderConfigError): + registrar.create_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.CUSTOM, + client_id="cid", + client_secret="sec", + ) + + def test_custom_rejects_both_discovery_modes(self, registrar): + with pytest.raises(InvalidCustomProviderConfigError): + registrar.create_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.CUSTOM, + client_id="cid", + client_secret="sec", + discovery_url="https://idp.example/.well-known/openid-configuration", + authorization_server_metadata={"authorizationEndpoint": "https://idp/auth"}, + ) + + def test_custom_with_discovery_url(self, registrar, boto_client): + boto_client.create_oauth2_credential_provider.return_value = _create_response() + + registrar.create_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.CUSTOM, + client_id="cid", + client_secret="sec", + discovery_url="https://idp.example/.well-known/openid-configuration", + ) + + config = boto_client.create_oauth2_credential_provider.call_args.kwargs[ + "oauth2ProviderConfigInput" + ]["customOauth2ProviderConfig"] + assert config["oauthDiscovery"] == { + "discoveryUrl": "https://idp.example/.well-known/openid-configuration" + } + + def test_canvas_routes_through_custom_vendor(self, registrar, boto_client): + boto_client.create_oauth2_credential_provider.return_value = _create_response() + + registrar.create_credential_provider( + provider_id="canvas", + provider_type=OAuthProviderType.CANVAS, + client_id="cid", + client_secret="sec", + authorization_server_metadata={ + "authorizationEndpoint": "https://canvas.example/login/oauth2/auth", + "tokenEndpoint": "https://canvas.example/login/oauth2/token", + }, + ) + + call = boto_client.create_oauth2_credential_provider.call_args.kwargs + assert call["credentialProviderVendor"] == "CustomOauth2" + assert "customOauth2ProviderConfig" in call["oauth2ProviderConfigInput"] + + def test_known_vendor_rejects_discovery_params(self, registrar): + with pytest.raises(ValueError, match="only valid for CustomOauth2"): + registrar.create_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.GOOGLE, + client_id="cid", + client_secret="sec", + discovery_url="https://idp.example/.well-known/openid-configuration", + ) + + def test_conflict_maps_to_domain_error(self, registrar, boto_client): + boto_client.create_oauth2_credential_provider.side_effect = _client_error( + "ConflictException" + ) + + with pytest.raises(CredentialProviderConflictError): + registrar.create_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.GOOGLE, + client_id="cid", + client_secret="sec", + ) + + def test_surfaces_client_id_when_echoed(self, registrar, boto_client): + boto_client.create_oauth2_credential_provider.return_value = _create_response( + config_output={"googleOauth2ProviderConfig": {"clientId": "cid"}}, + ) + + info = registrar.create_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.GOOGLE, + client_id="cid", + client_secret="sec", + ) + + assert info.client_id == "cid" + + +class TestUpdateCredentialProvider: + def test_sends_full_config(self, registrar, boto_client): + boto_client.update_oauth2_credential_provider.return_value = _create_response() + + registrar.update_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.GITHUB, + client_id="new-cid", + client_secret="new-sec", + ) + + call = boto_client.update_oauth2_credential_provider.call_args.kwargs + assert call["credentialProviderVendor"] == "GithubOauth2" + assert call["oauth2ProviderConfigInput"]["githubOauth2ProviderConfig"] == { + "clientId": "new-cid", + "clientSecret": "new-sec", + } + + def test_not_found_maps_to_domain_error(self, registrar, boto_client): + boto_client.update_oauth2_credential_provider.side_effect = _client_error( + "ResourceNotFoundException" + ) + + with pytest.raises(CredentialProviderNotFoundError): + registrar.update_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.GOOGLE, + client_id="cid", + client_secret="sec", + ) + + +class TestGetCredentialProvider: + def test_returns_info_including_callback_url(self, registrar, boto_client): + boto_client.get_oauth2_credential_provider.return_value = { + **_create_response( + config_output={"googleOauth2ProviderConfig": {"clientId": "cid"}}, + ), + "credentialProviderVendor": "GoogleOauth2", + } + + info = registrar.get_credential_provider("p") + + assert info.vendor == "GoogleOauth2" + assert info.client_id == "cid" + assert info.callback_url.endswith("/callback/p") + + def test_not_found(self, registrar, boto_client): + boto_client.get_oauth2_credential_provider.side_effect = _client_error( + "ResourceNotFoundException" + ) + + with pytest.raises(CredentialProviderNotFoundError): + registrar.get_credential_provider("missing") + + +class TestDeleteCredentialProvider: + def test_calls_boto(self, registrar, boto_client): + registrar.delete_credential_provider("p") + boto_client.delete_oauth2_credential_provider.assert_called_once_with(name="p") + + def test_not_found_is_success(self, registrar, boto_client): + boto_client.delete_oauth2_credential_provider.side_effect = _client_error( + "ResourceNotFoundException" + ) + registrar.delete_credential_provider("missing") # no raise + + def test_other_errors_bubble(self, registrar, boto_client): + boto_client.delete_oauth2_credential_provider.side_effect = _client_error( + "AccessDeniedException" + ) + with pytest.raises(ClientError): + registrar.delete_credential_provider("p") From 2961906217e5b92ed53aef6d95a12848fa513e8e Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Wed, 22 Apr 2026 10:40:25 -0600 Subject: [PATCH 05/35] chore(connectors): grant IAM for credential-provider admin ops Adds Create/Update/Delete/Get/List on bedrock-agentcore OAuth2 credential providers to the app-api task role, scoped to the default token vault. Co-Authored-By: Claude Opus 4.7 (1M context) --- infrastructure/lib/app-api-stack.ts | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/infrastructure/lib/app-api-stack.ts b/infrastructure/lib/app-api-stack.ts index 20524a40..90b19e5d 100644 --- a/infrastructure/lib/app-api-stack.ts +++ b/infrastructure/lib/app-api-stack.ts @@ -917,6 +917,28 @@ export class AppApiStack extends cdk.Stack { resources: [`${oauthClientSecretsArn}*`], // Wildcard for random suffix }) ); + + // Admin CRUD for OAuth2 credential providers stored in AgentCore Identity. + // Provider-scoped actions are scoped to the default token vault; List + // requires a broader resource since it enumerates the vault itself. + taskDefinition.taskRole.addToPrincipalPolicy( + new iam.PolicyStatement({ + sid: 'AgentCoreCredentialProviderAdmin', + effect: iam.Effect.ALLOW, + actions: [ + 'bedrock-agentcore:CreateOauth2CredentialProvider', + 'bedrock-agentcore:UpdateOauth2CredentialProvider', + 'bedrock-agentcore:DeleteOauth2CredentialProvider', + 'bedrock-agentcore:GetOauth2CredentialProvider', + 'bedrock-agentcore:ListOauth2CredentialProviders', + ], + resources: [ + `arn:aws:bedrock-agentcore:${config.awsRegion}:${config.awsAccount}:token-vault/default`, + `arn:aws:bedrock-agentcore:${config.awsRegion}:${config.awsAccount}:token-vault/default/oauth2credentialprovider/*`, + ], + }) + ); + // Grant permissions for API Keys table (imported from Infrastructure Stack) taskDefinition.taskRole.addToPrincipalPolicy( new iam.PolicyStatement({ From cb32499bfb6d67e2faf43d8f027720234d8d018b Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Wed, 22 Apr 2026 10:41:19 -0600 Subject: [PATCH 06/35] refactor(connectors): retire in-house OAuth flow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Deletes the legacy 3LO dance that predates AgentCore Identity — the per-user token vault, PKCE-based authorization service, encryption layer, token cache, user-facing /oauth/* routes, and the tool-side OAuthToolService. AgentCore Identity owns the token vault and consent flow now; the inference path already routes through agentcore_identity.py via the recent external MCP client refactor, so these modules had no live consumers. Also slims shared/oauth/__init__.py to the surviving surface (provider model, repository, registrar) and unwires the user-facing router from app_api/main.py. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../main_agent/tools/oauth_tool_service.py | 336 --------- backend/src/apis/app_api/main.py | 2 - backend/src/apis/shared/oauth/__init__.py | 62 +- backend/src/apis/shared/oauth/encryption.py | 138 ---- backend/src/apis/shared/oauth/routes.py | 268 ------- backend/src/apis/shared/oauth/service.py | 664 ------------------ backend/src/apis/shared/oauth/token_cache.py | 165 ----- .../src/apis/shared/oauth/token_repository.py | 337 --------- backend/tests/shared/test_conftest_smoke.py | 3 - backend/tests/shared/test_coverage_boost.py | 79 --- backend/tests/shared/test_oauth_service.py | 153 ---- 11 files changed, 29 insertions(+), 2178 deletions(-) delete mode 100644 backend/src/agents/main_agent/tools/oauth_tool_service.py delete mode 100644 backend/src/apis/shared/oauth/encryption.py delete mode 100644 backend/src/apis/shared/oauth/routes.py delete mode 100644 backend/src/apis/shared/oauth/service.py delete mode 100644 backend/src/apis/shared/oauth/token_cache.py delete mode 100644 backend/src/apis/shared/oauth/token_repository.py delete mode 100644 backend/tests/shared/test_oauth_service.py diff --git a/backend/src/agents/main_agent/tools/oauth_tool_service.py b/backend/src/agents/main_agent/tools/oauth_tool_service.py deleted file mode 100644 index f5d47d7c..00000000 --- a/backend/src/agents/main_agent/tools/oauth_tool_service.py +++ /dev/null @@ -1,336 +0,0 @@ -""" -OAuth Tool Service - Enables tools to securely access user OAuth tokens - -This service provides a clean interface for tools to: -1. Check if a user has connected to an OAuth provider -2. Retrieve decrypted access tokens for API calls -3. Handle token refresh automatically -4. Generate connection URLs for user guidance - -Usage in tools: - from agents.main_agent.tools.oauth_tool_service import get_oauth_tool_service - - @tool(context=True) - async def my_oauth_tool(query: str, tool_context: ToolContext) -> dict: - oauth_service = get_oauth_tool_service() - - # Get user_id from session manager - user_id = tool_context.agent._session_manager.user_id - - # Get token for this provider - result = await oauth_service.get_token_for_tool( - user_id=user_id, - provider_id="google_workspace" - ) - - if not result.connected: - return result.not_connected_response() - - # Use the token - headers = {"Authorization": f"Bearer {result.access_token}"} - # ... make API calls -""" -import logging -import os -from dataclasses import dataclass -from typing import Optional - -from agents.main_agent.config.constants import EnvVars, Defaults - -logger = logging.getLogger(__name__) - - -@dataclass -class OAuthTokenResult: - """Result of requesting an OAuth token for a tool.""" - - connected: bool - """Whether the user is connected to the provider.""" - - access_token: Optional[str] = None - """The decrypted access token (if connected).""" - - provider_id: str = "" - """The provider ID requested.""" - - provider_name: str = "" - """Human-readable provider name.""" - - error: Optional[str] = None - """Error message if token retrieval failed.""" - - needs_reauth: bool = False - """Whether the user needs to re-authorize (expired/revoked).""" - - def not_connected_response(self, tool_name: str = "this tool") -> dict: - """ - Generate a user-friendly response when not connected. - - Returns a dict suitable for returning from a tool. - """ - frontend_url = os.getenv(EnvVars.FRONTEND_URL, Defaults.FRONTEND_URL) - connect_url = f"{frontend_url}/settings/connectors" - - if self.needs_reauth: - message = f"""⚠️ **Re-authorization Required** - -Your connection to **{self.provider_name}** has expired or been revoked. - -Please reconnect to continue using {tool_name}: - -👉 [Reconnect to {self.provider_name}]({connect_url}) - -After reconnecting, try your request again.""" - else: - message = f"""🔗 **Connection Required** - -To use {tool_name}, you need to connect your **{self.provider_name}** account. - -This allows me to securely access your data on your behalf. - -👉 [Connect {self.provider_name}]({connect_url}) - -After connecting, try your request again.""" - - return { - "status": "not_connected", - "content": [{"text": message}], - "requires_oauth": True, - "provider_id": self.provider_id, - "provider_name": self.provider_name, - "connect_url": connect_url, - "needs_reauth": self.needs_reauth, - } - - -class OAuthToolService: - """ - Service for tools to access OAuth tokens. - - This service wraps the OAuth service and provides a simplified - interface optimized for tool usage. - """ - - def __init__(self): - self._oauth_service = None - self._provider_repo = None - - async def _get_oauth_service(self): - """Lazy-load OAuth service to avoid circular imports.""" - if self._oauth_service is None: - from apis.shared.oauth.service import get_oauth_service - self._oauth_service = get_oauth_service() - return self._oauth_service - - async def _get_provider_repo(self): - """Lazy-load provider repository.""" - if self._provider_repo is None: - from apis.shared.oauth.provider_repository import get_provider_repository - self._provider_repo = get_provider_repository() - return self._provider_repo - - async def get_token_for_tool( - self, - user_id: str, - provider_id: str, - ) -> OAuthTokenResult: - """ - Get an OAuth token for a tool to use. - - Args: - user_id: The user's ID (from session manager) - provider_id: The OAuth provider ID (e.g., "google_workspace") - - Returns: - OAuthTokenResult with token or connection guidance - """ - try: - oauth_service = await self._get_oauth_service() - provider_repo = await self._get_provider_repo() - - # Get provider info for display name - provider = await provider_repo.get_provider(provider_id) - provider_name = provider.display_name if provider else provider_id - - # Try to get the token — isolated to limit taint scope - result = await self._try_get_token(oauth_service, user_id, provider_id, provider_name) - if result: - return result - - # Check if user has a connection but needs re-auth - from apis.shared.oauth.token_repository import get_token_repository - token_repo = get_token_repository() - user_token = await token_repo.get_user_token(user_id, provider_id) - - if user_token and user_token.status in ("expired", "needs_reauth", "revoked"): - token_status = user_token.status - logger.debug("User needs re-auth for an OAuth provider") - return OAuthTokenResult( - connected=False, - provider_id=provider_id, - provider_name=provider_name, - needs_reauth=True, - error=f"Token {token_status}", - ) - - # User not connected - logger.debug("User not connected to an OAuth provider") - return OAuthTokenResult( - connected=False, - provider_id=provider_id, - provider_name=provider_name, - ) - - except Exception as e: - logger.error("Error checking OAuth connection: %s", type(e).__name__, exc_info=True) - return OAuthTokenResult( - connected=False, - provider_id=provider_id, - provider_name=provider_id, - error=str(e), - ) - - @staticmethod - async def _try_get_token( - oauth_service: "OAuthService", - user_id: str, - provider_id: str, - provider_name: str, - ) -> Optional[OAuthTokenResult]: - """Fetch decrypted token in isolated scope to avoid taint leaking to callers.""" - token = await oauth_service.get_decrypted_token( - user_id=user_id, - provider_id=provider_id, - ) - if token: - logger.debug("OAuth token successfully retrieved") - return OAuthTokenResult( - connected=True, - access_token=token, - provider_id=provider_id, - provider_name=provider_name, - ) - return None - - async def check_connection( - self, - user_id: str, - provider_id: str, - ) -> bool: - """ - Quick check if user is connected to a provider. - - Args: - user_id: The user's ID - provider_id: The OAuth provider ID - - Returns: - True if connected and token is valid - """ - result = await self.get_token_for_tool(user_id, provider_id) - return result.connected - - def get_connect_url(self, provider_id: str) -> str: - """ - Get the URL for the user to connect to a provider. - - Args: - provider_id: The OAuth provider ID - - Returns: - URL to the connections page - """ - frontend_url = os.getenv(EnvVars.FRONTEND_URL, Defaults.FRONTEND_URL) - return f"{frontend_url}/settings/connectors" - - -# Singleton instance -_oauth_tool_service: Optional[OAuthToolService] = None - - -def get_oauth_tool_service() -> OAuthToolService: - """Get the singleton OAuthToolService instance.""" - global _oauth_tool_service - if _oauth_tool_service is None: - _oauth_tool_service = OAuthToolService() - return _oauth_tool_service - - -async def check_oauth_requirements_for_tools( - user_id: str, - enabled_tool_ids: list[str], -) -> dict[str, OAuthTokenResult]: - """ - Check OAuth connection status for all tools that require OAuth. - - This is useful for determining which tools will work and providing - guidance to users about missing connections. - - Args: - user_id: The user's ID - enabled_tool_ids: List of enabled tool IDs - - Returns: - Dict mapping provider_id to OAuthTokenResult for tools needing OAuth - """ - from apis.app_api.tools.repository import get_tool_catalog_repository - - results: dict[str, OAuthTokenResult] = {} - repository = get_tool_catalog_repository() - oauth_service = get_oauth_tool_service() - - # Find all unique OAuth providers required by enabled tools - providers_needed: set[str] = set() - - for tool_id in enabled_tool_ids: - try: - tool = await repository.get_tool(tool_id) - if tool and tool.requires_oauth_provider: - providers_needed.add(tool.requires_oauth_provider) - except Exception as e: - logger.warning(f"Could not check tool {tool_id}: {e}") - - # Check connection status for each provider - for provider_id in providers_needed: - result = await oauth_service.get_token_for_tool(user_id, provider_id) - results[provider_id] = result - - return results - - -def format_oauth_connection_guidance( - missing_connections: list[OAuthTokenResult], -) -> str: - """ - Format a user-friendly message about missing OAuth connections. - - Args: - missing_connections: List of OAuthTokenResult for unconnected providers - - Returns: - Markdown formatted message for the user - """ - if not missing_connections: - return "" - - frontend_url = os.getenv(EnvVars.FRONTEND_URL, Defaults.FRONTEND_URL) - connect_url = f"{frontend_url}/settings/connectors" - - if len(missing_connections) == 1: - conn = missing_connections[0] - if conn.needs_reauth: - return f"""⚠️ Your connection to **{conn.provider_name}** has expired. - -Please [reconnect]({connect_url}) to use tools that require {conn.provider_name} access.""" - else: - return f"""🔗 To use tools that require **{conn.provider_name}** access, please [connect your account]({connect_url}).""" - - # Multiple missing connections - provider_names = [c.provider_name for c in missing_connections] - names_str = ", ".join(provider_names[:-1]) + f" and {provider_names[-1]}" - - return f"""🔗 Some tools require account connections. - -To use all your enabled tools, please connect: **{names_str}** - -👉 [Manage Connections]({connect_url})""" diff --git a/backend/src/apis/app_api/main.py b/backend/src/apis/app_api/main.py index 0c16aef8..fca2dd8e 100644 --- a/backend/src/apis/app_api/main.py +++ b/backend/src/apis/app_api/main.py @@ -87,7 +87,6 @@ async def lifespan(app: FastAPI): from apis.app_api.documents.routes import router as documents_router from apis.app_api.users.routes import router as users_router from apis.app_api.user_settings.routes import router as user_settings_router -from apis.shared.oauth.routes import router as oauth_router from apis.app_api.system.routes import router as system_router from apis.app_api.shares.routes import conversations_share_router, shares_router, shared_view_router @@ -108,7 +107,6 @@ async def lifespan(app: FastAPI): app.include_router(memory_router) # AgentCore Memory access endpoints app.include_router(tools_router) # Tool discovery and permissions app.include_router(files_router) # File upload via pre-signed URLs -app.include_router(oauth_router) # OAuth provider connections app.include_router(system_router) # System status and first-boot endpoints app.include_router(conversations_share_router) # Share conversations endpoints app.include_router(shares_router) # Share management (update, revoke, export) diff --git a/backend/src/apis/shared/oauth/__init__.py b/backend/src/apis/shared/oauth/__init__.py index bd792d88..6750d0e6 100644 --- a/backend/src/apis/shared/oauth/__init__.py +++ b/backend/src/apis/shared/oauth/__init__.py @@ -1,52 +1,48 @@ -"""OAuth provider management module. +"""OAuth provider administration. -This module provides OAuth connection management for third-party integrations. -Admins can configure OAuth providers (Google, Microsoft, Canvas, etc.) and -users can connect their accounts for MCP tool requests. +Providers are registered and authenticated against AWS Bedrock AgentCore +Identity. This module exposes the provider metadata model, the DynamoDB +repository, and the AgentCore registrar used by admin CRUD routes. """ +from .agentcore_registrar import ( + AgentCoreRegistrar, + CredentialProviderConflictError, + CredentialProviderInfo, + CredentialProviderNotFoundError, + InvalidCustomProviderConfigError, + get_agentcore_registrar, +) from .models import ( - OAuthProviderType, - OAuthConnectionStatus, OAuthProvider, - OAuthUserToken, OAuthProviderCreate, - OAuthProviderUpdate, - OAuthProviderResponse, OAuthProviderListResponse, - OAuthConnectionResponse, - OAuthConnectionListResponse, - OAuthConnectResponse, + OAuthProviderResponse, + OAuthProviderType, + OAuthProviderUpdate, + OAuthRequiredEvent, + compute_scopes_hash, +) +from .provider_repository import ( + OAuthProviderRepository, + get_provider_repository, ) -from .encryption import TokenEncryptionService, get_token_encryption_service -from .token_cache import TokenCache, get_token_cache -from .provider_repository import OAuthProviderRepository, get_provider_repository -from .token_repository import OAuthTokenRepository, get_token_repository -from .service import OAuthService, get_oauth_service __all__ = [ - # Enums "OAuthProviderType", - "OAuthConnectionStatus", - # Models "OAuthProvider", - "OAuthUserToken", "OAuthProviderCreate", "OAuthProviderUpdate", "OAuthProviderResponse", "OAuthProviderListResponse", - "OAuthConnectionResponse", - "OAuthConnectionListResponse", - "OAuthConnectResponse", - # Services - "TokenEncryptionService", - "get_token_encryption_service", - "TokenCache", - "get_token_cache", + "OAuthRequiredEvent", + "compute_scopes_hash", "OAuthProviderRepository", "get_provider_repository", - "OAuthTokenRepository", - "get_token_repository", - "OAuthService", - "get_oauth_service", + "AgentCoreRegistrar", + "CredentialProviderInfo", + "CredentialProviderConflictError", + "CredentialProviderNotFoundError", + "InvalidCustomProviderConfigError", + "get_agentcore_registrar", ] diff --git a/backend/src/apis/shared/oauth/encryption.py b/backend/src/apis/shared/oauth/encryption.py deleted file mode 100644 index 0b1dc877..00000000 --- a/backend/src/apis/shared/oauth/encryption.py +++ /dev/null @@ -1,138 +0,0 @@ -"""KMS encryption service for OAuth tokens.""" - -import base64 -import logging -import os -from typing import Optional - -logger = logging.getLogger(__name__) - - -class TokenEncryptionService: - """ - Service for encrypting/decrypting OAuth tokens using AWS KMS. - - Tokens are encrypted before storage in DynamoDB and decrypted on retrieval. - Uses AWS KMS with a customer-managed key for secure key management. - """ - - def __init__( - self, - key_arn: Optional[str] = None, - region: Optional[str] = None, - ): - """ - Initialize the encryption service. - - Args: - key_arn: KMS key ARN for encryption (defaults to env var) - region: AWS region (defaults to env var) - """ - self._key_arn = key_arn or os.getenv("OAUTH_TOKEN_ENCRYPTION_KEY_ARN") - self._region = region or os.getenv("AWS_REGION", "us-west-2") - self._client = None - self._enabled = bool(self._key_arn) - - if not self._enabled: - logger.warning( - "OAUTH_TOKEN_ENCRYPTION_KEY_ARN not set. " - "Token encryption is disabled (development mode only)." - ) - - @property - def enabled(self) -> bool: - """Check if encryption is enabled.""" - return self._enabled - - def _get_client(self): - """Lazy initialization of KMS client.""" - if self._client is None: - import boto3 - - profile = os.getenv("AWS_PROFILE") - if profile: - session = boto3.Session(profile_name=profile) - self._client = session.client("kms", region_name=self._region) - else: - self._client = boto3.client("kms", region_name=self._region) - return self._client - - def encrypt(self, plaintext: str) -> str: - """ - Encrypt a plaintext string using KMS. - - Args: - plaintext: The string to encrypt - - Returns: - Base64-encoded ciphertext - - Raises: - RuntimeError: If encryption fails - """ - if not self._enabled: - # Development mode: return base64-encoded plaintext (NOT secure!) - logger.warning("Using development mode encryption (NOT SECURE)") - return f"DEV:{base64.b64encode(plaintext.encode()).decode()}" - - try: - client = self._get_client() - response = client.encrypt( - KeyId=self._key_arn, - Plaintext=plaintext.encode("utf-8"), - ) - ciphertext = base64.b64encode(response["CiphertextBlob"]).decode("utf-8") - logger.debug(f"Encrypted token (length={len(plaintext)} -> {len(ciphertext)})") - return ciphertext - - except Exception as e: - logger.error(f"Failed to encrypt token: {e}") - raise RuntimeError(f"Token encryption failed: {e}") from e - - def decrypt(self, ciphertext: str) -> str: - """ - Decrypt a ciphertext string using KMS. - - Args: - ciphertext: Base64-encoded ciphertext - - Returns: - Decrypted plaintext string - - Raises: - RuntimeError: If decryption fails - """ - if not self._enabled: - # Development mode: decode base64 plaintext - if ciphertext.startswith("DEV:"): - logger.warning("Using development mode decryption (NOT SECURE)") - return base64.b64decode(ciphertext[4:]).decode() - else: - raise RuntimeError("Cannot decrypt production ciphertext without KMS key") - - try: - client = self._get_client() - ciphertext_blob = base64.b64decode(ciphertext) - response = client.decrypt( - CiphertextBlob=ciphertext_blob, - KeyId=self._key_arn, - ) - plaintext = response["Plaintext"].decode("utf-8") - logger.debug(f"Decrypted token (length={len(ciphertext)} -> {len(plaintext)})") - return plaintext - - except Exception as e: - logger.error(f"Failed to decrypt token: {e}") - raise RuntimeError(f"Token decryption failed: {e}") from e - - -# Singleton instance -_encryption_service: Optional[TokenEncryptionService] = None - - -def get_token_encryption_service() -> TokenEncryptionService: - """Get the token encryption service singleton.""" - global _encryption_service - if _encryption_service is None: - _encryption_service = TokenEncryptionService() - return _encryption_service diff --git a/backend/src/apis/shared/oauth/routes.py b/backend/src/apis/shared/oauth/routes.py deleted file mode 100644 index 93f2edf7..00000000 --- a/backend/src/apis/shared/oauth/routes.py +++ /dev/null @@ -1,268 +0,0 @@ -"""User-facing OAuth routes for connection management.""" - -import logging -import os -from typing import Optional -from urllib.parse import urlencode, urlparse - -from fastapi import APIRouter, Depends, HTTPException, Query, status -from fastapi.responses import RedirectResponse - -from apis.shared.auth import User, get_current_user -from apis.shared.rbac.service import AppRoleService, get_app_role_service - -from apis.shared.oauth.models import ( - OAuthConnectionListResponse, - OAuthConnectResponse, - OAuthProviderListResponse, - OAuthProviderResponse, -) -from apis.shared.oauth.provider_repository import OAuthProviderRepository, get_provider_repository -from apis.shared.oauth.service import OAuthService, get_oauth_service - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/oauth", tags=["oauth"]) - - -# ============================================================================= -# Provider Discovery (filtered by user roles) -# ============================================================================= - - -@router.get("/providers", response_model=OAuthProviderListResponse) -async def list_available_providers( - current_user: User = Depends(get_current_user), - provider_repo: OAuthProviderRepository = Depends(get_provider_repository), - role_service: AppRoleService = Depends(get_app_role_service), -): - """ - List OAuth providers available to the current user. - - Filters providers based on user's application roles. - - Returns: - OAuthProviderListResponse with available providers - """ - logger.info(f"User {current_user.name} listing available OAuth providers") - - # Resolve user's application roles - permissions = await role_service.resolve_user_permissions(current_user) - user_roles = permissions.app_roles if permissions.app_roles else [] - - # Get enabled providers - providers = await provider_repo.list_providers(enabled_only=True) - - # Filter by user roles - available = [] - for provider in providers: - if not provider.allowed_roles or any( - role in provider.allowed_roles for role in user_roles - ): - available.append(OAuthProviderResponse.from_provider(provider)) - - return OAuthProviderListResponse( - providers=available, - total=len(available), - ) - - -# ============================================================================= -# User Connections -# ============================================================================= - - -@router.get("/connections", response_model=OAuthConnectionListResponse) -async def list_user_connections( - current_user: User = Depends(get_current_user), - oauth_service: OAuthService = Depends(get_oauth_service), - role_service: AppRoleService = Depends(get_app_role_service), -): - """ - List the current user's OAuth connections. - - Returns all available providers with connection status. - - Returns: - OAuthConnectionListResponse with connection statuses - """ - logger.info(f"User {current_user.name} listing OAuth connections") - - # Resolve user's application roles - permissions = await role_service.resolve_user_permissions(current_user) - user_roles = permissions.app_roles if permissions.app_roles else [] - - connections = await oauth_service.get_user_connections( - user_id=current_user.user_id, - user_roles=user_roles, - ) - - return OAuthConnectionListResponse(connections=connections) - - -# ============================================================================= -# OAuth Flow -# ============================================================================= - - -@router.get("/connect/{provider_id}", response_model=OAuthConnectResponse) -async def initiate_connection( - provider_id: str, - redirect: Optional[str] = Query( - None, - description="Frontend URL to redirect after OAuth callback", - ), - current_user: User = Depends(get_current_user), - oauth_service: OAuthService = Depends(get_oauth_service), - role_service: AppRoleService = Depends(get_app_role_service), -): - """ - Initiate OAuth connection flow for a provider. - - Returns an authorization URL that the frontend should redirect to. - - Args: - provider_id: Provider to connect to - redirect: Optional frontend redirect URL after completion - - Returns: - OAuthConnectResponse with authorization URL - - Raises: - HTTPException: - - 404 if provider not found - - 403 if user not authorized for provider - """ - logger.info( - "User initiating OAuth connection" - ) - - # Resolve user's application roles - permissions = await role_service.resolve_user_permissions(current_user) - user_roles = permissions.app_roles if permissions.app_roles else [] - - authorization_url = await oauth_service.initiate_connect( - provider_id=provider_id, - user_id=current_user.user_id, - user_roles=user_roles, - frontend_redirect=redirect, - ) - - return OAuthConnectResponse(authorization_url=authorization_url) - - -@router.get("/callback") -async def oauth_callback( - code: Optional[str] = Query(None), - state: Optional[str] = Query(None), - error: Optional[str] = Query(None), - error_description: Optional[str] = Query(None), - oauth_service: OAuthService = Depends(get_oauth_service), -): - """ - Handle OAuth callback from provider. - - This endpoint is called by the OAuth provider after user authorization. - Exchanges the code for tokens and redirects to the frontend. - - Args: - code: Authorization code from provider - state: State parameter for validation - error: Error code if authorization failed - error_description: Error description if authorization failed - - Returns: - Redirect to frontend with success/error query params - """ - # Get frontend base URL from environment - frontend_url = os.getenv("FRONTEND_URL", "http://localhost:4200") - callback_path = "/settings/oauth/callback" - - # Handle error from provider - if error: - logger.warning("OAuth callback error") - params = urlencode({"error": error, "error_description": error_description or ""}) - return RedirectResponse( - url=f"{frontend_url}{callback_path}?{params}", - status_code=status.HTTP_302_FOUND, - ) - - # Validate required params - if not code or not state: - logger.warning("OAuth callback missing code or state") - params = urlencode({"error": "missing_params"}) - return RedirectResponse( - url=f"{frontend_url}{callback_path}?{params}", - status_code=status.HTTP_302_FOUND, - ) - - # Process callback - provider_id, frontend_redirect, callback_error = await oauth_service.handle_callback( - code=code, - state=state, - ) - - # Build redirect URL — validate that frontend_redirect is same-origin - # to prevent open redirect attacks via manipulated OAuth state - redirect_base = f"{frontend_url}{callback_path}" - if frontend_redirect: - parsed = urlparse(frontend_redirect) - parsed_frontend = urlparse(frontend_url) - if parsed.scheme == parsed_frontend.scheme and parsed.netloc == parsed_frontend.netloc: - redirect_base = frontend_redirect - else: - logger.warning( - f"OAuth callback redirect blocked — origin mismatch: {parsed.netloc} != {parsed_frontend.netloc}" - ) - - if callback_error: - params = urlencode({"error": callback_error, "provider": provider_id}) - return RedirectResponse( - url=f"{redirect_base}?{params}", - status_code=status.HTTP_302_FOUND, - ) - - # Success - params = urlencode({"success": "true", "provider": provider_id}) - return RedirectResponse( - url=f"{redirect_base}?{params}", - status_code=status.HTTP_302_FOUND, - ) - - -# ============================================================================= -# Disconnect -# ============================================================================= - - -@router.delete("/connections/{provider_id}", status_code=status.HTTP_204_NO_CONTENT) -async def disconnect_provider( - provider_id: str, - current_user: User = Depends(get_current_user), - oauth_service: OAuthService = Depends(get_oauth_service), -): - """ - Disconnect from an OAuth provider. - - Revokes tokens if possible and removes the connection. - - Args: - provider_id: Provider to disconnect from - - Raises: - HTTPException: 404 if not connected to provider - """ - logger.info("User disconnecting from OAuth provider") - - disconnected = await oauth_service.disconnect( - user_id=current_user.user_id, - provider_id=provider_id, - ) - - if not disconnected: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Not connected to provider '{provider_id}'", - ) - - return None diff --git a/backend/src/apis/shared/oauth/service.py b/backend/src/apis/shared/oauth/service.py deleted file mode 100644 index 78612c75..00000000 --- a/backend/src/apis/shared/oauth/service.py +++ /dev/null @@ -1,664 +0,0 @@ -"""OAuth service for managing provider connections and token exchange.""" - -import base64 -import hashlib -import logging -import os -import secrets -import time -from dataclasses import dataclass -from datetime import datetime, timezone -from typing import Dict, List, Optional, Tuple - -import httpx -from authlib.integrations.httpx_client import AsyncOAuth2Client -from fastapi import HTTPException, status - -from apis.shared.auth.state_store import StateStore, create_state_store - -from .encryption import TokenEncryptionService, get_token_encryption_service -from .models import ( - OAuthConnectionResponse, - OAuthConnectionStatus, - OAuthProvider, - OAuthUserToken, -) -from .provider_repository import OAuthProviderRepository, get_provider_repository -from .token_cache import TokenCache, get_token_cache -from .token_repository import OAuthTokenRepository, get_token_repository - -logger = logging.getLogger(__name__) - - -@dataclass -class OAuthStateData: - """Data stored with OAuth state for security validation.""" - - provider_id: str - user_id: str - code_verifier: Optional[str] = None # PKCE code verifier (S256) - redirect_uri: Optional[str] = None # Frontend redirect after callback - - -def generate_pkce_pair() -> Tuple[str, str]: - """ - Generate PKCE code verifier and challenge (S256). - - Returns: - Tuple of (code_verifier, code_challenge) - """ - # Generate 32 bytes of random data for code_verifier - code_verifier = secrets.token_urlsafe(32) - - # Create code_challenge using S256: BASE64URL(SHA256(code_verifier)) - digest = hashlib.sha256(code_verifier.encode("ascii")).digest() - code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") - - return code_verifier, code_challenge - - -class OAuthService: - """ - Service for OAuth flow management. - - Handles: - - Initiating OAuth connection flows - - Processing OAuth callbacks and token exchange - - Token refresh and decryption - - Connection status management - """ - - def __init__( - self, - provider_repo: Optional[OAuthProviderRepository] = None, - token_repo: Optional[OAuthTokenRepository] = None, - encryption_service: Optional[TokenEncryptionService] = None, - token_cache: Optional[TokenCache] = None, - state_store: Optional[StateStore] = None, - ): - """ - Initialize OAuth service. - - Args: - provider_repo: Provider repository (defaults to singleton) - token_repo: Token repository (defaults to singleton) - encryption_service: Token encryption service (defaults to singleton) - token_cache: Token cache (defaults to singleton) - state_store: State store for OAuth state (defaults to create_state_store) - """ - self._provider_repo = provider_repo or get_provider_repository() - self._token_repo = token_repo or get_token_repository() - self._encryption = encryption_service or get_token_encryption_service() - self._cache = token_cache or get_token_cache() - self._state_store = state_store or create_state_store() - - # OAuth callback URL (configured in environment) - self._callback_url = os.getenv("OAUTH_CALLBACK_URL", "") - if not self._callback_url: - logger.warning( - "OAUTH_CALLBACK_URL not set. OAuth flows will fail. " - "Set to e.g. https://your-app.com/oauth/callback" - ) - - # State TTL in seconds (10 minutes) - self._state_ttl = 600 - - @property - def enabled(self) -> bool: - """Check if OAuth service is enabled.""" - return self._provider_repo.enabled and self._token_repo.enabled - - # ========================================================================= - # Connection Flow - # ========================================================================= - - async def initiate_connect( - self, - provider_id: str, - user_id: str, - user_roles: List[str], - frontend_redirect: Optional[str] = None, - ) -> str: - """ - Initiate OAuth connection flow. - - Generates authorization URL for user to visit. - - Args: - provider_id: Provider to connect to - user_id: User initiating connection - user_roles: User's roles for access check - frontend_redirect: URL to redirect after callback - - Returns: - Authorization URL - - Raises: - HTTPException: If provider not found or user not authorized - """ - # Get provider - provider = await self._provider_repo.get_provider(provider_id) - if not provider: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Provider '{provider_id}' not found", - ) - - if not provider.enabled: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Provider '{provider_id}' is not enabled", - ) - - # Check role access - if provider.allowed_roles and not any( - role in provider.allowed_roles for role in user_roles - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="You do not have access to this provider", - ) - - # Generate state and PKCE - state = secrets.token_urlsafe(32) - - code_verifier = None - code_challenge = None - if provider.pkce_required: - code_verifier, code_challenge = generate_pkce_pair() - - # Store state data - state_data = OAuthStateData( - provider_id=provider_id, - user_id=user_id, - code_verifier=code_verifier, - redirect_uri=frontend_redirect, - ) - self._store_state(state, state_data) - - # Build authorization URL - params = { - "client_id": provider.client_id, - "redirect_uri": self._callback_url, - "response_type": "code", - "scope": " ".join(provider.scopes), - "state": state, - } - - if code_challenge: - params["code_challenge"] = code_challenge - params["code_challenge_method"] = "S256" - - # Add provider-specific authorization params (e.g., access_type=offline for Google) - if provider.authorization_params: - params.update(provider.authorization_params) - - # Build URL - auth_url = f"{provider.authorization_endpoint}?" - auth_url += "&".join(f"{k}={v}" for k, v in params.items()) - - logger.info(f"Initiated OAuth flow for user {user_id}, provider {provider_id}") - return auth_url - - async def handle_callback( - self, - code: str, - state: str, - ) -> Tuple[str, Optional[str], Optional[str]]: - """ - Handle OAuth callback after user authorization. - - Exchanges code for tokens and stores encrypted. - - Args: - code: Authorization code from provider - state: State parameter for validation - - Returns: - Tuple of (provider_id, frontend_redirect, error) - """ - # Validate and retrieve state data - valid, state_data = self._get_and_delete_state(state) - if not valid or not state_data: - logger.warning(f"Invalid or expired OAuth state: {state[:16]}...") - return "", None, "invalid_state" - - provider_id = state_data.provider_id - user_id = state_data.user_id - code_verifier = state_data.code_verifier - frontend_redirect = state_data.redirect_uri - - try: - # Get provider - provider = await self._provider_repo.get_provider(provider_id) - if not provider: - logger.error(f"Provider not found during callback: {provider_id}") - return provider_id, frontend_redirect, "provider_not_found" - - # Get client secret - client_secret = await self._provider_repo.get_client_secret(provider_id) - if not client_secret: - logger.error(f"Client secret not found for provider: {provider_id}") - return provider_id, frontend_redirect, "configuration_error" - - # Exchange code for tokens - token_data = await self._exchange_code( - provider=provider, - client_secret=client_secret, - code=code, - code_verifier=code_verifier, - ) - - if not token_data: - return provider_id, frontend_redirect, "token_exchange_failed" - - # Encrypt and store tokens - access_token = token_data.get("access_token") - refresh_token = token_data.get("refresh_token") - expires_in = token_data.get("expires_in") - token_type = token_data.get("token_type", "Bearer") - - if not access_token: - logger.error("No access token in response") - return provider_id, frontend_redirect, "no_access_token" - - # Calculate expiration - expires_at = None - if expires_in: - expires_at = int(time.time()) + int(expires_in) - - # Encrypt tokens - encrypted_access = self._encryption.encrypt(access_token) - encrypted_refresh = None - if refresh_token: - encrypted_refresh = self._encryption.encrypt(refresh_token) - - # Create token record - now = datetime.now(timezone.utc).isoformat() + "Z" - user_token = OAuthUserToken( - user_id=user_id, - provider_id=provider_id, - access_token_encrypted=encrypted_access, - refresh_token_encrypted=encrypted_refresh, - token_type=token_type, - expires_at=expires_at, - scopes_hash=provider.scopes_hash, - status=OAuthConnectionStatus.CONNECTED, - connected_at=now, - updated_at=now, - ) - - await self._token_repo.save_token(user_token) - - # Cache the decrypted token - self._cache.set(user_id, provider_id, access_token) - - logger.info(f"Successfully connected user {user_id} to provider {provider_id}") - return provider_id, frontend_redirect, None - - except Exception as e: - logger.error(f"Error handling OAuth callback: {e}", exc_info=True) - return provider_id, frontend_redirect, str(e) - - async def _exchange_code( - self, - provider: OAuthProvider, - client_secret: str, - code: str, - code_verifier: Optional[str] = None, - ) -> Optional[Dict]: - """ - Exchange authorization code for tokens. - - Args: - provider: OAuth provider - client_secret: Provider client secret - code: Authorization code - code_verifier: PKCE code verifier (if used) - - Returns: - Token response dict or None on error - """ - try: - async with AsyncOAuth2Client( - client_id=provider.client_id, - client_secret=client_secret, - token_endpoint=provider.token_endpoint, - ) as client: - token = await client.fetch_token( - url=provider.token_endpoint, - grant_type="authorization_code", - code=code, - redirect_uri=self._callback_url, - code_verifier=code_verifier, - ) - return dict(token) - - except Exception as e: - logger.error(f"Token exchange failed: {e}") - return None - - # ========================================================================= - # Token Access - # ========================================================================= - - async def get_decrypted_token( - self, - user_id: str, - provider_id: str, - ) -> Optional[str]: - """ - Get decrypted access token for a user's provider connection. - - Checks cache first, then decrypts from storage. - Handles token refresh if needed. - - Args: - user_id: User identifier - provider_id: Provider identifier - - Returns: - Decrypted access token, or None if not connected - """ - # Check cache first - cached = self._cache.get(user_id, provider_id) - if cached: - return cached - - # Get token from storage - token = await self._token_repo.get_token(user_id, provider_id) - if not token: - return None - - if token.status == OAuthConnectionStatus.REVOKED: - return None - - # Check if expired and needs refresh - if token.is_expired: - if token.refresh_token_encrypted: - refreshed = await self._refresh_token(user_id, provider_id) - if refreshed: - return refreshed - # Mark as expired - await self._token_repo.update_token_status( - user_id, provider_id, OAuthConnectionStatus.EXPIRED - ) - return None - - # Decrypt and cache - try: - access_token = self._encryption.decrypt(token.access_token_encrypted) - self._cache.set(user_id, provider_id, access_token) - return access_token - except Exception as e: - logger.error(f"Failed to decrypt token: {e}") - return None - - async def _refresh_token( - self, - user_id: str, - provider_id: str, - ) -> Optional[str]: - """ - Refresh an expired token. - - Args: - user_id: User identifier - provider_id: Provider identifier - - Returns: - New access token, or None on failure - """ - token = await self._token_repo.get_token(user_id, provider_id) - if not token or not token.refresh_token_encrypted: - return None - - provider = await self._provider_repo.get_provider(provider_id) - if not provider: - return None - - client_secret = await self._provider_repo.get_client_secret(provider_id) - if not client_secret: - return None - - try: - # Decrypt refresh token - refresh_token = self._encryption.decrypt(token.refresh_token_encrypted) - - # Request new tokens - async with AsyncOAuth2Client( - client_id=provider.client_id, - client_secret=client_secret, - token_endpoint=provider.token_endpoint, - ) as client: - new_token = await client.refresh_token( - url=provider.token_endpoint, - refresh_token=refresh_token, - ) - - if not new_token or "access_token" not in new_token: - logger.warning(f"Token refresh failed for {user_id}/{provider_id}") - return None - - # Update stored token - token.access_token_encrypted = self._encryption.encrypt( - new_token["access_token"] - ) - if "refresh_token" in new_token: - token.refresh_token_encrypted = self._encryption.encrypt( - new_token["refresh_token"] - ) - if "expires_in" in new_token: - token.expires_at = int(time.time()) + int(new_token["expires_in"]) - token.status = OAuthConnectionStatus.CONNECTED - - await self._token_repo.save_token(token) - - # Update cache - self._cache.set(user_id, provider_id, new_token["access_token"]) - - logger.info(f"Refreshed token for user {user_id}, provider {provider_id}") - return new_token["access_token"] - - except Exception as e: - logger.error(f"Token refresh failed: {e}") - return None - - # ========================================================================= - # Connection Management - # ========================================================================= - - async def get_user_connections( - self, - user_id: str, - user_roles: List[str], - ) -> List[OAuthConnectionResponse]: - """ - Get user's OAuth connections with status. - - Returns all available providers with connection status. - - Args: - user_id: User identifier - user_roles: User's roles for filtering available providers - - Returns: - List of connection responses - """ - # Get available providers for user's roles - providers = await self._provider_repo.list_providers(enabled_only=True) - available_providers = [ - p for p in providers - if not p.allowed_roles or any(r in p.allowed_roles for r in user_roles) - ] - - # Get user's tokens - tokens = await self._token_repo.list_user_tokens(user_id) - token_map = {t.provider_id: t for t in tokens} - - # Build connection responses - connections = [] - for provider in available_providers: - token = token_map.get(provider.provider_id) - - if token: - # Check if needs re-auth (scope changes) - needs_reauth = token.scopes_hash != provider.scopes_hash - - # Update status based on token state - if token.is_expired: - status = OAuthConnectionStatus.EXPIRED - elif needs_reauth: - status = OAuthConnectionStatus.NEEDS_REAUTH - else: - status = token.status - - connections.append( - OAuthConnectionResponse( - provider_id=provider.provider_id, - display_name=provider.display_name, - provider_type=provider.provider_type, - icon_name=provider.icon_name, - status=status, - connected_at=token.connected_at, - needs_reauth=needs_reauth, - ) - ) - else: - # Not connected - connections.append( - OAuthConnectionResponse( - provider_id=provider.provider_id, - display_name=provider.display_name, - provider_type=provider.provider_type, - icon_name=provider.icon_name, - status=OAuthConnectionStatus.REVOKED, # Use REVOKED as "not connected" - connected_at=None, - needs_reauth=False, - ) - ) - - return connections - - async def disconnect( - self, - user_id: str, - provider_id: str, - ) -> bool: - """ - Disconnect user from a provider. - - Revokes token if possible and deletes from storage. - - Args: - user_id: User identifier - provider_id: Provider identifier - - Returns: - True if disconnected, False if not connected - """ - # Get token - token = await self._token_repo.get_token(user_id, provider_id) - if not token: - return False - - # Try to revoke token at provider - provider = await self._provider_repo.get_provider(provider_id) - if provider and provider.revocation_endpoint: - try: - access_token = self._encryption.decrypt(token.access_token_encrypted) - await self._revoke_token(provider, access_token) - except Exception as e: - logger.warning(f"Failed to revoke token at provider: {e}") - # Continue with deletion anyway - - # Delete from storage - deleted = await self._token_repo.delete_token(user_id, provider_id) - - # Clear cache - self._cache.delete(user_id, provider_id) - - logger.info(f"Disconnected user {user_id} from provider {provider_id}") - return deleted - - async def _revoke_token( - self, - provider: OAuthProvider, - access_token: str, - ) -> None: - """ - Revoke token at the provider's revocation endpoint. - - Args: - provider: OAuth provider - access_token: Token to revoke - """ - if not provider.revocation_endpoint: - return - - client_secret = await self._provider_repo.get_client_secret(provider.provider_id) - if not client_secret: - return - - try: - async with httpx.AsyncClient() as client: - await client.post( - provider.revocation_endpoint, - data={ - "token": access_token, - "client_id": provider.client_id, - "client_secret": client_secret, - }, - timeout=10.0, - ) - except Exception as e: - logger.warning(f"Token revocation request failed: {e}") - - # ========================================================================= - # State Management (reuses OIDC state store pattern) - # ========================================================================= - - def _store_state(self, state: str, data: OAuthStateData) -> None: - """Store OAuth state data.""" - # Convert to dict for storage - from apis.shared.auth.state_store import OIDCStateData - - oidc_data = OIDCStateData( - redirect_uri=data.redirect_uri, - code_verifier=data.code_verifier, - nonce=f"{data.provider_id}|{data.user_id}", # Encode provider/user in nonce - ) - self._state_store.store_state(state, oidc_data, self._state_ttl) - - def _get_and_delete_state( - self, state: str - ) -> Tuple[bool, Optional[OAuthStateData]]: - """Retrieve and delete OAuth state data.""" - valid, oidc_data = self._state_store.get_and_delete_state(state) - if not valid or not oidc_data: - return False, None - - # Decode provider/user from nonce - if not oidc_data.nonce or "|" not in oidc_data.nonce: - return False, None - - provider_id, user_id = oidc_data.nonce.split("|", 1) - - return True, OAuthStateData( - provider_id=provider_id, - user_id=user_id, - code_verifier=oidc_data.code_verifier, - redirect_uri=oidc_data.redirect_uri, - ) - - -# Singleton instance -_oauth_service: Optional[OAuthService] = None - - -def get_oauth_service() -> OAuthService: - """Get the OAuth service singleton.""" - global _oauth_service - if _oauth_service is None: - _oauth_service = OAuthService() - return _oauth_service diff --git a/backend/src/apis/shared/oauth/token_cache.py b/backend/src/apis/shared/oauth/token_cache.py deleted file mode 100644 index 33224e11..00000000 --- a/backend/src/apis/shared/oauth/token_cache.py +++ /dev/null @@ -1,165 +0,0 @@ -"""In-memory cache for decrypted OAuth tokens. - -Uses TTLCache to avoid repeated KMS decrypt calls for frequently accessed tokens. -Cache entries expire after 5 minutes by default. -""" - -import logging -from typing import Optional - -from cachetools import TTLCache - -logger = logging.getLogger(__name__) - -# Default cache TTL in seconds (5 minutes) -DEFAULT_CACHE_TTL = 300 - -# Default cache size (number of tokens to cache) -DEFAULT_CACHE_SIZE = 1000 - - -class TokenCache: - """ - TTL-based cache for decrypted OAuth access tokens. - - Reduces KMS API calls by caching decrypted tokens in memory. - Tokens are automatically evicted after the TTL expires. - - Thread-safe through cachetools' internal locking. - """ - - def __init__( - self, - maxsize: int = DEFAULT_CACHE_SIZE, - ttl: int = DEFAULT_CACHE_TTL, - ): - """ - Initialize the token cache. - - Args: - maxsize: Maximum number of tokens to cache - ttl: Time-to-live in seconds for cache entries - """ - self._cache: TTLCache = TTLCache(maxsize=maxsize, ttl=ttl) - self._ttl = ttl - self._maxsize = maxsize - - logger.info(f"Initialized token cache: maxsize={maxsize}, ttl={ttl}s") - - def _make_key(self, user_id: str, provider_id: str) -> str: - """Create a cache key from user and provider IDs.""" - return f"{user_id}:{provider_id}" - - def get(self, user_id: str, provider_id: str) -> Optional[str]: - """ - Get a cached access token. - - Args: - user_id: User identifier - provider_id: OAuth provider identifier - - Returns: - Decrypted access token if cached, None otherwise - """ - key = self._make_key(user_id, provider_id) - token = self._cache.get(key) - if token: - logger.debug(f"Token cache hit: user={user_id}, provider={provider_id}") - else: - logger.debug(f"Token cache miss: user={user_id}, provider={provider_id}") - return token - - def set(self, user_id: str, provider_id: str, access_token: str) -> None: - """ - Cache a decrypted access token. - - Args: - user_id: User identifier - provider_id: OAuth provider identifier - access_token: Decrypted access token to cache - """ - key = self._make_key(user_id, provider_id) - self._cache[key] = access_token - logger.debug(f"Cached token: user={user_id}, provider={provider_id}") - - def delete(self, user_id: str, provider_id: str) -> bool: - """ - Remove a token from cache. - - Args: - user_id: User identifier - provider_id: OAuth provider identifier - - Returns: - True if token was in cache, False otherwise - """ - key = self._make_key(user_id, provider_id) - if key in self._cache: - del self._cache[key] - logger.debug(f"Evicted token from cache: user={user_id}, provider={provider_id}") - return True - return False - - def delete_for_user(self, user_id: str) -> int: - """ - Remove all cached tokens for a user. - - Args: - user_id: User identifier - - Returns: - Number of tokens removed - """ - prefix = f"{user_id}:" - keys_to_delete = [k for k in self._cache.keys() if k.startswith(prefix)] - for key in keys_to_delete: - del self._cache[key] - if keys_to_delete: - logger.debug(f"Evicted {len(keys_to_delete)} tokens for user: {user_id}") - return len(keys_to_delete) - - def delete_for_provider(self, provider_id: str) -> int: - """ - Remove all cached tokens for a provider. - - Useful when provider configuration changes (e.g., scopes updated). - - Args: - provider_id: OAuth provider identifier - - Returns: - Number of tokens removed - """ - suffix = f":{provider_id}" - keys_to_delete = [k for k in self._cache.keys() if k.endswith(suffix)] - for key in keys_to_delete: - del self._cache[key] - if keys_to_delete: - logger.debug(f"Evicted {len(keys_to_delete)} tokens for provider: {provider_id}") - return len(keys_to_delete) - - def clear(self) -> None: - """Clear all cached tokens.""" - size = len(self._cache) - self._cache.clear() - logger.info(f"Cleared token cache ({size} entries)") - - def get_stats(self) -> dict: - """Get cache statistics.""" - return { - "size": len(self._cache), - "maxsize": self._maxsize, - "ttl_seconds": self._ttl, - } - - -# Singleton instance -_token_cache: Optional[TokenCache] = None - - -def get_token_cache() -> TokenCache: - """Get the token cache singleton.""" - global _token_cache - if _token_cache is None: - _token_cache = TokenCache() - return _token_cache diff --git a/backend/src/apis/shared/oauth/token_repository.py b/backend/src/apis/shared/oauth/token_repository.py deleted file mode 100644 index a695b148..00000000 --- a/backend/src/apis/shared/oauth/token_repository.py +++ /dev/null @@ -1,337 +0,0 @@ -"""DynamoDB repository for OAuth user tokens.""" - -import logging -import os -from datetime import datetime, timezone -from typing import List, Optional - -import boto3 -from botocore.exceptions import ClientError - -from .models import OAuthUserToken, OAuthConnectionStatus - -logger = logging.getLogger(__name__) - - -class OAuthTokenRepository: - """ - Repository for OAuth user token CRUD operations in DynamoDB. - - Handles encrypted token storage with KMS. - Uses single-table design with GSI for querying by provider. - """ - - def __init__( - self, - table_name: Optional[str] = None, - region: Optional[str] = None, - ): - """ - Initialize repository. - - Args: - table_name: DynamoDB table name (defaults to env var) - region: AWS region (defaults to env var) - """ - self._table_name = table_name or os.getenv("DYNAMODB_OAUTH_USER_TOKENS_TABLE_NAME") - self._region = region or os.getenv("AWS_REGION", "us-west-2") - self._enabled = bool(self._table_name) - - if not self._enabled: - logger.warning( - "DYNAMODB_OAUTH_USER_TOKENS_TABLE_NAME not set. " - "OAuth token repository is disabled." - ) - return - - # Initialize client - profile = os.getenv("AWS_PROFILE") - if profile: - session = boto3.Session(profile_name=profile) - self._dynamodb = session.resource("dynamodb", region_name=self._region) - else: - self._dynamodb = boto3.resource("dynamodb", region_name=self._region) - - self._table = self._dynamodb.Table(self._table_name) - logger.info(f"Initialized OAuth token repository: table={self._table_name}") - - @property - def enabled(self) -> bool: - """Check if repository is enabled.""" - return self._enabled - - # ========================================================================= - # Token CRUD - # ========================================================================= - - async def get_token( - self, user_id: str, provider_id: str - ) -> Optional[OAuthUserToken]: - """ - Get a user's token for a provider. - - Args: - user_id: User identifier - provider_id: Provider identifier - - Returns: - OAuthUserToken if found, None otherwise - """ - if not self._enabled: - return None - - try: - response = self._table.get_item( - Key={ - "PK": f"USER#{user_id}", - "SK": f"PROVIDER#{provider_id}", - } - ) - item = response.get("Item") - if not item: - return None - return OAuthUserToken.from_dynamo_item(item) - - except ClientError as e: - logger.error(f"Error getting token for user {user_id}, provider {provider_id}: {e}") - raise - - async def list_user_tokens(self, user_id: str) -> List[OAuthUserToken]: - """ - List all tokens for a user. - - Args: - user_id: User identifier - - Returns: - List of OAuthUserToken objects - """ - if not self._enabled: - return [] - - try: - response = self._table.query( - KeyConditionExpression="PK = :pk AND begins_with(SK, :sk_prefix)", - ExpressionAttributeValues={ - ":pk": f"USER#{user_id}", - ":sk_prefix": "PROVIDER#", - }, - ) - items = response.get("Items", []) - - # Handle pagination - while "LastEvaluatedKey" in response: - response = self._table.query( - KeyConditionExpression="PK = :pk AND begins_with(SK, :sk_prefix)", - ExpressionAttributeValues={ - ":pk": f"USER#{user_id}", - ":sk_prefix": "PROVIDER#", - }, - ExclusiveStartKey=response["LastEvaluatedKey"], - ) - items.extend(response.get("Items", [])) - - return [OAuthUserToken.from_dynamo_item(item) for item in items] - - except ClientError as e: - logger.error(f"Error listing tokens for user {user_id}: {e}") - raise - - async def list_provider_tokens(self, provider_id: str) -> List[OAuthUserToken]: - """ - List all user tokens for a provider (admin view). - - Uses GSI for efficient lookup. - - Args: - provider_id: Provider identifier - - Returns: - List of OAuthUserToken objects - """ - if not self._enabled: - return [] - - try: - response = self._table.query( - IndexName="ProviderUsersIndex", - KeyConditionExpression="GSI1PK = :pk", - ExpressionAttributeValues={":pk": f"PROVIDER#{provider_id}"}, - ) - items = response.get("Items", []) - - # Handle pagination - while "LastEvaluatedKey" in response: - response = self._table.query( - IndexName="ProviderUsersIndex", - KeyConditionExpression="GSI1PK = :pk", - ExpressionAttributeValues={":pk": f"PROVIDER#{provider_id}"}, - ExclusiveStartKey=response["LastEvaluatedKey"], - ) - items.extend(response.get("Items", [])) - - return [OAuthUserToken.from_dynamo_item(item) for item in items] - - except ClientError as e: - logger.error(f"Error listing tokens for provider {provider_id}: {e}") - raise - - async def save_token(self, token: OAuthUserToken) -> OAuthUserToken: - """ - Save or update a user token. - - Args: - token: Token to save - - Returns: - Saved token - """ - if not self._enabled: - raise RuntimeError("OAuth token repository is not enabled") - - try: - token.updated_at = datetime.now(timezone.utc).isoformat() + "Z" - self._table.put_item(Item=token.to_dynamo_item()) - logger.info(f"Saved token for user {token.user_id}, provider {token.provider_id}") - return token - - except ClientError as e: - logger.error(f"Error saving token: {e}") - raise - - async def update_token_status( - self, - user_id: str, - provider_id: str, - status: OAuthConnectionStatus, - ) -> Optional[OAuthUserToken]: - """ - Update token status. - - Args: - user_id: User identifier - provider_id: Provider identifier - status: New status - - Returns: - Updated token, or None if not found - """ - if not self._enabled: - return None - - token = await self.get_token(user_id, provider_id) - if not token: - return None - - token.status = status - return await self.save_token(token) - - async def delete_token(self, user_id: str, provider_id: str) -> bool: - """ - Delete a user's token for a provider. - - Args: - user_id: User identifier - provider_id: Provider identifier - - Returns: - True if deleted, False if not found - """ - if not self._enabled: - return False - - try: - # Check if exists first - existing = await self.get_token(user_id, provider_id) - if not existing: - return False - - self._table.delete_item( - Key={ - "PK": f"USER#{user_id}", - "SK": f"PROVIDER#{provider_id}", - } - ) - - logger.info(f"Deleted token for user {user_id}, provider {provider_id}") - return True - - except ClientError as e: - logger.error(f"Error deleting token: {e}") - raise - - async def delete_user_tokens(self, user_id: str) -> int: - """ - Delete all tokens for a user. - - Args: - user_id: User identifier - - Returns: - Number of tokens deleted - """ - if not self._enabled: - return 0 - - try: - tokens = await self.list_user_tokens(user_id) - - with self._table.batch_writer() as batch: - for token in tokens: - batch.delete_item( - Key={ - "PK": f"USER#{user_id}", - "SK": f"PROVIDER#{token.provider_id}", - } - ) - - logger.info(f"Deleted {len(tokens)} tokens for user {user_id}") - return len(tokens) - - except ClientError as e: - logger.error(f"Error deleting tokens for user {user_id}: {e}") - raise - - async def delete_provider_tokens(self, provider_id: str) -> int: - """ - Delete all tokens for a provider (when provider is deleted). - - Args: - provider_id: Provider identifier - - Returns: - Number of tokens deleted - """ - if not self._enabled: - return 0 - - try: - tokens = await self.list_provider_tokens(provider_id) - - with self._table.batch_writer() as batch: - for token in tokens: - batch.delete_item( - Key={ - "PK": f"USER#{token.user_id}", - "SK": f"PROVIDER#{provider_id}", - } - ) - - logger.info(f"Deleted {len(tokens)} tokens for provider {provider_id}") - return len(tokens) - - except ClientError as e: - logger.error(f"Error deleting tokens for provider {provider_id}: {e}") - raise - - -# Singleton instance -_token_repository: Optional[OAuthTokenRepository] = None - - -def get_token_repository() -> OAuthTokenRepository: - """Get the token repository singleton.""" - global _token_repository - if _token_repository is None: - _token_repository = OAuthTokenRepository() - return _token_repository diff --git a/backend/tests/shared/test_conftest_smoke.py b/backend/tests/shared/test_conftest_smoke.py index 5ad26732..9aa29a88 100644 --- a/backend/tests/shared/test_conftest_smoke.py +++ b/backend/tests/shared/test_conftest_smoke.py @@ -71,9 +71,6 @@ def test_auth_provider_repository(self, auth_provider_repository): def test_oauth_provider_repository(self, oauth_provider_repository): assert oauth_provider_repository.enabled - def test_oauth_token_repository(self, oauth_token_repository): - assert oauth_token_repository.enabled - def test_file_repository(self, file_repository): assert file_repository is not None diff --git a/backend/tests/shared/test_coverage_boost.py b/backend/tests/shared/test_coverage_boost.py index 72af7502..0f27173e 100644 --- a/backend/tests/shared/test_coverage_boost.py +++ b/backend/tests/shared/test_coverage_boost.py @@ -185,85 +185,6 @@ def test_convert_message_to_response(self): assert resp.role == "user" -class TestOAuthServiceExtended: - """Cover generate_pkce_pair, state management, disconnect, get_user_connections.""" - - def test_generate_pkce_pair(self): - from apis.shared.oauth.service import generate_pkce_pair - verifier, challenge = generate_pkce_pair() - assert len(verifier) > 20 - assert len(challenge) > 20 - assert verifier != challenge - - @pytest.mark.asyncio - async def test_disconnect(self, oauth_providers_table, oauth_tokens_table, secrets_manager, kms_key_arn, monkeypatch): - monkeypatch.setenv("OAUTH_CALLBACK_URL", "http://localhost/callback") - from apis.shared.oauth.service import OAuthService - from apis.shared.oauth.provider_repository import OAuthProviderRepository - from apis.shared.oauth.token_repository import OAuthTokenRepository - from apis.shared.oauth.encryption import TokenEncryptionService - from apis.shared.oauth.token_cache import TokenCache - from apis.shared.oauth.models import OAuthUserToken, OAuthConnectionStatus - - pr = OAuthProviderRepository(table_name="test-oauth-providers", secrets_arn="oauth-client-secrets", region="us-east-1") - tr = OAuthTokenRepository(table_name="test-oauth-user-tokens", region="us-east-1") - enc = TokenEncryptionService(key_arn=kms_key_arn, region="us-east-1") - cache = TokenCache() - - # Save a token - token = OAuthUserToken( - user_id="u1", provider_id="github", - access_token_encrypted=enc.encrypt("tok123"), - status=OAuthConnectionStatus.CONNECTED, - connected_at="2026-01-01", - ) - await tr.save_token(token) - - svc = OAuthService(provider_repo=pr, token_repo=tr, encryption_service=enc, token_cache=cache) - assert await svc.disconnect("u1", "github") is True - assert await svc.disconnect("u1", "github") is False # already deleted - - @pytest.mark.asyncio - async def test_get_decrypted_token_cached(self, oauth_tokens_table, monkeypatch): - monkeypatch.setenv("OAUTH_CALLBACK_URL", "http://localhost/callback") - from apis.shared.oauth.service import OAuthService - from apis.shared.oauth.token_cache import TokenCache - - cache = TokenCache() - cache.set("u1", "github", "cached_token") - svc = OAuthService(token_cache=cache) - result = await svc.get_decrypted_token("u1", "github") - assert result == "cached_token" - - @pytest.mark.asyncio - async def test_get_user_connections(self, oauth_providers_table, oauth_tokens_table, secrets_manager, kms_key_arn, monkeypatch): - monkeypatch.setenv("OAUTH_CALLBACK_URL", "http://localhost/callback") - from apis.shared.oauth.service import OAuthService - from apis.shared.oauth.provider_repository import OAuthProviderRepository - from apis.shared.oauth.token_repository import OAuthTokenRepository - from apis.shared.oauth.encryption import TokenEncryptionService - from apis.shared.oauth.token_cache import TokenCache - from apis.shared.oauth.models import OAuthProviderCreate, OAuthProviderType - - pr = OAuthProviderRepository(table_name="test-oauth-providers", secrets_arn="oauth-client-secrets", region="us-east-1") - tr = OAuthTokenRepository(table_name="test-oauth-user-tokens", region="us-east-1") - enc = TokenEncryptionService(key_arn=kms_key_arn, region="us-east-1") - cache = TokenCache() - - # Create a provider - await pr.create_provider(OAuthProviderCreate( - provider_id="github", display_name="GitHub", provider_type=OAuthProviderType.GITHUB, - authorization_endpoint="https://github.com/login/oauth/authorize", - token_endpoint="https://github.com/login/oauth/access_token", - client_id="cid", client_secret="csec", scopes=["repo"], allowed_roles=["viewer"], - )) - - svc = OAuthService(provider_repo=pr, token_repo=tr, encryption_service=enc, token_cache=cache) - connections = await svc.get_user_connections("u1", ["viewer"]) - assert len(connections) == 1 - assert connections[0].provider_id == "github" - - class TestStateStore: def test_in_memory_store_and_retrieve(self): from apis.shared.auth.state_store import InMemoryStateStore, OIDCStateData diff --git a/backend/tests/shared/test_oauth_service.py b/backend/tests/shared/test_oauth_service.py deleted file mode 100644 index 8b4085cf..00000000 --- a/backend/tests/shared/test_oauth_service.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Task 7: OAuth encryption (moto KMS), token cache, and service tests.""" - -import time -import pytest -from unittest.mock import AsyncMock, MagicMock, patch - - -# =================================================================== -# TokenEncryptionService (moto KMS) -# =================================================================== - -class TestTokenEncryptionService: - def test_encrypt_decrypt_roundtrip(self, kms_key_arn): - from apis.shared.oauth.encryption import TokenEncryptionService - svc = TokenEncryptionService(key_arn=kms_key_arn, region="us-east-1") - ciphertext = svc.encrypt("my-secret-token") - assert ciphertext != "my-secret-token" - plaintext = svc.decrypt(ciphertext) - assert plaintext == "my-secret-token" - - def test_disabled_without_key(self): - from apis.shared.oauth.encryption import TokenEncryptionService - svc = TokenEncryptionService(key_arn=None) - assert svc.enabled is False - - def test_encrypt_disabled_returns_plaintext(self): - from apis.shared.oauth.encryption import TokenEncryptionService - svc = TokenEncryptionService(key_arn=None) - result = svc.encrypt("token") - assert result.startswith("DEV:") - - def test_decrypt_disabled_returns_ciphertext(self): - from apis.shared.oauth.encryption import TokenEncryptionService - svc = TokenEncryptionService(key_arn=None) - encrypted = svc.encrypt("token") - assert svc.decrypt(encrypted) == "token" - - -# =================================================================== -# TokenCache (pure in-memory) -# =================================================================== - -class TestTokenCache: - @pytest.fixture() - def cache(self): - from apis.shared.oauth.token_cache import TokenCache - return TokenCache() - - def test_set_and_get(self, cache): - cache.set("u1", "p1", "token-abc") - assert cache.get("u1", "p1") == "token-abc" - - def test_get_missing(self, cache): - assert cache.get("u1", "p1") is None - - def test_delete(self, cache): - cache.set("u1", "p1", "token") - assert cache.delete("u1", "p1") is True - assert cache.get("u1", "p1") is None - - def test_delete_missing(self, cache): - assert cache.delete("u1", "p1") is False - - def test_delete_for_user(self, cache): - cache.set("u1", "p1", "t1") - cache.set("u1", "p2", "t2") - cache.set("u2", "p1", "t3") - count = cache.delete_for_user("u1") - assert count == 2 - assert cache.get("u2", "p1") == "t3" - - def test_delete_for_provider(self, cache): - cache.set("u1", "p1", "t1") - cache.set("u2", "p1", "t2") - cache.set("u1", "p2", "t3") - count = cache.delete_for_provider("p1") - assert count == 2 - assert cache.get("u1", "p2") == "t3" - - def test_clear(self, cache): - cache.set("u1", "p1", "t") - cache.clear() - assert cache.get("u1", "p1") is None - - def test_get_stats(self, cache): - cache.set("u1", "p1", "t") - stats = cache.get_stats() - assert stats["size"] == 1 - - -# =================================================================== -# OAuthService -# =================================================================== - -class TestOAuthServicePKCE: - def test_generate_pkce_pair(self): - from apis.shared.oauth.service import generate_pkce_pair - verifier, challenge = generate_pkce_pair() - assert len(verifier) > 20 - assert len(challenge) > 20 - assert verifier != challenge - - -class TestOAuthServiceConnect: - @pytest.fixture() - def oauth_service(self, oauth_provider_repository, oauth_token_repository, kms_key_arn, monkeypatch): - monkeypatch.setenv("OAUTH_CALLBACK_URL", "http://localhost:8000/api/oauth/callback") - from apis.shared.oauth.service import OAuthService - from apis.shared.oauth.encryption import TokenEncryptionService - from apis.shared.oauth.token_cache import TokenCache - from apis.shared.auth.state_store import InMemoryStateStore - - enc = TokenEncryptionService(key_arn=kms_key_arn, region="us-east-1") - cache = TokenCache() - state_store = InMemoryStateStore() - - return OAuthService( - provider_repo=oauth_provider_repository, - token_repo=oauth_token_repository, - encryption_service=enc, - token_cache=cache, - state_store=state_store, - ) - - @pytest.mark.asyncio - async def test_initiate_connect(self, oauth_service, oauth_provider_repository): - from apis.shared.oauth.models import OAuthProviderCreate - await oauth_provider_repository.create_provider( - OAuthProviderCreate( - provider_id="github", display_name="GitHub", - provider_type="github", client_id="cid", client_secret="secret", - authorization_endpoint="https://github.com/login/oauth/authorize", - token_endpoint="https://github.com/login/oauth/access_token", - scopes=["repo"], allowed_roles=["editor"], - ) - ) - result = await oauth_service.initiate_connect( - provider_id="github", user_id="u1", - user_roles=["editor"], - frontend_redirect="http://localhost/callback" - ) - assert "github.com" in result - - @pytest.mark.asyncio - async def test_get_user_connections_empty(self, oauth_service): - connections = await oauth_service.get_user_connections("u1", user_roles=["editor"]) - assert connections == [] - - @pytest.mark.asyncio - async def test_disconnect_nonexistent(self, oauth_service): - # Should not raise - result = await oauth_service.disconnect("u1", "github") - assert result is True or result is False # implementation-dependent From 0a1d99b96c88fb43de9c01240ccdbff184e3ff85 Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Wed, 22 Apr 2026 10:41:41 -0600 Subject: [PATCH 07/35] refactor(connectors): slim OAuth provider model to AgentCore shape MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit AgentCore Identity owns the clientId, clientSecret, endpoint config, and callback URL. Our DynamoDB record keeps only the admin metadata (display name, scopes, role gates, icon) plus cached pointers to AgentCore's record (credential_provider_arn, callback_url) for convenience. Drops authorization_endpoint, token_endpoint, authorization_params, userinfo_endpoint, revocation_endpoint, pkce_required, OAuthUserToken, and the user-side connection DTOs — all artifacts of the retired in-house flow. Adds oauth_discovery_url and authorization_server_metadata for Custom/Canvas providers, gated by a pydantic validator. Repository surface tightens to put_provider + apply_metadata_update; the Secrets Manager write/read path is gone. Admin routes (commit next) own the AgentCore round-trip and hand a fully-formed record to the repo. Co-Authored-By: Claude Opus 4.7 (1M context) --- backend/src/apis/shared/oauth/models.py | 300 ++++++-------- .../apis/shared/oauth/provider_repository.py | 365 +++--------------- backend/tests/shared/conftest.py | 9 +- backend/tests/shared/test_models_and_utils.py | 35 +- .../tests/shared/test_oauth_repos_extended.py | 198 ---------- .../tests/shared/test_oauth_repositories.py | 152 +++----- 6 files changed, 228 insertions(+), 831 deletions(-) delete mode 100644 backend/tests/shared/test_oauth_repos_extended.py diff --git a/backend/src/apis/shared/oauth/models.py b/backend/src/apis/shared/oauth/models.py index 5c1a384d..ebbf0aba 100644 --- a/backend/src/apis/shared/oauth/models.py +++ b/backend/src/apis/shared/oauth/models.py @@ -1,4 +1,10 @@ -"""OAuth models for provider configuration and user tokens.""" +"""OAuth provider models. + +Providers are registered and administered through AWS Bedrock AgentCore +Identity — AgentCore owns `clientId`, `clientSecret`, endpoint config, and +the callback URL. Our DynamoDB record keeps the display metadata, scopes, +role gates, and cached pointers (ARN + callback URL) for convenience. +""" import hashlib import logging @@ -7,13 +13,18 @@ from enum import Enum from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, ConfigDict, Field, model_validator logger = logging.getLogger(__name__) class OAuthProviderType(str, Enum): - """Supported OAuth provider types.""" + """Supported OAuth provider types. + + `CANVAS` routes through AgentCore's `CustomOauth2` vendor but is kept + as a distinct type so the admin UI can surface Canvas-specific guidance + if/when we add a preset. Today the admin form treats it as Custom. + """ GOOGLE = "google" MICROSOFT = "microsoft" @@ -22,27 +33,8 @@ class OAuthProviderType(str, Enum): CUSTOM = "custom" -class OAuthConnectionStatus(str, Enum): - """Connection status for user OAuth tokens.""" - - CONNECTED = "connected" - EXPIRED = "expired" - REVOKED = "revoked" - NEEDS_REAUTH = "needs_reauth" - - def compute_scopes_hash(scopes: List[str]) -> str: - """ - Compute a hash of the scopes list for change detection. - - Used to detect when provider scopes change and user needs to re-authenticate. - - Args: - scopes: List of OAuth scopes - - Returns: - SHA-256 hash of sorted scopes - """ + """Return a short, order-independent hash of `scopes` for change detection.""" sorted_scopes = sorted(scopes) scopes_str = ",".join(sorted_scopes) return hashlib.sha256(scopes_str.encode()).hexdigest()[:16] @@ -50,288 +42,224 @@ def compute_scopes_hash(scopes: List[str]) -> str: @dataclass class OAuthProvider: - """OAuth provider configuration stored in DynamoDB.""" + """OAuth provider record stored in DynamoDB. + + AgentCore-owned fields (`credential_provider_arn`, `callback_url`) are + populated after a successful registration and kept in sync on update. + They are cached for admin UX — the source of truth lives in AgentCore. + """ provider_id: str display_name: str provider_type: OAuthProviderType - authorization_endpoint: str - token_endpoint: str - client_id: str scopes: List[str] allowed_roles: List[str] # AppRole IDs that can use this provider enabled: bool = True - icon_name: str = "heroLink" # Default icon - userinfo_endpoint: Optional[str] = None # Optional userinfo endpoint - revocation_endpoint: Optional[str] = None # Optional token revocation endpoint - pkce_required: bool = True # PKCE is required by default for security - authorization_params: Dict[str, str] = field(default_factory=dict) # Extra params for auth URL (e.g., access_type=offline) + icon_name: str = "heroLink" + credential_provider_arn: Optional[str] = None + callback_url: Optional[str] = None + # Custom vendor only — mirrors AgentCore's Oauth2Discovery union. + # Exactly one of these is populated when `provider_type` is CUSTOM or + # CANVAS; both are None for Google/Microsoft/GitHub. + oauth_discovery_url: Optional[str] = None + authorization_server_metadata: Optional[Dict[str, Any]] = None created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat() + "Z") updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat() + "Z") @property def scopes_hash(self) -> str: - """Get the scopes hash for this provider.""" return compute_scopes_hash(self.scopes) def to_dynamo_item(self) -> Dict[str, Any]: - """Convert to DynamoDB item format.""" return { "PK": f"PROVIDER#{self.provider_id}", "SK": "CONFIG", - # GSI for enabled providers "GSI1PK": f"ENABLED#{str(self.enabled).lower()}", "GSI1SK": f"PROVIDER#{self.provider_id}", - # Main attributes "providerId": self.provider_id, "displayName": self.display_name, "providerType": self.provider_type.value, - "authorizationEndpoint": self.authorization_endpoint, - "tokenEndpoint": self.token_endpoint, - "clientId": self.client_id, "scopes": self.scopes, "scopesHash": self.scopes_hash, "allowedRoles": self.allowed_roles, "enabled": self.enabled, "iconName": self.icon_name, - "userinfoEndpoint": self.userinfo_endpoint, - "revocationEndpoint": self.revocation_endpoint, - "pkceRequired": self.pkce_required, - "authorizationParams": self.authorization_params, + "credentialProviderArn": self.credential_provider_arn, + "callbackUrl": self.callback_url, + "oauthDiscoveryUrl": self.oauth_discovery_url, + "authorizationServerMetadata": self.authorization_server_metadata, "createdAt": self.created_at, "updatedAt": self.updated_at, } @classmethod def from_dynamo_item(cls, item: Dict[str, Any]) -> "OAuthProvider": - """Create from DynamoDB item.""" return cls( provider_id=item["providerId"], display_name=item["displayName"], provider_type=OAuthProviderType(item["providerType"]), - authorization_endpoint=item["authorizationEndpoint"], - token_endpoint=item["tokenEndpoint"], - client_id=item["clientId"], scopes=item.get("scopes", []), allowed_roles=item.get("allowedRoles", []), enabled=item.get("enabled", True), icon_name=item.get("iconName", "heroLink"), - userinfo_endpoint=item.get("userinfoEndpoint"), - revocation_endpoint=item.get("revocationEndpoint"), - pkce_required=item.get("pkceRequired", True), - authorization_params=item.get("authorizationParams", {}), + credential_provider_arn=item.get("credentialProviderArn"), + callback_url=item.get("callbackUrl"), + oauth_discovery_url=item.get("oauthDiscoveryUrl"), + authorization_server_metadata=item.get("authorizationServerMetadata"), created_at=item.get("createdAt", datetime.now(timezone.utc).isoformat() + "Z"), updated_at=item.get("updatedAt", datetime.now(timezone.utc).isoformat() + "Z"), ) -@dataclass -class OAuthUserToken: - """User's OAuth token stored in DynamoDB (encrypted).""" - - user_id: str - provider_id: str - access_token_encrypted: str # KMS-encrypted access token - refresh_token_encrypted: Optional[str] = None # KMS-encrypted refresh token - token_type: str = "Bearer" - expires_at: Optional[int] = None # Unix timestamp - scopes_hash: str = "" # Hash of scopes at time of authorization - status: OAuthConnectionStatus = OAuthConnectionStatus.CONNECTED - connected_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat() + "Z") - updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat() + "Z") - - @property - def is_expired(self) -> bool: - """Check if token has expired.""" - if not self.expires_at: - return False - import time - - return time.time() > self.expires_at - - def to_dynamo_item(self) -> Dict[str, Any]: - """Convert to DynamoDB item format.""" - item = { - "PK": f"USER#{self.user_id}", - "SK": f"PROVIDER#{self.provider_id}", - # GSI for listing users by provider - "GSI1PK": f"PROVIDER#{self.provider_id}", - "GSI1SK": f"USER#{self.user_id}", - # Main attributes - "userId": self.user_id, - "providerId": self.provider_id, - "accessTokenEncrypted": self.access_token_encrypted, - "tokenType": self.token_type, - "scopesHash": self.scopes_hash, - "status": self.status.value, - "connectedAt": self.connected_at, - "updatedAt": self.updated_at, - } - - if self.refresh_token_encrypted: - item["refreshTokenEncrypted"] = self.refresh_token_encrypted - - if self.expires_at: - item["expiresAt"] = self.expires_at - - return item - - @classmethod - def from_dynamo_item(cls, item: Dict[str, Any]) -> "OAuthUserToken": - """Create from DynamoDB item.""" - return cls( - user_id=item["userId"], - provider_id=item["providerId"], - access_token_encrypted=item["accessTokenEncrypted"], - refresh_token_encrypted=item.get("refreshTokenEncrypted"), - token_type=item.get("tokenType", "Bearer"), - expires_at=item.get("expiresAt"), - scopes_hash=item.get("scopesHash", ""), - status=OAuthConnectionStatus(item.get("status", "connected")), - connected_at=item.get("connectedAt", datetime.now(timezone.utc).isoformat() + "Z"), - updated_at=item.get("updatedAt", datetime.now(timezone.utc).isoformat() + "Z"), - ) - - # ============================================================================= -# Pydantic Request/Response Models +# Pydantic request/response models # ============================================================================= +_CUSTOM_TYPES = {OAuthProviderType.CUSTOM, OAuthProviderType.CANVAS} + + class OAuthProviderCreate(BaseModel): - """Request model for creating an OAuth provider.""" + """Request model for creating an OAuth provider. + + `client_id` and `client_secret` are forwarded to AgentCore Identity and + are never persisted in our DynamoDB table. For Custom/Canvas providers + the caller must supply exactly one of `oauth_discovery_url` or + `authorization_server_metadata`. + """ provider_id: str = Field(..., min_length=1, max_length=64, pattern=r"^[a-z0-9-]+$") display_name: str = Field(..., min_length=1, max_length=128) provider_type: OAuthProviderType - authorization_endpoint: str = Field(..., min_length=1) - token_endpoint: str = Field(..., min_length=1) client_id: str = Field(..., min_length=1) - client_secret: str = Field(..., min_length=1) # Will be stored in Secrets Manager + client_secret: str = Field(..., min_length=1) scopes: List[str] = Field(default_factory=list) allowed_roles: List[str] = Field(default_factory=list) enabled: bool = True icon_name: str = "heroLink" - userinfo_endpoint: Optional[str] = None - revocation_endpoint: Optional[str] = None - pkce_required: bool = True - authorization_params: Dict[str, str] = Field(default_factory=dict) + oauth_discovery_url: Optional[str] = None + authorization_server_metadata: Optional[Dict[str, Any]] = None model_config = ConfigDict(json_schema_extra={ "example": { "provider_id": "google-workspace", "display_name": "Google Workspace", "provider_type": "google", - "authorization_endpoint": "https://accounts.google.com/o/oauth2/v2/auth", - "token_endpoint": "https://oauth2.googleapis.com/token", "client_id": "your-client-id.apps.googleusercontent.com", "client_secret": "your-client-secret", - "scopes": ["openid", "email", "profile", "https://www.googleapis.com/auth/drive.readonly"], + "scopes": ["openid", "email", "profile"], "allowed_roles": ["admin", "user"], "enabled": True, "icon_name": "heroCloud", - "pkce_required": True, - "authorization_params": {"access_type": "offline", "prompt": "consent"}, } }) + @model_validator(mode="after") + def _validate_discovery(self) -> "OAuthProviderCreate": + if self.provider_type in _CUSTOM_TYPES: + if bool(self.oauth_discovery_url) == bool(self.authorization_server_metadata): + raise ValueError( + "Custom providers require exactly one of " + "oauth_discovery_url or authorization_server_metadata" + ) + elif self.oauth_discovery_url or self.authorization_server_metadata: + raise ValueError( + f"Discovery config is only valid for custom/canvas providers; " + f"provider_type={self.provider_type.value} does not accept it" + ) + return self + class OAuthProviderUpdate(BaseModel): - """Request model for updating an OAuth provider.""" + """Request model for updating an OAuth provider. + + Credential rotation requires both `client_id` and `client_secret` + because AgentCore's update API demands the full config and does not + echo back the stored secret. Partial edits to metadata (display name, + scopes, roles, icon, enabled) are allowed without touching credentials. + """ display_name: Optional[str] = Field(None, min_length=1, max_length=128) - authorization_endpoint: Optional[str] = None - token_endpoint: Optional[str] = None client_id: Optional[str] = None - client_secret: Optional[str] = None # Only if rotating secret + client_secret: Optional[str] = None scopes: Optional[List[str]] = None allowed_roles: Optional[List[str]] = None enabled: Optional[bool] = None icon_name: Optional[str] = None - userinfo_endpoint: Optional[str] = None - revocation_endpoint: Optional[str] = None - pkce_required: Optional[bool] = None - authorization_params: Optional[Dict[str, str]] = None + oauth_discovery_url: Optional[str] = None + authorization_server_metadata: Optional[Dict[str, Any]] = None + + @model_validator(mode="after") + def _validate_credential_pair(self) -> "OAuthProviderUpdate": + if bool(self.client_id) != bool(self.client_secret): + raise ValueError( + "client_id and client_secret must be provided together for rotation" + ) + if self.oauth_discovery_url and self.authorization_server_metadata: + raise ValueError( + "oauth_discovery_url and authorization_server_metadata are mutually exclusive" + ) + return self class OAuthProviderResponse(BaseModel): - """Response model for an OAuth provider (excludes secrets).""" + """Response model for an OAuth provider.""" provider_id: str display_name: str provider_type: OAuthProviderType - authorization_endpoint: str - token_endpoint: str - client_id: str scopes: List[str] allowed_roles: List[str] enabled: bool icon_name: str - userinfo_endpoint: Optional[str] = None - revocation_endpoint: Optional[str] = None - pkce_required: bool - authorization_params: Dict[str, str] + credential_provider_arn: Optional[str] = None + callback_url: Optional[str] = None + oauth_discovery_url: Optional[str] = None + authorization_server_metadata: Optional[Dict[str, Any]] = None created_at: str updated_at: str @classmethod def from_provider(cls, provider: OAuthProvider) -> "OAuthProviderResponse": - """Create from OAuthProvider dataclass.""" return cls( provider_id=provider.provider_id, display_name=provider.display_name, provider_type=provider.provider_type, - authorization_endpoint=provider.authorization_endpoint, - token_endpoint=provider.token_endpoint, - client_id=provider.client_id, scopes=provider.scopes, allowed_roles=provider.allowed_roles, enabled=provider.enabled, icon_name=provider.icon_name, - userinfo_endpoint=provider.userinfo_endpoint, - revocation_endpoint=provider.revocation_endpoint, - pkce_required=provider.pkce_required, - authorization_params=provider.authorization_params, + credential_provider_arn=provider.credential_provider_arn, + callback_url=provider.callback_url, + oauth_discovery_url=provider.oauth_discovery_url, + authorization_server_metadata=provider.authorization_server_metadata, created_at=provider.created_at, updated_at=provider.updated_at, ) class OAuthProviderListResponse(BaseModel): - """Response model for listing OAuth providers.""" - providers: List[OAuthProviderResponse] total: int -class OAuthConnectionResponse(BaseModel): - """Response model for a user's OAuth connection.""" - - provider_id: str - display_name: str - provider_type: OAuthProviderType - icon_name: str - status: OAuthConnectionStatus - connected_at: Optional[str] = None - needs_reauth: bool = False - - -class OAuthConnectionListResponse(BaseModel): - """Response model for listing user's OAuth connections.""" - - connections: List[OAuthConnectionResponse] - +class OAuthRequiredEvent(BaseModel): + """SSE event signalling that a tool needs user consent before it can run. -class OAuthConnectResponse(BaseModel): - """Response model for initiating OAuth connection.""" - - authorization_url: str + Emitted by the inference route after an agent response finishes, one per + provider with a pending consent URL. The frontend uses it to render a + "Connect to X" affordance that opens `authorizationUrl` in a popup. + """ + model_config = ConfigDict(populate_by_name=True) -class OAuthCallbackResult(BaseModel): - """Internal model for OAuth callback result.""" + type: str = "oauth_required" + provider_id: str = Field(..., alias="providerId") + authorization_url: str = Field(..., alias="authorizationUrl") - success: bool - provider_id: Optional[str] = None - error: Optional[str] = None - error_description: Optional[str] = None + def to_sse_format(self) -> str: + import json + return ( + f"event: oauth_required\n" + f"data: {json.dumps(self.model_dump(by_alias=True, exclude_none=True))}\n\n" + ) diff --git a/backend/src/apis/shared/oauth/provider_repository.py b/backend/src/apis/shared/oauth/provider_repository.py index 31df2825..793a4b49 100644 --- a/backend/src/apis/shared/oauth/provider_repository.py +++ b/backend/src/apis/shared/oauth/provider_repository.py @@ -1,6 +1,10 @@ -"""DynamoDB repository for OAuth provider configurations.""" +"""DynamoDB repository for OAuth provider configurations. + +Only display metadata, scopes, and AgentCore-owned pointers (the credential +provider ARN and callback URL) live here. `clientId` / `clientSecret` are +registered directly with AgentCore Identity by the admin route. +""" -import json import logging import os from datetime import datetime, timezone @@ -9,35 +13,20 @@ import boto3 from botocore.exceptions import ClientError -from .models import OAuthProvider, OAuthProviderCreate, OAuthProviderUpdate +from .models import OAuthProvider, OAuthProviderUpdate logger = logging.getLogger(__name__) class OAuthProviderRepository: - """ - Repository for OAuth provider CRUD operations in DynamoDB. - - Handles provider configurations and client secrets in Secrets Manager. - Uses single-table design with GSI for querying enabled providers. - """ + """CRUD over the oauth-providers DynamoDB table.""" def __init__( self, table_name: Optional[str] = None, - secrets_arn: Optional[str] = None, region: Optional[str] = None, ): - """ - Initialize repository. - - Args: - table_name: DynamoDB table name (defaults to env var) - secrets_arn: Secrets Manager ARN for client secrets (defaults to env var) - region: AWS region (defaults to env var) - """ self._table_name = table_name or os.getenv("DYNAMODB_OAUTH_PROVIDERS_TABLE_NAME") - self._secrets_arn = secrets_arn or os.getenv("OAUTH_CLIENT_SECRETS_ARN") self._region = region or os.getenv("AWS_REGION", "us-west-2") self._enabled = bool(self._table_name) @@ -48,38 +37,18 @@ def __init__( ) return - # Initialize clients profile = os.getenv("AWS_PROFILE") - if profile: - session = boto3.Session(profile_name=profile) - self._dynamodb = session.resource("dynamodb", region_name=self._region) - self._secrets_client = session.client("secretsmanager", region_name=self._region) - else: - self._dynamodb = boto3.resource("dynamodb", region_name=self._region) - self._secrets_client = boto3.client("secretsmanager", region_name=self._region) - + session = boto3.Session(profile_name=profile) if profile else boto3 + self._dynamodb = session.resource("dynamodb", region_name=self._region) self._table = self._dynamodb.Table(self._table_name) - logger.info(f"Initialized OAuth provider repository: table={self._table_name}") + logger.info("Initialized OAuth provider repository: table=%s", self._table_name) @property def enabled(self) -> bool: - """Check if repository is enabled.""" return self._enabled - # ========================================================================= - # Provider CRUD - # ========================================================================= - + # ------------------------------------------------------------------- reads async def get_provider(self, provider_id: str) -> Optional[OAuthProvider]: - """ - Get a provider by ID. - - Args: - provider_id: Provider identifier - - Returns: - OAuthProvider if found, None otherwise - """ if not self._enabled: return None @@ -88,38 +57,23 @@ async def get_provider(self, provider_id: str) -> Optional[OAuthProvider]: Key={"PK": f"PROVIDER#{provider_id}", "SK": "CONFIG"} ) item = response.get("Item") - if not item: - return None - return OAuthProvider.from_dynamo_item(item) - + return OAuthProvider.from_dynamo_item(item) if item else None except ClientError as e: - logger.error(f"Error getting provider {provider_id}: {e}") + logger.error("Error getting provider %s: %s", provider_id, e) raise async def list_providers(self, enabled_only: bool = False) -> List[OAuthProvider]: - """ - List all providers. - - Args: - enabled_only: If True, only return enabled providers - - Returns: - List of OAuthProvider objects - """ if not self._enabled: return [] try: if enabled_only: - # Use GSI for efficient query response = self._table.query( IndexName="EnabledProvidersIndex", KeyConditionExpression="GSI1PK = :pk", ExpressionAttributeValues={":pk": "ENABLED#true"}, ) items = response.get("Items", []) - - # Handle pagination while "LastEvaluatedKey" in response: response = self._table.query( IndexName="EnabledProvidersIndex", @@ -129,14 +83,11 @@ async def list_providers(self, enabled_only: bool = False) -> List[OAuthProvider ) items.extend(response.get("Items", [])) else: - # Scan all providers response = self._table.scan( FilterExpression="SK = :sk", ExpressionAttributeValues={":sk": "CONFIG"}, ) items = response.get("Items", []) - - # Handle pagination while "LastEvaluatedKey" in response: response = self._table.scan( FilterExpression="SK = :sk", @@ -146,157 +97,61 @@ async def list_providers(self, enabled_only: bool = False) -> List[OAuthProvider items.extend(response.get("Items", [])) providers = [OAuthProvider.from_dynamo_item(item) for item in items] - - # Sort by display name providers.sort(key=lambda p: p.display_name.lower()) - return providers - except ClientError as e: - logger.error(f"Error listing providers: {e}") + logger.error("Error listing providers: %s", e) raise - async def create_provider( - self, create_request: OAuthProviderCreate - ) -> OAuthProvider: - """ - Create a new provider. - - Args: - create_request: Provider creation data including client secret + # ------------------------------------------------------------------ writes + async def put_provider(self, provider: OAuthProvider) -> OAuthProvider: + """Upsert a fully-formed provider record. - Returns: - Created OAuthProvider - - Raises: - ValueError: If provider already exists + The admin route is expected to build the `OAuthProvider` with all + AgentCore-owned fields already populated from the registrar call. """ if not self._enabled: raise RuntimeError("OAuth provider repository is not enabled") - # Check if provider exists - existing = await self.get_provider(create_request.provider_id) - if existing: - raise ValueError(f"Provider '{create_request.provider_id}' already exists") - - try: - now = datetime.now(timezone.utc).isoformat() + "Z" + self._table.put_item(Item=provider.to_dynamo_item()) + logger.info("Upserted OAuth provider: %s", provider.provider_id) + return provider - # Create provider object - provider = OAuthProvider( - provider_id=create_request.provider_id, - display_name=create_request.display_name, - provider_type=create_request.provider_type, - authorization_endpoint=create_request.authorization_endpoint, - token_endpoint=create_request.token_endpoint, - client_id=create_request.client_id, - scopes=create_request.scopes, - allowed_roles=create_request.allowed_roles, - enabled=create_request.enabled, - icon_name=create_request.icon_name, - userinfo_endpoint=create_request.userinfo_endpoint, - revocation_endpoint=create_request.revocation_endpoint, - pkce_required=create_request.pkce_required, - authorization_params=create_request.authorization_params, - created_at=now, - updated_at=now, - ) - - # Store client secret in Secrets Manager - await self._store_client_secret( - create_request.provider_id, create_request.client_secret - ) - - # Store provider in DynamoDB - self._table.put_item( - Item=provider.to_dynamo_item(), - ConditionExpression="attribute_not_exists(PK)", - ) - - logger.info(f"Created OAuth provider: {provider.provider_id}") - return provider - - except ClientError as e: - if e.response["Error"]["Code"] == "ConditionalCheckFailedException": - raise ValueError( - f"Provider '{create_request.provider_id}' already exists" - ) - logger.error(f"Error creating provider: {e}") - raise - - async def update_provider( + async def apply_metadata_update( self, provider_id: str, updates: OAuthProviderUpdate ) -> Optional[OAuthProvider]: - """ - Update an existing provider. - - Args: - provider_id: Provider identifier - updates: Fields to update + """Apply a metadata-only update to an existing provider record. - Returns: - Updated OAuthProvider, or None if not found + Does not touch AgentCore — the admin route is responsible for + calling the registrar first when credentials or the discovery + config change, then passing the refreshed metadata through here. + Fields left `None` on `updates` are preserved. """ - if not self._enabled: - return None - existing = await self.get_provider(provider_id) if not existing: return None - try: - # Apply updates - if updates.display_name is not None: - existing.display_name = updates.display_name - if updates.authorization_endpoint is not None: - existing.authorization_endpoint = updates.authorization_endpoint - if updates.token_endpoint is not None: - existing.token_endpoint = updates.token_endpoint - if updates.client_id is not None: - existing.client_id = updates.client_id - if updates.scopes is not None: - existing.scopes = updates.scopes - if updates.allowed_roles is not None: - existing.allowed_roles = updates.allowed_roles - if updates.enabled is not None: - existing.enabled = updates.enabled - if updates.icon_name is not None: - existing.icon_name = updates.icon_name - if updates.userinfo_endpoint is not None: - existing.userinfo_endpoint = updates.userinfo_endpoint - if updates.revocation_endpoint is not None: - existing.revocation_endpoint = updates.revocation_endpoint - if updates.pkce_required is not None: - existing.pkce_required = updates.pkce_required - if updates.authorization_params is not None: - existing.authorization_params = updates.authorization_params - - existing.updated_at = datetime.now(timezone.utc).isoformat() + "Z" - - # Update client secret if provided - if updates.client_secret is not None: - await self._store_client_secret(provider_id, updates.client_secret) - - # Store updated provider - self._table.put_item(Item=existing.to_dynamo_item()) - - logger.info(f"Updated OAuth provider: {provider_id}") - return existing - - except ClientError as e: - logger.error(f"Error updating provider {provider_id}: {e}") - raise + if updates.display_name is not None: + existing.display_name = updates.display_name + if updates.scopes is not None: + existing.scopes = updates.scopes + if updates.allowed_roles is not None: + existing.allowed_roles = updates.allowed_roles + if updates.enabled is not None: + existing.enabled = updates.enabled + if updates.icon_name is not None: + existing.icon_name = updates.icon_name + if updates.oauth_discovery_url is not None: + existing.oauth_discovery_url = updates.oauth_discovery_url + if updates.authorization_server_metadata is not None: + existing.authorization_server_metadata = updates.authorization_server_metadata + + existing.updated_at = datetime.now(timezone.utc).isoformat() + "Z" + self._table.put_item(Item=existing.to_dynamo_item()) + logger.info("Updated OAuth provider metadata: %s", provider_id) + return existing async def delete_provider(self, provider_id: str) -> bool: - """ - Delete a provider. - - Args: - provider_id: Provider identifier - - Returns: - True if deleted, False if not found - """ if not self._enabled: return False @@ -305,139 +160,21 @@ async def delete_provider(self, provider_id: str) -> bool: return False try: - # Delete from DynamoDB self._table.delete_item( Key={"PK": f"PROVIDER#{provider_id}", "SK": "CONFIG"} ) - - # Remove client secret from Secrets Manager - await self._delete_client_secret(provider_id) - - logger.info(f"Deleted OAuth provider: {provider_id}") + logger.info("Deleted OAuth provider: %s", provider_id) return True - - except ClientError as e: - logger.error(f"Error deleting provider {provider_id}: {e}") - raise - - # ========================================================================= - # Client Secret Management (Secrets Manager) - # ========================================================================= - - async def get_client_secret(self, provider_id: str) -> Optional[str]: - """ - Get client secret for a provider from Secrets Manager. - - Args: - provider_id: Provider identifier - - Returns: - Client secret string, or None if not found - """ - if not self._secrets_arn: - logger.warning("Secrets ARN not configured, cannot retrieve client secret") - return None - - try: - response = self._secrets_client.get_secret_value( - SecretId=self._secrets_arn - ) - secrets = json.loads(response["SecretString"]) - return secrets.get(provider_id) - - except ClientError as e: - if e.response["Error"]["Code"] == "ResourceNotFoundException": - logger.warning("OAuth secrets not found in Secrets Manager") - return None - logger.error(f"Error getting client secret for {provider_id}: {e}") - raise - - async def _store_client_secret( - self, provider_id: str, client_secret: str - ) -> None: - """ - Store client secret in Secrets Manager. - - Args: - provider_id: Provider identifier - client_secret: Client secret to store - """ - if not self._secrets_arn: - logger.warning( - "Secrets ARN not configured, cannot store client secret. " - "This is only acceptable in development." - ) - return - - try: - # Get existing secrets - try: - response = self._secrets_client.get_secret_value( - SecretId=self._secrets_arn - ) - secrets = json.loads(response["SecretString"]) - except ClientError as e: - if e.response["Error"]["Code"] == "ResourceNotFoundException": - secrets = {} - else: - raise - - # Update with new secret - secrets[provider_id] = client_secret - - # Store back - self._secrets_client.put_secret_value( - SecretId=self._secrets_arn, - SecretString=json.dumps(secrets), - ) - - logger.info(f"Stored client secret for provider: {provider_id}") - except ClientError as e: - logger.error(f"Error storing client secret for {provider_id}: {e}") + logger.error("Error deleting provider %s: %s", provider_id, e) raise - async def _delete_client_secret(self, provider_id: str) -> None: - """ - Remove client secret from Secrets Manager. - - Args: - provider_id: Provider identifier - """ - if not self._secrets_arn: - return - - try: - # Get existing secrets - response = self._secrets_client.get_secret_value( - SecretId=self._secrets_arn - ) - secrets = json.loads(response["SecretString"]) - - # Remove provider's secret - if provider_id in secrets: - del secrets[provider_id] - - # Store back - self._secrets_client.put_secret_value( - SecretId=self._secrets_arn, - SecretString=json.dumps(secrets), - ) - - logger.info(f"Removed client secret for provider: {provider_id}") - - except ClientError as e: - if e.response["Error"]["Code"] != "ResourceNotFoundException": - logger.error(f"Error deleting client secret for {provider_id}: {e}") - raise - -# Singleton instance _provider_repository: Optional[OAuthProviderRepository] = None def get_provider_repository() -> OAuthProviderRepository: - """Get the provider repository singleton.""" + """Get the process-wide provider repository singleton.""" global _provider_repository if _provider_repository is None: _provider_repository = OAuthProviderRepository() diff --git a/backend/tests/shared/conftest.py b/backend/tests/shared/conftest.py index b00f2ca2..56f647c0 100644 --- a/backend/tests/shared/conftest.py +++ b/backend/tests/shared/conftest.py @@ -323,21 +323,14 @@ def auth_provider_repository(auth_providers_table, secrets_manager): @pytest.fixture() -def oauth_provider_repository(oauth_providers_table, secrets_manager): +def oauth_provider_repository(oauth_providers_table): from apis.shared.oauth.provider_repository import OAuthProviderRepository return OAuthProviderRepository( table_name="test-oauth-providers", - secrets_arn="oauth-client-secrets", region=AWS_REGION, ) -@pytest.fixture() -def oauth_token_repository(oauth_tokens_table): - from apis.shared.oauth.token_repository import OAuthTokenRepository - return OAuthTokenRepository(table_name="test-oauth-user-tokens", region=AWS_REGION) - - @pytest.fixture() def file_repository(files_table): from apis.shared.files.repository import FileUploadRepository diff --git a/backend/tests/shared/test_models_and_utils.py b/backend/tests/shared/test_models_and_utils.py index c3cb7c87..90e38a83 100644 --- a/backend/tests/shared/test_models_and_utils.py +++ b/backend/tests/shared/test_models_and_utils.py @@ -97,35 +97,28 @@ def test_compute_scopes_hash(self): def test_provider_scopes_hash_property(self): from apis.shared.oauth.models import OAuthProvider, OAuthProviderType - p = OAuthProvider(provider_id="p1", display_name="P", provider_type=OAuthProviderType.CUSTOM, client_id="c", authorization_endpoint="http://a", token_endpoint="http://t", scopes=["a", "b"], allowed_roles=[]) + p = OAuthProvider( + provider_id="p1", display_name="P", + provider_type=OAuthProviderType.GOOGLE, + scopes=["a", "b"], allowed_roles=[], + ) assert p.scopes_hash == p.scopes_hash # consistent - def test_token_is_expired(self): - from apis.shared.oauth.models import OAuthUserToken - t = OAuthUserToken(user_id="u1", provider_id="p1", access_token_encrypted="x", expires_at=946684800) - assert t.is_expired is True - - def test_token_not_expired(self): - from apis.shared.oauth.models import OAuthUserToken - t = OAuthUserToken(user_id="u1", provider_id="p1", access_token_encrypted="x", expires_at=4070908800) - assert t.is_expired is False - def test_provider_dynamo_roundtrip(self): from apis.shared.oauth.models import OAuthProvider, OAuthProviderType - p = OAuthProvider(provider_id="p1", display_name="P", provider_type=OAuthProviderType.CUSTOM, client_id="c", authorization_endpoint="http://a", token_endpoint="http://t", scopes=["a"], allowed_roles=[]) + p = OAuthProvider( + provider_id="p1", display_name="P", + provider_type=OAuthProviderType.GOOGLE, + scopes=["a"], allowed_roles=[], + credential_provider_arn="arn:aws:bedrock-agentcore:us-east-1:1:cp/p1", + callback_url="https://bedrock-agentcore.us-east-1.amazonaws.com/cb/p1", + ) item = p.to_dynamo_item() assert item["PK"] == "PROVIDER#p1" restored = OAuthProvider.from_dynamo_item(item) assert restored.provider_id == "p1" - - def test_token_dynamo_roundtrip(self): - from apis.shared.oauth.models import OAuthUserToken - t = OAuthUserToken(user_id="u1", provider_id="p1", access_token_encrypted="enc", expires_at=4070908800) - item = t.to_dynamo_item() - assert item["PK"] == "USER#u1" - assert item["SK"] == "PROVIDER#p1" - restored = OAuthUserToken.from_dynamo_item(item) - assert restored.user_id == "u1" + assert restored.callback_url == p.callback_url + assert restored.credential_provider_arn == p.credential_provider_arn # =================================================================== diff --git a/backend/tests/shared/test_oauth_repos_extended.py b/backend/tests/shared/test_oauth_repos_extended.py deleted file mode 100644 index 0c876fe5..00000000 --- a/backend/tests/shared/test_oauth_repos_extended.py +++ /dev/null @@ -1,198 +0,0 @@ -"""Extended OAuth repository tests for deeper coverage.""" - -import pytest - - -class TestOAuthProviderRepositoryExtended: - @pytest.fixture(autouse=True) - def _setup(self, oauth_provider_repository): - self.repo = oauth_provider_repository - - def _make_create(self, pid="github", **kw): - from apis.shared.oauth.models import OAuthProviderCreate, OAuthProviderType - defaults = dict( - provider_id=pid, display_name="GitHub", provider_type=OAuthProviderType.GITHUB, - authorization_endpoint="https://github.com/login/oauth/authorize", - token_endpoint="https://github.com/login/oauth/access_token", - client_id="cid", client_secret="secret", scopes=["repo"], - allowed_roles=["viewer"], - ) - defaults.update(kw) - return OAuthProviderCreate(**defaults) - - @pytest.mark.asyncio - async def test_create_and_get(self): - p = await self.repo.create_provider(self._make_create()) - assert p.provider_id == "github" - got = await self.repo.get_provider("github") - assert got is not None - assert got.display_name == "GitHub" - - @pytest.mark.asyncio - async def test_get_nonexistent(self): - assert await self.repo.get_provider("nope") is None - - @pytest.mark.asyncio - async def test_create_duplicate_raises(self): - await self.repo.create_provider(self._make_create()) - with pytest.raises(ValueError, match="already exists"): - await self.repo.create_provider(self._make_create()) - - @pytest.mark.asyncio - async def test_list_all(self): - await self.repo.create_provider(self._make_create("a")) - await self.repo.create_provider(self._make_create("b")) - providers = await self.repo.list_providers() - assert len(providers) == 2 - - @pytest.mark.asyncio - async def test_list_enabled_only(self): - await self.repo.create_provider(self._make_create("a", enabled=True)) - await self.repo.create_provider(self._make_create("b", enabled=False)) - enabled = await self.repo.list_providers(enabled_only=True) - assert all(p.enabled for p in enabled) - - @pytest.mark.asyncio - async def test_update_provider(self): - from apis.shared.oauth.models import OAuthProviderUpdate - await self.repo.create_provider(self._make_create()) - updated = await self.repo.update_provider("github", OAuthProviderUpdate(display_name="GH")) - assert updated.display_name == "GH" - - @pytest.mark.asyncio - async def test_update_nonexistent(self): - from apis.shared.oauth.models import OAuthProviderUpdate - assert await self.repo.update_provider("nope", OAuthProviderUpdate(display_name="X")) is None - - @pytest.mark.asyncio - async def test_update_client_secret(self): - from apis.shared.oauth.models import OAuthProviderUpdate - await self.repo.create_provider(self._make_create()) - await self.repo.update_provider("github", OAuthProviderUpdate(client_secret="new_secret")) - secret = await self.repo.get_client_secret("github") - assert secret == "new_secret" - - @pytest.mark.asyncio - async def test_delete_provider(self): - await self.repo.create_provider(self._make_create()) - assert await self.repo.delete_provider("github") is True - assert await self.repo.get_provider("github") is None - - @pytest.mark.asyncio - async def test_delete_nonexistent(self): - assert await self.repo.delete_provider("nope") is False - - @pytest.mark.asyncio - async def test_get_client_secret(self): - await self.repo.create_provider(self._make_create()) - secret = await self.repo.get_client_secret("github") - assert secret == "secret" - - @pytest.mark.asyncio - async def test_get_client_secret_missing(self): - secret = await self.repo.get_client_secret("nope") - assert secret is None - - @pytest.mark.asyncio - async def test_disabled_repo(self, monkeypatch): - monkeypatch.delenv("DYNAMODB_OAUTH_PROVIDERS_TABLE_NAME", raising=False) - from apis.shared.oauth.provider_repository import OAuthProviderRepository - repo = OAuthProviderRepository(table_name=None) - assert repo.enabled is False - assert await repo.get_provider("x") is None - assert await repo.list_providers() == [] - assert await repo.delete_provider("x") is False - - -class TestOAuthTokenRepositoryExtended: - @pytest.fixture(autouse=True) - def _setup(self, oauth_token_repository): - self.repo = oauth_token_repository - - def _make_token(self, user_id="u1", provider_id="github", **kw): - from apis.shared.oauth.models import OAuthUserToken, OAuthConnectionStatus - defaults = dict( - user_id=user_id, provider_id=provider_id, - access_token_encrypted="enc_tok", - status=OAuthConnectionStatus.CONNECTED, - connected_at="2026-01-01", - ) - defaults.update(kw) - return OAuthUserToken(**defaults) - - @pytest.mark.asyncio - async def test_save_and_get(self): - token = self._make_token() - saved = await self.repo.save_token(token) - assert saved.updated_at is not None - got = await self.repo.get_token("u1", "github") - assert got is not None - assert got.access_token_encrypted == "enc_tok" - - @pytest.mark.asyncio - async def test_get_nonexistent(self): - assert await self.repo.get_token("u1", "nope") is None - - @pytest.mark.asyncio - async def test_list_user_tokens(self): - await self.repo.save_token(self._make_token(provider_id="a")) - await self.repo.save_token(self._make_token(provider_id="b")) - tokens = await self.repo.list_user_tokens("u1") - assert len(tokens) == 2 - - @pytest.mark.asyncio - async def test_list_provider_tokens(self): - await self.repo.save_token(self._make_token(user_id="u1")) - await self.repo.save_token(self._make_token(user_id="u2")) - tokens = await self.repo.list_provider_tokens("github") - assert len(tokens) == 2 - - @pytest.mark.asyncio - async def test_update_token_status(self): - from apis.shared.oauth.models import OAuthConnectionStatus - await self.repo.save_token(self._make_token()) - updated = await self.repo.update_token_status("u1", "github", OAuthConnectionStatus.EXPIRED) - assert updated.status == OAuthConnectionStatus.EXPIRED - - @pytest.mark.asyncio - async def test_update_token_status_nonexistent(self): - from apis.shared.oauth.models import OAuthConnectionStatus - assert await self.repo.update_token_status("u1", "nope", OAuthConnectionStatus.EXPIRED) is None - - @pytest.mark.asyncio - async def test_delete_token(self): - await self.repo.save_token(self._make_token()) - assert await self.repo.delete_token("u1", "github") is True - assert await self.repo.get_token("u1", "github") is None - - @pytest.mark.asyncio - async def test_delete_nonexistent(self): - assert await self.repo.delete_token("u1", "nope") is False - - @pytest.mark.asyncio - async def test_delete_user_tokens(self): - await self.repo.save_token(self._make_token(provider_id="a")) - await self.repo.save_token(self._make_token(provider_id="b")) - count = await self.repo.delete_user_tokens("u1") - assert count == 2 - assert await self.repo.list_user_tokens("u1") == [] - - @pytest.mark.asyncio - async def test_delete_provider_tokens(self): - await self.repo.save_token(self._make_token(user_id="u1")) - await self.repo.save_token(self._make_token(user_id="u2")) - count = await self.repo.delete_provider_tokens("github") - assert count == 2 - - @pytest.mark.asyncio - async def test_disabled_repo(self, monkeypatch): - monkeypatch.delenv("DYNAMODB_OAUTH_USER_TOKENS_TABLE_NAME", raising=False) - from apis.shared.oauth.token_repository import OAuthTokenRepository - repo = OAuthTokenRepository(table_name=None) - assert repo.enabled is False - assert await repo.get_token("u1", "x") is None - assert await repo.list_user_tokens("u1") == [] - assert await repo.list_provider_tokens("x") == [] - assert await repo.delete_token("u1", "x") is False - assert await repo.delete_user_tokens("u1") == 0 - assert await repo.delete_provider_tokens("x") == 0 diff --git a/backend/tests/shared/test_oauth_repositories.py b/backend/tests/shared/test_oauth_repositories.py index 9035b2b9..381ecf0f 100644 --- a/backend/tests/shared/test_oauth_repositories.py +++ b/backend/tests/shared/test_oauth_repositories.py @@ -1,44 +1,38 @@ -"""Task 6: OAuth provider + token repositories (moto DynamoDB + Secrets Manager).""" +"""OAuth provider repository tests (moto DynamoDB).""" import pytest -from apis.shared.oauth.models import OAuthProvider, OAuthProviderType, OAuthUserToken, OAuthConnectionStatus - -def _make_provider(provider_id="github", **kw): - defaults = dict( - provider_id=provider_id, display_name="GitHub", provider_type=OAuthProviderType.GITHUB, - authorization_endpoint="https://github.com/login/oauth/authorize", - token_endpoint="https://github.com/login/oauth/access_token", - client_id="cid", scopes=["repo"], allowed_roles=["editor"], - ) - defaults.update(kw) - return defaults +from apis.shared.oauth.models import ( + OAuthProvider, + OAuthProviderType, + OAuthProviderUpdate, +) +from apis.shared.oauth.provider_repository import OAuthProviderRepository -def _make_token(user_id="u1", provider_id="github", **kw): +def _make_provider(provider_id="github", **kw) -> OAuthProvider: defaults = dict( - user_id=user_id, provider_id=provider_id, - access_token_encrypted="enc-token", token_type="Bearer", - scopes_hash="abc", status=OAuthConnectionStatus.CONNECTED, + provider_id=provider_id, + display_name="GitHub", + provider_type=OAuthProviderType.GITHUB, + scopes=["repo"], + allowed_roles=["editor"], + credential_provider_arn=f"arn:aws:bedrock-agentcore:us-east-1:1:cp/{provider_id}", + callback_url=f"https://bedrock-agentcore.us-east-1.amazonaws.com/cb/{provider_id}", ) defaults.update(kw) - return OAuthUserToken(**defaults) + return OAuthProvider(**defaults) -# =================================================================== -# OAuthProviderRepository -# =================================================================== - class TestOAuthProviderRepository: @pytest.mark.asyncio - async def test_create_and_get(self, oauth_provider_repository): - from apis.shared.oauth.models import OAuthProviderCreate - data = OAuthProviderCreate(**_make_provider(), client_secret="secret") - provider = await oauth_provider_repository.create_provider(data) - assert provider.provider_id == "github" + async def test_put_and_get(self, oauth_provider_repository): + await oauth_provider_repository.put_provider(_make_provider()) result = await oauth_provider_repository.get_provider("github") assert result is not None assert result.display_name == "GitHub" + assert result.callback_url.endswith("/cb/github") + assert result.credential_provider_arn @pytest.mark.asyncio async def test_get_nonexistent(self, oauth_provider_repository): @@ -46,101 +40,51 @@ async def test_get_nonexistent(self, oauth_provider_repository): @pytest.mark.asyncio async def test_list_all(self, oauth_provider_repository): - from apis.shared.oauth.models import OAuthProviderCreate - await oauth_provider_repository.create_provider(OAuthProviderCreate(**_make_provider("p1"), client_secret="s")) - await oauth_provider_repository.create_provider(OAuthProviderCreate(**_make_provider("p2"), client_secret="s")) + await oauth_provider_repository.put_provider(_make_provider("p1")) + await oauth_provider_repository.put_provider(_make_provider("p2")) providers = await oauth_provider_repository.list_providers() assert len(providers) == 2 @pytest.mark.asyncio async def test_list_enabled_only(self, oauth_provider_repository): - from apis.shared.oauth.models import OAuthProviderCreate - await oauth_provider_repository.create_provider(OAuthProviderCreate(**_make_provider("p1"), client_secret="s")) - await oauth_provider_repository.create_provider(OAuthProviderCreate(**_make_provider("p2", enabled=False), client_secret="s")) + await oauth_provider_repository.put_provider(_make_provider("p1")) + await oauth_provider_repository.put_provider(_make_provider("p2", enabled=False)) providers = await oauth_provider_repository.list_providers(enabled_only=True) assert len(providers) == 1 + assert providers[0].provider_id == "p1" + + @pytest.mark.asyncio + async def test_apply_metadata_update(self, oauth_provider_repository): + await oauth_provider_repository.put_provider(_make_provider()) + updated = await oauth_provider_repository.apply_metadata_update( + "github", + OAuthProviderUpdate(display_name="GH", scopes=["repo", "read:user"]), + ) + assert updated.display_name == "GH" + assert updated.scopes == ["repo", "read:user"] + + @pytest.mark.asyncio + async def test_apply_metadata_update_nonexistent(self, oauth_provider_repository): + updates = OAuthProviderUpdate(display_name="X") + assert await oauth_provider_repository.apply_metadata_update("nope", updates) is None @pytest.mark.asyncio async def test_delete_provider(self, oauth_provider_repository): - from apis.shared.oauth.models import OAuthProviderCreate - await oauth_provider_repository.create_provider(OAuthProviderCreate(**_make_provider(), client_secret="s")) + await oauth_provider_repository.put_provider(_make_provider()) assert await oauth_provider_repository.delete_provider("github") is True assert await oauth_provider_repository.get_provider("github") is None @pytest.mark.asyncio - async def test_client_secret(self, oauth_provider_repository): - from apis.shared.oauth.models import OAuthProviderCreate - await oauth_provider_repository.create_provider(OAuthProviderCreate(**_make_provider(), client_secret="my-secret")) - secret = await oauth_provider_repository.get_client_secret("github") - assert secret == "my-secret" + async def test_delete_nonexistent(self, oauth_provider_repository): + assert await oauth_provider_repository.delete_provider("nope") is False def test_disabled_when_no_table(self): - from apis.shared.oauth.provider_repository import OAuthProviderRepository repo = OAuthProviderRepository(table_name=None) assert repo.enabled is False - -# =================================================================== -# OAuthTokenRepository -# =================================================================== - -class TestOAuthTokenRepository: - @pytest.mark.asyncio - async def test_save_and_get(self, oauth_token_repository): - token = _make_token() - saved = await oauth_token_repository.save_token(token) - assert saved.user_id == "u1" - result = await oauth_token_repository.get_token("u1", "github") - assert result is not None - assert result.access_token_encrypted == "enc-token" - - @pytest.mark.asyncio - async def test_get_nonexistent(self, oauth_token_repository): - assert await oauth_token_repository.get_token("u1", "nope") is None - - @pytest.mark.asyncio - async def test_list_user_tokens(self, oauth_token_repository): - await oauth_token_repository.save_token(_make_token(provider_id="p1")) - await oauth_token_repository.save_token(_make_token(provider_id="p2")) - tokens = await oauth_token_repository.list_user_tokens("u1") - assert len(tokens) == 2 - - @pytest.mark.asyncio - async def test_list_provider_tokens(self, oauth_token_repository): - await oauth_token_repository.save_token(_make_token(user_id="u1")) - await oauth_token_repository.save_token(_make_token(user_id="u2")) - tokens = await oauth_token_repository.list_provider_tokens("github") - assert len(tokens) == 2 - - @pytest.mark.asyncio - async def test_update_token_status(self, oauth_token_repository): - await oauth_token_repository.save_token(_make_token()) - updated = await oauth_token_repository.update_token_status("u1", "github", OAuthConnectionStatus.EXPIRED) - assert updated is not None - assert updated.status == OAuthConnectionStatus.EXPIRED - @pytest.mark.asyncio - async def test_delete_token(self, oauth_token_repository): - await oauth_token_repository.save_token(_make_token()) - assert await oauth_token_repository.delete_token("u1", "github") is True - assert await oauth_token_repository.get_token("u1", "github") is None - - @pytest.mark.asyncio - async def test_delete_user_tokens(self, oauth_token_repository): - await oauth_token_repository.save_token(_make_token(provider_id="p1")) - await oauth_token_repository.save_token(_make_token(provider_id="p2")) - count = await oauth_token_repository.delete_user_tokens("u1") - assert count == 2 - assert len(await oauth_token_repository.list_user_tokens("u1")) == 0 - - @pytest.mark.asyncio - async def test_delete_provider_tokens(self, oauth_token_repository): - await oauth_token_repository.save_token(_make_token(user_id="u1")) - await oauth_token_repository.save_token(_make_token(user_id="u2")) - count = await oauth_token_repository.delete_provider_tokens("github") - assert count == 2 - - def test_disabled_when_no_table(self): - from apis.shared.oauth.token_repository import OAuthTokenRepository - repo = OAuthTokenRepository(table_name=None) - assert repo.enabled is False + async def test_disabled_repo_is_inert(self): + repo = OAuthProviderRepository(table_name=None) + assert await repo.get_provider("x") is None + assert await repo.list_providers() == [] + assert await repo.delete_provider("x") is False From 748e74f510a7c2bfa4f50b07b54bc3936b0d2573 Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Wed, 22 Apr 2026 10:41:50 -0600 Subject: [PATCH 08/35] refactor(connectors): route admin OAuth CRUD through AgentCore Identity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit POST now calls the registrar first and, on success, upserts the metadata record in DynamoDB. If the DB write fails after AgentCore has accepted the credentials, we best-effort delete the AgentCore provider to avoid orphans. PATCH distinguishes metadata-only edits (scopes, roles, display name, icon, enabled) from credential rotation. Rotation requires clientId + clientSecret together — partial updates are rejected by AgentCore's UpdateOauth2CredentialProvider contract. DELETE removes the AgentCore provider first (which revokes every user token stored in its vault), then the local record. Pre-existing connection- count checks are dropped since per-user tokens no longer live in our DB. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../src/apis/app_api/admin/oauth/routes.py | 346 ++++++++---------- 1 file changed, 160 insertions(+), 186 deletions(-) diff --git a/backend/src/apis/app_api/admin/oauth/routes.py b/backend/src/apis/app_api/admin/oauth/routes.py index e5dee9b1..76ea385d 100644 --- a/backend/src/apis/app_api/admin/oauth/routes.py +++ b/backend/src/apis/app_api/admin/oauth/routes.py @@ -1,11 +1,27 @@ -"""Admin API routes for OAuth provider management.""" +"""Admin API routes for OAuth provider management. + +Registration flows through AWS Bedrock AgentCore Identity. Our DynamoDB +record holds display metadata, scopes, and role gates; the AgentCore +credential provider owns `clientId`, `clientSecret`, endpoint config, and +the callback URL that the admin must register with the vendor. +""" import logging +from datetime import datetime, timezone from fastapi import APIRouter, Depends, HTTPException, Query, status from apis.shared.auth import User, require_admin +from apis.shared.oauth.agentcore_registrar import ( + AgentCoreRegistrar, + CredentialProviderConflictError, + CredentialProviderInfo, + CredentialProviderNotFoundError, + InvalidCustomProviderConfigError, + get_agentcore_registrar, +) from apis.shared.oauth.models import ( + OAuthProvider, OAuthProviderCreate, OAuthProviderListResponse, OAuthProviderResponse, @@ -15,11 +31,6 @@ OAuthProviderRepository, get_provider_repository, ) -from apis.shared.oauth.token_repository import ( - OAuthTokenRepository, - get_token_repository, -) -from apis.shared.oauth.token_cache import get_token_cache logger = logging.getLogger(__name__) @@ -37,22 +48,9 @@ async def list_providers( admin: User = Depends(require_admin), provider_repo: OAuthProviderRepository = Depends(get_provider_repository), ): - """ - List all OAuth providers. - - Requires admin access. - - Args: - enabled_only: If True, only return enabled providers - admin: Authenticated admin user (injected) - - Returns: - OAuthProviderListResponse with all providers - """ + """List all OAuth providers. Admin only.""" logger.info("Admin listing OAuth providers") - providers = await provider_repo.list_providers(enabled_only=enabled_only) - return OAuthProviderListResponse( providers=[OAuthProviderResponse.from_provider(p) for p in providers], total=len(providers), @@ -65,67 +63,73 @@ async def get_provider( admin: User = Depends(require_admin), provider_repo: OAuthProviderRepository = Depends(get_provider_repository), ): - """ - Get a provider by ID. - - Requires admin access. - - Args: - provider_id: Provider identifier - admin: Authenticated admin user (injected) - - Returns: - OAuthProviderResponse with provider details - - Raises: - HTTPException: 404 if provider not found - """ - logger.info("Admin getting OAuth provider") - + """Get a provider by ID. Admin only.""" provider = await provider_repo.get_provider(provider_id) - if not provider: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Provider '{provider_id}' not found", ) - return OAuthProviderResponse.from_provider(provider) -@router.post("/", response_model=OAuthProviderResponse, status_code=status.HTTP_201_CREATED) +@router.post( + "/", response_model=OAuthProviderResponse, status_code=status.HTTP_201_CREATED +) async def create_provider( provider_data: OAuthProviderCreate, admin: User = Depends(require_admin), provider_repo: OAuthProviderRepository = Depends(get_provider_repository), + registrar: AgentCoreRegistrar = Depends(get_agentcore_registrar), ): - """ - Create a new OAuth provider. - - Requires admin access. - - Args: - provider_data: Provider creation data - admin: Authenticated admin user (injected) + """Register a new OAuth provider. - Returns: - Created OAuthProviderResponse - - Raises: - HTTPException: 400 if provider already exists or validation fails + Registers credentials with AgentCore Identity first; on success, writes + the metadata record to DynamoDB. If the DB write fails after AgentCore + has accepted the credentials, best-effort rolls back the AgentCore + provider so the two stores stay in sync. """ - logger.info("Admin creating OAuth provider") + logger.info("Admin creating OAuth provider %s", provider_data.provider_id) + + existing = await provider_repo.get_provider(provider_data.provider_id) + if existing: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"Provider '{provider_data.provider_id}' already exists", + ) try: - provider = await provider_repo.create_provider(provider_data) - return OAuthProviderResponse.from_provider(provider) + credential_info = registrar.create_credential_provider( + provider_id=provider_data.provider_id, + provider_type=provider_data.provider_type, + client_id=provider_data.client_id, + client_secret=provider_data.client_secret, + discovery_url=provider_data.oauth_discovery_url, + authorization_server_metadata=provider_data.authorization_server_metadata, + ) + except CredentialProviderConflictError as err: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(err)) + except InvalidCustomProviderConfigError as err: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(err)) - except ValueError as e: - logger.warning(f"Provider creation failed: {e}") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=str(e), + try: + provider = _build_provider_from_create(provider_data, credential_info) + await provider_repo.put_provider(provider) + except Exception: + logger.exception( + "DB write failed for %s; rolling back AgentCore credential provider", + provider_data.provider_id, ) + try: + registrar.delete_credential_provider(provider_data.provider_id) + except Exception: + logger.exception( + "Rollback delete failed for %s; orphaned AgentCore provider may exist", + provider_data.provider_id, + ) + raise + + return OAuthProviderResponse.from_provider(provider) @router.patch("/{provider_id}", response_model=OAuthProviderResponse) @@ -134,54 +138,82 @@ async def update_provider( updates: OAuthProviderUpdate, admin: User = Depends(require_admin), provider_repo: OAuthProviderRepository = Depends(get_provider_repository), + registrar: AgentCoreRegistrar = Depends(get_agentcore_registrar), ): - """ - Update an OAuth provider. - - Requires admin access. - - Note: If scopes are updated, existing user connections may need to re-authenticate. - The system tracks scope changes via hash and will prompt users to re-auth. + """Update a provider's metadata, and optionally rotate credentials. - Args: - provider_id: Provider identifier - updates: Fields to update - admin: Authenticated admin user (injected) - - Returns: - Updated OAuthProviderResponse - - Raises: - HTTPException: - - 400 if validation fails - - 404 if provider not found + Metadata edits (display name, scopes, roles, icon, enabled) write + straight to DynamoDB. Credential or discovery-config changes require + a corresponding AgentCore update — this is done first, and only if it + succeeds do we persist the new metadata and cached pointers. """ - logger.info("Admin updating OAuth provider") - - # Track if scopes changed (will invalidate cached tokens) - old_provider = await provider_repo.get_provider(provider_id) - scopes_changed = ( - old_provider - and updates.scopes is not None - and set(updates.scopes) != set(old_provider.scopes) - ) + existing = await provider_repo.get_provider(provider_id) + if not existing: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Provider '{provider_id}' not found", + ) - provider = await provider_repo.update_provider(provider_id, updates) + rotating_credentials = bool(updates.client_id and updates.client_secret) + changing_discovery = ( + updates.oauth_discovery_url is not None + or updates.authorization_server_metadata is not None + ) - if not provider: + credential_info: CredentialProviderInfo | None = None + if rotating_credentials or changing_discovery: + discovery_url = ( + updates.oauth_discovery_url + if updates.oauth_discovery_url is not None + else existing.oauth_discovery_url + ) + authorization_server_metadata = ( + updates.authorization_server_metadata + if updates.authorization_server_metadata is not None + else existing.authorization_server_metadata + ) + if not rotating_credentials: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=( + "Discovery config can only be updated together with a " + "credential rotation (client_id + client_secret)." + ), + ) + try: + credential_info = registrar.update_credential_provider( + provider_id=provider_id, + provider_type=existing.provider_type, + client_id=updates.client_id, + client_secret=updates.client_secret, + discovery_url=discovery_url, + authorization_server_metadata=authorization_server_metadata, + ) + except CredentialProviderNotFoundError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=( + f"AgentCore credential provider for '{provider_id}' " + "was not found. The DynamoDB record may be stale." + ), + ) + except InvalidCustomProviderConfigError as err: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=str(err) + ) + + provider = await provider_repo.apply_metadata_update(provider_id, updates) + if provider is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Provider '{provider_id}' not found", ) - # Invalidate cached tokens for this provider if scopes changed - if scopes_changed: - cache = get_token_cache() - evicted = cache.delete_for_provider(provider_id) - logger.info( - "Scopes changed for provider, " - f"evicted {evicted} cached tokens" - ) + if credential_info is not None: + provider.credential_provider_arn = credential_info.credential_provider_arn + provider.callback_url = credential_info.callback_url + provider.updated_at = datetime.now(timezone.utc).isoformat() + "Z" + await provider_repo.put_provider(provider) return OAuthProviderResponse.from_provider(provider) @@ -189,107 +221,49 @@ async def update_provider( @router.delete("/{provider_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_provider( provider_id: str, - force: bool = Query( - False, - description="Force delete even if users are connected (will delete their tokens)", - ), admin: User = Depends(require_admin), provider_repo: OAuthProviderRepository = Depends(get_provider_repository), - token_repo: OAuthTokenRepository = Depends(get_token_repository), + registrar: AgentCoreRegistrar = Depends(get_agentcore_registrar), ): - """ - Delete an OAuth provider. - - Requires admin access. - - Warning: If users are connected to this provider, their tokens will be deleted - unless force=False (default), in which case the deletion will fail. + """Delete a provider from AgentCore and DynamoDB. - Args: - provider_id: Provider identifier - force: If True, delete even if users are connected - admin: Authenticated admin user (injected) - - Raises: - HTTPException: - - 400 if users are connected and force=False - - 404 if provider not found + AgentCore's deletion also removes every user token stored in its vault + for this provider, so connected users must reconnect the next time + they invoke a tool that needs it. """ - logger.info("Admin deleting OAuth provider") - - # Check for connected users - connected_tokens = await token_repo.list_provider_tokens(provider_id) - - if connected_tokens and not force: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=( - f"Cannot delete provider with {len(connected_tokens)} connected users. " - "Use force=true to delete anyway (will remove user connections)." - ), - ) - - # Delete user tokens if any - if connected_tokens: - deleted_count = await token_repo.delete_provider_tokens(provider_id) - logger.info(f"Deleted {deleted_count} user tokens for provider") - - # Delete provider - deleted = await provider_repo.delete_provider(provider_id) - - if not deleted: + existing = await provider_repo.get_provider(provider_id) + if not existing: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Provider '{provider_id}' not found", ) - # Clear cached tokens - cache = get_token_cache() - cache.delete_for_provider(provider_id) - + registrar.delete_credential_provider(provider_id) + await provider_repo.delete_provider(provider_id) return None # ============================================================================= -# Provider Statistics +# Helpers # ============================================================================= -@router.get("/{provider_id}/connections/count") -async def get_provider_connection_count( - provider_id: str, - admin: User = Depends(require_admin), - provider_repo: OAuthProviderRepository = Depends(get_provider_repository), - token_repo: OAuthTokenRepository = Depends(get_token_repository), -): - """ - Get the number of users connected to a provider. - - Requires admin access. - - Args: - provider_id: Provider identifier - admin: Authenticated admin user (injected) - - Returns: - Dict with provider_id and connection_count - - Raises: - HTTPException: 404 if provider not found - """ - logger.info("Admin getting connection count for provider") - - # Verify provider exists - provider = await provider_repo.get_provider(provider_id) - if not provider: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Provider '{provider_id}' not found", - ) - - tokens = await token_repo.list_provider_tokens(provider_id) - - return { - "provider_id": provider_id, - "connection_count": len(tokens), - } +def _build_provider_from_create( + data: OAuthProviderCreate, credential_info: CredentialProviderInfo +) -> OAuthProvider: + now = datetime.now(timezone.utc).isoformat() + "Z" + return OAuthProvider( + provider_id=data.provider_id, + display_name=data.display_name, + provider_type=data.provider_type, + scopes=data.scopes, + allowed_roles=data.allowed_roles, + enabled=data.enabled, + icon_name=data.icon_name, + credential_provider_arn=credential_info.credential_provider_arn, + callback_url=credential_info.callback_url, + oauth_discovery_url=data.oauth_discovery_url, + authorization_server_metadata=data.authorization_server_metadata, + created_at=now, + updated_at=now, + ) From 677ca5a44b26e7d45bd9ea38983b939fb58f3f93 Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Wed, 22 Apr 2026 10:43:17 -0600 Subject: [PATCH 09/35] refactor(connectors): rewire frontend for AgentCore flow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Admin side: - Rename admin/oauth-providers → admin/connectors (file + route); old route path redirects for URL stability - Rewrite the admin model to the AgentCore-owned shape: drop endpoint fields, authorization_params, pkce_required, userinfo/revocation endpoints. Add credential_provider_arn, callback_url, and oauth_discovery_url / authorization_server_metadata for Custom vendors - Rewrite the admin form: preset picker simplified to display metadata only, Custom requires an OIDC discovery URL, credential rotation requires clientId + clientSecret together (AgentCore's update API is not partial), success screen after create displays the AgentCore callback URL with a copy button so the admin can paste it into the vendor console, edit mode shows the callback URL + ARN read-only User-facing retirement: - Delete settings/connectors (user "my connected accounts" page), settings/oauth-callback (legacy 3LO return handler), and the sidebar + route entries for them. AgentCore Identity owns the consent flow at runtime via the existing /oauth-complete landing page Co-Authored-By: Claude Opus 4.7 (1M context) --- .../ai.client/src/app/admin/admin.page.ts | 6 +- .../src/app/admin/connectors/index.ts | 9 + .../connectors/models/connector.model.ts | 148 ++++ .../pages/connector-form.page.ts} | 656 +++++++++--------- .../pages/connector-list.page.ts} | 152 ++-- .../services/connectors.service.spec.ts} | 56 +- .../connectors/services/connectors.service.ts | 113 +++ .../src/app/admin/oauth-providers/index.ts | 9 - .../models/oauth-provider.model.ts | 186 ----- .../services/oauth-providers.service.ts | 154 ---- frontend/ai.client/src/app/app.routes.ts | 39 +- .../src/app/settings/connectors/index.ts | 2 - .../app/settings/connectors/models/index.ts | 1 - .../models/oauth-connector.model.ts | 60 -- .../services/connectors.service.spec.ts | 81 --- .../connectors/services/connectors.service.ts | 103 --- .../app/settings/connectors/services/index.ts | 1 - .../oauth-callback/oauth-callback.page.ts | 655 ----------------- .../connectors-settings.page.ts | 308 -------- .../src/app/settings/settings.page.ts | 1 - .../src/app/settings/settings.routes.ts | 10 - 21 files changed, 728 insertions(+), 2022 deletions(-) create mode 100644 frontend/ai.client/src/app/admin/connectors/index.ts create mode 100644 frontend/ai.client/src/app/admin/connectors/models/connector.model.ts rename frontend/ai.client/src/app/admin/{oauth-providers/pages/provider-form.page.ts => connectors/pages/connector-form.page.ts} (50%) rename frontend/ai.client/src/app/admin/{oauth-providers/pages/provider-list.page.ts => connectors/pages/connector-list.page.ts} (77%) rename frontend/ai.client/src/app/admin/{oauth-providers/services/oauth-providers.service.spec.ts => connectors/services/connectors.service.spec.ts} (59%) create mode 100644 frontend/ai.client/src/app/admin/connectors/services/connectors.service.ts delete mode 100644 frontend/ai.client/src/app/admin/oauth-providers/index.ts delete mode 100644 frontend/ai.client/src/app/admin/oauth-providers/models/oauth-provider.model.ts delete mode 100644 frontend/ai.client/src/app/admin/oauth-providers/services/oauth-providers.service.ts delete mode 100644 frontend/ai.client/src/app/settings/connectors/index.ts delete mode 100644 frontend/ai.client/src/app/settings/connectors/models/index.ts delete mode 100644 frontend/ai.client/src/app/settings/connectors/models/oauth-connector.model.ts delete mode 100644 frontend/ai.client/src/app/settings/connectors/services/connectors.service.spec.ts delete mode 100644 frontend/ai.client/src/app/settings/connectors/services/connectors.service.ts delete mode 100644 frontend/ai.client/src/app/settings/connectors/services/index.ts delete mode 100644 frontend/ai.client/src/app/settings/oauth-callback/oauth-callback.page.ts delete mode 100644 frontend/ai.client/src/app/settings/pages/connectors-settings/connectors-settings.page.ts diff --git a/frontend/ai.client/src/app/admin/admin.page.ts b/frontend/ai.client/src/app/admin/admin.page.ts index dba672ef..c5b4166b 100644 --- a/frontend/ai.client/src/app/admin/admin.page.ts +++ b/frontend/ai.client/src/app/admin/admin.page.ts @@ -110,10 +110,10 @@ export class AdminPage { route: '/admin/auth-providers', }, { - title: 'OAuth Providers', - description: 'Configure third-party OAuth integrations for MCP tool authentication. Manage Google, Microsoft, GitHub, and custom providers.', + title: 'Connectors', + description: 'Configure third-party OAuth integrations that users can connect for MCP tool authentication. Manage Google, Microsoft, GitHub, and custom connectors.', icon: 'heroLink', - route: '/admin/oauth-providers', + route: '/admin/connectors', }, { title: 'Fine-Tuning Access', diff --git a/frontend/ai.client/src/app/admin/connectors/index.ts b/frontend/ai.client/src/app/admin/connectors/index.ts new file mode 100644 index 00000000..c87a34f0 --- /dev/null +++ b/frontend/ai.client/src/app/admin/connectors/index.ts @@ -0,0 +1,9 @@ +// Models +export * from './models/connector.model'; + +// Services +export * from './services/connectors.service'; + +// Pages +export * from './pages/connector-list.page'; +export * from './pages/connector-form.page'; diff --git a/frontend/ai.client/src/app/admin/connectors/models/connector.model.ts b/frontend/ai.client/src/app/admin/connectors/models/connector.model.ts new file mode 100644 index 00000000..1b6c4c19 --- /dev/null +++ b/frontend/ai.client/src/app/admin/connectors/models/connector.model.ts @@ -0,0 +1,148 @@ +/** + * Connector type enumeration. + * + * `canvas` routes through AgentCore's `CustomOauth2` vendor but is kept + * distinct so the UI can surface Canvas-specific guidance if we add a + * preset later. Today it is treated like `custom`. + */ +export type ConnectorType = 'google' | 'microsoft' | 'github' | 'canvas' | 'custom'; + +/** + * Connector record as returned by the admin API. + * + * AgentCore Identity is authoritative for `clientId`, `clientSecret`, the + * vendor-specific endpoint config, and `callbackUrl`. Our backend caches + * the ARN and callback URL on the record for admin convenience. + */ +export interface Connector { + providerId: string; + displayName: string; + providerType: ConnectorType; + scopes: string[]; + allowedRoles: string[]; + enabled: boolean; + iconName: string; + credentialProviderArn?: string | null; + callbackUrl?: string | null; + /** Custom/Canvas only — OIDC discovery URL or explicit server metadata. */ + oauthDiscoveryUrl?: string | null; + authorizationServerMetadata?: Record | null; + createdAt: string; + updatedAt: string; +} + +/** + * Response model for listing connectors. + * + * The backend still returns the array under `providers` — we preserve the + * field name to match the wire format exactly. + */ +export interface ConnectorListResponse { + providers: Connector[]; + total: number; +} + +/** + * Create request. `clientId` and `clientSecret` are forwarded to AgentCore + * Identity and are never stored in our DynamoDB. Custom/Canvas providers + * must supply exactly one of `oauthDiscoveryUrl` or + * `authorizationServerMetadata`. + */ +export interface ConnectorCreateRequest { + providerId: string; + displayName: string; + providerType: ConnectorType; + clientId: string; + clientSecret: string; + scopes: string[]; + allowedRoles?: string[]; + enabled?: boolean; + iconName?: string; + oauthDiscoveryUrl?: string; + authorizationServerMetadata?: Record; +} + +/** + * Update request. Credential rotation requires `clientId` and + * `clientSecret` together; metadata-only edits leave them undefined. + */ +export interface ConnectorUpdateRequest { + displayName?: string; + clientId?: string; + clientSecret?: string; + scopes?: string[]; + allowedRoles?: string[]; + enabled?: boolean; + iconName?: string; + oauthDiscoveryUrl?: string; + authorizationServerMetadata?: Record; +} + +/** + * Form data bound to the connector form. Scopes are a comma-separated + * string for admin entry; parsed into `string[]` before submit. + */ +export interface ConnectorFormData { + providerId: string; + displayName: string; + providerType: ConnectorType; + clientId: string; + clientSecret: string; + scopes: string; + allowedRoles: string[]; + enabled: boolean; + iconName: string; + oauthDiscoveryUrl: string; +} + +/** + * Preset configuration for the connector picker. Endpoints are owned by + * AgentCore Identity and not configurable here. + */ +export interface ConnectorPreset { + type: ConnectorType; + displayName: string; + defaultScopes: string[]; + iconName: string; + /** Optional hint shown to the admin when selecting the preset. */ + hint?: string; +} + +export const CONNECTOR_PRESETS: ConnectorPreset[] = [ + { + type: 'google', + displayName: 'Google', + defaultScopes: ['openid', 'email', 'profile'], + iconName: 'heroCloud', + }, + { + type: 'microsoft', + displayName: 'Microsoft', + defaultScopes: ['openid', 'email', 'profile', 'offline_access'], + iconName: 'heroCloud', + }, + { + type: 'github', + displayName: 'GitHub', + defaultScopes: ['read:user', 'user:email'], + iconName: 'heroCodeBracket', + }, + { + type: 'custom', + displayName: 'Custom (OIDC)', + defaultScopes: [], + iconName: 'heroLink', + hint: 'Requires an OpenID Connect discovery URL', + }, +]; + +export function getConnectorPreset(type: ConnectorType): ConnectorPreset | undefined { + return CONNECTOR_PRESETS.find(preset => preset.type === type); +} + +/** + * True when the provider type needs an OIDC discovery URL. + */ +export function requiresDiscovery(type: ConnectorType): boolean { + return type === 'custom' || type === 'canvas'; +} diff --git a/frontend/ai.client/src/app/admin/oauth-providers/pages/provider-form.page.ts b/frontend/ai.client/src/app/admin/connectors/pages/connector-form.page.ts similarity index 50% rename from frontend/ai.client/src/app/admin/oauth-providers/pages/provider-form.page.ts rename to frontend/ai.client/src/app/admin/connectors/pages/connector-form.page.ts index cf5dab75..566c8f79 100644 --- a/frontend/ai.client/src/app/admin/oauth-providers/pages/provider-form.page.ts +++ b/frontend/ai.client/src/app/admin/connectors/pages/connector-form.page.ts @@ -22,28 +22,30 @@ import { heroEyeSlash, heroExclamationTriangle, heroCheckCircle, + heroClipboard, + heroClipboardDocumentCheck, } from '@ng-icons/heroicons/outline'; -import { OAuthProvidersService } from '../services/oauth-providers.service'; +import { ConnectorsService } from '../services/connectors.service'; import { AppRolesService } from '../../roles/services/app-roles.service'; import { - OAuthProviderCreateRequest, - OAuthProviderUpdateRequest, - OAuthProviderType, - OAUTH_PROVIDER_PRESETS, - getProviderPreset, -} from '../models/oauth-provider.model'; + Connector, + ConnectorCreateRequest, + ConnectorUpdateRequest, + ConnectorType, + CONNECTOR_PRESETS, + getConnectorPreset, + requiresDiscovery, +} from '../models/connector.model'; import { TooltipDirective } from '../../../components/tooltip/tooltip.directive'; -interface ProviderFormGroup { +interface ConnectorFormGroup { providerId: FormControl; displayName: FormControl; - providerType: FormControl; - authorizationEndpoint: FormControl; - tokenEndpoint: FormControl; + providerType: FormControl; clientId: FormControl; clientSecret: FormControl; + oauthDiscoveryUrl: FormControl; scopes: FormControl; - authorizationParams: FormControl; allowedRoles: FormControl; grantAllRoles: FormControl; enabled: FormControl; @@ -51,7 +53,7 @@ interface ProviderFormGroup { } @Component({ - selector: 'app-provider-form', + selector: 'app-connector-form', changeDetection: ChangeDetectionStrategy.OnPush, imports: [ReactiveFormsModule, NgIcon, TooltipDirective], providers: [ @@ -62,93 +64,131 @@ interface ProviderFormGroup { heroEyeSlash, heroExclamationTriangle, heroCheckCircle, + heroClipboard, + heroClipboardDocumentCheck, }), ], - host: { - class: 'block', - }, + host: { class: 'block' }, template: `
- -

{{ pageTitle() }}

- {{ isEditMode() ? 'Update OAuth provider settings and credentials' : 'Configure a new OAuth provider for tool authentication' }} + {{ isEditMode() ? 'Update connector settings and credentials' : 'Register a new OAuth connector' }}

- @if (loading()) {
-
-

- Loading provider... +

+

Loading connector...

+
+
+ } @else if (createdConnector(); as created) { + +
+
+
+ +
+

+ Connector created +

+

+ "{{ created.displayName }}" is registered with AgentCore Identity. +

+
+
+
+ +
+

+ Next step: register this callback URL +

+

+ Add the following URL to your OAuth provider's list of authorized redirect URIs + (in Google Cloud Console, Microsoft Entra, GitHub OAuth App settings, etc.). + Until this is done, users will see an error when they try to consent.

+
+ + +
+
+ +
+
} @else { - -
+ - @if (!isEditMode()) {
-

- Provider Type -

+

Connector Type

- Select a preset or configure a custom OAuth 2.0 provider. + Choose a preset or use Custom for any OIDC-compliant provider.

- -
+
@for (preset of presets; track preset.type) { }
} -
-

- Basic Information -

- +

Basic Information

-

Unique identifier. Lowercase letters, numbers, and hyphens only.

- @if (providerForm.controls.providerId.invalid && providerForm.controls.providerId.touched) { + @if (connectorForm.controls.providerId.invalid && connectorForm.controls.providerId.touched) {

- @if (providerForm.controls.providerId.errors?.['required']) { - Provider ID is required - } @else if (providerForm.controls.providerId.errors?.['pattern']) { - Must be lowercase letters, numbers, and hyphens only - } @else if (providerForm.controls.providerId.errors?.['maxlength']) { - Must be at most 64 characters - } + @if (connectorForm.controls.providerId.errors?.['required']) { Connector ID is required } + @else if (connectorForm.controls.providerId.errors?.['pattern']) { Must be lowercase letters, numbers, and hyphens only } + @else if (connectorForm.controls.providerId.errors?.['maxlength']) { Must be at most 64 characters }

}
-
-
- +
- + + @if (isEditMode() && loadedConnector(); as loaded) { +
+

AgentCore Identity

+
+
+ +
+ + +
+
+ @if (loaded.credentialProviderArn) { +
+ + +
+ } +
+
+ } +
-

- OAuth Configuration -

+

OAuth Credentials

- Configure the OAuth 2.0 endpoints and credentials. + @if (isEditMode()) { + Enter both fields to rotate credentials. Leave both blank to keep existing. + } @else { + Credentials are stored by AWS Bedrock AgentCore Identity — never by this application. + }

- -
- - - @if (providerForm.controls.authorizationEndpoint.invalid && providerForm.controls.authorizationEndpoint.touched) { -

- @if (providerForm.controls.authorizationEndpoint.errors?.['required']) { - Authorization endpoint is required - } @else { - Must be a valid URL - } -

- } -
- - -
- - - @if (providerForm.controls.tokenEndpoint.invalid && providerForm.controls.tokenEndpoint.touched) { -

- @if (providerForm.controls.tokenEndpoint.errors?.['required']) { - Token endpoint is required - } @else { - Must be a valid URL - } -

- } -
- -
- @if (providerForm.controls.clientId.invalid && providerForm.controls.clientId.touched) { -

Client ID is required

- }
-
- @if (isEditMode()) { +
+ + @if (credentialPairError()) { +

{{ credentialPairError() }}

+ } + + @if (needsDiscovery()) { +
+ +

- Leave blank to keep the existing secret. Enter a new value to update it. + AgentCore fetches this URL to resolve authorization and token endpoints.

- } - @if (!isEditMode() && providerForm.controls.clientSecret.invalid && providerForm.controls.clientSecret.touched) { -

Client secret is required

- } -
+
+ } -
- +
- - -
- - -

- Extra URL parameters for the authorization request. For Google, use "access_type=offline, prompt=consent" to enable refresh tokens. -

-
-
-

- Access Control -

+

Access Control

- Restrict which application roles can use this provider. + Restrict which application roles can use this connector.

-
- - - +
- @if (!providerForm.controls.grantAllRoles.value) { + @if (!connectorForm.controls.grantAllRoles.value) { @if (rolesResource.isLoading() || rolesResource.value() === undefined) {
@@ -407,7 +411,6 @@ interface ProviderFormGroup { [class.dark:text-gray-300]="!isRoleSelected(role.roleId)" class="rounded-sm px-3 py-1.5 text-sm/6 font-medium transition-colors hover:opacity-80 focus:outline-hidden focus:ring-3 focus:ring-purple-500/50" [appTooltip]="role.description || 'No description'" - appTooltipPosition="top" > {{ role.displayName }} @@ -420,33 +423,30 @@ interface ProviderFormGroup { } }

- Only users with selected roles will be able to connect to this provider. + Only users with selected roles will be able to use this connector.

- @if (isEditMode()) {
-

- Security Notice -

+

Security Notice

- Changing scopes may invalidate existing user tokens. Users may need to re-authenticate after scope changes. + Changing scopes forces connected users to re-consent on their next tool call. + Rotating credentials requires re-entering both Client ID and Client Secret.

} -
- @if (providersResource.isLoading() && providers().length === 0) { + @if (connectorsResource.isLoading() && connectors().length === 0) {

- Loading providers... + Loading connectors...

} - @if (providersResource.error()) { + @if (connectorsResource.error()) {
-

Failed to load providers

+

Failed to load connectors

Please check your connection and try again.

} - - @if (!providersResource.isLoading() || providers().length > 0) { + + @if (!connectorsResource.isLoading() || connectors().length > 0) {
- @for (provider of filteredProviders(); track provider.providerId) { + @for (connector of filteredConnectors(); track connector.providerId) { - +
- Provider + Connector
-
- +
+

- {{ provider.displayName }} + {{ connector.displayName }}

- {{ provider.providerId }} + {{ connector.providerId }}

@@ -217,23 +217,23 @@ import { TooltipDirective } from '../../../components/tooltip/tooltip.directive'
- @if (provider.enabled) { + @if (connector.enabled) { Active @@ -276,17 +276,17 @@ import { TooltipDirective } from '../../../components/tooltip/tooltip.directive'
- @if (filteredProviders().length === 0 && !providersResource.isLoading()) { + @if (filteredConnectors().length === 0 && !connectorsResource.isLoading()) {
@if (hasActiveFilters()) { -

No providers match your filters

+

No connectors match your filters

Try adjusting your search or filter criteria.

@@ -332,15 +332,15 @@ import { TooltipDirective } from '../../../components/tooltip/tooltip.directive' } - @if (providers().length > 0) { + @if (connectors().length > 0) {
-

About OAuth Providers

+

About Connectors

- Provider Types: Choose from common presets (Google, Microsoft, GitHub, Canvas) or configure a custom OAuth 2.0 provider. + Connector Types: Choose from common presets (Google, Microsoft, GitHub, Canvas) or configure a custom OAuth 2.0 connector.

- Role Restrictions: Control which application roles can connect to each provider. Leave empty for unrestricted access. + Role Restrictions: Control which application roles can use each connector. Leave empty for unrestricted access.

Security: Client secrets are encrypted and stored securely. They are never exposed to the frontend after creation. @@ -352,47 +352,44 @@ import { TooltipDirective } from '../../../components/tooltip/tooltip.directive'

`, }) -export class ProviderListPage { - oauthProvidersService = inject(OAuthProvidersService); +export class ConnectorListPage { + connectorsService = inject(ConnectorsService); private router = inject(Router); - readonly providersResource = this.oauthProvidersService.providersResource; + readonly connectorsResource = this.connectorsService.connectorsResource; - // Local state searchQuery = signal(''); enabledFilter = signal(''); typeFilter = signal(''); - // Computed - readonly providers = computed(() => this.oauthProvidersService.getProviders()); + readonly connectors = computed(() => this.connectorsService.getConnectors()); - readonly filteredProviders = computed(() => { - let providers = this.providers(); + readonly filteredConnectors = computed(() => { + let connectors = this.connectors(); const query = this.searchQuery().toLowerCase(); const enabled = this.enabledFilter(); const type = this.typeFilter(); if (query) { - providers = providers.filter( - p => - p.displayName.toLowerCase().includes(query) || - p.providerId.toLowerCase().includes(query) || - p.providerType.toLowerCase().includes(query) + connectors = connectors.filter( + c => + c.displayName.toLowerCase().includes(query) || + c.providerId.toLowerCase().includes(query) || + c.providerType.toLowerCase().includes(query) ); } if (enabled === 'enabled') { - providers = providers.filter(p => p.enabled); + connectors = connectors.filter(c => c.enabled); } else if (enabled === 'disabled') { - providers = providers.filter(p => !p.enabled); + connectors = connectors.filter(c => !c.enabled); } if (type) { - providers = providers.filter(p => p.providerType === type); + connectors = connectors.filter(c => c.providerType === type); } - // Sort: enabled first, then alphabetically by display name - return providers.sort((a, b) => { + return connectors.sort((a, b) => { if (a.enabled !== b.enabled) { return a.enabled ? -1 : 1; } @@ -410,12 +407,11 @@ export class ProviderListPage { this.typeFilter.set(''); } - getProviderIcon(provider: OAuthProvider): string { - if (provider.iconName) { - return provider.iconName; + getConnectorIcon(connector: Connector): string { + if (connector.iconName) { + return connector.iconName; } - // Default icons by type - switch (provider.providerType) { + switch (connector.providerType) { case 'google': case 'microsoft': return 'heroCloud'; @@ -428,7 +424,7 @@ export class ProviderListPage { } } - getProviderIconClasses(type: OAuthProviderType): string { + getConnectorIconClasses(type: ConnectorType): string { const baseClasses = 'flex size-10 shrink-0 items-center justify-center rounded-sm'; switch (type) { case 'google': @@ -444,7 +440,7 @@ export class ProviderListPage { } } - getProviderTypeBadgeClasses(type: OAuthProviderType): string { + getConnectorTypeBadgeClasses(type: ConnectorType): string { const baseClasses = 'inline-flex items-center rounded-xs px-2 py-0.5 text-xs font-medium'; switch (type) { case 'google': @@ -460,7 +456,7 @@ export class ProviderListPage { } } - getProviderTypeLabel(type: OAuthProviderType): string { + getConnectorTypeLabel(type: ConnectorType): string { switch (type) { case 'google': return 'Google'; @@ -475,16 +471,16 @@ export class ProviderListPage { } } - async deleteProvider(provider: OAuthProvider): Promise { - if (!confirm(`Are you sure you want to delete the provider "${provider.displayName}"?\n\nThis will disconnect all users currently using this provider. This action cannot be undone.`)) { + async deleteConnector(connector: Connector): Promise { + if (!confirm(`Are you sure you want to delete the connector "${connector.displayName}"?\n\nThis will disconnect all users currently using this connector. This action cannot be undone.`)) { return; } try { - await this.oauthProvidersService.deleteProvider(provider.providerId); + await this.connectorsService.deleteConnector(connector.providerId); } catch (error: any) { - console.error('Error deleting provider:', error); - const message = error?.error?.detail || error?.message || 'Failed to delete provider.'; + console.error('Error deleting connector:', error); + const message = error?.error?.detail || error?.message || 'Failed to delete connector.'; alert(message); } } diff --git a/frontend/ai.client/src/app/admin/oauth-providers/services/oauth-providers.service.spec.ts b/frontend/ai.client/src/app/admin/connectors/services/connectors.service.spec.ts similarity index 59% rename from frontend/ai.client/src/app/admin/oauth-providers/services/oauth-providers.service.spec.ts rename to frontend/ai.client/src/app/admin/connectors/services/connectors.service.spec.ts index bbea6c75..be7aee54 100644 --- a/frontend/ai.client/src/app/admin/oauth-providers/services/oauth-providers.service.spec.ts +++ b/frontend/ai.client/src/app/admin/connectors/services/connectors.service.spec.ts @@ -2,12 +2,12 @@ import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; import { TestBed } from '@angular/core/testing'; import { HttpClientTestingModule, HttpTestingController } from '@angular/common/http/testing'; import { signal } from '@angular/core'; -import { OAuthProvidersService } from './oauth-providers.service'; +import { ConnectorsService } from './connectors.service'; import { ConfigService } from '../../../services/config.service'; import { AuthService } from '../../../auth/auth.service'; -describe('OAuthProvidersService', () => { - let service: OAuthProvidersService; +describe('ConnectorsService', () => { + let service: ConnectorsService; let httpMock: HttpTestingController; beforeEach(() => { @@ -15,63 +15,63 @@ describe('OAuthProvidersService', () => { TestBed.configureTestingModule({ imports: [HttpClientTestingModule], providers: [ - OAuthProvidersService, + ConnectorsService, { provide: AuthService, useValue: { ensureAuthenticated: vi.fn().mockResolvedValue(undefined) } }, { provide: ConfigService, useValue: { appApiUrl: signal('http://localhost:8000') } }, ], }); - service = TestBed.inject(OAuthProvidersService); + service = TestBed.inject(ConnectorsService); httpMock = TestBed.inject(HttpTestingController); }); afterEach(() => { - httpMock.match(() => true); // discard pending requests + httpMock.match(() => true); TestBed.resetTestingModule(); }); - it('should fetch providers', async () => { + it('should fetch connectors', async () => { const mockResponse = { providers: [], total: 0 }; - const promise = service.fetchProviders(); + const promise = service.fetchConnectors(); await vi.waitFor(() => { httpMock.expectOne('http://localhost:8000/admin/oauth-providers/').flush(mockResponse); }); expect(await promise).toEqual(mockResponse); }); - it('should fetch provider by id', async () => { - const mockProvider = { provider_id: '1', name: 'Test Provider' }; - const promise = service.fetchProvider('1'); + it('should fetch connector by id', async () => { + const mockConnector = { provider_id: '1', name: 'Test Connector' }; + const promise = service.fetchConnector('1'); await vi.waitFor(() => { - httpMock.expectOne('http://localhost:8000/admin/oauth-providers/1').flush(mockProvider); + httpMock.expectOne('http://localhost:8000/admin/oauth-providers/1').flush(mockConnector); }); - expect(await promise).toEqual({ providerId: '1', name: 'Test Provider' }); + expect(await promise).toEqual({ providerId: '1', name: 'Test Connector' }); }); - it('should create provider', async () => { - const providerData = { name: 'New Provider' } as any; - const mockProvider = { provider_id: '1', name: 'New Provider' }; - const promise = service.createProvider(providerData); + it('should create connector', async () => { + const data = { name: 'New Connector' } as any; + const mockConnector = { provider_id: '1', name: 'New Connector' }; + const promise = service.createConnector(data); await vi.waitFor(() => { - httpMock.expectOne('http://localhost:8000/admin/oauth-providers/').flush(mockProvider); + httpMock.expectOne('http://localhost:8000/admin/oauth-providers/').flush(mockConnector); }); - expect(await promise).toEqual({ providerId: '1', name: 'New Provider' }); + expect(await promise).toEqual({ providerId: '1', name: 'New Connector' }); }); - it('should update provider', async () => { - const updates = { name: 'Updated Provider' } as any; - const mockProvider = { provider_id: '1', name: 'Updated Provider' }; - const promise = service.updateProvider('1', updates); + it('should update connector', async () => { + const updates = { name: 'Updated Connector' } as any; + const mockConnector = { provider_id: '1', name: 'Updated Connector' }; + const promise = service.updateConnector('1', updates); await vi.waitFor(() => { - httpMock.expectOne('http://localhost:8000/admin/oauth-providers/1').flush(mockProvider); + httpMock.expectOne('http://localhost:8000/admin/oauth-providers/1').flush(mockConnector); }); - expect(await promise).toEqual({ providerId: '1', name: 'Updated Provider' }); + expect(await promise).toEqual({ providerId: '1', name: 'Updated Connector' }); }); - it('should delete provider', async () => { - const promise = service.deleteProvider('1'); + it('should delete connector', async () => { + const promise = service.deleteConnector('1'); await vi.waitFor(() => { httpMock.expectOne('http://localhost:8000/admin/oauth-providers/1').flush(null); }); await promise; }); -}); \ No newline at end of file +}); diff --git a/frontend/ai.client/src/app/admin/connectors/services/connectors.service.ts b/frontend/ai.client/src/app/admin/connectors/services/connectors.service.ts new file mode 100644 index 00000000..fbf1ff78 --- /dev/null +++ b/frontend/ai.client/src/app/admin/connectors/services/connectors.service.ts @@ -0,0 +1,113 @@ +import { Injectable, inject, resource, computed } from '@angular/core'; +import { HttpClient } from '@angular/common/http'; +import { firstValueFrom } from 'rxjs'; +import { ConfigService } from '../../../services/config.service'; +import { AuthService } from '../../../auth/auth.service'; +import { + Connector, + ConnectorListResponse, + ConnectorCreateRequest, + ConnectorUpdateRequest, +} from '../models/connector.model'; + +function toSnakeCase(obj: Record): Record { + const result: Record = {}; + for (const [key, value] of Object.entries(obj)) { + if (value === undefined) continue; + const snakeKey = key.replace(/[A-Z]/g, letter => `_${letter.toLowerCase()}`); + result[snakeKey] = value; + } + return result; +} + +function toCamelCase(obj: Record): Record { + const result: Record = {}; + for (const [key, value] of Object.entries(obj)) { + const camelKey = key.replace(/_([a-z])/g, (_, letter) => letter.toUpperCase()); + result[camelKey] = value; + } + return result; +} + +/** + * Admin service for managing connectors. + * + * The backend admin endpoint is still `/admin/oauth-providers` — that is the + * stable wire contract. We use the connectors vernacular throughout the + * frontend and translate at this layer. + */ +@Injectable({ + providedIn: 'root' +}) +export class ConnectorsService { + private http = inject(HttpClient); + private authService = inject(AuthService); + private config = inject(ConfigService); + + private readonly baseUrl = computed(() => `${this.config.appApiUrl()}/admin/oauth-providers`); + + readonly connectorsResource = resource({ + loader: async () => { + await this.authService.ensureAuthenticated(); + return this.fetchConnectors(); + } + }); + + getConnectors(): Connector[] { + return this.connectorsResource.value()?.providers ?? []; + } + + getEnabledConnectors(): Connector[] { + return this.getConnectors().filter(c => c.enabled); + } + + getConnectorById(providerId: string): Connector | undefined { + return this.getConnectors().find(c => c.providerId === providerId); + } + + async fetchConnectors(): Promise { + const response = await firstValueFrom( + this.http.get(`${this.baseUrl()}/`) + ); + return { + providers: response.providers.map((p: any) => toCamelCase(p) as Connector), + total: response.total, + }; + } + + async fetchConnector(providerId: string): Promise { + const response = await firstValueFrom( + this.http.get(`${this.baseUrl()}/${providerId}`) + ); + return toCamelCase(response) as Connector; + } + + async createConnector(data: ConnectorCreateRequest): Promise { + const snakeCaseData = toSnakeCase(data as unknown as Record); + const response = await firstValueFrom( + this.http.post(`${this.baseUrl()}/`, snakeCaseData) + ); + this.connectorsResource.reload(); + return toCamelCase(response) as Connector; + } + + async updateConnector(providerId: string, updates: ConnectorUpdateRequest): Promise { + const snakeCaseData = toSnakeCase(updates as unknown as Record); + const response = await firstValueFrom( + this.http.patch(`${this.baseUrl()}/${providerId}`, snakeCaseData) + ); + this.connectorsResource.reload(); + return toCamelCase(response) as Connector; + } + + async deleteConnector(providerId: string): Promise { + await firstValueFrom( + this.http.delete(`${this.baseUrl()}/${providerId}`) + ); + this.connectorsResource.reload(); + } + + reload(): void { + this.connectorsResource.reload(); + } +} diff --git a/frontend/ai.client/src/app/admin/oauth-providers/index.ts b/frontend/ai.client/src/app/admin/oauth-providers/index.ts deleted file mode 100644 index 7b71465d..00000000 --- a/frontend/ai.client/src/app/admin/oauth-providers/index.ts +++ /dev/null @@ -1,9 +0,0 @@ -// Models -export * from './models/oauth-provider.model'; - -// Services -export * from './services/oauth-providers.service'; - -// Pages -export * from './pages/provider-list.page'; -export * from './pages/provider-form.page'; diff --git a/frontend/ai.client/src/app/admin/oauth-providers/models/oauth-provider.model.ts b/frontend/ai.client/src/app/admin/oauth-providers/models/oauth-provider.model.ts deleted file mode 100644 index 8f2b5577..00000000 --- a/frontend/ai.client/src/app/admin/oauth-providers/models/oauth-provider.model.ts +++ /dev/null @@ -1,186 +0,0 @@ -/** - * OAuth provider type enumeration. - */ -export type OAuthProviderType = 'google' | 'microsoft' | 'github' | 'canvas' | 'custom'; - -/** - * OAuth Provider configuration. - */ -export interface OAuthProvider { - /** Unique provider identifier (lowercase alphanumeric + underscore) */ - providerId: string; - /** Human-readable display name */ - displayName: string; - /** Provider type for preset configurations */ - providerType: OAuthProviderType; - /** OAuth authorization endpoint URL */ - authorizationEndpoint: string; - /** OAuth token endpoint URL */ - tokenEndpoint: string; - /** OAuth client ID (public) */ - clientId: string; - /** OAuth scopes to request */ - scopes: string[]; - /** AppRole IDs that can use this provider */ - allowedRoles: string[]; - /** Whether this provider is active */ - enabled: boolean; - /** Icon name for UI display (heroicons) */ - iconName: string; - /** Additional authorization URL parameters (e.g., access_type=offline for Google) */ - authorizationParams: Record; - /** ISO 8601 creation timestamp */ - createdAt: string; - /** ISO 8601 update timestamp */ - updatedAt: string; -} - -/** - * Response model for listing OAuth providers. - */ -export interface OAuthProviderListResponse { - providers: OAuthProvider[]; - total: number; -} - -/** - * Request model for creating a new OAuth provider. - */ -export interface OAuthProviderCreateRequest { - /** Unique provider identifier (lowercase alphanumeric + underscore, 3-50 chars) */ - providerId: string; - /** Human-readable display name (1-100 chars) */ - displayName: string; - /** Provider type */ - providerType: OAuthProviderType; - /** OAuth authorization endpoint URL */ - authorizationEndpoint: string; - /** OAuth token endpoint URL */ - tokenEndpoint: string; - /** OAuth client ID */ - clientId: string; - /** OAuth client secret (only sent on create, never returned) */ - clientSecret: string; - /** OAuth scopes to request */ - scopes: string[]; - /** AppRole IDs that can use this provider */ - allowedRoles?: string[]; - /** Whether this provider is active */ - enabled?: boolean; - /** Icon name for UI display */ - iconName?: string; - /** Additional authorization URL parameters (e.g., access_type=offline for Google) */ - authorizationParams?: Record; -} - -/** - * Request model for updating an OAuth provider. - * All fields are optional for partial updates. - */ -export interface OAuthProviderUpdateRequest { - /** Human-readable display name (1-100 chars) */ - displayName?: string; - /** OAuth authorization endpoint URL */ - authorizationEndpoint?: string; - /** OAuth token endpoint URL */ - tokenEndpoint?: string; - /** OAuth client ID */ - clientId?: string; - /** OAuth client secret (only set if updating) */ - clientSecret?: string; - /** OAuth scopes to request */ - scopes?: string[]; - /** AppRole IDs that can use this provider */ - allowedRoles?: string[]; - /** Whether this provider is active */ - enabled?: boolean; - /** Icon name for UI display */ - iconName?: string; - /** Additional authorization URL parameters (e.g., access_type=offline for Google) */ - authorizationParams?: Record; -} - -/** - * Form data model for creating/editing an OAuth provider. - */ -export interface OAuthProviderFormData { - providerId: string; - displayName: string; - providerType: OAuthProviderType; - authorizationEndpoint: string; - tokenEndpoint: string; - clientId: string; - clientSecret: string; - scopes: string; - allowedRoles: string[]; - enabled: boolean; - iconName: string; - authorizationParams: string; -} - -/** - * Preset configurations for common OAuth providers. - */ -export interface OAuthProviderPreset { - type: OAuthProviderType; - displayName: string; - authorizationEndpoint: string; - tokenEndpoint: string; - defaultScopes: string[]; - iconName: string; - authorizationParams?: Record; -} - -/** - * Common OAuth provider presets. - */ -export const OAUTH_PROVIDER_PRESETS: OAuthProviderPreset[] = [ - { - type: 'google', - displayName: 'Google', - authorizationEndpoint: 'https://accounts.google.com/o/oauth2/v2/auth', - tokenEndpoint: 'https://oauth2.googleapis.com/token', - defaultScopes: ['openid', 'email', 'profile'], - iconName: 'heroCloud', - authorizationParams: { access_type: 'offline', prompt: 'consent' }, - }, - { - type: 'microsoft', - displayName: 'Microsoft', - authorizationEndpoint: 'https://login.microsoftonline.com/common/oauth2/v2.0/authorize', - tokenEndpoint: 'https://login.microsoftonline.com/common/oauth2/v2.0/token', - defaultScopes: ['openid', 'email', 'profile', 'offline_access'], - iconName: 'heroCloud', - }, - { - type: 'github', - displayName: 'GitHub', - authorizationEndpoint: 'https://github.com/login/oauth/authorize', - tokenEndpoint: 'https://github.com/login/oauth/access_token', - defaultScopes: ['read:user', 'user:email'], - iconName: 'heroCodeBracket', - }, - { - type: 'canvas', - displayName: 'Canvas LMS', - authorizationEndpoint: '', // User must configure - tokenEndpoint: '', // User must configure - defaultScopes: [], - iconName: 'heroAcademicCap', - }, - { - type: 'custom', - displayName: 'Custom Provider', - authorizationEndpoint: '', - tokenEndpoint: '', - defaultScopes: [], - iconName: 'heroLink', - }, -]; - -/** - * Get preset configuration for a provider type. - */ -export function getProviderPreset(type: OAuthProviderType): OAuthProviderPreset | undefined { - return OAUTH_PROVIDER_PRESETS.find(preset => preset.type === type); -} diff --git a/frontend/ai.client/src/app/admin/oauth-providers/services/oauth-providers.service.ts b/frontend/ai.client/src/app/admin/oauth-providers/services/oauth-providers.service.ts deleted file mode 100644 index b433014e..00000000 --- a/frontend/ai.client/src/app/admin/oauth-providers/services/oauth-providers.service.ts +++ /dev/null @@ -1,154 +0,0 @@ -import { Injectable, inject, resource, computed } from '@angular/core'; -import { HttpClient } from '@angular/common/http'; -import { firstValueFrom } from 'rxjs'; -import { ConfigService } from '../../../services/config.service'; -import { AuthService } from '../../../auth/auth.service'; -import { - OAuthProvider, - OAuthProviderListResponse, - OAuthProviderCreateRequest, - OAuthProviderUpdateRequest, -} from '../models/oauth-provider.model'; - -/** - * Convert camelCase to snake_case for backend API. - */ -function toSnakeCase(obj: Record): Record { - const result: Record = {}; - for (const [key, value] of Object.entries(obj)) { - if (value === undefined) continue; - const snakeKey = key.replace(/[A-Z]/g, letter => `_${letter.toLowerCase()}`); - result[snakeKey] = value; - } - return result; -} - -/** - * Convert snake_case to camelCase for frontend models. - */ -function toCamelCase(obj: Record): Record { - const result: Record = {}; - for (const [key, value] of Object.entries(obj)) { - const camelKey = key.replace(/_([a-z])/g, (_, letter) => letter.toUpperCase()); - result[camelKey] = value; - } - return result; -} - -/** - * Service to manage OAuth Providers. - * - * Provides access to the provider list for use in forms and displays, - * as well as CRUD operations for provider management. - */ -@Injectable({ - providedIn: 'root' -}) -export class OAuthProvidersService { - private http = inject(HttpClient); - private authService = inject(AuthService); - private config = inject(ConfigService); - - private readonly baseUrl = computed(() => `${this.config.appApiUrl()}/admin/oauth-providers`); - - /** - * Reactive resource for fetching OAuth Providers. - */ - readonly providersResource = resource({ - loader: async () => { - await this.authService.ensureAuthenticated(); - return this.fetchProviders(); - } - }); - - /** - * Get all OAuth Providers (from resource). - */ - getProviders(): OAuthProvider[] { - return this.providersResource.value()?.providers ?? []; - } - - /** - * Get only enabled OAuth Providers. - */ - getEnabledProviders(): OAuthProvider[] { - return this.getProviders().filter(p => p.enabled); - } - - /** - * Get a provider by ID from the cached resource. - */ - getProviderById(providerId: string): OAuthProvider | undefined { - return this.getProviders().find(p => p.providerId === providerId); - } - - /** - * Fetch all OAuth Providers from the API. - */ - async fetchProviders(): Promise { - const response = await firstValueFrom( - this.http.get(`${this.baseUrl()}/`) - ); - // Convert snake_case response to camelCase - return { - providers: response.providers.map((p: any) => toCamelCase(p) as OAuthProvider), - total: response.total, - }; - } - - /** - * Fetch a single provider by ID from the API. - */ - async fetchProvider(providerId: string): Promise { - const response = await firstValueFrom( - this.http.get(`${this.baseUrl()}/${providerId}`) - ); - // Convert snake_case response to camelCase - return toCamelCase(response) as OAuthProvider; - } - - /** - * Create a new OAuth Provider. - */ - async createProvider(providerData: OAuthProviderCreateRequest): Promise { - // Convert camelCase request to snake_case - const snakeCaseData = toSnakeCase(providerData as unknown as Record); - const response = await firstValueFrom( - this.http.post(`${this.baseUrl()}/`, snakeCaseData) - ); - this.providersResource.reload(); - // Convert snake_case response to camelCase - return toCamelCase(response) as OAuthProvider; - } - - /** - * Update an existing OAuth Provider. - */ - async updateProvider(providerId: string, updates: OAuthProviderUpdateRequest): Promise { - // Convert camelCase request to snake_case - const snakeCaseData = toSnakeCase(updates as unknown as Record); - const response = await firstValueFrom( - this.http.patch(`${this.baseUrl()}/${providerId}`, snakeCaseData) - ); - this.providersResource.reload(); - // Convert snake_case response to camelCase - return toCamelCase(response) as OAuthProvider; - } - - /** - * Delete an OAuth Provider. - */ - async deleteProvider(providerId: string): Promise { - await firstValueFrom( - this.http.delete(`${this.baseUrl()}/${providerId}`) - ); - this.providersResource.reload(); - } - - /** - * Reload the providers resource. - */ - reload(): void { - this.providersResource.reload(); - } -} diff --git a/frontend/ai.client/src/app/app.routes.ts b/frontend/ai.client/src/app/app.routes.ts index 1be59cdd..3874cdd4 100644 --- a/frontend/ai.client/src/app/app.routes.ts +++ b/frontend/ai.client/src/app/app.routes.ts @@ -32,16 +32,6 @@ export const routes: Routes = [ path: 'auth/callback', loadComponent: () => import('./auth/callback/callback.page').then(m => m.CallbackPage), }, - { - path: 'connectors', - redirectTo: 'settings/connectors', - pathMatch: 'full', - }, - { - path: 'connections', - redirectTo: 'settings/connectors', - pathMatch: 'full', - }, { path: 'admin', loadComponent: () => import('./admin/admin.page').then(m => m.AdminPage), @@ -108,8 +98,8 @@ export const routes: Routes = [ canActivate: [authGuard], }, { - path: 'settings/oauth/callback', - loadComponent: () => import('./settings/oauth-callback/oauth-callback.page').then(m => m.OAuthCallbackPage), + path: 'oauth-complete', + loadComponent: () => import('./oauth-complete/oauth-complete.page').then(m => m.OAuthCompletePage), }, { path: 'settings', @@ -184,17 +174,32 @@ export const routes: Routes = [ }, { path: 'admin/oauth-providers', - loadComponent: () => import('./admin/oauth-providers/pages/provider-list.page').then(m => m.ProviderListPage), - canActivate: [adminGuard], + redirectTo: 'admin/connectors', + pathMatch: 'full', }, { path: 'admin/oauth-providers/new', - loadComponent: () => import('./admin/oauth-providers/pages/provider-form.page').then(m => m.ProviderFormPage), - canActivate: [adminGuard], + redirectTo: 'admin/connectors/new', + pathMatch: 'full', }, { path: 'admin/oauth-providers/edit/:providerId', - loadComponent: () => import('./admin/oauth-providers/pages/provider-form.page').then(m => m.ProviderFormPage), + redirectTo: 'admin/connectors/edit/:providerId', + pathMatch: 'full', + }, + { + path: 'admin/connectors', + loadComponent: () => import('./admin/connectors/pages/connector-list.page').then(m => m.ConnectorListPage), + canActivate: [adminGuard], + }, + { + path: 'admin/connectors/new', + loadComponent: () => import('./admin/connectors/pages/connector-form.page').then(m => m.ConnectorFormPage), + canActivate: [adminGuard], + }, + { + path: 'admin/connectors/edit/:providerId', + loadComponent: () => import('./admin/connectors/pages/connector-form.page').then(m => m.ConnectorFormPage), canActivate: [adminGuard], }, { diff --git a/frontend/ai.client/src/app/settings/connectors/index.ts b/frontend/ai.client/src/app/settings/connectors/index.ts deleted file mode 100644 index b01b9746..00000000 --- a/frontend/ai.client/src/app/settings/connectors/index.ts +++ /dev/null @@ -1,2 +0,0 @@ -export * from './models'; -export * from './services'; diff --git a/frontend/ai.client/src/app/settings/connectors/models/index.ts b/frontend/ai.client/src/app/settings/connectors/models/index.ts deleted file mode 100644 index eee4f455..00000000 --- a/frontend/ai.client/src/app/settings/connectors/models/index.ts +++ /dev/null @@ -1 +0,0 @@ -export * from './oauth-connector.model'; diff --git a/frontend/ai.client/src/app/settings/connectors/models/oauth-connector.model.ts b/frontend/ai.client/src/app/settings/connectors/models/oauth-connector.model.ts deleted file mode 100644 index f4fe81fb..00000000 --- a/frontend/ai.client/src/app/settings/connectors/models/oauth-connector.model.ts +++ /dev/null @@ -1,60 +0,0 @@ -/** - * OAuth connector models for the user-facing Connectors UI. - * - * A "connector" is a single user-to-provider OAuth link surfaced in - * /settings/connectors. The underlying backend endpoint still returns a - * `connections` array — we translate that at the service layer. - */ - -/** Connection status for a user's OAuth connector */ -export type OAuthConnectorStatus = 'connected' | 'expired' | 'revoked' | 'needs_reauth'; - -/** Supported OAuth provider types */ -export type OAuthProviderType = 'google' | 'microsoft' | 'github' | 'canvas' | 'custom'; - -/** - * A user's OAuth connector for a single provider. - */ -export interface OAuthConnector { - providerId: string; - displayName: string; - providerType: OAuthProviderType; - iconName: string; - status: OAuthConnectorStatus; - connectedAt: string | null; - needsReauth: boolean; -} - -/** - * Response shape returned by {@link ConnectorsService.fetchConnectors}. - */ -export interface OAuthConnectorListResponse { - connectors: OAuthConnector[]; -} - -/** - * Available OAuth provider a user may connect to. - * Returned from GET /oauth/providers (filtered by user roles) - */ -export interface OAuthProvider { - providerId: string; - displayName: string; - providerType: OAuthProviderType; - iconName: string; - scopes: string[]; -} - -/** - * Response from GET /oauth/providers - */ -export interface OAuthProviderListResponse { - providers: OAuthProvider[]; - total: number; -} - -/** - * Response from GET /oauth/connect/{provider_id} - */ -export interface OAuthConnectResponse { - authorizationUrl: string; -} diff --git a/frontend/ai.client/src/app/settings/connectors/services/connectors.service.spec.ts b/frontend/ai.client/src/app/settings/connectors/services/connectors.service.spec.ts deleted file mode 100644 index a135a551..00000000 --- a/frontend/ai.client/src/app/settings/connectors/services/connectors.service.spec.ts +++ /dev/null @@ -1,81 +0,0 @@ -import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; -import { TestBed } from '@angular/core/testing'; -import { HttpClientTestingModule, HttpTestingController } from '@angular/common/http/testing'; -import { signal } from '@angular/core'; -import { ConnectorsService } from './connectors.service'; -import { ConfigService } from '../../../services/config.service'; -import { AuthService } from '../../../auth/auth.service'; - -describe('ConnectorsService', () => { - let service: ConnectorsService; - let httpMock: HttpTestingController; - - beforeEach(() => { - TestBed.resetTestingModule(); - TestBed.configureTestingModule({ - imports: [HttpClientTestingModule], - providers: [ - ConnectorsService, - { provide: AuthService, useValue: { ensureAuthenticated: vi.fn().mockResolvedValue(undefined) } }, - { provide: ConfigService, useValue: { appApiUrl: signal('http://localhost:8000') } }, - ], - }); - service = TestBed.inject(ConnectorsService); - httpMock = TestBed.inject(HttpTestingController); - }); - - afterEach(() => { - httpMock.match(() => true); - TestBed.resetTestingModule(); - }); - - it('should fetch connectors', async () => { - const mockResponse = { connections: [{ provider_id: 'google', status: 'connected' }] }; - - const connectorsPromise = service.fetchConnectors(); - - await vi.waitFor(() => { - httpMock.expectOne('http://localhost:8000/oauth/connections').flush(mockResponse); - }); - - const connectors = await connectorsPromise; - expect(connectors.connectors[0].providerId).toBe('google'); - }); - - it('should fetch providers', async () => { - const mockResponse = { providers: [{ provider_id: 'google', name: 'Google' }], total: 1 }; - - const providersPromise = service.fetchProviders(); - - await vi.waitFor(() => { - httpMock.expectOne('http://localhost:8000/oauth/providers').flush(mockResponse); - }); - - const providers = await providersPromise; - expect(providers.providers[0].providerId).toBe('google'); - expect(providers.total).toBe(1); - }); - - it('should connect to provider', async () => { - const mockResponse = { authorization_url: 'https://oauth.example.com/auth' }; - - const connectPromise = service.connect('google'); - - await vi.waitFor(() => { - httpMock.expectOne('http://localhost:8000/oauth/connect/google').flush(mockResponse); - }); - - const authUrl = await connectPromise; - expect(authUrl).toBe('https://oauth.example.com/auth'); - }); - - it('should disconnect from provider', async () => { - const disconnectPromise = service.disconnect('google'); - - await vi.waitFor(() => { - httpMock.expectOne('http://localhost:8000/oauth/connections/google').flush({}); - }); - - await disconnectPromise; - }); -}); diff --git a/frontend/ai.client/src/app/settings/connectors/services/connectors.service.ts b/frontend/ai.client/src/app/settings/connectors/services/connectors.service.ts deleted file mode 100644 index 55272e2f..00000000 --- a/frontend/ai.client/src/app/settings/connectors/services/connectors.service.ts +++ /dev/null @@ -1,103 +0,0 @@ -import { Injectable, inject, resource, computed } from '@angular/core'; -import { HttpClient } from '@angular/common/http'; -import { firstValueFrom } from 'rxjs'; -import { ConfigService } from '../../../services/config.service'; -import { AuthService } from '../../../auth/auth.service'; -import { - OAuthConnector, - OAuthConnectorListResponse, - OAuthProvider, - OAuthProviderListResponse, -} from '../models'; - -function toCamelCase(obj: Record): Record { - const result: Record = {}; - for (const [key, value] of Object.entries(obj)) { - const camelKey = key.replace(/_([a-z])/g, (_, letter) => letter.toUpperCase()); - result[camelKey] = value; - } - return result; -} - -/** - * Service for managing user OAuth connectors. - * - * Provides access to available providers and user's connectors, - * as well as connect/disconnect operations. - */ -@Injectable({ - providedIn: 'root' -}) -export class ConnectorsService { - private http = inject(HttpClient); - private authService = inject(AuthService); - private config = inject(ConfigService); - - private readonly baseUrl = computed(() => `${this.config.appApiUrl()}/oauth`); - - readonly connectorsResource = resource({ - loader: async () => { - await this.authService.ensureAuthenticated(); - return this.fetchConnectors(); - } - }); - - readonly providersResource = resource({ - loader: async () => { - await this.authService.ensureAuthenticated(); - return this.fetchProviders(); - } - }); - - getConnectors(): OAuthConnector[] { - return this.connectorsResource.value()?.connectors ?? []; - } - - getProviders(): OAuthProvider[] { - return this.providersResource.value()?.providers ?? []; - } - - getConnectorByProviderId(providerId: string): OAuthConnector | undefined { - return this.getConnectors().find(c => c.providerId === providerId); - } - - async fetchConnectors(): Promise { - // Backend endpoint still returns a `connections` array — translate at the service layer. - const response = await firstValueFrom( - this.http.get(`${this.baseUrl()}/connections`) - ); - return { - connectors: response.connections.map((c: any) => toCamelCase(c) as OAuthConnector), - }; - } - - async fetchProviders(): Promise { - const response = await firstValueFrom( - this.http.get(`${this.baseUrl()}/providers`) - ); - return { - providers: response.providers.map((p: any) => toCamelCase(p) as OAuthProvider), - total: response.total, - }; - } - - async connect(providerId: string, redirectUrl?: string): Promise { - const params = redirectUrl ? `?redirect=${encodeURIComponent(redirectUrl)}` : ''; - const response = await firstValueFrom( - this.http.get(`${this.baseUrl()}/connect/${providerId}${params}`) - ); - return response.authorization_url; - } - - async disconnect(providerId: string): Promise { - await firstValueFrom( - this.http.delete(`${this.baseUrl()}/connections/${providerId}`) - ); - this.connectorsResource.reload(); - } - - reload(): void { - this.connectorsResource.reload(); - this.providersResource.reload(); - } -} diff --git a/frontend/ai.client/src/app/settings/connectors/services/index.ts b/frontend/ai.client/src/app/settings/connectors/services/index.ts deleted file mode 100644 index 32879454..00000000 --- a/frontend/ai.client/src/app/settings/connectors/services/index.ts +++ /dev/null @@ -1 +0,0 @@ -export * from './connectors.service'; diff --git a/frontend/ai.client/src/app/settings/oauth-callback/oauth-callback.page.ts b/frontend/ai.client/src/app/settings/oauth-callback/oauth-callback.page.ts deleted file mode 100644 index df98aa94..00000000 --- a/frontend/ai.client/src/app/settings/oauth-callback/oauth-callback.page.ts +++ /dev/null @@ -1,655 +0,0 @@ -import { - Component, - ChangeDetectionStrategy, - inject, - OnInit, - OnDestroy, - signal, -} from '@angular/core'; -import { Router, ActivatedRoute } from '@angular/router'; -import { NgIcon, provideIcons } from '@ng-icons/core'; -import { - heroCheck, - heroXMark, - heroArrowPath, - heroLink, -} from '@ng-icons/heroicons/outline'; -import { SidenavService } from '../../services/sidenav/sidenav.service'; - -type CallbackState = 'processing' | 'success' | 'error'; - -@Component({ - selector: 'app-oauth-callback', - imports: [NgIcon], - providers: [ - provideIcons({ - heroCheck, - heroXMark, - heroArrowPath, - heroLink, - }), - ], - changeDetection: ChangeDetectionStrategy.OnPush, - template: ` -
- - - - - - - -
- - @if (state() === 'processing') { -
-
- -
-
-
-
-
-
-

Connecting

-

- Establishing secure connection - - . - . - . - -

-
- } - - - @if (state() === 'success') { -
-
- -
-
-
-
-

Connected

-

- @if (providerName()) { - Successfully linked to {{ providerName() }} - } @else { - Authorization complete - } -

-

- Redirecting to your connectors... -

-
- } - - - @if (state() === 'error') { -
-
- -
-
-
-
-

Connection Failed

-

- {{ errorMessage() }} -

-

- Redirecting back... -

-
- } - - - -
- - - -
- `, - styles: ` - :host { - display: block; - min-height: 100dvh; - background: var(--color-gray-50); - } - - :host-context(html.dark) { - background: var(--color-gray-900); - } - - .callback-container { - position: relative; - min-height: 100dvh; - display: flex; - flex-direction: column; - align-items: center; - justify-content: center; - overflow: hidden; - padding: 2rem; - } - - /* Animated grid background */ - .grid-background { - position: absolute; - inset: 0; - display: grid; - grid-template-columns: repeat(8, 1fr); - opacity: 0.04; - pointer-events: none; - } - - :host-context(html.dark) .grid-background { - opacity: 0.06; - } - - .grid-line { - border-right: 1px solid var(--color-primary-500); - height: 100%; - animation: pulse-line 4s ease-in-out infinite; - } - - @keyframes pulse-line { - 0%, 100% { opacity: 0.3; } - 50% { opacity: 1; } - } - - /* Floating accent shapes */ - .accent-shape { - position: absolute; - border-radius: 50%; - filter: blur(80px); - pointer-events: none; - animation: float 8s ease-in-out infinite; - } - - .shape-1 { - width: 350px; - height: 350px; - background: var(--color-primary-500); - opacity: 0.12; - top: -80px; - right: -80px; - animation-delay: 0s; - } - - .shape-2 { - width: 280px; - height: 280px; - background: var(--color-secondary-500); - opacity: 0.1; - bottom: -60px; - left: -60px; - animation-delay: 3s; - } - - @keyframes float { - 0%, 100% { transform: translate(0, 0) scale(1); } - 25% { transform: translate(10px, -20px) scale(1.05); } - 50% { transform: translate(-5px, 10px) scale(0.95); } - 75% { transform: translate(-15px, -10px) scale(1.02); } - } - - /* Main content */ - .content-wrapper { - position: relative; - z-index: 10; - display: flex; - flex-direction: column; - align-items: center; - text-align: center; - animation: fade-up 0.6s ease-out; - min-width: 320px; - } - - @keyframes fade-up { - from { - opacity: 0; - transform: translateY(20px); - } - to { - opacity: 1; - transform: translateY(0); - } - } - - /* Status display */ - .status-display { - margin-bottom: 2rem; - animation: scale-in 0.5s cubic-bezier(0.34, 1.56, 0.64, 1); - } - - @keyframes scale-in { - from { - opacity: 0; - transform: scale(0.8); - } - to { - opacity: 1; - transform: scale(1); - } - } - - .icon-container { - position: relative; - width: 120px; - height: 120px; - display: flex; - align-items: center; - justify-content: center; - border-radius: 50%; - } - - /* Processing state */ - .processing-icon { - background: linear-gradient(135deg, var(--color-primary-100) 0%, var(--color-primary-50) 100%); - color: var(--color-primary-600); - animation: rotate-subtle 8s linear infinite; - } - - :host-context(html.dark) .processing-icon { - background: linear-gradient(135deg, var(--color-primary-900) 0%, var(--color-primary-950) 100%); - color: var(--color-primary-400); - } - - @keyframes rotate-subtle { - from { transform: rotate(0deg); } - to { transform: rotate(360deg); } - } - - .pulse-ring { - position: absolute; - inset: -8px; - border-radius: 50%; - border: 2px solid var(--color-primary-400); - animation: pulse-out 2s ease-out infinite; - } - - .pulse-ring.delay-1 { - animation-delay: 0.6s; - } - - .pulse-ring.delay-2 { - animation-delay: 1.2s; - } - - @keyframes pulse-out { - 0% { - opacity: 0.6; - transform: scale(1); - } - 100% { - opacity: 0; - transform: scale(1.6); - } - } - - /* Success state */ - .success-icon { - background: linear-gradient(135deg, var(--color-green-100) 0%, var(--color-green-50) 100%); - color: var(--color-green-600); - animation: success-bounce 0.6s cubic-bezier(0.34, 1.56, 0.64, 1); - } - - :host-context(html.dark) .success-icon { - background: linear-gradient(135deg, var(--color-green-900) 0%, var(--color-green-950) 100%); - color: var(--color-green-400); - } - - @keyframes success-bounce { - 0% { - opacity: 0; - transform: scale(0.3); - } - 50% { - transform: scale(1.1); - } - 100% { - opacity: 1; - transform: scale(1); - } - } - - .check-ring { - position: absolute; - inset: -4px; - border-radius: 50%; - border: 3px solid var(--color-green-400); - animation: ring-appear 0.4s ease-out 0.3s both; - } - - @keyframes ring-appear { - from { - opacity: 0; - transform: scale(0.8); - } - to { - opacity: 1; - transform: scale(1); - } - } - - /* Error state */ - .error-icon { - background: linear-gradient(135deg, var(--color-red-100) 0%, var(--color-red-50) 100%); - color: var(--color-red-600); - animation: error-shake 0.5s ease-out; - } - - :host-context(html.dark) .error-icon { - background: linear-gradient(135deg, var(--color-red-900) 0%, var(--color-red-950) 100%); - color: var(--color-red-400); - } - - @keyframes error-shake { - 0%, 100% { transform: translateX(0); } - 20% { transform: translateX(-8px); } - 40% { transform: translateX(8px); } - 60% { transform: translateX(-4px); } - 80% { transform: translateX(4px); } - } - - .error-ring { - position: absolute; - inset: -4px; - border-radius: 50%; - border: 3px solid var(--color-red-400); - animation: ring-appear 0.4s ease-out 0.3s both; - } - - /* Message section */ - .message-section { - animation: fade-up 0.6s ease-out 0.2s both; - } - - .title { - font-family: 'Outfit', system-ui, sans-serif; - font-weight: 700; - font-size: clamp(1.75rem, 5vw, 2.5rem); - color: var(--color-gray-900); - margin: 0 0 0.75rem; - letter-spacing: -0.02em; - } - - :host-context(html.dark) .title { - color: var(--color-gray-100); - } - - .success-title { - color: var(--color-green-600); - } - - :host-context(html.dark) .success-title { - color: var(--color-green-400); - } - - .error-title { - color: var(--color-red-600); - } - - :host-context(html.dark) .error-title { - color: var(--color-red-400); - } - - .subtitle { - font-family: 'Space Mono', monospace; - font-size: clamp(0.875rem, 2vw, 1rem); - color: var(--color-gray-600); - margin: 0; - max-width: 320px; - display: flex; - align-items: center; - justify-content: center; - gap: 0; - } - - :host-context(html.dark) .subtitle { - color: var(--color-gray-400); - } - - .error-subtitle { - color: var(--color-red-600); - } - - :host-context(html.dark) .error-subtitle { - color: var(--color-red-400); - } - - .typing-text { - overflow: hidden; - white-space: nowrap; - animation: typing 1.5s steps(30) forwards; - } - - @keyframes typing { - from { width: 0; } - to { width: 100%; } - } - - .dots { - display: inline-flex; - margin-left: 2px; - } - - .dot { - animation: dot-bounce 1.4s ease-in-out infinite; - opacity: 0; - } - - .dot:nth-child(1) { animation-delay: 0s; } - .dot:nth-child(2) { animation-delay: 0.2s; } - .dot:nth-child(3) { animation-delay: 0.4s; } - - @keyframes dot-bounce { - 0%, 80%, 100% { - opacity: 0; - transform: translateY(0); - } - 40% { - opacity: 1; - transform: translateY(-4px); - } - } - - .redirect-notice { - font-family: 'Space Mono', monospace; - font-size: 0.75rem; - color: var(--color-gray-500); - margin-top: 1.5rem; - animation: fade-up 0.4s ease-out 0.5s both; - } - - :host-context(html.dark) .redirect-notice { - color: var(--color-gray-500); - } - - /* Progress track */ - .progress-track { - width: 200px; - height: 4px; - background: var(--color-gray-200); - border-radius: 2px; - margin-top: 2.5rem; - overflow: hidden; - animation: fade-up 0.4s ease-out 0.4s both; - } - - :host-context(html.dark) .progress-track { - background: var(--color-gray-700); - } - - .progress-fill { - height: 100%; - width: 0%; - background: var(--color-primary-500); - border-radius: 2px; - animation: progress-loading 2.5s ease-out forwards; - } - - .progress-fill.success { - background: var(--color-green-500); - animation: progress-complete 0.4s ease-out forwards; - } - - .progress-fill.error { - background: var(--color-red-500); - animation: progress-complete 0.4s ease-out forwards; - } - - @keyframes progress-loading { - 0% { width: 0%; } - 20% { width: 15%; } - 50% { width: 45%; } - 80% { width: 75%; } - 100% { width: 90%; } - } - - @keyframes progress-complete { - to { width: 100%; } - } - - /* Bottom accent bar */ - .bottom-bar { - position: absolute; - bottom: 0; - left: 0; - right: 0; - height: 6px; - display: flex; - } - - .bar-segment { - flex: 1; - background: var(--color-primary-500); - animation: bar-grow 0.8s ease-out both; - } - - .bar-segment.success { - background: var(--color-green-500); - } - - .bar-segment.error { - background: var(--color-red-500); - } - - @keyframes bar-grow { - from { transform: scaleX(0); } - to { transform: scaleX(1); } - } - `, -}) -export class OAuthCallbackPage implements OnInit, OnDestroy { - private router = inject(Router); - private route = inject(ActivatedRoute); - private sidenavService = inject(SidenavService); - - gridLines = Array.from({ length: 8 }, (_, i) => i); - - // State signals - state = signal('processing'); - providerName = signal(null); - errorMessage = signal('Authorization was denied or failed.'); - - ngOnInit(): void { - this.sidenavService.hide(); - this.handleCallback(); - } - - ngOnDestroy(): void { - this.sidenavService.show(); - } - - private handleCallback(): void { - const params = this.route.snapshot.queryParams; - - // Simulate brief processing delay for visual feedback - setTimeout(() => { - if (params['success'] === 'true') { - this.handleSuccess(params); - } else if (params['error']) { - this.handleError(params); - } else { - // No valid params, redirect to connectors - this.redirectToConnectors(); - } - }, 800); - } - - private handleSuccess(params: Record): void { - const provider = params['provider']; - if (provider) { - this.providerName.set(this.formatProviderName(provider)); - } - this.state.set('success'); - - // Redirect after showing success - setTimeout(() => { - this.redirectToConnectors({ success: 'true', provider }); - }, 1500); - } - - private handleError(params: Record): void { - const error = params['error']; - const description = params['error_description']; - const provider = params['provider']; - - let message = 'Authorization was denied or failed.'; - if (description) { - message = description; - } else if (error === 'access_denied') { - message = 'Authorization was denied. Please try again.'; - } else if (error === 'missing_params') { - message = 'Invalid callback parameters.'; - } else if (error === 'invalid_state') { - message = 'Session expired. Please try again.'; - } else if (error === 'token_exchange_failed') { - message = 'Failed to complete authorization.'; - } - - this.errorMessage.set(message); - if (provider) { - this.providerName.set(this.formatProviderName(provider)); - } - this.state.set('error'); - - // Redirect after showing error - setTimeout(() => { - this.redirectToConnectors({ error, provider }); - }, 2500); - } - - private redirectToConnectors(queryParams?: Record): void { - this.router.navigate(['/settings/connectors'], { - queryParams, - replaceUrl: true, - }); - } - - private formatProviderName(providerId: string): string { - // Convert provider_id to display name - return providerId - .replace(/_/g, ' ') - .replace(/-/g, ' ') - .split(' ') - .map(word => word.charAt(0).toUpperCase() + word.slice(1)) - .join(' '); - } -} diff --git a/frontend/ai.client/src/app/settings/pages/connectors-settings/connectors-settings.page.ts b/frontend/ai.client/src/app/settings/pages/connectors-settings/connectors-settings.page.ts deleted file mode 100644 index 282e3d06..00000000 --- a/frontend/ai.client/src/app/settings/pages/connectors-settings/connectors-settings.page.ts +++ /dev/null @@ -1,308 +0,0 @@ -import { - Component, - ChangeDetectionStrategy, - inject, - signal, - computed, - OnInit, -} from '@angular/core'; -import { Router, ActivatedRoute } from '@angular/router'; -import { NgIcon, provideIcons } from '@ng-icons/core'; -import { - heroLink, - heroCloud, - heroCodeBracket, - heroAcademicCap, - heroCheck, - heroExclamationTriangle, - heroArrowPath, - heroKey, -} from '@ng-icons/heroicons/outline'; -import { ConnectorsService } from '../../connectors/services'; -import { OAuthConnector, OAuthProviderType } from '../../connectors/models'; -import { ToastService } from '../../../services/toast/toast.service'; - -@Component({ - selector: 'app-connectors-settings', - changeDetection: ChangeDetectionStrategy.OnPush, - imports: [NgIcon], - providers: [ - provideIcons({ - heroLink, - heroCloud, - heroCodeBracket, - heroAcademicCap, - heroCheck, - heroExclamationTriangle, - heroArrowPath, - heroKey, - }), - ], - host: { class: 'block' }, - template: ` -
- -
-

Connectors

-

- Connect your accounts to enable tools that require third-party authentication. -

-
- - - @if (connectorsResource.isLoading() && connectors().length === 0) { -
-
-
-

Loading connectors...

-
-
- } - - - @if (connectorsResource.error()) { -
-

Failed to load connectors

-

Please check your connection and try again.

- -
- } - - - @if (!connectorsResource.isLoading() || connectors().length > 0) { - @if (connectors().length === 0 && !connectorsResource.error()) { - -
-
- -
-

No connectors available

-

- There are no OAuth providers configured for your account. Contact an administrator if you need access to external tools. -

-
- } @else { -
- @for (connector of connectors(); track connector.providerId) { -
- -
-
- -
-
-

- {{ connector.displayName }} -

- - - @if (isConnected(connector)) { -
- - - Connected - - @if (connector.connectedAt) { - - since {{ formatDate(connector.connectedAt) }} - - } -
- } @else if (connector.needsReauth || connector.status === 'needs_reauth' || connector.status === 'expired') { -
- - - Needs Re-authorization - -
- } @else { -

- Not connected -

- } -
-
- - -
- @if (isConnected(connector) && !connector.needsReauth && connector.status !== 'needs_reauth' && connector.status !== 'expired') { - - } @else { - - } -
-
- } -
- } - } -
- `, -}) -export class ConnectorsSettingsPage implements OnInit { - readonly connectorsService = inject(ConnectorsService); - private router = inject(Router); - private route = inject(ActivatedRoute); - private toast = inject(ToastService); - - readonly connectorsResource = this.connectorsService.connectorsResource; - - connecting = signal(null); - disconnecting = signal(null); - - readonly connectors = computed(() => this.connectorsService.getConnectors()); - - ngOnInit(): void { - this.handleCallbackParams(); - } - - private handleCallbackParams(): void { - const params = this.route.snapshot.queryParams; - - if (params['success'] === 'true') { - const provider = params['provider'] || 'the service'; - this.toast.success('Connected!', `Successfully connected to ${provider}.`); - this.connectorsService.reload(); - this.router.navigate([], { - relativeTo: this.route, - queryParams: {}, - replaceUrl: true, - }); - } else if (params['error']) { - const error = params['error']; - const provider = params['provider'] || 'the service'; - const description = params['error_description']; - - let message = `Failed to connect to ${provider}.`; - if (description) { - message = description; - } else if (error === 'access_denied') { - message = 'Authorization was denied. Please try again.'; - } else if (error === 'missing_params') { - message = 'Invalid callback. Please try again.'; - } - - this.toast.error('Connection Failed', message); - this.router.navigate([], { - relativeTo: this.route, - queryParams: {}, - replaceUrl: true, - }); - } - } - - isConnected(connector: OAuthConnector): boolean { - return connector.status === 'connected'; - } - - async connect(connector: OAuthConnector): Promise { - this.connecting.set(connector.providerId); - - try { - const redirectUrl = window.location.origin + '/settings/oauth/callback'; - const authUrl = await this.connectorsService.connect(connector.providerId, redirectUrl); - window.location.href = authUrl; - } catch (error: unknown) { - const err = error as { error?: { detail?: string }; message?: string }; - const message = err?.error?.detail || err?.message || 'Failed to initiate connection.'; - this.toast.error('Connection Error', message); - this.connecting.set(null); - } - } - - async disconnect(connector: OAuthConnector): Promise { - if (!confirm(`Are you sure you want to disconnect from ${connector.displayName}?`)) { - return; - } - - this.disconnecting.set(connector.providerId); - - try { - await this.connectorsService.disconnect(connector.providerId); - this.toast.success('Disconnected', `Successfully disconnected from ${connector.displayName}.`); - } catch (error: unknown) { - const err = error as { error?: { detail?: string }; message?: string }; - const message = err?.error?.detail || err?.message || 'Failed to disconnect.'; - this.toast.error('Disconnect Error', message); - } finally { - this.disconnecting.set(null); - } - } - - getProviderIcon(connector: OAuthConnector): string { - if (connector.iconName && connector.iconName !== 'heroLink') { - return connector.iconName; - } - switch (connector.providerType) { - case 'google': - case 'microsoft': - return 'heroCloud'; - case 'github': - return 'heroCodeBracket'; - case 'canvas': - return 'heroAcademicCap'; - default: - return 'heroLink'; - } - } - - getProviderIconClasses(type: OAuthProviderType): string { - const baseClasses = 'flex size-12 shrink-0 items-center justify-center rounded-sm'; - switch (type) { - case 'google': - return `${baseClasses} bg-red-100 text-red-600 dark:bg-red-900/30 dark:text-red-400`; - case 'microsoft': - return `${baseClasses} bg-blue-100 text-blue-600 dark:bg-blue-900/30 dark:text-blue-400`; - case 'github': - return `${baseClasses} bg-gray-800 text-white dark:bg-gray-700`; - case 'canvas': - return `${baseClasses} bg-orange-100 text-orange-600 dark:bg-orange-900/30 dark:text-orange-400`; - default: - return `${baseClasses} bg-purple-100 text-purple-600 dark:bg-purple-900/30 dark:text-purple-400`; - } - } - - formatDate(dateString: string): string { - try { - const date = new Date(dateString); - return date.toLocaleDateString(undefined, { - month: 'short', - day: 'numeric', - year: 'numeric', - }); - } catch { - return dateString; - } - } -} diff --git a/frontend/ai.client/src/app/settings/settings.page.ts b/frontend/ai.client/src/app/settings/settings.page.ts index 5a273d55..7b4f0a5f 100644 --- a/frontend/ai.client/src/app/settings/settings.page.ts +++ b/frontend/ai.client/src/app/settings/settings.page.ts @@ -114,7 +114,6 @@ export class SettingsPage { { label: 'Profile', icon: 'heroUser', route: '/settings/profile', description: 'Your personal information' }, { label: 'Appearance', icon: 'heroPaintBrush', route: '/settings/appearance', description: 'Theme and display' }, { label: 'Chat', icon: 'heroChatBubbleLeftRight', route: '/settings/chat', description: 'Chat preferences' }, - { label: 'Connectors', icon: 'heroLink', route: '/settings/connectors', description: 'Connected apps' }, { label: 'API Keys', icon: 'heroKey', route: '/settings/api-keys', description: 'API key management' }, { label: 'Usage', icon: 'heroChartBar', route: '/settings/usage', description: 'Usage and billing' }, ]; diff --git a/frontend/ai.client/src/app/settings/settings.routes.ts b/frontend/ai.client/src/app/settings/settings.routes.ts index 121976fd..5c8e5b86 100644 --- a/frontend/ai.client/src/app/settings/settings.routes.ts +++ b/frontend/ai.client/src/app/settings/settings.routes.ts @@ -21,16 +21,6 @@ export const settingsRoutes: Routes = [ loadComponent: () => import('./pages/chat-preferences/chat-preferences-settings.page').then(m => m.ChatPreferencesSettingsPage), }, - { - path: 'connectors', - loadComponent: () => - import('./pages/connectors-settings/connectors-settings.page').then(m => m.ConnectorsSettingsPage), - }, - { - path: 'connections', - redirectTo: 'connectors', - pathMatch: 'full', - }, { path: 'api-keys', loadComponent: () => From 4657dafcdad13206b1e166fba94e3da784136efc Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Wed, 22 Apr 2026 10:43:26 -0600 Subject: [PATCH 10/35] chore: gitignore .claude/scheduled_tasks.lock Co-Authored-By: Claude Opus 4.7 (1M context) --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 8d641be6..991aeb60 100644 --- a/.gitignore +++ b/.gitignore @@ -90,6 +90,7 @@ tmp/ # Claude - exclude personal settings .claude/settings.local.json +.claude/scheduled_tasks.lock # OS generated files ehthumbs.db From 782627c1b0ec132718118ed049fdd4b87f12b371 Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Wed, 22 Apr 2026 10:55:22 -0600 Subject: [PATCH 11/35] feat(connectors): emit oauth_required events + runtime consent UI When an external MCP tool needs OAuth consent, AgentCore Identity returns an authorization URL instead of a token. This wires that signal all the way to the user: Backend: - Inference route drains pending consent URLs from the external MCP integration after the agent stream finishes and emits one oauth_required SSE event per provider before done - IAM grants bedrock-agentcore:GetResourceOauth2Token on the runtime role so the AgentCore Identity client can reach the token vault - CLAUDE.MD + SSE_ERROR_MESSAGING.md document the new event Frontend: - Stream parser recognizes oauth_required and surfaces it as an OAuthRequiredEvent - New /oauth-complete landing page handles the AgentCore callback redirect and postMessages consent completion to the opener tab - OAuthConsentService orchestrates popup opening + postMessage receipt - OAuthConsentBanner renders the Connect button inside the chat input - chat-http and assistant preview pass OAuth2CallbackUrl header so AgentCore Runtime knows where to return after consent Also updates the admin Tool form reference from /admin/oauth-providers to /admin/connectors to match the renamed admin surface. Co-Authored-By: Claude Opus 4.7 (1M context) --- CLAUDE.MD | 1 + backend/src/apis/inference_api/chat/routes.py | 15 + backend/tests/routes/test_inference.py | 87 ++++++ docs/SSE_ERROR_MESSAGING.md | 13 + .../app/admin/tools/pages/tool-form.page.ts | 12 +- .../services/preview-chat.service.ts | 1 + .../oauth-consent-banner.component.ts | 100 +++++++ .../app/oauth-complete/oauth-complete.page.ts | 275 ++++++++++++++++++ .../oauth-consent/oauth-consent.service.ts | 190 ++++++++++++ .../chat-input/chat-input.component.html | 3 + .../chat-input/chat-input.component.ts | 3 +- .../services/chat/chat-http.service.ts | 3 +- .../services/chat/stream-parser.service.ts | 6 + .../app/shared/utils/stream-parser/index.ts | 2 + .../utils/stream-parser/stream-parser-core.ts | 31 ++ .../stream-parser/stream-parser-types.ts | 15 +- infrastructure/lib/inference-api-stack.ts | 19 ++ 17 files changed, 767 insertions(+), 9 deletions(-) create mode 100644 frontend/ai.client/src/app/components/oauth-consent-banner/oauth-consent-banner.component.ts create mode 100644 frontend/ai.client/src/app/oauth-complete/oauth-complete.page.ts create mode 100644 frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.ts diff --git a/CLAUDE.MD b/CLAUDE.MD index a736d9d7..a1a09b5c 100644 --- a/CLAUDE.MD +++ b/CLAUDE.MD @@ -48,6 +48,7 @@ npx cdk deploy --all | `message_stop` | End of message | | `tool_use` / `tool_result` | Tool invocation and result | | `stream_error` | Conversational error | +| `oauth_required` | External MCP tool needs user consent — payload `{providerId, authorizationUrl}`, one event per provider emitted after `message_stop` | | `done` | Stream complete | ## Multi-Protocol Tool Architecture diff --git a/backend/src/apis/inference_api/chat/routes.py b/backend/src/apis/inference_api/chat/routes.py index ae4d6893..17b4a4ae 100644 --- a/backend/src/apis/inference_api/chat/routes.py +++ b/backend/src/apis/inference_api/chat/routes.py @@ -24,6 +24,7 @@ build_conversational_error_event, ) from apis.shared.files.file_resolver import get_file_resolver +from apis.shared.oauth.models import OAuthRequiredEvent from apis.shared.models.managed_models import list_managed_models from apis.shared.quota import ( QuotaExceededEvent, @@ -582,6 +583,20 @@ async def stream_with_quota_warning() -> AsyncGenerator[str, None]: ): yield event + # Surface any OAuth consent URLs collected while loading external + # MCP tools for this user. Draining is idempotent — entries are + # removed on read, so a later invocation won't re-prompt unless + # AgentCore Identity still reports consent is required. + from agents.main_agent.integrations.external_mcp_client import ( + get_external_mcp_integration, + ) + + for entry in get_external_mcp_integration().drain_pending_consent(user_id): + yield OAuthRequiredEvent( + provider_id=entry["provider_id"], + authorization_url=entry["authorization_url"], + ).to_sse_format() + # Stream response from agent as SSE (with optional files) # Note: Compression is handled by GZipMiddleware if configured in main.py return StreamingResponse( diff --git a/backend/tests/routes/test_inference.py b/backend/tests/routes/test_inference.py index fdfc59fa..0faaa140 100644 --- a/backend/tests/routes/test_inference.py +++ b/backend/tests/routes/test_inference.py @@ -147,6 +147,93 @@ async def fake_stream(*args, **kwargs): # --------------------------------------------------------------------------- +class TestInvocationsOAuthRequired: + """After a stream, pending OAuth consent URLs are emitted as oauth_required events.""" + + def test_emits_oauth_required_event_per_provider(self, authed_app, authed_client, trusted_user): + """Drain pending_consent for the user after stream_async and emit one event per provider.""" + mock_agent = MagicMock() + + async def fake_stream(*args, **kwargs): + yield 'event: message_start\ndata: {"role": "assistant"}\n\n' + yield 'event: message_stop\ndata: {"stopReason": "end_turn"}\n\n' + + mock_agent.stream_async = fake_stream + + mock_integration = MagicMock() + mock_integration.drain_pending_consent.return_value = [ + {"provider_id": "google", "authorization_url": "https://accounts.google.com/consent?x=1"}, + {"provider_id": "slack", "authorization_url": "https://slack.com/oauth?y=2"}, + ] + + with patch( + "apis.inference_api.chat.routes.get_agent", + return_value=mock_agent, + ), patch( + "apis.inference_api.chat.routes.is_quota_enforcement_enabled", + return_value=False, + ), patch( + "agents.main_agent.integrations.external_mcp_client.get_external_mcp_integration", + return_value=mock_integration, + ): + resp = authed_client.post( + "/invocations", + json={ + "session_id": "sess-oauth", + "message": "Do the thing", + }, + ) + + assert resp.status_code == 200 + body = resp.text + + assert body.count("event: oauth_required\n") == 2 + assert '"providerId": "google"' in body + assert '"authorizationUrl": "https://accounts.google.com/consent?x=1"' in body + assert '"providerId": "slack"' in body + assert '"authorizationUrl": "https://slack.com/oauth?y=2"' in body + + mock_integration.drain_pending_consent.assert_called_once_with(trusted_user.user_id) + + # oauth_required events must come AFTER message_stop so the UI + # renders consent prompts after the assistant message settles. + assert body.index("event: message_stop") < body.index("event: oauth_required") + + def test_no_event_when_no_pending_consent(self, authed_app, authed_client): + """When drain_pending_consent returns an empty list, no oauth_required event is emitted.""" + mock_agent = MagicMock() + + async def fake_stream(*args, **kwargs): + yield 'event: message_start\ndata: {"role": "assistant"}\n\n' + yield 'event: message_stop\ndata: {"stopReason": "end_turn"}\n\n' + + mock_agent.stream_async = fake_stream + + mock_integration = MagicMock() + mock_integration.drain_pending_consent.return_value = [] + + with patch( + "apis.inference_api.chat.routes.get_agent", + return_value=mock_agent, + ), patch( + "apis.inference_api.chat.routes.is_quota_enforcement_enabled", + return_value=False, + ), patch( + "agents.main_agent.integrations.external_mcp_client.get_external_mcp_integration", + return_value=mock_integration, + ): + resp = authed_client.post( + "/invocations", + json={ + "session_id": "sess-no-oauth", + "message": "Hi", + }, + ) + + assert resp.status_code == 200 + assert "event: oauth_required" not in resp.text + + class TestInvocationsInvalid: """POST /invocations with invalid payload returns 422.""" diff --git a/docs/SSE_ERROR_MESSAGING.md b/docs/SSE_ERROR_MESSAGING.md index e6a435ab..2d06af8b 100644 --- a/docs/SSE_ERROR_MESSAGING.md +++ b/docs/SSE_ERROR_MESSAGING.md @@ -138,6 +138,19 @@ if hasattr(session_manager, 'base_manager') and hasattr(session_manager.base_man - `StreamError` interface - Error state signals for UI +### OAuth Required + +**Backend**: `apis/shared/oauth/models.py` +- `OAuthRequiredEvent` - Pydantic model with `to_sse_format()` method +- Event name: `oauth_required`, payload: `{providerId, authorizationUrl}` + +**Routes**: `apis/inference_api/chat/routes.py` +- After the agent stream drains, pull consent URLs from + `ExternalMCPIntegration.drain_pending_consent(user_id)` and emit one + `oauth_required` event per provider before `done`. +- Consent URLs are populated during external MCP tool loading whenever + AgentCore Identity reports the user hasn't granted the required scopes. + ## Adding New Error Types 1. **Create event model** in `apis/shared/errors.py`: diff --git a/frontend/ai.client/src/app/admin/tools/pages/tool-form.page.ts b/frontend/ai.client/src/app/admin/tools/pages/tool-form.page.ts index dbb20b8d..0fecfa5a 100644 --- a/frontend/ai.client/src/app/admin/tools/pages/tool-form.page.ts +++ b/frontend/ai.client/src/app/admin/tools/pages/tool-form.page.ts @@ -19,7 +19,7 @@ import { heroShieldCheck, } from '@ng-icons/heroicons/outline'; import { AdminToolService } from '../services/admin-tool.service'; -import { OAuthProvidersService } from '../../oauth-providers/services/oauth-providers.service'; +import { ConnectorsService } from '../../connectors/services/connectors.service'; import { TOOL_CATEGORIES, TOOL_PROTOCOLS, @@ -361,8 +361,8 @@ import { }

- Users must connect this provider before using the tool. Manage providers in - OAuth Settings. + Users must connect this connector before using the tool. Manage connectors in + Connectors.

@@ -568,7 +568,7 @@ export class ToolFormPage implements OnInit { private router = inject(Router); private route = inject(ActivatedRoute); private adminToolService = inject(AdminToolService); - private oauthProvidersService = inject(OAuthProvidersService); + private connectorsService = inject(ConnectorsService); readonly categories = TOOL_CATEGORIES; readonly protocols = TOOL_PROTOCOLS; @@ -585,8 +585,8 @@ export class ToolFormPage implements OnInit { readonly isEditMode = computed(() => !!this.toolId()); readonly selectedProtocol = signal('local'); - /** Available OAuth providers for dropdown */ - readonly oauthProviders = computed(() => this.oauthProvidersService.getEnabledProviders()); + /** Available connectors for dropdown */ + readonly oauthProviders = computed(() => this.connectorsService.getEnabledConnectors()); form: FormGroup = this.fb.group({ toolId: ['', [Validators.required, Validators.pattern(/^[a-z][a-z0-9_]{2,49}$/)]], diff --git a/frontend/ai.client/src/app/assistants/assistant-form/services/preview-chat.service.ts b/frontend/ai.client/src/app/assistants/assistant-form/services/preview-chat.service.ts index 5c9dea3d..2727fa79 100644 --- a/frontend/ai.client/src/app/assistants/assistant-form/services/preview-chat.service.ts +++ b/frontend/ai.client/src/app/assistants/assistant-form/services/preview-chat.service.ts @@ -214,6 +214,7 @@ export class PreviewChatService { 'Content-Type': 'application/json', Authorization: `Bearer ${token}`, Accept: 'text/event-stream', + OAuth2CallbackUrl: `${window.location.origin}/oauth-complete`, }, body: JSON.stringify(requestBody), signal: this.abortController.signal, diff --git a/frontend/ai.client/src/app/components/oauth-consent-banner/oauth-consent-banner.component.ts b/frontend/ai.client/src/app/components/oauth-consent-banner/oauth-consent-banner.component.ts new file mode 100644 index 00000000..411931dd --- /dev/null +++ b/frontend/ai.client/src/app/components/oauth-consent-banner/oauth-consent-banner.component.ts @@ -0,0 +1,100 @@ +import { Component, ChangeDetectionStrategy, inject } from '@angular/core'; +import { NgIcon, provideIcons } from '@ng-icons/core'; +import { heroLink, heroArrowTopRightOnSquare, heroXMark } from '@ng-icons/heroicons/outline'; +import { OAuthConsentService } from '../../services/oauth-consent/oauth-consent.service'; + +/** + * Renders a compact "Connect to X" prompt above the chat input whenever the + * SSE stream surfaces an `oauth_required` event. Clicking the button opens + * AgentCore Identity's consent URL in a popup; {@link OAuthConsentService} + * listens for the completion postMessage and clears the entry automatically. + */ +@Component({ + selector: 'app-oauth-consent-banner', + changeDetection: ChangeDetectionStrategy.OnPush, + imports: [NgIcon], + providers: [provideIcons({ heroLink, heroArrowTopRightOnSquare, heroXMark })], + template: ` + @if (consentService.hasPending()) { +
+ @for (request of consentService.pending(); track request.providerId) { +
+
+ } +
+ } + `, + styles: [` + @keyframes fadeIn { + from { + opacity: 0; + transform: translateY(4px); + } + to { + opacity: 1; + transform: translateY(0); + } + } + + .animate-fade-in { + animation: fadeIn 0.15s ease-out; + } + `], +}) +export class OAuthConsentBannerComponent { + protected consentService = inject(OAuthConsentService); + + labelFor(providerId: string): string { + if (!providerId) { + return 'This tool'; + } + return providerId + .replace(/[-_]+/g, ' ') + .split(' ') + .filter((part) => part.length > 0) + .map((part) => part.charAt(0).toUpperCase() + part.slice(1)) + .join(' '); + } + + connect(providerId: string): void { + this.consentService.openConsentPopup(providerId); + } + + dismiss(providerId: string, event: Event): void { + event.stopPropagation(); + this.consentService.dismiss(providerId); + } +} diff --git a/frontend/ai.client/src/app/oauth-complete/oauth-complete.page.ts b/frontend/ai.client/src/app/oauth-complete/oauth-complete.page.ts new file mode 100644 index 00000000..eb7e1b07 --- /dev/null +++ b/frontend/ai.client/src/app/oauth-complete/oauth-complete.page.ts @@ -0,0 +1,275 @@ +import { + ChangeDetectionStrategy, + Component, + OnDestroy, + OnInit, + computed, + inject, + signal, +} from '@angular/core'; +import { ActivatedRoute, Router } from '@angular/router'; +import { NgIcon, provideIcons } from '@ng-icons/core'; +import { + heroCheckCircle, + heroExclamationCircle, +} from '@ng-icons/heroicons/outline'; + +/** + * Landing page for AgentCore Identity's 3-legged OAuth flow. + * + * AgentCore Identity redirects the user here after consent completes (or + * fails). We detect whether this page was opened in a popup: + * + * - Popup: post a message to the opener so the chat can retry the tool + * call, then close the window. + * - Same tab: show a brief success/error message, then route back to chat. + * + * Query params AgentCore Identity may append are not strictly contractual, + * so we treat missing params as success and known error indicators as + * failure (`error`, `error_description`). This matches the defensive + * parsing pattern in `settings/oauth-callback`. + */ + +type CompleteState = 'success' | 'error'; + +/** postMessage payload shape — kept public so other code can type the listener. */ +export interface OAuthCompleteMessage { + type: 'agentcore-oauth-complete'; + status: CompleteState; + providerId: string | null; + error: string | null; +} + +@Component({ + selector: 'app-oauth-complete', + imports: [NgIcon], + providers: [ + provideIcons({ heroCheckCircle, heroExclamationCircle }), + ], + changeDetection: ChangeDetectionStrategy.OnPush, + template: ` +
+ @if (state() === 'success') { +
+
+ } @else { + + } +
+ `, + styles: ` + :host { + display: block; + min-height: 100dvh; + background: var(--color-gray-50); + } + + :host-context(html.dark) { + background: var(--color-gray-900); + } + + .page { + min-height: 100dvh; + display: flex; + align-items: center; + justify-content: center; + padding: 2rem; + } + + .card { + width: 100%; + max-width: 28rem; + padding: 2rem; + border-radius: 0.75rem; + background: var(--color-white); + border: 1px solid var(--color-gray-200); + box-shadow: 0 1px 2px rgb(0 0 0 / 0.05); + text-align: center; + } + + :host-context(html.dark) .card { + background: var(--color-gray-800); + border-color: var(--color-gray-700); + } + + .icon { + display: inline-flex; + width: 3rem; + height: 3rem; + margin-bottom: 1rem; + } + + .success-icon { + color: var(--color-green-600); + } + + :host-context(html.dark) .success-icon { + color: var(--color-green-400); + } + + .error-icon { + color: var(--color-red-600); + } + + :host-context(html.dark) .error-icon { + color: var(--color-red-400); + } + + .title { + font-weight: 600; + font-size: 1.5rem; + margin: 0 0 0.5rem; + color: var(--color-gray-900); + } + + :host-context(html.dark) .title { + color: var(--color-gray-100); + } + + .subtitle { + margin: 0 0 1rem; + color: var(--color-gray-700); + font-size: 0.95rem; + } + + :host-context(html.dark) .subtitle { + color: var(--color-gray-300); + } + + .hint { + margin: 0; + font-size: 0.8125rem; + color: var(--color-gray-500); + } + `, +}) +export class OAuthCompletePage implements OnInit, OnDestroy { + private readonly route = inject(ActivatedRoute); + private readonly router = inject(Router); + + private redirectTimer: ReturnType | null = null; + + readonly state = signal('success'); + readonly providerId = signal(null); + readonly errorMessage = signal('Authorization was denied or did not complete.'); + private readonly isPopup = signal(false); + + readonly providerLabel = computed(() => { + const id = this.providerId(); + if (!id) { + return null; + } + return id + .replace(/[-_]+/g, ' ') + .split(' ') + .filter((part) => part.length > 0) + .map((part) => part.charAt(0).toUpperCase() + part.slice(1)) + .join(' '); + }); + + readonly dismissHint = computed(() => + this.isPopup() + ? 'You can close this window.' + : 'Redirecting back to your chat…', + ); + + ngOnInit(): void { + const params = this.route.snapshot.queryParamMap; + const error = params.get('error'); + const errorDescription = params.get('error_description'); + const providerId = params.get('provider_id') ?? params.get('providerId'); + + if (providerId) { + this.providerId.set(providerId); + } + + if (error) { + this.state.set('error'); + this.errorMessage.set( + errorDescription?.trim() || this.describeError(error), + ); + } + + const inPopup = this.detectPopup(); + this.isPopup.set(inPopup); + + if (inPopup) { + this.notifyOpenerAndClose(); + } else { + this.redirectTimer = setTimeout(() => this.router.navigate(['/']), 2000); + } + } + + ngOnDestroy(): void { + if (this.redirectTimer !== null) { + clearTimeout(this.redirectTimer); + } + } + + /** + * A page is "in a popup" only when it has an opener from the same origin + * that it can actually postMessage to. We guard the property reads + * defensively because cross-origin `window.opener` access can throw. + */ + private detectPopup(): boolean { + try { + return typeof window !== 'undefined' && window.opener != null && window.opener !== window; + } catch { + return false; + } + } + + private notifyOpenerAndClose(): void { + const message: OAuthCompleteMessage = { + type: 'agentcore-oauth-complete', + status: this.state(), + providerId: this.providerId(), + error: this.state() === 'error' ? this.errorMessage() : null, + }; + try { + window.opener?.postMessage(message, window.location.origin); + } catch { + // Cross-origin opener — nothing we can do; leave the page open so the + // user can read the message and close manually. + return; + } + // Small delay so the opener has time to receive before the window closes + // (Chrome closes immediately otherwise on some platforms). + this.redirectTimer = setTimeout(() => { + try { + window.close(); + } catch { + // Some browsers refuse to close pages they didn't open; show the + // static "you can close this window" hint and let the user dismiss. + } + }, 400); + } + + private describeError(code: string): string { + switch (code) { + case 'access_denied': + return 'You declined the authorization request.'; + case 'invalid_scope': + return 'The requested permissions are not available for this account.'; + case 'server_error': + return 'The provider could not complete the request. Try again in a moment.'; + default: + return 'Authorization did not complete. Try again.'; + } + } +} diff --git a/frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.ts b/frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.ts new file mode 100644 index 00000000..b3f0c735 --- /dev/null +++ b/frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.ts @@ -0,0 +1,190 @@ +import { Injectable, signal, computed, inject, DestroyRef } from '@angular/core'; +import { takeUntilDestroyed } from '@angular/core/rxjs-interop'; +import { fromEvent } from 'rxjs'; +import { filter } from 'rxjs/operators'; + +/** + * Pending OAuth consent request surfaced by the backend when an external + * MCP tool needs the user to authorize AgentCore Identity. + */ +export interface OAuthConsentRequest { + providerId: string; + authorizationUrl: string; + receivedAt: number; +} + +/** + * postMessage payload shape broadcast by the `/oauth-complete` landing + * page. Kept in sync with `OAuthCompleteMessage` in + * `src/app/oauth-complete/oauth-complete.page.ts`. + */ +export interface OAuthCompleteMessage { + type: 'agentcore-oauth-complete'; + status: 'success' | 'error'; + providerId: string | null; + error: string | null; +} + +function isOAuthCompleteMessage(data: unknown): data is OAuthCompleteMessage { + if (!data || typeof data !== 'object') { + return false; + } + const msg = data as Partial; + return msg.type === 'agentcore-oauth-complete'; +} + +/** + * Tracks OAuth consent requests surfaced by the SSE stream and coordinates + * the popup flow. + * + * The stream parser calls {@link requestConsent} when an `oauth_required` + * event arrives; components render a "Connect" affordance bound to + * {@link pending}. When the user clicks, {@link openConsentPopup} opens the + * AgentCore Identity URL, and this service listens for the + * `agentcore-oauth-complete` postMessage from the `/oauth-complete` landing + * page to resolve the provider. + */ +@Injectable({ providedIn: 'root' }) +export class OAuthConsentService { + private readonly destroyRef = inject(DestroyRef); + + /** Map of providerId → request. A provider only appears once, even if + * the backend emits duplicates mid-stream. */ + private readonly requests = signal>(new Map()); + + /** ProviderIds whose popup is currently open. */ + private readonly inFlight = signal>(new Set()); + + /** Most recent completion notice surfaced to the chat layer. */ + private readonly lastCompletion = signal(null); + + readonly pending = computed(() => + Array.from(this.requests().values()).sort((a, b) => a.receivedAt - b.receivedAt), + ); + + readonly hasPending = computed(() => this.requests().size > 0); + + readonly completion = this.lastCompletion.asReadonly(); + + constructor() { + // Listen for postMessages from the /oauth-complete landing page. The + // origin guard makes sure cross-origin pages can't spoof a completion. + fromEvent(window, 'message') + .pipe( + filter((event) => event.origin === window.location.origin), + filter((event) => isOAuthCompleteMessage(event.data)), + takeUntilDestroyed(this.destroyRef), + ) + .subscribe((event) => { + const message = event.data as OAuthCompleteMessage; + this.handleCompletion(message); + }); + } + + /** + * Register a consent request coming off the SSE stream. + * Duplicate providerIds refresh the existing entry (URLs can rotate). + */ + requestConsent(providerId: string, authorizationUrl: string): void { + this.requests.update((map) => { + const next = new Map(map); + next.set(providerId, { + providerId, + authorizationUrl, + receivedAt: Date.now(), + }); + return next; + }); + } + + /** + * Open the AgentCore Identity consent URL in a popup window. + * Falls back to a same-tab redirect if the popup is blocked. + */ + openConsentPopup(providerId: string): void { + const request = this.requests().get(providerId); + if (!request) { + return; + } + + const width = 520; + const height = 680; + const left = window.screenX + Math.max(0, (window.outerWidth - width) / 2); + const top = window.screenY + Math.max(0, (window.outerHeight - height) / 2); + + const features = [ + `width=${width}`, + `height=${height}`, + `left=${Math.round(left)}`, + `top=${Math.round(top)}`, + 'resizable=yes', + 'scrollbars=yes', + 'status=no', + 'toolbar=no', + 'menubar=no', + 'location=no', + ].join(','); + + const popup = window.open(request.authorizationUrl, `oauth-${providerId}`, features); + + if (!popup) { + // Popup blocked — fall back to opening in the current tab. The + // `/oauth-complete` page will route back to `/` once consent resolves. + window.location.href = request.authorizationUrl; + return; + } + + this.inFlight.update((set) => { + const next = new Set(set); + next.add(providerId); + return next; + }); + } + + /** Check whether a popup is still open for this provider. */ + isInFlight(providerId: string): boolean { + return this.inFlight().has(providerId); + } + + /** + * Clear a single consent request — called from the UI after the user + * completes or dismisses a provider, or when the chat is reset. + */ + dismiss(providerId: string): void { + this.requests.update((map) => { + if (!map.has(providerId)) { + return map; + } + const next = new Map(map); + next.delete(providerId); + return next; + }); + this.inFlight.update((set) => { + if (!set.has(providerId)) { + return set; + } + const next = new Set(set); + next.delete(providerId); + return next; + }); + } + + /** Reset all state (new session, logout). */ + clear(): void { + this.requests.set(new Map()); + this.inFlight.set(new Set()); + this.lastCompletion.set(null); + } + + /** Acknowledge the last completion signal after the UI has reacted. */ + acknowledgeCompletion(): void { + this.lastCompletion.set(null); + } + + private handleCompletion(message: OAuthCompleteMessage): void { + this.lastCompletion.set(message); + if (message.status === 'success' && message.providerId) { + this.dismiss(message.providerId); + } + } +} diff --git a/frontend/ai.client/src/app/session/components/chat-input/chat-input.component.html b/frontend/ai.client/src/app/session/components/chat-input/chat-input.component.html index d219ddf8..471cb627 100644 --- a/frontend/ai.client/src/app/session/components/chat-input/chat-input.component.html +++ b/frontend/ai.client/src/app/session/components/chat-input/chat-input.component.html @@ -1,3 +1,6 @@ + + + diff --git a/frontend/ai.client/src/app/session/components/chat-input/chat-input.component.ts b/frontend/ai.client/src/app/session/components/chat-input/chat-input.component.ts index 0d67ce96..5f4024a9 100644 --- a/frontend/ai.client/src/app/session/components/chat-input/chat-input.component.ts +++ b/frontend/ai.client/src/app/session/components/chat-input/chat-input.component.ts @@ -10,6 +10,7 @@ import { import { heroPaperAirplaneSolid, heroStopSolid } from '@ng-icons/heroicons/solid'; import { ModelDropdownComponent } from '../../../components/model-dropdown/model-dropdown.component'; import { QuotaWarningBannerComponent } from '../../../components/quota-warning-banner/quota-warning-banner.component'; +import { OAuthConsentBannerComponent } from '../../../components/oauth-consent-banner/oauth-consent-banner.component'; import { TooltipDirective } from '../../../components/tooltip'; import { FileCardComponent } from '../../../components/file-card'; import { StorageQuotaBannerComponent } from '../../../components/storage-quota-banner'; @@ -32,7 +33,7 @@ interface Message { @Component({ selector: 'app-chat-input', - imports: [FormsModule, ModelDropdownComponent, NgIcon, QuotaWarningBannerComponent, StorageQuotaBannerComponent, TooltipDirective, FileCardComponent], + imports: [FormsModule, ModelDropdownComponent, NgIcon, QuotaWarningBannerComponent, OAuthConsentBannerComponent, StorageQuotaBannerComponent, TooltipDirective, FileCardComponent], providers: [ provideIcons({ heroPlus, diff --git a/frontend/ai.client/src/app/session/services/chat/chat-http.service.ts b/frontend/ai.client/src/app/session/services/chat/chat-http.service.ts index 5743f91f..2c857f68 100644 --- a/frontend/ai.client/src/app/session/services/chat/chat-http.service.ts +++ b/frontend/ai.client/src/app/session/services/chat/chat-http.service.ts @@ -65,7 +65,8 @@ export class ChatHttpService { headers: { 'Content-Type': 'application/json', Authorization: `Bearer ${token}`, - Accept: 'text/event-stream', + Accept: 'text/event-stream', + OAuth2CallbackUrl: `${window.location.origin}/oauth-complete`, }, body: JSON.stringify(requestObject), signal: abortController.signal, diff --git a/frontend/ai.client/src/app/session/services/chat/stream-parser.service.ts b/frontend/ai.client/src/app/session/services/chat/stream-parser.service.ts index 14a343a8..c6c30252 100644 --- a/frontend/ai.client/src/app/session/services/chat/stream-parser.service.ts +++ b/frontend/ai.client/src/app/session/services/chat/stream-parser.service.ts @@ -14,6 +14,8 @@ import { QuotaWarning, QuotaExceeded, } from '../../../services/quota/quota-warning.service'; +import { OAuthConsentService } from '../../../services/oauth-consent/oauth-consent.service'; +import type { OAuthRequiredEvent } from '../../../shared/utils/stream-parser'; import { processStreamEvent, createStreamLineParser, @@ -48,6 +50,7 @@ export class StreamParserService { private chatStateService = inject(ChatStateService); private errorService = inject(ErrorService); private quotaWarningService = inject(QuotaWarningService); + private oauthConsentService = inject(OAuthConsentService); // ========================================================================= // State Signals @@ -289,6 +292,9 @@ export class StreamParserService { onQuotaWarning: (data) => this.quotaWarningService.setWarning(data as QuotaWarning), onQuotaExceeded: (data) => this.quotaWarningService.setQuotaExceeded(data as QuotaExceeded), + onOAuthRequired: (data: OAuthRequiredEvent) => + this.oauthConsentService.requestConsent(data.providerId, data.authorizationUrl), + onError: (data) => this.handleError(data), onStreamError: (data) => this.errorService.handleConversationalStreamError(data as ConversationalStreamError), diff --git a/frontend/ai.client/src/app/shared/utils/stream-parser/index.ts b/frontend/ai.client/src/app/shared/utils/stream-parser/index.ts index 9a2085eb..abfe4f72 100644 --- a/frontend/ai.client/src/app/shared/utils/stream-parser/index.ts +++ b/frontend/ai.client/src/app/shared/utils/stream-parser/index.ts @@ -55,6 +55,7 @@ export { validateQuotaExceededEvent, validateConversationalStreamError, validateCitation, + validateOAuthRequiredEvent, } from './stream-parser-core'; // Types @@ -74,6 +75,7 @@ export type { QuotaExceededEvent, StreamErrorEvent, ConversationalStreamErrorEvent, + OAuthRequiredEvent, StreamEventType, StreamEventData, ParsedStreamEvent, diff --git a/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-core.ts b/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-core.ts index 339d697d..092a0ee7 100644 --- a/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-core.ts +++ b/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-core.ts @@ -36,6 +36,7 @@ import type { QuotaExceededEvent, StreamErrorEvent, ConversationalStreamErrorEvent, + OAuthRequiredEvent, ToolProgress, } from './stream-parser-types'; import type { MetadataEvent } from '../../../session/services/models/content-types'; @@ -75,6 +76,9 @@ export interface StreamParserCallbacks { onQuotaWarning?: (data: QuotaWarningEvent) => void; onQuotaExceeded?: (data: QuotaExceededEvent) => void; + // OAuth consent required (external MCP tool needs user authorization) + onOAuthRequired?: (data: OAuthRequiredEvent) => void; + // Error handling onError?: (data: StreamErrorEvent | ConversationalStreamErrorEvent | string) => void; onStreamError?: (data: ConversationalStreamErrorEvent) => void; @@ -318,6 +322,25 @@ export function validateConversationalStreamError( ); } +/** + * Validate OAuthRequiredEvent structure + */ +export function validateOAuthRequiredEvent(data: unknown): data is OAuthRequiredEvent { + if (!data || typeof data !== 'object') { + return false; + } + + const event = data as Partial; + + return ( + event.type === 'oauth_required' && + typeof event.providerId === 'string' && + event.providerId.length > 0 && + typeof event.authorizationUrl === 'string' && + event.authorizationUrl.length > 0 + ); +} + /** * Validate Citation structure */ @@ -483,6 +506,14 @@ export function processStreamEvent( } break; + case 'oauth_required': + if (validateOAuthRequiredEvent(data)) { + callbacks.onOAuthRequired?.(data); + } else { + callbacks.onParseError?.('oauth_required: invalid data structure'); + } + break; + default: // Ignore unknown events (ping, etc.) break; diff --git a/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-types.ts b/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-types.ts index 8b9ed116..382d4842 100644 --- a/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-types.ts +++ b/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-types.ts @@ -86,6 +86,17 @@ export interface ReasoningEvent { reasoningText?: string; } +/** + * OAuth required event — emitted when an external MCP tool needs the user + * to grant consent via AgentCore Identity. The payload carries the provider + * slug and the consent URL to open. + */ +export interface OAuthRequiredEvent { + type: 'oauth_required'; + providerId: string; + authorizationUrl: string; +} + /** * Tool result event data structure */ @@ -123,7 +134,8 @@ export type StreamEventType = | 'quota_warning' | 'quota_exceeded' | 'stream_error' - | 'citation'; + | 'citation' + | 'oauth_required'; /** * Union type of all possible event data types @@ -143,6 +155,7 @@ export type StreamEventData = | StreamErrorEvent | ConversationalStreamErrorEvent | Citation + | OAuthRequiredEvent | null | undefined; diff --git a/infrastructure/lib/inference-api-stack.ts b/infrastructure/lib/inference-api-stack.ts index 35f3413c..f75ad903 100644 --- a/infrastructure/lib/inference-api-stack.ts +++ b/infrastructure/lib/inference-api-stack.ts @@ -482,6 +482,25 @@ export class InferenceApiStack extends cdk.Stack { ], })); + // AgentCore Identity OAuth2 token vault access — lets the Runtime fetch + // user-federated OAuth tokens for external MCP tools (e.g. Google, Slack) + // via IdentityClient.get_token. When the user has not consented, this + // call returns an authorization URL instead of a token, which the + // inference route surfaces as an `oauth_required` SSE event. + runtimeExecutionRole.addToPolicy(new iam.PolicyStatement({ + sid: 'GetResourceOauth2Token', + effect: iam.Effect.ALLOW, + actions: [ + 'bedrock-agentcore:GetResourceOauth2Token', + ], + resources: [ + `arn:aws:bedrock-agentcore:${config.awsRegion}:${config.awsAccount}:token-vault/default`, + `arn:aws:bedrock-agentcore:${config.awsRegion}:${config.awsAccount}:token-vault/default/*`, + `arn:aws:bedrock-agentcore:${config.awsRegion}:${config.awsAccount}:workload-identity-directory/default`, + `arn:aws:bedrock-agentcore:${config.awsRegion}:${config.awsAccount}:workload-identity-directory/default/workload-identity/hosted_agent_*`, + ], + })); + // DynamoDB Quota Tables permissions (imported from App API Stack) const userQuotasTableArn = ssm.StringParameter.valueForStringParameter( this, From 54bfe10fcd4cd314d2ca20874185f8ee94cf92bb Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Wed, 22 Apr 2026 17:24:10 -0600 Subject: [PATCH 12/35] feat(connectors): user-facing settings page + AgentCore consent finalizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the Settings → Connectors page so users can browse and connect OAuth-backed external tools end-to-end: - New /connectors routers on app-api (list user-visible providers via RBAC) and inference-api (initiate-consent, complete-consent) — the inference-api side runs under the AgentCore Runtime proxy where the WorkloadAccessToken context is populated. - AgentCoreIdentityClient gains a workload-token mint fallback for local dev (GetWorkloadAccessTokenForUserId) and appends provider_id to the callback URL so the landing page can dismiss the right banner. - /oauth-complete page POSTs CompleteResourceTokenAuth back through the inference-api before notifying the opener, fixing the "consent finished but vault stayed empty" race. Uses BroadcastChannel to bridge popup → opener under Chrome's COOP isolation. - New connectors settings page with a Connect / Reconnect affordance per provider, wired to the OAuthConsentService popup flow. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../integrations/agentcore_identity.py | 151 ++++++++-- .../src/apis/app_api/connectors/__init__.py | 0 backend/src/apis/app_api/connectors/routes.py | 79 +++++ backend/src/apis/app_api/main.py | 2 + .../apis/inference_api/connectors/__init__.py | 0 .../apis/inference_api/connectors/routes.py | 172 +++++++++++ backend/src/apis/inference_api/main.py | 2 + .../integrations/test_agentcore_identity.py | 63 ++-- .../app/oauth-complete/oauth-complete.page.ts | 99 ++++++- .../connectors/models/user-connector.model.ts | 25 ++ .../services/user-connectors.service.ts | 75 +++++ .../connectors-settings.page.ts | 270 ++++++++++++++++++ .../src/app/settings/settings.page.ts | 1 + .../src/app/settings/settings.routes.ts | 5 + 14 files changed, 887 insertions(+), 57 deletions(-) create mode 100644 backend/src/apis/app_api/connectors/__init__.py create mode 100644 backend/src/apis/app_api/connectors/routes.py create mode 100644 backend/src/apis/inference_api/connectors/__init__.py create mode 100644 backend/src/apis/inference_api/connectors/routes.py create mode 100644 frontend/ai.client/src/app/settings/connectors/models/user-connector.model.ts create mode 100644 frontend/ai.client/src/app/settings/connectors/services/user-connectors.service.ts create mode 100644 frontend/ai.client/src/app/settings/pages/connectors-settings/connectors-settings.page.ts diff --git a/backend/src/agents/main_agent/integrations/agentcore_identity.py b/backend/src/agents/main_agent/integrations/agentcore_identity.py index 195464ec..193f9cce 100644 --- a/backend/src/agents/main_agent/integrations/agentcore_identity.py +++ b/backend/src/agents/main_agent/integrations/agentcore_identity.py @@ -31,12 +31,45 @@ from dataclasses import dataclass from typing import List, Optional +import boto3 from bedrock_agentcore.runtime import BedrockAgentCoreContext -from bedrock_agentcore.services.identity import IdentityClient +from bedrock_agentcore.services.identity import IdentityClient, TokenPoller logger = logging.getLogger(__name__) +class _ConsentRequired(Exception): + """Internal marker — raised by `_ShortCircuitPoller` once AgentCore hands + us an auth URL, so we can return it to the caller without waiting for + the user to actually complete consent.""" + + +class _ShortCircuitPoller(TokenPoller): + """Skip the SDK's default poll loop. + + The default poller hits `GetResourceOauth2Token` on a timer until the + user finishes consent (up to several minutes). We only care about the + URL — our caller returns it to the frontend, which drives the popup + flow on its own. Raising immediately short-circuits the wait. + """ + + async def poll_for_token(self) -> str: + raise _ConsentRequired() + +# In production, the AgentCore Runtime proxies every request to the inference +# API with a `WorkloadAccessToken` header bound to (runtime workload, user). +# `AgentCoreContextMiddleware` copies that header onto `BedrockAgentCoreContext` +# so downstream code can fetch user-federated OAuth tokens without threading +# it through function args. +# +# Local dev doesn't go through the runtime, so the header is absent. When +# `AGENTCORE_RUNTIME_WORKLOAD_NAME` is set, we fall back to minting a workload +# token against that runtime ourselves via +# `bedrock-agentcore:GetWorkloadAccessTokenForUserId`. The caller's AWS +# principal must be authorised for that action on the target workload. +_RUNTIME_WORKLOAD_ENV = "AGENTCORE_RUNTIME_WORKLOAD_NAME" + + @dataclass(frozen=True) class TokenResult: """Result of a token fetch attempt. @@ -77,20 +110,26 @@ class AgentCoreIdentityClient: def __init__(self, region: Optional[str] = None): self._region = region or os.environ.get("AWS_REGION", "us-east-1") self._client = IdentityClient(region=self._region) + self._control_client = boto3.client("bedrock-agentcore", region_name=self._region) - def get_token_for_user( + async def get_token_for_user( self, *, provider_name: str, scopes: List[str], callback_url: Optional[str] = None, force_authentication: bool = False, + user_id: Optional[str] = None, + custom_state: Optional[str] = None, ) -> TokenResult: """Fetch a user-federated OAuth2 access token for `provider_name`. - Pulls the workload identity token from `BedrockAgentCoreContext`, so - this must be called from inside an AgentCore Runtime invocation that - has been processed by `AgentCoreContextMiddleware`. + In production the workload identity token comes from + `BedrockAgentCoreContext` (populated by `AgentCoreContextMiddleware` + from the AgentCore Runtime's request header). For local dev the + header is absent — when `AGENTCORE_RUNTIME_WORKLOAD_NAME` is set + and `user_id` is provided, we mint a workload token ourselves + against that runtime. If the user has not consented (or re-consent is required), returns a `TokenResult` with `authorization_url` populated instead of raising. @@ -105,43 +144,62 @@ def get_token_for_user( force_authentication: If True, bypasses the token vault cache and forces the user through the consent flow again. Used for scope upgrades. + user_id: User identifier for the local-dev workload-token + fallback. Ignored in production where the context already + has a token. Returns: `TokenResult` with either `access_token` or `authorization_url`. Raises: - WorkloadTokenUnavailableError: No workload token on context. + WorkloadTokenUnavailableError: No token on context and the + local-dev fallback is unavailable (env var unset, user_id + missing, or IAM denies the mint call). """ - workload_token = BedrockAgentCoreContext.get_workload_access_token() - if not workload_token: - raise WorkloadTokenUnavailableError( - "No WorkloadAccessToken on context — ensure " - "AgentCoreContextMiddleware is installed and this call " - "runs inside an AgentCore Runtime invocation." - ) + workload_token = self._resolve_workload_token(user_id) resolved_callback_url = ( callback_url or BedrockAgentCoreContext.get_oauth2_callback_url() ) + # AgentCore's return-URL redirect doesn't include any hint about which + # provider resolved, so the frontend's /oauth-complete page has no way + # to tell the consent service which pending entry to dismiss. Append + # provider_id as a query param so the page can read it back. + if resolved_callback_url: + from urllib.parse import urlencode, urlparse, urlunparse, parse_qsl + + parsed = urlparse(resolved_callback_url) + existing = dict(parse_qsl(parsed.query, keep_blank_values=True)) + existing.setdefault("provider_id", provider_name) + resolved_callback_url = urlunparse(parsed._replace(query=urlencode(existing))) + captured_url: dict[str, Optional[str]] = {"url": None} def _capture_auth_url(url: str) -> None: captured_url["url"] = url - token = self._client.get_token( - provider_name=provider_name, - scopes=scopes, - agent_identity_token=workload_token, - auth_flow="USER_FEDERATION", - callback_url=resolved_callback_url, - force_authentication=force_authentication, - on_auth_url=_capture_auth_url, - ) - - # `get_token` returns either the token string or triggers on_auth_url - # when consent is required. Guard both: if we captured a URL, surface - # it as a TokenResult even if the SDK also returned a (stale) token. + try: + sdk_kwargs = dict( + provider_name=provider_name, + scopes=scopes, + agent_identity_token=workload_token, + auth_flow="USER_FEDERATION", + callback_url=resolved_callback_url, + force_authentication=force_authentication, + on_auth_url=_capture_auth_url, + token_poller=_ShortCircuitPoller(), + ) + if custom_state is not None: + sdk_kwargs["custom_state"] = custom_state + token = await self._client.get_token(**sdk_kwargs) + except _ConsentRequired: + # Expected path when consent is required: the SDK invoked + # on_auth_url and then handed off to our poller, which raises. + token = None + + # If we captured a URL, return it — even if the SDK also returned + # a (stale) token, consent-required is the authoritative signal. if captured_url["url"]: logger.info( "AgentCore Identity requires user consent for provider=%s", @@ -157,6 +215,47 @@ def _capture_auth_url(url: str) -> None: return TokenResult(access_token=token) + def _resolve_workload_token(self, user_id: Optional[str]) -> str: + """Return a workload access token, preferring the runtime-supplied one. + + Falls back to minting via `GetWorkloadAccessTokenForUserId` when the + context has no token, the `AGENTCORE_RUNTIME_WORKLOAD_NAME` env var + is set, and the caller passed `user_id`. Any other combination + raises `WorkloadTokenUnavailableError`. + """ + context_token = BedrockAgentCoreContext.get_workload_access_token() + if context_token: + return context_token + + workload_name = os.environ.get(_RUNTIME_WORKLOAD_ENV) + if not workload_name: + raise WorkloadTokenUnavailableError( + "No WorkloadAccessToken on context. For local dev, set " + f"{_RUNTIME_WORKLOAD_ENV} to your deployed runtime's " + "workload identity name (e.g. hosted_agent_XXXXX)." + ) + if not user_id: + raise WorkloadTokenUnavailableError( + "No WorkloadAccessToken on context and no user_id provided " + "for the local-dev mint fallback." + ) + + logger.info( + "Minting workload access token for user=%s workload=%s (local dev)", + user_id, + workload_name, + ) + response = self._control_client.get_workload_access_token_for_user_id( + workloadName=workload_name, + userId=user_id, + ) + minted_token = response.get("workloadAccessToken") + if not minted_token: + raise WorkloadTokenUnavailableError( + "GetWorkloadAccessTokenForUserId returned no token" + ) + return minted_token + _default_client: Optional[AgentCoreIdentityClient] = None diff --git a/backend/src/apis/app_api/connectors/__init__.py b/backend/src/apis/app_api/connectors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/src/apis/app_api/connectors/routes.py b/backend/src/apis/app_api/connectors/routes.py new file mode 100644 index 00000000..5f14715f --- /dev/null +++ b/backend/src/apis/app_api/connectors/routes.py @@ -0,0 +1,79 @@ +"""User-facing connector routes. + +Lets a signed-in user see which OAuth connectors are available to them +(role-filtered). Consent is initiated by the inference API, which has the +AgentCore Runtime workload context; this router is purely a data source. +""" + +import logging +from typing import List + +from fastapi import APIRouter, Depends +from pydantic import BaseModel + +from apis.shared.auth import User, get_current_user +from apis.shared.oauth.models import OAuthProvider, OAuthProviderType +from apis.shared.oauth.provider_repository import ( + OAuthProviderRepository, + get_provider_repository, +) +from apis.shared.rbac.service import AppRoleService, get_app_role_service + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/connectors", tags=["connectors"]) + + +class UserConnector(BaseModel): + """Connector as visible to a signed-in user. + + Drops admin-only fields (ARN, callback URL, role allow-list) and keeps + only what the settings page needs to render. + """ + + provider_id: str + display_name: str + provider_type: OAuthProviderType + icon_name: str + scopes: List[str] + + +class UserConnectorListResponse(BaseModel): + connectors: List[UserConnector] + + +def _visible_to_user(provider: OAuthProvider, user_role_ids: List[str]) -> bool: + """True when the user is allowed to use this connector. + + An empty `allowed_roles` list means unrestricted access. A non-empty + list grants access to users who share at least one AppRole ID. + """ + if not provider.enabled: + return False + if not provider.allowed_roles: + return True + return bool(set(provider.allowed_roles) & set(user_role_ids)) + + +@router.get("/", response_model=UserConnectorListResponse) +async def list_connectors( + current_user: User = Depends(get_current_user), + provider_repo: OAuthProviderRepository = Depends(get_provider_repository), + role_service: AppRoleService = Depends(get_app_role_service), +) -> UserConnectorListResponse: + """List enabled connectors available to the current user.""" + permissions = await role_service.resolve_user_permissions(current_user) + providers = await provider_repo.list_providers(enabled_only=True) + visible = [p for p in providers if _visible_to_user(p, permissions.app_roles)] + return UserConnectorListResponse( + connectors=[ + UserConnector( + provider_id=p.provider_id, + display_name=p.display_name, + provider_type=p.provider_type, + icon_name=p.icon_name, + scopes=p.scopes, + ) + for p in visible + ] + ) diff --git a/backend/src/apis/app_api/main.py b/backend/src/apis/app_api/main.py index fca2dd8e..36e8665c 100644 --- a/backend/src/apis/app_api/main.py +++ b/backend/src/apis/app_api/main.py @@ -87,6 +87,7 @@ async def lifespan(app: FastAPI): from apis.app_api.documents.routes import router as documents_router from apis.app_api.users.routes import router as users_router from apis.app_api.user_settings.routes import router as user_settings_router +from apis.app_api.connectors.routes import router as connectors_router from apis.app_api.system.routes import router as system_router from apis.app_api.shares.routes import conversations_share_router, shares_router, shared_view_router @@ -107,6 +108,7 @@ async def lifespan(app: FastAPI): app.include_router(memory_router) # AgentCore Memory access endpoints app.include_router(tools_router) # Tool discovery and permissions app.include_router(files_router) # File upload via pre-signed URLs +app.include_router(connectors_router) # User-facing connector catalog app.include_router(system_router) # System status and first-boot endpoints app.include_router(conversations_share_router) # Share conversations endpoints app.include_router(shares_router) # Share management (update, revoke, export) diff --git a/backend/src/apis/inference_api/connectors/__init__.py b/backend/src/apis/inference_api/connectors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/src/apis/inference_api/connectors/routes.py b/backend/src/apis/inference_api/connectors/routes.py new file mode 100644 index 00000000..233cd6db --- /dev/null +++ b/backend/src/apis/inference_api/connectors/routes.py @@ -0,0 +1,172 @@ +"""User-initiated OAuth consent for connectors. + +Lives on the inference API because the AgentCore Runtime injects the +workload access token via `AgentCoreContextMiddleware` on every request +proxied through `InvokeAgentRuntime`. `IdentityClient.get_token` reads +that token from `BedrockAgentCoreContext`, which is only populated here +— never on the app API. + +Flow: the settings page posts to `/connectors/{id}/initiate-consent`. +If AgentCore already has a valid token for this user + provider, we +return `{connected: true}` so the UI can show a success state. If +consent is required, AgentCore hands us back an authorization URL and we +forward it for the frontend popup. +""" + +from __future__ import annotations + +import logging + +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import BaseModel + +from agents.main_agent.integrations.agentcore_identity import ( + WorkloadTokenUnavailableError, + get_agentcore_identity_client, +) +from apis.shared.auth.dependencies import get_current_user_trusted +from apis.shared.auth.models import User +from apis.shared.oauth.provider_repository import ( + OAuthProviderRepository, + get_provider_repository, +) +from apis.shared.rbac.service import AppRoleService, get_app_role_service + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/connectors", tags=["connectors"]) + + +class InitiateConsentResponse(BaseModel): + """Either a pending consent URL or a confirmation of existing access.""" + + connected: bool = False + authorization_url: str | None = None + + +class CompleteConsentRequest(BaseModel): + """Body for finalizing a consent flow after the popup returns.""" + + session_uri: str + provider_id: str | None = None + + +class CompleteConsentResponse(BaseModel): + ok: bool = True + + +def _is_visible_to_user(provider, user_role_ids: list[str]) -> bool: + if not provider.enabled: + return False + if not provider.allowed_roles: + return True + return bool(set(provider.allowed_roles) & set(user_role_ids)) + + +@router.post( + "/{provider_id}/initiate-consent", + response_model=InitiateConsentResponse, +) +async def initiate_consent( + provider_id: str, + current_user: User = Depends(get_current_user_trusted), + provider_repo: OAuthProviderRepository = Depends(get_provider_repository), + role_service: AppRoleService = Depends(get_app_role_service), +) -> InitiateConsentResponse: + """Start (or verify) AgentCore consent for the given provider.""" + provider = await provider_repo.get_provider(provider_id) + if not provider or not provider.enabled: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Connector '{provider_id}' not found", + ) + + permissions = await role_service.resolve_user_permissions(current_user) + if not _is_visible_to_user(provider, permissions.app_roles): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have access to this connector", + ) + + identity = get_agentcore_identity_client() + try: + result = await identity.get_token_for_user( + provider_name=provider.provider_id, + scopes=provider.scopes, + user_id=current_user.user_id, + # No custom_state: AgentCore appears to treat its presence as a + # signal to start a fresh flow, never short-circuiting to the + # cached token. The frontend passes provider_id via the + # callback URL query string so /oauth-complete still knows + # which provider resolved. + ) + except WorkloadTokenUnavailableError as err: + # Only happens when the route is called outside an AgentCore Runtime + # invocation (e.g. local dev without the runtime proxy). Surface a + # clear error instead of a 500. + logger.warning("Consent initiation attempted without workload context: %s", err) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=( + "AgentCore workload context is not available. In prod, this " + "endpoint runs under the runtime proxy; locally, set " + "AGENTCORE_RUNTIME_WORKLOAD_NAME to enable the mint fallback." + ), + ) + + if result.requires_consent: + return InitiateConsentResponse(authorization_url=result.authorization_url) + return InitiateConsentResponse(connected=True) + + +@router.post( + "/complete-consent", + response_model=CompleteConsentResponse, +) +async def complete_consent( + body: CompleteConsentRequest, + current_user: User = Depends(get_current_user_trusted), +) -> CompleteConsentResponse: + """Finalize an OAuth consent flow after the popup redirects home. + + AgentCore's `/identities/oauth2/authorize` redirect comes back with the + same `request_uri` it was initiated with (as `session_id` on our landing + page). Until we call `CompleteResourceTokenAuth` with that URI and the + user's identity, AgentCore treats the flow as unfinished and the token + vault stays empty — the next `GetResourceOauth2Token` call returns a + fresh authorization URL. + + Returns `ok: true` on success; errors from AgentCore bubble up as 502. + """ + import boto3 + from agents.main_agent.integrations.agentcore_identity import ( + _RUNTIME_WORKLOAD_ENV, + ) + import os + + region = os.environ.get("AWS_REGION", "us-west-2") + control = boto3.client("bedrock-agentcore", region_name=region) + + try: + control.complete_resource_token_auth( + userIdentifier={"userId": current_user.user_id}, + sessionUri=body.session_uri, + ) + except Exception as err: + logger.error( + "CompleteResourceTokenAuth failed for user=%s provider=%s: %s", + current_user.user_id, + body.provider_id, + err, + ) + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"Failed to finalize OAuth consent: {err}", + ) + + logger.info( + "Completed OAuth consent for user=%s provider=%s", + current_user.user_id, + body.provider_id, + ) + return CompleteConsentResponse(ok=True) diff --git a/backend/src/apis/inference_api/main.py b/backend/src/apis/inference_api/main.py index e3f2177a..4572ddb4 100644 --- a/backend/src/apis/inference_api/main.py +++ b/backend/src/apis/inference_api/main.py @@ -140,11 +140,13 @@ async def lifespan(app: FastAPI): from apis.inference_api.chat.routes import router as agentcore_router from apis.inference_api.chat.converse_routes import router as converse_router from apis.inference_api.chat.voice_routes import router as voice_router +from apis.inference_api.connectors.routes import router as connectors_router # Include routers #app.include_router(health_router) app.include_router(agentcore_router) # AgentCore Runtime endpoints: /ping, /invocations app.include_router(converse_router) # API-key authenticated converse endpoint app.include_router(voice_router) # WebSocket voice streaming endpoint +app.include_router(connectors_router) # User-initiated OAuth consent # Mount static file directories for serving generated content # These are created by tools (visualization, code interpreter, etc.) diff --git a/backend/tests/agents/main_agent/integrations/test_agentcore_identity.py b/backend/tests/agents/main_agent/integrations/test_agentcore_identity.py index 40348df1..edbeac39 100644 --- a/backend/tests/agents/main_agent/integrations/test_agentcore_identity.py +++ b/backend/tests/agents/main_agent/integrations/test_agentcore_identity.py @@ -1,6 +1,6 @@ """Tests for AgentCoreIdentityClient.""" -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -52,14 +52,15 @@ def mock_context(): class TestGetTokenForUserCacheHit: - def test_returns_access_token_when_vault_has_token( + @pytest.mark.asyncio + async def test_returns_access_token_when_vault_has_token( self, mock_identity_sdk: MagicMock, mock_context: MagicMock ) -> None: sdk_instance = mock_identity_sdk.return_value - sdk_instance.get_token.return_value = "ya29.access-token" + sdk_instance.get_token = AsyncMock(return_value="ya29.access-token") client = AgentCoreIdentityClient(region="us-east-1") - result = client.get_token_for_user( + result = await client.get_token_for_user( provider_name="google-workspace", scopes=["openid"] ) @@ -72,28 +73,36 @@ def test_returns_access_token_when_vault_has_token( assert kwargs["scopes"] == ["openid"] assert kwargs["auth_flow"] == "USER_FEDERATION" assert kwargs["agent_identity_token"] == "workload-token-xyz" - assert kwargs["callback_url"] == "https://cb.example.com/oauth" + # Wrapper appends provider_id to the callback so the /oauth-complete + # page knows which provider to dismiss in the consent banner. + assert kwargs["callback_url"] == ( + "https://cb.example.com/oauth?provider_id=google-workspace" + ) assert kwargs["force_authentication"] is False - def test_explicit_callback_url_overrides_context( + @pytest.mark.asyncio + async def test_explicit_callback_url_overrides_context( self, mock_identity_sdk: MagicMock, mock_context: MagicMock ) -> None: sdk_instance = mock_identity_sdk.return_value - sdk_instance.get_token.return_value = "t" + sdk_instance.get_token = AsyncMock(return_value="t") client = AgentCoreIdentityClient() - client.get_token_for_user( + await client.get_token_for_user( provider_name="p", scopes=["s"], callback_url="https://override.example.com/cb", ) kwargs = sdk_instance.get_token.call_args.kwargs - assert kwargs["callback_url"] == "https://override.example.com/cb" + assert kwargs["callback_url"] == ( + "https://override.example.com/cb?provider_id=p" + ) class TestGetTokenForUserConsentRequired: - def test_returns_authorization_url_when_sdk_invokes_callback( + @pytest.mark.asyncio + async def test_returns_authorization_url_when_sdk_invokes_callback( self, mock_identity_sdk: MagicMock, mock_context: MagicMock ) -> None: """When the user needs to consent, the SDK calls on_auth_url with the @@ -101,67 +110,71 @@ def test_returns_authorization_url_when_sdk_invokes_callback( authorization_url set rather than raising.""" sdk_instance = mock_identity_sdk.return_value - def fake_get_token(**kwargs): + async def fake_get_token(**kwargs): kwargs["on_auth_url"]("https://accounts.example.com/consent?x=1") return None - sdk_instance.get_token.side_effect = fake_get_token + sdk_instance.get_token = AsyncMock(side_effect=fake_get_token) client = AgentCoreIdentityClient() - result = client.get_token_for_user(provider_name="p", scopes=["s"]) + result = await client.get_token_for_user(provider_name="p", scopes=["s"]) assert result.requires_consent is True assert result.authorization_url == "https://accounts.example.com/consent?x=1" assert result.access_token is None - def test_auth_url_takes_precedence_over_stale_token( + @pytest.mark.asyncio + async def test_auth_url_takes_precedence_over_stale_token( self, mock_identity_sdk: MagicMock, mock_context: MagicMock ) -> None: """Defensive: if the SDK both returns a token AND invokes on_auth_url, we treat consent-required as the authoritative signal.""" sdk_instance = mock_identity_sdk.return_value - def fake_get_token(**kwargs): + async def fake_get_token(**kwargs): kwargs["on_auth_url"]("https://consent.example.com") return "stale-token" - sdk_instance.get_token.side_effect = fake_get_token + sdk_instance.get_token = AsyncMock(side_effect=fake_get_token) client = AgentCoreIdentityClient() - result = client.get_token_for_user(provider_name="p", scopes=["s"]) + result = await client.get_token_for_user(provider_name="p", scopes=["s"]) assert result.requires_consent is True assert result.authorization_url == "https://consent.example.com" class TestGetTokenForUserErrors: - def test_raises_when_no_workload_token_on_context( + @pytest.mark.asyncio + async def test_raises_when_no_workload_token_on_context( self, mock_identity_sdk: MagicMock, mock_context: MagicMock ) -> None: mock_context.get_workload_access_token.return_value = None client = AgentCoreIdentityClient() with pytest.raises(WorkloadTokenUnavailableError): - client.get_token_for_user(provider_name="p", scopes=["s"]) + await client.get_token_for_user(provider_name="p", scopes=["s"]) - def test_raises_when_sdk_returns_nothing_and_no_auth_url( + @pytest.mark.asyncio + async def test_raises_when_sdk_returns_nothing_and_no_auth_url( self, mock_identity_sdk: MagicMock, mock_context: MagicMock ) -> None: sdk_instance = mock_identity_sdk.return_value - sdk_instance.get_token.return_value = None + sdk_instance.get_token = AsyncMock(return_value=None) client = AgentCoreIdentityClient() with pytest.raises(RuntimeError, match="neither a token nor"): - client.get_token_for_user(provider_name="p", scopes=["s"]) + await client.get_token_for_user(provider_name="p", scopes=["s"]) - def test_force_authentication_flag_is_forwarded( + @pytest.mark.asyncio + async def test_force_authentication_flag_is_forwarded( self, mock_identity_sdk: MagicMock, mock_context: MagicMock ) -> None: sdk_instance = mock_identity_sdk.return_value - sdk_instance.get_token.return_value = "t" + sdk_instance.get_token = AsyncMock(return_value="t") client = AgentCoreIdentityClient() - client.get_token_for_user( + await client.get_token_for_user( provider_name="p", scopes=["s"], force_authentication=True ) diff --git a/frontend/ai.client/src/app/oauth-complete/oauth-complete.page.ts b/frontend/ai.client/src/app/oauth-complete/oauth-complete.page.ts index eb7e1b07..aa72b865 100644 --- a/frontend/ai.client/src/app/oauth-complete/oauth-complete.page.ts +++ b/frontend/ai.client/src/app/oauth-complete/oauth-complete.page.ts @@ -13,6 +13,8 @@ import { heroCheckCircle, heroExclamationCircle, } from '@ng-icons/heroicons/outline'; +import { AuthService } from '../auth/auth.service'; +import { ConfigService } from '../services/config.service'; /** * Landing page for AgentCore Identity's 3-legged OAuth flow. @@ -161,13 +163,17 @@ export interface OAuthCompleteMessage { export class OAuthCompletePage implements OnInit, OnDestroy { private readonly route = inject(ActivatedRoute); private readonly router = inject(Router); + private readonly authService = inject(AuthService); + private readonly config = inject(ConfigService); private redirectTimer: ReturnType | null = null; readonly state = signal('success'); readonly providerId = signal(null); + readonly sessionUri = signal(null); readonly errorMessage = signal('Authorization was denied or did not complete.'); private readonly isPopup = signal(false); + private readonly finalizing = signal(false); readonly providerLabel = computed(() => { const id = this.providerId(); @@ -192,11 +198,24 @@ export class OAuthCompletePage implements OnInit, OnDestroy { const params = this.route.snapshot.queryParamMap; const error = params.get('error'); const errorDescription = params.get('error_description'); - const providerId = params.get('provider_id') ?? params.get('providerId'); + // AgentCore echoes our custom_state back as `state`; we set it server-side + // to the providerId when initiating consent from the settings page. + const providerId = + params.get('provider_id') ?? + params.get('providerId') ?? + params.get('state'); + + // AgentCore's redirect also carries `session_id`, which is the + // `request_uri` from the initial authorize call. We must hand it back + // to CompleteResourceTokenAuth or the token vault stays empty. + const sessionUri = params.get('session_id') ?? params.get('sessionUri'); if (providerId) { this.providerId.set(providerId); } + if (sessionUri) { + this.sessionUri.set(sessionUri); + } if (error) { this.state.set('error'); @@ -208,6 +227,31 @@ export class OAuthCompletePage implements OnInit, OnDestroy { const inPopup = this.detectPopup(); this.isPopup.set(inPopup); + // Finalize AgentCore's OAuth session (exchanges the `request_uri` for a + // persisted token). Must happen BEFORE we tell the opener we're done — + // otherwise the opener's next tool call will still see "consent required". + if (this.state() === 'success' && sessionUri) { + this.finalizing.set(true); + this.finalizeConsent(sessionUri, providerId) + .catch((err) => { + this.state.set('error'); + this.errorMessage.set( + err instanceof Error + ? `Couldn't finalize authorization: ${err.message}` + : "Couldn't finalize authorization.", + ); + }) + .finally(() => { + this.finalizing.set(false); + if (inPopup) { + this.notifyOpenerAndClose(); + } else { + this.redirectTimer = setTimeout(() => this.router.navigate(['/']), 2000); + } + }); + return; + } + if (inPopup) { this.notifyOpenerAndClose(); } else { @@ -215,6 +259,36 @@ export class OAuthCompletePage implements OnInit, OnDestroy { } } + private async finalizeConsent( + sessionUri: string, + providerId: string | null, + ): Promise { + const baseUrl = this.config.inferenceApiUrl().replace(/\/invocations\/?$/, ''); + if (!baseUrl) { + throw new Error('inferenceApiUrl not configured'); + } + const token = this.authService.getAccessToken(); + if (!token) { + throw new Error('No access token available'); + } + const url = `${baseUrl}/connectors/complete-consent`; + const response = await fetch(url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ + session_uri: sessionUri, + provider_id: providerId, + }), + }); + if (!response.ok) { + const text = await response.text().catch(() => ''); + throw new Error(`HTTP ${response.status}: ${text || response.statusText}`); + } + } + ngOnDestroy(): void { if (this.redirectTimer !== null) { clearTimeout(this.redirectTimer); @@ -241,15 +315,28 @@ export class OAuthCompletePage implements OnInit, OnDestroy { providerId: this.providerId(), error: this.state() === 'error' ? this.errorMessage() : null, }; + + // Primary channel: BroadcastChannel. Survives the Cross-Origin-Opener-Policy + // split that severs window.opener once a popup navigates through external + // origins (AgentCore, Google) and back. Same-origin tabs sharing a channel + // name always see each other, regardless of opener relationship. + try { + const channel = new BroadcastChannel('agentcore-oauth-complete'); + channel.postMessage(message); + // Give the message a tick to propagate before closing the channel. + setTimeout(() => channel.close(), 200); + } catch { + // BroadcastChannel unavailable — fall back to postMessage below. + } + + // Fallback: postMessage to opener. Works when COOP isn't in play. try { window.opener?.postMessage(message, window.location.origin); } catch { - // Cross-origin opener — nothing we can do; leave the page open so the - // user can read the message and close manually. - return; + // Cross-origin or COOP-isolated opener — BroadcastChannel above + // handles the handoff in that case. } - // Small delay so the opener has time to receive before the window closes - // (Chrome closes immediately otherwise on some platforms). + this.redirectTimer = setTimeout(() => { try { window.close(); diff --git a/frontend/ai.client/src/app/settings/connectors/models/user-connector.model.ts b/frontend/ai.client/src/app/settings/connectors/models/user-connector.model.ts new file mode 100644 index 00000000..de884b6e --- /dev/null +++ b/frontend/ai.client/src/app/settings/connectors/models/user-connector.model.ts @@ -0,0 +1,25 @@ +/** + * Connector shape returned by the user-facing catalog endpoint. + * Strips admin-only fields (ARN, callback URL, allow-list). + */ +export interface UserConnector { + providerId: string; + displayName: string; + providerType: 'google' | 'microsoft' | 'github' | 'canvas' | 'custom'; + iconName: string; + scopes: string[]; +} + +export interface UserConnectorListResponse { + connectors: UserConnector[]; +} + +/** + * Inference-API response for `/connectors/{id}/initiate-consent`. + * Exactly one of `connected` (true) or `authorizationUrl` (populated) will + * be meaningful — `connected: false` with a URL is the consent path. + */ +export interface InitiateConsentResponse { + connected: boolean; + authorizationUrl: string | null; +} diff --git a/frontend/ai.client/src/app/settings/connectors/services/user-connectors.service.ts b/frontend/ai.client/src/app/settings/connectors/services/user-connectors.service.ts new file mode 100644 index 00000000..7a5b7abd --- /dev/null +++ b/frontend/ai.client/src/app/settings/connectors/services/user-connectors.service.ts @@ -0,0 +1,75 @@ +import { Injectable, inject, resource, computed } from '@angular/core'; +import { HttpClient } from '@angular/common/http'; +import { firstValueFrom } from 'rxjs'; +import { ConfigService } from '../../../services/config.service'; +import { AuthService } from '../../../auth/auth.service'; +import { + InitiateConsentResponse, + UserConnector, +} from '../models/user-connector.model'; + +function toCamelCase(obj: Record): T { + const out: Record = {}; + for (const [k, v] of Object.entries(obj)) { + const camel = k.replace(/_([a-z])/g, (_, c: string) => c.toUpperCase()); + out[camel] = v; + } + return out as T; +} + +/** + * User-facing connectors service. + * + * Catalog lives on app-api (`GET /connectors`). Consent initiation lives + * on inference-api (`POST /connectors/{id}/initiate-consent`) because + * AgentCore's IdentityClient needs the workload token that the runtime + * middleware injects only on the inference path. + */ +@Injectable({ providedIn: 'root' }) +export class UserConnectorsService { + private readonly http = inject(HttpClient); + private readonly auth = inject(AuthService); + private readonly config = inject(ConfigService); + + private readonly appApiUrl = computed(() => `${this.config.appApiUrl()}/connectors`); + private readonly inferenceUrl = computed( + () => `${this.config.inferenceApiUrl()}/connectors`, + ); + + readonly connectorsResource = resource({ + loader: async () => { + await this.auth.ensureAuthenticated(); + const response = await firstValueFrom( + this.http.get<{ connectors: Record[] }>(`${this.appApiUrl()}/`), + ); + return response.connectors.map((c) => toCamelCase(c)); + }, + }); + + async initiateConsent(providerId: string): Promise { + await this.auth.ensureAuthenticated(); + // AgentCore's GetResourceOauth2Token requires a callback URL. The runtime + // reads this header in prod via AgentCoreContextMiddleware; the settings + // page bypasses the runtime, so we set it explicitly on the request. + // We also tack the provider_id onto the callback URL as a query param so + // /oauth-complete can surface the right providerId in its postMessage. + const callback = new URL('/oauth-complete', window.location.origin); + callback.searchParams.set('provider_id', providerId); + const raw = await firstValueFrom( + this.http.post>( + `${this.inferenceUrl()}/${providerId}/initiate-consent`, + {}, + { + headers: { + OAuth2CallbackUrl: callback.toString(), + }, + }, + ), + ); + return toCamelCase(raw); + } + + reload(): void { + this.connectorsResource.reload(); + } +} diff --git a/frontend/ai.client/src/app/settings/pages/connectors-settings/connectors-settings.page.ts b/frontend/ai.client/src/app/settings/pages/connectors-settings/connectors-settings.page.ts new file mode 100644 index 00000000..71da3c9f --- /dev/null +++ b/frontend/ai.client/src/app/settings/pages/connectors-settings/connectors-settings.page.ts @@ -0,0 +1,270 @@ +import { + Component, + ChangeDetectionStrategy, + inject, + signal, + computed, + effect, +} from '@angular/core'; +import { NgIcon, provideIcons } from '@ng-icons/core'; +import { + heroLink, + heroCloud, + heroCodeBracket, + heroAcademicCap, + heroCheckCircle, + heroArrowPath, + heroExclamationTriangle, +} from '@ng-icons/heroicons/outline'; +import { UserConnectorsService } from '../../connectors/services/user-connectors.service'; +import { OAuthConsentService } from '../../../services/oauth-consent/oauth-consent.service'; +import { UserConnector } from '../../connectors/models/user-connector.model'; +import { ToastService } from '../../../services/toast/toast.service'; + +type ConnectState = 'idle' | 'initiating' | 'awaiting' | 'connected' | 'error'; + +@Component({ + selector: 'app-connectors-settings', + changeDetection: ChangeDetectionStrategy.OnPush, + imports: [NgIcon], + providers: [ + provideIcons({ + heroLink, + heroCloud, + heroCodeBracket, + heroAcademicCap, + heroCheckCircle, + heroArrowPath, + heroExclamationTriangle, + }), + ], + host: { class: 'block' }, + template: ` +
+
+

Connectors

+

+ Connect your third-party accounts so agents can call tools on your behalf. +

+
+ + @if (resource.isLoading()) { +
+
+ Loading connectors... +
+ } @else if (resource.error()) { +
+ +
+

Couldn't load connectors

+

+ {{ resource.error()?.message || 'Try again in a moment.' }} +

+ +
+
+ } @else if (connectors().length === 0) { +
+ +

+ No connectors are available to you yet. +

+

+ Ask an administrator to enable a connector for your role. +

+
+ } @else { +
    + @for (connector of connectors(); track connector.providerId) { +
  • +
    +
    + +
    +
    +

    + {{ connector.displayName }} +

    + @if (connector.scopes.length > 0) { +

    + Requests: {{ connector.scopes.join(', ') }} +

    + } +
    +
    + + @let state = getState(connector.providerId); +
    + @if (state === 'connected') { + + + Connected + + } @else if (state === 'error') { + + + Failed + + } + + +
    +
  • + } +
+ } +
+ `, +}) +export class ConnectorsSettingsPage { + private readonly connectorsService = inject(UserConnectorsService); + private readonly consentService = inject(OAuthConsentService); + private readonly toast = inject(ToastService); + + protected readonly resource = this.connectorsService.connectorsResource; + + protected readonly connectors = computed( + () => this.resource.value() ?? [], + ); + + private readonly states = signal>(new Map()); + + constructor() { + // Flip a provider to `connected` when the /oauth-complete landing page + // postMessages success. This is the same signal the chat-input banner + // listens to, so both UIs stay in sync. + effect(() => { + const completion = this.consentService.completion(); + if (!completion || !completion.providerId) return; + if (completion.status === 'success') { + this.setState(completion.providerId, 'connected'); + } else { + this.setState(completion.providerId, 'error'); + } + this.consentService.acknowledgeCompletion(); + }); + + // Probe AgentCore on load (and whenever the connector list changes) + // to restore the "Connected" badge without the user having to click. + // We call initiateConsent just for its `connected` flag — if it returns + // false we discard the URL, the user can still click Connect manually. + effect(() => { + const connectors = this.connectors(); + if (connectors.length === 0) return; + void this.probeConnectedStatus(connectors); + }); + } + + private async probeConnectedStatus(connectors: UserConnector[]): Promise { + const unknown = connectors.filter((c) => !this.states().has(c.providerId)); + await Promise.all( + unknown.map(async (c) => { + try { + const result = await this.connectorsService.initiateConsent(c.providerId); + if (result.connected && this.getState(c.providerId) === 'idle') { + this.setState(c.providerId, 'connected'); + } + } catch { + // Leave state as idle — user can still click Connect to retry. + } + }), + ); + } + + protected getState(providerId: string): ConnectState { + return this.states().get(providerId) ?? 'idle'; + } + + private setState(providerId: string, state: ConnectState): void { + this.states.update((map) => { + const next = new Map(map); + next.set(providerId, state); + return next; + }); + } + + protected async connect(providerId: string): Promise { + this.setState(providerId, 'initiating'); + try { + const result = await this.connectorsService.initiateConsent(providerId); + if (result.connected) { + this.setState(providerId, 'connected'); + this.toast.success(`${this.displayNameFor(providerId)} is already connected.`); + return; + } + if (!result.authorizationUrl) { + this.setState(providerId, 'error'); + this.toast.error('Unexpected response from the server.'); + return; + } + this.consentService.requestConsent(providerId, result.authorizationUrl); + this.consentService.openConsentPopup(providerId); + this.setState(providerId, 'awaiting'); + } catch (err: unknown) { + console.error('Consent initiation failed', err); + this.setState(providerId, 'error'); + const detail = (err as { error?: { detail?: string }; message?: string })?.error?.detail; + this.toast.error(detail ?? 'Could not start the consent flow.'); + } + } + + private displayNameFor(providerId: string): string { + return this.connectors().find((c) => c.providerId === providerId)?.displayName ?? providerId; + } + + protected defaultIcon(providerType: UserConnector['providerType']): string { + switch (providerType) { + case 'google': + case 'microsoft': + return 'heroCloud'; + case 'github': + return 'heroCodeBracket'; + case 'canvas': + return 'heroAcademicCap'; + default: + return 'heroLink'; + } + } + + protected iconClasses(providerType: UserConnector['providerType']): string { + const base = 'flex size-10 items-center justify-center rounded-sm'; + switch (providerType) { + case 'google': + return `${base} bg-red-100 text-red-600 dark:bg-red-900/30 dark:text-red-400`; + case 'microsoft': + return `${base} bg-blue-100 text-blue-600 dark:bg-blue-900/30 dark:text-blue-400`; + case 'github': + return `${base} bg-gray-800 text-white dark:bg-gray-600`; + case 'canvas': + return `${base} bg-orange-100 text-orange-600 dark:bg-orange-900/30 dark:text-orange-400`; + default: + return `${base} bg-purple-100 text-purple-600 dark:bg-purple-900/30 dark:text-purple-400`; + } + } +} diff --git a/frontend/ai.client/src/app/settings/settings.page.ts b/frontend/ai.client/src/app/settings/settings.page.ts index 7b4f0a5f..97079221 100644 --- a/frontend/ai.client/src/app/settings/settings.page.ts +++ b/frontend/ai.client/src/app/settings/settings.page.ts @@ -114,6 +114,7 @@ export class SettingsPage { { label: 'Profile', icon: 'heroUser', route: '/settings/profile', description: 'Your personal information' }, { label: 'Appearance', icon: 'heroPaintBrush', route: '/settings/appearance', description: 'Theme and display' }, { label: 'Chat', icon: 'heroChatBubbleLeftRight', route: '/settings/chat', description: 'Chat preferences' }, + { label: 'Connectors', icon: 'heroLink', route: '/settings/connectors', description: 'Connected accounts' }, { label: 'API Keys', icon: 'heroKey', route: '/settings/api-keys', description: 'API key management' }, { label: 'Usage', icon: 'heroChartBar', route: '/settings/usage', description: 'Usage and billing' }, ]; diff --git a/frontend/ai.client/src/app/settings/settings.routes.ts b/frontend/ai.client/src/app/settings/settings.routes.ts index 5c8e5b86..90d3243f 100644 --- a/frontend/ai.client/src/app/settings/settings.routes.ts +++ b/frontend/ai.client/src/app/settings/settings.routes.ts @@ -21,6 +21,11 @@ export const settingsRoutes: Routes = [ loadComponent: () => import('./pages/chat-preferences/chat-preferences-settings.page').then(m => m.ChatPreferencesSettingsPage), }, + { + path: 'connectors', + loadComponent: () => + import('./pages/connectors-settings/connectors-settings.page').then(m => m.ConnectorsSettingsPage), + }, { path: 'api-keys', loadComponent: () => From b55653dd8d4f127303656ed28c0ae6809875dd60 Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Wed, 22 Apr 2026 17:24:42 -0600 Subject: [PATCH 13/35] refactor(connectors): switch oauth gating from pre-flight to mid-turn interrupts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The agent used to pre-flight OAuth at tool-load time and abort the whole turn if any provider needed consent — the user then had to retype the prompt after authorizing. This switches to the Strands interrupt protocol: the consent gate runs lazily before each tool call, pauses the in-flight turn, and resumes it automatically once the user finishes the popup. Backend - New OAuthConsentHook (BeforeToolCallEvent + AfterToolCallEvent). - BeforeToolCall: looks up the OAuth provider for the selected MCPAgentTool's MCPClient (no name coupling), checks the in-process token cache, and either lets the tool run or calls event.interrupt(...) with the consent URL when AgentCore Identity reports consent required. - AfterToolCall: detects 401-style failures from MCP tool results, marks the (user, provider) for force_authentication on the next fetch, and sets event.retry = True so the BeforeToolCall hook re-fires and triggers a fresh consent. Closes the gap where a provider-side revocation leaves a stale token in AgentCore's vault. - New oauth_token_cache: per-(user, provider) tokens + force-reauth flags; lifecycle-managed by the hook. - ExternalMCPIntegration always loads MCP clients with a lazy token_provider that reads from the cache; the pending_consent / drain_pending_consent dict and the route's pre-LLM short-circuit branch are gone. - StreamCoordinator emits one oauth_required SSE event per pending interrupt before the final done event, carrying interruptId so the frontend can resume the same turn. - ChatAgent.stream_async accepts interrupt_responses and forwards them to Strands as the resume prompt; route accepts the same on /invocations and skips quota + RAG augmentation on resume. Frontend - OAuthRequiredEvent type + validator gain interruptId; settings-page consent path makes interruptId optional (no agent turn to resume). - OAuthConsentService tracks the interruptId per request and invokes a registered resume handler on broadcast success. - ChatRequestService snapshots the last turn's payload and replays it with interrupt_responses attached when a consent completes — the user never retypes the prompt. Smoke-tested end-to-end: Google revoke → whoami → 401 → AfterToolCall detects + retries → fresh consent banner → popup → auto-resume → tool returns greeting in the same turn. Co-Authored-By: Claude Opus 4.7 (1M context) --- backend/src/agents/main_agent/base_agent.py | 81 +++- backend/src/agents/main_agent/chat_agent.py | 26 +- .../integrations/external_mcp_client.py | 249 +++-------- .../integrations/oauth_token_cache.py | 74 ++++ .../main_agent/session/hooks/__init__.py | 2 + .../main_agent/session/hooks/oauth_consent.py | 266 +++++++++++ .../streaming/stream_coordinator.py | 49 ++ backend/src/agents/main_agent/voice_agent.py | 8 +- backend/src/apis/inference_api/chat/models.py | 18 +- backend/src/apis/inference_api/chat/routes.py | 38 +- backend/src/apis/shared/oauth/models.py | 10 +- .../integrations/test_external_mcp_client.py | 177 ++------ .../integrations/test_oauth_token_cache.py | 61 +++ .../session/test_oauth_consent_hook.py | 419 ++++++++++++++++++ .../oauth-consent/oauth-consent.service.ts | 99 ++++- .../services/chat/chat-request.service.ts | 77 +++- .../services/chat/stream-parser.service.ts | 13 +- .../utils/stream-parser/stream-parser-core.ts | 4 +- .../stream-parser/stream-parser-types.ts | 6 +- 19 files changed, 1288 insertions(+), 389 deletions(-) create mode 100644 backend/src/agents/main_agent/integrations/oauth_token_cache.py create mode 100644 backend/src/agents/main_agent/session/hooks/oauth_consent.py create mode 100644 backend/tests/agents/main_agent/integrations/test_oauth_token_cache.py create mode 100644 backend/tests/agents/main_agent/session/test_oauth_consent_hook.py diff --git a/backend/src/agents/main_agent/base_agent.py b/backend/src/agents/main_agent/base_agent.py index 538bb423..136ac171 100644 --- a/backend/src/agents/main_agent/base_agent.py +++ b/backend/src/agents/main_agent/base_agent.py @@ -8,12 +8,13 @@ import logging from abc import ABC, abstractmethod -from typing import AsyncGenerator, List, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional from agents.main_agent.core import ModelConfig, SystemPromptBuilder, AgentFactory from agents.main_agent.session import SessionFactory from agents.main_agent.session.hooks import ( StopHook, + OAuthConsentHook, EmailApprovalHook, ExternalWriteApprovalHook, DangerousToolApprovalHook, @@ -131,9 +132,21 @@ def _create_agent(self) -> None: @abstractmethod async def stream_async( - self, message: str, session_id: Optional[str] = None, files: Optional[List] = None, citations: Optional[List] = None, original_message: Optional[str] = None + self, + message: str, + session_id: Optional[str] = None, + files: Optional[List] = None, + citations: Optional[List] = None, + original_message: Optional[str] = None, + interrupt_responses: Optional[List[Dict[str, Any]]] = None, ) -> AsyncGenerator[str, None]: - """Stream agent responses. Subclasses must implement.""" + """Stream agent responses. Subclasses must implement. + + When `interrupt_responses` is provided, the call resumes a paused + agent turn (Strands interrupt protocol) instead of starting a new + one. In that case `message`/`files` are ignored — the original turn + already has the user's prompt in its context. + """ ... def _register_external_mcp_tools(self) -> None: @@ -187,6 +200,8 @@ def _create_hooks(self) -> List: Includes: - StopHook: Always enabled, cancels tool execution on user stop + - OAuthConsentHook: Pauses the agent (Strands interrupt) when an + OAuth-gated MCP tool is about to run without a cached token - Approval hooks: Gate dangerous operations for user confirmation Returns: @@ -197,6 +212,10 @@ def _create_hooks(self) -> List: # Always-on: session cancellation hooks.append(StopHook(self.session_manager)) + # OAuth consent gate for external MCP tools. Registered unconditionally; + # the hook is a no-op for tools that don't have a registered provider. + hooks.append(self._build_oauth_consent_hook()) + # Approval gates for dangerous operations hooks.append(EmailApprovalHook()) hooks.append(ExternalWriteApprovalHook()) @@ -204,6 +223,34 @@ def _create_hooks(self) -> List: return hooks + def _build_oauth_consent_hook(self) -> OAuthConsentHook: + """Construct the OAuth consent hook with closures over the MCP + integration and provider repository so it stays decoupled from them. + """ + from agents.main_agent.integrations.external_mcp_client import ( + get_external_mcp_integration, + ) + from strands.tools.mcp import MCPAgentTool + + integration = get_external_mcp_integration() + + def provider_lookup(selected_tool: object) -> Optional[str]: + if not isinstance(selected_tool, MCPAgentTool): + return None + return integration.provider_for_client(selected_tool.mcp_client) + + async def scopes_lookup(provider_id: str) -> List[str]: + from apis.shared.oauth.provider_repository import get_provider_repository + + provider = await get_provider_repository().get_provider(provider_id) + return provider.scopes if provider else [] + + return OAuthConsentHook( + user_id=self.user_id, + provider_lookup=provider_lookup, + scopes_lookup=scopes_lookup, + ) + def _build_filtered_tools(self) -> List: """ Filter tools and load gateway/external MCP clients. @@ -226,22 +273,34 @@ def _build_filtered_tools(self) -> List: if external_mcp_tool_ids: import asyncio + from bedrock_agentcore.runtime import BedrockAgentCoreContext + from agents.main_agent.integrations.external_mcp_client import get_external_mcp_integration + # Capture request-scoped context values before crossing the thread + # boundary below. ContextVars do not propagate into the executor's + # fresh event loop, so anything we need there must be passed as args. + oauth2_callback_url = BedrockAgentCoreContext.get_oauth2_callback_url() + workload_access_token = BedrockAgentCoreContext.get_workload_access_token() + external_integration = get_external_mcp_integration() loop = asyncio.get_event_loop() if loop.is_running(): import concurrent.futures - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit( - asyncio.run, - external_integration.load_external_tools( - external_mcp_tool_ids, - user_id=self.user_id, - auth_token=self.auth_token, - ), + async def _load_with_context(): + if oauth2_callback_url: + BedrockAgentCoreContext.set_oauth2_callback_url(oauth2_callback_url) + if workload_access_token: + BedrockAgentCoreContext.set_workload_access_token(workload_access_token) + return await external_integration.load_external_tools( + external_mcp_tool_ids, + user_id=self.user_id, + auth_token=self.auth_token, ) + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, _load_with_context()) external_clients = future.result() else: external_clients = loop.run_until_complete( diff --git a/backend/src/agents/main_agent/chat_agent.py b/backend/src/agents/main_agent/chat_agent.py index 97fbb702..edf26f44 100644 --- a/backend/src/agents/main_agent/chat_agent.py +++ b/backend/src/agents/main_agent/chat_agent.py @@ -6,7 +6,7 @@ """ import logging -from typing import AsyncGenerator, List, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional from agents.main_agent.base_agent import BaseAgent from agents.main_agent.core import AgentFactory @@ -43,17 +43,28 @@ def _create_agent(self) -> None: raise async def stream_async( - self, message: str, session_id: Optional[str] = None, files: Optional[List] = None, citations: Optional[List] = None, original_message: Optional[str] = None + self, + message: str, + session_id: Optional[str] = None, + files: Optional[List] = None, + citations: Optional[List] = None, + original_message: Optional[str] = None, + interrupt_responses: Optional[List[Dict[str, Any]]] = None, ) -> AsyncGenerator[str, None]: """ Stream agent responses. Args: - message: User message text + message: User message text. Ignored when resuming via + `interrupt_responses` — the paused turn already has the + original prompt in `_interrupt_state`. session_id: Session identifier (defaults to instance session_id) files: Optional list of FileContent objects (with base64 bytes) citations: Optional list of citation dicts from RAG retrieval original_message: Original user message before RAG augmentation + interrupt_responses: When set, resume a paused agent turn by + passing this list as the prompt to Strands. Each entry is + `{"interruptResponse": {"interruptId": str, "response": Any}}`. Yields: str: SSE formatted events @@ -61,7 +72,14 @@ async def stream_async( if not self.agent: self._create_agent() - prompt = self.multimodal_builder.build_prompt(message, files) + if interrupt_responses: + # Strands' resume protocol: passing a list of interrupt responses + # as the prompt re-enters the loop, populates the matching + # interrupts' `.response`, and continues from the paused tool + # call. multimodal_builder + files do not apply here. + prompt: Any = interrupt_responses + else: + prompt = self.multimodal_builder.build_prompt(message, files) async for event in self.stream_coordinator.stream_response( agent=self.agent, diff --git a/backend/src/agents/main_agent/integrations/external_mcp_client.py b/backend/src/agents/main_agent/integrations/external_mcp_client.py index 3a9c08bc..847a2afd 100644 --- a/backend/src/agents/main_agent/integrations/external_mcp_client.py +++ b/backend/src/agents/main_agent/integrations/external_mcp_client.py @@ -5,14 +5,17 @@ supporting various authentication methods (AWS IAM, API Key, OAuth, etc.) OAuth Support: - When a tool has `requires_oauth_provider` set, the MCP client will - automatically inject the user's OAuth token into requests. This requires - per-user client instances since tokens are user-specific. + Tools with `requires_oauth_provider` set get an `OAuthBearerAuth` whose + token is resolved lazily on every MCP request via `oauth_token_cache`. + The cache is warmed by `OAuthConsentHook` (which also pauses the agent + via Strands interrupts when consent is needed). This module never + pre-flights OAuth — the agent loads tools optimistically; the hook + gates execution. """ import logging import re -from typing import Optional, List, Any +from typing import Any, Callable, Optional, List from mcp.client.streamable_http import streamablehttp_client from strands.tools.mcp import MCPClient @@ -23,11 +26,7 @@ MCPTransport, ToolDefinition, ) -from agents.main_agent.integrations.agentcore_identity import ( - TokenResult, - WorkloadTokenUnavailableError, - get_agentcore_identity_client, -) +from agents.main_agent.integrations import oauth_token_cache from agents.main_agent.integrations.gateway_auth import get_sigv4_auth from agents.main_agent.integrations.oauth_auth import ( CompositeAuth, @@ -96,28 +95,22 @@ def create_external_mcp_client( config: MCPServerConfig, tool_definition: Optional[ToolDefinition] = None, oauth_token: Optional[str] = None, + token_provider: Optional[Callable[[], Optional[str]]] = None, ) -> Optional[MCPClient]: """ Create an MCP client for an externally deployed MCP server. + Pass either `oauth_token` (static, for OIDC forwarding) or `token_provider` + (callable, for OAuth tokens that the consent hook resolves lazily). + Args: config: MCP server configuration from tool catalog tool_definition: Optional tool definition for logging - oauth_token: Optional OAuth token to include in requests (for user-specific auth) + oauth_token: Optional static token (used for OIDC forwarding) + token_provider: Optional callable returning the current token Returns: MCPClient instance or None if configuration is invalid - - Example: - >>> config = MCPServerConfig( - ... server_url="https://xxx.lambda-url.us-west-2.on.aws/", - ... transport=MCPTransport.STREAMABLE_HTTP, - ... auth_type=MCPAuthType.AWS_IAM, - ... ) - >>> client = create_external_mcp_client(config) - - # With OAuth token for user-specific access: - >>> client = create_external_mcp_client(config, oauth_token="user_access_token") """ if not config.server_url: logger.warning("MCP server URL is required") @@ -125,26 +118,29 @@ def create_external_mcp_client( tool_id = tool_definition.tool_id if tool_definition else "unknown" requires_oauth = tool_definition.requires_oauth_provider if tool_definition else None - has_token = bool(oauth_token) + has_static_token = bool(oauth_token) + has_provider = bool(token_provider) logger.info(f"Creating external MCP client for tool: {tool_id}") logger.debug(f" Transport: {config.transport}") logger.debug(f" Auth Type: {config.auth_type}") if requires_oauth: logger.debug(" Requires OAuth Provider: yes") - logger.debug(f" OAuth Token Provided: {has_token}") + logger.debug(f" Token mode: {'provider' if has_provider else 'static' if has_static_token else 'none'}") try: # Build list of auth handlers (may combine multiple) auth_handlers = [] - # When an OAuth token is provided, use it exclusively as the auth method. - # SigV4 and OAuth both use the Authorization header and cannot coexist — - # SigV4 sets "AWS4-HMAC-SHA256 ..." while OAuth sets "Bearer ...". - # The Lambda Function URL auth type should be NONE for OAuth-authenticated tools. - if oauth_token: - oauth_auth = create_oauth_bearer_auth(token=oauth_token) - auth_handlers.append(oauth_auth) - logger.debug(" Using OAuth Bearer token auth (skipping SigV4)") + # OAuth/OIDC bearer auth takes precedence over SigV4. SigV4 and OAuth both + # use the Authorization header and cannot coexist (SigV4 sets + # "AWS4-HMAC-SHA256 ...", OAuth sets "Bearer ..."). Lambda Function URLs + # backing OAuth-authenticated MCP servers must use auth_type=NONE. + if token_provider: + auth_handlers.append(create_oauth_bearer_auth(token_provider=token_provider)) + logger.debug(" Using OAuth Bearer token provider (lazy)") + elif oauth_token: + auth_handlers.append(create_oauth_bearer_auth(token=oauth_token)) + logger.debug(" Using OAuth Bearer token (static)") # AWS IAM SigV4 authentication (for Lambda/API Gateway without OAuth) elif config.auth_type == MCPAuthType.AWS_IAM or config.auth_type == "aws-iam": @@ -209,82 +205,31 @@ class ExternalMCPIntegration: with protocol='mcp_external' in the tool catalog. OAuth Support: - Tools with `requires_oauth_provider` set will have their MCP clients - created with the user's OAuth token injected. Since tokens are user-specific, - OAuth-enabled tools use a per-user cache key. + For OAuth-gated tools, clients are created with a lazy token provider + that reads from the per-process token cache at request time. The + cache is populated by `OAuthConsentHook`, which gates execution by + raising a Strands interrupt when no token is available yet. + + This integration also maintains an MCPClient -> provider_id map so + the hook can look up which provider a tool's MCP server requires + without coupling on tool names. """ def __init__(self): - """Initialize external MCP integration.""" # Cache key: tool_id for non-OAuth tools, "user_id:tool_id" for OAuth tools self.clients: dict[str, MCPClient] = {} - # Consent URLs collected during load_external_tools, keyed by user_id. - # Consumed (and cleared) by the inference route on the next response so - # they surface as an oauth_required SSE event. Shape: - # { user_id: [ { "provider_id": str, "authorization_url": str }, ... ] } - self.pending_consent: dict[str, list[dict[str, str]]] = {} + # MCPClient object identity -> provider_id, populated alongside `clients`. + # Consumed by OAuthConsentHook via `provider_for_client`. + self._provider_for_client_id: dict[int, str] = {} def _get_cache_key(self, tool_id: str, user_id: Optional[str], requires_oauth: bool) -> str: - """Get the cache key for a tool client.""" if requires_oauth and user_id: return f"{user_id}:{tool_id}" return tool_id - async def _get_oauth_token( - self, - provider_id: str, - ) -> TokenResult: - """Fetch an OAuth token for `provider_id` via AgentCore Identity. - - The user is identified implicitly by the WorkloadAccessToken on - `BedrockAgentCoreContext` (populated from request headers by - `AgentCoreContextMiddleware`). Scopes are read from the platform's - OAuth provider record so organizations can change them without code - changes. - - Convention: `provider_id` is used verbatim as the AgentCore Identity - credential-provider name. Admins register providers with matching - names via `CreateOauth2CredentialProvider`. - - Returns: - `TokenResult` — either `.access_token` on cache hit or - `.authorization_url` when user consent is required. - - Raises: - WorkloadTokenUnavailableError: Not running inside an AgentCore - Runtime invocation (e.g. misconfigured middleware). - """ - from apis.shared.oauth.provider_repository import get_provider_repository - - provider = await get_provider_repository().get_provider(provider_id) - scopes = provider.scopes if provider else [] - - identity_client = get_agentcore_identity_client() - return identity_client.get_token_for_user( - provider_name=provider_id, scopes=scopes - ) - - def _record_pending_consent( - self, user_id: str, provider_id: str, authorization_url: str - ) -> None: - """Stash a consent URL to be surfaced to the user via SSE.""" - bucket = self.pending_consent.setdefault(user_id, []) - # Dedupe on provider_id — if the user has two tools needing the same - # provider, one consent covers both. - if any(entry["provider_id"] == provider_id for entry in bucket): - return - bucket.append( - {"provider_id": provider_id, "authorization_url": authorization_url} - ) - - def drain_pending_consent(self, user_id: str) -> list[dict[str, str]]: - """Consume and return pending consent prompts for a user. - - Called by the inference route on each response so the frontend can - render "Connect to X" affordances. Idempotent across repeated calls - because consent entries are removed once read. - """ - return self.pending_consent.pop(user_id, []) + def provider_for_client(self, client: Any) -> Optional[str]: + """Return the OAuth provider_id backing `client`, or None.""" + return self._provider_for_client_id.get(id(client)) async def load_external_tools( self, @@ -295,15 +240,16 @@ async def load_external_tools( """ Load external MCP clients for enabled tools. - This method queries the tool catalog for tools with protocol='mcp_external' - and creates MCP clients for them. For tools requiring OAuth, the user's - OAuth token is retrieved and injected. For tools with forward_auth_token, - the user's OIDC authentication token is forwarded instead. + For OAuth-gated tools, the client is created with a token provider + that reads from `oauth_token_cache` lazily at request time. Token + acquisition + consent prompting happen in `OAuthConsentHook` at + tool-call time, not here. For tools with `forward_auth_token`, the + user's OIDC token is injected statically. Args: enabled_tool_ids: List of enabled tool IDs - user_id: User ID for OAuth token retrieval (required for OAuth-enabled tools) - auth_token: Raw OIDC token for forwarding (required for forward_auth_token tools) + user_id: User ID (required for OAuth-gated and OIDC-forwarded tools) + auth_token: Raw OIDC token for forwarding Returns: List of MCPClient instances to add to the agent's tools @@ -319,7 +265,6 @@ async def load_external_tools( if not tool: continue - # Check if this is an external MCP tool if tool.protocol != "mcp_external": continue @@ -327,34 +272,30 @@ async def load_external_tools( logger.warning(f"Tool {tool_id} has protocol=mcp_external but no mcp_config") continue - # Determine auth mode: OIDC forwarding, OAuth, or none forward_auth = bool(getattr(tool, "forward_auth_token", False)) requires_oauth = bool(tool.requires_oauth_provider) requires_user_auth = forward_auth or requires_oauth cache_key = self._get_cache_key(tool_id, user_id, requires_user_auth) - # Check cache if cache_key in self.clients: clients.append(self.clients[cache_key]) continue - # Resolve token to use (OIDC forwarding takes precedence) - token_to_use = None + static_token: Optional[str] = None + token_provider: Optional[Callable[[], Optional[str]]] = None + provider_id: Optional[str] = None if forward_auth: - # Forward the user's OIDC authentication token if not auth_token: logger.warning( f"Tool {tool_id} has forward_auth_token=true but no auth_token provided" ) - # Still create the client - server will reject unauthorized requests else: - token_to_use = auth_token + static_token = auth_token logger.info(f"Using OIDC token forwarding for tool {tool_id}") elif requires_oauth: - # Fetch user-federated token via AgentCore Identity. if not user_id: logger.warning( f"Tool {tool_id} requires OAuth provider '{tool.requires_oauth_provider}' " @@ -362,58 +303,28 @@ async def load_external_tools( ) continue - try: - token_result = await self._get_oauth_token( - provider_id=tool.requires_oauth_provider, - ) - except WorkloadTokenUnavailableError: - logger.error( - "No workload token on context for tool %s — " - "AgentCoreContextMiddleware may be misconfigured", - tool_id, - ) - continue - except Exception as e: - logger.error( - "Failed to fetch OAuth token for tool %s: %s", - tool_id, - e, - ) - continue - - if token_result.requires_consent: - # Record the auth URL; the inference route will emit - # an oauth_required SSE event on the next response. - self._record_pending_consent( - user_id=user_id, - provider_id=tool.requires_oauth_provider, - authorization_url=token_result.authorization_url, - ) - logger.info( - "User consent required for tool %s (provider=%s); " - "skipping client creation until consent completes", - tool_id, - tool.requires_oauth_provider, - ) - # Skip loading this tool — the frontend will prompt - # the user to consent before the next invocation. - continue - - token_to_use = token_result.access_token + provider_id = tool.requires_oauth_provider + # Bind user_id and provider_id at closure time so the + # provider stays valid for the client's lifetime. + token_provider = ( + lambda u=user_id, p=provider_id: oauth_token_cache.get(u, p) + ) - # Create MCP client with optional token (works for both OAuth and OIDC) client = create_external_mcp_client( config=tool.mcp_config, tool_definition=tool, - oauth_token=token_to_use, + oauth_token=static_token, + token_provider=token_provider, ) if client: self.clients[cache_key] = client + if provider_id: + self._provider_for_client_id[id(client)] = provider_id clients.append(client) auth_label = ( - " (with OIDC forwarding)" if forward_auth and token_to_use - else " (with OAuth)" if requires_oauth and token_to_use + " (with OIDC forwarding)" if forward_auth and static_token + else f" (OAuth: {provider_id})" if provider_id else "" ) logger.info(f"✅ Loaded external MCP tool: {tool_id}{auth_label}") @@ -425,17 +336,6 @@ async def load_external_tools( return clients def get_client(self, tool_id: str, user_id: Optional[str] = None) -> Optional[MCPClient]: - """ - Get a specific MCP client by tool ID. - - Args: - tool_id: The tool ID - user_id: User ID for OAuth-enabled tools - - Returns: - MCPClient or None if not found - """ - # Try user-specific key first, then generic if user_id: user_key = f"{user_id}:{tool_id}" if user_key in self.clients: @@ -443,15 +343,6 @@ def get_client(self, tool_id: str, user_id: Optional[str] = None) -> Optional[MC return self.clients.get(tool_id) def add_to_tool_list(self, tools: List[Any]) -> List[Any]: - """ - Add all loaded external MCP clients to the tool list. - - Args: - tools: Existing list of tools - - Returns: - Updated tool list with MCP clients added - """ for client in self.clients.values(): if client not in tools: tools.append(client) @@ -461,24 +352,22 @@ def clear_user_clients(self, user_id: str) -> None: """ Clear cached MCP clients for a specific user. - Call this when a user disconnects from an OAuth provider - to ensure fresh clients are created on next use. - - Args: - user_id: User ID to clear clients for + Call this when a user disconnects from an OAuth provider so the next + agent build creates fresh clients (and the token cache miss forces a + new consent flow). """ keys_to_remove = [ key for key in self.clients.keys() if key.startswith(f"{user_id}:") ] for key in keys_to_remove: - del self.clients[key] + client = self.clients.pop(key) + self._provider_for_client_id.pop(id(client), None) if keys_to_remove: logger.info(f"Cleared {len(keys_to_remove)} cached MCP clients for user {user_id}") -# Global instance _external_mcp_integration: Optional[ExternalMCPIntegration] = None diff --git a/backend/src/agents/main_agent/integrations/oauth_token_cache.py b/backend/src/agents/main_agent/integrations/oauth_token_cache.py new file mode 100644 index 00000000..1296d7fc --- /dev/null +++ b/backend/src/agents/main_agent/integrations/oauth_token_cache.py @@ -0,0 +1,74 @@ +"""In-process cache of OAuth access tokens, keyed by (user_id, provider_id). + +Lives for the lifetime of the inference API process. The authoritative store +is AgentCore Identity's token vault — this cache is just a hot path so the +`OAuthBearerAuth` token provider doesn't have to call AgentCore on every +MCP request. + +Tokens are written when: + * `OAuthConsentHook` warms the cache after a successful vault lookup, or + * the resume path re-fetches a token after the user completes consent. + +Tokens are evicted explicitly via `clear_user_provider` when consent is +revoked or expires; we don't track expiry locally because AgentCore +Identity owns refresh. +""" + +from __future__ import annotations + +import threading +from typing import Optional + + +_lock = threading.Lock() +_cache: dict[tuple[str, str], str] = {} +# Per-(user, provider) sticky flag set by `mark_force_reauth` after a tool +# call returns a 401-style error. The consent hook reads this flag on the +# next BeforeToolCallEvent and asks AgentCore Identity for a fresh consent +# URL (`force_authentication=True`) instead of trusting the now-stale +# vault token. Cleared once the new token lands in the cache. +_force_reauth: set[tuple[str, str]] = set() + + +def get(user_id: str, provider_id: str) -> Optional[str]: + with _lock: + return _cache.get((user_id, provider_id)) + + +def set(user_id: str, provider_id: str, token: str) -> None: + with _lock: + _cache[(user_id, provider_id)] = token + _force_reauth.discard((user_id, provider_id)) + + +def clear_user_provider(user_id: str, provider_id: str) -> None: + with _lock: + _cache.pop((user_id, provider_id), None) + + +def clear_user(user_id: str) -> int: + with _lock: + keys = [k for k in _cache if k[0] == user_id] + for key in keys: + del _cache[key] + force_keys = [k for k in _force_reauth if k[0] == user_id] + for key in force_keys: + _force_reauth.discard(key) + return len(keys) + + +def mark_force_reauth(user_id: str, provider_id: str) -> None: + """Flag (user, provider) for forced re-consent on the next token fetch. + + Used by the OAuth error hook after an MCP tool returns a 401 — the + cached token is provably stale, so the next BeforeToolCallEvent must + bypass AgentCore's vault and trigger a fresh consent. + """ + with _lock: + _cache.pop((user_id, provider_id), None) + _force_reauth.add((user_id, provider_id)) + + +def needs_force_reauth(user_id: str, provider_id: str) -> bool: + with _lock: + return (user_id, provider_id) in _force_reauth diff --git a/backend/src/agents/main_agent/session/hooks/__init__.py b/backend/src/agents/main_agent/session/hooks/__init__.py index 9900a4a5..c1b1ce0f 100644 --- a/backend/src/agents/main_agent/session/hooks/__init__.py +++ b/backend/src/agents/main_agent/session/hooks/__init__.py @@ -1,5 +1,6 @@ """Hooks for Main Agent""" +from agents.main_agent.session.hooks.oauth_consent import OAuthConsentHook from agents.main_agent.session.hooks.stop import StopHook from agents.main_agent.session.hooks.tool_approval import ( ToolApprovalHook, @@ -9,6 +10,7 @@ ) __all__ = [ + "OAuthConsentHook", "StopHook", "ToolApprovalHook", "EmailApprovalHook", diff --git a/backend/src/agents/main_agent/session/hooks/oauth_consent.py b/backend/src/agents/main_agent/session/hooks/oauth_consent.py new file mode 100644 index 00000000..5ba1b489 --- /dev/null +++ b/backend/src/agents/main_agent/session/hooks/oauth_consent.py @@ -0,0 +1,266 @@ +"""OAuth consent gate for external MCP tools. + +Fires on every `BeforeToolCallEvent`. If the tool about to run is backed by +an MCP server that requires user-federated OAuth (per the tool catalog), +the hook ensures we have an access token in the in-process cache. If we +don't, it calls `event.interrupt(...)` to pause the agent mid-turn and +hand the authorization URL back to the caller. + +When the user completes consent in the popup and the frontend resumes the +turn, the hook fires a second time and `event.interrupt(...)` returns the +user's response (instead of raising). At that point AgentCore Identity has +the new token in its vault, so we re-fetch and warm the cache; the +`OAuthBearerAuth` token provider then injects it on the next MCP request. + +The hook never aborts the turn on its own — `cancel_tool` is reserved for +genuine refusal (e.g. consent declined). If the user closes the popup we +don't reach that path; the agent simply remains paused until a resume +arrives or the session times out. +""" + +from __future__ import annotations + +import asyncio +import inspect +import logging +import re +from typing import Any, Awaitable, Callable, Optional, Union + +from strands.hooks import ( + AfterToolCallEvent, + BeforeToolCallEvent, + HookProvider, + HookRegistry, +) + +from agents.main_agent.integrations import oauth_token_cache +from agents.main_agent.integrations.agentcore_identity import ( + WorkloadTokenUnavailableError, + get_agentcore_identity_client, +) + +logger = logging.getLogger(__name__) + + +# String markers that indicate an OAuth-style auth failure in a tool +# result. MCP servers vary in how they format errors, so we match a small +# set of unambiguous signals: the literal HTTP code, "Unauthorized", and +# explicit OAuth/token-rejected language. Tools that legitimately return +# the digit "401" in successful output are not at risk because we also +# require an error status (or a context word) — see `_looks_like_auth_failure`. +_AUTH_FAILURE_PATTERN = re.compile( + r"\b401\b|\bunauthorized\b|invalid[_\s-]token|expired[_\s-]token" + r"|token[\s_-]expired|rejected the oauth token|oauth token (?:has )?expired", + re.IGNORECASE, +) + + +def _looks_like_auth_failure(tool_result: Any) -> bool: + """Heuristic: does this tool result look like an OAuth 401? + + Inspects the result's status and content for one of the markers above. + False positives here just trigger a wasted retry; false negatives + leave the user stuck with a stale token, so we err on the side of + matching. + """ + if not isinstance(tool_result, dict): + return False + if tool_result.get("status") != "error": + return False + for block in tool_result.get("content", []) or []: + if not isinstance(block, dict): + continue + text = block.get("text") or "" + if isinstance(text, str) and _AUTH_FAILURE_PATTERN.search(text): + return True + return False + + +# Returns provider_id for a Strands `selected_tool`, or None if the tool +# isn't OAuth-gated. Encapsulates the MCPClient -> provider mapping. +ProviderLookup = Callable[[Any], Optional[str]] + +# Returns OAuth scopes for a provider_id. May be sync or async; the hook +# awaits the result either way so we can read from an async repository +# without forcing a sync wrapper. +ScopesLookup = Callable[[str], Union[list[str], Awaitable[list[str]]]] + + +class OAuthConsentHook(HookProvider): + """Pause the agent if a tool needs OAuth and we don't have a token yet.""" + + def __init__( + self, + user_id: str, + provider_lookup: ProviderLookup, + scopes_lookup: ScopesLookup, + ): + """Initialize. + + Args: + user_id: User the agent is running for. Used as cache key and + passed to AgentCore Identity for the local-dev workload-token + fallback (no-op in production). + provider_lookup: See `ProviderLookup`. + scopes_lookup: See `ScopesLookup`. + """ + self._user_id = user_id + self._provider_lookup = provider_lookup + self._scopes_lookup = scopes_lookup + # Cache scopes per provider for the lifetime of this hook (one agent + # invocation). Avoids repeated DB hits if the same provider is used + # across multiple tool calls in a single turn. + self._scopes_cache: dict[str, list[str]] = {} + + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + registry.add_callback(BeforeToolCallEvent, self._gate) + registry.add_callback(AfterToolCallEvent, self._handle_auth_failure) + + async def _gate(self, event: BeforeToolCallEvent) -> None: + provider_id = self._provider_lookup(event.selected_tool) + if not provider_id: + return # Not an OAuth-gated tool + + force_reauth = oauth_token_cache.needs_force_reauth(self._user_id, provider_id) + + # Fast path: token already in cache (from a prior call this process, + # or warmed by a previous turn). Skipped when a prior tool call + # surfaced a 401 and asked us to bypass the cache. + if not force_reauth and oauth_token_cache.get(self._user_id, provider_id): + return + + # Slow path: ask AgentCore Identity. Either we get a token (vault + # hit, cache it and proceed) or a consent URL (interrupt the turn). + # `force_reauth` makes us bypass AgentCore's vault entirely so a + # stale post-revocation token doesn't get re-served. + token_or_url = await self._fetch_token_or_url( + provider_id, force_authentication=force_reauth + ) + if token_or_url is None: + # Couldn't resolve — let the tool run; the MCP server will return + # 401 and the resulting tool_error surfaces conversationally. + return + + if token_or_url["token"]: + oauth_token_cache.set(self._user_id, provider_id, token_or_url["token"]) + return + + # Consent required: pause the agent. The interrupt name is namespaced + # by provider so the SDK generates a stable interrupt id we can + # correlate with the user's response. + response = event.interrupt( + name=f"oauth:{provider_id}", + reason={ + "type": "oauth_required", + "providerId": provider_id, + "authorizationUrl": token_or_url["url"], + }, + ) + + # We're past the interrupt — the user resumed. Re-fetch from the + # vault (AgentCore Identity should now have the token after consent + # completion) and warm the cache. We ignore `response` content — + # successful resumption is itself the signal that consent happened. + del response + refreshed = await self._fetch_token_or_url(provider_id) + if refreshed and refreshed["token"]: + oauth_token_cache.set(self._user_id, provider_id, refreshed["token"]) + return + + # Resumed but still no token — treat as declined. cancel_tool emits a + # tool_error to the model so it can apologize/replan. + event.cancel_tool = ( + f"User did not complete authorization for {provider_id}; " + "the tool cannot run." + ) + + async def _fetch_token_or_url( + self, provider_id: str, *, force_authentication: bool = False + ) -> Optional[dict]: + """Return {'token': str|None, 'url': str|None} or None on hard error.""" + scopes = await self._resolve_scopes(provider_id) + identity_client = get_agentcore_identity_client() + + try: + result = await identity_client.get_token_for_user( + provider_name=provider_id, + scopes=scopes, + user_id=self._user_id, + force_authentication=force_authentication, + ) + except WorkloadTokenUnavailableError: + logger.error( + "No workload token on context for provider=%s — " + "AgentCoreContextMiddleware may be misconfigured", + provider_id, + ) + return None + except asyncio.CancelledError: + raise + except Exception: + logger.exception( + "Failed to fetch OAuth token for provider=%s", provider_id + ) + return None + + return { + "token": result.access_token, + "url": result.authorization_url, + } + + async def _handle_auth_failure(self, event: AfterToolCallEvent) -> None: + """Detect a 401 from an OAuth-gated MCP tool and retry with fresh consent. + + AgentCore Identity has no revoke API, so when the user revokes our + app at the provider (or the refresh token expires), AgentCore's + vault keeps serving the now-stale token. The MCP server rejects it + with a 401 — and that's where the staleness first becomes visible. + + We detect the 401 in the tool result, mark the (user, provider) + for forced re-consent in the cache, and set `event.retry = True`. + Strands' tool executor then re-fires `BeforeToolCallEvent`, our + `_gate` callback sees the force-reauth flag, asks AgentCore for a + fresh consent URL with `force_authentication=True`, and raises an + interrupt — same path as a first-time consent. + """ + provider_id = self._provider_lookup(event.selected_tool) + if not provider_id: + return + + if not _looks_like_auth_failure(event.result): + return + + # Avoid an infinite retry loop if the refreshed token also fails: + # retry once per (toolUseId, provider) per turn. We piggyback on + # invocation_state, which Strands carries across the retry inside + # the same BeforeToolCallEvent → AfterToolCallEvent cycle. + attempted = event.invocation_state.setdefault("_oauth_reauth_attempted", set()) + key = (event.tool_use.get("toolUseId"), provider_id) + if key in attempted: + logger.warning( + "OAuth re-auth already attempted for tool=%s provider=%s; not retrying again", + event.tool_use.get("name"), + provider_id, + ) + return + attempted.add(key) + + logger.info( + "Detected OAuth 401 for tool=%s provider=%s; clearing token cache and retrying", + event.tool_use.get("name"), + provider_id, + ) + oauth_token_cache.mark_force_reauth(self._user_id, provider_id) + event.retry = True + + async def _resolve_scopes(self, provider_id: str) -> list[str]: + if provider_id in self._scopes_cache: + return self._scopes_cache[provider_id] + result = self._scopes_lookup(provider_id) + if inspect.isawaitable(result): + scopes = await result + else: + scopes = result + scopes = list(scopes or []) + self._scopes_cache[provider_id] = scopes + return scopes diff --git a/backend/src/agents/main_agent/streaming/stream_coordinator.py b/backend/src/agents/main_agent/streaming/stream_coordinator.py index f50be71a..90592982 100644 --- a/backend/src/agents/main_agent/streaming/stream_coordinator.py +++ b/backend/src/agents/main_agent/streaming/stream_coordinator.py @@ -195,6 +195,16 @@ async def stream_response( # Don't yield this event to the client (will send final metadata before done) continue + # If the agent paused on an OAuth interrupt, surface one + # `oauth_required` SSE event per pending interrupt before the + # stream closes. The frontend uses these to drive the consent + # popup and then POSTs interrupt responses to resume the turn. + # Done before the metadata branch so the events land between + # message_stop and the final metadata/done block. + if event.get("type") == "done": + for sse in self._extract_oauth_required_events(agent): + yield sse + # Check if this is the "done" event - send final metadata before it if event.get("type") == "done": # Calculate end-to-end latency @@ -530,6 +540,45 @@ async def stream_response( except Exception as persist_error: logger.error(f"Failed to persist stream error to session: {persist_error}") + def _extract_oauth_required_events(self, agent: Any) -> List[str]: + """Yield one SSE-formatted `oauth_required` event per pending OAuth + interrupt on the agent. + + The Strands `_interrupt_state` is populated when `OAuthConsentHook` + calls `event.interrupt(...)`. We look for interrupts whose `reason` + carries `type: "oauth_required"` and translate them into the SSE + shape the frontend already understands. Non-OAuth interrupts (other + approval gates added later) are ignored here so they can be handled + by their own SSE event types. + """ + from apis.shared.oauth.models import OAuthRequiredEvent + + interrupt_state = getattr(agent, "_interrupt_state", None) + if not interrupt_state or not getattr(interrupt_state, "activated", False): + return [] + + events: List[str] = [] + for interrupt in interrupt_state.interrupts.values(): + reason = interrupt.reason or {} + if not isinstance(reason, dict) or reason.get("type") != "oauth_required": + continue + provider_id = reason.get("providerId") + authorization_url = reason.get("authorizationUrl") + if not provider_id or not authorization_url: + logger.warning( + "OAuth interrupt missing providerId or authorizationUrl: id=%s", + interrupt.id, + ) + continue + events.append( + OAuthRequiredEvent( + provider_id=provider_id, + authorization_url=authorization_url, + interrupt_id=interrupt.id, + ).to_sse_format() + ) + return events + def _format_sse_event(self, event: Dict[str, Any]) -> str: """ Format processed event as SSE (Server-Sent Event) diff --git a/backend/src/agents/main_agent/voice_agent.py b/backend/src/agents/main_agent/voice_agent.py index dde8a061..7ef326bf 100644 --- a/backend/src/agents/main_agent/voice_agent.py +++ b/backend/src/agents/main_agent/voice_agent.py @@ -335,7 +335,13 @@ async def receive_events(self) -> AsyncGenerator[dict, None]: yield event_dict async def stream_async( - self, message: str, session_id: Optional[str] = None, files: Optional[List] = None, citations: Optional[List] = None, original_message: Optional[str] = None + self, + message: str, + session_id: Optional[str] = None, + files: Optional[List] = None, + citations: Optional[List] = None, + original_message: Optional[str] = None, + interrupt_responses: Optional[List] = None, ) -> AsyncGenerator[str, None]: """ BaseAgent interface compatibility — not used for voice mode. diff --git a/backend/src/apis/inference_api/chat/models.py b/backend/src/apis/inference_api/chat/models.py index f2b4a5d9..09968ed6 100644 --- a/backend/src/apis/inference_api/chat/models.py +++ b/backend/src/apis/inference_api/chat/models.py @@ -17,11 +17,23 @@ class FileContent(BaseModel): bytes: str # Base64 encoded +class InterruptResponseEntry(BaseModel): + """One user response to a Strands interrupt, in the SDK's prompt shape. + + Posted by the frontend after the user completes (or declines) an OAuth + consent popup. The backend forwards the list verbatim to + `agent.stream_async(...)` to resume the paused turn. + """ + + interruptId: str + response: Any = None + + class InvocationRequest(BaseModel): """Input for /invocations endpoint with multi-provider support""" session_id: str - message: str + message: str = "" model_id: Optional[str] = None temperature: Optional[float] = None system_prompt: Optional[str] = None @@ -36,6 +48,10 @@ class InvocationRequest(BaseModel): # AgentCore Runtime returns 424 when it sees a non-empty 'assistant_id' field, # likely trying to resolve it as an AWS Bedrock Agent ID. rag_assistant_id: Optional[str] = None + # When set, the route resumes a paused agent turn instead of starting a + # new one. `message` is ignored in that case — the original prompt is + # already in the agent's interrupt context. + interrupt_responses: Optional[List[InterruptResponseEntry]] = None class InvocationResponse(BaseModel): diff --git a/backend/src/apis/inference_api/chat/routes.py b/backend/src/apis/inference_api/chat/routes.py index 17b4a4ae..b07cfea2 100644 --- a/backend/src/apis/inference_api/chat/routes.py +++ b/backend/src/apis/inference_api/chat/routes.py @@ -24,7 +24,6 @@ build_conversational_error_event, ) from apis.shared.files.file_resolver import get_file_resolver -from apis.shared.oauth.models import OAuthRequiredEvent from apis.shared.models.managed_models import list_managed_models from apis.shared.quota import ( QuotaExceededEvent, @@ -210,7 +209,13 @@ async def invocations(request: InvocationRequest, current_user: User = Depends(g input_data = request user_id = current_user.user_id auth_token = current_user.raw_token - logger.info("Invocation request received") + # Resume requests reuse the cached agent and its paused interrupt state; + # they bypass quota, file resolution, and RAG augmentation because those + # already ran on the original turn that got paused. + is_resume = bool(input_data.interrupt_responses) + logger.info( + "Invocation request received (resume=%s)" % is_resume + ) logger.info("Message received") if input_data.enabled_tools: @@ -246,7 +251,7 @@ async def invocations(request: InvocationRequest, current_user: User = Depends(g # Check quota if enforcement is enabled quota_warning_event = None quota_exceeded_event = None - if is_quota_enforcement_enabled(): + if is_quota_enforcement_enabled() and not is_resume: try: quota_checker = get_quota_checker() quota_result = await quota_checker.check_quota(user=current_user, session_id=input_data.session_id) @@ -305,7 +310,7 @@ async def invocations(request: InvocationRequest, current_user: User = Depends(g "Invocation request - processing with assistant context" ) - if input_data.rag_assistant_id: + if input_data.rag_assistant_id and not is_resume: # Local imports to avoid circular dependency from apis.shared.assistants.rag_service import ( augment_prompt_with_context, @@ -574,29 +579,26 @@ async def stream_with_quota_warning() -> AsyncGenerator[str, None]: augmented_message != input_data.message # RAG augmentation or bool(files_to_send) # File attachments ) + # Strands' resume protocol wants each entry wrapped as + # {"interruptResponse": {...}}. The InvocationRequest schema + # accepts the inner shape so callers don't have to think about + # the SDK's content-block convention. + interrupt_responses_payload = ( + [{"interruptResponse": entry.model_dump()} for entry in input_data.interrupt_responses] + if input_data.interrupt_responses + else None + ) + async for event in agent.stream_async( augmented_message, session_id=input_data.session_id, files=files_to_send if files_to_send else None, citations=citations_for_storage if citations_for_storage else None, original_message=input_data.message if message_will_be_modified else None, + interrupt_responses=interrupt_responses_payload, ): yield event - # Surface any OAuth consent URLs collected while loading external - # MCP tools for this user. Draining is idempotent — entries are - # removed on read, so a later invocation won't re-prompt unless - # AgentCore Identity still reports consent is required. - from agents.main_agent.integrations.external_mcp_client import ( - get_external_mcp_integration, - ) - - for entry in get_external_mcp_integration().drain_pending_consent(user_id): - yield OAuthRequiredEvent( - provider_id=entry["provider_id"], - authorization_url=entry["authorization_url"], - ).to_sse_format() - # Stream response from agent as SSE (with optional files) # Note: Compression is handled by GZipMiddleware if configured in main.py return StreamingResponse( diff --git a/backend/src/apis/shared/oauth/models.py b/backend/src/apis/shared/oauth/models.py index ebbf0aba..07ea00f4 100644 --- a/backend/src/apis/shared/oauth/models.py +++ b/backend/src/apis/shared/oauth/models.py @@ -246,9 +246,12 @@ class OAuthProviderListResponse(BaseModel): class OAuthRequiredEvent(BaseModel): """SSE event signalling that a tool needs user consent before it can run. - Emitted by the inference route after an agent response finishes, one per - provider with a pending consent URL. The frontend uses it to render a - "Connect to X" affordance that opens `authorizationUrl` in a popup. + Emitted mid-turn when `OAuthConsentHook` raises a Strands interrupt: the + agent's tool call is paused (its in-flight state is held in + `_interrupt_state`), the frontend receives this event, opens the + consent popup at `authorizationUrl`, and on completion POSTs an + interrupt response carrying `interruptId` back to `/invocations`. The + backend resumes the same turn — no retype, no replay. """ model_config = ConfigDict(populate_by_name=True) @@ -256,6 +259,7 @@ class OAuthRequiredEvent(BaseModel): type: str = "oauth_required" provider_id: str = Field(..., alias="providerId") authorization_url: str = Field(..., alias="authorizationUrl") + interrupt_id: str = Field(..., alias="interruptId") def to_sse_format(self) -> str: import json diff --git a/backend/tests/agents/main_agent/integrations/test_external_mcp_client.py b/backend/tests/agents/main_agent/integrations/test_external_mcp_client.py index b1e39992..b847ba23 100644 --- a/backend/tests/agents/main_agent/integrations/test_external_mcp_client.py +++ b/backend/tests/agents/main_agent/integrations/test_external_mcp_client.py @@ -1,16 +1,15 @@ """ -Tests for extract_region_from_url and detect_aws_service_from_url. +Tests for the external MCP client helpers. + +OAuth provisioning moved to `OAuthConsentHook` (see +`tests/agents/main_agent/session/hooks/test_oauth_consent.py`); this +module covers the URL-parsing helpers and the integration's +MCPClient -> provider_id map that the hook reads from. Requirements: 25.1–25.3 """ import pytest -from unittest.mock import AsyncMock, MagicMock, patch - -from agents.main_agent.integrations.agentcore_identity import ( - TokenResult, - WorkloadTokenUnavailableError, -) from agents.main_agent.integrations.external_mcp_client import ( ExternalMCPIntegration, detect_aws_service_from_url, @@ -71,158 +70,42 @@ def test_defaults_to_lambda_for_unknown_url(self): assert detect_aws_service_from_url(url) == "lambda" -class TestGetOAuthTokenViaAgentCoreIdentity: - """Tests for ExternalMCPIntegration._get_oauth_token delegating to AgentCore Identity.""" - - @pytest.mark.asyncio - async def test_fetches_scopes_from_provider_repo_and_calls_identity(self): - """Provider scopes from the platform's provider record are forwarded to AgentCore.""" - integration = ExternalMCPIntegration() - - mock_provider = MagicMock() - mock_provider.scopes = ["openid", "profile", "email"] - mock_repo = MagicMock() - mock_repo.get_provider = AsyncMock(return_value=mock_provider) - - mock_identity = MagicMock() - mock_identity.get_token_for_user.return_value = TokenResult(access_token="tok") - - with patch( - "apis.shared.oauth.provider_repository.get_provider_repository", - return_value=mock_repo, - ), patch( - "agents.main_agent.integrations.external_mcp_client.get_agentcore_identity_client", - return_value=mock_identity, - ): - result = await integration._get_oauth_token(provider_id="google") - - assert result.access_token == "tok" - mock_identity.get_token_for_user.assert_called_once_with( - provider_name="google", scopes=["openid", "profile", "email"] - ) - - @pytest.mark.asyncio - async def test_returns_authorization_url_when_consent_required(self): - integration = ExternalMCPIntegration() - - mock_repo = MagicMock() - mock_repo.get_provider = AsyncMock( - return_value=MagicMock(scopes=["openid"]) - ) - - mock_identity = MagicMock() - mock_identity.get_token_for_user.return_value = TokenResult( - authorization_url="https://accounts.example.com/consent" - ) - - with patch( - "apis.shared.oauth.provider_repository.get_provider_repository", - return_value=mock_repo, - ), patch( - "agents.main_agent.integrations.external_mcp_client.get_agentcore_identity_client", - return_value=mock_identity, - ): - result = await integration._get_oauth_token(provider_id="google") - - assert result.requires_consent is True - assert result.authorization_url == "https://accounts.example.com/consent" - - @pytest.mark.asyncio - async def test_empty_scopes_when_provider_record_missing(self): - """Missing provider record falls back to empty scopes so the call still succeeds - and AgentCore can apply its own provider defaults.""" - integration = ExternalMCPIntegration() - - mock_repo = MagicMock() - mock_repo.get_provider = AsyncMock(return_value=None) - - mock_identity = MagicMock() - mock_identity.get_token_for_user.return_value = TokenResult(access_token="t") +class TestProviderForClient: + """The integration's MCPClient -> provider_id map is what + `OAuthConsentHook.provider_lookup` consults.""" - with patch( - "apis.shared.oauth.provider_repository.get_provider_repository", - return_value=mock_repo, - ), patch( - "agents.main_agent.integrations.external_mcp_client.get_agentcore_identity_client", - return_value=mock_identity, - ): - await integration._get_oauth_token(provider_id="unknown") - - mock_identity.get_token_for_user.assert_called_once_with( - provider_name="unknown", scopes=[] - ) - - @pytest.mark.asyncio - async def test_propagates_workload_token_unavailable(self): - """Misconfigured middleware should surface as a typed error, not be swallowed.""" + def test_unknown_client_returns_none(self): integration = ExternalMCPIntegration() - mock_repo = MagicMock() - mock_repo.get_provider = AsyncMock(return_value=MagicMock(scopes=[])) - - mock_identity = MagicMock() - mock_identity.get_token_for_user.side_effect = WorkloadTokenUnavailableError( - "no ctx" - ) + class FakeClient: + pass - with patch( - "apis.shared.oauth.provider_repository.get_provider_repository", - return_value=mock_repo, - ), patch( - "agents.main_agent.integrations.external_mcp_client.get_agentcore_identity_client", - return_value=mock_identity, - ): - with pytest.raises(WorkloadTokenUnavailableError): - await integration._get_oauth_token(provider_id="google") + assert integration.provider_for_client(FakeClient()) is None - -class TestPendingConsent: - """Tests for the per-user consent URL stash consumed by the SSE emitter.""" - - def test_record_and_drain_roundtrip(self): + def test_records_and_resolves_provider_for_client(self): integration = ExternalMCPIntegration() - integration._record_pending_consent( - user_id="u1", provider_id="google", authorization_url="https://a/1" - ) - drained = integration.drain_pending_consent("u1") + class FakeClient: + pass - assert drained == [ - {"provider_id": "google", "authorization_url": "https://a/1"} - ] + client = FakeClient() + # Simulate what `load_external_tools` does after creating an + # OAuth-gated MCP client. + integration._provider_for_client_id[id(client)] = "google-workspace" - def test_drain_is_idempotent(self): - """Second drain returns empty — consent prompts are single-delivery.""" - integration = ExternalMCPIntegration() - integration._record_pending_consent("u1", "google", "https://a/1") + assert integration.provider_for_client(client) == "google-workspace" - integration.drain_pending_consent("u1") - second = integration.drain_pending_consent("u1") - - assert second == [] - - def test_dedupe_by_provider(self): - """Two tools needing the same provider produce one consent prompt.""" + def test_clear_user_clients_drops_provider_mapping(self): integration = ExternalMCPIntegration() - integration._record_pending_consent("u1", "google", "https://a/1") - integration._record_pending_consent("u1", "google", "https://a/1") - drained = integration.drain_pending_consent("u1") + class FakeClient: + pass - assert len(drained) == 1 + client = FakeClient() + integration.clients["alice:gmail"] = client + integration._provider_for_client_id[id(client)] = "google-workspace" - def test_per_user_isolation(self): - integration = ExternalMCPIntegration() - integration._record_pending_consent("u1", "google", "https://a/1") - integration._record_pending_consent("u2", "slack", "https://a/2") + integration.clear_user_clients("alice") - assert integration.drain_pending_consent("u1") == [ - {"provider_id": "google", "authorization_url": "https://a/1"} - ] - assert integration.drain_pending_consent("u2") == [ - {"provider_id": "slack", "authorization_url": "https://a/2"} - ] - - def test_drain_empty_when_no_prompts(self): - integration = ExternalMCPIntegration() - assert integration.drain_pending_consent("u-nobody") == [] + assert "alice:gmail" not in integration.clients + assert integration.provider_for_client(client) is None diff --git a/backend/tests/agents/main_agent/integrations/test_oauth_token_cache.py b/backend/tests/agents/main_agent/integrations/test_oauth_token_cache.py new file mode 100644 index 00000000..939ec006 --- /dev/null +++ b/backend/tests/agents/main_agent/integrations/test_oauth_token_cache.py @@ -0,0 +1,61 @@ +"""Tests for the in-process OAuth token cache.""" + +from agents.main_agent.integrations import oauth_token_cache + + +def _isolate(user: str = "tester") -> None: + oauth_token_cache.clear_user(user) + + +def test_get_returns_none_when_unset(): + _isolate() + assert oauth_token_cache.get("tester", "google") is None + + +def test_set_then_get_roundtrip(): + _isolate() + oauth_token_cache.set("tester", "google", "tok-1") + assert oauth_token_cache.get("tester", "google") == "tok-1" + + +def test_per_user_isolation(): + oauth_token_cache.clear_user("alice") + oauth_token_cache.clear_user("bob") + oauth_token_cache.set("alice", "google", "alice-tok") + oauth_token_cache.set("bob", "google", "bob-tok") + + assert oauth_token_cache.get("alice", "google") == "alice-tok" + assert oauth_token_cache.get("bob", "google") == "bob-tok" + + +def test_per_provider_isolation(): + _isolate() + oauth_token_cache.set("tester", "google", "g-tok") + oauth_token_cache.set("tester", "github", "gh-tok") + + assert oauth_token_cache.get("tester", "google") == "g-tok" + assert oauth_token_cache.get("tester", "github") == "gh-tok" + + +def test_clear_user_drops_only_that_user(): + oauth_token_cache.clear_user("alice") + oauth_token_cache.clear_user("bob") + oauth_token_cache.set("alice", "google", "a") + oauth_token_cache.set("bob", "google", "b") + + removed = oauth_token_cache.clear_user("alice") + + assert removed == 1 + assert oauth_token_cache.get("alice", "google") is None + assert oauth_token_cache.get("bob", "google") == "b" + + +def test_clear_user_provider_drops_only_that_pair(): + _isolate() + oauth_token_cache.set("tester", "google", "g") + oauth_token_cache.set("tester", "github", "gh") + + oauth_token_cache.clear_user_provider("tester", "google") + + assert oauth_token_cache.get("tester", "google") is None + assert oauth_token_cache.get("tester", "github") == "gh" diff --git a/backend/tests/agents/main_agent/session/test_oauth_consent_hook.py b/backend/tests/agents/main_agent/session/test_oauth_consent_hook.py new file mode 100644 index 00000000..bab2f085 --- /dev/null +++ b/backend/tests/agents/main_agent/session/test_oauth_consent_hook.py @@ -0,0 +1,419 @@ +"""Tests for OAuthConsentHook. + +Covers the lazy token-resolution path: hook fires before each tool call, +asks AgentCore Identity for the user's token, caches it on a hit, and +raises a Strands interrupt with the consent URL on a miss. Resume is +exercised by pre-seeding the interrupt with a response so the second +`event.interrupt(...)` returns instead of raising. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from strands.interrupt import Interrupt, InterruptException + +from agents.main_agent.integrations import oauth_token_cache +from agents.main_agent.integrations.agentcore_identity import ( + TokenResult, + WorkloadTokenUnavailableError, +) +from agents.main_agent.session.hooks.oauth_consent import OAuthConsentHook + + +@pytest.fixture(autouse=True) +def _clear_cache(): + """Token cache is process-global; isolate between tests.""" + oauth_token_cache.clear_user("alice") + yield + oauth_token_cache.clear_user("alice") + + +def _make_event(provider_id: str | None, *, agent=None) -> MagicMock: + """Build a stand-in for `BeforeToolCallEvent`. + + The hook reads `event.selected_tool` (passed straight to + `provider_lookup`) and calls `event.interrupt(...)`. We forward + `interrupt` to a real `_Interruptible.interrupt` style implementation + so the test exercises the same raise/return semantics as the SDK. + """ + event = MagicMock() + event.selected_tool = MagicMock() + event.cancel_tool = None + + agent = agent or MagicMock() + agent._interrupt_state = MagicMock() + agent._interrupt_state.interrupts = {} + + def interrupt(name: str, reason=None, response=None): + # Mirror the SDK: deterministic id keyed on the name so the second + # call returns the response instead of raising. + interrupt_id = f"v1:before_tool_call:tu_test:{name}" + existing = agent._interrupt_state.interrupts.setdefault( + interrupt_id, Interrupt(interrupt_id, name, reason, response) + ) + if existing.response is not None: + return existing.response + raise InterruptException(existing) + + event.interrupt = interrupt + event._agent = agent # for tests that want to inspect interrupt state + return event + + +class TestOAuthConsentHookCacheHit: + @pytest.mark.asyncio + async def test_no_op_when_tool_not_oauth_gated(self): + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: None, + scopes_lookup=lambda _: [], + ) + event = _make_event(provider_id=None) + + await hook._gate(event) + + assert event.cancel_tool is None + + @pytest.mark.asyncio + async def test_uses_cached_token_without_calling_identity(self): + oauth_token_cache.set("alice", "google", "cached-token") + + identity = MagicMock() + identity.get_token_for_user = AsyncMock() + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: ["openid"], + ) + event = _make_event(provider_id="google") + + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + await hook._gate(event) + + identity.get_token_for_user.assert_not_called() + assert event.cancel_tool is None + + +class TestOAuthConsentHookVaultHit: + @pytest.mark.asyncio + async def test_warms_cache_when_vault_returns_token(self): + identity = MagicMock() + identity.get_token_for_user = AsyncMock( + return_value=TokenResult(access_token="tok-from-vault") + ) + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: ["openid"], + ) + event = _make_event(provider_id="google") + + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + await hook._gate(event) + + assert oauth_token_cache.get("alice", "google") == "tok-from-vault" + identity.get_token_for_user.assert_called_once_with( + provider_name="google", + scopes=["openid"], + user_id="alice", + force_authentication=False, + ) + + +class TestOAuthConsentHookConsentRequired: + @pytest.mark.asyncio + async def test_raises_interrupt_with_oauth_required_reason(self): + identity = MagicMock() + identity.get_token_for_user = AsyncMock( + return_value=TokenResult(authorization_url="https://accounts/consent") + ) + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: ["openid"], + ) + event = _make_event(provider_id="google") + + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + with pytest.raises(InterruptException) as excinfo: + await hook._gate(event) + + interrupt = excinfo.value.interrupt + assert interrupt.name == "oauth:google" + assert interrupt.reason == { + "type": "oauth_required", + "providerId": "google", + "authorizationUrl": "https://accounts/consent", + } + # Cache stays empty until consent actually completes. + assert oauth_token_cache.get("alice", "google") is None + + @pytest.mark.asyncio + async def test_resume_warms_cache_with_post_consent_token(self): + """On resume the SDK pre-populates the interrupt's response so + `event.interrupt(...)` returns. The hook then re-fetches from the + vault (which now has a token) and primes the cache so subsequent + MCP requests pick up the bearer token without another round trip.""" + identity = MagicMock() + # First call: consent required. Second call (post-consent): token. + identity.get_token_for_user = AsyncMock( + side_effect=[ + TokenResult(authorization_url="https://accounts/consent"), + TokenResult(access_token="post-consent-token"), + ] + ) + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: ["openid"], + ) + event = _make_event(provider_id="google") + + # Pre-seed the interrupt with a response — simulates the SDK + # restoring `_interrupt_state` before re-running the hook on resume. + agent = event._agent + interrupt_id = "v1:before_tool_call:tu_test:oauth:google" + agent._interrupt_state.interrupts[interrupt_id] = Interrupt( + interrupt_id, "oauth:google", reason=None, response="consented" + ) + + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + await hook._gate(event) + + assert oauth_token_cache.get("alice", "google") == "post-consent-token" + assert event.cancel_tool is None + + @pytest.mark.asyncio + async def test_resume_without_token_cancels_tool(self): + """If the user closes the popup mid-flow, AgentCore's vault stays + empty. Resuming surfaces this as a cancel_tool so the model + gets a tool_error and can apologize/replan instead of looping.""" + identity = MagicMock() + identity.get_token_for_user = AsyncMock( + side_effect=[ + TokenResult(authorization_url="https://accounts/consent"), + TokenResult(authorization_url="https://accounts/consent"), + ] + ) + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: ["openid"], + ) + event = _make_event(provider_id="google") + agent = event._agent + interrupt_id = "v1:before_tool_call:tu_test:oauth:google" + agent._interrupt_state.interrupts[interrupt_id] = Interrupt( + interrupt_id, "oauth:google", reason=None, response="consented" + ) + + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + await hook._gate(event) + + assert event.cancel_tool is not None + assert "google" in event.cancel_tool + + +class TestOAuthConsentHookAuthFailureRetry: + """The AfterToolCallEvent handler turns a 401-style tool error into + a retry that forces re-consent at AgentCore Identity.""" + + def _after_event( + self, + provider_id: str | None, + result_text: str, + *, + result_status: str = "error", + ) -> MagicMock: + event = MagicMock() + event.selected_tool = MagicMock() + event.tool_use = {"name": "whoami", "toolUseId": "tu_1"} + event.invocation_state = {} + event.result = { + "toolUseId": "tu_1", + "status": result_status, + "content": [{"text": result_text}], + } + event.retry = False + return event + + @pytest.mark.asyncio + async def test_401_marks_force_reauth_and_retries(self): + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: [], + ) + oauth_token_cache.set("alice", "google", "stale-token") + event = self._after_event( + "google", + "Error executing tool whoami: Google rejected the OAuth token (401).", + ) + + await hook._handle_auth_failure(event) + + assert event.retry is True + assert oauth_token_cache.needs_force_reauth("alice", "google") is True + # Cache cleared so the BeforeToolCallEvent retry doesn't short-circuit. + assert oauth_token_cache.get("alice", "google") is None + + @pytest.mark.asyncio + async def test_non_oauth_tool_is_ignored(self): + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: None, + scopes_lookup=lambda _: [], + ) + event = self._after_event(None, "401 Unauthorized") + + await hook._handle_auth_failure(event) + + assert event.retry is False + + @pytest.mark.asyncio + async def test_non_auth_error_is_ignored(self): + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: [], + ) + event = self._after_event("google", "Network unreachable") + + await hook._handle_auth_failure(event) + + assert event.retry is False + assert oauth_token_cache.needs_force_reauth("alice", "google") is False + + @pytest.mark.asyncio + async def test_does_not_retry_twice_for_same_tool_use(self): + """Second 401 in the same retry cycle must not loop forever.""" + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: [], + ) + event1 = self._after_event("google", "401 Unauthorized") + await hook._handle_auth_failure(event1) + assert event1.retry is True + + # Same tool_use_id, same invocation_state — second failure must + # surrender so the user sees the error. + event2 = self._after_event("google", "401 Unauthorized") + event2.invocation_state = event1.invocation_state + await hook._handle_auth_failure(event2) + assert event2.retry is False + + +class TestOAuthConsentHookErrors: + @pytest.mark.asyncio + async def test_workload_token_unavailable_lets_tool_proceed(self): + """A misconfigured runtime context shouldn't crash the agent; the + tool runs, the MCP server 401s, and the failure surfaces as a + normal tool_error the user can act on.""" + identity = MagicMock() + identity.get_token_for_user = AsyncMock( + side_effect=WorkloadTokenUnavailableError("no ctx") + ) + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: ["openid"], + ) + event = _make_event(provider_id="google") + + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + await hook._gate(event) # must not raise + + assert event.cancel_tool is None + assert oauth_token_cache.get("alice", "google") is None + + @pytest.mark.asyncio + async def test_scopes_lookup_can_be_async(self): + """Hook accepts async scopes_lookup so callers can read directly + from an async repository without a sync wrapper.""" + identity = MagicMock() + identity.get_token_for_user = AsyncMock( + return_value=TokenResult(access_token="t") + ) + + async def async_scopes(_pid: str) -> list[str]: + return ["openid", "profile"] + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=async_scopes, + ) + event = _make_event(provider_id="google") + + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + await hook._gate(event) + + kwargs = identity.get_token_for_user.call_args.kwargs + assert kwargs["scopes"] == ["openid", "profile"] + + @pytest.mark.asyncio + async def test_scopes_lookup_is_cached_across_calls(self): + """Repeated tool calls for the same provider hit the scopes lookup + once per hook lifetime (one agent invocation).""" + identity = MagicMock() + identity.get_token_for_user = AsyncMock( + return_value=TokenResult(access_token="t") + ) + + scopes_lookup = MagicMock(return_value=["openid"]) + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=scopes_lookup, + ) + + # First call hits identity (and the lookup). + event1 = _make_event(provider_id="google") + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + await hook._gate(event1) + + # Cache now warm — second call short-circuits before identity. + # Force a vault fetch by clearing the token cache. + oauth_token_cache.clear_user("alice") + event2 = _make_event(provider_id="google") + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + await hook._gate(event2) + + assert scopes_lookup.call_count == 1 diff --git a/frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.ts b/frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.ts index b3f0c735..1b7db3af 100644 --- a/frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.ts +++ b/frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.ts @@ -1,15 +1,20 @@ import { Injectable, signal, computed, inject, DestroyRef } from '@angular/core'; import { takeUntilDestroyed } from '@angular/core/rxjs-interop'; import { fromEvent } from 'rxjs'; -import { filter } from 'rxjs/operators'; /** * Pending OAuth consent request surfaced by the backend when an external * MCP tool needs the user to authorize AgentCore Identity. + * + * `interruptId` is set when the request comes from a paused agent turn + * (SSE `oauth_required` event) so the chat layer can resume the same turn + * after consent. It's omitted when the user proactively connects from the + * settings page — in that case there's no agent turn to resume. */ export interface OAuthConsentRequest { providerId: string; authorizationUrl: string; + interruptId?: string; receivedAt: number; } @@ -25,6 +30,14 @@ export interface OAuthCompleteMessage { error: string | null; } +/** + * Handler the chat layer registers to resume a paused agent turn after + * one or more OAuth consents complete. Receives the interrupt ids whose + * tokens are now available; the handler is expected to POST a resume + * request to `/invocations` with `interrupt_responses` populated. + */ +export type OAuthResumeHandler = (interruptIds: string[]) => void | Promise; + function isOAuthCompleteMessage(data: unknown): data is OAuthCompleteMessage { if (!data || typeof data !== 'object') { return false; @@ -35,14 +48,16 @@ function isOAuthCompleteMessage(data: unknown): data is OAuthCompleteMessage { /** * Tracks OAuth consent requests surfaced by the SSE stream and coordinates - * the popup flow. + * the popup + auto-resume flow. * * The stream parser calls {@link requestConsent} when an `oauth_required` * event arrives; components render a "Connect" affordance bound to * {@link pending}. When the user clicks, {@link openConsentPopup} opens the * AgentCore Identity URL, and this service listens for the * `agentcore-oauth-complete` postMessage from the `/oauth-complete` landing - * page to resolve the provider. + * page. On success it dismisses the request and asks the registered + * {@link OAuthResumeHandler} to fire a resume request — the user does NOT + * have to retype the original prompt. */ @Injectable({ providedIn: 'root' }) export class OAuthConsentService { @@ -58,6 +73,10 @@ export class OAuthConsentService { /** Most recent completion notice surfaced to the chat layer. */ private readonly lastCompletion = signal(null); + /** Resume handler registered by the chat layer. Replayed when a + * consent completes successfully. */ + private resumeHandler: OAuthResumeHandler | null = null; + readonly pending = computed(() => Array.from(this.requests().values()).sort((a, b) => a.receivedAt - b.receivedAt), ); @@ -67,30 +86,52 @@ export class OAuthConsentService { readonly completion = this.lastCompletion.asReadonly(); constructor() { - // Listen for postMessages from the /oauth-complete landing page. The - // origin guard makes sure cross-origin pages can't spoof a completion. + // Primary channel: BroadcastChannel. AgentCore's OAuth popup navigates + // through external origins (Google, AgentCore), which triggers Chrome's + // Cross-Origin-Opener-Policy and severs window.opener. window.postMessage + // from the /oauth-complete page is silently blocked in that case, so we + // rely on a same-origin BroadcastChannel to bridge popup → opener. + try { + const channel = new BroadcastChannel('agentcore-oauth-complete'); + channel.addEventListener('message', (event) => { + if (!isOAuthCompleteMessage(event.data)) { + return; + } + this.handleCompletion(event.data); + }); + this.destroyRef.onDestroy(() => channel.close()); + } catch { + // BroadcastChannel unavailable — fall back to postMessage below. + } + + // Fallback channel: window postMessage (pre-COOP browsers, or flows + // where the popup manages to retain window.opener). The origin guard + // makes sure cross-origin pages can't spoof a completion. fromEvent(window, 'message') - .pipe( - filter((event) => event.origin === window.location.origin), - filter((event) => isOAuthCompleteMessage(event.data)), - takeUntilDestroyed(this.destroyRef), - ) + .pipe(takeUntilDestroyed(this.destroyRef)) .subscribe((event) => { - const message = event.data as OAuthCompleteMessage; - this.handleCompletion(message); + if (event.origin !== window.location.origin) { + return; + } + if (!isOAuthCompleteMessage(event.data)) { + return; + } + this.handleCompletion(event.data); }); } /** * Register a consent request coming off the SSE stream. - * Duplicate providerIds refresh the existing entry (URLs can rotate). + * Duplicate providerIds refresh the existing entry — the backend may + * reissue an interrupt with a new id if the user retried. */ - requestConsent(providerId: string, authorizationUrl: string): void { + requestConsent(providerId: string, authorizationUrl: string, interruptId?: string): void { this.requests.update((map) => { const next = new Map(map); next.set(providerId, { providerId, authorizationUrl, + interruptId, receivedAt: Date.now(), }); return next; @@ -146,6 +187,16 @@ export class OAuthConsentService { return this.inFlight().has(providerId); } + /** + * Register the chat-layer handler that resumes the paused agent turn + * after one or more OAuth consents complete. The handler receives the + * interrupt ids whose tokens are ready; replacing it (set to null) + * disables auto-resume. + */ + setResumeHandler(handler: OAuthResumeHandler | null): void { + this.resumeHandler = handler; + } + /** * Clear a single consent request — called from the UI after the user * completes or dismisses a provider, or when the chat is reset. @@ -183,8 +234,24 @@ export class OAuthConsentService { private handleCompletion(message: OAuthCompleteMessage): void { this.lastCompletion.set(message); - if (message.status === 'success' && message.providerId) { - this.dismiss(message.providerId); + if (message.status !== 'success' || !message.providerId) { + return; } + + // Capture the paused interrupt id BEFORE dismissing the request, since + // dismiss removes the entry the handler needs. A user-initiated + // settings-page consent has no interruptId — nothing to resume. + const request = this.requests().get(message.providerId); + this.dismiss(message.providerId); + + if (!request?.interruptId || !this.resumeHandler) { + return; + } + + void Promise.resolve(this.resumeHandler([request.interruptId])).catch((err) => { + // Resume failures are surfaced through the resume request's own error + // handling — log here for diagnostics but don't crash the consent flow. + console.error('OAuth resume handler failed', err); + }); } } diff --git a/frontend/ai.client/src/app/session/services/chat/chat-request.service.ts b/frontend/ai.client/src/app/session/services/chat/chat-request.service.ts index e7ab45fa..b25e83e1 100644 --- a/frontend/ai.client/src/app/session/services/chat/chat-request.service.ts +++ b/frontend/ai.client/src/app/session/services/chat/chat-request.service.ts @@ -1,4 +1,4 @@ -import { inject, Injectable } from '@angular/core'; +import { inject, Injectable, OnDestroy } from '@angular/core'; import { Router } from '@angular/router'; import { v4 as uuidv4 } from 'uuid'; import { ChatStateService } from './chat-state.service'; @@ -10,6 +10,8 @@ import { ModelService } from '../model/model.service'; import { ToolService } from '../../../services/tool/tool.service'; import { FileUploadService } from '../../../services/file-upload'; import { FileAttachmentData } from '../models/message.model'; +import { OAuthConsentService } from '../../../services/oauth-consent/oauth-consent.service'; +import { StreamParserService } from './stream-parser.service'; export interface ContentFile { fileName: string; @@ -21,7 +23,7 @@ export interface ContentFile { @Injectable({ providedIn: 'root', }) -export class ChatRequestService { +export class ChatRequestService implements OnDestroy { // private conversationService = inject(ConversationService); private chatHttpService = inject(ChatHttpService); private chatStateService = inject(ChatStateService); @@ -31,9 +33,26 @@ export class ChatRequestService { private modelService = inject(ModelService); private toolService = inject(ToolService); private fileUploadService = inject(FileUploadService); + private oauthConsentService = inject(OAuthConsentService); + private streamParserService = inject(StreamParserService); private router = inject(Router); // TODO: Inject proper logging service + /** Last request payload — replayed (with `interrupt_responses` added) when + * the user completes an OAuth consent so the paused agent turn resumes + * without retyping. Cleared on a true new turn. */ + private lastRequestObject: Record | null = null; + + constructor() { + this.oauthConsentService.setResumeHandler((interruptIds) => + this.resumeFromOAuthConsent(interruptIds), + ); + } + + ngOnDestroy(): void { + this.oauthConsentService.setResumeHandler(null); + } + async submitChatRequest( userInput: string, sessionId: string | null, @@ -78,6 +97,12 @@ export class ChatRequestService { assistantId, ); + // Remember this turn's params so the OAuth resume handler can replay + // them with `interrupt_responses` attached. Snapshotting the *exact* + // payload keeps the agent cache key stable, so the resume hits the + // same paused agent instance. + this.lastRequestObject = { ...requestObject }; + try { await this.chatHttpService.sendChatRequest(requestObject); } catch (error) { @@ -149,6 +174,54 @@ export class ChatRequestService { return requestObject; } + /** + * Replay the last turn's request with `interrupt_responses` attached so + * the backend resumes the paused agent turn instead of starting a new + * one. Triggered by OAuthConsentService after the user completes a + * consent popup. + */ + private async resumeFromOAuthConsent(interruptIds: string[]): Promise { + if (!this.lastRequestObject || interruptIds.length === 0) { + return; + } + + const sessionId = this.lastRequestObject['session_id'] as string | undefined; + if (!sessionId) { + return; + } + + // Reset the parser so the resumed stream is treated as a fresh batch + // of events. Without this, the parser stays in Completed state from + // the prior `done` and ignores everything. + this.streamParserService.reset(sessionId); + this.messageMapService.startStreaming(sessionId); + this.chatStateService.createNewAbortController(); + this.chatStateService.setChatLoading(true); + + const resumeRequest: Record = { + ...this.lastRequestObject, + // The original prompt is already in the agent's interrupt context; + // sending an empty string keeps the request valid without + // re-augmenting or re-charging quota. + message: '', + interrupt_responses: interruptIds.map((interruptId) => ({ + interruptId, + // The token is already in AgentCore Identity's vault by the time + // we resume; the response payload itself doesn't carry a secret — + // it's just the signal that consent completed. + response: 'consented', + })), + }; + + try { + await this.chatHttpService.sendChatRequest(resumeRequest); + } catch (error) { + this.chatStateService.setChatLoading(false); + this.messageMapService.endStreaming(); + throw error; + } + } + /** * Get file attachment metadata for display in user messages. * Retrieves file metadata from FileUploadService for given upload IDs. diff --git a/frontend/ai.client/src/app/session/services/chat/stream-parser.service.ts b/frontend/ai.client/src/app/session/services/chat/stream-parser.service.ts index c6c30252..4e925592 100644 --- a/frontend/ai.client/src/app/session/services/chat/stream-parser.service.ts +++ b/frontend/ai.client/src/app/session/services/chat/stream-parser.service.ts @@ -197,8 +197,11 @@ export class StreamParserService { } // Check if we should process this event - const isStartOrErrorEvent = event === 'message_start' || event === 'error'; - if (!isStartOrErrorEvent && !this.shouldProcessEvent()) { + // oauth_required arrives after message_stop/done by design (see CLAUDE.md SSE + // table) — allow it through even when the stream state is Completed. + const isAlwaysAllowedEvent = + event === 'message_start' || event === 'error' || event === 'oauth_required'; + if (!isAlwaysAllowedEvent && !this.shouldProcessEvent()) { return; } @@ -293,7 +296,11 @@ export class StreamParserService { onQuotaExceeded: (data) => this.quotaWarningService.setQuotaExceeded(data as QuotaExceeded), onOAuthRequired: (data: OAuthRequiredEvent) => - this.oauthConsentService.requestConsent(data.providerId, data.authorizationUrl), + this.oauthConsentService.requestConsent( + data.providerId, + data.authorizationUrl, + data.interruptId, + ), onError: (data) => this.handleError(data), onStreamError: (data) => diff --git a/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-core.ts b/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-core.ts index 092a0ee7..de03d17c 100644 --- a/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-core.ts +++ b/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-core.ts @@ -337,7 +337,9 @@ export function validateOAuthRequiredEvent(data: unknown): data is OAuthRequired typeof event.providerId === 'string' && event.providerId.length > 0 && typeof event.authorizationUrl === 'string' && - event.authorizationUrl.length > 0 + event.authorizationUrl.length > 0 && + typeof event.interruptId === 'string' && + event.interruptId.length > 0 ); } diff --git a/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-types.ts b/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-types.ts index 382d4842..59b6586b 100644 --- a/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-types.ts +++ b/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-types.ts @@ -88,13 +88,15 @@ export interface ReasoningEvent { /** * OAuth required event — emitted when an external MCP tool needs the user - * to grant consent via AgentCore Identity. The payload carries the provider - * slug and the consent URL to open. + * to grant consent via AgentCore Identity. The agent's tool call is paused + * (Strands interrupt) and the frontend resumes the same turn after the + * user completes consent by POSTing back the carried `interruptId`. */ export interface OAuthRequiredEvent { type: 'oauth_required'; providerId: string; authorizationUrl: string; + interruptId: string; } /** From 90733559bdd5a054b24add81b3c3fd0f57a40ab3 Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Wed, 22 Apr 2026 17:46:44 -0600 Subject: [PATCH 14/35] fix(connectors): bind complete_consent to initiating user + tighten auth-failure regex Hardens two gaps called out in review of the AgentCore OAuth flow. - `/connectors/complete-consent` now verifies the submitted `session_uri` was issued to the authenticated user at `initiate_consent`, rejecting cross-user replay with 403 before ever calling AgentCore. Backed by a thread-safe TTL cache (10 min, single-use). Soft-fails with a warning when AgentCore's authorize URL doesn't carry a recognised session parameter, so an SDK shape change logs rather than blocks. - `_AUTH_FAILURE_PATTERN` tightened with word boundaries on every clause and a non-path guard on `401` so tool errors containing `/v1/401/...` no longer trigger a spurious force-reauth. Also moves `import boto3`/`os` out of the `complete_consent` handler body and caches the control-plane client via `lru_cache`. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../main_agent/session/hooks/oauth_consent.py | 19 ++- .../apis/inference_api/connectors/routes.py | 57 ++++++- .../src/apis/shared/oauth/session_cache.py | 115 +++++++++++++ .../session/test_oauth_consent_hook.py | 86 +++++++++- .../inference_api/test_connectors_routes.py | 158 ++++++++++++++++++ .../tests/shared/oauth/test_session_cache.py | 137 +++++++++++++++ 6 files changed, 559 insertions(+), 13 deletions(-) create mode 100644 backend/src/apis/shared/oauth/session_cache.py create mode 100644 backend/tests/apis/inference_api/test_connectors_routes.py create mode 100644 backend/tests/shared/oauth/test_session_cache.py diff --git a/backend/src/agents/main_agent/session/hooks/oauth_consent.py b/backend/src/agents/main_agent/session/hooks/oauth_consent.py index 5ba1b489..203422c4 100644 --- a/backend/src/agents/main_agent/session/hooks/oauth_consent.py +++ b/backend/src/agents/main_agent/session/hooks/oauth_consent.py @@ -45,12 +45,21 @@ # String markers that indicate an OAuth-style auth failure in a tool # result. MCP servers vary in how they format errors, so we match a small # set of unambiguous signals: the literal HTTP code, "Unauthorized", and -# explicit OAuth/token-rejected language. Tools that legitimately return -# the digit "401" in successful output are not at risk because we also -# require an error status (or a context word) — see `_looks_like_auth_failure`. +# explicit OAuth/token-rejected language. +# +# Every alternative is word-bounded. `401` additionally excludes adjacent +# `/` so path segments like `/v1/401/...` in an error message do not +# trigger a false-positive reauth. We only run the pattern on results +# whose `status == "error"` (see `_looks_like_auth_failure`), so this +# plus `\b` on every other clause is tight enough in practice. _AUTH_FAILURE_PATTERN = re.compile( - r"\b401\b|\bunauthorized\b|invalid[_\s-]token|expired[_\s-]token" - r"|token[\s_-]expired|rejected the oauth token|oauth token (?:has )?expired", + r"(? (user_id, expires_at_monotonic) +_pending: dict[str, tuple[str, float]] = {} + + +def remember(user_id: str, session_uri: str) -> None: + """Record that `user_id` initiated consent with `session_uri`.""" + now = time.monotonic() + expires_at = now + _TTL_SECONDS + with _lock: + _prune_locked(now) + _pending[session_uri] = (user_id, expires_at) + + +def consume(user_id: str, session_uri: str) -> bool: + """Return True if `session_uri` was initiated by `user_id` and is unexpired. + + The entry is removed on a successful match (single-use). Misses are + left in place so a concurrent legitimate completer still succeeds. + """ + now = time.monotonic() + with _lock: + _prune_locked(now) + entry = _pending.get(session_uri) + if entry is None: + return False + owner_id, expires_at = entry + if expires_at <= now: + _pending.pop(session_uri, None) + return False + if owner_id != user_id: + return False + _pending.pop(session_uri, None) + return True + + +def forget_user(user_id: str) -> int: + """Drop every pending session initiated by `user_id`. Returns the count.""" + with _lock: + keys = [uri for uri, (owner, _) in _pending.items() if owner == user_id] + for key in keys: + _pending.pop(key, None) + return len(keys) + + +def _prune_locked(now: float) -> None: + """Evict expired entries. Caller must hold `_lock`.""" + expired = [uri for uri, (_, exp) in _pending.items() if exp <= now] + for uri in expired: + _pending.pop(uri, None) + + +# Query-parameter names AgentCore has been observed to use for the +# session identifier when it hands back an authorization URL. We try +# each in order; whichever hits first is returned. +_SESSION_URI_PARAMS = ("request_uri", "session_id", "sessionUri", "sessionId") + + +def extract_session_uri(authorization_url: str) -> Optional[str]: + """Parse the AgentCore session identifier out of an authorization URL. + + Returns None when none of the expected parameter names are present. + Callers should treat that as a soft failure (log + skip server-side + tracking) rather than blocking consent — AgentCore's own binding is + still in force. + """ + if not authorization_url: + return None + try: + parsed = urlparse(authorization_url) + except ValueError: + return None + params = parse_qs(parsed.query, keep_blank_values=False) + for key in _SESSION_URI_PARAMS: + values = params.get(key) + if values and values[0]: + return values[0] + return None diff --git a/backend/tests/agents/main_agent/session/test_oauth_consent_hook.py b/backend/tests/agents/main_agent/session/test_oauth_consent_hook.py index bab2f085..195ab452 100644 --- a/backend/tests/agents/main_agent/session/test_oauth_consent_hook.py +++ b/backend/tests/agents/main_agent/session/test_oauth_consent_hook.py @@ -19,7 +19,10 @@ TokenResult, WorkloadTokenUnavailableError, ) -from agents.main_agent.session.hooks.oauth_consent import OAuthConsentHook +from agents.main_agent.session.hooks.oauth_consent import ( + OAuthConsentHook, + _looks_like_auth_failure, +) @pytest.fixture(autouse=True) @@ -417,3 +420,84 @@ async def test_scopes_lookup_is_cached_across_calls(self): await hook._gate(event2) assert scopes_lookup.call_count == 1 + + +class TestLooksLikeAuthFailure: + """Detector must fire on genuine auth failures and ignore everything + else — paths containing `401`, non-error statuses, benign prose. + """ + + def _err(self, text: str) -> dict: + return {"status": "error", "content": [{"text": text}]} + + def _ok(self, text: str) -> dict: + return {"status": "success", "content": [{"text": text}]} + + @pytest.mark.parametrize( + "text", + [ + "HTTP 401 Unauthorized", + "Request failed: 401", + "status=401 message=unauthorized", + "The server rejected the OAuth token", + "invalid_token", + "invalid-token", + "invalid token", + "expired_token", + "token expired", + "token_expired", + "oauth token expired", + "oauth token has expired", + "Unauthorized", + ], + ) + def test_matches_genuine_auth_errors(self, text): + assert _looks_like_auth_failure(self._err(text)) is True + + @pytest.mark.parametrize( + "text", + [ + # Path segments containing 401 — previously false-positive. + "GET /v1/401/foo failed with 500", + "https://example.com/api/401/items returned empty", + # Digits embedded in other numbers. + "returned 4010 rows", + "status 14011", + # Token as substring of longer words should not match. + "refreshtokenRequired", + "ExpiredTokens", # plural — not \btoken\b + # Prose that shouldn't trigger. + "The weather today is unauthorized-feeling, but fine", # still matches \bunauthorized\b + # ^ confirms we accept this as a fair trigger; pure prose without + # the keyword should not. + "Everything is fine, nothing to see here", + "Rate limit exceeded", + "500 Internal Server Error", + ], + ) + def test_avoids_false_positives(self, text): + # All cases above either have no auth keyword OR have the keyword + # embedded in a longer word — the regex must reject them. + if "unauthorized-feeling" in text: + # Word boundary intentionally matches "unauthorized" even when + # followed by a hyphen. This is expected; skip from negative set. + pytest.skip("hyphen after keyword is a legitimate match") + assert _looks_like_auth_failure(self._err(text)) is False + + def test_ignores_non_error_status(self): + # Even an auth-shaped body doesn't count if status is success. + assert _looks_like_auth_failure(self._ok("401 Unauthorized")) is False + + def test_ignores_non_dict_result(self): + assert _looks_like_auth_failure("HTTP 401 Unauthorized") is False + assert _looks_like_auth_failure(None) is False + assert _looks_like_auth_failure(["401"]) is False + + def test_ignores_missing_content(self): + assert _looks_like_auth_failure({"status": "error"}) is False + assert _looks_like_auth_failure({"status": "error", "content": None}) is False + assert _looks_like_auth_failure({"status": "error", "content": []}) is False + + def test_ignores_non_dict_content_blocks(self): + result = {"status": "error", "content": ["401 Unauthorized"]} + assert _looks_like_auth_failure(result) is False diff --git a/backend/tests/apis/inference_api/test_connectors_routes.py b/backend/tests/apis/inference_api/test_connectors_routes.py new file mode 100644 index 00000000..e876d855 --- /dev/null +++ b/backend/tests/apis/inference_api/test_connectors_routes.py @@ -0,0 +1,158 @@ +"""Route-level tests for `/connectors/complete-consent`. + +Focused on the defence-in-depth check that rejects a `session_uri` that +was never remembered for the caller. The AgentCore control-plane client +is patched out — we care about whether our gate blocks the request +before it reaches AWS, not the downstream call itself. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from apis.inference_api.connectors import routes +from apis.shared.auth.models import User +from apis.shared.oauth import session_cache + + +@pytest.fixture(autouse=True) +def _clear_cache(): + session_cache._pending.clear() # noqa: SLF001 — test-only + yield + session_cache._pending.clear() # noqa: SLF001 — test-only + + +@pytest.fixture(autouse=True) +def _reset_control_client(): + """`_agentcore_control_client` is `lru_cache`d; reset between tests.""" + routes._agentcore_control_client.cache_clear() + yield + routes._agentcore_control_client.cache_clear() + + +def _make_user(user_id: str) -> User: + return User( + user_id=user_id, + email=f"{user_id}@example.com", + name=user_id.capitalize(), + roles=[], + raw_token="test-token", + ) + + +@pytest.fixture +def app_for_user(): + """Build a minimal FastAPI app with the connectors router mounted and + the `get_current_user_trusted` dependency stubbed to a specific user. + Returns a factory so each test picks the caller's identity. + """ + + def _build(user_id: str) -> FastAPI: + app = FastAPI() + app.include_router(routes.router) + app.dependency_overrides[routes.get_current_user_trusted] = lambda: _make_user(user_id) + return app + + return _build + + +class TestCompleteConsentAuthorization: + def test_rejects_unknown_session_uri(self, app_for_user, monkeypatch): + """No prior `remember` call → 403.""" + mock_client = MagicMock() + monkeypatch.setattr(routes, "_agentcore_control_client", lambda: mock_client) + + app = app_for_user("alice") + client = TestClient(app) + response = client.post( + "/connectors/complete-consent", + json={"session_uri": "never-issued", "provider_id": "google"}, + ) + + assert response.status_code == 403 + assert "not initiated by you" in response.json()["detail"] + # AgentCore must not be called when the gate rejects the request. + mock_client.complete_resource_token_auth.assert_not_called() + + def test_rejects_session_uri_from_different_user(self, app_for_user, monkeypatch): + """Alice's session cannot be completed by Bob.""" + mock_client = MagicMock() + monkeypatch.setattr(routes, "_agentcore_control_client", lambda: mock_client) + + # Alice starts a flow; Bob's request tries to steal it. + session_cache.remember("alice", "uri-abc") + + app = app_for_user("bob") + client = TestClient(app) + response = client.post( + "/connectors/complete-consent", + json={"session_uri": "uri-abc", "provider_id": "google"}, + ) + + assert response.status_code == 403 + mock_client.complete_resource_token_auth.assert_not_called() + # Alice's entry survives — Bob's rejection is non-destructive. + assert session_cache.consume("alice", "uri-abc") is True + + def test_accepts_session_uri_remembered_for_caller(self, app_for_user, monkeypatch): + mock_client = MagicMock() + monkeypatch.setattr(routes, "_agentcore_control_client", lambda: mock_client) + + session_cache.remember("alice", "uri-abc") + + app = app_for_user("alice") + client = TestClient(app) + response = client.post( + "/connectors/complete-consent", + json={"session_uri": "uri-abc", "provider_id": "google"}, + ) + + assert response.status_code == 200 + assert response.json() == {"ok": True} + mock_client.complete_resource_token_auth.assert_called_once_with( + userIdentifier={"userId": "alice"}, + sessionUri="uri-abc", + ) + + def test_single_use_rejects_replay(self, app_for_user, monkeypatch): + """A successful completion consumes the entry — replay gets 403.""" + mock_client = MagicMock() + monkeypatch.setattr(routes, "_agentcore_control_client", lambda: mock_client) + + session_cache.remember("alice", "uri-abc") + + app = app_for_user("alice") + client = TestClient(app) + first = client.post( + "/connectors/complete-consent", + json={"session_uri": "uri-abc", "provider_id": "google"}, + ) + second = client.post( + "/connectors/complete-consent", + json={"session_uri": "uri-abc", "provider_id": "google"}, + ) + + assert first.status_code == 200 + assert second.status_code == 403 + assert mock_client.complete_resource_token_auth.call_count == 1 + + def test_surfaces_agentcore_error_as_502(self, app_for_user, monkeypatch): + mock_client = MagicMock() + mock_client.complete_resource_token_auth.side_effect = RuntimeError("agentcore down") + monkeypatch.setattr(routes, "_agentcore_control_client", lambda: mock_client) + + session_cache.remember("alice", "uri-abc") + + app = app_for_user("alice") + client = TestClient(app) + response = client.post( + "/connectors/complete-consent", + json={"session_uri": "uri-abc", "provider_id": "google"}, + ) + + assert response.status_code == 502 + assert "agentcore down" in response.json()["detail"] diff --git a/backend/tests/shared/oauth/test_session_cache.py b/backend/tests/shared/oauth/test_session_cache.py new file mode 100644 index 00000000..36262aed --- /dev/null +++ b/backend/tests/shared/oauth/test_session_cache.py @@ -0,0 +1,137 @@ +"""Tests for the OAuth consent session cache.""" + +from __future__ import annotations + +import time + +import pytest + +from apis.shared.oauth import session_cache + + +@pytest.fixture(autouse=True) +def _isolate_cache(): + """Cache is module-global; drop all entries around each test.""" + session_cache.forget_user("alice") + session_cache.forget_user("bob") + # Also clear anything else left over from prior tests by bumping + # the clock forward inside a prune. + session_cache._pending.clear() # noqa: SLF001 — test-only + yield + session_cache._pending.clear() # noqa: SLF001 — test-only + + +class TestRememberAndConsume: + def test_consume_matches_remembered_pair(self): + session_cache.remember("alice", "uri-123") + assert session_cache.consume("alice", "uri-123") is True + + def test_consume_is_single_use(self): + session_cache.remember("alice", "uri-123") + assert session_cache.consume("alice", "uri-123") is True + # Second consume on the same uri fails — no replay. + assert session_cache.consume("alice", "uri-123") is False + + def test_consume_rejects_mismatched_user(self): + session_cache.remember("alice", "uri-123") + # Bob cannot complete a session Alice started. + assert session_cache.consume("bob", "uri-123") is False + # Alice's entry still there — rejection is non-destructive. + assert session_cache.consume("alice", "uri-123") is True + + def test_consume_rejects_unknown_session(self): + assert session_cache.consume("alice", "never-seen") is False + + def test_remember_overwrites_previous_owner(self): + # If a second user somehow remembers the same uri, the latest + # write wins. (AgentCore request_uris are opaque and unique in + # practice — this is just a deterministic-semantics test.) + session_cache.remember("alice", "uri-123") + session_cache.remember("bob", "uri-123") + assert session_cache.consume("alice", "uri-123") is False + assert session_cache.consume("bob", "uri-123") is True + + +class TestExpiry: + def test_expired_entries_are_rejected(self, monkeypatch): + # Freeze clock: remember at t=0, consume at t=TTL+1. + t = [0.0] + monkeypatch.setattr(session_cache.time, "monotonic", lambda: t[0]) + + session_cache.remember("alice", "uri-123") + t[0] = session_cache._TTL_SECONDS + 1.0 # noqa: SLF001 — test-only + assert session_cache.consume("alice", "uri-123") is False + + def test_entries_within_ttl_succeed(self, monkeypatch): + t = [0.0] + monkeypatch.setattr(session_cache.time, "monotonic", lambda: t[0]) + + session_cache.remember("alice", "uri-123") + t[0] = session_cache._TTL_SECONDS - 1.0 # noqa: SLF001 — test-only + assert session_cache.consume("alice", "uri-123") is True + + def test_prune_evicts_expired_entries(self, monkeypatch): + t = [0.0] + monkeypatch.setattr(session_cache.time, "monotonic", lambda: t[0]) + + session_cache.remember("alice", "old-uri") + t[0] = session_cache._TTL_SECONDS + 1.0 # noqa: SLF001 — test-only + # Any subsequent remember/consume triggers the internal prune. + session_cache.remember("bob", "new-uri") + assert "old-uri" not in session_cache._pending # noqa: SLF001 — test-only + + +class TestForgetUser: + def test_forget_clears_only_that_user(self): + session_cache.remember("alice", "a1") + session_cache.remember("alice", "a2") + session_cache.remember("bob", "b1") + + assert session_cache.forget_user("alice") == 2 + assert session_cache.consume("alice", "a1") is False + assert session_cache.consume("alice", "a2") is False + assert session_cache.consume("bob", "b1") is True + + +class TestExtractSessionUri: + @pytest.mark.parametrize( + ("url", "expected"), + [ + ( + "https://bedrock-agentcore.us-west-2.amazonaws.com/identities/oauth2/authorize?request_uri=urn%3Aietf%3Aparams%3Aoauth%3Arequest-uri%3Aabc123&client_id=x", + "urn:ietf:params:oauth:request-uri:abc123", + ), + ( + "https://example.com/authorize?session_id=sess_abc123&x=1", + "sess_abc123", + ), + ( + "https://example.com/authorize?sessionUri=sess_xyz", + "sess_xyz", + ), + ( + "https://example.com/authorize?sessionId=sess_xyz", + "sess_xyz", + ), + ], + ) + def test_extracts_known_params(self, url, expected): + assert session_cache.extract_session_uri(url) == expected + + def test_returns_none_when_no_known_param(self): + url = "https://example.com/authorize?client_id=x&state=y" + assert session_cache.extract_session_uri(url) is None + + def test_returns_none_for_empty_input(self): + assert session_cache.extract_session_uri("") is None + assert session_cache.extract_session_uri(None) is None # type: ignore[arg-type] + + def test_prefers_request_uri_when_multiple_present(self): + # If AgentCore ever emits both, request_uri is the canonical OAuth 2.0 + # PAR parameter — prefer it. + url = "https://example.com/authorize?session_id=shouldnt_win&request_uri=should_win" + assert session_cache.extract_session_uri(url) == "should_win" + + def test_handles_malformed_url(self): + # parse_qs is lenient, but urlparse should still not raise on junk. + assert session_cache.extract_session_uri("not a url") is None From 895f187080d1d09c1a0f36e3abd32259e00351ae Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Wed, 22 Apr 2026 17:54:00 -0600 Subject: [PATCH 15/35] fix(connectors): type-assert AgentCore responses + harden create rollback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses the remaining two critical items from PR #174 review. Registrar response parsing (`_info_from_response`): fails loudly on contract violations rather than silently storing empty strings. Missing `clientSecretArn` still tolerated (some vendors won't persist one) but a wrong-shape `clientSecretArn` or absent `credentialProviderArn` now raises TypeError so an AgentCore API change surfaces as a real error. Admin create-provider rollback (`_rollback_orphaned_provider`): now retries the AgentCore delete twice with backoff before giving up. On exhaustion, emits a CloudWatch `Agentcore/OAuth::ProviderOrphaned` custom metric so ops can alarm on stranded credential providers. Secondary failures (CW down, registrar down after retries) never shadow the admin's original 5xx — they only log. The subsequent create attempt that hits `CredentialProviderConflictError` with no DB record now returns an actionable 409 pointing at the AWS CLI cleanup command instead of a bare "already exists". App API task role grants `cloudwatch:PutMetricData` scoped to the `Agentcore/OAuth` namespace via a condition key. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../src/apis/app_api/admin/oauth/routes.py | 121 ++++++++++++++++-- .../apis/shared/oauth/agentcore_registrar.py | 35 ++++- .../apis/app_api/test_admin_oauth_rollback.py | 119 +++++++++++++++++ .../shared/test_oauth_agentcore_registrar.py | 74 +++++++++++ infrastructure/lib/app-api-stack.ts | 18 +++ 5 files changed, 354 insertions(+), 13 deletions(-) create mode 100644 backend/tests/apis/app_api/test_admin_oauth_rollback.py diff --git a/backend/src/apis/app_api/admin/oauth/routes.py b/backend/src/apis/app_api/admin/oauth/routes.py index 76ea385d..0b5969c4 100644 --- a/backend/src/apis/app_api/admin/oauth/routes.py +++ b/backend/src/apis/app_api/admin/oauth/routes.py @@ -6,9 +6,13 @@ the callback URL that the admin must register with the vendor. """ +import asyncio import logging +import os from datetime import datetime, timezone +from functools import lru_cache +import boto3 from fastapi import APIRouter, Depends, HTTPException, Query, status from apis.shared.auth import User, require_admin @@ -36,6 +40,92 @@ router = APIRouter(prefix="/oauth-providers", tags=["admin-oauth"]) +# Rollback backoff schedule. Two retries after the initial attempt, ~2.5s +# total worst case — short enough to keep the create request responsive, +# long enough to absorb the common class of transient AWS errors. +_ROLLBACK_RETRY_DELAYS_SECONDS = (0.5, 2.0) + +# CloudWatch namespace for orphan telemetry. Kept distinct so ops can +# scope alarms without catching unrelated Bedrock metrics. +_ORPHAN_METRIC_NAMESPACE = "Agentcore/OAuth" +_ORPHAN_METRIC_NAME = "ProviderOrphaned" + + +@lru_cache(maxsize=1) +def _cloudwatch_client(): + region = os.environ.get("AWS_REGION", "us-west-2") + return boto3.client("cloudwatch", region_name=region) + + +async def _rollback_orphaned_provider( + registrar: AgentCoreRegistrar, provider_id: str +) -> None: + """Best-effort delete of an AgentCore provider after a DB write failed. + + Retries on transient AWS errors. If every attempt fails we emit a + CloudWatch `ProviderOrphaned` metric and log at ERROR — the AgentCore + record is now orphaned (no DynamoDB row) and needs manual cleanup, + but the admin's original 5xx still propagates so they know the + create didn't land. + """ + last_err: Exception | None = None + for attempt in range(1 + len(_ROLLBACK_RETRY_DELAYS_SECONDS)): + try: + # Registrar is sync; off-thread it so we don't block the event loop. + await asyncio.to_thread(registrar.delete_credential_provider, provider_id) + logger.info( + "Rolled back orphaned AgentCore provider %s (attempt %d)", + provider_id, + attempt + 1, + ) + return + except Exception as err: + last_err = err + if attempt < len(_ROLLBACK_RETRY_DELAYS_SECONDS): + delay = _ROLLBACK_RETRY_DELAYS_SECONDS[attempt] + logger.warning( + "Rollback attempt %d for %s failed (%s); retrying in %.1fs", + attempt + 1, + provider_id, + err, + delay, + ) + await asyncio.sleep(delay) + + logger.error( + "Rollback delete exhausted for %s after %d attempts; emitting " + "orphan metric. Last error: %s", + provider_id, + 1 + len(_ROLLBACK_RETRY_DELAYS_SECONDS), + last_err, + exc_info=last_err, + ) + _emit_orphan_metric(provider_id) + + +def _emit_orphan_metric(provider_id: str) -> None: + """Fire-and-forget CloudWatch metric for an orphaned credential provider. + + Failures here are swallowed — we're already inside a rollback path + and a secondary error would just shadow the admin-facing one. + """ + try: + _cloudwatch_client().put_metric_data( + Namespace=_ORPHAN_METRIC_NAMESPACE, + MetricData=[ + { + "MetricName": _ORPHAN_METRIC_NAME, + "Dimensions": [{"Name": "ProviderId", "Value": provider_id}], + "Value": 1, + "Unit": "Count", + } + ], + ) + except Exception: + logger.exception( + "Failed to emit CloudWatch orphan metric for %s", provider_id + ) + # ============================================================================= # Provider CRUD @@ -107,8 +197,27 @@ async def create_provider( discovery_url=provider_data.oauth_discovery_url, authorization_server_metadata=provider_data.authorization_server_metadata, ) - except CredentialProviderConflictError as err: - raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(err)) + except CredentialProviderConflictError: + # We already verified the DB has no record for this provider_id + # above, so a conflict from AgentCore means its vault carries an + # orphaned record — almost always from a prior failed rollback. + # Give the admin a cleanup pointer instead of a bare 409. + logger.error( + "Orphan detected for %s: DB empty but AgentCore has a credential " + "provider. Previous rollback likely failed.", + provider_data.provider_id, + ) + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=( + f"An AgentCore credential provider named " + f"'{provider_data.provider_id}' already exists but has no " + "matching database record (likely from a previous failed " + "rollback). Delete it via the AWS CLI and retry: " + "`aws bedrock-agentcore-control delete-oauth2-credential-provider " + f"--name {provider_data.provider_id}`." + ), + ) except InvalidCustomProviderConfigError as err: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(err)) @@ -120,13 +229,7 @@ async def create_provider( "DB write failed for %s; rolling back AgentCore credential provider", provider_data.provider_id, ) - try: - registrar.delete_credential_provider(provider_data.provider_id) - except Exception: - logger.exception( - "Rollback delete failed for %s; orphaned AgentCore provider may exist", - provider_data.provider_id, - ) + await _rollback_orphaned_provider(registrar, provider_data.provider_id) raise return OAuthProviderResponse.from_provider(provider) diff --git a/backend/src/apis/shared/oauth/agentcore_registrar.py b/backend/src/apis/shared/oauth/agentcore_registrar.py index 3e140bac..c426b1ee 100644 --- a/backend/src/apis/shared/oauth/agentcore_registrar.py +++ b/backend/src/apis/shared/oauth/agentcore_registrar.py @@ -288,7 +288,28 @@ def _build_oauth_discovery( def _info_from_response( *, provider_id: str, vendor: str, response: Dict[str, Any] ) -> CredentialProviderInfo: - client_secret = response.get("clientSecretArn") or {} + # AgentCore's documented shape is `clientSecretArn: {secretArn: str}`. + # If the field is missing (possible for vendors that don't persist a + # secret) we tolerate that. If it's present but shaped differently, + # that's a contract change we want to fail loudly on rather than + # silently storing an empty string. + raw_secret = response.get("clientSecretArn") + if raw_secret is None: + client_secret_arn = "" + elif isinstance(raw_secret, dict): + secret_arn = raw_secret.get("secretArn", "") + if not isinstance(secret_arn, str): + raise TypeError( + f"AgentCore returned clientSecretArn.secretArn of unexpected " + f"type {type(secret_arn).__name__}; expected str" + ) + client_secret_arn = secret_arn + else: + raise TypeError( + f"AgentCore returned clientSecretArn of unexpected type " + f"{type(raw_secret).__name__}; expected dict or None" + ) + output_config = response.get("oauth2ProviderConfigOutput") or {} # Each vendor variant nests its own output object; the clientId lives # one level deeper when present. We tolerate its absence. @@ -298,12 +319,18 @@ def _info_from_response( client_id = nested["clientId"] break + credential_provider_arn = response.get("credentialProviderArn") + if not isinstance(credential_provider_arn, str) or not credential_provider_arn: + raise TypeError( + "AgentCore response missing credentialProviderArn or wrong type" + ) + return CredentialProviderInfo( provider_id=provider_id, vendor=vendor, - credential_provider_arn=response["credentialProviderArn"], - client_secret_arn=client_secret.get("secretArn", ""), - callback_url=response.get("callbackUrl", ""), + credential_provider_arn=credential_provider_arn, + client_secret_arn=client_secret_arn, + callback_url=response.get("callbackUrl", "") or "", client_id=client_id, ) diff --git a/backend/tests/apis/app_api/test_admin_oauth_rollback.py b/backend/tests/apis/app_api/test_admin_oauth_rollback.py new file mode 100644 index 00000000..898fc1f8 --- /dev/null +++ b/backend/tests/apis/app_api/test_admin_oauth_rollback.py @@ -0,0 +1,119 @@ +"""Tests for the admin OAuth create-provider rollback helpers. + +Verifies retry + CloudWatch-metric behaviour of +`_rollback_orphaned_provider` and `_emit_orphan_metric`. These run +inside the admin `create_provider` route after a DB write fails; if +both the rollback AND the metric emit fail, the admin's original error +still propagates — these helpers must never raise. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +from botocore.exceptions import ClientError + +from apis.app_api.admin.oauth import routes + + +def _client_error(code: str = "ThrottlingException") -> ClientError: + return ClientError( + error_response={"Error": {"Code": code, "Message": code}}, + operation_name="DeleteOauth2CredentialProvider", + ) + + +@pytest.fixture(autouse=True) +def _reset_cw_client_cache(): + routes._cloudwatch_client.cache_clear() + yield + routes._cloudwatch_client.cache_clear() + + +@pytest.fixture +def fast_backoff(monkeypatch): + """Rollback retries sleep between attempts; zero those out in tests.""" + monkeypatch.setattr(routes, "_ROLLBACK_RETRY_DELAYS_SECONDS", (0.0, 0.0)) + + +class TestRollbackOrphanedProvider: + @pytest.mark.asyncio + async def test_succeeds_on_first_attempt(self, fast_backoff): + registrar = MagicMock() + registrar.delete_credential_provider.return_value = None + + await routes._rollback_orphaned_provider(registrar, "google") + + registrar.delete_credential_provider.assert_called_once_with("google") + + @pytest.mark.asyncio + async def test_retries_on_transient_error(self, fast_backoff): + registrar = MagicMock() + registrar.delete_credential_provider.side_effect = [ + _client_error("ThrottlingException"), + None, + ] + + await routes._rollback_orphaned_provider(registrar, "google") + + assert registrar.delete_credential_provider.call_count == 2 + + @pytest.mark.asyncio + async def test_emits_metric_when_all_retries_exhausted(self, fast_backoff): + registrar = MagicMock() + registrar.delete_credential_provider.side_effect = _client_error( + "InternalServerException" + ) + + with patch.object(routes, "_emit_orphan_metric") as emit: + await routes._rollback_orphaned_provider(registrar, "google") + + # 1 initial + 2 retries = 3 attempts with zero backoff schedule. + assert registrar.delete_credential_provider.call_count == 3 + emit.assert_called_once_with("google") + + @pytest.mark.asyncio + async def test_never_raises_even_when_everything_fails(self, fast_backoff): + """Rollback runs inside an `except` block that already re-raises + the admin-facing error. A secondary raise here would shadow it. + + Simulates the worst case: every registrar call fails AND + CloudWatch is down. `_emit_orphan_metric` swallows its own + errors, so the overall helper must return cleanly. + """ + registrar = MagicMock() + registrar.delete_credential_provider.side_effect = RuntimeError("boom") + + fake_cw = MagicMock() + fake_cw.put_metric_data.side_effect = RuntimeError("cw boom") + + with patch.object(routes, "_cloudwatch_client", lambda: fake_cw): + # Must not raise — secondary failures only log. + await routes._rollback_orphaned_provider(registrar, "google") + + +class TestEmitOrphanMetric: + def test_puts_metric_with_provider_dimension(self): + fake_cw = MagicMock() + with patch.object(routes, "_cloudwatch_client", lambda: fake_cw): + routes._emit_orphan_metric("google-workspace") + + fake_cw.put_metric_data.assert_called_once() + call_kwargs = fake_cw.put_metric_data.call_args.kwargs + assert call_kwargs["Namespace"] == "Agentcore/OAuth" + metric = call_kwargs["MetricData"][0] + assert metric["MetricName"] == "ProviderOrphaned" + assert metric["Dimensions"] == [ + {"Name": "ProviderId", "Value": "google-workspace"} + ] + assert metric["Value"] == 1 + assert metric["Unit"] == "Count" + + def test_swallows_cloudwatch_failure(self): + """CloudWatch outage must not shadow the admin's create failure.""" + fake_cw = MagicMock() + fake_cw.put_metric_data.side_effect = RuntimeError("cw down") + + with patch.object(routes, "_cloudwatch_client", lambda: fake_cw): + routes._emit_orphan_metric("google") # no raise diff --git a/backend/tests/shared/test_oauth_agentcore_registrar.py b/backend/tests/shared/test_oauth_agentcore_registrar.py index 09cf3443..f98a722a 100644 --- a/backend/tests/shared/test_oauth_agentcore_registrar.py +++ b/backend/tests/shared/test_oauth_agentcore_registrar.py @@ -262,3 +262,77 @@ def test_other_errors_bubble(self, registrar, boto_client): ) with pytest.raises(ClientError): registrar.delete_credential_provider("p") + + +class TestResponseParsing: + """`_info_from_response` must fail loudly on contract violations and + tolerate documented variations (missing secret, missing clientId).""" + + def test_tolerates_missing_client_secret_arn(self, registrar, boto_client): + response = _create_response() + response.pop("clientSecretArn") + boto_client.create_oauth2_credential_provider.return_value = response + + info = registrar.create_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.GOOGLE, + client_id="cid", + client_secret="csec", + ) + assert info.client_secret_arn == "" + + def test_rejects_client_secret_arn_as_string(self, registrar, boto_client): + """AgentCore contract is `{secretArn: str}`; a raw string signals + an API contract change we should fail on loudly.""" + response = _create_response() + response["clientSecretArn"] = "arn:aws:secretsmanager:...:secret:s" + boto_client.create_oauth2_credential_provider.return_value = response + + with pytest.raises(TypeError, match="clientSecretArn of unexpected type"): + registrar.create_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.GOOGLE, + client_id="cid", + client_secret="csec", + ) + + def test_rejects_non_string_nested_secret_arn(self, registrar, boto_client): + response = _create_response() + response["clientSecretArn"] = {"secretArn": 12345} + boto_client.create_oauth2_credential_provider.return_value = response + + with pytest.raises(TypeError, match="secretArn of unexpected type"): + registrar.create_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.GOOGLE, + client_id="cid", + client_secret="csec", + ) + + def test_rejects_missing_credential_provider_arn(self, registrar, boto_client): + response = _create_response() + response.pop("credentialProviderArn") + boto_client.create_oauth2_credential_provider.return_value = response + + with pytest.raises(TypeError, match="credentialProviderArn"): + registrar.create_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.GOOGLE, + client_id="cid", + client_secret="csec", + ) + + def test_tolerates_missing_callback_url(self, registrar, boto_client): + """Callback URL absence falls back to empty string — non-fatal for + vendors that don't declare one yet.""" + response = _create_response() + response["callbackUrl"] = None + boto_client.create_oauth2_credential_provider.return_value = response + + info = registrar.create_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.GOOGLE, + client_id="cid", + client_secret="csec", + ) + assert info.callback_url == "" diff --git a/infrastructure/lib/app-api-stack.ts b/infrastructure/lib/app-api-stack.ts index 90b19e5d..b3176944 100644 --- a/infrastructure/lib/app-api-stack.ts +++ b/infrastructure/lib/app-api-stack.ts @@ -939,6 +939,24 @@ export class AppApiStack extends cdk.Stack { }) ); + // Custom metrics for OAuth admin flows (e.g. ProviderOrphaned emitted + // by `_emit_orphan_metric` when a failed DB write + failed rollback + // leaves an AgentCore credential provider stranded). PutMetricData + // cannot be resource-scoped; we scope via the namespace condition. + taskDefinition.taskRole.addToPrincipalPolicy( + new iam.PolicyStatement({ + sid: 'OAuthAdminMetrics', + effect: iam.Effect.ALLOW, + actions: ['cloudwatch:PutMetricData'], + resources: ['*'], + conditions: { + StringEquals: { + 'cloudwatch:namespace': 'Agentcore/OAuth', + }, + }, + }) + ); + // Grant permissions for API Keys table (imported from Infrastructure Stack) taskDefinition.taskRole.addToPrincipalPolicy( new iam.PolicyStatement({ From 824f8246987802aacdf8b113d6ab8c6e7965337b Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Wed, 22 Apr 2026 21:00:03 -0600 Subject: [PATCH 16/35] fix(connectors): harden oauth consent flow per code review - Reject non-https authorizationUrls at both intake and open time so a compromised backend can't smuggle javascript:/data: URIs into a user click. - Replace window.location.href hijack on popup-block with a blocked signal; the banner renders an "Open in new tab" anchor instead of tearing down the chat tab. - Reject resume requests whose interruptIds aren't present in the cached agent's _interrupt_state with 400, preventing silent acceptance after cache eviction, process restart, or forged payloads. Co-Authored-By: Claude Opus 4.7 (1M context) --- backend/src/apis/inference_api/chat/routes.py | 23 ++++ .../oauth-consent-banner.component.ts | 54 +++++---- .../oauth-consent/oauth-consent.service.ts | 105 ++++++++++++++++-- 3 files changed, 156 insertions(+), 26 deletions(-) diff --git a/backend/src/apis/inference_api/chat/routes.py b/backend/src/apis/inference_api/chat/routes.py index b07cfea2..939056f7 100644 --- a/backend/src/apis/inference_api/chat/routes.py +++ b/backend/src/apis/inference_api/chat/routes.py @@ -539,6 +539,29 @@ async def invocations(request: InvocationRequest, current_user: User = Depends(g max_tokens=input_data.max_tokens, ) + # Resume requests must target interrupts that the cached agent + # actually has paused. Cache eviction, a process restart, or a + # forged request will otherwise be silently accepted by Strands + # and drop the client's response. Reject up front so the client + # sees a 400 and can restart the turn cleanly. + if is_resume: + strands_agent = getattr(agent, "agent", None) + interrupt_state = getattr(strands_agent, "_interrupt_state", None) if strands_agent else None + known_ids: set[str] = set() + if interrupt_state and getattr(interrupt_state, "activated", False): + interrupts = getattr(interrupt_state, "interrupts", None) or {} + known_ids = set(interrupts.keys()) + submitted_ids = [entry.interruptId for entry in (input_data.interrupt_responses or [])] + unknown_ids = [iid for iid in submitted_ids if iid not in known_ids] + if unknown_ids: + logger.warning( + "Resume rejected: submitted interrupt ids not in paused state" + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Unknown or expired interrupt ids; restart the turn.", + ) + # Build citations list for persistence (convert context chunks to citation format) citations_for_storage = [] if context_chunks: diff --git a/frontend/ai.client/src/app/components/oauth-consent-banner/oauth-consent-banner.component.ts b/frontend/ai.client/src/app/components/oauth-consent-banner/oauth-consent-banner.component.ts index 411931dd..17b24d8f 100644 --- a/frontend/ai.client/src/app/components/oauth-consent-banner/oauth-consent-banner.component.ts +++ b/frontend/ai.client/src/app/components/oauth-consent-banner/oauth-consent-banner.component.ts @@ -24,26 +24,42 @@ import { OAuthConsentService } from '../../services/oauth-consent/oauth-consent. aria-live="polite" >
-
- -
+ @if (connector.iconData) { +
+ +
+ } @else { +
+ +
+ }

{{ connector.displayName }} diff --git a/frontend/ai.client/src/app/settings/connectors/models/user-connector.model.ts b/frontend/ai.client/src/app/settings/connectors/models/user-connector.model.ts index 5fc30d1e..85557f46 100644 --- a/frontend/ai.client/src/app/settings/connectors/models/user-connector.model.ts +++ b/frontend/ai.client/src/app/settings/connectors/models/user-connector.model.ts @@ -7,6 +7,8 @@ export interface UserConnector { displayName: string; providerType: 'google' | 'microsoft' | 'github' | 'canvas' | 'custom'; iconName: string; + /** Optional admin-uploaded icon (base64 data URL). Wins over `iconName`. */ + iconData?: string | null; scopes: string[]; } diff --git a/frontend/ai.client/src/app/settings/pages/connectors-settings/connectors-settings.page.ts b/frontend/ai.client/src/app/settings/pages/connectors-settings/connectors-settings.page.ts index 2c71f5ac..e8dc615c 100644 --- a/frontend/ai.client/src/app/settings/pages/connectors-settings/connectors-settings.page.ts +++ b/frontend/ai.client/src/app/settings/pages/connectors-settings/connectors-settings.page.ts @@ -99,9 +99,19 @@ type ConnectState = class="flex items-start justify-between gap-4 rounded-sm border border-gray-200 bg-white p-4 dark:border-gray-700 dark:bg-gray-800" >

-
- -
+ @if (connector.iconData) { +
+ +
+ } @else { +
+ +
+ }

{{ connector.displayName }} From 8561f89113e5151db21f2ada1da1c411e24ee842 Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Sat, 25 Apr 2026 17:52:48 -0600 Subject: [PATCH 24/35] feat(connectors): inline OAuth consent prompt + persistence-backed restore Replaces the floating OAuth banner with an inline prompt anchored to the assistant turn that triggered consent, and persists pending interrupts to session metadata so a browser refresh rediscovers them instead of leaving the tool call orphaned in `pending` forever. Backend - New `PendingInterrupt` model on `apis.shared.sessions.models`; included on `MessagesListResponse` and `SessionMetadata`. - `metadata.add_pending_interrupt` / `remove_pending_interrupts` / `get_pending_interrupts` helpers using GSI lookup + targeted UpdateExpression. - `StreamCoordinator._extract_oauth_required_events` is now async and persists each interrupt before yielding the SSE event; failures log but never break the live stream. - `get_messages_from_cloud` fetches pending interrupts in parallel. - `/invocations` resume path clears resolved interrupts from metadata after `agent.stream_async` completes. - New `DELETE /sessions/{sid}/pending-interrupts/{iid:path}` endpoint for explicit dismiss; colon-bearing Strands ids preserved via `:path`. Frontend - New `OAuthConsentPromptComponent` with a refined inline card design, connector icon (admin base64 wins over heroicon, falls back to providerType default), eyebrow/lock motif, primary gradient action button, hover-revealed dismiss, fade+slide entrance. - `MessageMapService.loadMessagesForSession` hydrates pending interrupts on session load; anchors to triggering message id when present, else the most recent assistant message. - `OAuthConsentService.openConsentPopup` is async; lazy-fetches a fresh authorization URL via `initiate-consent` when the stored one is absent or expired (handles "already consented in another tab" by auto-resuming). - `OAuthConsentService.dismiss` syncs to backend by default; completion flow opts out so the resume path's own cleanup isn't double-fired. - `MessageListComponent` renders unanchored interrupts at end-of-list as a fallback for the "partial assistant message wasn't persisted" case. - `awaiting_auth` derived tool status renders as a primary-blue ring on the tool-rail dot instead of an indefinite amber spinner. - `ChatRequestService.resumeFromOAuthConsent` accepts a fallback session id (post-refresh case where `lastRequestObject` is null) and surfaces 400 `Unknown or expired interrupt ids` as a conversational error. - Old floating `OAuthConsentBannerComponent` removed. Known follow-up - First-turn-of-a-new-session OAuth: persistence currently no-ops because the session metadata row doesn't exist yet when the interrupt fires. Tracked separately; sidecar item or upsert pattern is the likely fix. Co-Authored-By: Claude Opus 4.7 --- .../streaming/stream_coordinator.py | 49 ++- backend/src/apis/app_api/sessions/routes.py | 42 ++- backend/src/apis/inference_api/chat/routes.py | 28 ++ backend/src/apis/shared/sessions/messages.py | 21 +- backend/src/apis/shared/sessions/metadata.py | 158 ++++++++- backend/src/apis/shared/sessions/models.py | 37 +++ .../oauth-consent-banner.component.ts | 116 ------- .../oauth-consent-prompt.component.ts | 304 ++++++++++++++++++ .../oauth-consent/oauth-consent.service.ts | 137 +++++++- .../chat-input/chat-input.component.html | 3 - .../chat-input/chat-input.component.ts | 3 +- .../components/assistant-message.component.ts | 56 +++- .../tool-rail/tool-rail.component.ts | 8 +- .../components/tool-rail/tool-rail.model.ts | 7 +- .../message-list/message-list.component.html | 9 + .../message-list/message-list.component.ts | 28 +- .../services/chat/chat-request.service.ts | 59 +++- .../services/chat/stream-parser.service.ts | 19 +- .../services/session/message-map.service.ts | 43 ++- .../services/session/session.service.ts | 39 ++- 20 files changed, 998 insertions(+), 168 deletions(-) delete mode 100644 frontend/ai.client/src/app/components/oauth-consent-banner/oauth-consent-banner.component.ts create mode 100644 frontend/ai.client/src/app/components/oauth-consent-prompt/oauth-consent-prompt.component.ts diff --git a/backend/src/agents/main_agent/streaming/stream_coordinator.py b/backend/src/agents/main_agent/streaming/stream_coordinator.py index 90592982..7577274f 100644 --- a/backend/src/agents/main_agent/streaming/stream_coordinator.py +++ b/backend/src/agents/main_agent/streaming/stream_coordinator.py @@ -200,9 +200,15 @@ async def stream_response( # stream closes. The frontend uses these to drive the consent # popup and then POSTs interrupt responses to resume the turn. # Done before the metadata branch so the events land between - # message_stop and the final metadata/done block. + # message_stop and the final metadata/done block. Persistence + # to session metadata happens inside the extractor so a refresh + # rediscovers the consent prompt. if event.get("type") == "done": - for sse in self._extract_oauth_required_events(agent): + for sse in await self._extract_oauth_required_events( + agent, + session_id=session_id, + user_id=user_id, + ): yield sse # Check if this is the "done" event - send final metadata before it @@ -540,9 +546,16 @@ async def stream_response( except Exception as persist_error: logger.error(f"Failed to persist stream error to session: {persist_error}") - def _extract_oauth_required_events(self, agent: Any) -> List[str]: + async def _extract_oauth_required_events( + self, + agent: Any, + session_id: Optional[str] = None, + user_id: Optional[str] = None, + triggering_message_id: Optional[str] = None, + ) -> List[str]: """Yield one SSE-formatted `oauth_required` event per pending OAuth - interrupt on the agent. + interrupt on the agent, persisting each one to session metadata so + the frontend can rediscover them after a refresh. The Strands `_interrupt_state` is populated when `OAuthConsentHook` calls `event.interrupt(...)`. We look for interrupts whose `reason` @@ -550,8 +563,13 @@ def _extract_oauth_required_events(self, agent: Any) -> List[str]: shape the frontend already understands. Non-OAuth interrupts (other approval gates added later) are ignored here so they can be handled by their own SSE event types. + + Persistence is best-effort: a DynamoDB write failure logs but does + not break the live SSE flow. """ from apis.shared.oauth.models import OAuthRequiredEvent + from apis.shared.sessions.metadata import add_pending_interrupt + from apis.shared.sessions.models import PendingInterrupt interrupt_state = getattr(agent, "_interrupt_state", None) if not interrupt_state or not getattr(interrupt_state, "activated", False): @@ -570,6 +588,29 @@ def _extract_oauth_required_events(self, agent: Any) -> List[str]: interrupt.id, ) continue + + # Persist the breadcrumb before yielding so a client that loads + # the session a moment later sees this interrupt. Only attempt + # when we have session/user context — preview/anonymous flows + # don't have a metadata record to write to. + if session_id and user_id: + try: + await add_pending_interrupt( + session_id=session_id, + user_id=user_id, + interrupt=PendingInterrupt( + interrupt_id=interrupt.id, + provider_id=provider_id, + triggering_message_id=triggering_message_id, + created_at=datetime.now(timezone.utc).isoformat(), + ), + ) + except Exception as e: + logger.error( + "Failed to persist pending_interrupt %s: %s", + interrupt.id, e, exc_info=True, + ) + events.append( OAuthRequiredEvent( provider_id=provider_id, diff --git a/backend/src/apis/app_api/sessions/routes.py b/backend/src/apis/app_api/sessions/routes.py index d540b797..0cac994f 100644 --- a/backend/src/apis/app_api/sessions/routes.py +++ b/backend/src/apis/app_api/sessions/routes.py @@ -19,7 +19,12 @@ MessagesListResponse ) from apis.shared.sessions.messages import get_messages -from apis.shared.sessions.metadata import store_session_metadata, get_session_metadata, list_user_sessions +from apis.shared.sessions.metadata import ( + list_user_sessions, + get_session_metadata, + remove_pending_interrupts, + store_session_metadata, +) from .services.session_service import SessionService from apis.app_api.shares.service import get_share_service from apis.shared.auth.dependencies import get_current_user @@ -546,3 +551,38 @@ async def get_session_messages_endpoint( status_code=500, detail=f"Failed to retrieve messages: {str(e)}" ) + + +@router.delete("/{session_id}/pending-interrupts/{interrupt_id:path}", status_code=204) +async def dismiss_pending_interrupt_endpoint( + session_id: str, + interrupt_id: str, + current_user: User = Depends(get_current_user), +): + """Dismiss a pending OAuth consent interrupt for the caller's session. + + The frontend calls this when the user clicks the dismiss button on an + inline consent prompt, so a refresh doesn't redisplay it. The id is + matched as-is — Strands generates ids like ``oauth:google-calendar``, + so we accept ``:path`` to keep the colon literal in the URL. + + No-op for unknown ids and missing sessions, returning 204 in both + cases (the user's intent is satisfied). + """ + user_id = current_user.user_id + + logger.info("DELETE /sessions/%s/pending-interrupts/%s", session_id, interrupt_id) + + try: + await remove_pending_interrupts( + session_id=session_id, + user_id=user_id, + interrupt_ids=[interrupt_id], + ) + return Response(status_code=204) + except Exception as e: + logger.error("Error dismissing pending interrupt", exc_info=True) + raise HTTPException( + status_code=500, + detail=f"Failed to dismiss interrupt: {str(e)}", + ) diff --git a/backend/src/apis/inference_api/chat/routes.py b/backend/src/apis/inference_api/chat/routes.py index 3e4c42d2..5b60ac21 100644 --- a/backend/src/apis/inference_api/chat/routes.py +++ b/backend/src/apis/inference_api/chat/routes.py @@ -622,6 +622,34 @@ async def stream_with_quota_warning() -> AsyncGenerator[str, None]: ): yield event + # Resume bookkeeping: any interrupt that was submitted in this + # request and is no longer present in the agent's interrupt state + # has been resolved — drop the persisted breadcrumb so a refresh + # doesn't redisplay a stale prompt. Interrupts that re-paused + # (same provider, new url) are left in place; the next event + # extractor will refresh them. + if is_resume and input_data.interrupt_responses: + try: + strands_agent = getattr(agent, "agent", None) + interrupt_state = getattr(strands_agent, "_interrupt_state", None) if strands_agent else None + still_paused: set[str] = set() + if interrupt_state and getattr(interrupt_state, "activated", False): + still_paused = set((getattr(interrupt_state, "interrupts", None) or {}).keys()) + resolved_ids = [ + entry.interruptId + for entry in input_data.interrupt_responses + if entry.interruptId not in still_paused + ] + if resolved_ids: + from apis.shared.sessions.metadata import remove_pending_interrupts + await remove_pending_interrupts( + session_id=input_data.session_id, + user_id=user_id, + interrupt_ids=resolved_ids, + ) + except Exception as cleanup_err: + logger.error("Failed to clear resolved pending_interrupts: %s", cleanup_err, exc_info=True) + # Stream response from agent as SSE (with optional files) # Note: Compression is handled by GZipMiddleware if configured in main.py return StreamingResponse( diff --git a/backend/src/apis/shared/sessions/messages.py b/backend/src/apis/shared/sessions/messages.py index 1ec668ab..32de1e95 100644 --- a/backend/src/apis/shared/sessions/messages.py +++ b/backend/src/apis/shared/sessions/messages.py @@ -411,9 +411,20 @@ async def fetch_metadata(): return await get_all_message_metadata(session_id, user_id) + async def fetch_pending_interrupts(): + """Fetch pending OAuth consent interrupts from session metadata. + + Returns an empty list when the session has none — the only signal + the frontend needs to know whether a refresh-restored conversation + still has a paused turn awaiting consent. + """ + from .metadata import get_pending_interrupts + + return await get_pending_interrupts(session_id, user_id) + # Run fetches in parallel - messages_raw, metadata_index = await asyncio.gather( - fetch_messages(), fetch_metadata() + messages_raw, metadata_index, pending_interrupts = await asyncio.gather( + fetch_messages(), fetch_metadata(), fetch_pending_interrupts() ) messages_raw = list(messages_raw or []) @@ -462,7 +473,11 @@ async def fetch_metadata(): message_responses = [_convert_message_to_response(msg, session_id, start_seq + idx) for idx, msg in enumerate(paginated_messages)] - return MessagesListResponse(messages=message_responses, next_token=next_page_token) + return MessagesListResponse( + messages=message_responses, + next_token=next_page_token, + pending_interrupts=pending_interrupts, + ) except Exception as e: logger.error(f"Error retrieving messages from AgentCore Memory: {e}") diff --git a/backend/src/apis/shared/sessions/metadata.py b/backend/src/apis/shared/sessions/metadata.py index acdb7b83..93aea6ee 100644 --- a/backend/src/apis/shared/sessions/metadata.py +++ b/backend/src/apis/shared/sessions/metadata.py @@ -11,11 +11,11 @@ import json import os import base64 -from typing import Optional, Tuple, Any, Dict +from typing import Iterable, List, Optional, Tuple, Any, Dict from decimal import Decimal # Relative imports from shared sessions module -from .models import MessageMetadata, SessionMetadata +from .models import MessageMetadata, PendingInterrupt, SessionMetadata # Import preview session helper from agents.main_agent.session.preview_session_manager import is_preview_session @@ -1183,3 +1183,157 @@ def _deep_merge(base: dict, updates: dict) -> dict: result[key] = value return result + + +# ============================================================================ +# Pending OAuth interrupts +# ============================================================================ +# +# Pending interrupts persist the breadcrumb the SSE stream emits when the +# agent pauses on `oauth_required`, so the frontend can rediscover them on +# reload. We do read-modify-write through the SessionLookupIndex GSI: +# OAuth flows are rare and one-at-a-time per user, so the simplicity wins +# over an UpdateExpression with list_append/REMOVE-by-index gymnastics. + + +def _interrupts_to_dynamo(interrupts: Iterable[PendingInterrupt]) -> List[Dict[str, Any]]: + """Serialize PendingInterrupt list for DynamoDB storage (camelCase keys).""" + return [item.model_dump(by_alias=True, exclude_none=True) for item in interrupts] + + +def _interrupts_from_dynamo(raw: Any) -> List[PendingInterrupt]: + """Best-effort parse of stored interrupt entries; tolerate missing/legacy items.""" + if not raw or not isinstance(raw, list): + return [] + parsed: List[PendingInterrupt] = [] + for entry in raw: + if not isinstance(entry, dict): + continue + try: + parsed.append(PendingInterrupt.model_validate(entry)) + except Exception as exc: # pragma: no cover — corrupted entry shouldn't break load + logger.warning("Skipping unparseable pending_interrupts entry: %s", exc) + return parsed + + +async def add_pending_interrupt( + session_id: str, + user_id: str, + interrupt: PendingInterrupt, +) -> None: + """Idempotently append a pending OAuth interrupt to the session record. + + If an entry with the same ``interrupt_id`` already exists, it's replaced + rather than duplicated — the agent may re-emit the same interrupt across + re-streams of a paused turn. + + No-op when the session metadata record is missing (preview sessions, + sessions deleted mid-turn). The frontend will fall back to its in-memory + consent state in that case. + """ + sessions_metadata_table = os.environ.get("DYNAMODB_SESSIONS_METADATA_TABLE_NAME") + if not sessions_metadata_table: + logger.warning("DYNAMODB_SESSIONS_METADATA_TABLE_NAME not set; skipping pending_interrupts persistence") + return + + try: + import boto3 + + dynamodb = boto3.resource("dynamodb") + table = dynamodb.Table(sessions_metadata_table) + + existing = await _get_session_by_gsi(session_id, user_id, table) + if not existing: + logger.info("Skipping pending_interrupts add — session %s not found", session_id) + return + + sk = existing.get("SK") + if not sk: + logger.warning("Session %s has no SK; cannot update pending_interrupts", session_id) + return + + current_raw = existing.get("pendingInterrupts") or [] + current = _interrupts_from_dynamo(current_raw) + + # Replace any existing entry with the same interrupt_id + merged = [p for p in current if p.interrupt_id != interrupt.interrupt_id] + merged.append(interrupt) + + table.update_item( + Key={"PK": f"USER#{user_id}", "SK": sk}, + UpdateExpression="SET #pi = :pi", + ExpressionAttributeNames={"#pi": "pendingInterrupts"}, + ExpressionAttributeValues={":pi": _interrupts_to_dynamo(merged)}, + ) + logger.info( + "Persisted pending_interrupt %s (provider=%s) for session %s", + interrupt.interrupt_id, interrupt.provider_id, session_id, + ) + except Exception as e: + # Persistence failure must not break the live SSE flow — the in-memory + # consent on the live tab still works; refresh-resume just won't. + logger.error("Failed to persist pending_interrupt: %s", e, exc_info=True) + + +async def remove_pending_interrupts( + session_id: str, + user_id: str, + interrupt_ids: Iterable[str], +) -> None: + """Drop the given ``interrupt_ids`` from the session's pending list. + + No-op for unknown ids and missing sessions. Used by the resume path + (after the agent successfully completes the resumed turn) and by the + explicit dismiss endpoint. + """ + drop_set = {iid for iid in interrupt_ids if iid} + if not drop_set: + return + + sessions_metadata_table = os.environ.get("DYNAMODB_SESSIONS_METADATA_TABLE_NAME") + if not sessions_metadata_table: + return + + try: + import boto3 + + dynamodb = boto3.resource("dynamodb") + table = dynamodb.Table(sessions_metadata_table) + + existing = await _get_session_by_gsi(session_id, user_id, table) + if not existing: + return + + sk = existing.get("SK") + if not sk: + return + + current = _interrupts_from_dynamo(existing.get("pendingInterrupts") or []) + kept = [p for p in current if p.interrupt_id not in drop_set] + + if len(kept) == len(current): + return # Nothing matched + + table.update_item( + Key={"PK": f"USER#{user_id}", "SK": sk}, + UpdateExpression="SET #pi = :pi", + ExpressionAttributeNames={"#pi": "pendingInterrupts"}, + ExpressionAttributeValues={":pi": _interrupts_to_dynamo(kept)}, + ) + logger.info( + "Cleared %d pending_interrupt(s) from session %s", + len(current) - len(kept), session_id, + ) + except Exception as e: + logger.error("Failed to remove pending_interrupts: %s", e, exc_info=True) + + +async def get_pending_interrupts(session_id: str, user_id: str) -> List[PendingInterrupt]: + """Return the current pending OAuth interrupts for a session. + + Returns an empty list when the session doesn't exist or has none. + """ + metadata = await get_session_metadata(session_id, user_id) + if not metadata: + return [] + return list(metadata.pending_interrupts or []) diff --git a/backend/src/apis/shared/sessions/models.py b/backend/src/apis/shared/sessions/models.py index d684b6c5..e760e83c 100644 --- a/backend/src/apis/shared/sessions/models.py +++ b/backend/src/apis/shared/sessions/models.py @@ -20,6 +20,31 @@ class VisualDisplayState(BaseModel): expanded: bool = Field(default=True, description="Visual is expanded vs collapsed") +class PendingInterrupt(BaseModel): + """An OAuth consent request that paused an agent turn and is awaiting user action. + + Persisted to session metadata when ``OAuthConsentHook`` fires the interrupt + so the frontend can rediscover pending consents on reload — without it, a + browser refresh leaves the consent prompt stuck and the tool call orphaned + in ``pending`` forever. + + Note: ``authorization_url`` is intentionally omitted. AgentCore Identity's + consent URLs are short-lived; storing them invites stale-URL bugs on + refresh-after-an-hour. Frontend re-fetches via ``initiate-consent`` on + Connect. + """ + + model_config = ConfigDict(populate_by_name=True) + interrupt_id: str = Field(..., alias="interruptId", description="Strands interrupt id used to resume the paused turn") + provider_id: str = Field(..., alias="providerId", description="Connector providerId needing consent") + triggering_message_id: Optional[str] = Field( + None, + alias="triggeringMessageId", + description="Id of the assistant message whose tool call triggered this interrupt, when known", + ) + created_at: str = Field(..., alias="createdAt", description="ISO 8601 timestamp when the interrupt was recorded") + + class SessionPreferences(BaseModel): """User preferences for a session""" @@ -78,6 +103,13 @@ class SessionMetadata(BaseModel): deleted: Optional[bool] = Field(False, description="Whether session is soft-deleted") deleted_at: Optional[str] = Field(None, alias="deletedAt", description="ISO 8601 timestamp of deletion") + # OAuth consent state + pending_interrupts: Optional[List[PendingInterrupt]] = Field( + default=None, + alias="pendingInterrupts", + description="Pending OAuth consent interrupts that paused agent turns in this session", + ) + class UpdateSessionMetadataRequest(BaseModel): """Request body for updating session metadata""" @@ -298,3 +330,8 @@ class MessagesListResponse(BaseModel): messages: List[MessageResponse] = Field(..., description="List of messages in the session") next_token: Optional[str] = Field(None, alias="nextToken", description="Pagination token for retrieving the next page of results") + pending_interrupts: List[PendingInterrupt] = Field( + default_factory=list, + alias="pendingInterrupts", + description="OAuth consent interrupts that paused agent turns in this session and are awaiting user action", + ) diff --git a/frontend/ai.client/src/app/components/oauth-consent-banner/oauth-consent-banner.component.ts b/frontend/ai.client/src/app/components/oauth-consent-banner/oauth-consent-banner.component.ts deleted file mode 100644 index 17b24d8f..00000000 --- a/frontend/ai.client/src/app/components/oauth-consent-banner/oauth-consent-banner.component.ts +++ /dev/null @@ -1,116 +0,0 @@ -import { Component, ChangeDetectionStrategy, inject } from '@angular/core'; -import { NgIcon, provideIcons } from '@ng-icons/core'; -import { heroLink, heroArrowTopRightOnSquare, heroXMark } from '@ng-icons/heroicons/outline'; -import { OAuthConsentService } from '../../services/oauth-consent/oauth-consent.service'; - -/** - * Renders a compact "Connect to X" prompt above the chat input whenever the - * SSE stream surfaces an `oauth_required` event. Clicking the button opens - * AgentCore Identity's consent URL in a popup; {@link OAuthConsentService} - * listens for the completion postMessage and clears the entry automatically. - */ -@Component({ - selector: 'app-oauth-consent-banner', - changeDetection: ChangeDetectionStrategy.OnPush, - imports: [NgIcon], - providers: [provideIcons({ heroLink, heroArrowTopRightOnSquare, heroXMark })], - template: ` - @if (consentService.hasPending()) { -
- @for (request of consentService.pending(); track request.providerId) { -
-
- } -
- } - `, - styles: [` - @keyframes fadeIn { - from { - opacity: 0; - transform: translateY(4px); - } - to { - opacity: 1; - transform: translateY(0); - } - } - - .animate-fade-in { - animation: fadeIn 0.15s ease-out; - } - `], -}) -export class OAuthConsentBannerComponent { - protected consentService = inject(OAuthConsentService); - - labelFor(providerId: string): string { - if (!providerId) { - return 'This tool'; - } - return providerId - .replace(/[-_]+/g, ' ') - .split(' ') - .filter((part) => part.length > 0) - .map((part) => part.charAt(0).toUpperCase() + part.slice(1)) - .join(' '); - } - - connect(providerId: string): void { - this.consentService.openConsentPopup(providerId); - } - - dismiss(providerId: string, event: Event): void { - event.stopPropagation(); - this.consentService.dismiss(providerId); - } -} diff --git a/frontend/ai.client/src/app/components/oauth-consent-prompt/oauth-consent-prompt.component.ts b/frontend/ai.client/src/app/components/oauth-consent-prompt/oauth-consent-prompt.component.ts new file mode 100644 index 00000000..1ee60eb8 --- /dev/null +++ b/frontend/ai.client/src/app/components/oauth-consent-prompt/oauth-consent-prompt.component.ts @@ -0,0 +1,304 @@ +import { ChangeDetectionStrategy, Component, computed, inject, input } from '@angular/core'; +import { NgIcon, provideIcons } from '@ng-icons/core'; +import { + heroAcademicCap, + heroArrowTopRightOnSquare, + heroCloud, + heroCodeBracket, + heroLink, + heroLockClosed, + heroXMark, +} from '@ng-icons/heroicons/outline'; +import { + OAuthConsentRequest, + OAuthConsentService, +} from '../../services/oauth-consent/oauth-consent.service'; +import { UserConnector } from '../../settings/connectors/models/user-connector.model'; +import { UserConnectorsService } from '../../settings/connectors/services/user-connectors.service'; + +/** + * Inline OAuth consent prompt rendered alongside the assistant message whose + * tool call needed authorization. Looks up the connector definition to display + * its icon (admin-uploaded base64 wins over heroicon name) and friendly name, + * delegates click handling to {@link OAuthConsentService}. + */ +@Component({ + selector: 'app-oauth-consent-prompt', + changeDetection: ChangeDetectionStrategy.OnPush, + imports: [NgIcon], + providers: [ + provideIcons({ + heroAcademicCap, + heroArrowTopRightOnSquare, + heroCloud, + heroCodeBracket, + heroLink, + heroLockClosed, + heroXMark, + }), + ], + host: { class: 'block' }, + template: ` +
+ + +
+ @if (iconDataUrl(); as data) { + + } @else { +
+ +
+

+

+

+ @if (isBlocked()) { + Popup blocked. Open + {{ displayName() }} + in a new tab to continue. + } @else { + Connect + {{ displayName() }} + so the assistant can finish this request. + } +

+
+ +
+ @if (isBlocked() && request().authorizationUrl; as blockedUrl) { + + Open + + } @else { + + } + +
+
+ `, + styles: ` + @import 'tailwindcss'; + @custom-variant dark (&:where(.dark, .dark *)); + + :host { + display: block; + } + + .oauth-prompt { + animation: oauth-rise 0.32s cubic-bezier(0.16, 1, 0.3, 1); + } + + .action-btn { + display: inline-flex; + align-items: center; + gap: 0.375rem; + border-radius: 0.5rem; + padding: 0.375rem 0.75rem; + font-size: 0.8125rem; + font-weight: 600; + color: white; + background: linear-gradient( + 180deg, + var(--color-primary-500) 0%, + var(--color-primary-600) 100% + ); + box-shadow: + 0 1px 2px rgba(15, 23, 42, 0.12), + inset 0 1px 0 rgba(255, 255, 255, 0.12); + transition: + transform 120ms ease, + box-shadow 160ms ease, + filter 120ms ease; + } + + .action-btn:hover:not(:disabled) { + filter: brightness(1.05); + box-shadow: + 0 4px 14px -4px rgba(15, 23, 42, 0.25), + inset 0 1px 0 rgba(255, 255, 255, 0.18); + } + + .action-btn:active:not(:disabled) { + transform: translateY(1px); + } + + .action-btn:focus-visible { + outline: 2px solid var(--color-primary-500); + outline-offset: 2px; + } + + .action-btn:disabled { + opacity: 0.85; + cursor: default; + } + + /* Dismiss is subtle until the row is hovered or focused. */ + .dismiss-btn { + opacity: 0; + } + + .group:hover .dismiss-btn, + .group:focus-within .dismiss-btn { + opacity: 1; + } + + @keyframes oauth-rise { + from { + opacity: 0; + transform: translateY(6px); + } + to { + opacity: 1; + transform: translateY(0); + } + } + + @media (prefers-reduced-motion: reduce) { + .oauth-prompt { + animation: none; + } + .action-btn { + transition: none; + } + } + `, +}) +export class OAuthConsentPromptComponent { + request = input.required(); + + protected consentService = inject(OAuthConsentService); + private connectorsService = inject(UserConnectorsService); + + /** Connector definition for this providerId, when the catalog is loaded. */ + private connector = computed(() => { + const connectors = this.connectorsService.connectorsResource.value(); + if (!connectors) return null; + return connectors.find((c) => c.providerId === this.request().providerId) ?? null; + }); + + /** Admin-uploaded base64 icon. Wins over heroicon when present. */ + protected iconDataUrl = computed(() => this.connector()?.iconData ?? null); + + /** Heroicon fallback when no iconData exists. Mirrors the cascade used by + * the connectors-settings page so the same connector renders identically. */ + protected iconName = computed(() => { + const c = this.connector(); + if (c?.iconName) return c.iconName; + switch (c?.providerType) { + case 'google': + case 'microsoft': + return 'heroCloud'; + case 'github': + return 'heroCodeBracket'; + case 'canvas': + return 'heroAcademicCap'; + default: + return 'heroLink'; + } + }); + + protected displayName = computed(() => { + const c = this.connector(); + if (c?.displayName) return c.displayName; + return this.titleCase(this.request().providerId); + }); + + protected isInFlight = computed(() => + this.consentService.isInFlight(this.request().providerId), + ); + + protected isBlocked = computed(() => + this.consentService.isBlocked(this.request().providerId), + ); + + connect(): void { + void this.consentService.openConsentPopup(this.request().providerId); + } + + dismiss(event: Event): void { + event.stopPropagation(); + this.consentService.dismiss(this.request().providerId); + } + + private titleCase(providerId: string): string { + if (!providerId) return 'This tool'; + return providerId + .replace(/[-_]+/g, ' ') + .split(' ') + .filter((part) => part.length > 0) + .map((part) => part.charAt(0).toUpperCase() + part.slice(1)) + .join(' '); + } +} diff --git a/frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.ts b/frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.ts index d6d8e90e..84ce5e91 100644 --- a/frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.ts +++ b/frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.ts @@ -1,6 +1,8 @@ import { Injectable, signal, computed, inject, DestroyRef } from '@angular/core'; import { takeUntilDestroyed } from '@angular/core/rxjs-interop'; import { fromEvent } from 'rxjs'; +import { UserConnectorsService } from '../../settings/connectors/services/user-connectors.service'; +import { SessionService } from '../../session/services/session/session.service'; /** * Pending OAuth consent request surfaced by the backend when an external @@ -13,9 +15,21 @@ import { fromEvent } from 'rxjs'; */ export interface OAuthConsentRequest { providerId: string; - authorizationUrl: string; + /** Authorization URL captured from a live `oauth_required` SSE event. + * Absent on requests hydrated from session metadata after a refresh — + * AgentCore's URLs expire quickly, so the service re-fetches a fresh one + * via `initiate-consent` when the user clicks Connect. */ + authorizationUrl?: string; interruptId?: string; receivedAt: number; + /** Id of the assistant message whose tool call triggered this consent request. + * Used by the inline message renderer to anchor the prompt to the turn that + * needs it. Omitted for proactive consents from the settings page. */ + messageId?: string; + /** Session id the request belongs to. Required for the backend dismiss + * endpoint to clear the persisted breadcrumb so a refresh doesn't + * resurrect a dismissed prompt. */ + sessionId?: string; } /** @@ -33,10 +47,15 @@ export interface OAuthCompleteMessage { /** * Handler the chat layer registers to resume a paused agent turn after * one or more OAuth consents complete. Receives the interrupt ids whose - * tokens are now available; the handler is expected to POST a resume + * tokens are now available, plus the originating session id so the + * handler can resume even when the live ``lastRequestObject`` is gone + * (post-refresh hydration). The handler is expected to POST a resume * request to `/invocations` with `interrupt_responses` populated. */ -export type OAuthResumeHandler = (interruptIds: string[]) => void | Promise; +export type OAuthResumeHandler = ( + interruptIds: string[], + context?: { sessionId?: string }, +) => void | Promise; function isOAuthCompleteMessage(data: unknown): data is OAuthCompleteMessage { if (!data || typeof data !== 'object') { @@ -76,6 +95,8 @@ function isSafeConsentUrl(raw: string): boolean { @Injectable({ providedIn: 'root' }) export class OAuthConsentService { private readonly destroyRef = inject(DestroyRef); + private readonly userConnectorsService = inject(UserConnectorsService); + private readonly sessionService = inject(SessionService); /** Map of providerId → request. A provider only appears once, even if * the backend emits duplicates mid-stream. */ @@ -153,8 +174,17 @@ export class OAuthConsentService { * * Rejects non-https URLs — see {@link isSafeConsentUrl}. */ - requestConsent(providerId: string, authorizationUrl: string, interruptId?: string): void { - if (!isSafeConsentUrl(authorizationUrl)) { + requestConsent( + providerId: string, + authorizationUrl: string | undefined, + interruptId?: string, + messageId?: string, + sessionId?: string, + ): void { + // Hydration from session metadata passes undefined — the URL gets fetched + // lazily on Connect. Live SSE flows still pass the URL up front for the + // fast-path popup with no extra roundtrip. + if (authorizationUrl !== undefined && !isSafeConsentUrl(authorizationUrl)) { console.error( 'OAuth consent rejected: authorizationUrl is not https', { providerId }, @@ -167,6 +197,8 @@ export class OAuthConsentService { providerId, authorizationUrl, interruptId, + messageId, + sessionId, receivedAt: Date.now(), }); return next; @@ -193,19 +225,70 @@ export class OAuthConsentService { * Returns true if the popup opened, false if it was blocked or the URL * failed validation. Callers can use this to trigger a fallback UI. */ - openConsentPopup(providerId: string): boolean { + async openConsentPopup(providerId: string): Promise { const request = this.requests().get(providerId); if (!request) { return false; } + // Hydrated requests don't carry a URL — fetch a fresh one. Stored URLs + // can also be stale if the live SSE event fired more than a few minutes + // ago, so we treat any missing/expired URL as a refresh trigger. + let authorizationUrl = request.authorizationUrl; + if (!authorizationUrl) { + this.inFlight.update((set) => { + const next = new Set(set); + next.add(providerId); + return next; + }); + try { + const response = await this.userConnectorsService.initiateConsent(providerId); + if (response.connected || !response.authorizationUrl) { + // Already consented while paused (e.g. the user authorized in + // another tab). Drop the request and let the resume handler — if + // any — fire so the agent can finish the turn. + this.dismiss(providerId); + if (request.interruptId && this.resumeHandler) { + void Promise.resolve(this.resumeHandler([request.interruptId])).catch((err) => + console.error('OAuth resume handler failed after pre-consented refresh', err), + ); + } + return false; + } + authorizationUrl = response.authorizationUrl; + this.requests.update((map) => { + const next = new Map(map); + const current = next.get(providerId); + if (current) { + next.set(providerId, { ...current, authorizationUrl }); + } + return next; + }); + } catch (err) { + console.error('Failed to fetch fresh authorization URL', err); + this.inFlight.update((set) => { + if (!set.has(providerId)) return set; + const next = new Set(set); + next.delete(providerId); + return next; + }); + return false; + } + } + // Re-validate on the hot path even though requestConsent already // checked — defensive against anyone mutating the stored entry. - if (!isSafeConsentUrl(request.authorizationUrl)) { + if (!isSafeConsentUrl(authorizationUrl)) { console.error( 'OAuth consent rejected at open: authorizationUrl is not https', { providerId }, ); + this.inFlight.update((set) => { + if (!set.has(providerId)) return set; + const next = new Set(set); + next.delete(providerId); + return next; + }); return false; } @@ -227,7 +310,7 @@ export class OAuthConsentService { 'location=no', ].join(','); - const popup = window.open(request.authorizationUrl, `oauth-${providerId}`, features); + const popup = window.open(authorizationUrl, `oauth-${providerId}`, features); if (!popup) { this.blocked.update((set) => { @@ -324,7 +407,7 @@ export class OAuthConsentService { */ getAuthorizationUrl(providerId: string): string | null { const request = this.requests().get(providerId); - return request ? request.authorizationUrl : null; + return request?.authorizationUrl ?? null; } /** @@ -338,10 +421,20 @@ export class OAuthConsentService { } /** - * Clear a single consent request — called from the UI after the user - * completes or dismisses a provider, or when the chat is reset. + * Drop a single consent request from local state, and (when called from + * the UI's explicit dismiss button) clear the persisted breadcrumb so a + * refresh doesn't resurrect the prompt. + * + * On completion-driven cleanup ({@link handleCompletion}) we set + * ``syncServer: false`` because the resume request that follows will + * remove the same breadcrumb server-side — a separate DELETE would just + * be redundant network noise. */ - dismiss(providerId: string): void { + dismiss(providerId: string, options?: { syncServer?: boolean }): void { + const entry = this.requests().get(providerId); + const sessionId = entry?.sessionId; + const interruptId = entry?.interruptId; + this.requests.update((map) => { if (!map.has(providerId)) { return map; @@ -366,6 +459,18 @@ export class OAuthConsentService { next.delete(providerId); return next; }); + + if (options?.syncServer === false || !sessionId || !interruptId) { + return; + } + + // Best-effort: a backend cleanup failure shouldn't block the UI from + // hiding the prompt — the prompt is already gone locally. + void this.sessionService + .dismissPendingInterrupt(sessionId, interruptId) + .catch((err) => { + console.warn('Failed to clear persisted pending_interrupt; local dismiss still applied', err); + }); } /** Reset all state (new session, logout). */ @@ -396,14 +501,18 @@ export class OAuthConsentService { // Capture the paused interrupt id BEFORE dismissing the request, since // dismiss removes the entry the handler needs. A user-initiated // settings-page consent has no interruptId — nothing to resume. + // Skip server sync: the resume request fired below clears the persisted + // interrupt server-side, so a separate DELETE would just be redundant. const request = this.requests().get(message.providerId); - this.dismiss(message.providerId); + this.dismiss(message.providerId, { syncServer: false }); if (!request?.interruptId || !this.resumeHandler) { return; } - void Promise.resolve(this.resumeHandler([request.interruptId])).catch((err) => { + void Promise.resolve( + this.resumeHandler([request.interruptId], { sessionId: request.sessionId }), + ).catch((err) => { // Resume failures are surfaced through the resume request's own error // handling — log here for diagnostics but don't crash the consent flow. console.error('OAuth resume handler failed', err); diff --git a/frontend/ai.client/src/app/session/components/chat-input/chat-input.component.html b/frontend/ai.client/src/app/session/components/chat-input/chat-input.component.html index 471cb627..d219ddf8 100644 --- a/frontend/ai.client/src/app/session/components/chat-input/chat-input.component.html +++ b/frontend/ai.client/src/app/session/components/chat-input/chat-input.component.html @@ -1,6 +1,3 @@ - - - diff --git a/frontend/ai.client/src/app/session/components/chat-input/chat-input.component.ts b/frontend/ai.client/src/app/session/components/chat-input/chat-input.component.ts index 5f4024a9..0d67ce96 100644 --- a/frontend/ai.client/src/app/session/components/chat-input/chat-input.component.ts +++ b/frontend/ai.client/src/app/session/components/chat-input/chat-input.component.ts @@ -10,7 +10,6 @@ import { import { heroPaperAirplaneSolid, heroStopSolid } from '@ng-icons/heroicons/solid'; import { ModelDropdownComponent } from '../../../components/model-dropdown/model-dropdown.component'; import { QuotaWarningBannerComponent } from '../../../components/quota-warning-banner/quota-warning-banner.component'; -import { OAuthConsentBannerComponent } from '../../../components/oauth-consent-banner/oauth-consent-banner.component'; import { TooltipDirective } from '../../../components/tooltip'; import { FileCardComponent } from '../../../components/file-card'; import { StorageQuotaBannerComponent } from '../../../components/storage-quota-banner'; @@ -33,7 +32,7 @@ interface Message { @Component({ selector: 'app-chat-input', - imports: [FormsModule, ModelDropdownComponent, NgIcon, QuotaWarningBannerComponent, OAuthConsentBannerComponent, StorageQuotaBannerComponent, TooltipDirective, FileCardComponent], + imports: [FormsModule, ModelDropdownComponent, NgIcon, QuotaWarningBannerComponent, StorageQuotaBannerComponent, TooltipDirective, FileCardComponent], providers: [ provideIcons({ heroPlus, diff --git a/frontend/ai.client/src/app/session/components/message-list/components/assistant-message.component.ts b/frontend/ai.client/src/app/session/components/message-list/components/assistant-message.component.ts index b0078598..68077336 100644 --- a/frontend/ai.client/src/app/session/components/message-list/components/assistant-message.component.ts +++ b/frontend/ai.client/src/app/session/components/message-list/components/assistant-message.component.ts @@ -1,4 +1,4 @@ -import { ChangeDetectionStrategy, Component, input, computed } from '@angular/core'; +import { ChangeDetectionStrategy, Component, computed, inject, input } from '@angular/core'; import { Message, ContentBlock, ToolUseData } from '../../../services/models/message.model'; import { ToolUseComponent } from './tool-use'; import { ToolRailComponent } from './tool-rail'; @@ -6,6 +6,11 @@ import { ToolCallGroup, ToolCallDisplay } from './tool-rail/tool-rail.model'; import { ReasoningContentComponent } from './reasoning-content'; import { StreamingTextComponent } from './streaming-text.component'; import { InlineVisualComponent } from './inline-visual'; +import { OAuthConsentPromptComponent } from '../../../../components/oauth-consent-prompt/oauth-consent-prompt.component'; +import { + OAuthConsentRequest, + OAuthConsentService, +} from '../../../../services/oauth-consent/oauth-consent.service'; // ────────────────────────────────────────────────────────────── // 🔧 MOCK FLAG — set to true to render 10 fake tool calls @@ -98,7 +103,13 @@ const MOCK_TOOL_GROUP: ToolCallGroup = { * promoted visuals and grouped tool rails. */ interface DisplayBlock { - type: 'text' | 'tool_group' | 'tool_use_minimized' | 'promoted_visual' | 'reasoningContent'; + type: + | 'text' + | 'tool_group' + | 'tool_use_minimized' + | 'promoted_visual' + | 'reasoningContent' + | 'oauth_required'; data?: ContentBlock; // For tool groups (inline rail) group?: ToolCallGroup; @@ -106,6 +117,8 @@ interface DisplayBlock { uiType?: string; payload?: unknown; toolUseId?: string; + // For inline OAuth consent prompts + oauthRequest?: OAuthConsentRequest; } @Component({ @@ -117,6 +130,7 @@ interface DisplayBlock { ReasoningContentComponent, StreamingTextComponent, InlineVisualComponent, + OAuthConsentPromptComponent, ], template: `
@@ -182,6 +196,14 @@ interface DisplayBlock { />
} + @case ('oauth_required') { +
+ +
+ } } }

@@ -236,6 +258,8 @@ export class AssistantMessageComponent { message = input.required(); isStreaming = input(false); + private consentService = inject(OAuthConsentService); + /** * Transforms content blocks into display blocks. * - Consecutive non-promoted tool-use blocks are grouped into a single ToolCallGroup @@ -253,6 +277,14 @@ export class AssistantMessageComponent { } const blocks = this.message().content; + const messageId = this.message().id; + // Pending interrupts anchored to this message. Used to flip the matching + // tool_use blocks to ``awaiting_auth`` so the row reads as "paused for + // authorization" instead of an indefinite spinner. + const pendingInterruptsHere = this.consentService + .pending() + .filter((req) => req.messageId === messageId); + const hasPendingInterruptHere = pendingInterruptsHere.length > 0; const result: DisplayBlock[] = []; let pendingToolCalls: ToolCallDisplay[] = []; @@ -307,13 +339,22 @@ export class AssistantMessageComponent { toolUseId: toolUse.toolUseId }); } else { - // Accumulate into the current tool group + // Accumulate into the current tool group. A tool_use with no result + // on a message that has a pending OAuth interrupt is the row that + // got paused — surface that distinct state instead of a forever- + // spinning ``pending``. + const baseStatus = toolUse.status || 'pending'; + const hasNoResult = !toolUse.result; + const status: ToolCallDisplay['status'] = + hasPendingInterruptHere && hasNoResult && baseStatus === 'pending' + ? 'awaiting_auth' + : baseStatus; pendingToolCalls.push({ id: toolUse.toolUseId, toolName: toolUse.name, input: toolUse.input || {}, result: toolUse.result, - status: toolUse.status || 'pending', + status, }); } continue; @@ -323,6 +364,13 @@ export class AssistantMessageComponent { // Flush any remaining tool calls flushToolGroup(); + // Append any pending OAuth consent prompts anchored to this message. + // Tracking through the consent service signal keeps the synthetic prompt + // out of message.content so it is never persisted to the backend. + for (const req of pendingInterruptsHere) { + result.push({ type: 'oauth_required', oauthRequest: req }); + } + return result; }); diff --git a/frontend/ai.client/src/app/session/components/message-list/components/tool-rail/tool-rail.component.ts b/frontend/ai.client/src/app/session/components/message-list/components/tool-rail/tool-rail.component.ts index 209c3b6a..35c51277 100644 --- a/frontend/ai.client/src/app/session/components/message-list/components/tool-rail/tool-rail.component.ts +++ b/frontend/ai.client/src/app/session/components/message-list/components/tool-rail/tool-rail.component.ts @@ -81,9 +81,11 @@ export class ToolRailComponent { /** CSS class for status dot */ statusDotClass(call: ToolCallDisplay): string { switch (call.status) { - case 'complete': return 'status-dot bg-green-500'; - case 'pending': return 'status-dot bg-amber-400 shimmer'; - case 'error': return 'status-dot bg-red-500'; + case 'complete': return 'status-dot bg-green-500'; + case 'pending': return 'status-dot bg-amber-400 shimmer'; + case 'error': return 'status-dot bg-red-500'; + case 'awaiting_auth': return 'status-dot bg-primary-500 ring-2 ring-primary-300/40 dark:ring-primary-400/30'; + default: return 'status-dot bg-gray-400'; } } diff --git a/frontend/ai.client/src/app/session/components/message-list/components/tool-rail/tool-rail.model.ts b/frontend/ai.client/src/app/session/components/message-list/components/tool-rail/tool-rail.model.ts index 4dd8fda7..c669c333 100644 --- a/frontend/ai.client/src/app/session/components/message-list/components/tool-rail/tool-rail.model.ts +++ b/frontend/ai.client/src/app/session/components/message-list/components/tool-rail/tool-rail.model.ts @@ -20,8 +20,11 @@ export interface ToolCallDisplay { content: ToolResultContent[]; }; - /** Execution status (from toolUseData.status, defaults to 'pending') */ - status: 'pending' | 'complete' | 'error'; + /** Execution status (from toolUseData.status, defaults to 'pending'). + * ``awaiting_auth`` is derived in the message renderer when the tool was + * paused on an OAuth consent gate — the tool didn't fail, it's waiting + * for the user to authorize. */ + status: 'pending' | 'complete' | 'error' | 'awaiting_auth'; /** Optional LLM-generated one-line summary of this tool call's result */ summary?: string; diff --git a/frontend/ai.client/src/app/session/components/message-list/message-list.component.html b/frontend/ai.client/src/app/session/components/message-list/message-list.component.html index 45751424..a76a0edb 100644 --- a/frontend/ai.client/src/app/session/components/message-list/message-list.component.html +++ b/frontend/ai.client/src/app/session/components/message-list/message-list.component.html @@ -30,6 +30,15 @@ } } + + @for (request of unanchoredInterrupts(); track request.providerId) { +
+ +
+ } + @if (isChatLoading()) {
diff --git a/frontend/ai.client/src/app/session/components/message-list/message-list.component.ts b/frontend/ai.client/src/app/session/components/message-list/message-list.component.ts index 8ca7f641..f0a8acdf 100644 --- a/frontend/ai.client/src/app/session/components/message-list/message-list.component.ts +++ b/frontend/ai.client/src/app/session/components/message-list/message-list.component.ts @@ -1,4 +1,4 @@ -import { Component, input, signal, effect, OnDestroy, inject, PLATFORM_ID } from '@angular/core'; +import { Component, computed, input, signal, effect, OnDestroy, inject, PLATFORM_ID } from '@angular/core'; import { isPlatformBrowser } from '@angular/common'; import { Message } from '../../services/models/message.model'; import { UserMessageComponent } from './components/user-message.component'; @@ -6,10 +6,22 @@ import { AssistantMessageComponent } from './components/assistant-message.compon import { MessageMetadataBadgesComponent } from './components/message-metadata-badges.component'; import { CitationDisplayComponent } from '../citation-display/citation-display.component'; import { PulsatingLoaderComponent } from '../../../components/pulsating-loader.component'; +import { OAuthConsentPromptComponent } from '../../../components/oauth-consent-prompt/oauth-consent-prompt.component'; +import { + OAuthConsentRequest, + OAuthConsentService, +} from '../../../services/oauth-consent/oauth-consent.service'; @Component({ selector: 'app-message-list', - imports: [UserMessageComponent, AssistantMessageComponent, MessageMetadataBadgesComponent, CitationDisplayComponent, PulsatingLoaderComponent], + imports: [ + UserMessageComponent, + AssistantMessageComponent, + MessageMetadataBadgesComponent, + CitationDisplayComponent, + PulsatingLoaderComponent, + OAuthConsentPromptComponent, + ], templateUrl: './message-list.component.html', styleUrl: './message-list.component.css', }) @@ -27,6 +39,18 @@ export class MessageListComponent implements OnDestroy { streamingMessageId = input(null); embeddedMode = input(false); + private consentService = inject(OAuthConsentService); + + /** Pending consent prompts whose anchor message id isn't in the loaded + * message list — typically the case when an interrupt fires on a turn + * whose partial assistant message wasn't persisted to AgentCore Memory. + * Rendered at the end of the conversation so the user still sees the + * affordance instead of a silently stalled tool call. */ + protected unanchoredInterrupts = computed(() => { + const ids = new Set(this.messages().map((m) => m.id)); + return this.consentService.pending().filter((req) => !req.messageId || !ids.has(req.messageId)); + }); + // Calculate the spacer height dynamically // This creates space at the bottom so user messages can scroll to the top spacerHeight = signal(0); diff --git a/frontend/ai.client/src/app/session/services/chat/chat-request.service.ts b/frontend/ai.client/src/app/session/services/chat/chat-request.service.ts index b25e83e1..1ce55b26 100644 --- a/frontend/ai.client/src/app/session/services/chat/chat-request.service.ts +++ b/frontend/ai.client/src/app/session/services/chat/chat-request.service.ts @@ -11,7 +11,9 @@ import { ToolService } from '../../../services/tool/tool.service'; import { FileUploadService } from '../../../services/file-upload'; import { FileAttachmentData } from '../models/message.model'; import { OAuthConsentService } from '../../../services/oauth-consent/oauth-consent.service'; +import { ErrorService } from '../../../services/error/error.service'; import { StreamParserService } from './stream-parser.service'; +import { HttpErrorResponse } from '@angular/common/http'; export interface ContentFile { fileName: string; @@ -35,6 +37,7 @@ export class ChatRequestService implements OnDestroy { private fileUploadService = inject(FileUploadService); private oauthConsentService = inject(OAuthConsentService); private streamParserService = inject(StreamParserService); + private errorService = inject(ErrorService); private router = inject(Router); // TODO: Inject proper logging service @@ -44,8 +47,8 @@ export class ChatRequestService implements OnDestroy { private lastRequestObject: Record | null = null; constructor() { - this.oauthConsentService.setResumeHandler((interruptIds) => - this.resumeFromOAuthConsent(interruptIds), + this.oauthConsentService.setResumeHandler((interruptIds, context) => + this.resumeFromOAuthConsent(interruptIds, context?.sessionId), ); } @@ -180,12 +183,20 @@ export class ChatRequestService implements OnDestroy { * one. Triggered by OAuthConsentService after the user completes a * consent popup. */ - private async resumeFromOAuthConsent(interruptIds: string[]): Promise { - if (!this.lastRequestObject || interruptIds.length === 0) { + private async resumeFromOAuthConsent( + interruptIds: string[], + fallbackSessionId?: string, + ): Promise { + if (interruptIds.length === 0) { return; } - const sessionId = this.lastRequestObject['session_id'] as string | undefined; + // Live flow: the same tab that originated the turn still has its + // payload — replay it with `interrupt_responses` attached. + // Refresh flow: ``lastRequestObject`` is null, so we synthesize a + // minimal resume payload from the consent service's session context. + const liveSessionId = this.lastRequestObject?.['session_id'] as string | undefined; + const sessionId = liveSessionId ?? fallbackSessionId; if (!sessionId) { return; } @@ -198,8 +209,12 @@ export class ChatRequestService implements OnDestroy { this.chatStateService.createNewAbortController(); this.chatStateService.setChatLoading(true); + const baseRequest: Record = this.lastRequestObject + ? { ...this.lastRequestObject } + : { session_id: sessionId }; + const resumeRequest: Record = { - ...this.lastRequestObject, + ...baseRequest, // The original prompt is already in the agent's interrupt context; // sending an empty string keeps the request valid without // re-augmenting or re-charging quota. @@ -218,10 +233,42 @@ export class ChatRequestService implements OnDestroy { } catch (error) { this.chatStateService.setChatLoading(false); this.messageMapService.endStreaming(); + + // 400 from the resume route means the agent's `_interrupt_state` no + // longer holds the submitted ids — the cache evicted, the pod + // restarted, or the breadcrumb outlived its agent. Surface a + // conversational error so the user knows to retry the prompt + // instead of staring at a stuck spinner. + if (this.isExpiredInterruptError(error)) { + this.errorService.addError( + 'Authorization expired', + 'The agent paused too long ago to resume this turn automatically. Please send your message again.', + ); + return; + } throw error; } } + /** Detect the 400 the inference-api returns for unknown/expired interrupt + * ids. Both fetch-based and HttpClient-based flows are checked because + * the resume path uses `fetch-event-source`, which surfaces errors as + * plain Error/Response objects rather than HttpErrorResponse. */ + private isExpiredInterruptError(error: unknown): boolean { + if (error instanceof HttpErrorResponse) { + return error.status === 400; + } + if (typeof error === 'object' && error !== null) { + const status = (error as { status?: unknown }).status; + if (status === 400) return true; + const message = (error as { message?: unknown }).message; + if (typeof message === 'string' && /expired interrupt/i.test(message)) { + return true; + } + } + return false; + } + /** * Get file attachment metadata for display in user messages. * Retrieves file metadata from FileUploadService for given upload IDs. diff --git a/frontend/ai.client/src/app/session/services/chat/stream-parser.service.ts b/frontend/ai.client/src/app/session/services/chat/stream-parser.service.ts index 4e925592..c8e4c61d 100644 --- a/frontend/ai.client/src/app/session/services/chat/stream-parser.service.ts +++ b/frontend/ai.client/src/app/session/services/chat/stream-parser.service.ts @@ -295,12 +295,27 @@ export class StreamParserService { onQuotaWarning: (data) => this.quotaWarningService.setWarning(data as QuotaWarning), onQuotaExceeded: (data) => this.quotaWarningService.setQuotaExceeded(data as QuotaExceeded), - onOAuthRequired: (data: OAuthRequiredEvent) => + onOAuthRequired: (data: OAuthRequiredEvent) => { + // oauth_required arrives after message_stop, so the triggering + // assistant message is normally in completedMessages; fall back + // to the in-flight builder for tool_use stop reasons that keep + // the message active. + const messages = this.allMessages(); + let lastAssistantId: string | undefined; + for (let i = messages.length - 1; i >= 0; i--) { + if (messages[i].role === 'assistant') { + lastAssistantId = messages[i].id; + break; + } + } this.oauthConsentService.requestConsent( data.providerId, data.authorizationUrl, data.interruptId, - ), + lastAssistantId, + this.sessionId ?? undefined, + ); + }, onError: (data) => this.handleError(data), onStreamError: (data) => diff --git a/frontend/ai.client/src/app/session/services/session/message-map.service.ts b/frontend/ai.client/src/app/session/services/session/message-map.service.ts index f4d7197d..2108ef42 100644 --- a/frontend/ai.client/src/app/session/services/session/message-map.service.ts +++ b/frontend/ai.client/src/app/session/services/session/message-map.service.ts @@ -2,8 +2,9 @@ import { Injectable, Signal, WritableSignal, signal, effect, inject } from '@angular/core'; import { ContentBlock, FileAttachmentData, Message } from '../models/message.model'; import { StreamParserService } from '../chat/stream-parser.service'; -import { SessionService } from './session.service'; +import { PendingInterrupt, SessionService } from './session.service'; import { FileUploadService, FileMetadata } from '../../../services/file-upload'; +import { OAuthConsentService } from '../../../services/oauth-consent/oauth-consent.service'; /** Regex to match file attachment marker in message text: [Attached files: file1.pdf, file2.png] */ const ATTACHED_FILES_PATTERN = /\n\n\[Attached files: ([^\]]+)\]$/; @@ -42,6 +43,7 @@ export class MessageMapService { private streamParser = inject(StreamParserService); private sessionService = inject(SessionService); private fileUploadService = inject(FileUploadService); + private oauthConsentService = inject(OAuthConsentService); constructor() { // Reactive effect: automatically sync streaming messages to the message map @@ -278,6 +280,13 @@ export class MessageMapService { return updated; }); + + // Hydrate OAuth consent service from any persisted pending interrupts + // so a reload restores the consent prompt anchored to the right turn. + // Anchor falls back to the most recent assistant message in history, + // matching how the live SSE flow already attaches prompts. Authorization + // URL is intentionally omitted — fresh URL is fetched lazily on Connect. + this.hydratePendingInterrupts(sessionId, messagesResponse.pending_interrupts, processedMessages); } catch (error) { console.error('Failed to load messages for session:', sessionId, error); throw error; @@ -287,6 +296,38 @@ export class MessageMapService { } } + /** + * Replay persisted pending OAuth interrupts into the consent service so + * the inline prompt re-renders. If the backend records a triggering message + * id we use it directly; otherwise we anchor to the most recent assistant + * message (mirroring the live-stream behavior in stream-parser.service.ts). + */ + private hydratePendingInterrupts( + sessionId: string, + interrupts: PendingInterrupt[] | undefined, + messages: Message[], + ): void { + if (!interrupts || interrupts.length === 0) { + return; + } + let lastAssistantId: string | undefined; + for (let i = messages.length - 1; i >= 0; i--) { + if (messages[i].role === 'assistant') { + lastAssistantId = messages[i].id; + break; + } + } + for (const interrupt of interrupts) { + this.oauthConsentService.requestConsent( + interrupt.provider_id, + undefined, // URL is fetched lazily on Connect — stored URLs go stale + interrupt.interrupt_id, + interrupt.triggering_message_id ?? lastAssistantId, + sessionId, + ); + } + } + /** * Fetch file metadata for a session from the backend. * Returns empty array on error to allow graceful degradation. diff --git a/frontend/ai.client/src/app/session/services/session/session.service.ts b/frontend/ai.client/src/app/session/services/session/session.service.ts index 96dd45e1..c19ef353 100644 --- a/frontend/ai.client/src/app/session/services/session/session.service.ts +++ b/frontend/ai.client/src/app/session/services/session/session.service.ts @@ -28,9 +28,27 @@ export interface SessionsListResponse { nextToken: string | null; } +/** + * Pending OAuth consent interrupt that paused an agent turn for this session. + * + * Returned from `GET /sessions/{id}/messages` so the frontend can rediscover + * pending consents after a browser refresh. `authorization_url` is intentionally + * absent — those expire quickly; the frontend re-fetches on Connect. + */ +export interface PendingInterrupt { + /** Strands interrupt id used to resume the paused turn */ + interrupt_id: string; + /** Connector providerId needing consent */ + provider_id: string; + /** Id of the assistant message whose tool call triggered this interrupt, if known */ + triggering_message_id?: string | null; + /** ISO 8601 timestamp when the interrupt was recorded */ + created_at: string; +} + /** * Response model for listing messages with pagination support. - * + * * Matches the MessagesListResponse model from the Python API. */ export interface MessagesListResponse { @@ -38,6 +56,8 @@ export interface MessagesListResponse { messages: Message[]; /** Pagination token for retrieving the next page of results */ next_token: string | null; + /** OAuth consent interrupts that paused agent turns and are awaiting user action */ + pending_interrupts?: PendingInterrupt[]; } /** @@ -423,11 +443,11 @@ export class SessionService { */ async getMessages(sessionId: string, params?: GetMessagesParams): Promise { let httpParams = new HttpParams(); - + if (params?.limit !== undefined) { httpParams = httpParams.set('limit', params.limit.toString()); } - + if (params?.next_token) { httpParams = httpParams.set('next_token', params.next_token); } @@ -446,6 +466,19 @@ export class SessionService { } } + /** + * Dismiss a persisted OAuth pending interrupt for a session. Idempotent — + * the backend returns 204 even if the entry is already gone. + */ + async dismissPendingInterrupt(sessionId: string, interruptId: string): Promise { + // The interrupt id contains a colon (e.g. ``oauth:google-calendar``); + // encode it so it survives URL parsing on the path. + const encoded = encodeURIComponent(interruptId); + await firstValueFrom( + this.http.delete(`${this.baseUrl()}/${sessionId}/pending-interrupts/${encoded}`), + ); + } + /** * Fetches metadata for a specific session from the Python API. * From 94fcf240fc90b45c27ace33ad53738d3b1bb125e Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Sat, 25 Apr 2026 21:28:43 -0600 Subject: [PATCH 25/35] feat(connectors): add OAuth consent prompt component for authorization handling --- .../oauth-consent-prompt/oauth-consent-prompt.component.ts | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename frontend/ai.client/src/app/{ => session/components/message-list}/components/oauth-consent-prompt/oauth-consent-prompt.component.ts (100%) diff --git a/frontend/ai.client/src/app/components/oauth-consent-prompt/oauth-consent-prompt.component.ts b/frontend/ai.client/src/app/session/components/message-list/components/oauth-consent-prompt/oauth-consent-prompt.component.ts similarity index 100% rename from frontend/ai.client/src/app/components/oauth-consent-prompt/oauth-consent-prompt.component.ts rename to frontend/ai.client/src/app/session/components/message-list/components/oauth-consent-prompt/oauth-consent-prompt.component.ts From 18b96c90af6263a1eb96d8206e8228227e4224ce Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Sat, 25 Apr 2026 21:28:53 -0600 Subject: [PATCH 26/35] feat: enhance session metadata management and update handling - Add functions to ensure session metadata existence and update session title and activity. - Implement logic for handling session activity updates, including message count increments and preferences merging. - Introduce deduplication for pending interrupts to prevent duplicate entries during session updates. - Update frontend components to reflect changes in session management, including OAuth consent prompts and message handling. - Refactor session service interfaces to use camelCase for consistency with backend responses. - Enhance tests for session activity updates, pending interrupts, and ensure proper handling of session metadata. --- .../streaming/stream_coordinator.py | 160 ++-------- backend/src/apis/inference_api/chat/routes.py | 24 +- .../src/apis/inference_api/chat/service.py | 116 +------ backend/src/apis/shared/sessions/__init__.py | 6 + backend/src/apis/shared/sessions/metadata.py | 283 +++++++++++++++++- .../tests/shared/test_sessions_metadata.py | 264 ++++++++++++++++ .../components/assistant-message.component.ts | 2 +- .../oauth-consent-prompt.component.ts | 58 ++-- .../message-list/message-list.component.ts | 2 +- .../services/chat/chat-http.service.ts | 39 ++- .../session/message-map.service.spec.ts | 47 ++- .../services/session/message-map.service.ts | 8 +- .../services/session/session.service.spec.ts | 2 +- .../services/session/session.service.ts | 16 +- 14 files changed, 703 insertions(+), 324 deletions(-) diff --git a/backend/src/agents/main_agent/streaming/stream_coordinator.py b/backend/src/agents/main_agent/streaming/stream_coordinator.py index 7577274f..0a28207f 100644 --- a/backend/src/agents/main_agent/streaming/stream_coordinator.py +++ b/backend/src/agents/main_agent/streaming/stream_coordinator.py @@ -1180,149 +1180,39 @@ async def _calculate_streaming_cost(self, model_id: str, usage: Dict[str, Any]) return None async def _update_session_metadata(self, session_id: str, user_id: str, message_id: int, agent: Any = None) -> None: - """ - Update session-level metadata after each message - - This updates conversation-level tracking after each message: - - lastMessageAt: Timestamp of this message - - messageCount: Incremented by 1 - - preferences: Model/temperature/tools/system_prompt_hash from agent config - - Auto-creates session metadata on first message + """Update per-turn session activity (lastMessageAt, messageCount, preferences). - Args: - session_id: Session identifier - user_id: User identifier - message_id: Message ID that was just flushed - agent: Agent instance for extracting model preferences + Delegates to ``update_session_activity``, which uses targeted writes + so concurrent writers (title-gen, pending-interrupt persistence) + cannot be clobbered. Pre-create is handled at /invocations entry, so + no lazy-create branch is needed here. """ try: import hashlib - from apis.shared.sessions.models import SessionMetadata, SessionPreferences - from apis.shared.sessions.metadata import get_session_metadata, store_session_metadata - - logger.info(f"🔍 _update_session_metadata called for session {session_id}, message_id {message_id}") - - # Get existing metadata or create new - existing = await get_session_metadata(session_id, user_id) - - if existing: - logger.info(f"📄 Found existing metadata: messageCount={existing.message_count}, has_preferences={existing.preferences is not None}") - else: - logger.info(f"📄 No existing metadata found - creating new") - - # Calculate message count incrementally - # NOTE: We cannot query AgentCore Memory immediately after flush due to eventual consistency. - # The turn-based session manager calls create_message() then immediately calls list_messages(), - # but the newly created message is not yet available for reading (can take several seconds). - # - # Instead, we use incremental counting: - # - Each streaming turn creates 1 merged message in AgentCore Memory - # - We increment the count by 1 per turn - # - # This count represents "turns" (user-assistant exchanges), not individual message events. - # Tool use creates multiple content blocks within a single turn/message. - if not existing: - actual_message_count = 1 - logger.info(f"📊 First turn in session - message_count: {actual_message_count}") - else: - actual_message_count = existing.message_count + 1 - logger.info(f"📊 Incremental turn count: {existing.message_count} + 1 = {actual_message_count}") - - now = datetime.now(timezone.utc).isoformat() - - if not existing: - # First message - create session metadata - preferences = None - if agent and hasattr(agent, "model_config"): - logger.info(f"📦 Agent has model_config: model_id={agent.model_config.model_id}") - - # Generate system prompt hash for tracking exact prompt version - # This hash represents the FINAL rendered system prompt (after date injection, etc.) - system_prompt_hash = None - if hasattr(agent, "system_prompt") and agent.system_prompt: - system_prompt_hash = hashlib.md5(agent.system_prompt.encode()).hexdigest()[:16] # 16 char hash for uniqueness - logger.debug(f"Generated system_prompt_hash: {system_prompt_hash}") - - # Extract enabled tools from agent - enabled_tools = getattr(agent, "enabled_tools", None) - - preferences = SessionPreferences( - last_model=agent.model_config.model_id, - last_temperature=getattr(agent.model_config, "temperature", None), - enabled_tools=enabled_tools, - system_prompt_hash=system_prompt_hash, - ) - logger.info(f"✨ Created new preferences: last_model={preferences.last_model}") - else: - logger.warning(f"⚠️ Agent is None or missing model_config") + from apis.shared.sessions.metadata import update_session_activity - metadata = SessionMetadata( - session_id=session_id, - user_id=user_id, - title="New Conversation", # Will be updated by frontend - status="active", - created_at=now, - last_message_at=now, - message_count=actual_message_count, - starred=False, - tags=[], - preferences=preferences, - ) + last_model = None + last_temperature = None + enabled_tools = None + system_prompt_hash = None + if agent and hasattr(agent, "model_config"): + last_model = agent.model_config.model_id + last_temperature = getattr(agent.model_config, "temperature", None) + enabled_tools = getattr(agent, "enabled_tools", None) + if hasattr(agent, "system_prompt") and agent.system_prompt: + system_prompt_hash = hashlib.md5(agent.system_prompt.encode()).hexdigest()[:16] else: - # Update existing - only update what changed - preferences = existing.preferences - if agent and hasattr(agent, "model_config"): - logger.info(f"📦 Updating preferences with model_id={agent.model_config.model_id}") - - # Update preferences if model/temperature/tools/system_prompt changed - prefs_dict = preferences.model_dump(by_alias=False) if preferences else {} - logger.info(f"📝 Existing prefs_dict: {prefs_dict}") - - prefs_dict["last_model"] = agent.model_config.model_id - prefs_dict["last_temperature"] = getattr(agent.model_config, "temperature", None) - - # Update enabled_tools from agent - prefs_dict["enabled_tools"] = getattr(agent, "enabled_tools", None) - - # Update system_prompt_hash if system prompt changed - # This allows tracking when the prompt was modified during a conversation - if hasattr(agent, "system_prompt") and agent.system_prompt: - new_hash = hashlib.md5(agent.system_prompt.encode()).hexdigest()[:16] - # Only update if hash changed (prompt was modified) - if prefs_dict.get("system_prompt_hash") != new_hash: - logger.info(f"System prompt changed - updating hash from {prefs_dict.get('system_prompt_hash')} to {new_hash}") - prefs_dict["system_prompt_hash"] = new_hash - - preferences = SessionPreferences(**prefs_dict) - logger.info(f"✨ Updated preferences: last_model={preferences.last_model}") - else: - logger.warning(f"⚠️ Agent is None or missing model_config - keeping existing preferences") - - metadata = SessionMetadata( - session_id=session_id, - user_id=user_id, - title=existing.title, - status=existing.status, - created_at=existing.created_at, - last_message_at=now, - message_count=actual_message_count, - starred=existing.starred, - tags=existing.tags, - preferences=preferences, - ) + logger.warning("⚠️ Agent is None or missing model_config — skipping preference update") - # Store updated metadata (uses deep merge in storage layer) - await store_session_metadata(session_id=session_id, user_id=user_id, session_metadata=metadata) - - logger.info( - f"✅ Updated session metadata - last_model: {metadata.preferences.last_model if metadata.preferences else 'None'}, message_count: {metadata.message_count}" + await update_session_activity( + session_id=session_id, + user_id=user_id, + last_model=last_model, + last_temperature=last_temperature, + enabled_tools=enabled_tools, + system_prompt_hash=system_prompt_hash, ) - - # Return message count for use as a fallback message_id - return metadata.message_count - except Exception as e: logger.error(f"Failed to update session metadata: {e}", exc_info=True) - # Don't raise - metadata failures shouldn't break streaming - return None + # Don't raise — metadata failures shouldn't break streaming. diff --git a/backend/src/apis/inference_api/chat/routes.py b/backend/src/apis/inference_api/chat/routes.py index 5b60ac21..fc09db1a 100644 --- a/backend/src/apis/inference_api/chat/routes.py +++ b/backend/src/apis/inference_api/chat/routes.py @@ -7,6 +7,7 @@ These endpoints are at the root level to comply with AWS Bedrock AgentCore Runtime requirements. """ +import asyncio import json import logging import os @@ -35,9 +36,10 @@ ) from apis.shared.rbac.service import get_app_role_service +from apis.shared.sessions.metadata import ensure_session_metadata_exists from .models import FileContent, InvocationRequest -from .service import get_agent +from .service import generate_conversation_title, get_agent logger = logging.getLogger(__name__) @@ -248,6 +250,26 @@ async def invocations(request: InvocationRequest, current_user: User = Depends(g logger.warning("Failed to resolve file upload IDs") # Continue without files rather than failing the request + # Pre-create session metadata so OAuth interrupts and other state can + # attach to the session row from turn one. Best-effort; on failure the + # post-stream lazy-create in StreamCoordinator still covers it. + is_new_session = False + if not is_resume: + is_new_session = await ensure_session_metadata_exists(input_data.session_id, user_id) + + # First turn → kick off title generation concurrently with the stream. + # Runs as a background task so it doesn't add latency to TTFT. The + # targeted UpdateExpression in update_session_title is race-safe with + # the post-stream _update_session_metadata write. + if is_new_session and input_data.message: + asyncio.create_task( + generate_conversation_title( + session_id=input_data.session_id, + user_id=user_id, + user_input=input_data.message, + ) + ) + # Check quota if enforcement is enabled quota_warning_event = None quota_exceeded_event = None diff --git a/backend/src/apis/inference_api/chat/service.py b/backend/src/apis/inference_api/chat/service.py index 0ab7a706..6a11e380 100644 --- a/backend/src/apis/inference_api/chat/service.py +++ b/backend/src/apis/inference_api/chat/service.py @@ -7,14 +7,12 @@ import hashlib import os from typing import Optional, List, Tuple -from datetime import datetime, timezone import boto3 # from agentcore.agent.agent import ChatbotAgent from agents.main_agent.main_agent import MainAgent -from apis.shared.sessions.models import SessionMetadata -from apis.shared.sessions.metadata import store_session_metadata +from apis.shared.sessions.metadata import update_session_title logger = logging.getLogger(__name__) @@ -289,113 +287,17 @@ async def generate_conversation_title( logger.info(f"✅ Generated title: '{title}' for session {session_id}") - # Update session metadata with the generated title - # IMPORTANT: We must read existing metadata first and only update the title field. - # The streaming coordinator has already set message_count correctly, and we must - # not overwrite it. This function is called async after streaming completes, - # so there's a race condition where we could overwrite the correct message_count - # with 0 if we don't preserve existing values. - from apis.shared.sessions.metadata import get_session_metadata - - logger.info(f"📖 Title generation: Reading existing metadata for session {session_id}") - existing_metadata = await get_session_metadata(session_id, user_id) - - if existing_metadata: - logger.info(f"📊 Title generation: Found existing metadata with message_count={existing_metadata.message_count}") - # Preserve existing metadata, only update title - session_metadata = SessionMetadata( - session_id=session_id, - user_id=user_id, - title=title, # Only update this field - status=existing_metadata.status, - created_at=existing_metadata.created_at, - last_message_at=existing_metadata.last_message_at, - message_count=existing_metadata.message_count, # PRESERVE existing count - starred=existing_metadata.starred, - tags=existing_metadata.tags, - preferences=existing_metadata.preferences - ) - else: - logger.warning(f"⚠️ Title generation: No existing metadata found - creating new with message_count=0") - # Fallback: If metadata doesn't exist yet (rare edge case), create it - # The streaming coordinator will update message_count shortly after - now = datetime.now(timezone.utc).isoformat() - session_metadata = SessionMetadata( - session_id=session_id, - user_id=user_id, - title=title, - status="active", - created_at=now, - last_message_at=now, - message_count=0, # Safe fallback - will be set by streaming coordinator - starred=False, - tags=[], - preferences=None - ) - - logger.info(f"📝 Title generation: About to store metadata with message_count={session_metadata.message_count}") - await store_session_metadata( - session_id=session_id, - user_id=user_id, - session_metadata=session_metadata - ) - - logger.info(f"💾 Title generation: Stored session metadata with title for session {session_id}") + # Targeted update — only writes the title attribute. The post-stream + # update_session_activity write is also targeted and disjoint, so the + # two cannot clobber each other on overlapping turns. + await update_session_title(session_id=session_id, user_id=user_id, title=title) return title except Exception as e: - # Log error but don't fail the request - # Title generation is nice-to-have, not critical + # Title generation is nice-to-have. Leave the existing "New Conversation" + # placeholder in place rather than writing a fallback; the row already + # exists from the pre-create. logger.error(f"Failed to generate title for session {session_id}: {e}", exc_info=True) - - # Return fallback title - fallback_title = "New Conversation" - - # Still try to store metadata with fallback title - # Same as above: preserve existing metadata to avoid race conditions - try: - from apis.shared.sessions.metadata import get_session_metadata - - existing_metadata = await get_session_metadata(session_id, user_id) - - if existing_metadata: - # Preserve existing metadata, only update title - session_metadata = SessionMetadata( - session_id=session_id, - user_id=user_id, - title=fallback_title, - status=existing_metadata.status, - created_at=existing_metadata.created_at, - last_message_at=existing_metadata.last_message_at, - message_count=existing_metadata.message_count, # PRESERVE - starred=existing_metadata.starred, - tags=existing_metadata.tags, - preferences=existing_metadata.preferences - ) - else: - # Fallback: metadata doesn't exist yet - now = datetime.now(timezone.utc).isoformat() - session_metadata = SessionMetadata( - session_id=session_id, - user_id=user_id, - title=fallback_title, - status="active", - created_at=now, - last_message_at=now, - message_count=0, - starred=False, - tags=[], - preferences=None - ) - - await store_session_metadata( - session_id=session_id, - user_id=user_id, - session_metadata=session_metadata - ) - except Exception as metadata_error: - logger.error(f"Failed to store fallback metadata: {metadata_error}") - - return fallback_title + return "New Conversation" diff --git a/backend/src/apis/shared/sessions/__init__.py b/backend/src/apis/shared/sessions/__init__.py index 5b8a7c44..5e479ca0 100644 --- a/backend/src/apis/shared/sessions/__init__.py +++ b/backend/src/apis/shared/sessions/__init__.py @@ -33,6 +33,9 @@ from .metadata import ( store_message_metadata, store_session_metadata, + ensure_session_metadata_exists, + update_session_title, + update_session_activity, get_session_metadata, get_all_message_metadata, list_user_sessions, @@ -69,6 +72,9 @@ # Metadata operations "store_message_metadata", "store_session_metadata", + "ensure_session_metadata_exists", + "update_session_title", + "update_session_activity", "get_session_metadata", "get_all_message_metadata", "list_user_sessions", diff --git a/backend/src/apis/shared/sessions/metadata.py b/backend/src/apis/shared/sessions/metadata.py index 93aea6ee..12f28d92 100644 --- a/backend/src/apis/shared/sessions/metadata.py +++ b/backend/src/apis/shared/sessions/metadata.py @@ -15,7 +15,7 @@ from decimal import Decimal # Relative imports from shared sessions module -from .models import MessageMetadata, PendingInterrupt, SessionMetadata +from .models import MessageMetadata, PendingInterrupt, SessionMetadata, SessionPreferences # Import preview session helper from agents.main_agent.session.preview_session_manager import is_preview_session @@ -697,6 +697,230 @@ async def _store_session_metadata_cloud( ) +async def ensure_session_metadata_exists(session_id: str, user_id: str) -> bool: + """Idempotently create a session metadata row if it doesn't exist yet. + + Uses a conditional ``put_item`` keyed on ``attribute_not_exists(PK)`` so + concurrent first-turn requests for the same session can't clobber each + other. Returns ``True`` when a new row was created (caller can use this + as the "first turn" signal, e.g. to fire title generation). + + No-op for preview sessions, which intentionally skip persistence. + """ + if is_preview_session(session_id): + return False + + sessions_metadata_table = os.environ.get("DYNAMODB_SESSIONS_METADATA_TABLE_NAME") + if not sessions_metadata_table: + raise RuntimeError("DYNAMODB_SESSIONS_METADATA_TABLE_NAME environment variable is required") + + try: + import boto3 + from botocore.exceptions import ClientError + from datetime import datetime, timezone + + dynamodb = boto3.resource("dynamodb") + table = dynamodb.Table(sessions_metadata_table) + + now = datetime.now(timezone.utc).isoformat() + item = { + "PK": f"USER#{user_id}", + "SK": f"S#ACTIVE#{now}#{session_id}", + "GSI_PK": f"SESSION#{session_id}", + "GSI_SK": "META", + "sessionId": session_id, + "userId": user_id, + "title": "New Conversation", + "status": "active", + "createdAt": now, + "lastMessageAt": now, + "messageCount": 0, + "starred": False, + "tags": [], + } + + try: + table.put_item(Item=item, ConditionExpression="attribute_not_exists(PK)") + logger.info(f"💾 Pre-created session metadata for {session_id}") + return True + except ClientError as e: + if e.response.get("Error", {}).get("Code") == "ConditionalCheckFailedException": + # Session already exists — expected for continuing conversations. + return False + raise + except Exception as e: + # Best-effort: failures must not block the stream. update_session_activity + # self-heals by retrying this call once if the row is missing post-stream. + logger.error(f"ensure_session_metadata_exists failed: {e}", exc_info=True) + return False + + +async def update_session_title(session_id: str, user_id: str, title: str) -> None: + """Update only the title attribute on the session row. + + Uses a targeted ``UpdateExpression`` so it can run concurrently with + ``store_session_metadata`` (which does a full-row merge) without racing + on other fields like ``messageCount`` or ``lastMessageAt``. Looks up the + current SK via the GSI because the SK contains a timestamp. + + No-op when the session row doesn't exist (preview sessions, sessions + deleted mid-turn). + """ + if is_preview_session(session_id): + return + + sessions_metadata_table = os.environ.get("DYNAMODB_SESSIONS_METADATA_TABLE_NAME") + if not sessions_metadata_table: + raise RuntimeError("DYNAMODB_SESSIONS_METADATA_TABLE_NAME environment variable is required") + + try: + import boto3 + + dynamodb = boto3.resource("dynamodb") + table = dynamodb.Table(sessions_metadata_table) + + existing = await _get_session_by_gsi(session_id, user_id, table) + if not existing: + logger.info(f"update_session_title: session {session_id} not found, skipping") + return + sk = existing.get("SK") + if not sk: + logger.warning(f"update_session_title: session {session_id} has no SK") + return + + table.update_item( + Key={"PK": f"USER#{user_id}", "SK": sk}, + UpdateExpression="SET title = :t", + ExpressionAttributeValues={":t": title}, + ) + logger.info(f"💾 Updated title for session {session_id}") + except Exception as e: + logger.error(f"update_session_title failed: {e}", exc_info=True) + + +async def update_session_activity( + session_id: str, + user_id: str, + *, + last_model: Optional[str] = None, + last_temperature: Optional[float] = None, + enabled_tools: Optional[List[str]] = None, + system_prompt_hash: Optional[str] = None, +) -> bool: + """Per-turn session activity update with targeted writes. + + Increments ``messageCount``, advances ``lastMessageAt`` to now, and + merges agent-derived preferences. No other attributes are written, so + concurrent writers (``update_session_title``, ``add_pending_interrupt``) + cannot be clobbered by this path. + + Phase A is a targeted ``UpdateExpression`` on the current SK. Phase B + rotates the SK because ``lastMessageAt`` is encoded in it for recency + listing — fresh-read after Phase A, put at the new SK, delete the old. + The Phase B carry picks up any concurrent write that landed between + Phase A and the fresh read; the residual race window is bounded to + that small interval (full elimination requires the schema change in + issue #175). + + Self-heals when the row is missing by calling + ``ensure_session_metadata_exists`` and retrying the lookup once. + No-op for preview sessions. Returns ``True`` when the update applied. + """ + if is_preview_session(session_id): + return False + + sessions_metadata_table = os.environ.get("DYNAMODB_SESSIONS_METADATA_TABLE_NAME") + if not sessions_metadata_table: + raise RuntimeError("DYNAMODB_SESSIONS_METADATA_TABLE_NAME environment variable is required") + + try: + import boto3 + from datetime import datetime, timezone + + dynamodb = boto3.resource("dynamodb") + table = dynamodb.Table(sessions_metadata_table) + + existing = await _get_session_by_gsi(session_id, user_id, table) + if not existing: + # Pre-create may have failed at /invocations entry — try once + # more so we don't lose the session record entirely. + await ensure_session_metadata_exists(session_id, user_id) + existing = await _get_session_by_gsi(session_id, user_id, table) + if not existing: + logger.warning( + "update_session_activity: session %s missing and could not be created", + session_id, + ) + return False + + old_sk = existing.get("SK") + if not old_sk: + logger.warning("update_session_activity: session %s has no SK", session_id) + return False + + # Merge preferences: existing values take effect for keys the + # caller didn't pass (e.g. assistantId set by the assistant-attach + # flow). We replace the whole `preferences` map in one SET so the + # update works whether the attribute exists yet or not — DynamoDB + # disallows updating both a parent path and its children in the + # same expression. + existing_prefs_raw = existing.get("preferences") or {} + try: + existing_prefs = SessionPreferences.model_validate(existing_prefs_raw) + except Exception: + existing_prefs = SessionPreferences() + prefs_dict = existing_prefs.model_dump(by_alias=False, exclude_none=True) + if last_model is not None: + prefs_dict["last_model"] = last_model + if last_temperature is not None: + prefs_dict["last_temperature"] = last_temperature + if enabled_tools is not None: + prefs_dict["enabled_tools"] = enabled_tools + if system_prompt_hash is not None: + prefs_dict["system_prompt_hash"] = system_prompt_hash + merged_prefs = SessionPreferences(**prefs_dict).model_dump(by_alias=True, exclude_none=True) + + now = datetime.now(timezone.utc).isoformat() + pk = f"USER#{user_id}" + + # Phase A: targeted update of owned attributes on the current SK. + # Disjoint from title, starred, tags, pendingInterrupts. + table.update_item( + Key={"PK": pk, "SK": old_sk}, + UpdateExpression="ADD messageCount :one SET lastMessageAt = :t, preferences = :p", + ExpressionAttributeValues={ + ":one": 1, + ":t": now, + ":p": _convert_floats_to_decimal(merged_prefs), + }, + ) + + # Phase B: SK rotation. lastMessageAt is encoded in the SK for + # recency listing, so a per-turn change forces a row move. Fresh + # read carries any concurrent write (e.g. title-gen) that landed + # between Phase A and now. + new_sk = f"S#ACTIVE#{now}#{session_id}" + if new_sk != old_sk: + fresh_resp = table.get_item(Key={"PK": pk, "SK": old_sk}) + fresh = fresh_resp.get("Item") + if not fresh: + logger.warning( + "update_session_activity: row vanished between Phase A and Phase B for %s", + session_id, + ) + return True + carried = {k: v for k, v in fresh.items() if k not in ("PK", "SK")} + new_item = {"PK": pk, "SK": new_sk, **carried} + table.put_item(Item=new_item) + table.delete_item(Key={"PK": pk, "SK": old_sk}) + + logger.info("Updated session activity for %s (sk_rotated=%s)", session_id, new_sk != old_sk) + return True + except Exception as e: + logger.error("update_session_activity failed for %s: %s", session_id, e, exc_info=True) + return False + + async def _get_session_by_gsi(session_id: str, user_id: str, table) -> Optional[dict]: """ Get session record using GSI (SessionLookupIndex) @@ -929,6 +1153,11 @@ async def _get_session_metadata_cloud( for key in ['PK', 'SK', 'GSI_PK', 'GSI_SK']: item.pop(key, None) + # Dedupe pending interrupts at the storage boundary so list_append + # re-emits don't surface as duplicate consent prompts. + if "pendingInterrupts" in item: + item["pendingInterrupts"] = _dedupe_interrupt_dicts(item["pendingInterrupts"]) + return SessionMetadata.model_validate(item) except Exception as e: @@ -1110,6 +1339,9 @@ async def _list_user_sessions_cloud( if is_preview_session(session_id): continue + if "pendingInterrupts" in item: + item["pendingInterrupts"] = _dedupe_interrupt_dicts(item["pendingInterrupts"]) + metadata = SessionMetadata.model_validate(item) sessions.append(metadata) @@ -1201,14 +1433,36 @@ def _interrupts_to_dynamo(interrupts: Iterable[PendingInterrupt]) -> List[Dict[s return [item.model_dump(by_alias=True, exclude_none=True) for item in interrupts] -def _interrupts_from_dynamo(raw: Any) -> List[PendingInterrupt]: - """Best-effort parse of stored interrupt entries; tolerate missing/legacy items.""" +def _dedupe_interrupt_dicts(raw: Any) -> List[Dict[str, Any]]: + """Last-write-wins dedupe of raw interrupt dicts by ``interruptId``. + + ``add_pending_interrupt`` uses ``list_append`` to be race-free against + concurrent writers, which means re-emits of the same interrupt across + stream replays accumulate as duplicate list entries. Storage-layer + callers run this on the raw list before handing it to the model so + Pydantic validation sees a clean list. Insertion order of the first + occurrence is preserved. + """ if not raw or not isinstance(raw, list): return [] - parsed: List[PendingInterrupt] = [] + by_id: Dict[str, Dict[str, Any]] = {} + order: List[str] = [] for entry in raw: if not isinstance(entry, dict): continue + iid = entry.get("interruptId") or entry.get("interrupt_id") + if not iid: + continue + if iid not in by_id: + order.append(iid) + by_id[iid] = entry + return [by_id[iid] for iid in order] + + +def _interrupts_from_dynamo(raw: Any) -> List[PendingInterrupt]: + """Parse stored interrupt entries with dedupe and corrupted-entry tolerance.""" + parsed: List[PendingInterrupt] = [] + for entry in _dedupe_interrupt_dicts(raw): try: parsed.append(PendingInterrupt.model_validate(entry)) except Exception as exc: # pragma: no cover — corrupted entry shouldn't break load @@ -1221,11 +1475,13 @@ async def add_pending_interrupt( user_id: str, interrupt: PendingInterrupt, ) -> None: - """Idempotently append a pending OAuth interrupt to the session record. + """Append a pending OAuth interrupt to the session record. - If an entry with the same ``interrupt_id`` already exists, it's replaced - rather than duplicated — the agent may re-emit the same interrupt across - re-streams of a paused turn. + Uses ``list_append`` with ``if_not_exists`` so concurrent writers can't + lose each other's entries — no read-modify-write window. Re-emits of + the same ``interrupt_id`` across stream replays accumulate as duplicate + list entries and are collapsed last-write-wins by + ``_interrupts_from_dynamo`` on read. No-op when the session metadata record is missing (preview sessions, sessions deleted mid-turn). The frontend will fall back to its in-memory @@ -1252,18 +1508,13 @@ async def add_pending_interrupt( logger.warning("Session %s has no SK; cannot update pending_interrupts", session_id) return - current_raw = existing.get("pendingInterrupts") or [] - current = _interrupts_from_dynamo(current_raw) - - # Replace any existing entry with the same interrupt_id - merged = [p for p in current if p.interrupt_id != interrupt.interrupt_id] - merged.append(interrupt) + new_entry = interrupt.model_dump(by_alias=True, exclude_none=True) table.update_item( Key={"PK": f"USER#{user_id}", "SK": sk}, - UpdateExpression="SET #pi = :pi", + UpdateExpression="SET #pi = list_append(if_not_exists(#pi, :empty), :new)", ExpressionAttributeNames={"#pi": "pendingInterrupts"}, - ExpressionAttributeValues={":pi": _interrupts_to_dynamo(merged)}, + ExpressionAttributeValues={":empty": [], ":new": [new_entry]}, ) logger.info( "Persisted pending_interrupt %s (provider=%s) for session %s", diff --git a/backend/tests/shared/test_sessions_metadata.py b/backend/tests/shared/test_sessions_metadata.py index e1d64edb..3859579b 100644 --- a/backend/tests/shared/test_sessions_metadata.py +++ b/backend/tests/shared/test_sessions_metadata.py @@ -230,3 +230,267 @@ async def test_missing_env_raises(self, sessions_metadata_table, monkeypatch): await store_user_display_text( session_id="s1", user_id="u1", message_id=0, display_text="boom", ) + + +class TestUpdateSessionActivity: + """Per-turn metadata update via targeted writes — closes the merge-write race.""" + + @pytest.mark.asyncio + async def test_increments_message_count(self, sessions_metadata_table): + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, + update_session_activity, + get_session_metadata, + ) + await ensure_session_metadata_exists("s1", "u1") + before = await get_session_metadata("s1", "u1") + assert before.message_count == 0 + + applied = await update_session_activity( + session_id="s1", user_id="u1", last_model="claude-3", last_temperature=0.7, + ) + assert applied is True + after = await get_session_metadata("s1", "u1") + assert after.message_count == 1 + + await update_session_activity(session_id="s1", user_id="u1", last_model="claude-3") + after2 = await get_session_metadata("s1", "u1") + assert after2.message_count == 2 + + @pytest.mark.asyncio + async def test_preserves_title_set_by_title_gen(self, sessions_metadata_table): + """Race regression: post-stream activity update must not clobber title-gen's write.""" + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, + update_session_title, + update_session_activity, + get_session_metadata, + ) + await ensure_session_metadata_exists("s1", "u1") + await update_session_title("s1", "u1", "My Generated Title") + await update_session_activity( + session_id="s1", user_id="u1", last_model="claude-3", last_temperature=0.5, + ) + result = await get_session_metadata("s1", "u1") + assert result.title == "My Generated Title" + + @pytest.mark.asyncio + async def test_preserves_pending_interrupts(self, sessions_metadata_table): + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, + add_pending_interrupt, + update_session_activity, + get_pending_interrupts, + ) + from apis.shared.sessions.models import PendingInterrupt + + await ensure_session_metadata_exists("s1", "u1") + await add_pending_interrupt( + session_id="s1", user_id="u1", + interrupt=PendingInterrupt( + interruptId="i1", providerId="slack", createdAt="2026-04-25T00:00:00Z", + ), + ) + await update_session_activity(session_id="s1", user_id="u1", last_model="claude-3") + interrupts = await get_pending_interrupts("s1", "u1") + assert len(interrupts) == 1 + assert interrupts[0].interrupt_id == "i1" + + @pytest.mark.asyncio + async def test_preserves_assistant_id_in_preferences(self, sessions_metadata_table): + """assistant_id set by the assistant-attach flow must survive per-turn updates.""" + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, + store_session_metadata, + update_session_activity, + get_session_metadata, + ) + from apis.shared.sessions.models import SessionMetadata, SessionPreferences + + await ensure_session_metadata_exists("s1", "u1") + existing = await get_session_metadata("s1", "u1") + seeded = SessionMetadata( + sessionId="s1", userId="u1", + title=existing.title, status="active", + createdAt=existing.created_at, + lastMessageAt=existing.last_message_at, + messageCount=existing.message_count, + preferences=SessionPreferences(assistantId="asst-abc"), + ) + await store_session_metadata("s1", "u1", seeded) + + await update_session_activity( + session_id="s1", user_id="u1", last_model="claude-3", last_temperature=0.5, + ) + result = await get_session_metadata("s1", "u1") + assert result.preferences.assistant_id == "asst-abc" + assert result.preferences.last_model == "claude-3" + assert result.preferences.last_temperature == 0.5 + + @pytest.mark.asyncio + async def test_rotates_sk_to_new_timestamp(self, sessions_metadata_table): + """SK rotation keeps recency listing correct — only one row remains.""" + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, + update_session_activity, + ) + await ensure_session_metadata_exists("s1", "u1") + items = sessions_metadata_table.scan()["Items"] + s_items = [i for i in items if i["SK"].startswith("S#ACTIVE#")] + assert len(s_items) == 1 + old_sk = s_items[0]["SK"] + + await update_session_activity(session_id="s1", user_id="u1", last_model="claude-3") + + items = sessions_metadata_table.scan()["Items"] + s_items_after = [i for i in items if i["SK"].startswith("S#ACTIVE#")] + assert len(s_items_after) == 1 + assert s_items_after[0]["SK"] != old_sk + + @pytest.mark.asyncio + async def test_self_heals_when_row_missing(self, sessions_metadata_table): + """If pre-create failed (or row was deleted), update self-heals via ensure_session_metadata_exists.""" + from apis.shared.sessions.metadata import update_session_activity, get_session_metadata + applied = await update_session_activity( + session_id="never-pre-created", user_id="u1", last_model="claude-3", + ) + assert applied is True + result = await get_session_metadata("never-pre-created", "u1") + assert result is not None + assert result.message_count == 1 + + @pytest.mark.asyncio + async def test_noop_for_preview_session(self, sessions_metadata_table): + from apis.shared.sessions.metadata import update_session_activity + applied = await update_session_activity( + session_id="preview-abc", user_id="u1", last_model="claude-3", + ) + assert applied is False + items = sessions_metadata_table.scan()["Items"] + assert items == [] + + +class TestAddPendingInterruptListAppend: + """list_append-based persistence — race-free with no read-modify-write.""" + + @pytest.mark.asyncio + async def test_first_interrupt(self, sessions_metadata_table): + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, add_pending_interrupt, get_pending_interrupts, + ) + from apis.shared.sessions.models import PendingInterrupt + + await ensure_session_metadata_exists("s1", "u1") + await add_pending_interrupt( + session_id="s1", user_id="u1", + interrupt=PendingInterrupt( + interruptId="i1", providerId="slack", createdAt="2026-04-25T00:00:00Z", + ), + ) + interrupts = await get_pending_interrupts("s1", "u1") + assert len(interrupts) == 1 + assert interrupts[0].interrupt_id == "i1" + assert interrupts[0].provider_id == "slack" + + @pytest.mark.asyncio + async def test_two_distinct_interrupts_accumulate(self, sessions_metadata_table): + """Two adds for different ids accumulate — list_append is atomic in DynamoDB.""" + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, add_pending_interrupt, get_pending_interrupts, + ) + from apis.shared.sessions.models import PendingInterrupt + + await ensure_session_metadata_exists("s1", "u1") + await add_pending_interrupt( + session_id="s1", user_id="u1", + interrupt=PendingInterrupt( + interruptId="i1", providerId="slack", createdAt="2026-04-25T00:00:00Z", + ), + ) + await add_pending_interrupt( + session_id="s1", user_id="u1", + interrupt=PendingInterrupt( + interruptId="i2", providerId="gmail", createdAt="2026-04-25T00:00:01Z", + ), + ) + interrupts = await get_pending_interrupts("s1", "u1") + ids = {p.interrupt_id for p in interrupts} + assert ids == {"i1", "i2"} + + @pytest.mark.asyncio + async def test_reemit_dedupes_on_read_last_write_wins(self, sessions_metadata_table): + """Same id added twice → one entry on read, last write's payload survives.""" + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, add_pending_interrupt, get_pending_interrupts, + ) + from apis.shared.sessions.models import PendingInterrupt + + await ensure_session_metadata_exists("s1", "u1") + await add_pending_interrupt( + session_id="s1", user_id="u1", + interrupt=PendingInterrupt( + interruptId="i1", providerId="slack", createdAt="2026-04-25T00:00:00Z", + ), + ) + await add_pending_interrupt( + session_id="s1", user_id="u1", + interrupt=PendingInterrupt( + interruptId="i1", providerId="slack", + triggeringMessageId="msg-7", + createdAt="2026-04-25T00:00:05Z", + ), + ) + interrupts = await get_pending_interrupts("s1", "u1") + assert len(interrupts) == 1 + assert interrupts[0].interrupt_id == "i1" + assert interrupts[0].triggering_message_id == "msg-7" + assert interrupts[0].created_at == "2026-04-25T00:00:05Z" + + @pytest.mark.asyncio + async def test_noop_when_session_missing(self, sessions_metadata_table): + from apis.shared.sessions.metadata import add_pending_interrupt, get_pending_interrupts + from apis.shared.sessions.models import PendingInterrupt + + await add_pending_interrupt( + session_id="never-created", user_id="u1", + interrupt=PendingInterrupt( + interruptId="i1", providerId="slack", createdAt="2026-04-25T00:00:00Z", + ), + ) + interrupts = await get_pending_interrupts("never-created", "u1") + assert interrupts == [] + + +class TestInterruptsFromDynamoDedupe: + """Read-side dedupe collapses duplicates produced by list_append re-emits.""" + + def test_dedupe_by_id_last_write_wins(self): + from apis.shared.sessions.metadata import _interrupts_from_dynamo + raw = [ + {"interruptId": "i1", "providerId": "slack", "createdAt": "2026-04-25T00:00:00Z"}, + {"interruptId": "i2", "providerId": "gmail", "createdAt": "2026-04-25T00:00:01Z"}, + {"interruptId": "i1", "providerId": "slack", + "triggeringMessageId": "msg-7", "createdAt": "2026-04-25T00:00:05Z"}, + ] + result = _interrupts_from_dynamo(raw) + assert [p.interrupt_id for p in result] == ["i1", "i2"] + i1 = next(p for p in result if p.interrupt_id == "i1") + assert i1.triggering_message_id == "msg-7" + assert i1.created_at == "2026-04-25T00:00:05Z" + + def test_skips_unparseable_entries(self): + from apis.shared.sessions.metadata import _interrupts_from_dynamo + raw = [ + {"interruptId": "i1", "providerId": "slack", "createdAt": "2026-04-25T00:00:00Z"}, + {"missing": "required-fields"}, + "not a dict", + ] + result = _interrupts_from_dynamo(raw) + assert len(result) == 1 + assert result[0].interrupt_id == "i1" + + def test_empty_input(self): + from apis.shared.sessions.metadata import _interrupts_from_dynamo + assert _interrupts_from_dynamo(None) == [] + assert _interrupts_from_dynamo([]) == [] + assert _interrupts_from_dynamo("not a list") == [] diff --git a/frontend/ai.client/src/app/session/components/message-list/components/assistant-message.component.ts b/frontend/ai.client/src/app/session/components/message-list/components/assistant-message.component.ts index 68077336..f0b13b65 100644 --- a/frontend/ai.client/src/app/session/components/message-list/components/assistant-message.component.ts +++ b/frontend/ai.client/src/app/session/components/message-list/components/assistant-message.component.ts @@ -6,7 +6,7 @@ import { ToolCallGroup, ToolCallDisplay } from './tool-rail/tool-rail.model'; import { ReasoningContentComponent } from './reasoning-content'; import { StreamingTextComponent } from './streaming-text.component'; import { InlineVisualComponent } from './inline-visual'; -import { OAuthConsentPromptComponent } from '../../../../components/oauth-consent-prompt/oauth-consent-prompt.component'; +import { OAuthConsentPromptComponent } from './oauth-consent-prompt/oauth-consent-prompt.component'; import { OAuthConsentRequest, OAuthConsentService, diff --git a/frontend/ai.client/src/app/session/components/message-list/components/oauth-consent-prompt/oauth-consent-prompt.component.ts b/frontend/ai.client/src/app/session/components/message-list/components/oauth-consent-prompt/oauth-consent-prompt.component.ts index 1ee60eb8..0a6fc33d 100644 --- a/frontend/ai.client/src/app/session/components/message-list/components/oauth-consent-prompt/oauth-consent-prompt.component.ts +++ b/frontend/ai.client/src/app/session/components/message-list/components/oauth-consent-prompt/oauth-consent-prompt.component.ts @@ -12,9 +12,9 @@ import { import { OAuthConsentRequest, OAuthConsentService, -} from '../../services/oauth-consent/oauth-consent.service'; -import { UserConnector } from '../../settings/connectors/models/user-connector.model'; -import { UserConnectorsService } from '../../settings/connectors/services/user-connectors.service'; +} from '../../../../../services/oauth-consent/oauth-consent.service'; +import { UserConnector } from '../../../../../settings/connectors/models/user-connector.model'; +import { UserConnectorsService } from '../../../../../settings/connectors/services/user-connectors.service'; /** * Inline OAuth consent prompt rendered alongside the assistant message whose @@ -40,24 +40,24 @@ import { UserConnectorsService } from '../../settings/connectors/services/user-c host: { class: 'block' }, template: `
@if (iconDataUrl(); as data) { } @else {

-

+

@if (isBlocked()) { Popup blocked. Open {{ displayName() }} @@ -88,7 +88,7 @@ import { UserConnectorsService } from '../../settings/connectors/services/user-c

-
+
@if (isBlocked() && request().authorizationUrl; as blockedUrl) { Open - } @else { }
@@ -163,31 +162,20 @@ import { UserConnectorsService } from '../../settings/connectors/services/user-c .action-btn { display: inline-flex; align-items: center; - gap: 0.375rem; - border-radius: 0.5rem; - padding: 0.375rem 0.75rem; - font-size: 0.8125rem; + gap: 0.25rem; + border-radius: 0.375rem; + padding: 0.25rem 0.625rem; + font-size: 0.75rem; font-weight: 600; color: white; - background: linear-gradient( - 180deg, - var(--color-primary-500) 0%, - var(--color-primary-600) 100% - ); - box-shadow: - 0 1px 2px rgba(15, 23, 42, 0.12), - inset 0 1px 0 rgba(255, 255, 255, 0.12); + background: var(--color-secondary-500); transition: - transform 120ms ease, - box-shadow 160ms ease, - filter 120ms ease; + background-color 120ms ease, + transform 120ms ease; } .action-btn:hover:not(:disabled) { - filter: brightness(1.05); - box-shadow: - 0 4px 14px -4px rgba(15, 23, 42, 0.25), - inset 0 1px 0 rgba(255, 255, 255, 0.18); + background: var(--color-secondary-600); } .action-btn:active:not(:disabled) { @@ -195,7 +183,7 @@ import { UserConnectorsService } from '../../settings/connectors/services/user-c } .action-btn:focus-visible { - outline: 2px solid var(--color-primary-500); + outline: 2px solid var(--color-secondary-500); outline-offset: 2px; } diff --git a/frontend/ai.client/src/app/session/components/message-list/message-list.component.ts b/frontend/ai.client/src/app/session/components/message-list/message-list.component.ts index f0a8acdf..e3d1ba12 100644 --- a/frontend/ai.client/src/app/session/components/message-list/message-list.component.ts +++ b/frontend/ai.client/src/app/session/components/message-list/message-list.component.ts @@ -6,7 +6,7 @@ import { AssistantMessageComponent } from './components/assistant-message.compon import { MessageMetadataBadgesComponent } from './components/message-metadata-badges.component'; import { CitationDisplayComponent } from '../citation-display/citation-display.component'; import { PulsatingLoaderComponent } from '../../../components/pulsating-loader.component'; -import { OAuthConsentPromptComponent } from '../../../components/oauth-consent-prompt/oauth-consent-prompt.component'; +import { OAuthConsentPromptComponent } from './components/oauth-consent-prompt/oauth-consent-prompt.component'; import { OAuthConsentRequest, OAuthConsentService, diff --git a/frontend/ai.client/src/app/session/services/chat/chat-http.service.ts b/frontend/ai.client/src/app/session/services/chat/chat-http.service.ts index 2c857f68..f28485c7 100644 --- a/frontend/ai.client/src/app/session/services/chat/chat-http.service.ts +++ b/frontend/ai.client/src/app/session/services/chat/chat-http.service.ts @@ -137,20 +137,10 @@ export class ChatHttpService { this.messageMapService.endStreaming(); this.chatStateService.setChatLoading(false); - // Generate title only for new sessions (fire and forget - don't block on this) + // Title is generated server-side concurrently with the stream + // (see /invocations). Refresh metadata so the sidebar reflects it. if (this.sessionService.isNewSession(requestObject.session_id)) { - this.generateTitle(requestObject.session_id, requestObject.message) - .then((response) => { - // Update the session title in the local cache - this.sessionService.updateSessionTitleInCache( - requestObject.session_id, - response.title, - ); - }) - .catch((error) => { - // Log error but don't block the user experience - console.error('Failed to generate session title:', error); - }); + this.refreshTitleFromServer(requestObject.session_id); } }, onerror: (err) => { @@ -187,6 +177,29 @@ export class ChatHttpService { this.chatStateService.resetState(); } + /** + * Pull the server-generated title into the local sidebar cache. + * + * Title generation runs concurrently with the agent stream on the backend. + * Nova Micro typically finishes well before the stream does, but on fast + * responses we may race past it — so on a "New Conversation" placeholder + * we retry once after a short delay before giving up. + */ + private async refreshTitleFromServer(sessionId: string, retried = false): Promise { + try { + const metadata = await this.sessionService.getSessionMetadata(sessionId); + if (metadata.title && metadata.title !== 'New Conversation') { + this.sessionService.updateSessionTitleInCache(sessionId, metadata.title); + return; + } + if (!retried) { + setTimeout(() => this.refreshTitleFromServer(sessionId, true), 1500); + } + } catch (error) { + console.error('Failed to refresh session title:', error); + } + } + /** * Generates a title for a session based on the user's input. * diff --git a/frontend/ai.client/src/app/session/services/session/message-map.service.spec.ts b/frontend/ai.client/src/app/session/services/session/message-map.service.spec.ts index 9e33f138..54b23c2e 100644 --- a/frontend/ai.client/src/app/session/services/session/message-map.service.spec.ts +++ b/frontend/ai.client/src/app/session/services/session/message-map.service.spec.ts @@ -4,6 +4,7 @@ import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; import { MessageMapService } from './message-map.service'; import { SessionService } from './session.service'; import { FileUploadService } from '../../../services/file-upload'; +import { OAuthConsentService } from '../../../services/oauth-consent/oauth-consent.service'; import { signal } from '@angular/core'; describe('MessageMapService', () => { @@ -11,6 +12,7 @@ describe('MessageMapService', () => { let httpMock: HttpTestingController; let mockSessionService: any; let mockFileUploadService: any; + let mockOAuthConsentService: any; beforeEach(() => { TestBed.resetTestingModule(); @@ -22,12 +24,16 @@ describe('MessageMapService', () => { mockFileUploadService = { listSessionFiles: vi.fn().mockResolvedValue([]) }; + mockOAuthConsentService = { + requestConsent: vi.fn() + }; TestBed.configureTestingModule({ imports: [HttpClientTestingModule], providers: [ MessageMapService, { provide: SessionService, useValue: mockSessionService }, - { provide: FileUploadService, useValue: mockFileUploadService } + { provide: FileUploadService, useValue: mockFileUploadService }, + { provide: OAuthConsentService, useValue: mockOAuthConsentService } ] }); service = TestBed.inject(MessageMapService); @@ -85,11 +91,48 @@ describe('MessageMapService', () => { expect(mockSessionService.getMessages).toHaveBeenCalledWith('session-1'); expect(mockFileUploadService.listSessionFiles).toHaveBeenCalledWith('session-1'); - + const messagesSignal = service.getMessagesForSession('session-1'); expect(messagesSignal()).toEqual(mockMessages); }); + it('should hydrate pending OAuth interrupts from camelCase wire response', async () => { + // Regression: backend serializes with by_alias=True so the wire payload uses + // camelCase (pendingInterrupts, interruptId, providerId, ...). If the consumer + // reads snake_case fields, the consent prompt silently fails to re-render + // after a refresh. + const mockMessages = [ + { id: 'msg-assistant-7', role: 'assistant', content: [{ type: 'text', text: 'ok' }] }, + ]; + mockSessionService.getMessages.mockResolvedValue({ + messages: mockMessages, + pendingInterrupts: [ + { + interruptId: 'v1:before_tool_call:tooluse_abc:xyz', + providerId: 'google-calendar-employee', + createdAt: '2026-04-26T01:13:54.543143+00:00', + }, + ], + }); + + await service.loadMessagesForSession('session-with-interrupt'); + + expect(mockOAuthConsentService.requestConsent).toHaveBeenCalledTimes(1); + expect(mockOAuthConsentService.requestConsent).toHaveBeenCalledWith( + 'google-calendar-employee', + undefined, + 'v1:before_tool_call:tooluse_abc:xyz', + 'msg-assistant-7', + 'session-with-interrupt', + ); + }); + + it('should not call requestConsent when no pending interrupts are returned', async () => { + mockSessionService.getMessages.mockResolvedValue({ messages: [] }); + await service.loadMessagesForSession('session-clean'); + expect(mockOAuthConsentService.requestConsent).not.toHaveBeenCalled(); + }); + it('should set loading session state', () => { service.setLoadingSession('session-1'); expect(service.isLoadingSession()).toBe('session-1'); diff --git a/frontend/ai.client/src/app/session/services/session/message-map.service.ts b/frontend/ai.client/src/app/session/services/session/message-map.service.ts index 2108ef42..446f84dc 100644 --- a/frontend/ai.client/src/app/session/services/session/message-map.service.ts +++ b/frontend/ai.client/src/app/session/services/session/message-map.service.ts @@ -286,7 +286,7 @@ export class MessageMapService { // Anchor falls back to the most recent assistant message in history, // matching how the live SSE flow already attaches prompts. Authorization // URL is intentionally omitted — fresh URL is fetched lazily on Connect. - this.hydratePendingInterrupts(sessionId, messagesResponse.pending_interrupts, processedMessages); + this.hydratePendingInterrupts(sessionId, messagesResponse.pendingInterrupts, processedMessages); } catch (error) { console.error('Failed to load messages for session:', sessionId, error); throw error; @@ -319,10 +319,10 @@ export class MessageMapService { } for (const interrupt of interrupts) { this.oauthConsentService.requestConsent( - interrupt.provider_id, + interrupt.providerId, undefined, // URL is fetched lazily on Connect — stored URLs go stale - interrupt.interrupt_id, - interrupt.triggering_message_id ?? lastAssistantId, + interrupt.interruptId, + interrupt.triggeringMessageId ?? lastAssistantId, sessionId, ); } diff --git a/frontend/ai.client/src/app/session/services/session/session.service.spec.ts b/frontend/ai.client/src/app/session/services/session/session.service.spec.ts index cff98f49..8b90a112 100644 --- a/frontend/ai.client/src/app/session/services/session/session.service.spec.ts +++ b/frontend/ai.client/src/app/session/services/session/session.service.spec.ts @@ -55,7 +55,7 @@ describe('SessionService', () => { describe('getMessages (no ensureAuthenticated)', () => { it('should GET messages', async () => { - const resp: MessagesListResponse = { messages: [mockMessage], next_token: null }; + const resp: MessagesListResponse = { messages: [mockMessage], nextToken: null }; const promise = service.getMessages('s1'); httpMock.expectOne('http://localhost:8000/sessions/s1/messages').flush(resp); expect(await promise).toEqual(resp); diff --git a/frontend/ai.client/src/app/session/services/session/session.service.ts b/frontend/ai.client/src/app/session/services/session/session.service.ts index c19ef353..05a521ca 100644 --- a/frontend/ai.client/src/app/session/services/session/session.service.ts +++ b/frontend/ai.client/src/app/session/services/session/session.service.ts @@ -37,13 +37,13 @@ export interface SessionsListResponse { */ export interface PendingInterrupt { /** Strands interrupt id used to resume the paused turn */ - interrupt_id: string; + interruptId: string; /** Connector providerId needing consent */ - provider_id: string; + providerId: string; /** Id of the assistant message whose tool call triggered this interrupt, if known */ - triggering_message_id?: string | null; + triggeringMessageId?: string | null; /** ISO 8601 timestamp when the interrupt was recorded */ - created_at: string; + createdAt: string; } /** @@ -55,9 +55,9 @@ export interface MessagesListResponse { /** List of messages in the session */ messages: Message[]; /** Pagination token for retrieving the next page of results */ - next_token: string | null; + nextToken: string | null; /** OAuth consent interrupts that paused agent turns and are awaiting user action */ - pending_interrupts?: PendingInterrupt[]; + pendingInterrupts?: PendingInterrupt[]; } /** @@ -389,7 +389,7 @@ export class SessionService { * // Get next page * const nextPage = await sessionService.getSessions({ * limit: 20, - * next_token: response.next_token + * next_token: response.nextToken * }); * ``` */ @@ -437,7 +437,7 @@ export class SessionService { * // Get next page * const nextPage = await sessionService.getMessages( * '8e70ae89-93af-4db7-ba60-f13ea201f4cd', - * { limit: 20, next_token: response.next_token } + * { limit: 20, next_token: response.nextToken } * ); * ``` */ From 22416348dcb30abaed925f712d1e74cdc90ba2d1 Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Sun, 26 Apr 2026 15:31:25 -0600 Subject: [PATCH 27/35] fix(connectors): durable OAuth resume across browser refresh MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Resume after an OAuth-gated tool call only worked when the in-memory agent cache still held the original turn. After a browser refresh the frontend lost its request snapshot and the resume request landed with no enabled_tools / model_id, so the inference API rebuilt a fresh agent with an empty external-tool registry — the paused tool call had nothing to resume against and the LLM responded that the tool wasn't available. Resume contract now lives server-side. On pause, the stream coordinator captures a ``PausedTurnSnapshot`` (enabled_tools, model_id, provider, temperature, system_prompt, caching_enabled, max_tokens) onto the session row alongside the existing ``pendingInterrupts``. On resume, the inference API loads the snapshot and rebuilds the agent from it; Strands' SessionManager then restores ``_interrupt_state`` from AgentCore Memory, so the paused tool call picks up where it left off regardless of cache hit/miss, refresh, or pod restart. Frontend ``lastRequestObject`` snapshotting is gone — the resume payload is now ``{ session_id, message: '', interrupt_responses }``. Server-side snapshot has a 1h TTL; cleared on full turn completion and at the start of any new (non-resume) turn. Co-Authored-By: Claude Opus 4.7 --- backend/src/agents/main_agent/base_agent.py | 17 +++ .../streaming/stream_coordinator.py | 41 +++++- backend/src/apis/inference_api/chat/routes.py | 125 ++++++++++++---- backend/src/apis/shared/sessions/metadata.py | 101 ++++++++++++- backend/src/apis/shared/sessions/models.py | 32 +++++ .../tests/shared/test_sessions_metadata.py | 136 ++++++++++++++++++ .../oauth-consent/oauth-consent.service.ts | 4 +- .../services/chat/chat-request.service.ts | 48 ++----- 8 files changed, 439 insertions(+), 65 deletions(-) diff --git a/backend/src/agents/main_agent/base_agent.py b/backend/src/agents/main_agent/base_agent.py index 8fa9b4dc..eb57f6a6 100644 --- a/backend/src/agents/main_agent/base_agent.py +++ b/backend/src/agents/main_agent/base_agent.py @@ -89,6 +89,19 @@ def __init__( model_id=model_id, temperature=temperature, caching_enabled=caching_enabled, provider=provider, max_tokens=max_tokens ) + # Frozen snapshot of agent-construction params, used when the turn + # pauses on OAuth consent so the resume request can rebuild this exact + # agent shape without depending on the in-process agent cache. + # ``system_prompt`` is captured below after the prompt builder resolves. + self._construction_snapshot: dict = { + "enabled_tools": enabled_tools, + "model_id": model_id, + "provider": provider, + "temperature": temperature, + "caching_enabled": caching_enabled, + "max_tokens": max_tokens, + } + # Load retry configuration from environment variables from agents.main_agent.core.model_config import RetryConfig self.model_config.retry_config = RetryConfig.from_env() @@ -101,6 +114,10 @@ def __init__( self.prompt_builder = SystemPromptBuilder() self.system_prompt = self.prompt_builder.build(include_date=True) + # Capture the resolved system prompt — what we'd need to pass back to + # ``get_agent`` to land on the same cache key on resume. + self._construction_snapshot["system_prompt"] = self.system_prompt + # Initialize tool registry and filter self.tool_registry = create_default_registry() self.tool_filter = ToolFilter(self.tool_registry) diff --git a/backend/src/agents/main_agent/streaming/stream_coordinator.py b/backend/src/agents/main_agent/streaming/stream_coordinator.py index 0a28207f..3de70cac 100644 --- a/backend/src/agents/main_agent/streaming/stream_coordinator.py +++ b/backend/src/agents/main_agent/streaming/stream_coordinator.py @@ -208,6 +208,7 @@ async def stream_response( agent, session_id=session_id, user_id=user_id, + main_agent_wrapper=main_agent_wrapper, ): yield sse @@ -552,6 +553,7 @@ async def _extract_oauth_required_events( session_id: Optional[str] = None, user_id: Optional[str] = None, triggering_message_id: Optional[str] = None, + main_agent_wrapper: Any = None, ) -> List[str]: """Yield one SSE-formatted `oauth_required` event per pending OAuth interrupt on the agent, persisting each one to session metadata so @@ -564,17 +566,52 @@ async def _extract_oauth_required_events( approval gates added later) are ignored here so they can be handled by their own SSE event types. + Also persists a ``PausedTurnSnapshot`` capturing the agent's + construction params, so a resume after refresh / cache eviction + rebuilds the same agent shape (matching tool registry) and lets + Strands restore ``_interrupt_state`` from AgentCore Memory. + Persistence is best-effort: a DynamoDB write failure logs but does not break the live SSE flow. """ + from datetime import timedelta from apis.shared.oauth.models import OAuthRequiredEvent - from apis.shared.sessions.metadata import add_pending_interrupt - from apis.shared.sessions.models import PendingInterrupt + from apis.shared.sessions.metadata import add_pending_interrupt, set_paused_turn + from apis.shared.sessions.models import PausedTurnSnapshot, PendingInterrupt interrupt_state = getattr(agent, "_interrupt_state", None) if not interrupt_state or not getattr(interrupt_state, "activated", False): return [] + # Snapshot the turn-construction context once per pause, before the + # per-interrupt loop. Multiple OAuth interrupts in the same turn + # share one snapshot — they were all built against the same agent. + # TTL matches AgentCore Identity's consent window so stale snapshots + # don't pin storage and a too-late resume returns a clean 400. + snapshot_source = ( + getattr(main_agent_wrapper, "_construction_snapshot", None) if main_agent_wrapper else None + ) + if session_id and user_id and snapshot_source: + try: + now = datetime.now(timezone.utc) + snapshot = PausedTurnSnapshot( + enabled_tools=snapshot_source.get("enabled_tools"), + model_id=snapshot_source.get("model_id"), + provider=snapshot_source.get("provider"), + temperature=snapshot_source.get("temperature"), + system_prompt=snapshot_source.get("system_prompt"), + caching_enabled=snapshot_source.get("caching_enabled"), + max_tokens=snapshot_source.get("max_tokens"), + captured_at=now.isoformat(), + expires_at=(now + timedelta(hours=1)).isoformat(), + ) + await set_paused_turn(session_id, user_id, snapshot) + except Exception as e: + logger.error( + "Failed to persist paused_turn snapshot for session %s: %s", + session_id, e, exc_info=True, + ) + events: List[str] = [] for interrupt in interrupt_state.interrupts.values(): reason = interrupt.reason or {} diff --git a/backend/src/apis/inference_api/chat/routes.py b/backend/src/apis/inference_api/chat/routes.py index fc09db1a..b35b884f 100644 --- a/backend/src/apis/inference_api/chat/routes.py +++ b/backend/src/apis/inference_api/chat/routes.py @@ -253,9 +253,20 @@ async def invocations(request: InvocationRequest, current_user: User = Depends(g # Pre-create session metadata so OAuth interrupts and other state can # attach to the session row from turn one. Best-effort; on failure the # post-stream lazy-create in StreamCoordinator still covers it. + # + # Also clear any stale paused_turn snapshot at the start of a fresh turn. + # If the user abandoned a paused turn and started a new one, the prior + # snapshot is no longer authorized — letting it survive would let a + # later (mistaken) resume request pick up against a turn the user + # already moved past. is_new_session = False if not is_resume: is_new_session = await ensure_session_metadata_exists(input_data.session_id, user_id) + try: + from apis.shared.sessions.metadata import clear_paused_turn + await clear_paused_turn(input_data.session_id, user_id) + except Exception as e: + logger.error("Failed to clear stale paused_turn on new turn: %s", e, exc_info=True) # First turn → kick off title generation concurrently with the stream. # Runs as a background task so it doesn't add latency to TTFT. The @@ -537,29 +548,79 @@ async def invocations(request: InvocationRequest, current_user: User = Depends(g logger.info("Preview session - skipping assistant_id persistence") try: - # Resolve caching_enabled based on managed model configuration - # This allows admins to disable caching for models that don't support it - caching_enabled = await _resolve_caching_enabled(model_id=input_data.model_id, explicit_caching_enabled=input_data.caching_enabled) - - if caching_enabled is False: - logger.info("Prompt caching disabled for model") - - # Get agent instance with user-specific configuration - # AgentCore Memory tracks preferences across sessions per user_id - # Supports multiple LLM providers: AWS Bedrock, OpenAI, and Google Gemini - # Use augmented message and assistant system prompt if assistant RAG was applied - agent = await get_agent( - session_id=input_data.session_id, - user_id=user_id, - auth_token=auth_token, - enabled_tools=input_data.enabled_tools, - model_id=input_data.model_id, - temperature=input_data.temperature, - system_prompt=system_prompt, # Use assistant's instructions if available - caching_enabled=caching_enabled, - provider=input_data.provider, - max_tokens=input_data.max_tokens, - ) + # Resume requests rebuild the agent from the persisted PausedTurnSnapshot + # so a refresh / cache eviction / pod restart between pause and resume + # still lands on the same MainAgent shape (matching tool registry, + # model, prompt). Strands' SessionManager separately restores + # `_interrupt_state` from AgentCore Memory, so the paused tool call + # picks up where it left off. Non-resume requests use the request + # body as before. + if is_resume: + from datetime import datetime, timezone + from apis.shared.sessions.metadata import clear_paused_turn, get_paused_turn + + snapshot = await get_paused_turn(input_data.session_id, user_id) + if not snapshot: + logger.warning( + "Resume rejected: no paused_turn snapshot for session %s", + input_data.session_id, + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="No paused turn for this session; restart the turn.", + ) + try: + expires_at = datetime.fromisoformat(snapshot.expires_at) + except ValueError: + expires_at = None + if expires_at and datetime.now(timezone.utc) > expires_at: + logger.warning( + "Resume rejected: paused_turn snapshot expired for session %s", + input_data.session_id, + ) + await clear_paused_turn(input_data.session_id, user_id) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Paused turn expired; restart the turn.", + ) + + caching_enabled = snapshot.caching_enabled + agent = await get_agent( + session_id=input_data.session_id, + user_id=user_id, + auth_token=auth_token, + enabled_tools=snapshot.enabled_tools, + model_id=snapshot.model_id, + temperature=snapshot.temperature, + system_prompt=snapshot.system_prompt, + caching_enabled=snapshot.caching_enabled, + provider=snapshot.provider, + max_tokens=snapshot.max_tokens, + ) + else: + # Resolve caching_enabled based on managed model configuration + # This allows admins to disable caching for models that don't support it + caching_enabled = await _resolve_caching_enabled(model_id=input_data.model_id, explicit_caching_enabled=input_data.caching_enabled) + + if caching_enabled is False: + logger.info("Prompt caching disabled for model") + + # Get agent instance with user-specific configuration + # AgentCore Memory tracks preferences across sessions per user_id + # Supports multiple LLM providers: AWS Bedrock, OpenAI, and Google Gemini + # Use augmented message and assistant system prompt if assistant RAG was applied + agent = await get_agent( + session_id=input_data.session_id, + user_id=user_id, + auth_token=auth_token, + enabled_tools=input_data.enabled_tools, + model_id=input_data.model_id, + temperature=input_data.temperature, + system_prompt=system_prompt, # Use assistant's instructions if available + caching_enabled=caching_enabled, + provider=input_data.provider, + max_tokens=input_data.max_tokens, + ) # Resume requests must target interrupts that the cached agent # actually has paused. Cache eviction, a process restart, or a @@ -650,12 +711,22 @@ async def stream_with_quota_warning() -> AsyncGenerator[str, None]: # doesn't redisplay a stale prompt. Interrupts that re-paused # (same provider, new url) are left in place; the next event # extractor will refresh them. + # + # When the agent's interrupt state is no longer activated after + # streaming, the turn fully completed — clear ``paused_turn`` too + # so a stale snapshot doesn't authorize a phantom resume against + # an already-finished turn. If interrupts re-paused, the snapshot + # was overwritten by ``_extract_oauth_required_events`` for the + # next pause, so leave it alone. if is_resume and input_data.interrupt_responses: try: strands_agent = getattr(agent, "agent", None) interrupt_state = getattr(strands_agent, "_interrupt_state", None) if strands_agent else None still_paused: set[str] = set() - if interrupt_state and getattr(interrupt_state, "activated", False): + state_activated = bool( + interrupt_state and getattr(interrupt_state, "activated", False) + ) + if state_activated: still_paused = set((getattr(interrupt_state, "interrupts", None) or {}).keys()) resolved_ids = [ entry.interruptId @@ -669,6 +740,12 @@ async def stream_with_quota_warning() -> AsyncGenerator[str, None]: user_id=user_id, interrupt_ids=resolved_ids, ) + if not state_activated: + from apis.shared.sessions.metadata import clear_paused_turn + await clear_paused_turn( + session_id=input_data.session_id, + user_id=user_id, + ) except Exception as cleanup_err: logger.error("Failed to clear resolved pending_interrupts: %s", cleanup_err, exc_info=True) diff --git a/backend/src/apis/shared/sessions/metadata.py b/backend/src/apis/shared/sessions/metadata.py index 12f28d92..a4eb15ad 100644 --- a/backend/src/apis/shared/sessions/metadata.py +++ b/backend/src/apis/shared/sessions/metadata.py @@ -15,7 +15,7 @@ from decimal import Decimal # Relative imports from shared sessions module -from .models import MessageMetadata, PendingInterrupt, SessionMetadata, SessionPreferences +from .models import MessageMetadata, PausedTurnSnapshot, PendingInterrupt, SessionMetadata, SessionPreferences # Import preview session helper from agents.main_agent.session.preview_session_manager import is_preview_session @@ -1588,3 +1588,102 @@ async def get_pending_interrupts(session_id: str, user_id: str) -> List[PendingI if not metadata: return [] return list(metadata.pending_interrupts or []) + + +async def set_paused_turn( + session_id: str, + user_id: str, + snapshot: PausedTurnSnapshot, +) -> None: + """Persist (or replace) the agent-construction snapshot for a paused turn. + + Idempotent overwrite: re-emits within the same turn replace the prior + snapshot rather than accumulating, since the snapshot is turn-scoped + rather than interrupt-scoped — multiple OAuth interrupts in a single + turn share the same construction context. + + No-op when the session metadata record is missing or when the table + name env var is unset (preview/anonymous flows). + """ + sessions_metadata_table = os.environ.get("DYNAMODB_SESSIONS_METADATA_TABLE_NAME") + if not sessions_metadata_table: + logger.warning("DYNAMODB_SESSIONS_METADATA_TABLE_NAME not set; skipping paused_turn persistence") + return + + try: + import boto3 + + dynamodb = boto3.resource("dynamodb") + table = dynamodb.Table(sessions_metadata_table) + + existing = await _get_session_by_gsi(session_id, user_id, table) + if not existing: + logger.info("Skipping paused_turn write — session %s not found", session_id) + return + + sk = existing.get("SK") + if not sk: + logger.warning("Session %s has no SK; cannot update paused_turn", session_id) + return + + snapshot_dict = _convert_floats_to_decimal( + snapshot.model_dump(by_alias=True, exclude_none=True) + ) + + table.update_item( + Key={"PK": f"USER#{user_id}", "SK": sk}, + UpdateExpression="SET #pt = :pt", + ExpressionAttributeNames={"#pt": "pausedTurn"}, + ExpressionAttributeValues={":pt": snapshot_dict}, + ) + logger.info("Persisted paused_turn snapshot for session %s", session_id) + except Exception as e: + # Best-effort: a write failure shouldn't break the live SSE flow. + # The same-process resume still works via the in-memory agent cache. + logger.error("Failed to persist paused_turn: %s", e, exc_info=True) + + +async def get_paused_turn(session_id: str, user_id: str) -> Optional[PausedTurnSnapshot]: + """Return the persisted paused-turn snapshot for a session, if any.""" + metadata = await get_session_metadata(session_id, user_id) + if not metadata: + return None + return metadata.paused_turn + + +async def clear_paused_turn(session_id: str, user_id: str) -> None: + """Drop the paused-turn snapshot for a session. + + Called on successful resume completion, on explicit dismiss, and at the + start of a non-resume invocation so a stale snapshot from an abandoned + turn doesn't poison a fresh one. + """ + sessions_metadata_table = os.environ.get("DYNAMODB_SESSIONS_METADATA_TABLE_NAME") + if not sessions_metadata_table: + return + + try: + import boto3 + + dynamodb = boto3.resource("dynamodb") + table = dynamodb.Table(sessions_metadata_table) + + existing = await _get_session_by_gsi(session_id, user_id, table) + if not existing: + return + + sk = existing.get("SK") + if not sk: + return + + if "pausedTurn" not in existing: + return # Already clear + + table.update_item( + Key={"PK": f"USER#{user_id}", "SK": sk}, + UpdateExpression="REMOVE #pt", + ExpressionAttributeNames={"#pt": "pausedTurn"}, + ) + logger.info("Cleared paused_turn for session %s", session_id) + except Exception as e: + logger.error("Failed to clear paused_turn: %s", e, exc_info=True) diff --git a/backend/src/apis/shared/sessions/models.py b/backend/src/apis/shared/sessions/models.py index e760e83c..baaa2510 100644 --- a/backend/src/apis/shared/sessions/models.py +++ b/backend/src/apis/shared/sessions/models.py @@ -45,6 +45,33 @@ class PendingInterrupt(BaseModel): created_at: str = Field(..., alias="createdAt", description="ISO 8601 timestamp when the interrupt was recorded") +class PausedTurnSnapshot(BaseModel): + """Frozen agent-construction context for a turn that paused on OAuth consent. + + Written once per paused turn so the resume request can rebuild the same + ``MainAgent`` shape (matching tool registry, model, prompt) regardless of + whether the in-process agent cache still holds it. Strands' session + manager separately persists ``_interrupt_state`` to AgentCore Memory, so + once the agent is rebuilt with the right shape the interrupt restores + automatically and the paused tool call can resume. + + Snapshot wins over current request state on resume: a turn the user + already authorized completes with the connector set it was authorized + against, even if the user toggled connectors mid-pause. + """ + + model_config = ConfigDict(populate_by_name=True) + enabled_tools: Optional[List[str]] = Field(default=None, alias="enabledTools") + model_id: Optional[str] = Field(default=None, alias="modelId") + provider: Optional[str] = Field(default=None) + temperature: Optional[float] = Field(default=None) + system_prompt: Optional[str] = Field(default=None, alias="systemPrompt") + caching_enabled: Optional[bool] = Field(default=None, alias="cachingEnabled") + max_tokens: Optional[int] = Field(default=None, alias="maxTokens") + captured_at: str = Field(..., alias="capturedAt", description="ISO 8601 timestamp when the turn paused") + expires_at: str = Field(..., alias="expiresAt", description="ISO 8601 timestamp after which the snapshot is no longer valid for resume") + + class SessionPreferences(BaseModel): """User preferences for a session""" @@ -109,6 +136,11 @@ class SessionMetadata(BaseModel): alias="pendingInterrupts", description="Pending OAuth consent interrupts that paused agent turns in this session", ) + paused_turn: Optional[PausedTurnSnapshot] = Field( + default=None, + alias="pausedTurn", + description="Agent-construction snapshot for a turn paused on OAuth consent; cleared on successful resume or when a new turn supersedes it", + ) class UpdateSessionMetadataRequest(BaseModel): diff --git a/backend/tests/shared/test_sessions_metadata.py b/backend/tests/shared/test_sessions_metadata.py index 3859579b..4e46e40f 100644 --- a/backend/tests/shared/test_sessions_metadata.py +++ b/backend/tests/shared/test_sessions_metadata.py @@ -494,3 +494,139 @@ def test_empty_input(self): assert _interrupts_from_dynamo(None) == [] assert _interrupts_from_dynamo([]) == [] assert _interrupts_from_dynamo("not a list") == [] + + +class TestPausedTurnSnapshot: + """PausedTurnSnapshot persistence — singleton, idempotent, round-trippable. + + The snapshot is the durable contract that lets a refresh / cache eviction + resume a paused agent turn — without it, the resume rebuilds an agent + with an empty tool registry and the paused tool call has nothing to + resume against. + """ + + @pytest.mark.asyncio + async def test_set_get_round_trip(self, sessions_metadata_table): + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, set_paused_turn, get_paused_turn, + ) + from apis.shared.sessions.models import PausedTurnSnapshot + + await ensure_session_metadata_exists("s1", "u1") + snap = PausedTurnSnapshot( + enabledTools=["calendar", "gmail"], modelId="claude-sonnet-4-6", + provider="bedrock", temperature=0.2, systemPrompt="prompt-text", + cachingEnabled=True, maxTokens=4096, + capturedAt="2026-04-25T00:00:00Z", expiresAt="2026-04-25T01:00:00Z", + ) + await set_paused_turn("s1", "u1", snap) + got = await get_paused_turn("s1", "u1") + assert got is not None + assert got.enabled_tools == ["calendar", "gmail"] + assert got.model_id == "claude-sonnet-4-6" + assert got.system_prompt == "prompt-text" + assert got.temperature == 0.2 + assert got.caching_enabled is True + + @pytest.mark.asyncio + async def test_idempotent_overwrite(self, sessions_metadata_table): + """Multiple OAuth interrupts in one turn share a single snapshot — + re-writing replaces in place rather than accumulating.""" + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, set_paused_turn, get_paused_turn, + ) + from apis.shared.sessions.models import PausedTurnSnapshot + + await ensure_session_metadata_exists("s1", "u1") + first = PausedTurnSnapshot( + enabledTools=["calendar"], capturedAt="2026-04-25T00:00:00Z", + expiresAt="2026-04-25T01:00:00Z", + ) + second = PausedTurnSnapshot( + enabledTools=["calendar", "gmail"], capturedAt="2026-04-25T00:00:01Z", + expiresAt="2026-04-25T01:00:01Z", + ) + await set_paused_turn("s1", "u1", first) + await set_paused_turn("s1", "u1", second) + got = await get_paused_turn("s1", "u1") + assert got is not None + assert got.enabled_tools == ["calendar", "gmail"] + assert got.captured_at == "2026-04-25T00:00:01Z" + + @pytest.mark.asyncio + async def test_clear_removes_snapshot(self, sessions_metadata_table): + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, set_paused_turn, + get_paused_turn, clear_paused_turn, + ) + from apis.shared.sessions.models import PausedTurnSnapshot + + await ensure_session_metadata_exists("s1", "u1") + await set_paused_turn( + "s1", "u1", + PausedTurnSnapshot( + enabledTools=["calendar"], capturedAt="2026-04-25T00:00:00Z", + expiresAt="2026-04-25T01:00:00Z", + ), + ) + assert await get_paused_turn("s1", "u1") is not None + await clear_paused_turn("s1", "u1") + assert await get_paused_turn("s1", "u1") is None + + @pytest.mark.asyncio + async def test_clear_is_noop_when_already_clear(self, sessions_metadata_table): + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, clear_paused_turn, get_paused_turn, + ) + + await ensure_session_metadata_exists("s1", "u1") + await clear_paused_turn("s1", "u1") + assert await get_paused_turn("s1", "u1") is None + + @pytest.mark.asyncio + async def test_set_noop_when_session_missing(self, sessions_metadata_table): + """Preview/anonymous sessions don't have a metadata row — write must + not crash and a subsequent get returns None.""" + from apis.shared.sessions.metadata import set_paused_turn, get_paused_turn + from apis.shared.sessions.models import PausedTurnSnapshot + + await set_paused_turn( + "never-created", "u1", + PausedTurnSnapshot( + enabledTools=["calendar"], capturedAt="2026-04-25T00:00:00Z", + expiresAt="2026-04-25T01:00:00Z", + ), + ) + assert await get_paused_turn("never-created", "u1") is None + + @pytest.mark.asyncio + async def test_paused_turn_independent_of_pending_interrupts(self, sessions_metadata_table): + """``paused_turn`` and ``pending_interrupts`` live on the same row + but their lifecycles don't intrude on each other — clearing one + leaves the other intact.""" + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, set_paused_turn, clear_paused_turn, + add_pending_interrupt, get_pending_interrupts, get_paused_turn, + ) + from apis.shared.sessions.models import PausedTurnSnapshot, PendingInterrupt + + await ensure_session_metadata_exists("s1", "u1") + await set_paused_turn( + "s1", "u1", + PausedTurnSnapshot( + enabledTools=["calendar"], capturedAt="2026-04-25T00:00:00Z", + expiresAt="2026-04-25T01:00:00Z", + ), + ) + await add_pending_interrupt( + "s1", "u1", + PendingInterrupt( + interruptId="i1", providerId="calendar", createdAt="2026-04-25T00:00:00Z", + ), + ) + + await clear_paused_turn("s1", "u1") + assert await get_paused_turn("s1", "u1") is None + interrupts = await get_pending_interrupts("s1", "u1") + assert len(interrupts) == 1 + assert interrupts[0].interrupt_id == "i1" diff --git a/frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.ts b/frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.ts index 84ce5e91..de832640 100644 --- a/frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.ts +++ b/frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.ts @@ -249,7 +249,9 @@ export class OAuthConsentService { // any — fire so the agent can finish the turn. this.dismiss(providerId); if (request.interruptId && this.resumeHandler) { - void Promise.resolve(this.resumeHandler([request.interruptId])).catch((err) => + void Promise.resolve( + this.resumeHandler([request.interruptId], { sessionId: request.sessionId }), + ).catch((err) => console.error('OAuth resume handler failed after pre-consented refresh', err), ); } diff --git a/frontend/ai.client/src/app/session/services/chat/chat-request.service.ts b/frontend/ai.client/src/app/session/services/chat/chat-request.service.ts index 1ce55b26..be128a3f 100644 --- a/frontend/ai.client/src/app/session/services/chat/chat-request.service.ts +++ b/frontend/ai.client/src/app/session/services/chat/chat-request.service.ts @@ -41,11 +41,6 @@ export class ChatRequestService implements OnDestroy { private router = inject(Router); // TODO: Inject proper logging service - /** Last request payload — replayed (with `interrupt_responses` added) when - * the user completes an OAuth consent so the paused agent turn resumes - * without retyping. Cleared on a true new turn. */ - private lastRequestObject: Record | null = null; - constructor() { this.oauthConsentService.setResumeHandler((interruptIds, context) => this.resumeFromOAuthConsent(interruptIds, context?.sessionId), @@ -100,12 +95,6 @@ export class ChatRequestService implements OnDestroy { assistantId, ); - // Remember this turn's params so the OAuth resume handler can replay - // them with `interrupt_responses` attached. Snapshotting the *exact* - // payload keeps the agent cache key stable, so the resume hits the - // same paused agent instance. - this.lastRequestObject = { ...requestObject }; - try { await this.chatHttpService.sendChatRequest(requestObject); } catch (error) { @@ -178,26 +167,17 @@ export class ChatRequestService implements OnDestroy { } /** - * Replay the last turn's request with `interrupt_responses` attached so - * the backend resumes the paused agent turn instead of starting a new - * one. Triggered by OAuthConsentService after the user completes a - * consent popup. + * Resume the paused agent turn by POSTing the interrupt responses. The + * backend rebuilds the agent from its persisted ``PausedTurnSnapshot``, + * so this request only needs to identify the session and the interrupts — + * no model / tools / prompt context is sent or required. Triggered by + * OAuthConsentService after the user completes a consent popup. */ private async resumeFromOAuthConsent( interruptIds: string[], - fallbackSessionId?: string, + sessionId?: string, ): Promise { - if (interruptIds.length === 0) { - return; - } - - // Live flow: the same tab that originated the turn still has its - // payload — replay it with `interrupt_responses` attached. - // Refresh flow: ``lastRequestObject`` is null, so we synthesize a - // minimal resume payload from the consent service's session context. - const liveSessionId = this.lastRequestObject?.['session_id'] as string | undefined; - const sessionId = liveSessionId ?? fallbackSessionId; - if (!sessionId) { + if (interruptIds.length === 0 || !sessionId) { return; } @@ -209,12 +189,8 @@ export class ChatRequestService implements OnDestroy { this.chatStateService.createNewAbortController(); this.chatStateService.setChatLoading(true); - const baseRequest: Record = this.lastRequestObject - ? { ...this.lastRequestObject } - : { session_id: sessionId }; - const resumeRequest: Record = { - ...baseRequest, + session_id: sessionId, // The original prompt is already in the agent's interrupt context; // sending an empty string keeps the request valid without // re-augmenting or re-charging quota. @@ -234,11 +210,9 @@ export class ChatRequestService implements OnDestroy { this.chatStateService.setChatLoading(false); this.messageMapService.endStreaming(); - // 400 from the resume route means the agent's `_interrupt_state` no - // longer holds the submitted ids — the cache evicted, the pod - // restarted, or the breadcrumb outlived its agent. Surface a - // conversational error so the user knows to retry the prompt - // instead of staring at a stuck spinner. + // 400 from the resume route means either the persisted snapshot is + // missing/expired, or the agent's `_interrupt_state` doesn't recognize + // the submitted ids. Either way the user needs to retry the prompt. if (this.isExpiredInterruptError(error)) { this.errorService.addError( 'Authorization expired', From c87a8202ed8019631f55883bcab5eee39b0a72c4 Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Sun, 26 Apr 2026 15:33:43 -0600 Subject: [PATCH 28/35] fix(connectors): pre-flight external MCP clients so one bad server can't fail a turn MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously, ``load_external_tools`` cached newly-created MCP clients without verifying the server was actually reachable. A single connector that wasn't running locally (or whose endpoint was misconfigured) would sit in the registry and fail the whole turn the first time Strands called ``load_tools()`` on it. Pre-flight each new client immediately after construction. On failure, log a warning, skip the tool, and continue — the user keeps their other tools. On success the call also primes the client's tool cache, so Strands' later ``load_tools()`` becomes a no-op. Co-Authored-By: Claude Opus 4.7 --- .../integrations/external_mcp_client.py | 15 ++++ .../integrations/test_external_mcp_client.py | 68 ++++++++++++++++++- 2 files changed, 80 insertions(+), 3 deletions(-) diff --git a/backend/src/agents/main_agent/integrations/external_mcp_client.py b/backend/src/agents/main_agent/integrations/external_mcp_client.py index 02b709a8..bc6b08f6 100644 --- a/backend/src/agents/main_agent/integrations/external_mcp_client.py +++ b/backend/src/agents/main_agent/integrations/external_mcp_client.py @@ -337,6 +337,21 @@ async def load_external_tools( ) if client: + # Pre-flight the MCP session so a single unreachable + # server (e.g. a connector that isn't running locally) + # drops out of the registry instead of failing the + # whole turn when Strands later calls load_tools(). + # On success this also primes the client's tool cache, + # so Strands' subsequent load_tools() is a no-op. + try: + await client.load_tools() + except Exception as exc: + logger.warning( + f"Skipping external MCP tool {tool_id}: " + f"failed to start client ({exc})" + ) + continue + self.clients[cache_key] = client self._client_versions[cache_key] = tool_version if provider_id: diff --git a/backend/tests/agents/main_agent/integrations/test_external_mcp_client.py b/backend/tests/agents/main_agent/integrations/test_external_mcp_client.py index d816d096..fef59b2b 100644 --- a/backend/tests/agents/main_agent/integrations/test_external_mcp_client.py +++ b/backend/tests/agents/main_agent/integrations/test_external_mcp_client.py @@ -202,12 +202,14 @@ async def test_reuses_client_when_updated_at_unchanged(self): tool = _fake_tool(datetime(2025, 1, 1, tzinfo=timezone.utc)) repo = SimpleNamespace(get_tool=AsyncMock(return_value=tool)) + client = SimpleNamespace(load_tools=AsyncMock(return_value=[])) + with patch( "apis.app_api.tools.repository.get_tool_catalog_repository", return_value=repo, ), patch( "agents.main_agent.integrations.external_mcp_client.create_external_mcp_client", - return_value=object(), + return_value=client, ) as create_mock: first = await integration.load_external_tools(["gmail"]) second = await integration.load_external_tools(["gmail"]) @@ -223,8 +225,8 @@ async def test_rebuilds_client_when_updated_at_changes(self): repo = SimpleNamespace(get_tool=AsyncMock(side_effect=[old, new])) - client_old = object() - client_new = object() + client_old = SimpleNamespace(load_tools=AsyncMock(return_value=[])) + client_new = SimpleNamespace(load_tools=AsyncMock(return_value=[])) with patch( "apis.app_api.tools.repository.get_tool_catalog_repository", @@ -241,3 +243,63 @@ async def test_rebuilds_client_when_updated_at_changes(self): assert integration.clients["gmail"] is client_new # Old client must be evicted, not left dangling under the same key. assert client_old not in integration.clients.values() + + +class TestLoadExternalToolsPreflight: + """A single unreachable MCP server must not fail the whole turn — + `load_external_tools` pre-flights each new client and silently drops + the ones whose session can't be opened.""" + + @pytest.mark.asyncio + async def test_skips_client_when_preflight_fails(self): + integration = ExternalMCPIntegration() + tool = _fake_tool(datetime(2025, 1, 1, tzinfo=timezone.utc)) + repo = SimpleNamespace(get_tool=AsyncMock(return_value=tool)) + + bad_client = SimpleNamespace( + load_tools=AsyncMock(side_effect=RuntimeError("connection refused")) + ) + + with patch( + "apis.app_api.tools.repository.get_tool_catalog_repository", + return_value=repo, + ), patch( + "agents.main_agent.integrations.external_mcp_client.create_external_mcp_client", + return_value=bad_client, + ): + result = await integration.load_external_tools(["gmail"]) + + assert result == [] + # Failed clients must not be cached — otherwise we'd serve a + # broken client back on subsequent turns. + assert "gmail" not in integration.clients + + @pytest.mark.asyncio + async def test_one_failing_client_does_not_block_others(self): + integration = ExternalMCPIntegration() + bad_tool = _fake_tool( + datetime(2025, 1, 1, tzinfo=timezone.utc), tool_id="calendar" + ) + good_tool = _fake_tool( + datetime(2025, 1, 1, tzinfo=timezone.utc), tool_id="gmail" + ) + repo = SimpleNamespace( + get_tool=AsyncMock(side_effect=[bad_tool, good_tool]) + ) + + bad_client = SimpleNamespace( + load_tools=AsyncMock(side_effect=RuntimeError("connection refused")) + ) + good_client = SimpleNamespace(load_tools=AsyncMock(return_value=[])) + + with patch( + "apis.app_api.tools.repository.get_tool_catalog_repository", + return_value=repo, + ), patch( + "agents.main_agent.integrations.external_mcp_client.create_external_mcp_client", + side_effect=[bad_client, good_client], + ): + result = await integration.load_external_tools(["calendar", "gmail"]) + + assert result == [good_client] + assert integration.clients == {"gmail": good_client} From 21d96f1370f0ad6a5a0d9f9aef8e99bba5bc4973 Mon Sep 17 00:00:00 2001 From: Phil Merrell Date: Sun, 26 Apr 2026 15:52:09 -0600 Subject: [PATCH 29/35] fix: Update oauth consent prompt styling --- .../oauth-consent-prompt.component.ts | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/frontend/ai.client/src/app/session/components/message-list/components/oauth-consent-prompt/oauth-consent-prompt.component.ts b/frontend/ai.client/src/app/session/components/message-list/components/oauth-consent-prompt/oauth-consent-prompt.component.ts index 0a6fc33d..9cc9cfbb 100644 --- a/frontend/ai.client/src/app/session/components/message-list/components/oauth-consent-prompt/oauth-consent-prompt.component.ts +++ b/frontend/ai.client/src/app/session/components/message-list/components/oauth-consent-prompt/oauth-consent-prompt.component.ts @@ -133,6 +133,11 @@ import { UserConnectorsService } from '../../../../../settings/connectors/servic Waiting… } @else { Connect +