diff --git a/.gitignore b/.gitignore index 8d641be6..991aeb60 100644 --- a/.gitignore +++ b/.gitignore @@ -90,6 +90,7 @@ tmp/ # Claude - exclude personal settings .claude/settings.local.json +.claude/scheduled_tasks.lock # OS generated files ehthumbs.db diff --git a/CLAUDE.MD b/CLAUDE.MD index a736d9d7..a1a09b5c 100644 --- a/CLAUDE.MD +++ b/CLAUDE.MD @@ -48,6 +48,7 @@ npx cdk deploy --all | `message_stop` | End of message | | `tool_use` / `tool_result` | Tool invocation and result | | `stream_error` | Conversational error | +| `oauth_required` | External MCP tool needs user consent — payload `{providerId, authorizationUrl}`, one event per provider emitted after `message_stop` | | `done` | Stream complete | ## Multi-Protocol Tool Architecture diff --git a/backend/run-app-api.sh b/backend/run-app-api.sh new file mode 100755 index 00000000..1aa0db50 --- /dev/null +++ b/backend/run-app-api.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -euo pipefail +cd "$(dirname "$0")/src/apis/app_api" +exec uv run python main.py diff --git a/backend/run-inference-api.sh b/backend/run-inference-api.sh new file mode 100755 index 00000000..3ce58d97 --- /dev/null +++ b/backend/run-inference-api.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -euo pipefail +cd "$(dirname "$0")/src/apis/inference_api" +exec uv run python main.py diff --git a/backend/src/.env.example b/backend/src/.env.example index 6bea8a6e..a5390470 100644 --- a/backend/src/.env.example +++ b/backend/src/.env.example @@ -226,13 +226,18 @@ SHARED_CONVERSATIONS_TABLE_NAME= # Example: bsu-agentcore-oauth-providers DYNAMODB_OAUTH_PROVIDERS_TABLE_NAME= -# DynamoDB table for OAuth user tokens (OPTIONAL) -# Purpose: Store user OAuth tokens (encrypted with KMS at rest) -# Local Development: Leave empty to disable OAuth connections +# DynamoDB table for OAuth user state (OPTIONAL) +# Purpose: Per-user OAuth state — currently the durable disconnect flag +# read by /status and the agent loop's OAuth consent hook so a +# /disconnect on one inference-API replica is honored from any other. +# Local Development: Leave empty to disable durable disconnect (the hook +# falls back to AgentCore's vault state — fine for single-process dev). # Production: Set to your DynamoDB table name from CDK deployment -# Schema: PK=USER#{user_id}, SK=PROVIDER#{provider_id} +# Schema: +# - PK=USER#{user_id}, SK=DISCONNECT#{provider_id} — disconnect flag +# - (reserved) PK=USER#{user_id}, SK=PROVIDER#{provider_id} — token storage # GSI: ProviderUsersIndex for admin view of connected users -# Security: All tokens encrypted with customer-managed KMS key +# Security: All items encrypted with customer-managed KMS key # CDK Deployment: Created by AppApiStack with KMS encryption # Example: bsu-agentcore-oauth-user-tokens DYNAMODB_OAUTH_USER_TOKENS_TABLE_NAME= diff --git a/backend/src/agents/main_agent/base_agent.py b/backend/src/agents/main_agent/base_agent.py index 538bb423..09a989ab 100644 --- a/backend/src/agents/main_agent/base_agent.py +++ b/backend/src/agents/main_agent/base_agent.py @@ -8,12 +8,13 @@ import logging from abc import ABC, abstractmethod -from typing import AsyncGenerator, List, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional from agents.main_agent.core import ModelConfig, SystemPromptBuilder, AgentFactory from agents.main_agent.session import SessionFactory from agents.main_agent.session.hooks import ( StopHook, + OAuthConsentHook, EmailApprovalHook, ExternalWriteApprovalHook, DangerousToolApprovalHook, @@ -88,6 +89,19 @@ def __init__( model_id=model_id, temperature=temperature, caching_enabled=caching_enabled, provider=provider, max_tokens=max_tokens ) + # Frozen snapshot of agent-construction params, used when the turn + # pauses on OAuth consent so the resume request can rebuild this exact + # agent shape without depending on the in-process agent cache. + # ``system_prompt`` is captured below after the prompt builder resolves. + self._construction_snapshot: dict = { + "enabled_tools": enabled_tools, + "model_id": model_id, + "provider": provider, + "temperature": temperature, + "caching_enabled": caching_enabled, + "max_tokens": max_tokens, + } + # Load retry configuration from environment variables from agents.main_agent.core.model_config import RetryConfig self.model_config.retry_config = RetryConfig.from_env() @@ -100,6 +114,10 @@ def __init__( self.prompt_builder = SystemPromptBuilder() self.system_prompt = self.prompt_builder.build(include_date=True) + # Capture the resolved system prompt — what we'd need to pass back to + # ``get_agent`` to land on the same cache key on resume. + self._construction_snapshot["system_prompt"] = self.system_prompt + # Initialize tool registry and filter self.tool_registry = create_default_registry() self.tool_filter = ToolFilter(self.tool_registry) @@ -131,9 +149,21 @@ def _create_agent(self) -> None: @abstractmethod async def stream_async( - self, message: str, session_id: Optional[str] = None, files: Optional[List] = None, citations: Optional[List] = None, original_message: Optional[str] = None + self, + message: str, + session_id: Optional[str] = None, + files: Optional[List] = None, + citations: Optional[List] = None, + original_message: Optional[str] = None, + interrupt_responses: Optional[List[Dict[str, Any]]] = None, ) -> AsyncGenerator[str, None]: - """Stream agent responses. Subclasses must implement.""" + """Stream agent responses. Subclasses must implement. + + When `interrupt_responses` is provided, the call resumes a paused + agent turn (Strands interrupt protocol) instead of starting a new + one. In that case `message`/`files` are ignored — the original turn + already has the user's prompt in its context. + """ ... def _register_external_mcp_tools(self) -> None: @@ -187,6 +217,8 @@ def _create_hooks(self) -> List: Includes: - StopHook: Always enabled, cancels tool execution on user stop + - OAuthConsentHook: Pauses the agent (Strands interrupt) when an + OAuth-gated MCP tool is about to run without a cached token - Approval hooks: Gate dangerous operations for user confirmation Returns: @@ -197,6 +229,10 @@ def _create_hooks(self) -> List: # Always-on: session cancellation hooks.append(StopHook(self.session_manager)) + # OAuth consent gate for external MCP tools. Registered unconditionally; + # the hook is a no-op for tools that don't have a registered provider. + hooks.append(self._build_oauth_consent_hook()) + # Approval gates for dangerous operations hooks.append(EmailApprovalHook()) hooks.append(ExternalWriteApprovalHook()) @@ -204,6 +240,84 @@ def _create_hooks(self) -> List: return hooks + def _build_oauth_consent_hook(self) -> OAuthConsentHook: + """Construct the OAuth consent hook with closures over the MCP + integration and provider repository so it stays decoupled from them. + """ + from agents.main_agent.integrations.external_mcp_client import ( + get_external_mcp_integration, + ) + from strands.tools.mcp import MCPAgentTool + + integration = get_external_mcp_integration() + + def provider_lookup(selected_tool: object) -> Optional[str]: + if not isinstance(selected_tool, MCPAgentTool): + return None + return integration.provider_for_client(selected_tool.mcp_client) + + async def scopes_lookup(provider_id: str) -> List[str]: + from apis.shared.oauth.provider_repository import get_provider_repository + + provider = await get_provider_repository().get_provider(provider_id) + return provider.scopes if provider else [] + + async def provider_type_lookup(provider_id: str) -> Optional[str]: + # AgentCore Identity needs vendor-specific OAuth params + # forwarded via `customParameters` (e.g. Google's + # `access_type=offline` for refresh tokens). The hook reads + # this to forward those. + from apis.shared.oauth.provider_repository import get_provider_repository + + provider = await get_provider_repository().get_provider(provider_id) + return provider.provider_type.value if provider else None + + async def custom_parameters_lookup( + provider_id: str, + ) -> Optional[dict[str, str]]: + # Admin-supplied OAuth extras (e.g. `hd=mycorp.com` for + # Google Workspace domain restriction). Merged with the + # vendor baseline by `custom_parameters_for`; baseline wins + # on conflict. + from apis.shared.oauth.provider_repository import get_provider_repository + + provider = await get_provider_repository().get_provider(provider_id) + return provider.custom_parameters if provider else None + + async def disconnected_lookup(provider_id: str) -> bool: + # Durable per-(user, provider) disconnect intent. Read from DDB + # on every gate call so a /disconnect on another replica is + # picked up before the next tool runs. + from apis.shared.oauth.disconnect_repository import ( + get_disconnect_repository, + ) + + return await get_disconnect_repository().is_disconnected( + self.user_id, provider_id + ) + + async def mark_disconnected(provider_id: str) -> None: + # Persist a disconnect from the AfterToolCallEvent 401-retry + # path so subsequent requests (potentially on other replicas) + # also force a fresh consent. + from apis.shared.oauth.disconnect_repository import ( + get_disconnect_repository, + ) + + await get_disconnect_repository().mark_disconnected( + self.user_id, provider_id + ) + + return OAuthConsentHook( + user_id=self.user_id, + provider_lookup=provider_lookup, + scopes_lookup=scopes_lookup, + provider_type_lookup=provider_type_lookup, + custom_parameters_lookup=custom_parameters_lookup, + disconnected_lookup=disconnected_lookup, + mark_disconnected=mark_disconnected, + ) + def _build_filtered_tools(self) -> List: """ Filter tools and load gateway/external MCP clients. @@ -226,22 +340,34 @@ def _build_filtered_tools(self) -> List: if external_mcp_tool_ids: import asyncio + from bedrock_agentcore.runtime import BedrockAgentCoreContext + from agents.main_agent.integrations.external_mcp_client import get_external_mcp_integration + # Capture request-scoped context values before crossing the thread + # boundary below. ContextVars do not propagate into the executor's + # fresh event loop, so anything we need there must be passed as args. + oauth2_callback_url = BedrockAgentCoreContext.get_oauth2_callback_url() + workload_access_token = BedrockAgentCoreContext.get_workload_access_token() + external_integration = get_external_mcp_integration() loop = asyncio.get_event_loop() if loop.is_running(): import concurrent.futures - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit( - asyncio.run, - external_integration.load_external_tools( - external_mcp_tool_ids, - user_id=self.user_id, - auth_token=self.auth_token, - ), + async def _load_with_context(): + if oauth2_callback_url: + BedrockAgentCoreContext.set_oauth2_callback_url(oauth2_callback_url) + if workload_access_token: + BedrockAgentCoreContext.set_workload_access_token(workload_access_token) + return await external_integration.load_external_tools( + external_mcp_tool_ids, + user_id=self.user_id, + auth_token=self.auth_token, ) + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, _load_with_context()) external_clients = future.result() else: external_clients = loop.run_until_complete( diff --git a/backend/src/agents/main_agent/chat_agent.py b/backend/src/agents/main_agent/chat_agent.py index 97fbb702..edf26f44 100644 --- a/backend/src/agents/main_agent/chat_agent.py +++ b/backend/src/agents/main_agent/chat_agent.py @@ -6,7 +6,7 @@ """ import logging -from typing import AsyncGenerator, List, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional from agents.main_agent.base_agent import BaseAgent from agents.main_agent.core import AgentFactory @@ -43,17 +43,28 @@ def _create_agent(self) -> None: raise async def stream_async( - self, message: str, session_id: Optional[str] = None, files: Optional[List] = None, citations: Optional[List] = None, original_message: Optional[str] = None + self, + message: str, + session_id: Optional[str] = None, + files: Optional[List] = None, + citations: Optional[List] = None, + original_message: Optional[str] = None, + interrupt_responses: Optional[List[Dict[str, Any]]] = None, ) -> AsyncGenerator[str, None]: """ Stream agent responses. Args: - message: User message text + message: User message text. Ignored when resuming via + `interrupt_responses` — the paused turn already has the + original prompt in `_interrupt_state`. session_id: Session identifier (defaults to instance session_id) files: Optional list of FileContent objects (with base64 bytes) citations: Optional list of citation dicts from RAG retrieval original_message: Original user message before RAG augmentation + interrupt_responses: When set, resume a paused agent turn by + passing this list as the prompt to Strands. Each entry is + `{"interruptResponse": {"interruptId": str, "response": Any}}`. Yields: str: SSE formatted events @@ -61,7 +72,14 @@ async def stream_async( if not self.agent: self._create_agent() - prompt = self.multimodal_builder.build_prompt(message, files) + if interrupt_responses: + # Strands' resume protocol: passing a list of interrupt responses + # as the prompt re-enters the loop, populates the matching + # interrupts' `.response`, and continues from the paused tool + # call. multimodal_builder + files do not apply here. + prompt: Any = interrupt_responses + else: + prompt = self.multimodal_builder.build_prompt(message, files) async for event in self.stream_coordinator.stream_response( agent=self.agent, diff --git a/backend/src/agents/main_agent/integrations/agentcore_identity.py b/backend/src/agents/main_agent/integrations/agentcore_identity.py new file mode 100644 index 00000000..f2d84e88 --- /dev/null +++ b/backend/src/agents/main_agent/integrations/agentcore_identity.py @@ -0,0 +1,352 @@ +"""AgentCore Identity integration for external MCP tool authorization. + +Wraps `bedrock_agentcore.services.identity.IdentityClient` with a narrower, +platform-friendly surface for retrieving OAuth2 access tokens on behalf of a +user via the USER_FEDERATION (3LO) flow. + +The client pulls the per-invocation workload identity token from +`BedrockAgentCoreContext`, which is populated by `AgentCoreContextMiddleware` +on the Inference API request path. No workload token has to be threaded +through function arguments. + +Two results are possible when fetching a token: + +1. A valid token exists in the AgentCore Token Vault for this user+provider + → returned synchronously as `TokenResult(access_token=...)`. +2. The user has never consented (or consent has been revoked, or scopes have + changed) → the caller receives `TokenResult(authorization_url=...)`. The + URL must be surfaced to the user; after they complete the consent flow the + frontend calls `CompleteResourceTokenAuthCommand` and the next tool call + will hit case 1. + +This module intentionally does not raise on "consent required" — it returns +a structured result because surfacing an auth URL is a normal, expected +outcome that flows through our SSE stream, not an error. +""" + +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass +from typing import Dict, List, Optional + +import boto3 +from bedrock_agentcore.runtime import BedrockAgentCoreContext +from bedrock_agentcore.services.identity import IdentityClient, TokenPoller + +logger = logging.getLogger(__name__) + + +class _ConsentRequired(Exception): + """Internal marker — raised by `_ShortCircuitPoller` once AgentCore hands + us an auth URL, so we can return it to the caller without waiting for + the user to actually complete consent.""" + + +class _ShortCircuitPoller(TokenPoller): + """Skip the SDK's default poll loop. + + The default poller hits `GetResourceOauth2Token` on a timer until the + user finishes consent (up to several minutes). We only care about the + URL — our caller returns it to the frontend, which drives the popup + flow on its own. Raising immediately short-circuits the wait. + """ + + async def poll_for_token(self) -> str: + raise _ConsentRequired() + +# In production, the AgentCore Runtime proxies every request to the inference +# API with a `WorkloadAccessToken` header bound to (runtime workload, user). +# `AgentCoreContextMiddleware` copies that header onto `BedrockAgentCoreContext` +# so downstream code can fetch user-federated OAuth tokens without threading +# it through function args. +# +# Local dev doesn't go through the runtime, so the header is absent. When +# `AGENTCORE_RUNTIME_WORKLOAD_NAME` is set, we fall back to minting a workload +# token against that runtime ourselves via +# `bedrock-agentcore:GetWorkloadAccessTokenForUserId`. The caller's AWS +# principal must be authorised for that action on the target workload. +_RUNTIME_WORKLOAD_ENV = "AGENTCORE_RUNTIME_WORKLOAD_NAME" + +# Same shape as above for the OAuth2 callback URL. The runtime injects an +# `OAuth2CallbackUrl` header on every proxied request; in local dev that +# header is absent and `BedrockAgentCoreContext.get_oauth2_callback_url()` +# returns None. Without a callback URL the SDK builds an authorize URL whose +# redirect points to a default that never reaches our `/oauth-complete` +# page, so consent silently fails to finalize and the user is re-prompted on +# every request. Set this env var to your frontend's `/oauth-complete` URL +# (e.g. `http://localhost:4200/oauth-complete`) to make chat-triggered +# consent work outside the runtime. +_LOCAL_CALLBACK_URL_ENV = "AGENTCORE_LOCAL_OAUTH_CALLBACK_URL" + + +def _vendor_baseline_params(provider_type: Optional[str]) -> Dict[str, str]: + """Hardcoded params AgentCore Identity *requires* for a given vendor. + + Per the AgentCore Identity authentication docs + (https://docs.aws.amazon.com/bedrock-agentcore/latest/devguide/identity-authentication.html), + Google must receive `access_type=offline` to issue a refresh token — + without it the vault entry expires after ~1 hour with no refresh + path. This is non-negotiable: it always wins over admin-supplied + extras to prevent an admin from accidentally turning it off. + """ + if not provider_type: + return {} + if provider_type.lower() == "google": + return {"access_type": "offline"} + return {} + + +def custom_parameters_for( + provider_type: Optional[str], + admin_extras: Optional[Dict[str, str]] = None, +) -> Optional[Dict[str, str]]: + """Build the `customParameters` payload AgentCore Identity wants forwarded. + + Merges admin-supplied extras (e.g. Google `hd=mycorp.com` for domain + restriction, `prompt=consent` for stricter UX) with the hardcoded + vendor baseline. Baseline keys win on conflict — admins cannot turn + off a documented requirement. + + Returns None when the merged result would be empty, so callers can + pass the value through to the SDK unconditionally without sending + an empty `customParameters` map. + """ + baseline = _vendor_baseline_params(provider_type) + merged = {**(admin_extras or {}), **baseline} + return merged or None + + +@dataclass(frozen=True) +class TokenResult: + """Result of a token fetch attempt. + + Exactly one of `access_token` or `authorization_url` will be populated. + """ + + access_token: Optional[str] = None + authorization_url: Optional[str] = None + + @property + def requires_consent(self) -> bool: + return self.access_token is None and self.authorization_url is not None + + def __post_init__(self) -> None: + if bool(self.access_token) == bool(self.authorization_url): + raise ValueError( + "TokenResult must have exactly one of access_token or authorization_url" + ) + + +class WorkloadTokenUnavailableError(RuntimeError): + """Raised when no workload access token is present on the current context. + + This indicates the caller is running outside an AgentCore Runtime + invocation, or the `AgentCoreContextMiddleware` was not applied. + """ + + +class CallbackUrlUnavailableError(RuntimeError): + """Raised when no OAuth2 callback URL can be resolved for an authorize call. + + Surfaced instead of silently passing `None` to the SDK, which builds an + authorize URL whose redirect never reaches `/oauth-complete` — the user + finishes consent at the provider but the token is never persisted to + AgentCore's vault, so the next request prompts them to consent again. + """ + + +class AgentCoreIdentityClient: + """Thin async-friendly wrapper around `IdentityClient` for 3LO tokens. + + The underlying `IdentityClient` is synchronous and uses boto3; callers + should treat `get_token_for_user` as potentially blocking and run it via + `asyncio.to_thread` when invoked from async code. + """ + + def __init__(self, region: Optional[str] = None): + self._region = region or os.environ.get("AWS_REGION", "us-east-1") + self._client = IdentityClient(region=self._region) + self._control_client = boto3.client("bedrock-agentcore", region_name=self._region) + + async def get_token_for_user( + self, + *, + provider_name: str, + scopes: List[str], + callback_url: Optional[str] = None, + force_authentication: bool = False, + user_id: Optional[str] = None, + custom_state: Optional[str] = None, + custom_parameters: Optional[Dict[str, str]] = None, + ) -> TokenResult: + """Fetch a user-federated OAuth2 access token for `provider_name`. + + In production the workload identity token comes from + `BedrockAgentCoreContext` (populated by `AgentCoreContextMiddleware` + from the AgentCore Runtime's request header). For local dev the + header is absent — when `AGENTCORE_RUNTIME_WORKLOAD_NAME` is set + and `user_id` is provided, we mint a workload token ourselves + against that runtime. + + If the user has not consented (or re-consent is required), returns a + `TokenResult` with `authorization_url` populated instead of raising. + + Args: + provider_name: Credential provider name registered with AgentCore + Identity (e.g. "google-workspace"). + scopes: OAuth2 scopes to request for this token. + callback_url: OAuth2 return URL. Defaults to the callback URL on + the current context (injected by Runtime via the + `OAuth2CallbackUrl` header). + force_authentication: If True, bypasses the token vault cache and + forces the user through the consent flow again. Used for + scope upgrades. + user_id: User identifier for the local-dev workload-token + fallback. Ignored in production where the context already + has a token. + + Returns: + `TokenResult` with either `access_token` or `authorization_url`. + + Raises: + WorkloadTokenUnavailableError: No token on context and the + local-dev fallback is unavailable (env var unset, user_id + missing, or IAM denies the mint call). + CallbackUrlUnavailableError: No callback URL on context and the + local-dev fallback env var is unset. + """ + workload_token = self._resolve_workload_token(user_id) + resolved_callback_url = self._resolve_callback_url(callback_url, provider_name) + + captured_url: dict[str, Optional[str]] = {"url": None} + + def _capture_auth_url(url: str) -> None: + captured_url["url"] = url + + try: + sdk_kwargs = dict( + provider_name=provider_name, + scopes=scopes, + agent_identity_token=workload_token, + auth_flow="USER_FEDERATION", + callback_url=resolved_callback_url, + force_authentication=force_authentication, + on_auth_url=_capture_auth_url, + token_poller=_ShortCircuitPoller(), + ) + if custom_state is not None: + sdk_kwargs["custom_state"] = custom_state + if custom_parameters: + sdk_kwargs["custom_parameters"] = custom_parameters + token = await self._client.get_token(**sdk_kwargs) + except _ConsentRequired: + # Expected path when consent is required: the SDK invoked + # on_auth_url and then handed off to our poller, which raises. + token = None + + # If we captured a URL, return it — even if the SDK also returned + # a (stale) token, consent-required is the authoritative signal. + if captured_url["url"]: + logger.info( + "AgentCore Identity requires user consent for provider=%s", + provider_name, + ) + return TokenResult(authorization_url=captured_url["url"]) + + if not token: + raise RuntimeError( + f"AgentCore Identity returned neither a token nor an " + f"authorization URL for provider={provider_name}" + ) + + return TokenResult(access_token=token) + + def _resolve_callback_url( + self, explicit: Optional[str], provider_name: str + ) -> str: + """Pick the OAuth2 callback URL and tag it with `provider_id`. + + Resolution order: explicit arg → request-scoped context (Runtime + header) → `AGENTCORE_LOCAL_OAUTH_CALLBACK_URL` env var (local-dev + escape hatch). Raises `CallbackUrlUnavailableError` when none is + available — passing None to the SDK silently breaks consent. + + AgentCore's redirect doesn't echo any provider hint, so we append + `provider_id` as a query param so `/oauth-complete` can dismiss the + right pending consent entry. + """ + from urllib.parse import urlencode, urlparse, urlunparse, parse_qsl + + base = ( + explicit + or BedrockAgentCoreContext.get_oauth2_callback_url() + or os.environ.get(_LOCAL_CALLBACK_URL_ENV) + ) + if not base: + raise CallbackUrlUnavailableError( + "No OAuth2 callback URL available. In production the " + "AgentCore Runtime injects this via the `OAuth2CallbackUrl` " + "header; for local dev set " + f"{_LOCAL_CALLBACK_URL_ENV} to your frontend's " + "/oauth-complete URL (e.g. " + "http://localhost:4200/oauth-complete)." + ) + + parsed = urlparse(base) + existing = dict(parse_qsl(parsed.query, keep_blank_values=True)) + existing.setdefault("provider_id", provider_name) + return urlunparse(parsed._replace(query=urlencode(existing))) + + def _resolve_workload_token(self, user_id: Optional[str]) -> str: + """Return a workload access token, preferring the runtime-supplied one. + + Falls back to minting via `GetWorkloadAccessTokenForUserId` when the + context has no token, the `AGENTCORE_RUNTIME_WORKLOAD_NAME` env var + is set, and the caller passed `user_id`. Any other combination + raises `WorkloadTokenUnavailableError`. + """ + context_token = BedrockAgentCoreContext.get_workload_access_token() + if context_token: + return context_token + + workload_name = os.environ.get(_RUNTIME_WORKLOAD_ENV) + if not workload_name: + raise WorkloadTokenUnavailableError( + "No WorkloadAccessToken on context. For local dev, set " + f"{_RUNTIME_WORKLOAD_ENV} to your deployed runtime's " + "workload identity name (e.g. hosted_agent_XXXXX)." + ) + if not user_id: + raise WorkloadTokenUnavailableError( + "No WorkloadAccessToken on context and no user_id provided " + "for the local-dev mint fallback." + ) + + logger.info( + "Minting workload access token for user=%s workload=%s (local dev)", + user_id, + workload_name, + ) + response = self._control_client.get_workload_access_token_for_user_id( + workloadName=workload_name, + userId=user_id, + ) + minted_token = response.get("workloadAccessToken") + if not minted_token: + raise WorkloadTokenUnavailableError( + "GetWorkloadAccessTokenForUserId returned no token" + ) + return minted_token + + +_default_client: Optional[AgentCoreIdentityClient] = None + + +def get_agentcore_identity_client() -> AgentCoreIdentityClient: + """Return the process-wide `AgentCoreIdentityClient` singleton.""" + global _default_client + if _default_client is None: + _default_client = AgentCoreIdentityClient() + return _default_client diff --git a/backend/src/agents/main_agent/integrations/external_mcp_client.py b/backend/src/agents/main_agent/integrations/external_mcp_client.py index 403bc9ea..bc6b08f6 100644 --- a/backend/src/agents/main_agent/integrations/external_mcp_client.py +++ b/backend/src/agents/main_agent/integrations/external_mcp_client.py @@ -5,14 +5,17 @@ supporting various authentication methods (AWS IAM, API Key, OAuth, etc.) OAuth Support: - When a tool has `requires_oauth_provider` set, the MCP client will - automatically inject the user's OAuth token into requests. This requires - per-user client instances since tokens are user-specific. + Tools with `requires_oauth_provider` set get an `OAuthBearerAuth` whose + token is resolved lazily on every MCP request via `oauth_token_cache`. + The cache is warmed by `OAuthConsentHook` (which also pauses the agent + via Strands interrupts when consent is needed). This module never + pre-flights OAuth — the agent loads tools optimistically; the hook + gates execution. """ import logging import re -from typing import Optional, List, Any +from typing import Any, Callable, Optional, List from mcp.client.streamable_http import streamablehttp_client from strands.tools.mcp import MCPClient @@ -23,6 +26,7 @@ MCPTransport, ToolDefinition, ) +from agents.main_agent.integrations import oauth_token_cache from agents.main_agent.integrations.gateway_auth import get_sigv4_auth from agents.main_agent.integrations.oauth_auth import ( CompositeAuth, @@ -91,28 +95,22 @@ def create_external_mcp_client( config: MCPServerConfig, tool_definition: Optional[ToolDefinition] = None, oauth_token: Optional[str] = None, + token_provider: Optional[Callable[[], Optional[str]]] = None, ) -> Optional[MCPClient]: """ Create an MCP client for an externally deployed MCP server. + Pass either `oauth_token` (static, for OIDC forwarding) or `token_provider` + (callable, for OAuth tokens that the consent hook resolves lazily). + Args: config: MCP server configuration from tool catalog tool_definition: Optional tool definition for logging - oauth_token: Optional OAuth token to include in requests (for user-specific auth) + oauth_token: Optional static token (used for OIDC forwarding) + token_provider: Optional callable returning the current token Returns: MCPClient instance or None if configuration is invalid - - Example: - >>> config = MCPServerConfig( - ... server_url="https://xxx.lambda-url.us-west-2.on.aws/", - ... transport=MCPTransport.STREAMABLE_HTTP, - ... auth_type=MCPAuthType.AWS_IAM, - ... ) - >>> client = create_external_mcp_client(config) - - # With OAuth token for user-specific access: - >>> client = create_external_mcp_client(config, oauth_token="user_access_token") """ if not config.server_url: logger.warning("MCP server URL is required") @@ -120,26 +118,29 @@ def create_external_mcp_client( tool_id = tool_definition.tool_id if tool_definition else "unknown" requires_oauth = tool_definition.requires_oauth_provider if tool_definition else None - has_token = bool(oauth_token) + has_static_token = bool(oauth_token) + has_provider = bool(token_provider) logger.info(f"Creating external MCP client for tool: {tool_id}") logger.debug(f" Transport: {config.transport}") logger.debug(f" Auth Type: {config.auth_type}") if requires_oauth: logger.debug(" Requires OAuth Provider: yes") - logger.debug(f" OAuth Token Provided: {has_token}") + logger.debug(f" Token mode: {'provider' if has_provider else 'static' if has_static_token else 'none'}") try: # Build list of auth handlers (may combine multiple) auth_handlers = [] - # When an OAuth token is provided, use it exclusively as the auth method. - # SigV4 and OAuth both use the Authorization header and cannot coexist — - # SigV4 sets "AWS4-HMAC-SHA256 ..." while OAuth sets "Bearer ...". - # The Lambda Function URL auth type should be NONE for OAuth-authenticated tools. - if oauth_token: - oauth_auth = create_oauth_bearer_auth(token=oauth_token) - auth_handlers.append(oauth_auth) - logger.debug(" Using OAuth Bearer token auth (skipping SigV4)") + # OAuth/OIDC bearer auth takes precedence over SigV4. SigV4 and OAuth both + # use the Authorization header and cannot coexist (SigV4 sets + # "AWS4-HMAC-SHA256 ...", OAuth sets "Bearer ..."). Lambda Function URLs + # backing OAuth-authenticated MCP servers must use auth_type=NONE. + if token_provider: + auth_handlers.append(create_oauth_bearer_auth(token_provider=token_provider)) + logger.debug(" Using OAuth Bearer token provider (lazy)") + elif oauth_token: + auth_handlers.append(create_oauth_bearer_auth(token=oauth_token)) + logger.debug(" Using OAuth Bearer token (static)") # AWS IAM SigV4 authentication (for Lambda/API Gateway without OAuth) elif config.auth_type == MCPAuthType.AWS_IAM or config.auth_type == "aws-iam": @@ -204,46 +205,36 @@ class ExternalMCPIntegration: with protocol='mcp_external' in the tool catalog. OAuth Support: - Tools with `requires_oauth_provider` set will have their MCP clients - created with the user's OAuth token injected. Since tokens are user-specific, - OAuth-enabled tools use a per-user cache key. + For OAuth-gated tools, clients are created with a lazy token provider + that reads from the per-process token cache at request time. The + cache is populated by `OAuthConsentHook`, which gates execution by + raising a Strands interrupt when no token is available yet. + + This integration also maintains an MCPClient -> provider_id map so + the hook can look up which provider a tool's MCP server requires + without coupling on tool names. """ def __init__(self): - """Initialize external MCP integration.""" # Cache key: tool_id for non-OAuth tools, "user_id:tool_id" for OAuth tools self.clients: dict[str, MCPClient] = {} + # Parallel map of cache_key -> tool updated_at at the time the + # client was built. On lookup we compare against the tool's + # current updated_at; mismatch means the admin edited the tool + # (URL, auth, provider, etc.) and we must rebuild the client. + self._client_versions: dict[str, str] = {} + # MCPClient object identity -> provider_id, populated alongside `clients`. + # Consumed by OAuthConsentHook via `provider_for_client`. + self._provider_for_client_id: dict[int, str] = {} def _get_cache_key(self, tool_id: str, user_id: Optional[str], requires_oauth: bool) -> str: - """Get the cache key for a tool client.""" if requires_oauth and user_id: return f"{user_id}:{tool_id}" return tool_id - async def _get_oauth_token( - self, - user_id: str, - provider_id: str, - ) -> Optional[str]: - """ - Get decrypted OAuth token for a user and provider. - - Args: - user_id: The user's ID - provider_id: The OAuth provider ID - - Returns: - Decrypted access token or None if not connected - """ - try: - from apis.shared.oauth.service import get_oauth_service - - oauth_service = get_oauth_service() - token = await oauth_service.get_decrypted_token(user_id, provider_id) - return token - except Exception as e: - logger.error("Error getting OAuth token") - return None + def provider_for_client(self, client: Any) -> Optional[str]: + """Return the OAuth provider_id backing `client`, or None.""" + return self._provider_for_client_id.get(id(client)) async def load_external_tools( self, @@ -254,15 +245,16 @@ async def load_external_tools( """ Load external MCP clients for enabled tools. - This method queries the tool catalog for tools with protocol='mcp_external' - and creates MCP clients for them. For tools requiring OAuth, the user's - OAuth token is retrieved and injected. For tools with forward_auth_token, - the user's OIDC authentication token is forwarded instead. + For OAuth-gated tools, the client is created with a token provider + that reads from `oauth_token_cache` lazily at request time. Token + acquisition + consent prompting happen in `OAuthConsentHook` at + tool-call time, not here. For tools with `forward_auth_token`, the + user's OIDC token is injected statically. Args: enabled_tool_ids: List of enabled tool IDs - user_id: User ID for OAuth token retrieval (required for OAuth-enabled tools) - auth_token: Raw OIDC token for forwarding (required for forward_auth_token tools) + user_id: User ID (required for OAuth-gated and OIDC-forwarded tools) + auth_token: Raw OIDC token for forwarding Returns: List of MCPClient instances to add to the agent's tools @@ -278,7 +270,6 @@ async def load_external_tools( if not tool: continue - # Check if this is an external MCP tool if tool.protocol != "mcp_external": continue @@ -286,34 +277,44 @@ async def load_external_tools( logger.warning(f"Tool {tool_id} has protocol=mcp_external but no mcp_config") continue - # Determine auth mode: OIDC forwarding, OAuth, or none forward_auth = bool(getattr(tool, "forward_auth_token", False)) requires_oauth = bool(tool.requires_oauth_provider) requires_user_auth = forward_auth or requires_oauth cache_key = self._get_cache_key(tool_id, user_id, requires_user_auth) + tool_version = ( + tool.updated_at.isoformat() + "Z" if tool.updated_at else "" + ) - # Check cache - if cache_key in self.clients: + if ( + cache_key in self.clients + and self._client_versions.get(cache_key) == tool_version + ): clients.append(self.clients[cache_key]) continue - # Resolve token to use (OIDC forwarding takes precedence) - token_to_use = None + # Stale entry — admin edited this tool since the client + # was built. Drop it so the block below creates a fresh + # client with the current config. + if cache_key in self.clients: + stale = self.clients.pop(cache_key) + self._client_versions.pop(cache_key, None) + self._provider_for_client_id.pop(id(stale), None) + + static_token: Optional[str] = None + token_provider: Optional[Callable[[], Optional[str]]] = None + provider_id: Optional[str] = None if forward_auth: - # Forward the user's OIDC authentication token if not auth_token: logger.warning( f"Tool {tool_id} has forward_auth_token=true but no auth_token provided" ) - # Still create the client - server will reject unauthorized requests else: - token_to_use = auth_token + static_token = auth_token logger.info(f"Using OIDC token forwarding for tool {tool_id}") elif requires_oauth: - # Use stored OAuth token from provider if not user_id: logger.warning( f"Tool {tool_id} requires OAuth provider '{tool.requires_oauth_provider}' " @@ -321,31 +322,44 @@ async def load_external_tools( ) continue - token_to_use = await self._get_oauth_token( - user_id=user_id, - provider_id=tool.requires_oauth_provider, + provider_id = tool.requires_oauth_provider + # Bind user_id and provider_id at closure time so the + # provider stays valid for the client's lifetime. + token_provider = ( + lambda u=user_id, p=provider_id: oauth_token_cache.get(u, p) ) - if not token_to_use: - logger.warning( - "User not connected to required OAuth provider for tool" - ) - # Still create the client - it will fail gracefully when used - # The MCP server should return an appropriate error - - # Create MCP client with optional token (works for both OAuth and OIDC) client = create_external_mcp_client( config=tool.mcp_config, tool_definition=tool, - oauth_token=token_to_use, + oauth_token=static_token, + token_provider=token_provider, ) if client: + # Pre-flight the MCP session so a single unreachable + # server (e.g. a connector that isn't running locally) + # drops out of the registry instead of failing the + # whole turn when Strands later calls load_tools(). + # On success this also primes the client's tool cache, + # so Strands' subsequent load_tools() is a no-op. + try: + await client.load_tools() + except Exception as exc: + logger.warning( + f"Skipping external MCP tool {tool_id}: " + f"failed to start client ({exc})" + ) + continue + self.clients[cache_key] = client + self._client_versions[cache_key] = tool_version + if provider_id: + self._provider_for_client_id[id(client)] = provider_id clients.append(client) auth_label = ( - " (with OIDC forwarding)" if forward_auth and token_to_use - else " (with OAuth)" if requires_oauth and token_to_use + " (with OIDC forwarding)" if forward_auth and static_token + else " (OAuth)" if provider_id else "" ) logger.info(f"✅ Loaded external MCP tool: {tool_id}{auth_label}") @@ -357,17 +371,6 @@ async def load_external_tools( return clients def get_client(self, tool_id: str, user_id: Optional[str] = None) -> Optional[MCPClient]: - """ - Get a specific MCP client by tool ID. - - Args: - tool_id: The tool ID - user_id: User ID for OAuth-enabled tools - - Returns: - MCPClient or None if not found - """ - # Try user-specific key first, then generic if user_id: user_key = f"{user_id}:{tool_id}" if user_key in self.clients: @@ -375,15 +378,6 @@ def get_client(self, tool_id: str, user_id: Optional[str] = None) -> Optional[MC return self.clients.get(tool_id) def add_to_tool_list(self, tools: List[Any]) -> List[Any]: - """ - Add all loaded external MCP clients to the tool list. - - Args: - tools: Existing list of tools - - Returns: - Updated tool list with MCP clients added - """ for client in self.clients.values(): if client not in tools: tools.append(client) @@ -393,24 +387,44 @@ def clear_user_clients(self, user_id: str) -> None: """ Clear cached MCP clients for a specific user. - Call this when a user disconnects from an OAuth provider - to ensure fresh clients are created on next use. - - Args: - user_id: User ID to clear clients for + Call this when a user disconnects from an OAuth provider so the next + agent build creates fresh clients (and the token cache miss forces a + new consent flow). """ keys_to_remove = [ key for key in self.clients.keys() if key.startswith(f"{user_id}:") ] for key in keys_to_remove: - del self.clients[key] + client = self.clients.pop(key) + self._client_versions.pop(key, None) + self._provider_for_client_id.pop(id(client), None) if keys_to_remove: logger.info(f"Cleared {len(keys_to_remove)} cached MCP clients for user {user_id}") + def clear_tool_clients(self, tool_id: str) -> None: + """ + Clear cached MCP clients for a specific tool across all users. + + Call this when a tool's config changes (e.g. admin updates the MCP + server URL) so the next agent build reconnects using the fresh + config. Without this, clients cached at process start continue to + point at the old URL for the lifetime of the process. + """ + keys_to_remove = [ + key for key in self.clients.keys() + if key == tool_id or key.endswith(f":{tool_id}") + ] + for key in keys_to_remove: + client = self.clients.pop(key) + self._client_versions.pop(key, None) + self._provider_for_client_id.pop(id(client), None) + + if keys_to_remove: + logger.info(f"Cleared {len(keys_to_remove)} cached MCP clients for tool {tool_id}") + -# Global instance _external_mcp_integration: Optional[ExternalMCPIntegration] = None diff --git a/backend/src/agents/main_agent/integrations/oauth_token_cache.py b/backend/src/agents/main_agent/integrations/oauth_token_cache.py new file mode 100644 index 00000000..5feb18aa --- /dev/null +++ b/backend/src/agents/main_agent/integrations/oauth_token_cache.py @@ -0,0 +1,51 @@ +"""In-process cache of OAuth access tokens, keyed by (user_id, provider_id). + +Lives for the lifetime of the inference API process. The authoritative store +is AgentCore Identity's token vault — this cache is just a hot path so the +`OAuthBearerAuth` token provider doesn't have to call AgentCore on every +MCP request. + +Tokens are written when: + * `OAuthConsentHook` warms the cache after a successful vault lookup, or + * the resume path re-fetches a token after the user completes consent. + +Tokens are evicted explicitly via `clear_user_provider` when consent is +revoked or expires; we don't track expiry locally because AgentCore +Identity owns refresh. + +Disconnect intent ("user pressed Disconnect" / "tool returned 401") is *not* +held here — it lives in the DDB-backed `OAuthDisconnectRepository` so it's +visible across replicas. The cache only holds tokens. +""" + +from __future__ import annotations + +import threading +from typing import Optional + + +_lock = threading.Lock() +_cache: dict[tuple[str, str], str] = {} + + +def get(user_id: str, provider_id: str) -> Optional[str]: + with _lock: + return _cache.get((user_id, provider_id)) + + +def set(user_id: str, provider_id: str, token: str) -> None: + with _lock: + _cache[(user_id, provider_id)] = token + + +def clear_user_provider(user_id: str, provider_id: str) -> None: + with _lock: + _cache.pop((user_id, provider_id), None) + + +def clear_user(user_id: str) -> int: + with _lock: + keys = [k for k in _cache if k[0] == user_id] + for key in keys: + del _cache[key] + return len(keys) diff --git a/backend/src/agents/main_agent/session/hooks/__init__.py b/backend/src/agents/main_agent/session/hooks/__init__.py index 9900a4a5..c1b1ce0f 100644 --- a/backend/src/agents/main_agent/session/hooks/__init__.py +++ b/backend/src/agents/main_agent/session/hooks/__init__.py @@ -1,5 +1,6 @@ """Hooks for Main Agent""" +from agents.main_agent.session.hooks.oauth_consent import OAuthConsentHook from agents.main_agent.session.hooks.stop import StopHook from agents.main_agent.session.hooks.tool_approval import ( ToolApprovalHook, @@ -9,6 +10,7 @@ ) __all__ = [ + "OAuthConsentHook", "StopHook", "ToolApprovalHook", "EmailApprovalHook", diff --git a/backend/src/agents/main_agent/session/hooks/oauth_consent.py b/backend/src/agents/main_agent/session/hooks/oauth_consent.py new file mode 100644 index 00000000..307a00d9 --- /dev/null +++ b/backend/src/agents/main_agent/session/hooks/oauth_consent.py @@ -0,0 +1,443 @@ +"""OAuth consent gate for external MCP tools. + +Fires on every `BeforeToolCallEvent`. If the tool about to run is backed by +an MCP server that requires user-federated OAuth (per the tool catalog), +the hook ensures we have an access token in the in-process cache. If we +don't, it calls `event.interrupt(...)` to pause the agent mid-turn and +hand the authorization URL back to the caller. + +When the user completes consent in the popup and the frontend resumes the +turn, the hook fires a second time and `event.interrupt(...)` returns the +user's response (instead of raising). At that point AgentCore Identity has +the new token in its vault, so we re-fetch and warm the cache; the +`OAuthBearerAuth` token provider then injects it on the next MCP request. + +The hook never aborts the turn on its own — `cancel_tool` is reserved for +genuine refusal (e.g. consent declined). If the user closes the popup we +don't reach that path; the agent simply remains paused until a resume +arrives or the session times out. +""" + +from __future__ import annotations + +import asyncio +import inspect +import logging +import re +from typing import Any, Awaitable, Callable, Optional, Union + +from strands.hooks import ( + AfterToolCallEvent, + BeforeInvocationEvent, + BeforeToolCallEvent, + HookProvider, + HookRegistry, +) + +from agents.main_agent.integrations import oauth_token_cache +from agents.main_agent.integrations.agentcore_identity import ( + CallbackUrlUnavailableError, + WorkloadTokenUnavailableError, + custom_parameters_for, + get_agentcore_identity_client, +) + +logger = logging.getLogger(__name__) + + +# Markers that indicate an OAuth-style auth failure in a tool result. +# A false positive triggers an unnecessary OAuth popup — far more +# disruptive than a missed match (which surfaces the underlying error to +# the user). So we err on the side of high-confidence signals only. +# +# Tiers: +# 1. HTTP 401 with negative lookarounds for path segments / adjacent +# digits. Bare "401" in MCP error text is almost always an HTTP +# status code in practice. +# 2. "Unauthorized" only when paired with an HTTP/status/code keyword. +# The bare word fires on prose like "you are not authorized to view +# this calendar" — which is application-level, not OAuth. +# 3. Unambiguous OAuth/token signals stand alone — `invalid_token`, +# `invalid_grant` (refresh-token revocation), Google API's +# `UNAUTHENTICATED` and `invalid authentication credentials`. +# +# We only run this on results whose `status == "error"` +# (see `_looks_like_auth_failure`), so even the broader patterns above +# are gated by an explicit failure signal from the MCP framework. +_AUTH_FAILURE_PATTERN = re.compile( + r"(? bool: + """Heuristic: does this tool result look like an OAuth 401? + + Inspects the result's status and content for one of the markers above. + False positives here just trigger a wasted retry; false negatives + leave the user stuck with a stale token, so we err on the side of + matching. + """ + if not isinstance(tool_result, dict): + return False + if tool_result.get("status") != "error": + return False + for block in tool_result.get("content", []) or []: + if not isinstance(block, dict): + continue + text = block.get("text") or "" + if isinstance(text, str) and _AUTH_FAILURE_PATTERN.search(text): + return True + return False + + +# Returns provider_id for a Strands `selected_tool`, or None if the tool +# isn't OAuth-gated. Encapsulates the MCPClient -> provider mapping. +ProviderLookup = Callable[[Any], Optional[str]] + +# Returns OAuth scopes for a provider_id. May be sync or async; the hook +# awaits the result either way so we can read from an async repository +# without forcing a sync wrapper. +ScopesLookup = Callable[[str], Union[list[str], Awaitable[list[str]]]] + +# Returns the provider's vendor type (e.g. "google", "microsoft") for a +# provider_id, or None if unknown / no per-vendor params needed. Optional — +# omitted in older tests; without it AgentCore Identity gets no +# `customParameters`, which means Google won't issue a refresh token and +# the vault entry expires after ~1 hour. +ProviderTypeLookup = Callable[[str], Union[Optional[str], Awaitable[Optional[str]]]] + +# Returns admin-supplied OAuth params (e.g. `hd=mycorp.com` for Google +# Workspace domain restriction) for a provider_id. Merged with the +# vendor baseline by `custom_parameters_for`; baseline wins on conflict. +CustomParametersLookup = Callable[ + [str], Union[Optional[dict[str, str]], Awaitable[Optional[dict[str, str]]]] +] + +# Returns whether the caller has been marked disconnected from this provider +# (set by the /disconnect route or by a prior 401 retry). When True, the +# hook bypasses the local token cache and asks AgentCore Identity for a +# fresh consent URL with `force_authentication=True`. +DisconnectedLookup = Callable[[str], Union[bool, Awaitable[bool]]] + +# Records a disconnect for the caller — invoked from the AfterToolCallEvent +# path when a tool returns a 401 against the cached vault token. Called +# instead of mutating per-process state so the intent is durable across +# replicas. +MarkDisconnected = Callable[[str], Union[None, Awaitable[None]]] + + +class OAuthConsentHook(HookProvider): + """Pause the agent if a tool needs OAuth and we don't have a token yet.""" + + def __init__( + self, + user_id: str, + provider_lookup: ProviderLookup, + scopes_lookup: ScopesLookup, + provider_type_lookup: Optional[ProviderTypeLookup] = None, + custom_parameters_lookup: Optional[CustomParametersLookup] = None, + disconnected_lookup: Optional[DisconnectedLookup] = None, + mark_disconnected: Optional[MarkDisconnected] = None, + ): + """Initialize. + + Args: + user_id: User the agent is running for. Used as cache key and + passed to AgentCore Identity for the local-dev workload-token + fallback (no-op in production). + provider_lookup: See `ProviderLookup`. + scopes_lookup: See `ScopesLookup`. + provider_type_lookup: See `ProviderTypeLookup`. Optional. When + provided, the hook forwards vendor-specific OAuth params + (e.g. Google's `access_type=offline`) to AgentCore Identity. + custom_parameters_lookup: See `CustomParametersLookup`. + Optional. Admin-supplied extras to merge with the vendor + baseline. + disconnected_lookup: See `DisconnectedLookup`. Optional. When + omitted, the hook never bypasses the local token cache — + effectively assumes the user has not disconnected. Wire + this to the durable disconnect repository in production so + a /disconnect on one replica is visible from any other. + mark_disconnected: See `MarkDisconnected`. Optional. Invoked + from the 401-retry path; without it, a 401 still flips + `event.retry = True` but leaves no durable record, so the + next BeforeToolCallEvent on a different replica won't know + to force a fresh consent. + """ + self._user_id = user_id + self._provider_lookup = provider_lookup + self._scopes_lookup = scopes_lookup + self._provider_type_lookup = provider_type_lookup + self._custom_parameters_lookup = custom_parameters_lookup + self._disconnected_lookup = disconnected_lookup + self._mark_disconnected = mark_disconnected + # Cache scopes per provider for the lifetime of this hook (one agent + # invocation). Avoids repeated DB hits if the same provider is used + # across multiple tool calls in a single turn. + self._scopes_cache: dict[str, list[str]] = {} + # Same cache shape for provider_type. `None` is a legitimate value + # (vendor without extra params), so we use a separate sentinel set + # to distinguish "unknown" from "looked up, no extras needed". + self._provider_type_cache: dict[str, Optional[str]] = {} + self._provider_type_cache_keys: set[str] = set() + self._custom_parameters_cache: dict[str, Optional[dict[str, str]]] = {} + self._custom_parameters_cache_keys: set[str] = set() + # Providers that already burned their one 401-retry in the current + # turn. The agent instance is cached across turns by `get_agent`, so + # this set must be reset on `BeforeInvocationEvent`. Without the cap, + # a misconfigured provider (wrong scope, perma-401) would surface a + # consent prompt on every tool call in the turn — `_record_disconnect` + # forces fresh consent on the next BeforeToolCallEvent, the user + # consents, the tool 401s again, and the loop repeats per tool use. + self._reauth_attempted_providers: set[str] = set() + + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + registry.add_callback(BeforeInvocationEvent, self._on_invocation_start) + registry.add_callback(BeforeToolCallEvent, self._gate) + registry.add_callback(AfterToolCallEvent, self._handle_auth_failure) + + def _on_invocation_start(self, event: BeforeInvocationEvent) -> None: + """Reset per-turn state at the start of each agent invocation. + + Both fresh turns and resumes (with `interrupt_responses`) trigger + BeforeInvocationEvent. Resetting on resume is intentional: the user + just took an action (consent), so they've signaled they want to + keep trying — start their retry budget fresh. + """ + self._reauth_attempted_providers.clear() + + async def _gate(self, event: BeforeToolCallEvent) -> None: + provider_id = self._provider_lookup(event.selected_tool) + if not provider_id: + return # Not an OAuth-gated tool + + force_reauth = await self._is_disconnected(provider_id) + + # Fast path: token already in cache (from a prior call this process, + # or warmed by a previous turn). Skipped when the durable disconnect + # repository says this user wants a fresh consent — either because + # they pressed "Disconnect" (possibly on a different replica) or + # because a prior tool call returned 401. + if not force_reauth and oauth_token_cache.get(self._user_id, provider_id): + return + + # Slow path: ask AgentCore Identity. Either we get a token (vault + # hit, cache it and proceed) or a consent URL (interrupt the turn). + # `force_reauth` makes us bypass AgentCore's vault entirely so a + # stale post-revocation token doesn't get re-served. + token_or_url = await self._fetch_token_or_url( + provider_id, force_authentication=force_reauth + ) + if token_or_url is None: + # Couldn't resolve — let the tool run; the MCP server will return + # 401 and the resulting tool_error surfaces conversationally. + return + + if token_or_url["token"]: + oauth_token_cache.set(self._user_id, provider_id, token_or_url["token"]) + return + + # Consent required: pause the agent. The interrupt name is + # provider-scoped, but Strands' BeforeToolCallEvent._interrupt_id + # also folds in `tool_use.toolUseId` (see strands/hooks/events.py), + # so two parallel tool calls to the same provider in one turn + # produce distinct interrupt ids and surface as separate + # `oauth_required` events. If Strands ever changes that ID scheme + # we'd need to incorporate toolUseId here ourselves — + # `test_parallel_tool_calls_same_provider_produce_distinct_interrupts` + # is the regression guard. + response = event.interrupt( + name=f"oauth:{provider_id}", + reason={ + "type": "oauth_required", + "providerId": provider_id, + "authorizationUrl": token_or_url["url"], + }, + ) + + # We're past the interrupt — the user resumed. Re-fetch from the + # vault (AgentCore Identity should now have the token after consent + # completion) and warm the cache. We ignore `response` content — + # successful resumption is itself the signal that consent happened. + del response + refreshed = await self._fetch_token_or_url(provider_id) + if refreshed and refreshed["token"]: + oauth_token_cache.set(self._user_id, provider_id, refreshed["token"]) + return + + # Resumed but still no token — treat as declined. cancel_tool emits a + # tool_error to the model so it can apologize/replan. + event.cancel_tool = ( + f"User did not complete authorization for {provider_id}; " + "the tool cannot run." + ) + + async def _fetch_token_or_url( + self, provider_id: str, *, force_authentication: bool = False + ) -> Optional[dict]: + """Return {'token': str|None, 'url': str|None} or None on hard error.""" + scopes = await self._resolve_scopes(provider_id) + provider_type = await self._resolve_provider_type(provider_id) + admin_extras = await self._resolve_custom_parameters(provider_id) + identity_client = get_agentcore_identity_client() + + try: + result = await identity_client.get_token_for_user( + provider_name=provider_id, + scopes=scopes, + user_id=self._user_id, + force_authentication=force_authentication, + custom_parameters=custom_parameters_for(provider_type, admin_extras), + ) + except WorkloadTokenUnavailableError: + logger.error( + "No workload token on context for provider=%s — " + "AgentCoreContextMiddleware may be misconfigured", + provider_id, + ) + return None + except CallbackUrlUnavailableError as err: + logger.error( + "No OAuth2 callback URL for provider=%s: %s", + provider_id, + err, + ) + return None + except asyncio.CancelledError: + raise + except Exception: + logger.exception( + "Failed to fetch OAuth token for provider=%s", provider_id + ) + return None + + return { + "token": result.access_token, + "url": result.authorization_url, + } + + async def _handle_auth_failure(self, event: AfterToolCallEvent) -> None: + """Detect a 401 from an OAuth-gated MCP tool and retry with fresh consent. + + AgentCore Identity has no revoke API, so when the user revokes our + app at the provider (or the refresh token expires), AgentCore's + vault keeps serving the now-stale token. The MCP server rejects it + with a 401 — and that's where the staleness first becomes visible. + + We detect the 401 in the tool result, mark the (user, provider) + for forced re-consent in the cache, and set `event.retry = True`. + Strands' tool executor then re-fires `BeforeToolCallEvent`, our + `_gate` callback sees the force-reauth flag, asks AgentCore for a + fresh consent URL with `force_authentication=True`, and raises an + interrupt — same path as a first-time consent. + """ + provider_id = self._provider_lookup(event.selected_tool) + if not provider_id: + return + + if not _looks_like_auth_failure(event.result): + return + + # Avoid both an infinite retry loop within a single tool call and a + # consent-prompt storm across multiple tool calls in the same turn: + # retry at most once per provider per turn. The set is reset on + # `BeforeInvocationEvent`, so a fresh turn (or a resume after the + # user re-consented) gets a fresh budget. + if provider_id in self._reauth_attempted_providers: + logger.warning( + "OAuth re-auth already attempted this turn for provider=%s " + "(tool=%s); not retrying again", + provider_id, + event.tool_use.get("name"), + ) + return + self._reauth_attempted_providers.add(provider_id) + + logger.info( + "Detected OAuth 401 for tool=%s provider=%s; clearing token cache and retrying", + event.tool_use.get("name"), + provider_id, + ) + # Drop the local hot-path token so the BeforeToolCallEvent retry + # doesn't short-circuit to it, and record the intent durably so + # other replicas (and subsequent requests on this one) also force a + # fresh consent. + oauth_token_cache.clear_user_provider(self._user_id, provider_id) + await self._record_disconnect(provider_id) + event.retry = True + + async def _resolve_scopes(self, provider_id: str) -> list[str]: + if provider_id in self._scopes_cache: + return self._scopes_cache[provider_id] + result = self._scopes_lookup(provider_id) + if inspect.isawaitable(result): + scopes = await result + else: + scopes = result + scopes = list(scopes or []) + self._scopes_cache[provider_id] = scopes + return scopes + + async def _resolve_provider_type(self, provider_id: str) -> Optional[str]: + if self._provider_type_lookup is None: + return None + if provider_id in self._provider_type_cache_keys: + return self._provider_type_cache.get(provider_id) + result = self._provider_type_lookup(provider_id) + if inspect.isawaitable(result): + provider_type = await result + else: + provider_type = result + self._provider_type_cache[provider_id] = provider_type + self._provider_type_cache_keys.add(provider_id) + return provider_type + + async def _resolve_custom_parameters( + self, provider_id: str + ) -> Optional[dict[str, str]]: + if self._custom_parameters_lookup is None: + return None + if provider_id in self._custom_parameters_cache_keys: + return self._custom_parameters_cache.get(provider_id) + result = self._custom_parameters_lookup(provider_id) + if inspect.isawaitable(result): + extras = await result + else: + extras = result + self._custom_parameters_cache[provider_id] = extras + self._custom_parameters_cache_keys.add(provider_id) + return extras + + async def _is_disconnected(self, provider_id: str) -> bool: + """Read the durable disconnect flag (DDB-backed in production). + + Not memoized: a disconnect request can land on this replica between + two tool calls in the same turn, and we want the second tool call + to honor it. The DDB read is a single GetItem keyed on + `(user_id, provider_id)`, so the cost is negligible. + """ + if self._disconnected_lookup is None: + return False + result = self._disconnected_lookup(provider_id) + if inspect.isawaitable(result): + return bool(await result) + return bool(result) + + async def _record_disconnect(self, provider_id: str) -> None: + """Persist a disconnect from the AfterToolCallEvent retry path.""" + if self._mark_disconnected is None: + return + result = self._mark_disconnected(provider_id) + if inspect.isawaitable(result): + await result diff --git a/backend/src/agents/main_agent/streaming/stream_coordinator.py b/backend/src/agents/main_agent/streaming/stream_coordinator.py index f50be71a..3de70cac 100644 --- a/backend/src/agents/main_agent/streaming/stream_coordinator.py +++ b/backend/src/agents/main_agent/streaming/stream_coordinator.py @@ -195,6 +195,23 @@ async def stream_response( # Don't yield this event to the client (will send final metadata before done) continue + # If the agent paused on an OAuth interrupt, surface one + # `oauth_required` SSE event per pending interrupt before the + # stream closes. The frontend uses these to drive the consent + # popup and then POSTs interrupt responses to resume the turn. + # Done before the metadata branch so the events land between + # message_stop and the final metadata/done block. Persistence + # to session metadata happens inside the extractor so a refresh + # rediscovers the consent prompt. + if event.get("type") == "done": + for sse in await self._extract_oauth_required_events( + agent, + session_id=session_id, + user_id=user_id, + main_agent_wrapper=main_agent_wrapper, + ): + yield sse + # Check if this is the "done" event - send final metadata before it if event.get("type") == "done": # Calculate end-to-end latency @@ -530,6 +547,116 @@ async def stream_response( except Exception as persist_error: logger.error(f"Failed to persist stream error to session: {persist_error}") + async def _extract_oauth_required_events( + self, + agent: Any, + session_id: Optional[str] = None, + user_id: Optional[str] = None, + triggering_message_id: Optional[str] = None, + main_agent_wrapper: Any = None, + ) -> List[str]: + """Yield one SSE-formatted `oauth_required` event per pending OAuth + interrupt on the agent, persisting each one to session metadata so + the frontend can rediscover them after a refresh. + + The Strands `_interrupt_state` is populated when `OAuthConsentHook` + calls `event.interrupt(...)`. We look for interrupts whose `reason` + carries `type: "oauth_required"` and translate them into the SSE + shape the frontend already understands. Non-OAuth interrupts (other + approval gates added later) are ignored here so they can be handled + by their own SSE event types. + + Also persists a ``PausedTurnSnapshot`` capturing the agent's + construction params, so a resume after refresh / cache eviction + rebuilds the same agent shape (matching tool registry) and lets + Strands restore ``_interrupt_state`` from AgentCore Memory. + + Persistence is best-effort: a DynamoDB write failure logs but does + not break the live SSE flow. + """ + from datetime import timedelta + from apis.shared.oauth.models import OAuthRequiredEvent + from apis.shared.sessions.metadata import add_pending_interrupt, set_paused_turn + from apis.shared.sessions.models import PausedTurnSnapshot, PendingInterrupt + + interrupt_state = getattr(agent, "_interrupt_state", None) + if not interrupt_state or not getattr(interrupt_state, "activated", False): + return [] + + # Snapshot the turn-construction context once per pause, before the + # per-interrupt loop. Multiple OAuth interrupts in the same turn + # share one snapshot — they were all built against the same agent. + # TTL matches AgentCore Identity's consent window so stale snapshots + # don't pin storage and a too-late resume returns a clean 400. + snapshot_source = ( + getattr(main_agent_wrapper, "_construction_snapshot", None) if main_agent_wrapper else None + ) + if session_id and user_id and snapshot_source: + try: + now = datetime.now(timezone.utc) + snapshot = PausedTurnSnapshot( + enabled_tools=snapshot_source.get("enabled_tools"), + model_id=snapshot_source.get("model_id"), + provider=snapshot_source.get("provider"), + temperature=snapshot_source.get("temperature"), + system_prompt=snapshot_source.get("system_prompt"), + caching_enabled=snapshot_source.get("caching_enabled"), + max_tokens=snapshot_source.get("max_tokens"), + captured_at=now.isoformat(), + expires_at=(now + timedelta(hours=1)).isoformat(), + ) + await set_paused_turn(session_id, user_id, snapshot) + except Exception as e: + logger.error( + "Failed to persist paused_turn snapshot for session %s: %s", + session_id, e, exc_info=True, + ) + + events: List[str] = [] + for interrupt in interrupt_state.interrupts.values(): + reason = interrupt.reason or {} + if not isinstance(reason, dict) or reason.get("type") != "oauth_required": + continue + provider_id = reason.get("providerId") + authorization_url = reason.get("authorizationUrl") + if not provider_id or not authorization_url: + logger.warning( + "OAuth interrupt missing providerId or authorizationUrl: id=%s", + interrupt.id, + ) + continue + + # Persist the breadcrumb before yielding so a client that loads + # the session a moment later sees this interrupt. Only attempt + # when we have session/user context — preview/anonymous flows + # don't have a metadata record to write to. + if session_id and user_id: + try: + await add_pending_interrupt( + session_id=session_id, + user_id=user_id, + interrupt=PendingInterrupt( + interrupt_id=interrupt.id, + provider_id=provider_id, + triggering_message_id=triggering_message_id, + created_at=datetime.now(timezone.utc).isoformat(), + ), + ) + except Exception as e: + logger.error( + "Failed to persist pending_interrupt %s: %s", + interrupt.id, e, exc_info=True, + ) + + events.append( + OAuthRequiredEvent( + provider_id=provider_id, + authorization_url=authorization_url, + interrupt_id=interrupt.id, + ).to_sse_format() + ) + return events + def _format_sse_event(self, event: Dict[str, Any]) -> str: """ Format processed event as SSE (Server-Sent Event) @@ -1090,149 +1217,39 @@ async def _calculate_streaming_cost(self, model_id: str, usage: Dict[str, Any]) return None async def _update_session_metadata(self, session_id: str, user_id: str, message_id: int, agent: Any = None) -> None: - """ - Update session-level metadata after each message - - This updates conversation-level tracking after each message: - - lastMessageAt: Timestamp of this message - - messageCount: Incremented by 1 - - preferences: Model/temperature/tools/system_prompt_hash from agent config - - Auto-creates session metadata on first message + """Update per-turn session activity (lastMessageAt, messageCount, preferences). - Args: - session_id: Session identifier - user_id: User identifier - message_id: Message ID that was just flushed - agent: Agent instance for extracting model preferences + Delegates to ``update_session_activity``, which uses targeted writes + so concurrent writers (title-gen, pending-interrupt persistence) + cannot be clobbered. Pre-create is handled at /invocations entry, so + no lazy-create branch is needed here. """ try: import hashlib - from apis.shared.sessions.models import SessionMetadata, SessionPreferences - from apis.shared.sessions.metadata import get_session_metadata, store_session_metadata - - logger.info(f"🔍 _update_session_metadata called for session {session_id}, message_id {message_id}") - - # Get existing metadata or create new - existing = await get_session_metadata(session_id, user_id) - - if existing: - logger.info(f"📄 Found existing metadata: messageCount={existing.message_count}, has_preferences={existing.preferences is not None}") - else: - logger.info(f"📄 No existing metadata found - creating new") - - # Calculate message count incrementally - # NOTE: We cannot query AgentCore Memory immediately after flush due to eventual consistency. - # The turn-based session manager calls create_message() then immediately calls list_messages(), - # but the newly created message is not yet available for reading (can take several seconds). - # - # Instead, we use incremental counting: - # - Each streaming turn creates 1 merged message in AgentCore Memory - # - We increment the count by 1 per turn - # - # This count represents "turns" (user-assistant exchanges), not individual message events. - # Tool use creates multiple content blocks within a single turn/message. - if not existing: - actual_message_count = 1 - logger.info(f"📊 First turn in session - message_count: {actual_message_count}") - else: - actual_message_count = existing.message_count + 1 - logger.info(f"📊 Incremental turn count: {existing.message_count} + 1 = {actual_message_count}") - - now = datetime.now(timezone.utc).isoformat() - - if not existing: - # First message - create session metadata - preferences = None - if agent and hasattr(agent, "model_config"): - logger.info(f"📦 Agent has model_config: model_id={agent.model_config.model_id}") - - # Generate system prompt hash for tracking exact prompt version - # This hash represents the FINAL rendered system prompt (after date injection, etc.) - system_prompt_hash = None - if hasattr(agent, "system_prompt") and agent.system_prompt: - system_prompt_hash = hashlib.md5(agent.system_prompt.encode()).hexdigest()[:16] # 16 char hash for uniqueness - logger.debug(f"Generated system_prompt_hash: {system_prompt_hash}") - - # Extract enabled tools from agent - enabled_tools = getattr(agent, "enabled_tools", None) - - preferences = SessionPreferences( - last_model=agent.model_config.model_id, - last_temperature=getattr(agent.model_config, "temperature", None), - enabled_tools=enabled_tools, - system_prompt_hash=system_prompt_hash, - ) - logger.info(f"✨ Created new preferences: last_model={preferences.last_model}") - else: - logger.warning(f"⚠️ Agent is None or missing model_config") + from apis.shared.sessions.metadata import update_session_activity - metadata = SessionMetadata( - session_id=session_id, - user_id=user_id, - title="New Conversation", # Will be updated by frontend - status="active", - created_at=now, - last_message_at=now, - message_count=actual_message_count, - starred=False, - tags=[], - preferences=preferences, - ) + last_model = None + last_temperature = None + enabled_tools = None + system_prompt_hash = None + if agent and hasattr(agent, "model_config"): + last_model = agent.model_config.model_id + last_temperature = getattr(agent.model_config, "temperature", None) + enabled_tools = getattr(agent, "enabled_tools", None) + if hasattr(agent, "system_prompt") and agent.system_prompt: + system_prompt_hash = hashlib.md5(agent.system_prompt.encode()).hexdigest()[:16] else: - # Update existing - only update what changed - preferences = existing.preferences - if agent and hasattr(agent, "model_config"): - logger.info(f"📦 Updating preferences with model_id={agent.model_config.model_id}") - - # Update preferences if model/temperature/tools/system_prompt changed - prefs_dict = preferences.model_dump(by_alias=False) if preferences else {} - logger.info(f"📝 Existing prefs_dict: {prefs_dict}") - - prefs_dict["last_model"] = agent.model_config.model_id - prefs_dict["last_temperature"] = getattr(agent.model_config, "temperature", None) - - # Update enabled_tools from agent - prefs_dict["enabled_tools"] = getattr(agent, "enabled_tools", None) - - # Update system_prompt_hash if system prompt changed - # This allows tracking when the prompt was modified during a conversation - if hasattr(agent, "system_prompt") and agent.system_prompt: - new_hash = hashlib.md5(agent.system_prompt.encode()).hexdigest()[:16] - # Only update if hash changed (prompt was modified) - if prefs_dict.get("system_prompt_hash") != new_hash: - logger.info(f"System prompt changed - updating hash from {prefs_dict.get('system_prompt_hash')} to {new_hash}") - prefs_dict["system_prompt_hash"] = new_hash - - preferences = SessionPreferences(**prefs_dict) - logger.info(f"✨ Updated preferences: last_model={preferences.last_model}") - else: - logger.warning(f"⚠️ Agent is None or missing model_config - keeping existing preferences") - - metadata = SessionMetadata( - session_id=session_id, - user_id=user_id, - title=existing.title, - status=existing.status, - created_at=existing.created_at, - last_message_at=now, - message_count=actual_message_count, - starred=existing.starred, - tags=existing.tags, - preferences=preferences, - ) + logger.warning("⚠️ Agent is None or missing model_config — skipping preference update") - # Store updated metadata (uses deep merge in storage layer) - await store_session_metadata(session_id=session_id, user_id=user_id, session_metadata=metadata) - - logger.info( - f"✅ Updated session metadata - last_model: {metadata.preferences.last_model if metadata.preferences else 'None'}, message_count: {metadata.message_count}" + await update_session_activity( + session_id=session_id, + user_id=user_id, + last_model=last_model, + last_temperature=last_temperature, + enabled_tools=enabled_tools, + system_prompt_hash=system_prompt_hash, ) - - # Return message count for use as a fallback message_id - return metadata.message_count - except Exception as e: logger.error(f"Failed to update session metadata: {e}", exc_info=True) - # Don't raise - metadata failures shouldn't break streaming - return None + # Don't raise — metadata failures shouldn't break streaming. diff --git a/backend/src/agents/main_agent/tools/oauth_tool_service.py b/backend/src/agents/main_agent/tools/oauth_tool_service.py deleted file mode 100644 index 2d42f763..00000000 --- a/backend/src/agents/main_agent/tools/oauth_tool_service.py +++ /dev/null @@ -1,336 +0,0 @@ -""" -OAuth Tool Service - Enables tools to securely access user OAuth tokens - -This service provides a clean interface for tools to: -1. Check if a user has connected to an OAuth provider -2. Retrieve decrypted access tokens for API calls -3. Handle token refresh automatically -4. Generate connection URLs for user guidance - -Usage in tools: - from agents.main_agent.tools.oauth_tool_service import get_oauth_tool_service - - @tool(context=True) - async def my_oauth_tool(query: str, tool_context: ToolContext) -> dict: - oauth_service = get_oauth_tool_service() - - # Get user_id from session manager - user_id = tool_context.agent._session_manager.user_id - - # Get token for this provider - result = await oauth_service.get_token_for_tool( - user_id=user_id, - provider_id="google_workspace" - ) - - if not result.connected: - return result.not_connected_response() - - # Use the token - headers = {"Authorization": f"Bearer {result.access_token}"} - # ... make API calls -""" -import logging -import os -from dataclasses import dataclass -from typing import Optional - -from agents.main_agent.config.constants import EnvVars, Defaults - -logger = logging.getLogger(__name__) - - -@dataclass -class OAuthTokenResult: - """Result of requesting an OAuth token for a tool.""" - - connected: bool - """Whether the user is connected to the provider.""" - - access_token: Optional[str] = None - """The decrypted access token (if connected).""" - - provider_id: str = "" - """The provider ID requested.""" - - provider_name: str = "" - """Human-readable provider name.""" - - error: Optional[str] = None - """Error message if token retrieval failed.""" - - needs_reauth: bool = False - """Whether the user needs to re-authorize (expired/revoked).""" - - def not_connected_response(self, tool_name: str = "this tool") -> dict: - """ - Generate a user-friendly response when not connected. - - Returns a dict suitable for returning from a tool. - """ - frontend_url = os.getenv(EnvVars.FRONTEND_URL, Defaults.FRONTEND_URL) - connect_url = f"{frontend_url}/settings/connections" - - if self.needs_reauth: - message = f"""⚠️ **Re-authorization Required** - -Your connection to **{self.provider_name}** has expired or been revoked. - -Please reconnect to continue using {tool_name}: - -👉 [Reconnect to {self.provider_name}]({connect_url}) - -After reconnecting, try your request again.""" - else: - message = f"""🔗 **Connection Required** - -To use {tool_name}, you need to connect your **{self.provider_name}** account. - -This allows me to securely access your data on your behalf. - -👉 [Connect {self.provider_name}]({connect_url}) - -After connecting, try your request again.""" - - return { - "status": "not_connected", - "content": [{"text": message}], - "requires_oauth": True, - "provider_id": self.provider_id, - "provider_name": self.provider_name, - "connect_url": connect_url, - "needs_reauth": self.needs_reauth, - } - - -class OAuthToolService: - """ - Service for tools to access OAuth tokens. - - This service wraps the OAuth service and provides a simplified - interface optimized for tool usage. - """ - - def __init__(self): - self._oauth_service = None - self._provider_repo = None - - async def _get_oauth_service(self): - """Lazy-load OAuth service to avoid circular imports.""" - if self._oauth_service is None: - from apis.shared.oauth.service import get_oauth_service - self._oauth_service = get_oauth_service() - return self._oauth_service - - async def _get_provider_repo(self): - """Lazy-load provider repository.""" - if self._provider_repo is None: - from apis.shared.oauth.provider_repository import get_provider_repository - self._provider_repo = get_provider_repository() - return self._provider_repo - - async def get_token_for_tool( - self, - user_id: str, - provider_id: str, - ) -> OAuthTokenResult: - """ - Get an OAuth token for a tool to use. - - Args: - user_id: The user's ID (from session manager) - provider_id: The OAuth provider ID (e.g., "google_workspace") - - Returns: - OAuthTokenResult with token or connection guidance - """ - try: - oauth_service = await self._get_oauth_service() - provider_repo = await self._get_provider_repo() - - # Get provider info for display name - provider = await provider_repo.get_provider(provider_id) - provider_name = provider.display_name if provider else provider_id - - # Try to get the token — isolated to limit taint scope - result = await self._try_get_token(oauth_service, user_id, provider_id, provider_name) - if result: - return result - - # Check if user has a connection but needs re-auth - from apis.shared.oauth.token_repository import get_token_repository - token_repo = get_token_repository() - user_token = await token_repo.get_user_token(user_id, provider_id) - - if user_token and user_token.status in ("expired", "needs_reauth", "revoked"): - token_status = user_token.status - logger.debug("User needs re-auth for an OAuth provider") - return OAuthTokenResult( - connected=False, - provider_id=provider_id, - provider_name=provider_name, - needs_reauth=True, - error=f"Token {token_status}", - ) - - # User not connected - logger.debug("User not connected to an OAuth provider") - return OAuthTokenResult( - connected=False, - provider_id=provider_id, - provider_name=provider_name, - ) - - except Exception as e: - logger.error("Error checking OAuth connection: %s", type(e).__name__, exc_info=True) - return OAuthTokenResult( - connected=False, - provider_id=provider_id, - provider_name=provider_id, - error=str(e), - ) - - @staticmethod - async def _try_get_token( - oauth_service: "OAuthService", - user_id: str, - provider_id: str, - provider_name: str, - ) -> Optional[OAuthTokenResult]: - """Fetch decrypted token in isolated scope to avoid taint leaking to callers.""" - token = await oauth_service.get_decrypted_token( - user_id=user_id, - provider_id=provider_id, - ) - if token: - logger.debug("OAuth token successfully retrieved") - return OAuthTokenResult( - connected=True, - access_token=token, - provider_id=provider_id, - provider_name=provider_name, - ) - return None - - async def check_connection( - self, - user_id: str, - provider_id: str, - ) -> bool: - """ - Quick check if user is connected to a provider. - - Args: - user_id: The user's ID - provider_id: The OAuth provider ID - - Returns: - True if connected and token is valid - """ - result = await self.get_token_for_tool(user_id, provider_id) - return result.connected - - def get_connect_url(self, provider_id: str) -> str: - """ - Get the URL for the user to connect to a provider. - - Args: - provider_id: The OAuth provider ID - - Returns: - URL to the connections page - """ - frontend_url = os.getenv(EnvVars.FRONTEND_URL, Defaults.FRONTEND_URL) - return f"{frontend_url}/settings/connections" - - -# Singleton instance -_oauth_tool_service: Optional[OAuthToolService] = None - - -def get_oauth_tool_service() -> OAuthToolService: - """Get the singleton OAuthToolService instance.""" - global _oauth_tool_service - if _oauth_tool_service is None: - _oauth_tool_service = OAuthToolService() - return _oauth_tool_service - - -async def check_oauth_requirements_for_tools( - user_id: str, - enabled_tool_ids: list[str], -) -> dict[str, OAuthTokenResult]: - """ - Check OAuth connection status for all tools that require OAuth. - - This is useful for determining which tools will work and providing - guidance to users about missing connections. - - Args: - user_id: The user's ID - enabled_tool_ids: List of enabled tool IDs - - Returns: - Dict mapping provider_id to OAuthTokenResult for tools needing OAuth - """ - from apis.app_api.tools.repository import get_tool_catalog_repository - - results: dict[str, OAuthTokenResult] = {} - repository = get_tool_catalog_repository() - oauth_service = get_oauth_tool_service() - - # Find all unique OAuth providers required by enabled tools - providers_needed: set[str] = set() - - for tool_id in enabled_tool_ids: - try: - tool = await repository.get_tool(tool_id) - if tool and tool.requires_oauth_provider: - providers_needed.add(tool.requires_oauth_provider) - except Exception as e: - logger.warning(f"Could not check tool {tool_id}: {e}") - - # Check connection status for each provider - for provider_id in providers_needed: - result = await oauth_service.get_token_for_tool(user_id, provider_id) - results[provider_id] = result - - return results - - -def format_oauth_connection_guidance( - missing_connections: list[OAuthTokenResult], -) -> str: - """ - Format a user-friendly message about missing OAuth connections. - - Args: - missing_connections: List of OAuthTokenResult for unconnected providers - - Returns: - Markdown formatted message for the user - """ - if not missing_connections: - return "" - - frontend_url = os.getenv(EnvVars.FRONTEND_URL, Defaults.FRONTEND_URL) - connect_url = f"{frontend_url}/settings/connections" - - if len(missing_connections) == 1: - conn = missing_connections[0] - if conn.needs_reauth: - return f"""⚠️ Your connection to **{conn.provider_name}** has expired. - -Please [reconnect]({connect_url}) to use tools that require {conn.provider_name} access.""" - else: - return f"""🔗 To use tools that require **{conn.provider_name}** access, please [connect your account]({connect_url}).""" - - # Multiple missing connections - provider_names = [c.provider_name for c in missing_connections] - names_str = ", ".join(provider_names[:-1]) + f" and {provider_names[-1]}" - - return f"""🔗 Some tools require account connections. - -To use all your enabled tools, please connect: **{names_str}** - -👉 [Manage Connections]({connect_url})""" diff --git a/backend/src/agents/main_agent/voice_agent.py b/backend/src/agents/main_agent/voice_agent.py index dde8a061..7ef326bf 100644 --- a/backend/src/agents/main_agent/voice_agent.py +++ b/backend/src/agents/main_agent/voice_agent.py @@ -335,7 +335,13 @@ async def receive_events(self) -> AsyncGenerator[dict, None]: yield event_dict async def stream_async( - self, message: str, session_id: Optional[str] = None, files: Optional[List] = None, citations: Optional[List] = None, original_message: Optional[str] = None + self, + message: str, + session_id: Optional[str] = None, + files: Optional[List] = None, + citations: Optional[List] = None, + original_message: Optional[str] = None, + interrupt_responses: Optional[List] = None, ) -> AsyncGenerator[str, None]: """ BaseAgent interface compatibility — not used for voice mode. diff --git a/backend/src/apis/app_api/admin/oauth/routes.py b/backend/src/apis/app_api/admin/oauth/routes.py index e5dee9b1..3fb125bf 100644 --- a/backend/src/apis/app_api/admin/oauth/routes.py +++ b/backend/src/apis/app_api/admin/oauth/routes.py @@ -1,11 +1,31 @@ -"""Admin API routes for OAuth provider management.""" +"""Admin API routes for OAuth provider management. +Registration flows through AWS Bedrock AgentCore Identity. Our DynamoDB +record holds display metadata, scopes, and role gates; the AgentCore +credential provider owns `clientId`, `clientSecret`, endpoint config, and +the callback URL that the admin must register with the vendor. +""" + +import asyncio import logging +import os +from datetime import datetime, timezone +from functools import lru_cache +import boto3 from fastapi import APIRouter, Depends, HTTPException, Query, status from apis.shared.auth import User, require_admin +from apis.shared.oauth.agentcore_registrar import ( + AgentCoreRegistrar, + CredentialProviderConflictError, + CredentialProviderInfo, + CredentialProviderNotFoundError, + InvalidCustomProviderConfigError, + get_agentcore_registrar, +) from apis.shared.oauth.models import ( + OAuthProvider, OAuthProviderCreate, OAuthProviderListResponse, OAuthProviderResponse, @@ -15,16 +35,97 @@ OAuthProviderRepository, get_provider_repository, ) -from apis.shared.oauth.token_repository import ( - OAuthTokenRepository, - get_token_repository, -) -from apis.shared.oauth.token_cache import get_token_cache logger = logging.getLogger(__name__) router = APIRouter(prefix="/oauth-providers", tags=["admin-oauth"]) +# Rollback backoff schedule. Two retries after the initial attempt, ~2.5s +# total worst case — short enough to keep the create request responsive, +# long enough to absorb the common class of transient AWS errors. +_ROLLBACK_RETRY_DELAYS_SECONDS = (0.5, 2.0) + +# CloudWatch namespace for orphan telemetry. Kept distinct so ops can +# scope alarms without catching unrelated Bedrock metrics. +_ORPHAN_METRIC_NAMESPACE = "Agentcore/OAuth" +_ORPHAN_METRIC_NAME = "ProviderOrphaned" + + +@lru_cache(maxsize=1) +def _cloudwatch_client(): + region = os.environ.get("AWS_REGION", "us-west-2") + return boto3.client("cloudwatch", region_name=region) + + +async def _rollback_orphaned_provider( + registrar: AgentCoreRegistrar, provider_id: str +) -> None: + """Best-effort delete of an AgentCore provider after a DB write failed. + + Retries on transient AWS errors. If every attempt fails we emit a + CloudWatch `ProviderOrphaned` metric and log at ERROR — the AgentCore + record is now orphaned (no DynamoDB row) and needs manual cleanup, + but the admin's original 5xx still propagates so they know the + create didn't land. + """ + last_err: Exception | None = None + for attempt in range(1 + len(_ROLLBACK_RETRY_DELAYS_SECONDS)): + try: + # Registrar is sync; off-thread it so we don't block the event loop. + await asyncio.to_thread(registrar.delete_credential_provider, provider_id) + logger.info( + "Rolled back orphaned AgentCore provider %s (attempt %d)", + provider_id, + attempt + 1, + ) + return + except Exception as err: + last_err = err + if attempt < len(_ROLLBACK_RETRY_DELAYS_SECONDS): + delay = _ROLLBACK_RETRY_DELAYS_SECONDS[attempt] + logger.warning( + "Rollback attempt %d for %s failed (%s); retrying in %.1fs", + attempt + 1, + provider_id, + err, + delay, + ) + await asyncio.sleep(delay) + + logger.error( + "Rollback delete exhausted for %s after %d attempts; emitting " + "orphan metric. Last error: %s", + provider_id, + 1 + len(_ROLLBACK_RETRY_DELAYS_SECONDS), + last_err, + exc_info=last_err, + ) + _emit_orphan_metric(provider_id) + + +def _emit_orphan_metric(provider_id: str) -> None: + """Fire-and-forget CloudWatch metric for an orphaned credential provider. + + Failures here are swallowed — we're already inside a rollback path + and a secondary error would just shadow the admin-facing one. + """ + try: + _cloudwatch_client().put_metric_data( + Namespace=_ORPHAN_METRIC_NAMESPACE, + MetricData=[ + { + "MetricName": _ORPHAN_METRIC_NAME, + "Dimensions": [{"Name": "ProviderId", "Value": provider_id}], + "Value": 1, + "Unit": "Count", + } + ], + ) + except Exception: + logger.exception( + "Failed to emit CloudWatch orphan metric for %s", provider_id + ) + # ============================================================================= # Provider CRUD @@ -37,22 +138,9 @@ async def list_providers( admin: User = Depends(require_admin), provider_repo: OAuthProviderRepository = Depends(get_provider_repository), ): - """ - List all OAuth providers. - - Requires admin access. - - Args: - enabled_only: If True, only return enabled providers - admin: Authenticated admin user (injected) - - Returns: - OAuthProviderListResponse with all providers - """ + """List all OAuth providers. Admin only.""" logger.info("Admin listing OAuth providers") - providers = await provider_repo.list_providers(enabled_only=enabled_only) - return OAuthProviderListResponse( providers=[OAuthProviderResponse.from_provider(p) for p in providers], total=len(providers), @@ -65,67 +153,86 @@ async def get_provider( admin: User = Depends(require_admin), provider_repo: OAuthProviderRepository = Depends(get_provider_repository), ): - """ - Get a provider by ID. - - Requires admin access. - - Args: - provider_id: Provider identifier - admin: Authenticated admin user (injected) - - Returns: - OAuthProviderResponse with provider details - - Raises: - HTTPException: 404 if provider not found - """ - logger.info("Admin getting OAuth provider") - + """Get a provider by ID. Admin only.""" provider = await provider_repo.get_provider(provider_id) - if not provider: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Provider '{provider_id}' not found", ) - return OAuthProviderResponse.from_provider(provider) -@router.post("/", response_model=OAuthProviderResponse, status_code=status.HTTP_201_CREATED) +@router.post( + "/", response_model=OAuthProviderResponse, status_code=status.HTTP_201_CREATED +) async def create_provider( provider_data: OAuthProviderCreate, admin: User = Depends(require_admin), provider_repo: OAuthProviderRepository = Depends(get_provider_repository), + registrar: AgentCoreRegistrar = Depends(get_agentcore_registrar), ): - """ - Create a new OAuth provider. - - Requires admin access. - - Args: - provider_data: Provider creation data - admin: Authenticated admin user (injected) + """Register a new OAuth provider. - Returns: - Created OAuthProviderResponse - - Raises: - HTTPException: 400 if provider already exists or validation fails + Registers credentials with AgentCore Identity first; on success, writes + the metadata record to DynamoDB. If the DB write fails after AgentCore + has accepted the credentials, best-effort rolls back the AgentCore + provider so the two stores stay in sync. """ - logger.info("Admin creating OAuth provider") + logger.info("Admin creating OAuth provider %s", provider_data.provider_id) - try: - provider = await provider_repo.create_provider(provider_data) - return OAuthProviderResponse.from_provider(provider) + existing = await provider_repo.get_provider(provider_data.provider_id) + if existing: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"Provider '{provider_data.provider_id}' already exists", + ) - except ValueError as e: - logger.warning(f"Provider creation failed: {e}") + try: + credential_info = registrar.create_credential_provider( + provider_id=provider_data.provider_id, + provider_type=provider_data.provider_type, + client_id=provider_data.client_id, + client_secret=provider_data.client_secret, + discovery_url=provider_data.oauth_discovery_url, + authorization_server_metadata=provider_data.authorization_server_metadata, + ) + except CredentialProviderConflictError: + # We already verified the DB has no record for this provider_id + # above, so a conflict from AgentCore means its vault carries an + # orphaned record — almost always from a prior failed rollback. + # Give the admin a cleanup pointer instead of a bare 409. + logger.error( + "Orphan detected for %s: DB empty but AgentCore has a credential " + "provider. Previous rollback likely failed.", + provider_data.provider_id, + ) raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=str(e), + status_code=status.HTTP_409_CONFLICT, + detail=( + f"An AgentCore credential provider named " + f"'{provider_data.provider_id}' already exists but has no " + "matching database record (likely from a previous failed " + "rollback). Delete it via the AWS CLI and retry: " + "`aws bedrock-agentcore-control delete-oauth2-credential-provider " + f"--name {provider_data.provider_id}`." + ), + ) + except InvalidCustomProviderConfigError as err: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(err)) + + try: + provider = _build_provider_from_create(provider_data, credential_info) + await provider_repo.put_provider(provider) + except Exception: + logger.exception( + "DB write failed for %s; rolling back AgentCore credential provider", + provider_data.provider_id, ) + await _rollback_orphaned_provider(registrar, provider_data.provider_id) + raise + + return OAuthProviderResponse.from_provider(provider) @router.patch("/{provider_id}", response_model=OAuthProviderResponse) @@ -134,54 +241,83 @@ async def update_provider( updates: OAuthProviderUpdate, admin: User = Depends(require_admin), provider_repo: OAuthProviderRepository = Depends(get_provider_repository), + registrar: AgentCoreRegistrar = Depends(get_agentcore_registrar), ): - """ - Update an OAuth provider. - - Requires admin access. - - Note: If scopes are updated, existing user connections may need to re-authenticate. - The system tracks scope changes via hash and will prompt users to re-auth. - - Args: - provider_id: Provider identifier - updates: Fields to update - admin: Authenticated admin user (injected) - - Returns: - Updated OAuthProviderResponse + """Update a provider's metadata, and optionally rotate credentials. - Raises: - HTTPException: - - 400 if validation fails - - 404 if provider not found + Metadata edits (display name, scopes, roles, icon, enabled) write + straight to DynamoDB. Credential or discovery-config changes require + a corresponding AgentCore update — this is done first, and only if it + succeeds do we persist the new metadata and cached pointers. """ - logger.info("Admin updating OAuth provider") - - # Track if scopes changed (will invalidate cached tokens) - old_provider = await provider_repo.get_provider(provider_id) - scopes_changed = ( - old_provider - and updates.scopes is not None - and set(updates.scopes) != set(old_provider.scopes) - ) + existing = await provider_repo.get_provider(provider_id) + if not existing: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Provider '{provider_id}' not found", + ) - provider = await provider_repo.update_provider(provider_id, updates) + rotating_credentials = bool(updates.client_id and updates.client_secret) + changing_discovery = ( + updates.oauth_discovery_url is not None + or updates.authorization_server_metadata is not None + ) - if not provider: + credential_info: CredentialProviderInfo | None = None + if rotating_credentials or changing_discovery: + discovery_url = ( + updates.oauth_discovery_url + if updates.oauth_discovery_url is not None + else existing.oauth_discovery_url + ) + authorization_server_metadata = ( + updates.authorization_server_metadata + if updates.authorization_server_metadata is not None + else existing.authorization_server_metadata + ) + if not rotating_credentials: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=( + "Discovery config can only be updated together with a " + "credential rotation (client_id + client_secret)." + ), + ) + try: + credential_info = registrar.update_credential_provider( + provider_id=provider_id, + provider_type=existing.provider_type, + client_id=updates.client_id, + client_secret=updates.client_secret, + discovery_url=discovery_url, + authorization_server_metadata=authorization_server_metadata, + fallback_arn=existing.credential_provider_arn, + ) + except CredentialProviderNotFoundError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=( + f"AgentCore credential provider for '{provider_id}' " + "was not found. The DynamoDB record may be stale." + ), + ) + except InvalidCustomProviderConfigError as err: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=str(err) + ) + + provider = await provider_repo.apply_metadata_update(provider_id, updates) + if provider is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Provider '{provider_id}' not found", ) - # Invalidate cached tokens for this provider if scopes changed - if scopes_changed: - cache = get_token_cache() - evicted = cache.delete_for_provider(provider_id) - logger.info( - "Scopes changed for provider, " - f"evicted {evicted} cached tokens" - ) + if credential_info is not None: + provider.credential_provider_arn = credential_info.credential_provider_arn + provider.callback_url = credential_info.callback_url + provider.updated_at = datetime.now(timezone.utc).isoformat() + "Z" + await provider_repo.put_provider(provider) return OAuthProviderResponse.from_provider(provider) @@ -189,107 +325,56 @@ async def update_provider( @router.delete("/{provider_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_provider( provider_id: str, - force: bool = Query( - False, - description="Force delete even if users are connected (will delete their tokens)", - ), admin: User = Depends(require_admin), provider_repo: OAuthProviderRepository = Depends(get_provider_repository), - token_repo: OAuthTokenRepository = Depends(get_token_repository), + registrar: AgentCoreRegistrar = Depends(get_agentcore_registrar), ): - """ - Delete an OAuth provider. - - Requires admin access. - - Warning: If users are connected to this provider, their tokens will be deleted - unless force=False (default), in which case the deletion will fail. + """Delete a provider from AgentCore and DynamoDB. - Args: - provider_id: Provider identifier - force: If True, delete even if users are connected - admin: Authenticated admin user (injected) - - Raises: - HTTPException: - - 400 if users are connected and force=False - - 404 if provider not found + AgentCore's deletion also removes every user token stored in its vault + for this provider, so connected users must reconnect the next time + they invoke a tool that needs it. """ - logger.info("Admin deleting OAuth provider") - - # Check for connected users - connected_tokens = await token_repo.list_provider_tokens(provider_id) - - if connected_tokens and not force: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=( - f"Cannot delete provider with {len(connected_tokens)} connected users. " - "Use force=true to delete anyway (will remove user connections)." - ), - ) - - # Delete user tokens if any - if connected_tokens: - deleted_count = await token_repo.delete_provider_tokens(provider_id) - logger.info(f"Deleted {deleted_count} user tokens for provider") - - # Delete provider - deleted = await provider_repo.delete_provider(provider_id) - - if not deleted: + existing = await provider_repo.get_provider(provider_id) + if not existing: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Provider '{provider_id}' not found", ) - # Clear cached tokens - cache = get_token_cache() - cache.delete_for_provider(provider_id) - + registrar.delete_credential_provider(provider_id) + await provider_repo.delete_provider(provider_id) return None # ============================================================================= -# Provider Statistics +# Helpers # ============================================================================= -@router.get("/{provider_id}/connections/count") -async def get_provider_connection_count( - provider_id: str, - admin: User = Depends(require_admin), - provider_repo: OAuthProviderRepository = Depends(get_provider_repository), - token_repo: OAuthTokenRepository = Depends(get_token_repository), -): - """ - Get the number of users connected to a provider. - - Requires admin access. - - Args: - provider_id: Provider identifier - admin: Authenticated admin user (injected) - - Returns: - Dict with provider_id and connection_count - - Raises: - HTTPException: 404 if provider not found - """ - logger.info("Admin getting connection count for provider") - - # Verify provider exists - provider = await provider_repo.get_provider(provider_id) - if not provider: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Provider '{provider_id}' not found", - ) - - tokens = await token_repo.list_provider_tokens(provider_id) - - return { - "provider_id": provider_id, - "connection_count": len(tokens), - } +def _build_provider_from_create( + data: OAuthProviderCreate, credential_info: CredentialProviderInfo +) -> OAuthProvider: + now = datetime.now(timezone.utc).isoformat() + "Z" + return OAuthProvider( + provider_id=data.provider_id, + display_name=data.display_name, + provider_type=data.provider_type, + scopes=data.scopes, + allowed_roles=data.allowed_roles, + enabled=data.enabled, + icon_name=data.icon_name, + # `""` from the form means "no uploaded icon" — store as None so + # absent and explicitly-cleared round-trip identically. + icon_data=data.icon_data or None, + credential_provider_arn=credential_info.credential_provider_arn, + callback_url=credential_info.callback_url, + oauth_discovery_url=data.oauth_discovery_url, + authorization_server_metadata=data.authorization_server_metadata, + # `{}` from the form means "explicitly no extras" — store as None + # so absent/empty are indistinguishable in DynamoDB and `from_*` + # lookups round-trip identically. + custom_parameters=data.custom_parameters or None, + created_at=now, + updated_at=now, + ) diff --git a/backend/src/apis/app_api/admin/tools/routes.py b/backend/src/apis/app_api/admin/tools/routes.py index 89e0a477..88cfb825 100644 --- a/backend/src/apis/app_api/admin/tools/routes.py +++ b/backend/src/apis/app_api/admin/tools/routes.py @@ -177,6 +177,13 @@ async def admin_update_tool( if not updated: raise HTTPException(status_code=404, detail=f"Tool '{tool_id}' not found") + # Invalidate the freshness TTL entry so the next chat turn in this + # process sees the new updated_at immediately (no wait for the TTL + # to lapse). Other processes pick the change up within one TTL + # window via their own freshness reads. + from apis.app_api.tools.freshness import invalidate as invalidate_freshness + invalidate_freshness(tool_id) + return AdminToolResponse.from_tool_definition(updated) @@ -210,6 +217,9 @@ async def admin_delete_tool( if not deleted: raise HTTPException(status_code=404, detail=f"Tool '{tool_id}' not found") + from apis.app_api.tools.freshness import invalidate as invalidate_freshness + invalidate_freshness(tool_id) + action = "deleted" if hard else "disabled" return {"message": f"Tool '{tool_id}' {action} successfully"} diff --git a/backend/src/apis/app_api/assistants/routes.py b/backend/src/apis/app_api/assistants/routes.py index 3ad1a0a2..9be35b50 100644 --- a/backend/src/apis/app_api/assistants/routes.py +++ b/backend/src/apis/app_api/assistants/routes.py @@ -446,7 +446,7 @@ async def test_chat_endpoint(assistant_id: str, request: AssistantTestChatReques augmented_message = augment_prompt_with_context(user_message=request.message, context_chunks=context_chunks) # 6. Create agent with assistant's instructions as system prompt - agent = get_agent( + agent = await get_agent( session_id=session_id, user_id=user_id, enabled_tools=None, # No tools for test chat diff --git a/backend/src/apis/app_api/chat/routes.py b/backend/src/apis/app_api/chat/routes.py index 839e24f9..8c479a2f 100644 --- a/backend/src/apis/app_api/chat/routes.py +++ b/backend/src/apis/app_api/chat/routes.py @@ -363,7 +363,7 @@ async def chat_stream(request: ChatRequest, current_user: User = Depends(get_cur try: # Get agent instance (with or without tool filtering) # Use assistant's system prompt if provided - agent = get_agent( + agent = await get_agent( session_id=request.session_id, user_id=user_id, enabled_tools=authorized_tools, # Filtered by RBAC (may be None for all allowed) diff --git a/backend/src/apis/app_api/connectors/__init__.py b/backend/src/apis/app_api/connectors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/src/apis/app_api/connectors/routes.py b/backend/src/apis/app_api/connectors/routes.py new file mode 100644 index 00000000..eb761e02 --- /dev/null +++ b/backend/src/apis/app_api/connectors/routes.py @@ -0,0 +1,81 @@ +"""User-facing connector routes. + +Lets a signed-in user see which OAuth connectors are available to them +(role-filtered). Consent is initiated by the inference API, which has the +AgentCore Runtime workload context; this router is purely a data source. +""" + +import logging +from typing import List, Optional + +from fastapi import APIRouter, Depends +from pydantic import BaseModel + +from apis.shared.auth import User, get_current_user +from apis.shared.oauth.models import OAuthProvider, OAuthProviderType +from apis.shared.oauth.provider_repository import ( + OAuthProviderRepository, + get_provider_repository, +) +from apis.shared.rbac.service import AppRoleService, get_app_role_service + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/connectors", tags=["connectors"]) + + +class UserConnector(BaseModel): + """Connector as visible to a signed-in user. + + Drops admin-only fields (ARN, callback URL, role allow-list) and keeps + only what the settings page needs to render. + """ + + provider_id: str + display_name: str + provider_type: OAuthProviderType + icon_name: str + icon_data: Optional[str] = None + scopes: List[str] + + +class UserConnectorListResponse(BaseModel): + connectors: List[UserConnector] + + +def _visible_to_user(provider: OAuthProvider, user_role_ids: List[str]) -> bool: + """True when the user is allowed to use this connector. + + An empty `allowed_roles` list means unrestricted access. A non-empty + list grants access to users who share at least one AppRole ID. + """ + if not provider.enabled: + return False + if not provider.allowed_roles: + return True + return bool(set(provider.allowed_roles) & set(user_role_ids)) + + +@router.get("/", response_model=UserConnectorListResponse) +async def list_connectors( + current_user: User = Depends(get_current_user), + provider_repo: OAuthProviderRepository = Depends(get_provider_repository), + role_service: AppRoleService = Depends(get_app_role_service), +) -> UserConnectorListResponse: + """List enabled connectors available to the current user.""" + permissions = await role_service.resolve_user_permissions(current_user) + providers = await provider_repo.list_providers(enabled_only=True) + visible = [p for p in providers if _visible_to_user(p, permissions.app_roles)] + return UserConnectorListResponse( + connectors=[ + UserConnector( + provider_id=p.provider_id, + display_name=p.display_name, + provider_type=p.provider_type, + icon_name=p.icon_name, + icon_data=p.icon_data, + scopes=p.scopes, + ) + for p in visible + ] + ) diff --git a/backend/src/apis/app_api/main.py b/backend/src/apis/app_api/main.py index 0c16aef8..4e76e08b 100644 --- a/backend/src/apis/app_api/main.py +++ b/backend/src/apis/app_api/main.py @@ -87,7 +87,7 @@ async def lifespan(app: FastAPI): from apis.app_api.documents.routes import router as documents_router from apis.app_api.users.routes import router as users_router from apis.app_api.user_settings.routes import router as user_settings_router -from apis.shared.oauth.routes import router as oauth_router +from apis.app_api.connectors.routes import router as connectors_router from apis.app_api.system.routes import router as system_router from apis.app_api.shares.routes import conversations_share_router, shares_router, shared_view_router @@ -108,7 +108,7 @@ async def lifespan(app: FastAPI): app.include_router(memory_router) # AgentCore Memory access endpoints app.include_router(tools_router) # Tool discovery and permissions app.include_router(files_router) # File upload via pre-signed URLs -app.include_router(oauth_router) # OAuth provider connections +app.include_router(connectors_router) # User-facing connector catalog app.include_router(system_router) # System status and first-boot endpoints app.include_router(conversations_share_router) # Share conversations endpoints app.include_router(shares_router) # Share management (update, revoke, export) @@ -142,11 +142,16 @@ async def lifespan(app: FastAPI): if __name__ == "__main__": import uvicorn + # Watch the full backend/src tree so edits to shared modules outside + # app_api/ (apis/shared/, agents/) trigger reload instead of defaulting + # to cwd, which only sees this API's own files. + src_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) # Run with full module path when executing directly uvicorn.run( "apis.app_api.main:app", host="0.0.0.0", port=8000, reload=True, + reload_dirs=[src_root], log_level="info" ) diff --git a/backend/src/apis/app_api/sessions/routes.py b/backend/src/apis/app_api/sessions/routes.py index d540b797..0cac994f 100644 --- a/backend/src/apis/app_api/sessions/routes.py +++ b/backend/src/apis/app_api/sessions/routes.py @@ -19,7 +19,12 @@ MessagesListResponse ) from apis.shared.sessions.messages import get_messages -from apis.shared.sessions.metadata import store_session_metadata, get_session_metadata, list_user_sessions +from apis.shared.sessions.metadata import ( + list_user_sessions, + get_session_metadata, + remove_pending_interrupts, + store_session_metadata, +) from .services.session_service import SessionService from apis.app_api.shares.service import get_share_service from apis.shared.auth.dependencies import get_current_user @@ -546,3 +551,38 @@ async def get_session_messages_endpoint( status_code=500, detail=f"Failed to retrieve messages: {str(e)}" ) + + +@router.delete("/{session_id}/pending-interrupts/{interrupt_id:path}", status_code=204) +async def dismiss_pending_interrupt_endpoint( + session_id: str, + interrupt_id: str, + current_user: User = Depends(get_current_user), +): + """Dismiss a pending OAuth consent interrupt for the caller's session. + + The frontend calls this when the user clicks the dismiss button on an + inline consent prompt, so a refresh doesn't redisplay it. The id is + matched as-is — Strands generates ids like ``oauth:google-calendar``, + so we accept ``:path`` to keep the colon literal in the URL. + + No-op for unknown ids and missing sessions, returning 204 in both + cases (the user's intent is satisfied). + """ + user_id = current_user.user_id + + logger.info("DELETE /sessions/%s/pending-interrupts/%s", session_id, interrupt_id) + + try: + await remove_pending_interrupts( + session_id=session_id, + user_id=user_id, + interrupt_ids=[interrupt_id], + ) + return Response(status_code=204) + except Exception as e: + logger.error("Error dismissing pending interrupt", exc_info=True) + raise HTTPException( + status_code=500, + detail=f"Failed to dismiss interrupt: {str(e)}", + ) diff --git a/backend/src/apis/app_api/tools/freshness.py b/backend/src/apis/app_api/tools/freshness.py new file mode 100644 index 00000000..b8048e34 --- /dev/null +++ b/backend/src/apis/app_api/tools/freshness.py @@ -0,0 +1,92 @@ +"""Per-process TTL cache of tool-config freshness tokens. + +Cheap change-detection signal for the agent and MCP-client caches: any +admin edit to a tool bumps its `updated_at`, so including the freshness +hash in a cache key causes the next build to miss and rebuild with the +fresh config. + +Reads are TTL-cached so the per-turn overhead is bounded to at most one +DynamoDB GetItem per tool per TTL window, per process. Admin routes +call `invalidate(tool_id)` after a write so same-process visibility is +immediate; other processes see the change within one TTL window. +""" + +import asyncio +import hashlib +import logging +import time +from typing import Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + +# tool_id -> (updated_at_iso_or_none, monotonic_fetched_at) +# None is stored when the tool is missing, so negative lookups are also +# TTL-cached — a deleted tool doesn't trigger a DynamoDB read every turn. +_cache: Dict[str, Tuple[Optional[str], float]] = {} +_TTL_SECONDS = 10.0 + + +def _reset_for_tests() -> None: + _cache.clear() + + +async def _fetch_updated_at(tool_id: str) -> Optional[str]: + from apis.app_api.tools.repository import get_tool_catalog_repository + + repo = get_tool_catalog_repository() + tool = await repo.get_tool(tool_id) + if tool is None or tool.updated_at is None: + return None + return tool.updated_at.isoformat() + "Z" + + +async def get_tool_updated_at(tool_id: str) -> Optional[str]: + """Return the `updated_at` for one tool, TTL-cached per process.""" + now = time.monotonic() + cached = _cache.get(tool_id) + if cached is not None and now - cached[1] < _TTL_SECONDS: + return cached[0] + + try: + updated_at = await _fetch_updated_at(tool_id) + except Exception: + logger.exception("Failed to fetch updated_at for tool %s", tool_id) + # On failure, return the last-known value if we have one, else + # None. Never raise — freshness is advisory for cache keying and + # must not break the chat turn. + return cached[0] if cached is not None else None + + _cache[tool_id] = (updated_at, now) + return updated_at + + +async def get_freshness_hash(tool_ids: List[str]) -> str: + """Return a stable 16-char hash of (tool_id -> updated_at). + + Changes when any of the given tools' config is edited. Empty list + returns the empty string so callers can short-circuit. + """ + if not tool_ids: + return "" + + sorted_ids = sorted(tool_ids) + values = await asyncio.gather( + *(get_tool_updated_at(tid) for tid in sorted_ids) + ) + + payload = "|".join( + f"{tid}={val or 'none'}" for tid, val in zip(sorted_ids, values) + ) + return hashlib.md5(payload.encode()).hexdigest()[:16] + + +def invalidate(tool_id: Optional[str] = None) -> None: + """Drop an entry (or the whole cache) from the TTL store. + + Call this from admin write paths so changes are visible in the same + process on the very next turn, without waiting for the TTL to lapse. + """ + if tool_id is None: + _cache.clear() + else: + _cache.pop(tool_id, None) diff --git a/backend/src/apis/inference_api/chat/models.py b/backend/src/apis/inference_api/chat/models.py index f2b4a5d9..09968ed6 100644 --- a/backend/src/apis/inference_api/chat/models.py +++ b/backend/src/apis/inference_api/chat/models.py @@ -17,11 +17,23 @@ class FileContent(BaseModel): bytes: str # Base64 encoded +class InterruptResponseEntry(BaseModel): + """One user response to a Strands interrupt, in the SDK's prompt shape. + + Posted by the frontend after the user completes (or declines) an OAuth + consent popup. The backend forwards the list verbatim to + `agent.stream_async(...)` to resume the paused turn. + """ + + interruptId: str + response: Any = None + + class InvocationRequest(BaseModel): """Input for /invocations endpoint with multi-provider support""" session_id: str - message: str + message: str = "" model_id: Optional[str] = None temperature: Optional[float] = None system_prompt: Optional[str] = None @@ -36,6 +48,10 @@ class InvocationRequest(BaseModel): # AgentCore Runtime returns 424 when it sees a non-empty 'assistant_id' field, # likely trying to resolve it as an AWS Bedrock Agent ID. rag_assistant_id: Optional[str] = None + # When set, the route resumes a paused agent turn instead of starting a + # new one. `message` is ignored in that case — the original prompt is + # already in the agent's interrupt context. + interrupt_responses: Optional[List[InterruptResponseEntry]] = None class InvocationResponse(BaseModel): diff --git a/backend/src/apis/inference_api/chat/routes.py b/backend/src/apis/inference_api/chat/routes.py index ae4d6893..b35b884f 100644 --- a/backend/src/apis/inference_api/chat/routes.py +++ b/backend/src/apis/inference_api/chat/routes.py @@ -7,6 +7,7 @@ These endpoints are at the root level to comply with AWS Bedrock AgentCore Runtime requirements. """ +import asyncio import json import logging import os @@ -35,9 +36,10 @@ ) from apis.shared.rbac.service import get_app_role_service +from apis.shared.sessions.metadata import ensure_session_metadata_exists from .models import FileContent, InvocationRequest -from .service import get_agent +from .service import generate_conversation_title, get_agent logger = logging.getLogger(__name__) @@ -209,7 +211,13 @@ async def invocations(request: InvocationRequest, current_user: User = Depends(g input_data = request user_id = current_user.user_id auth_token = current_user.raw_token - logger.info("Invocation request received") + # Resume requests reuse the cached agent and its paused interrupt state; + # they bypass quota, file resolution, and RAG augmentation because those + # already ran on the original turn that got paused. + is_resume = bool(input_data.interrupt_responses) + logger.info( + "Invocation request received (resume=%s)" % is_resume + ) logger.info("Message received") if input_data.enabled_tools: @@ -242,10 +250,41 @@ async def invocations(request: InvocationRequest, current_user: User = Depends(g logger.warning("Failed to resolve file upload IDs") # Continue without files rather than failing the request + # Pre-create session metadata so OAuth interrupts and other state can + # attach to the session row from turn one. Best-effort; on failure the + # post-stream lazy-create in StreamCoordinator still covers it. + # + # Also clear any stale paused_turn snapshot at the start of a fresh turn. + # If the user abandoned a paused turn and started a new one, the prior + # snapshot is no longer authorized — letting it survive would let a + # later (mistaken) resume request pick up against a turn the user + # already moved past. + is_new_session = False + if not is_resume: + is_new_session = await ensure_session_metadata_exists(input_data.session_id, user_id) + try: + from apis.shared.sessions.metadata import clear_paused_turn + await clear_paused_turn(input_data.session_id, user_id) + except Exception as e: + logger.error("Failed to clear stale paused_turn on new turn: %s", e, exc_info=True) + + # First turn → kick off title generation concurrently with the stream. + # Runs as a background task so it doesn't add latency to TTFT. The + # targeted UpdateExpression in update_session_title is race-safe with + # the post-stream _update_session_metadata write. + if is_new_session and input_data.message: + asyncio.create_task( + generate_conversation_title( + session_id=input_data.session_id, + user_id=user_id, + user_input=input_data.message, + ) + ) + # Check quota if enforcement is enabled quota_warning_event = None quota_exceeded_event = None - if is_quota_enforcement_enabled(): + if is_quota_enforcement_enabled() and not is_resume: try: quota_checker = get_quota_checker() quota_result = await quota_checker.check_quota(user=current_user, session_id=input_data.session_id) @@ -304,7 +343,7 @@ async def invocations(request: InvocationRequest, current_user: User = Depends(g "Invocation request - processing with assistant context" ) - if input_data.rag_assistant_id: + if input_data.rag_assistant_id and not is_resume: # Local imports to avoid circular dependency from apis.shared.assistants.rag_service import ( augment_prompt_with_context, @@ -509,29 +548,102 @@ async def invocations(request: InvocationRequest, current_user: User = Depends(g logger.info("Preview session - skipping assistant_id persistence") try: - # Resolve caching_enabled based on managed model configuration - # This allows admins to disable caching for models that don't support it - caching_enabled = await _resolve_caching_enabled(model_id=input_data.model_id, explicit_caching_enabled=input_data.caching_enabled) - - if caching_enabled is False: - logger.info("Prompt caching disabled for model") - - # Get agent instance with user-specific configuration - # AgentCore Memory tracks preferences across sessions per user_id - # Supports multiple LLM providers: AWS Bedrock, OpenAI, and Google Gemini - # Use augmented message and assistant system prompt if assistant RAG was applied - agent = get_agent( - session_id=input_data.session_id, - user_id=user_id, - auth_token=auth_token, - enabled_tools=input_data.enabled_tools, - model_id=input_data.model_id, - temperature=input_data.temperature, - system_prompt=system_prompt, # Use assistant's instructions if available - caching_enabled=caching_enabled, - provider=input_data.provider, - max_tokens=input_data.max_tokens, - ) + # Resume requests rebuild the agent from the persisted PausedTurnSnapshot + # so a refresh / cache eviction / pod restart between pause and resume + # still lands on the same MainAgent shape (matching tool registry, + # model, prompt). Strands' SessionManager separately restores + # `_interrupt_state` from AgentCore Memory, so the paused tool call + # picks up where it left off. Non-resume requests use the request + # body as before. + if is_resume: + from datetime import datetime, timezone + from apis.shared.sessions.metadata import clear_paused_turn, get_paused_turn + + snapshot = await get_paused_turn(input_data.session_id, user_id) + if not snapshot: + logger.warning( + "Resume rejected: no paused_turn snapshot for session %s", + input_data.session_id, + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="No paused turn for this session; restart the turn.", + ) + try: + expires_at = datetime.fromisoformat(snapshot.expires_at) + except ValueError: + expires_at = None + if expires_at and datetime.now(timezone.utc) > expires_at: + logger.warning( + "Resume rejected: paused_turn snapshot expired for session %s", + input_data.session_id, + ) + await clear_paused_turn(input_data.session_id, user_id) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Paused turn expired; restart the turn.", + ) + + caching_enabled = snapshot.caching_enabled + agent = await get_agent( + session_id=input_data.session_id, + user_id=user_id, + auth_token=auth_token, + enabled_tools=snapshot.enabled_tools, + model_id=snapshot.model_id, + temperature=snapshot.temperature, + system_prompt=snapshot.system_prompt, + caching_enabled=snapshot.caching_enabled, + provider=snapshot.provider, + max_tokens=snapshot.max_tokens, + ) + else: + # Resolve caching_enabled based on managed model configuration + # This allows admins to disable caching for models that don't support it + caching_enabled = await _resolve_caching_enabled(model_id=input_data.model_id, explicit_caching_enabled=input_data.caching_enabled) + + if caching_enabled is False: + logger.info("Prompt caching disabled for model") + + # Get agent instance with user-specific configuration + # AgentCore Memory tracks preferences across sessions per user_id + # Supports multiple LLM providers: AWS Bedrock, OpenAI, and Google Gemini + # Use augmented message and assistant system prompt if assistant RAG was applied + agent = await get_agent( + session_id=input_data.session_id, + user_id=user_id, + auth_token=auth_token, + enabled_tools=input_data.enabled_tools, + model_id=input_data.model_id, + temperature=input_data.temperature, + system_prompt=system_prompt, # Use assistant's instructions if available + caching_enabled=caching_enabled, + provider=input_data.provider, + max_tokens=input_data.max_tokens, + ) + + # Resume requests must target interrupts that the cached agent + # actually has paused. Cache eviction, a process restart, or a + # forged request will otherwise be silently accepted by Strands + # and drop the client's response. Reject up front so the client + # sees a 400 and can restart the turn cleanly. + if is_resume: + strands_agent = getattr(agent, "agent", None) + interrupt_state = getattr(strands_agent, "_interrupt_state", None) if strands_agent else None + known_ids: set[str] = set() + if interrupt_state and getattr(interrupt_state, "activated", False): + interrupts = getattr(interrupt_state, "interrupts", None) or {} + known_ids = set(interrupts.keys()) + submitted_ids = [entry.interruptId for entry in (input_data.interrupt_responses or [])] + unknown_ids = [iid for iid in submitted_ids if iid not in known_ids] + if unknown_ids: + logger.warning( + "Resume rejected: submitted interrupt ids not in paused state" + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Unknown or expired interrupt ids; restart the turn.", + ) # Build citations list for persistence (convert context chunks to citation format) citations_for_storage = [] @@ -573,15 +685,70 @@ async def stream_with_quota_warning() -> AsyncGenerator[str, None]: augmented_message != input_data.message # RAG augmentation or bool(files_to_send) # File attachments ) + # Strands' resume protocol wants each entry wrapped as + # {"interruptResponse": {...}}. The InvocationRequest schema + # accepts the inner shape so callers don't have to think about + # the SDK's content-block convention. + interrupt_responses_payload = ( + [{"interruptResponse": entry.model_dump()} for entry in input_data.interrupt_responses] + if input_data.interrupt_responses + else None + ) + async for event in agent.stream_async( augmented_message, session_id=input_data.session_id, files=files_to_send if files_to_send else None, citations=citations_for_storage if citations_for_storage else None, original_message=input_data.message if message_will_be_modified else None, + interrupt_responses=interrupt_responses_payload, ): yield event + # Resume bookkeeping: any interrupt that was submitted in this + # request and is no longer present in the agent's interrupt state + # has been resolved — drop the persisted breadcrumb so a refresh + # doesn't redisplay a stale prompt. Interrupts that re-paused + # (same provider, new url) are left in place; the next event + # extractor will refresh them. + # + # When the agent's interrupt state is no longer activated after + # streaming, the turn fully completed — clear ``paused_turn`` too + # so a stale snapshot doesn't authorize a phantom resume against + # an already-finished turn. If interrupts re-paused, the snapshot + # was overwritten by ``_extract_oauth_required_events`` for the + # next pause, so leave it alone. + if is_resume and input_data.interrupt_responses: + try: + strands_agent = getattr(agent, "agent", None) + interrupt_state = getattr(strands_agent, "_interrupt_state", None) if strands_agent else None + still_paused: set[str] = set() + state_activated = bool( + interrupt_state and getattr(interrupt_state, "activated", False) + ) + if state_activated: + still_paused = set((getattr(interrupt_state, "interrupts", None) or {}).keys()) + resolved_ids = [ + entry.interruptId + for entry in input_data.interrupt_responses + if entry.interruptId not in still_paused + ] + if resolved_ids: + from apis.shared.sessions.metadata import remove_pending_interrupts + await remove_pending_interrupts( + session_id=input_data.session_id, + user_id=user_id, + interrupt_ids=resolved_ids, + ) + if not state_activated: + from apis.shared.sessions.metadata import clear_paused_turn + await clear_paused_turn( + session_id=input_data.session_id, + user_id=user_id, + ) + except Exception as cleanup_err: + logger.error("Failed to clear resolved pending_interrupts: %s", cleanup_err, exc_info=True) + # Stream response from agent as SSE (with optional files) # Note: Compression is handled by GZipMiddleware if configured in main.py return StreamingResponse( diff --git a/backend/src/apis/inference_api/chat/service.py b/backend/src/apis/inference_api/chat/service.py index 11006fd2..6a11e380 100644 --- a/backend/src/apis/inference_api/chat/service.py +++ b/backend/src/apis/inference_api/chat/service.py @@ -7,14 +7,12 @@ import hashlib import os from typing import Optional, List, Tuple -from datetime import datetime, timezone import boto3 # from agentcore.agent.agent import ChatbotAgent from agents.main_agent.main_agent import MainAgent -from apis.shared.sessions.models import SessionMetadata -from apis.shared.sessions.metadata import store_session_metadata +from apis.shared.sessions.metadata import update_session_title logger = logging.getLogger(__name__) @@ -47,24 +45,16 @@ def _create_cache_key( system_prompt: Optional[str], caching_enabled: Optional[bool], provider: Optional[str], - max_tokens: Optional[int] + max_tokens: Optional[int], + freshness_hash: str, ) -> Tuple: """ - Create a cache key for agent instances + Create a cache key for agent instances. - Args: - session_id: Session identifier - user_id: User identifier - enabled_tools: List of enabled tool names - model_id: Model identifier - temperature: Model temperature - system_prompt: System prompt text - caching_enabled: Whether caching is enabled - provider: LLM provider - max_tokens: Maximum tokens to generate - - Returns: - Tuple suitable for use as cache key + `freshness_hash` is a short digest of the enabled tools' current + `updated_at` values (see `freshness.get_freshness_hash`). When an + admin edits a tool's config, the hash changes and the cache misses, + so the next turn builds a fresh agent with the new config. """ # Hash the tools list for stable key tools_hash = _hash_tools(enabled_tools) @@ -83,7 +73,8 @@ def _create_cache_key( prompt_hash, caching_enabled or False, provider or "bedrock", - max_tokens or 0 + max_tokens or 0, + freshness_hash, ) @@ -94,7 +85,7 @@ def _create_cache_key( _CACHE_MAX_SIZE = 100 -def get_agent( +async def get_agent( session_id: str, user_id: Optional[str] = None, auth_token: Optional[str] = None, @@ -110,8 +101,9 @@ def get_agent( Get or create agent instance with current configuration for session Implements LRU caching to reduce agent initialization overhead. - Cache key includes all configuration parameters to ensure correct behavior. - Session message history is managed by AgentCore Memory automatically. + Cache key includes all configuration parameters plus a freshness + hash of the enabled tools' `updated_at` values, so admin edits to a + tool's config invalidate the cached agent on the next turn. Args: session_id: Session identifier @@ -127,7 +119,10 @@ def get_agent( Returns: MainAgent instance (cached or newly created) """ - # Create cache key from all configuration parameters + from apis.app_api.tools.freshness import get_freshness_hash + + freshness_hash = await get_freshness_hash(enabled_tools or []) + cache_key = _create_cache_key( session_id=session_id, user_id=user_id, @@ -137,7 +132,8 @@ def get_agent( system_prompt=system_prompt, caching_enabled=caching_enabled, provider=provider, - max_tokens=max_tokens + max_tokens=max_tokens, + freshness_hash=freshness_hash, ) # Check cache @@ -291,113 +287,17 @@ async def generate_conversation_title( logger.info(f"✅ Generated title: '{title}' for session {session_id}") - # Update session metadata with the generated title - # IMPORTANT: We must read existing metadata first and only update the title field. - # The streaming coordinator has already set message_count correctly, and we must - # not overwrite it. This function is called async after streaming completes, - # so there's a race condition where we could overwrite the correct message_count - # with 0 if we don't preserve existing values. - from apis.shared.sessions.metadata import get_session_metadata - - logger.info(f"📖 Title generation: Reading existing metadata for session {session_id}") - existing_metadata = await get_session_metadata(session_id, user_id) - - if existing_metadata: - logger.info(f"📊 Title generation: Found existing metadata with message_count={existing_metadata.message_count}") - # Preserve existing metadata, only update title - session_metadata = SessionMetadata( - session_id=session_id, - user_id=user_id, - title=title, # Only update this field - status=existing_metadata.status, - created_at=existing_metadata.created_at, - last_message_at=existing_metadata.last_message_at, - message_count=existing_metadata.message_count, # PRESERVE existing count - starred=existing_metadata.starred, - tags=existing_metadata.tags, - preferences=existing_metadata.preferences - ) - else: - logger.warning(f"⚠️ Title generation: No existing metadata found - creating new with message_count=0") - # Fallback: If metadata doesn't exist yet (rare edge case), create it - # The streaming coordinator will update message_count shortly after - now = datetime.now(timezone.utc).isoformat() - session_metadata = SessionMetadata( - session_id=session_id, - user_id=user_id, - title=title, - status="active", - created_at=now, - last_message_at=now, - message_count=0, # Safe fallback - will be set by streaming coordinator - starred=False, - tags=[], - preferences=None - ) - - logger.info(f"📝 Title generation: About to store metadata with message_count={session_metadata.message_count}") - await store_session_metadata( - session_id=session_id, - user_id=user_id, - session_metadata=session_metadata - ) - - logger.info(f"💾 Title generation: Stored session metadata with title for session {session_id}") + # Targeted update — only writes the title attribute. The post-stream + # update_session_activity write is also targeted and disjoint, so the + # two cannot clobber each other on overlapping turns. + await update_session_title(session_id=session_id, user_id=user_id, title=title) return title except Exception as e: - # Log error but don't fail the request - # Title generation is nice-to-have, not critical + # Title generation is nice-to-have. Leave the existing "New Conversation" + # placeholder in place rather than writing a fallback; the row already + # exists from the pre-create. logger.error(f"Failed to generate title for session {session_id}: {e}", exc_info=True) - - # Return fallback title - fallback_title = "New Conversation" - - # Still try to store metadata with fallback title - # Same as above: preserve existing metadata to avoid race conditions - try: - from apis.shared.sessions.metadata import get_session_metadata - - existing_metadata = await get_session_metadata(session_id, user_id) - - if existing_metadata: - # Preserve existing metadata, only update title - session_metadata = SessionMetadata( - session_id=session_id, - user_id=user_id, - title=fallback_title, - status=existing_metadata.status, - created_at=existing_metadata.created_at, - last_message_at=existing_metadata.last_message_at, - message_count=existing_metadata.message_count, # PRESERVE - starred=existing_metadata.starred, - tags=existing_metadata.tags, - preferences=existing_metadata.preferences - ) - else: - # Fallback: metadata doesn't exist yet - now = datetime.now(timezone.utc).isoformat() - session_metadata = SessionMetadata( - session_id=session_id, - user_id=user_id, - title=fallback_title, - status="active", - created_at=now, - last_message_at=now, - message_count=0, - starred=False, - tags=[], - preferences=None - ) - - await store_session_metadata( - session_id=session_id, - user_id=user_id, - session_metadata=session_metadata - ) - except Exception as metadata_error: - logger.error(f"Failed to store fallback metadata: {metadata_error}") - - return fallback_title + return "New Conversation" diff --git a/backend/src/apis/inference_api/connectors/__init__.py b/backend/src/apis/inference_api/connectors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/src/apis/inference_api/connectors/routes.py b/backend/src/apis/inference_api/connectors/routes.py new file mode 100644 index 00000000..e86e961b --- /dev/null +++ b/backend/src/apis/inference_api/connectors/routes.py @@ -0,0 +1,362 @@ +"""User-initiated OAuth consent for connectors. + +Lives on the inference API because the AgentCore Runtime injects the +workload access token via `AgentCoreContextMiddleware` on every request +proxied through `InvokeAgentRuntime`. `IdentityClient.get_token` reads +that token from `BedrockAgentCoreContext`, which is only populated here +— never on the app API. + +Flow: the settings page posts to `/connectors/{id}/initiate-consent`. +If AgentCore already has a valid token for this user + provider, we +return `{connected: true}` so the UI can show a success state. If +consent is required, AgentCore hands us back an authorization URL and we +forward it for the frontend popup. +""" + +from __future__ import annotations + +import logging +import os +from functools import lru_cache + +import boto3 +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import BaseModel + +from agents.main_agent.integrations import oauth_token_cache +from agents.main_agent.integrations.agentcore_identity import ( + CallbackUrlUnavailableError, + WorkloadTokenUnavailableError, + custom_parameters_for, + get_agentcore_identity_client, +) +from apis.shared.auth.dependencies import get_current_user_trusted +from apis.shared.auth.models import User +from apis.shared.oauth.disconnect_repository import ( + OAuthDisconnectRepository, + get_disconnect_repository, +) +from apis.shared.oauth.provider_repository import ( + OAuthProviderRepository, + get_provider_repository, +) +from apis.shared.rbac.service import AppRoleService, get_app_role_service + +logger = logging.getLogger(__name__) + + +@lru_cache(maxsize=1) +def _agentcore_control_client(): + """Process-wide bedrock-agentcore control-plane client. + + Cached so `complete_consent` doesn't reconstruct the boto3 client + (and re-resolve credentials) on every request. + """ + region = os.environ.get("AWS_REGION", "us-west-2") + return boto3.client("bedrock-agentcore", region_name=region) + +router = APIRouter(prefix="/connectors", tags=["connectors"]) + + +class InitiateConsentResponse(BaseModel): + """Either a pending consent URL or a confirmation of existing access.""" + + connected: bool = False + authorization_url: str | None = None + + +class ConnectorStatusResponse(BaseModel): + """Whether the caller has a usable token in AgentCore's vault. + + Side-effect-free: unlike `initiate-consent`, this endpoint discards + the authorization URL when consent is required, and does NOT remember + the session_uri server-side. Use it from listing UIs that need a + "Connected" badge without committing the user to a consent flow. + """ + + connected: bool = False + + +class CompleteConsentRequest(BaseModel): + """Body for finalizing a consent flow after the popup returns.""" + + session_uri: str + provider_id: str | None = None + + +class CompleteConsentResponse(BaseModel): + ok: bool = True + + +def _is_visible_to_user(provider, user_role_ids: list[str]) -> bool: + if not provider.enabled: + return False + if not provider.allowed_roles: + return True + return bool(set(provider.allowed_roles) & set(user_role_ids)) + + +async def _resolve_visible_provider( + provider_id: str, + current_user: User, + provider_repo: OAuthProviderRepository, + role_service: AppRoleService, +): + """Fetch a provider and 404/403 if it isn't visible to the caller. + + Centralizes the lookup so `initiate_consent` and `connector_status` + use identical visibility rules. + """ + provider = await provider_repo.get_provider(provider_id) + if not provider or not provider.enabled: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Connector '{provider_id}' not found", + ) + + permissions = await role_service.resolve_user_permissions(current_user) + if not _is_visible_to_user(provider, permissions.app_roles): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have access to this connector", + ) + return provider + + +@router.post( + "/{provider_id}/initiate-consent", + response_model=InitiateConsentResponse, +) +async def initiate_consent( + provider_id: str, + current_user: User = Depends(get_current_user_trusted), + provider_repo: OAuthProviderRepository = Depends(get_provider_repository), + role_service: AppRoleService = Depends(get_app_role_service), + disconnect_repo: OAuthDisconnectRepository = Depends(get_disconnect_repository), +) -> InitiateConsentResponse: + """Start (or verify) AgentCore consent for the given provider.""" + provider = await _resolve_visible_provider( + provider_id, current_user, provider_repo, role_service + ) + + # If the user previously disconnected, force a fresh consent flow even + # though AgentCore's vault still holds an unexpired token — they + # explicitly opted out, and re-using the cached entry would silently + # undo that. + force_auth = await disconnect_repo.is_disconnected( + current_user.user_id, provider.provider_id + ) + + identity = get_agentcore_identity_client() + try: + result = await identity.get_token_for_user( + provider_name=provider.provider_id, + scopes=provider.scopes, + user_id=current_user.user_id, + force_authentication=force_auth, + custom_parameters=custom_parameters_for( + provider.provider_type.value, provider.custom_parameters + ), + # No custom_state: AgentCore appears to treat its presence as a + # signal to start a fresh flow, never short-circuiting to the + # cached token. The frontend passes provider_id via the + # callback URL query string so /oauth-complete still knows + # which provider resolved. + ) + except WorkloadTokenUnavailableError as err: + # Only happens when the route is called outside an AgentCore Runtime + # invocation (e.g. local dev without the runtime proxy). Surface a + # clear error instead of a 500. + logger.warning("Consent initiation attempted without workload context: %s", err) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=( + "AgentCore workload context is not available. In prod, this " + "endpoint runs under the runtime proxy; locally, set " + "AGENTCORE_RUNTIME_WORKLOAD_NAME to enable the mint fallback." + ), + ) + except CallbackUrlUnavailableError as err: + # Frontend is expected to send the OAuth2CallbackUrl header on this + # path; if the header is missing AND the env-var fallback is unset, + # tell the caller exactly what to fix. + logger.warning("Consent initiation missing callback URL: %s", err) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=str(err), + ) + + if result.requires_consent: + return InitiateConsentResponse(authorization_url=result.authorization_url) + return InitiateConsentResponse(connected=True) + + +@router.get( + "/{provider_id}/status", + response_model=ConnectorStatusResponse, +) +async def connector_status( + provider_id: str, + current_user: User = Depends(get_current_user_trusted), + provider_repo: OAuthProviderRepository = Depends(get_provider_repository), + role_service: AppRoleService = Depends(get_app_role_service), + disconnect_repo: OAuthDisconnectRepository = Depends(get_disconnect_repository), +) -> ConnectorStatusResponse: + """Report whether AgentCore's vault has a usable token for this caller. + + Side-effect-free read: when the vault is empty we discard the + authorization URL the SDK returns. The settings page uses this to + decorate the list with a "Connected" badge without committing the + user to a flow. + + GET so it's cache-friendly and idempotent. The HTTP status only + reflects request validity (401/403/404/503); whether the user is + *connected* is in the response body. + """ + provider = await _resolve_visible_provider( + provider_id, current_user, provider_repo, role_service + ) + + # User just disconnected — they're not connected, regardless of what + # AgentCore's vault still holds. This avoids a misleading "Connected" + # badge between disconnect and the next re-consent. + if await disconnect_repo.is_disconnected( + current_user.user_id, provider.provider_id + ): + return ConnectorStatusResponse(connected=False) + + identity = get_agentcore_identity_client() + try: + result = await identity.get_token_for_user( + provider_name=provider.provider_id, + scopes=provider.scopes, + user_id=current_user.user_id, + custom_parameters=custom_parameters_for( + provider.provider_type.value, provider.custom_parameters + ), + ) + except WorkloadTokenUnavailableError as err: + logger.warning("Status check without workload context: %s", err) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=str(err), + ) + except CallbackUrlUnavailableError as err: + logger.warning("Status check missing callback URL: %s", err) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=str(err), + ) + + return ConnectorStatusResponse(connected=not result.requires_consent) + + +@router.post( + "/complete-consent", + response_model=CompleteConsentResponse, +) +async def complete_consent( + body: CompleteConsentRequest, + current_user: User = Depends(get_current_user_trusted), + disconnect_repo: OAuthDisconnectRepository = Depends(get_disconnect_repository), +) -> CompleteConsentResponse: + """Finalize an OAuth consent flow after the popup redirects home. + + AgentCore's `/identities/oauth2/authorize` redirect comes back with the + same `request_uri` it was initiated with (as `session_id` on our landing + page). Until we call `CompleteResourceTokenAuth` with that URI and the + user's identity, AgentCore treats the flow as unfinished and the token + vault stays empty — the next `GetResourceOauth2Token` call returns a + fresh authorization URL. + + Returns `ok: true` on success; errors from AgentCore bubble up as 502. + + Authorization: the inbound JWT (`current_user`) is verified by + `get_current_user_trusted`, and we pass that user's id as + `userIdentifier` to AgentCore. AgentCore's own binding rejects a + completion attempt whose `userIdentifier` doesn't match the identity + that initiated the session, so a leaked `session_uri` cannot be + redeemed under a different user. + """ + control = _agentcore_control_client() + + try: + control.complete_resource_token_auth( + userIdentifier={"userId": current_user.user_id}, + sessionUri=body.session_uri, + ) + except Exception as err: + logger.error( + "CompleteResourceTokenAuth failed for user=%s provider=%s: %s", + current_user.user_id, + body.provider_id, + err, + ) + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"Failed to finalize OAuth consent: {err}", + ) + + # Successful re-consent supersedes any prior disconnect — clear the + # durable flag so subsequent status checks report the user as connected + # without waiting for the agent loop to warm the cache. + if body.provider_id: + await disconnect_repo.clear_disconnected( + current_user.user_id, body.provider_id + ) + + logger.info( + "Completed OAuth consent for user=%s provider=%s", + current_user.user_id, + body.provider_id, + ) + return CompleteConsentResponse(ok=True) + + +@router.delete( + "/{provider_id}/connection", + status_code=status.HTTP_204_NO_CONTENT, +) +async def disconnect_connector( + provider_id: str, + current_user: User = Depends(get_current_user_trusted), + provider_repo: OAuthProviderRepository = Depends(get_provider_repository), + role_service: AppRoleService = Depends(get_app_role_service), + disconnect_repo: OAuthDisconnectRepository = Depends(get_disconnect_repository), +): + """Best-effort disconnect for the caller's connection to this provider. + + AgentCore Identity exposes no per-user vault-delete API, so we cannot + actually destroy the user's stored token. What we can do: + + 1. Persist the disconnect intent in DynamoDB so every replica's agent + loop and `/status` endpoint reads the same state — the next attempt + to use the connector triggers a fresh consent flow with + `force_authentication=True`, which makes AgentCore replace the vault + entry rather than reuse it. + 2. Drop the local hot-path cache entry on this replica, so no in-flight + MCP request continues to inject the (stale-by-intent) bearer token. + Other replicas pick up the change on their next `BeforeToolCallEvent` + (the consent hook reads the disconnect repo every gate call). + + The existing vault entry stays valid at the upstream provider until it + expires naturally or the user revokes the application from their + provider account (e.g. https://myaccount.google.com/connections). This + is documented as part of the disconnect UX. + """ + provider = await _resolve_visible_provider( + provider_id, current_user, provider_repo, role_service + ) + + await disconnect_repo.mark_disconnected( + current_user.user_id, provider.provider_id + ) + oauth_token_cache.clear_user_provider( + current_user.user_id, provider.provider_id + ) + logger.info( + "Marked connector for re-consent on next use: user=%s provider=%s", + current_user.user_id, + provider.provider_id, + ) + return None diff --git a/backend/src/apis/inference_api/main.py b/backend/src/apis/inference_api/main.py index de9308d0..434941e8 100644 --- a/backend/src/apis/inference_api/main.py +++ b/backend/src/apis/inference_api/main.py @@ -33,6 +33,8 @@ from contextlib import asynccontextmanager import logging +from apis.inference_api.middleware.agentcore_context import AgentCoreContextMiddleware + # Set up logging logging.basicConfig( level=logging.INFO, @@ -117,6 +119,12 @@ async def lifespan(app: FastAPI): ) logger.info("Added GZip middleware for response compression") +# Bridge AgentCore Runtime headers (WorkloadAccessToken, OAuth2CallbackUrl, +# session ID) into BedrockAgentCoreContext so downstream code can look up +# per-user OAuth tokens via AgentCore Identity. +app.add_middleware(AgentCoreContextMiddleware) +logger.info("Added AgentCore Runtime context middleware") + # Add CORS middleware - origins from CDK-provided CORS_ORIGINS env var _cors_origins = os.environ.get("CORS_ORIGINS", "").split(",") app.add_middleware( @@ -132,11 +140,13 @@ async def lifespan(app: FastAPI): from apis.inference_api.chat.routes import router as agentcore_router from apis.inference_api.chat.converse_routes import router as converse_router from apis.inference_api.chat.voice_routes import router as voice_router +from apis.inference_api.connectors.routes import router as connectors_router # Include routers #app.include_router(health_router) app.include_router(agentcore_router) # AgentCore Runtime endpoints: /ping, /invocations app.include_router(converse_router) # API-key authenticated converse endpoint app.include_router(voice_router) # WebSocket voice streaming endpoint +app.include_router(connectors_router) # User-initiated OAuth consent # Mount static file directories for serving generated content # These are created by tools (visualization, code interpreter, etc.) @@ -160,10 +170,15 @@ async def lifespan(app: FastAPI): if __name__ == "__main__": import uvicorn + # Watch the full backend/src tree so edits to shared modules (agents/, + # apis/shared/) trigger reload. Without this uvicorn defaults to cwd, + # which hides changes outside inference_api/. + src_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) uvicorn.run( "apis.inference_api.main:app", host="0.0.0.0", port=8001, reload=True, + reload_dirs=[src_root], log_level="info" ) diff --git a/backend/src/apis/inference_api/middleware/__init__.py b/backend/src/apis/inference_api/middleware/__init__.py new file mode 100644 index 00000000..44a0d67c --- /dev/null +++ b/backend/src/apis/inference_api/middleware/__init__.py @@ -0,0 +1 @@ +"""Inference API middleware.""" diff --git a/backend/src/apis/inference_api/middleware/agentcore_context.py b/backend/src/apis/inference_api/middleware/agentcore_context.py new file mode 100644 index 00000000..677b8a2c --- /dev/null +++ b/backend/src/apis/inference_api/middleware/agentcore_context.py @@ -0,0 +1,103 @@ +"""AgentCore Runtime context middleware. + +Bridges AgentCore Runtime request headers into BedrockAgentCoreContext so that +downstream code (e.g. IdentityClient token lookups) can access the per-invocation +workload identity token without threading it through every function call. + +AgentCore Runtime injects these headers on every invocation: + - WorkloadAccessToken: per-user workload identity token, derived from the + validated inbound JWT by the Runtime's managed JWT authorizer. + - OAuth2CallbackUrl: OAuth2 callback URL registered on the workload identity. + - X-Amzn-Bedrock-AgentCore-Runtime-Session-Id: current session ID. + - X-Amzn-Request-Id: per-request trace ID. + +When the inference API is wrapped by BedrockAgentCoreApp, these are populated +automatically. Because this service runs as a plain FastAPI app inside +AgentCore Runtime, we populate the context ourselves. + +The middleware is a no-op in local development where these headers are absent, +which keeps tests and `python -m main` runs working without mocks. +""" + +import logging +import os +from urllib.parse import urlparse + +from bedrock_agentcore.runtime import BedrockAgentCoreContext +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response + +logger = logging.getLogger(__name__) + +HEADER_WORKLOAD_ACCESS_TOKEN = "WorkloadAccessToken" +HEADER_OAUTH2_CALLBACK_URL = "OAuth2CallbackUrl" +HEADER_SESSION_ID = "X-Amzn-Bedrock-AgentCore-Runtime-Session-Id" +HEADER_REQUEST_ID = "X-Amzn-Request-Id" + +# The frontend always posts to `/oauth-complete` (see +# `frontend/ai.client/src/app/settings/connectors/services/user-connectors.service.ts`). +# Pinning the path closes off path-traversal and arbitrary-endpoint variants of +# the same attack class once the origin is allowlisted. +_ALLOWED_CALLBACK_PATH = "/oauth-complete" + + +def _allowed_callback_origins() -> frozenset[str]: + """Origins that are allowed to set `OAuth2CallbackUrl`. + + Reuses `CORS_ORIGINS` (set by CDK on the inference-api task) as the trust + boundary: the frontend lives at one of those origins, so its callback URL + must too. Read at request time so tests can monkeypatch the env var. + """ + raw = os.environ.get("CORS_ORIGINS", "") + return frozenset(o.strip().rstrip("/") for o in raw.split(",") if o.strip()) + + +def _is_safe_callback_url(url: str) -> bool: + """Return True iff `url` is an allowlisted `/oauth-complete` URL. + + The header is client-supplied (see `user-connectors.service.ts`), so an + authenticated user can otherwise pivot the OAuth redirect to an + attacker-controlled origin and capture the authorization code on consent. + """ + try: + parsed = urlparse(url) + except ValueError: + return False + if parsed.scheme not in ("http", "https") or not parsed.netloc: + return False + if parsed.path != _ALLOWED_CALLBACK_PATH: + return False + if parsed.query or parsed.fragment: + return False + origin = f"{parsed.scheme}://{parsed.netloc}" + return origin in _allowed_callback_origins() + + +class AgentCoreContextMiddleware(BaseHTTPMiddleware): + """Populates BedrockAgentCoreContext from Runtime request headers.""" + + async def dispatch(self, request: Request, call_next) -> Response: + workload_token = request.headers.get(HEADER_WORKLOAD_ACCESS_TOKEN) + if workload_token: + BedrockAgentCoreContext.set_workload_access_token(workload_token) + + callback_url = request.headers.get(HEADER_OAUTH2_CALLBACK_URL) + if callback_url: + if _is_safe_callback_url(callback_url): + BedrockAgentCoreContext.set_oauth2_callback_url(callback_url) + else: + logger.warning( + "Rejected OAuth2CallbackUrl header: not in CORS_ORIGINS " + "allowlist or path != %s", + _ALLOWED_CALLBACK_PATH, + ) + + session_id = request.headers.get(HEADER_SESSION_ID) + if session_id: + BedrockAgentCoreContext.set_request_context( + request_id=request.headers.get(HEADER_REQUEST_ID, ""), + session_id=session_id, + ) + + return await call_next(request) diff --git a/backend/src/apis/shared/oauth/__init__.py b/backend/src/apis/shared/oauth/__init__.py index bd792d88..6750d0e6 100644 --- a/backend/src/apis/shared/oauth/__init__.py +++ b/backend/src/apis/shared/oauth/__init__.py @@ -1,52 +1,48 @@ -"""OAuth provider management module. +"""OAuth provider administration. -This module provides OAuth connection management for third-party integrations. -Admins can configure OAuth providers (Google, Microsoft, Canvas, etc.) and -users can connect their accounts for MCP tool requests. +Providers are registered and authenticated against AWS Bedrock AgentCore +Identity. This module exposes the provider metadata model, the DynamoDB +repository, and the AgentCore registrar used by admin CRUD routes. """ +from .agentcore_registrar import ( + AgentCoreRegistrar, + CredentialProviderConflictError, + CredentialProviderInfo, + CredentialProviderNotFoundError, + InvalidCustomProviderConfigError, + get_agentcore_registrar, +) from .models import ( - OAuthProviderType, - OAuthConnectionStatus, OAuthProvider, - OAuthUserToken, OAuthProviderCreate, - OAuthProviderUpdate, - OAuthProviderResponse, OAuthProviderListResponse, - OAuthConnectionResponse, - OAuthConnectionListResponse, - OAuthConnectResponse, + OAuthProviderResponse, + OAuthProviderType, + OAuthProviderUpdate, + OAuthRequiredEvent, + compute_scopes_hash, +) +from .provider_repository import ( + OAuthProviderRepository, + get_provider_repository, ) -from .encryption import TokenEncryptionService, get_token_encryption_service -from .token_cache import TokenCache, get_token_cache -from .provider_repository import OAuthProviderRepository, get_provider_repository -from .token_repository import OAuthTokenRepository, get_token_repository -from .service import OAuthService, get_oauth_service __all__ = [ - # Enums "OAuthProviderType", - "OAuthConnectionStatus", - # Models "OAuthProvider", - "OAuthUserToken", "OAuthProviderCreate", "OAuthProviderUpdate", "OAuthProviderResponse", "OAuthProviderListResponse", - "OAuthConnectionResponse", - "OAuthConnectionListResponse", - "OAuthConnectResponse", - # Services - "TokenEncryptionService", - "get_token_encryption_service", - "TokenCache", - "get_token_cache", + "OAuthRequiredEvent", + "compute_scopes_hash", "OAuthProviderRepository", "get_provider_repository", - "OAuthTokenRepository", - "get_token_repository", - "OAuthService", - "get_oauth_service", + "AgentCoreRegistrar", + "CredentialProviderInfo", + "CredentialProviderConflictError", + "CredentialProviderNotFoundError", + "InvalidCustomProviderConfigError", + "get_agentcore_registrar", ] diff --git a/backend/src/apis/shared/oauth/agentcore_registrar.py b/backend/src/apis/shared/oauth/agentcore_registrar.py new file mode 100644 index 00000000..a353d256 --- /dev/null +++ b/backend/src/apis/shared/oauth/agentcore_registrar.py @@ -0,0 +1,386 @@ +"""AgentCore Identity credential-provider registrar. + +Wraps `bedrock-agentcore-control` for managing OAuth2 credential providers +owned by AgentCore Identity. Callers upsert one of our `OAuthProvider` +records by first registering the client_id + client_secret here, then +storing the returned `callbackUrl` and `credentialProviderArn` on the +DynamoDB record. + +Division of authority: + +- AgentCore Identity: clientId, clientSecret, vendor-specific endpoint + config, callback URL. Returns a `credentialProviderArn` that identifies + the provider within the default token vault. +- Our DynamoDB: displayName, scopes, allowedRoles, iconName, enabled flag. + +Update semantics matter: `UpdateOauth2CredentialProvider` is NOT a partial +update — the full `oauth2ProviderConfigInput` (including clientId and +clientSecret) must be re-submitted on every call. Because +`GetOauth2CredentialProvider` returns only `clientSecretArn` (not the +secret value), credential rotation always requires the admin to re-enter +both fields. `update_credential_provider` enforces this. +""" + +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import boto3 +from botocore.exceptions import ClientError + +from .models import OAuthProviderType + +logger = logging.getLogger(__name__) + + +# Mapping from our OAuthProviderType to AgentCore's credentialProviderVendor +# and the corresponding vendor-specific config key inside +# `oauth2ProviderConfigInput`. AgentCore exposes a dedicated config struct +# for some vendors (Google, Microsoft, GitHub, Slack, Salesforce, Custom) and +# a shared `includedOauth2ProviderConfig` for "simpler" vendors that just +# need clientId/clientSecret (Zoom and most of the long tail). The vendor +# enum string is unchanged either way — only the surrounding config key +# differs. CANVAS routes through CustomOauth2 because AgentCore does not +# ship a first-class Canvas vendor. +_VENDOR_BY_TYPE: Dict[OAuthProviderType, str] = { + OAuthProviderType.GOOGLE: "GoogleOauth2", + OAuthProviderType.MICROSOFT: "MicrosoftOauth2", + OAuthProviderType.GITHUB: "GithubOauth2", + OAuthProviderType.SLACK: "SlackOauth2", + OAuthProviderType.SALESFORCE: "SalesforceOauth2", + OAuthProviderType.ZOOM: "ZoomOauth2", + OAuthProviderType.CANVAS: "CustomOauth2", + OAuthProviderType.CUSTOM: "CustomOauth2", +} + +_CONFIG_KEY_BY_TYPE: Dict[OAuthProviderType, str] = { + OAuthProviderType.GOOGLE: "googleOauth2ProviderConfig", + OAuthProviderType.MICROSOFT: "microsoftOauth2ProviderConfig", + OAuthProviderType.GITHUB: "githubOauth2ProviderConfig", + OAuthProviderType.SLACK: "slackOauth2ProviderConfig", + OAuthProviderType.SALESFORCE: "salesforceOauth2ProviderConfig", + # Zoom shares the `includedOauth2ProviderConfig` slot with most of the + # long-tail vendors (Okta, Notion, Dropbox, etc). The vendor enum + # string still discriminates the actual provider; the shared config + # key just carries the credentials. + OAuthProviderType.ZOOM: "includedOauth2ProviderConfig", + OAuthProviderType.CANVAS: "customOauth2ProviderConfig", + OAuthProviderType.CUSTOM: "customOauth2ProviderConfig", +} + + +@dataclass(frozen=True) +class CredentialProviderInfo: + """AgentCore Identity record for one OAuth2 credential provider. + + `client_id` is populated on `get_credential_provider`; Create/Update + responses include it in `oauth2ProviderConfigOutput` when the vendor + echoes it back, and we surface it when present. `client_secret` is + never returned by AgentCore — only `client_secret_arn`. + """ + + provider_id: str + vendor: str + credential_provider_arn: str + client_secret_arn: str + callback_url: str + client_id: Optional[str] = None + + +class CredentialProviderNotFoundError(LookupError): + """Raised when an AgentCore credential provider does not exist.""" + + +class CredentialProviderConflictError(RuntimeError): + """Raised when creating a provider that already exists in AgentCore.""" + + +class InvalidCustomProviderConfigError(ValueError): + """Raised when Custom vendor is selected without exactly one discovery mode.""" + + +class AgentCoreRegistrar: + """Thin wrapper around `bedrock-agentcore-control` for OAuth2 providers. + + Stateless apart from the boto3 client. Safe to share across requests. + """ + + def __init__( + self, + *, + region: Optional[str] = None, + client: Any = None, + ): + self._region = region or os.environ.get("AWS_REGION", "us-east-1") + self._client = client or boto3.client( + "bedrock-agentcore-control", region_name=self._region + ) + + # ------------------------------------------------------------------ create + def create_credential_provider( + self, + *, + provider_id: str, + provider_type: OAuthProviderType, + client_id: str, + client_secret: str, + discovery_url: Optional[str] = None, + authorization_server_metadata: Optional[Dict[str, Any]] = None, + ) -> CredentialProviderInfo: + """Register a new OAuth2 credential provider in AgentCore Identity. + + Raises: + CredentialProviderConflictError: A provider with `provider_id` + already exists. + InvalidCustomProviderConfigError: Custom/Canvas vendor was used + without exactly one of `discovery_url` or + `authorization_server_metadata`. + botocore.exceptions.ClientError: Any other AWS error bubbles up. + """ + vendor, config_input = self._build_config_input( + provider_type=provider_type, + client_id=client_id, + client_secret=client_secret, + discovery_url=discovery_url, + authorization_server_metadata=authorization_server_metadata, + ) + + try: + response = self._client.create_oauth2_credential_provider( + name=provider_id, + credentialProviderVendor=vendor, + oauth2ProviderConfigInput=config_input, + ) + except ClientError as err: + code = err.response.get("Error", {}).get("Code") + if code in ("ConflictException", "ResourceAlreadyExistsException"): + raise CredentialProviderConflictError( + f"AgentCore credential provider '{provider_id}' already exists" + ) from err + raise + + return self._info_from_response( + provider_id=provider_id, vendor=vendor, response=response + ) + + # ------------------------------------------------------------------ update + def update_credential_provider( + self, + *, + provider_id: str, + provider_type: OAuthProviderType, + client_id: str, + client_secret: str, + discovery_url: Optional[str] = None, + authorization_server_metadata: Optional[Dict[str, Any]] = None, + fallback_arn: Optional[str] = None, + ) -> CredentialProviderInfo: + """Replace the AgentCore provider's full config. + + `UpdateOauth2CredentialProvider` requires the full + `oauth2ProviderConfigInput`, so the caller must supply both + `client_id` and `client_secret`. There is no "change only the + secret" path — Get does not return the existing secret, and the API + does not support partial updates. + + Unlike `CreateOauth2CredentialProvider`, the Update response does + NOT include `credentialProviderArn`. Since the ARN is immutable + across updates, callers pass the known ARN via `fallback_arn` + (typically from the existing DynamoDB record) so the returned + `CredentialProviderInfo` has the same shape as create/get. + + Raises: + CredentialProviderNotFoundError: No such provider. + InvalidCustomProviderConfigError: Custom/Canvas without exactly + one discovery mode. + botocore.exceptions.ClientError: Any other AWS error. + """ + vendor, config_input = self._build_config_input( + provider_type=provider_type, + client_id=client_id, + client_secret=client_secret, + discovery_url=discovery_url, + authorization_server_metadata=authorization_server_metadata, + ) + + try: + response = self._client.update_oauth2_credential_provider( + name=provider_id, + credentialProviderVendor=vendor, + oauth2ProviderConfigInput=config_input, + ) + except ClientError as err: + if self._is_not_found(err): + raise CredentialProviderNotFoundError(provider_id) from err + raise + + return self._info_from_response( + provider_id=provider_id, + vendor=vendor, + response=response, + fallback_arn=fallback_arn, + ) + + # --------------------------------------------------------------------- get + def get_credential_provider(self, provider_id: str) -> CredentialProviderInfo: + """Fetch the AgentCore record for `provider_id`. + + Raises: + CredentialProviderNotFoundError: No such provider. + """ + try: + response = self._client.get_oauth2_credential_provider(name=provider_id) + except ClientError as err: + if self._is_not_found(err): + raise CredentialProviderNotFoundError(provider_id) from err + raise + + return self._info_from_response( + provider_id=provider_id, + vendor=response["credentialProviderVendor"], + response=response, + ) + + # ------------------------------------------------------------------ delete + def delete_credential_provider(self, provider_id: str) -> None: + """Delete the AgentCore provider. Missing providers are treated as success.""" + try: + self._client.delete_oauth2_credential_provider(name=provider_id) + except ClientError as err: + if self._is_not_found(err): + logger.info( + "AgentCore provider '%s' already absent; delete is a no-op", + provider_id, + ) + return + raise + + # ------------------------------------------------------------- build helper + def _build_config_input( + self, + *, + provider_type: OAuthProviderType, + client_id: str, + client_secret: str, + discovery_url: Optional[str], + authorization_server_metadata: Optional[Dict[str, Any]], + ) -> tuple[str, Dict[str, Any]]: + """Return `(vendor, oauth2ProviderConfigInput)` for AgentCore.""" + try: + vendor = _VENDOR_BY_TYPE[provider_type] + config_key = _CONFIG_KEY_BY_TYPE[provider_type] + except KeyError as err: + raise ValueError(f"Unsupported OAuth provider type: {provider_type}") from err + + vendor_config: Dict[str, Any] = { + "clientId": client_id, + "clientSecret": client_secret, + } + + if config_key == "customOauth2ProviderConfig": + vendor_config["oauthDiscovery"] = self._build_oauth_discovery( + discovery_url=discovery_url, + authorization_server_metadata=authorization_server_metadata, + ) + elif discovery_url or authorization_server_metadata: + raise ValueError( + f"Discovery config is only valid for CustomOauth2; " + f"provider_type={provider_type} ignores it" + ) + + return vendor, {config_key: vendor_config} + + @staticmethod + def _build_oauth_discovery( + *, + discovery_url: Optional[str], + authorization_server_metadata: Optional[Dict[str, Any]], + ) -> Dict[str, Any]: + if bool(discovery_url) == bool(authorization_server_metadata): + raise InvalidCustomProviderConfigError( + "CustomOauth2 requires exactly one of discovery_url or " + "authorization_server_metadata" + ) + if discovery_url: + return {"discoveryUrl": discovery_url} + return {"authorizationServerMetadata": authorization_server_metadata} + + # ----------------------------------------------------------- parse helpers + @staticmethod + def _info_from_response( + *, + provider_id: str, + vendor: str, + response: Dict[str, Any], + fallback_arn: Optional[str] = None, + ) -> CredentialProviderInfo: + # AgentCore's documented shape is `clientSecretArn: {secretArn: str}`. + # If the field is missing (possible for vendors that don't persist a + # secret) we tolerate that. If it's present but shaped differently, + # that's a contract change we want to fail loudly on rather than + # silently storing an empty string. + raw_secret = response.get("clientSecretArn") + if raw_secret is None: + client_secret_arn = "" + elif isinstance(raw_secret, dict): + secret_arn = raw_secret.get("secretArn", "") + if not isinstance(secret_arn, str): + raise TypeError( + f"AgentCore returned clientSecretArn.secretArn of unexpected " + f"type {type(secret_arn).__name__}; expected str" + ) + client_secret_arn = secret_arn + else: + raise TypeError( + f"AgentCore returned clientSecretArn of unexpected type " + f"{type(raw_secret).__name__}; expected dict or None" + ) + + output_config = response.get("oauth2ProviderConfigOutput") or {} + # Each vendor variant nests its own output object; the clientId lives + # one level deeper when present. We tolerate its absence. + client_id: Optional[str] = None + for nested in output_config.values(): + if isinstance(nested, dict) and "clientId" in nested: + client_id = nested["clientId"] + break + + credential_provider_arn = response.get("credentialProviderArn") + if not isinstance(credential_provider_arn, str) or not credential_provider_arn: + # UpdateOauth2CredentialProvider omits the ARN; callers pass the + # known-immutable ARN as `fallback_arn` in that case. Create/Get + # never supply a fallback, so a missing ARN there still fails. + if fallback_arn: + credential_provider_arn = fallback_arn + else: + raise TypeError( + "AgentCore response missing credentialProviderArn or wrong type" + ) + + return CredentialProviderInfo( + provider_id=provider_id, + vendor=vendor, + credential_provider_arn=credential_provider_arn, + client_secret_arn=client_secret_arn, + callback_url=response.get("callbackUrl", "") or "", + client_id=client_id, + ) + + @staticmethod + def _is_not_found(err: ClientError) -> bool: + code = err.response.get("Error", {}).get("Code") + return code in ("ResourceNotFoundException", "NotFoundException") + + +_default_registrar: Optional[AgentCoreRegistrar] = None + + +def get_agentcore_registrar() -> AgentCoreRegistrar: + """Return the process-wide `AgentCoreRegistrar` singleton.""" + global _default_registrar + if _default_registrar is None: + _default_registrar = AgentCoreRegistrar() + return _default_registrar diff --git a/backend/src/apis/shared/oauth/disconnect_repository.py b/backend/src/apis/shared/oauth/disconnect_repository.py new file mode 100644 index 00000000..2cc623e6 --- /dev/null +++ b/backend/src/apis/shared/oauth/disconnect_repository.py @@ -0,0 +1,135 @@ +"""DynamoDB repository for OAuth disconnect intent. + +Records that a user has explicitly disconnected from a provider, or that a +tool call surfaced a 401 against AgentCore Identity's vault token. Both +conditions mean the next consent flow must use `force_authentication=True` +so AgentCore replaces the vault entry rather than reusing it. + +Lives in the same `oauth-user-tokens` table as user OAuth tokens (currently +unused by backend code) — the table already has `PK`/`SK` keys, KMS +encryption, and R/W IAM for the inference API. Items use a `DISCONNECT#` +sort-key prefix so they cannot collide with future per-user token storage. + +The flag is durable so the disconnect intent survives across replicas: a +disconnect on one inference-API replica is visible to the next request, +which may land on a different replica. +""" + +import logging +import os +from datetime import datetime, timezone +from typing import Optional + +import boto3 +from botocore.exceptions import ClientError + +logger = logging.getLogger(__name__) + + +class OAuthDisconnectRepository: + """Per-(user, provider) disconnect flag backed by DynamoDB.""" + + def __init__( + self, + table_name: Optional[str] = None, + region: Optional[str] = None, + ): + self._table_name = table_name or os.getenv("DYNAMODB_OAUTH_USER_TOKENS_TABLE_NAME") + self._region = region or os.getenv("AWS_REGION", "us-west-2") + self._enabled = bool(self._table_name) + + if not self._enabled: + logger.warning( + "DYNAMODB_OAUTH_USER_TOKENS_TABLE_NAME not set. " + "OAuth disconnect repository is disabled — disconnect intent " + "will not be durable across replicas." + ) + return + + profile = os.getenv("AWS_PROFILE") + session = boto3.Session(profile_name=profile) if profile else boto3 + self._dynamodb = session.resource("dynamodb", region_name=self._region) + self._table = self._dynamodb.Table(self._table_name) + logger.info("Initialized OAuth disconnect repository: table=%s", self._table_name) + + @property + def enabled(self) -> bool: + return self._enabled + + @staticmethod + def _key(user_id: str, provider_id: str) -> dict: + return { + "PK": f"USER#{user_id}", + "SK": f"DISCONNECT#{provider_id}", + } + + async def is_disconnected(self, user_id: str, provider_id: str) -> bool: + """Return True if the user has been marked disconnected from `provider_id`. + + Failure-mode policy: if the read fails, return False. Treating an + unreachable DDB as "disconnected" would lock every user out of every + connector during a transient outage; treating it as "not disconnected" + falls back to the AgentCore vault state — the prior, less-correct + behavior, but still safe. + """ + if not self._enabled: + return False + try: + response = self._table.get_item(Key=self._key(user_id, provider_id)) + return "Item" in response + except ClientError as e: + logger.error( + "Disconnect lookup failed for user=%s provider=%s: %s", + user_id, + provider_id, + e, + ) + return False + + async def mark_disconnected(self, user_id: str, provider_id: str) -> None: + """Record that `(user_id, provider_id)` requires fresh consent. + + Idempotent — overwrites any prior `disconnected_at` if called twice. + """ + if not self._enabled: + return + item = { + **self._key(user_id, provider_id), + "disconnected_at": datetime.now(timezone.utc).isoformat() + "Z", + } + try: + self._table.put_item(Item=item) + except ClientError as e: + logger.error( + "Failed to mark disconnect for user=%s provider=%s: %s", + user_id, + provider_id, + e, + ) + raise + + async def clear_disconnected(self, user_id: str, provider_id: str) -> None: + """Remove the disconnect flag — called after a successful re-consent.""" + if not self._enabled: + return + try: + self._table.delete_item(Key=self._key(user_id, provider_id)) + except ClientError as e: + logger.error( + "Failed to clear disconnect for user=%s provider=%s: %s", + user_id, + provider_id, + e, + ) + raise + + +_disconnect_repository: Optional[OAuthDisconnectRepository] = None + + +def get_disconnect_repository() -> OAuthDisconnectRepository: + """Get the process-wide disconnect repository singleton.""" + global _disconnect_repository + if _disconnect_repository is None: + _disconnect_repository = OAuthDisconnectRepository() + return _disconnect_repository diff --git a/backend/src/apis/shared/oauth/encryption.py b/backend/src/apis/shared/oauth/encryption.py deleted file mode 100644 index 0b1dc877..00000000 --- a/backend/src/apis/shared/oauth/encryption.py +++ /dev/null @@ -1,138 +0,0 @@ -"""KMS encryption service for OAuth tokens.""" - -import base64 -import logging -import os -from typing import Optional - -logger = logging.getLogger(__name__) - - -class TokenEncryptionService: - """ - Service for encrypting/decrypting OAuth tokens using AWS KMS. - - Tokens are encrypted before storage in DynamoDB and decrypted on retrieval. - Uses AWS KMS with a customer-managed key for secure key management. - """ - - def __init__( - self, - key_arn: Optional[str] = None, - region: Optional[str] = None, - ): - """ - Initialize the encryption service. - - Args: - key_arn: KMS key ARN for encryption (defaults to env var) - region: AWS region (defaults to env var) - """ - self._key_arn = key_arn or os.getenv("OAUTH_TOKEN_ENCRYPTION_KEY_ARN") - self._region = region or os.getenv("AWS_REGION", "us-west-2") - self._client = None - self._enabled = bool(self._key_arn) - - if not self._enabled: - logger.warning( - "OAUTH_TOKEN_ENCRYPTION_KEY_ARN not set. " - "Token encryption is disabled (development mode only)." - ) - - @property - def enabled(self) -> bool: - """Check if encryption is enabled.""" - return self._enabled - - def _get_client(self): - """Lazy initialization of KMS client.""" - if self._client is None: - import boto3 - - profile = os.getenv("AWS_PROFILE") - if profile: - session = boto3.Session(profile_name=profile) - self._client = session.client("kms", region_name=self._region) - else: - self._client = boto3.client("kms", region_name=self._region) - return self._client - - def encrypt(self, plaintext: str) -> str: - """ - Encrypt a plaintext string using KMS. - - Args: - plaintext: The string to encrypt - - Returns: - Base64-encoded ciphertext - - Raises: - RuntimeError: If encryption fails - """ - if not self._enabled: - # Development mode: return base64-encoded plaintext (NOT secure!) - logger.warning("Using development mode encryption (NOT SECURE)") - return f"DEV:{base64.b64encode(plaintext.encode()).decode()}" - - try: - client = self._get_client() - response = client.encrypt( - KeyId=self._key_arn, - Plaintext=plaintext.encode("utf-8"), - ) - ciphertext = base64.b64encode(response["CiphertextBlob"]).decode("utf-8") - logger.debug(f"Encrypted token (length={len(plaintext)} -> {len(ciphertext)})") - return ciphertext - - except Exception as e: - logger.error(f"Failed to encrypt token: {e}") - raise RuntimeError(f"Token encryption failed: {e}") from e - - def decrypt(self, ciphertext: str) -> str: - """ - Decrypt a ciphertext string using KMS. - - Args: - ciphertext: Base64-encoded ciphertext - - Returns: - Decrypted plaintext string - - Raises: - RuntimeError: If decryption fails - """ - if not self._enabled: - # Development mode: decode base64 plaintext - if ciphertext.startswith("DEV:"): - logger.warning("Using development mode decryption (NOT SECURE)") - return base64.b64decode(ciphertext[4:]).decode() - else: - raise RuntimeError("Cannot decrypt production ciphertext without KMS key") - - try: - client = self._get_client() - ciphertext_blob = base64.b64decode(ciphertext) - response = client.decrypt( - CiphertextBlob=ciphertext_blob, - KeyId=self._key_arn, - ) - plaintext = response["Plaintext"].decode("utf-8") - logger.debug(f"Decrypted token (length={len(ciphertext)} -> {len(plaintext)})") - return plaintext - - except Exception as e: - logger.error(f"Failed to decrypt token: {e}") - raise RuntimeError(f"Token decryption failed: {e}") from e - - -# Singleton instance -_encryption_service: Optional[TokenEncryptionService] = None - - -def get_token_encryption_service() -> TokenEncryptionService: - """Get the token encryption service singleton.""" - global _encryption_service - if _encryption_service is None: - _encryption_service = TokenEncryptionService() - return _encryption_service diff --git a/backend/src/apis/shared/oauth/models.py b/backend/src/apis/shared/oauth/models.py index 5c1a384d..492b3686 100644 --- a/backend/src/apis/shared/oauth/models.py +++ b/backend/src/apis/shared/oauth/models.py @@ -1,48 +1,87 @@ -"""OAuth models for provider configuration and user tokens.""" +"""OAuth provider models. +Providers are registered and administered through AWS Bedrock AgentCore +Identity — AgentCore owns `clientId`, `clientSecret`, endpoint config, and +the callback URL. Our DynamoDB record keeps the display metadata, scopes, +role gates, and cached pointers (ARN + callback URL) for convenience. +""" + +import base64 import hashlib import logging +import re from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, ConfigDict, Field, model_validator logger = logging.getLogger(__name__) +# Inline-icon data URLs are persisted directly in the provider record so we +# don't have to stand up an S3 bucket / CDN just for connector icons. The +# 100KB cap (after base64 decode) keeps the DynamoDB item well under its +# 400KB limit and is generous for an icon — a tuned 64x64 PNG is < 10KB. +ICON_DATA_MAX_BYTES = 100 * 1024 +_ICON_DATA_URL_RE = re.compile( + r"^data:image/(png|jpeg|jpg|gif|webp|svg\+xml);base64,([A-Za-z0-9+/=]+)$" +) -class OAuthProviderType(str, Enum): - """Supported OAuth provider types.""" - GOOGLE = "google" - MICROSOFT = "microsoft" - GITHUB = "github" - CANVAS = "canvas" - CUSTOM = "custom" +def validate_icon_data(value: Optional[str]) -> Optional[str]: + """Validate an inline icon data URL. + Returns the value unchanged when valid, raises `ValueError` otherwise. + `None` is allowed (no icon set). Empty string is preserved by the caller + as a "clear the icon" signal — handled at the repository layer. + """ + if value is None or value == "": + return value + match = _ICON_DATA_URL_RE.match(value) + if not match: + raise ValueError( + "icon_data must be a base64 data URL of the form " + "data:image/;base64,<...>" + ) + try: + decoded = base64.b64decode(match.group(2), validate=True) + except Exception as err: + raise ValueError(f"icon_data base64 payload is invalid: {err}") + if len(decoded) > ICON_DATA_MAX_BYTES: + raise ValueError( + f"icon_data exceeds {ICON_DATA_MAX_BYTES // 1024}KB " + f"(got {len(decoded) // 1024}KB)" + ) + return value -class OAuthConnectionStatus(str, Enum): - """Connection status for user OAuth tokens.""" - CONNECTED = "connected" - EXPIRED = "expired" - REVOKED = "revoked" - NEEDS_REAUTH = "needs_reauth" +class OAuthProviderType(str, Enum): + """Supported OAuth provider types. + `CANVAS` routes through AgentCore's `CustomOauth2` vendor but is kept + as a distinct type so the admin UI can surface Canvas-specific guidance + if/when we add a preset. Today the admin form treats it as Custom. -def compute_scopes_hash(scopes: List[str]) -> str: + `SLACK`, `SALESFORCE`, and `ZOOM` are first-class AgentCore Identity + vendors (per + https://docs.aws.amazon.com/bedrock-agentcore/latest/devguide/identity-idps.html) + — endpoints and provider-specific defaults are pre-configured by + AgentCore, so admins only need to supply client credentials and scopes. """ - Compute a hash of the scopes list for change detection. - Used to detect when provider scopes change and user needs to re-authenticate. + GOOGLE = "google" + MICROSOFT = "microsoft" + GITHUB = "github" + SLACK = "slack" + SALESFORCE = "salesforce" + ZOOM = "zoom" + CANVAS = "canvas" + CUSTOM = "custom" - Args: - scopes: List of OAuth scopes - Returns: - SHA-256 hash of sorted scopes - """ +def compute_scopes_hash(scopes: List[str]) -> str: + """Return a short, order-independent hash of `scopes` for change detection.""" sorted_scopes = sorted(scopes) scopes_str = ",".join(sorted_scopes) return hashlib.sha256(scopes_str.encode()).hexdigest()[:16] @@ -50,288 +89,256 @@ def compute_scopes_hash(scopes: List[str]) -> str: @dataclass class OAuthProvider: - """OAuth provider configuration stored in DynamoDB.""" + """OAuth provider record stored in DynamoDB. + + AgentCore-owned fields (`credential_provider_arn`, `callback_url`) are + populated after a successful registration and kept in sync on update. + They are cached for admin UX — the source of truth lives in AgentCore. + """ provider_id: str display_name: str provider_type: OAuthProviderType - authorization_endpoint: str - token_endpoint: str - client_id: str scopes: List[str] allowed_roles: List[str] # AppRole IDs that can use this provider enabled: bool = True - icon_name: str = "heroLink" # Default icon - userinfo_endpoint: Optional[str] = None # Optional userinfo endpoint - revocation_endpoint: Optional[str] = None # Optional token revocation endpoint - pkce_required: bool = True # PKCE is required by default for security - authorization_params: Dict[str, str] = field(default_factory=dict) # Extra params for auth URL (e.g., access_type=offline) + icon_name: str = "heroLink" + # Optional admin-uploaded icon as a base64 data URL. When present, + # frontends prefer this over `icon_name`. See `validate_icon_data` + # for the accepted shape and size cap. + icon_data: Optional[str] = None + credential_provider_arn: Optional[str] = None + callback_url: Optional[str] = None + # Custom vendor only — mirrors AgentCore's Oauth2Discovery union. + # Exactly one of these is populated when `provider_type` is CUSTOM or + # CANVAS; both are None for Google/Microsoft/GitHub. + oauth_discovery_url: Optional[str] = None + authorization_server_metadata: Optional[Dict[str, Any]] = None + # Vendor-specific OAuth parameters merged into AgentCore Identity's + # `customParameters` at request time. Examples: Google `hd=mycorp.com` + # to restrict to a Workspace domain, `prompt=consent` to force the + # consent screen. Hardcoded baselines (e.g. Google's + # `access_type=offline`) win on conflict — admins cannot accidentally + # turn off a documented requirement. + custom_parameters: Optional[Dict[str, str]] = None created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat() + "Z") updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat() + "Z") @property def scopes_hash(self) -> str: - """Get the scopes hash for this provider.""" return compute_scopes_hash(self.scopes) def to_dynamo_item(self) -> Dict[str, Any]: - """Convert to DynamoDB item format.""" return { "PK": f"PROVIDER#{self.provider_id}", "SK": "CONFIG", - # GSI for enabled providers "GSI1PK": f"ENABLED#{str(self.enabled).lower()}", "GSI1SK": f"PROVIDER#{self.provider_id}", - # Main attributes "providerId": self.provider_id, "displayName": self.display_name, "providerType": self.provider_type.value, - "authorizationEndpoint": self.authorization_endpoint, - "tokenEndpoint": self.token_endpoint, - "clientId": self.client_id, "scopes": self.scopes, "scopesHash": self.scopes_hash, "allowedRoles": self.allowed_roles, "enabled": self.enabled, "iconName": self.icon_name, - "userinfoEndpoint": self.userinfo_endpoint, - "revocationEndpoint": self.revocation_endpoint, - "pkceRequired": self.pkce_required, - "authorizationParams": self.authorization_params, + "iconData": self.icon_data, + "credentialProviderArn": self.credential_provider_arn, + "callbackUrl": self.callback_url, + "oauthDiscoveryUrl": self.oauth_discovery_url, + "authorizationServerMetadata": self.authorization_server_metadata, + "customParameters": self.custom_parameters, "createdAt": self.created_at, "updatedAt": self.updated_at, } @classmethod def from_dynamo_item(cls, item: Dict[str, Any]) -> "OAuthProvider": - """Create from DynamoDB item.""" return cls( provider_id=item["providerId"], display_name=item["displayName"], provider_type=OAuthProviderType(item["providerType"]), - authorization_endpoint=item["authorizationEndpoint"], - token_endpoint=item["tokenEndpoint"], - client_id=item["clientId"], scopes=item.get("scopes", []), allowed_roles=item.get("allowedRoles", []), enabled=item.get("enabled", True), icon_name=item.get("iconName", "heroLink"), - userinfo_endpoint=item.get("userinfoEndpoint"), - revocation_endpoint=item.get("revocationEndpoint"), - pkce_required=item.get("pkceRequired", True), - authorization_params=item.get("authorizationParams", {}), + icon_data=item.get("iconData"), + credential_provider_arn=item.get("credentialProviderArn"), + callback_url=item.get("callbackUrl"), + oauth_discovery_url=item.get("oauthDiscoveryUrl"), + authorization_server_metadata=item.get("authorizationServerMetadata"), + custom_parameters=item.get("customParameters"), created_at=item.get("createdAt", datetime.now(timezone.utc).isoformat() + "Z"), updated_at=item.get("updatedAt", datetime.now(timezone.utc).isoformat() + "Z"), ) -@dataclass -class OAuthUserToken: - """User's OAuth token stored in DynamoDB (encrypted).""" - - user_id: str - provider_id: str - access_token_encrypted: str # KMS-encrypted access token - refresh_token_encrypted: Optional[str] = None # KMS-encrypted refresh token - token_type: str = "Bearer" - expires_at: Optional[int] = None # Unix timestamp - scopes_hash: str = "" # Hash of scopes at time of authorization - status: OAuthConnectionStatus = OAuthConnectionStatus.CONNECTED - connected_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat() + "Z") - updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat() + "Z") - - @property - def is_expired(self) -> bool: - """Check if token has expired.""" - if not self.expires_at: - return False - import time - - return time.time() > self.expires_at - - def to_dynamo_item(self) -> Dict[str, Any]: - """Convert to DynamoDB item format.""" - item = { - "PK": f"USER#{self.user_id}", - "SK": f"PROVIDER#{self.provider_id}", - # GSI for listing users by provider - "GSI1PK": f"PROVIDER#{self.provider_id}", - "GSI1SK": f"USER#{self.user_id}", - # Main attributes - "userId": self.user_id, - "providerId": self.provider_id, - "accessTokenEncrypted": self.access_token_encrypted, - "tokenType": self.token_type, - "scopesHash": self.scopes_hash, - "status": self.status.value, - "connectedAt": self.connected_at, - "updatedAt": self.updated_at, - } - - if self.refresh_token_encrypted: - item["refreshTokenEncrypted"] = self.refresh_token_encrypted - - if self.expires_at: - item["expiresAt"] = self.expires_at - - return item - - @classmethod - def from_dynamo_item(cls, item: Dict[str, Any]) -> "OAuthUserToken": - """Create from DynamoDB item.""" - return cls( - user_id=item["userId"], - provider_id=item["providerId"], - access_token_encrypted=item["accessTokenEncrypted"], - refresh_token_encrypted=item.get("refreshTokenEncrypted"), - token_type=item.get("tokenType", "Bearer"), - expires_at=item.get("expiresAt"), - scopes_hash=item.get("scopesHash", ""), - status=OAuthConnectionStatus(item.get("status", "connected")), - connected_at=item.get("connectedAt", datetime.now(timezone.utc).isoformat() + "Z"), - updated_at=item.get("updatedAt", datetime.now(timezone.utc).isoformat() + "Z"), - ) - - # ============================================================================= -# Pydantic Request/Response Models +# Pydantic request/response models # ============================================================================= +_CUSTOM_TYPES = {OAuthProviderType.CUSTOM, OAuthProviderType.CANVAS} + + class OAuthProviderCreate(BaseModel): - """Request model for creating an OAuth provider.""" + """Request model for creating an OAuth provider. + + `client_id` and `client_secret` are forwarded to AgentCore Identity and + are never persisted in our DynamoDB table. For Custom/Canvas providers + the caller must supply exactly one of `oauth_discovery_url` or + `authorization_server_metadata`. + """ provider_id: str = Field(..., min_length=1, max_length=64, pattern=r"^[a-z0-9-]+$") display_name: str = Field(..., min_length=1, max_length=128) provider_type: OAuthProviderType - authorization_endpoint: str = Field(..., min_length=1) - token_endpoint: str = Field(..., min_length=1) client_id: str = Field(..., min_length=1) - client_secret: str = Field(..., min_length=1) # Will be stored in Secrets Manager + client_secret: str = Field(..., min_length=1) scopes: List[str] = Field(default_factory=list) allowed_roles: List[str] = Field(default_factory=list) enabled: bool = True icon_name: str = "heroLink" - userinfo_endpoint: Optional[str] = None - revocation_endpoint: Optional[str] = None - pkce_required: bool = True - authorization_params: Dict[str, str] = Field(default_factory=dict) + icon_data: Optional[str] = None + oauth_discovery_url: Optional[str] = None + authorization_server_metadata: Optional[Dict[str, Any]] = None + custom_parameters: Optional[Dict[str, str]] = None model_config = ConfigDict(json_schema_extra={ "example": { "provider_id": "google-workspace", "display_name": "Google Workspace", "provider_type": "google", - "authorization_endpoint": "https://accounts.google.com/o/oauth2/v2/auth", - "token_endpoint": "https://oauth2.googleapis.com/token", "client_id": "your-client-id.apps.googleusercontent.com", "client_secret": "your-client-secret", - "scopes": ["openid", "email", "profile", "https://www.googleapis.com/auth/drive.readonly"], + "scopes": ["openid", "email", "profile"], "allowed_roles": ["admin", "user"], "enabled": True, "icon_name": "heroCloud", - "pkce_required": True, - "authorization_params": {"access_type": "offline", "prompt": "consent"}, } }) + @model_validator(mode="after") + def _validate_discovery(self) -> "OAuthProviderCreate": + if self.provider_type in _CUSTOM_TYPES: + if bool(self.oauth_discovery_url) == bool(self.authorization_server_metadata): + raise ValueError( + "Custom providers require exactly one of " + "oauth_discovery_url or authorization_server_metadata" + ) + elif self.oauth_discovery_url or self.authorization_server_metadata: + raise ValueError( + f"Discovery config is only valid for custom/canvas providers; " + f"provider_type={self.provider_type.value} does not accept it" + ) + self.icon_data = validate_icon_data(self.icon_data) + return self + class OAuthProviderUpdate(BaseModel): - """Request model for updating an OAuth provider.""" + """Request model for updating an OAuth provider. + + Credential rotation requires both `client_id` and `client_secret` + because AgentCore's update API demands the full config and does not + echo back the stored secret. Partial edits to metadata (display name, + scopes, roles, icon, enabled) are allowed without touching credentials. + """ display_name: Optional[str] = Field(None, min_length=1, max_length=128) - authorization_endpoint: Optional[str] = None - token_endpoint: Optional[str] = None client_id: Optional[str] = None - client_secret: Optional[str] = None # Only if rotating secret + client_secret: Optional[str] = None scopes: Optional[List[str]] = None allowed_roles: Optional[List[str]] = None enabled: Optional[bool] = None icon_name: Optional[str] = None - userinfo_endpoint: Optional[str] = None - revocation_endpoint: Optional[str] = None - pkce_required: Optional[bool] = None - authorization_params: Optional[Dict[str, str]] = None + # `""` clears any uploaded icon (falls back to `icon_name`); `None` + # leaves the existing value alone. Validated for shape and size cap. + icon_data: Optional[str] = None + oauth_discovery_url: Optional[str] = None + authorization_server_metadata: Optional[Dict[str, Any]] = None + custom_parameters: Optional[Dict[str, str]] = None + + @model_validator(mode="after") + def _validate_credential_pair(self) -> "OAuthProviderUpdate": + if bool(self.client_id) != bool(self.client_secret): + raise ValueError( + "client_id and client_secret must be provided together for rotation" + ) + if self.oauth_discovery_url and self.authorization_server_metadata: + raise ValueError( + "oauth_discovery_url and authorization_server_metadata are mutually exclusive" + ) + if self.icon_data is not None: + self.icon_data = validate_icon_data(self.icon_data) + return self class OAuthProviderResponse(BaseModel): - """Response model for an OAuth provider (excludes secrets).""" + """Response model for an OAuth provider.""" provider_id: str display_name: str provider_type: OAuthProviderType - authorization_endpoint: str - token_endpoint: str - client_id: str scopes: List[str] allowed_roles: List[str] enabled: bool icon_name: str - userinfo_endpoint: Optional[str] = None - revocation_endpoint: Optional[str] = None - pkce_required: bool - authorization_params: Dict[str, str] + icon_data: Optional[str] = None + credential_provider_arn: Optional[str] = None + callback_url: Optional[str] = None + oauth_discovery_url: Optional[str] = None + authorization_server_metadata: Optional[Dict[str, Any]] = None + custom_parameters: Optional[Dict[str, str]] = None created_at: str updated_at: str @classmethod def from_provider(cls, provider: OAuthProvider) -> "OAuthProviderResponse": - """Create from OAuthProvider dataclass.""" return cls( provider_id=provider.provider_id, display_name=provider.display_name, provider_type=provider.provider_type, - authorization_endpoint=provider.authorization_endpoint, - token_endpoint=provider.token_endpoint, - client_id=provider.client_id, scopes=provider.scopes, allowed_roles=provider.allowed_roles, enabled=provider.enabled, icon_name=provider.icon_name, - userinfo_endpoint=provider.userinfo_endpoint, - revocation_endpoint=provider.revocation_endpoint, - pkce_required=provider.pkce_required, - authorization_params=provider.authorization_params, + icon_data=provider.icon_data, + credential_provider_arn=provider.credential_provider_arn, + callback_url=provider.callback_url, + oauth_discovery_url=provider.oauth_discovery_url, + authorization_server_metadata=provider.authorization_server_metadata, + custom_parameters=provider.custom_parameters, created_at=provider.created_at, updated_at=provider.updated_at, ) class OAuthProviderListResponse(BaseModel): - """Response model for listing OAuth providers.""" - providers: List[OAuthProviderResponse] total: int -class OAuthConnectionResponse(BaseModel): - """Response model for a user's OAuth connection.""" - - provider_id: str - display_name: str - provider_type: OAuthProviderType - icon_name: str - status: OAuthConnectionStatus - connected_at: Optional[str] = None - needs_reauth: bool = False - - -class OAuthConnectionListResponse(BaseModel): - """Response model for listing user's OAuth connections.""" - - connections: List[OAuthConnectionResponse] +class OAuthRequiredEvent(BaseModel): + """SSE event signalling that a tool needs user consent before it can run. + Emitted mid-turn when `OAuthConsentHook` raises a Strands interrupt: the + agent's tool call is paused (its in-flight state is held in + `_interrupt_state`), the frontend receives this event, opens the + consent popup at `authorizationUrl`, and on completion POSTs an + interrupt response carrying `interruptId` back to `/invocations`. The + backend resumes the same turn — no retype, no replay. + """ -class OAuthConnectResponse(BaseModel): - """Response model for initiating OAuth connection.""" - - authorization_url: str - + model_config = ConfigDict(populate_by_name=True) -class OAuthCallbackResult(BaseModel): - """Internal model for OAuth callback result.""" + type: str = "oauth_required" + provider_id: str = Field(..., alias="providerId") + authorization_url: str = Field(..., alias="authorizationUrl") + interrupt_id: str = Field(..., alias="interruptId") - success: bool - provider_id: Optional[str] = None - error: Optional[str] = None - error_description: Optional[str] = None + def to_sse_format(self) -> str: + import json + return ( + f"event: oauth_required\n" + f"data: {json.dumps(self.model_dump(by_alias=True, exclude_none=True))}\n\n" + ) diff --git a/backend/src/apis/shared/oauth/provider_repository.py b/backend/src/apis/shared/oauth/provider_repository.py index 31df2825..a738e91b 100644 --- a/backend/src/apis/shared/oauth/provider_repository.py +++ b/backend/src/apis/shared/oauth/provider_repository.py @@ -1,6 +1,10 @@ -"""DynamoDB repository for OAuth provider configurations.""" +"""DynamoDB repository for OAuth provider configurations. + +Only display metadata, scopes, and AgentCore-owned pointers (the credential +provider ARN and callback URL) live here. `clientId` / `clientSecret` are +registered directly with AgentCore Identity by the admin route. +""" -import json import logging import os from datetime import datetime, timezone @@ -9,35 +13,20 @@ import boto3 from botocore.exceptions import ClientError -from .models import OAuthProvider, OAuthProviderCreate, OAuthProviderUpdate +from .models import OAuthProvider, OAuthProviderUpdate logger = logging.getLogger(__name__) class OAuthProviderRepository: - """ - Repository for OAuth provider CRUD operations in DynamoDB. - - Handles provider configurations and client secrets in Secrets Manager. - Uses single-table design with GSI for querying enabled providers. - """ + """CRUD over the oauth-providers DynamoDB table.""" def __init__( self, table_name: Optional[str] = None, - secrets_arn: Optional[str] = None, region: Optional[str] = None, ): - """ - Initialize repository. - - Args: - table_name: DynamoDB table name (defaults to env var) - secrets_arn: Secrets Manager ARN for client secrets (defaults to env var) - region: AWS region (defaults to env var) - """ self._table_name = table_name or os.getenv("DYNAMODB_OAUTH_PROVIDERS_TABLE_NAME") - self._secrets_arn = secrets_arn or os.getenv("OAUTH_CLIENT_SECRETS_ARN") self._region = region or os.getenv("AWS_REGION", "us-west-2") self._enabled = bool(self._table_name) @@ -48,38 +37,18 @@ def __init__( ) return - # Initialize clients profile = os.getenv("AWS_PROFILE") - if profile: - session = boto3.Session(profile_name=profile) - self._dynamodb = session.resource("dynamodb", region_name=self._region) - self._secrets_client = session.client("secretsmanager", region_name=self._region) - else: - self._dynamodb = boto3.resource("dynamodb", region_name=self._region) - self._secrets_client = boto3.client("secretsmanager", region_name=self._region) - + session = boto3.Session(profile_name=profile) if profile else boto3 + self._dynamodb = session.resource("dynamodb", region_name=self._region) self._table = self._dynamodb.Table(self._table_name) - logger.info(f"Initialized OAuth provider repository: table={self._table_name}") + logger.info("Initialized OAuth provider repository: table=%s", self._table_name) @property def enabled(self) -> bool: - """Check if repository is enabled.""" return self._enabled - # ========================================================================= - # Provider CRUD - # ========================================================================= - + # ------------------------------------------------------------------- reads async def get_provider(self, provider_id: str) -> Optional[OAuthProvider]: - """ - Get a provider by ID. - - Args: - provider_id: Provider identifier - - Returns: - OAuthProvider if found, None otherwise - """ if not self._enabled: return None @@ -88,38 +57,23 @@ async def get_provider(self, provider_id: str) -> Optional[OAuthProvider]: Key={"PK": f"PROVIDER#{provider_id}", "SK": "CONFIG"} ) item = response.get("Item") - if not item: - return None - return OAuthProvider.from_dynamo_item(item) - + return OAuthProvider.from_dynamo_item(item) if item else None except ClientError as e: - logger.error(f"Error getting provider {provider_id}: {e}") + logger.error("Error getting provider %s: %s", provider_id, e) raise async def list_providers(self, enabled_only: bool = False) -> List[OAuthProvider]: - """ - List all providers. - - Args: - enabled_only: If True, only return enabled providers - - Returns: - List of OAuthProvider objects - """ if not self._enabled: return [] try: if enabled_only: - # Use GSI for efficient query response = self._table.query( IndexName="EnabledProvidersIndex", KeyConditionExpression="GSI1PK = :pk", ExpressionAttributeValues={":pk": "ENABLED#true"}, ) items = response.get("Items", []) - - # Handle pagination while "LastEvaluatedKey" in response: response = self._table.query( IndexName="EnabledProvidersIndex", @@ -129,14 +83,11 @@ async def list_providers(self, enabled_only: bool = False) -> List[OAuthProvider ) items.extend(response.get("Items", [])) else: - # Scan all providers response = self._table.scan( FilterExpression="SK = :sk", ExpressionAttributeValues={":sk": "CONFIG"}, ) items = response.get("Items", []) - - # Handle pagination while "LastEvaluatedKey" in response: response = self._table.scan( FilterExpression="SK = :sk", @@ -146,157 +97,70 @@ async def list_providers(self, enabled_only: bool = False) -> List[OAuthProvider items.extend(response.get("Items", [])) providers = [OAuthProvider.from_dynamo_item(item) for item in items] - - # Sort by display name providers.sort(key=lambda p: p.display_name.lower()) - return providers - except ClientError as e: - logger.error(f"Error listing providers: {e}") + logger.error("Error listing providers: %s", e) raise - async def create_provider( - self, create_request: OAuthProviderCreate - ) -> OAuthProvider: - """ - Create a new provider. - - Args: - create_request: Provider creation data including client secret + # ------------------------------------------------------------------ writes + async def put_provider(self, provider: OAuthProvider) -> OAuthProvider: + """Upsert a fully-formed provider record. - Returns: - Created OAuthProvider - - Raises: - ValueError: If provider already exists + The admin route is expected to build the `OAuthProvider` with all + AgentCore-owned fields already populated from the registrar call. """ if not self._enabled: raise RuntimeError("OAuth provider repository is not enabled") - # Check if provider exists - existing = await self.get_provider(create_request.provider_id) - if existing: - raise ValueError(f"Provider '{create_request.provider_id}' already exists") - - try: - now = datetime.now(timezone.utc).isoformat() + "Z" + self._table.put_item(Item=provider.to_dynamo_item()) + logger.info("Upserted OAuth provider: %s", provider.provider_id) + return provider - # Create provider object - provider = OAuthProvider( - provider_id=create_request.provider_id, - display_name=create_request.display_name, - provider_type=create_request.provider_type, - authorization_endpoint=create_request.authorization_endpoint, - token_endpoint=create_request.token_endpoint, - client_id=create_request.client_id, - scopes=create_request.scopes, - allowed_roles=create_request.allowed_roles, - enabled=create_request.enabled, - icon_name=create_request.icon_name, - userinfo_endpoint=create_request.userinfo_endpoint, - revocation_endpoint=create_request.revocation_endpoint, - pkce_required=create_request.pkce_required, - authorization_params=create_request.authorization_params, - created_at=now, - updated_at=now, - ) - - # Store client secret in Secrets Manager - await self._store_client_secret( - create_request.provider_id, create_request.client_secret - ) - - # Store provider in DynamoDB - self._table.put_item( - Item=provider.to_dynamo_item(), - ConditionExpression="attribute_not_exists(PK)", - ) - - logger.info(f"Created OAuth provider: {provider.provider_id}") - return provider - - except ClientError as e: - if e.response["Error"]["Code"] == "ConditionalCheckFailedException": - raise ValueError( - f"Provider '{create_request.provider_id}' already exists" - ) - logger.error(f"Error creating provider: {e}") - raise - - async def update_provider( + async def apply_metadata_update( self, provider_id: str, updates: OAuthProviderUpdate ) -> Optional[OAuthProvider]: - """ - Update an existing provider. - - Args: - provider_id: Provider identifier - updates: Fields to update + """Apply a metadata-only update to an existing provider record. - Returns: - Updated OAuthProvider, or None if not found + Does not touch AgentCore — the admin route is responsible for + calling the registrar first when credentials or the discovery + config change, then passing the refreshed metadata through here. + Fields left `None` on `updates` are preserved. """ - if not self._enabled: - return None - existing = await self.get_provider(provider_id) if not existing: return None - try: - # Apply updates - if updates.display_name is not None: - existing.display_name = updates.display_name - if updates.authorization_endpoint is not None: - existing.authorization_endpoint = updates.authorization_endpoint - if updates.token_endpoint is not None: - existing.token_endpoint = updates.token_endpoint - if updates.client_id is not None: - existing.client_id = updates.client_id - if updates.scopes is not None: - existing.scopes = updates.scopes - if updates.allowed_roles is not None: - existing.allowed_roles = updates.allowed_roles - if updates.enabled is not None: - existing.enabled = updates.enabled - if updates.icon_name is not None: - existing.icon_name = updates.icon_name - if updates.userinfo_endpoint is not None: - existing.userinfo_endpoint = updates.userinfo_endpoint - if updates.revocation_endpoint is not None: - existing.revocation_endpoint = updates.revocation_endpoint - if updates.pkce_required is not None: - existing.pkce_required = updates.pkce_required - if updates.authorization_params is not None: - existing.authorization_params = updates.authorization_params - - existing.updated_at = datetime.now(timezone.utc).isoformat() + "Z" - - # Update client secret if provided - if updates.client_secret is not None: - await self._store_client_secret(provider_id, updates.client_secret) - - # Store updated provider - self._table.put_item(Item=existing.to_dynamo_item()) - - logger.info(f"Updated OAuth provider: {provider_id}") - return existing - - except ClientError as e: - logger.error(f"Error updating provider {provider_id}: {e}") - raise + if updates.display_name is not None: + existing.display_name = updates.display_name + if updates.scopes is not None: + existing.scopes = updates.scopes + if updates.allowed_roles is not None: + existing.allowed_roles = updates.allowed_roles + if updates.enabled is not None: + existing.enabled = updates.enabled + if updates.icon_name is not None: + existing.icon_name = updates.icon_name + if updates.icon_data is not None: + # Empty string explicitly clears any uploaded icon (frontends + # then fall back to `icon_name`); a populated data URL replaces + # it. `None` on the update model leaves the existing value alone. + existing.icon_data = updates.icon_data or None + if updates.oauth_discovery_url is not None: + existing.oauth_discovery_url = updates.oauth_discovery_url + if updates.authorization_server_metadata is not None: + existing.authorization_server_metadata = updates.authorization_server_metadata + if updates.custom_parameters is not None: + # Empty dict (`{}`) explicitly clears the field; pass None on the + # update model to leave the existing value alone. + existing.custom_parameters = updates.custom_parameters or None + + existing.updated_at = datetime.now(timezone.utc).isoformat() + "Z" + self._table.put_item(Item=existing.to_dynamo_item()) + logger.info("Updated OAuth provider metadata: %s", provider_id) + return existing async def delete_provider(self, provider_id: str) -> bool: - """ - Delete a provider. - - Args: - provider_id: Provider identifier - - Returns: - True if deleted, False if not found - """ if not self._enabled: return False @@ -305,139 +169,21 @@ async def delete_provider(self, provider_id: str) -> bool: return False try: - # Delete from DynamoDB self._table.delete_item( Key={"PK": f"PROVIDER#{provider_id}", "SK": "CONFIG"} ) - - # Remove client secret from Secrets Manager - await self._delete_client_secret(provider_id) - - logger.info(f"Deleted OAuth provider: {provider_id}") + logger.info("Deleted OAuth provider: %s", provider_id) return True - - except ClientError as e: - logger.error(f"Error deleting provider {provider_id}: {e}") - raise - - # ========================================================================= - # Client Secret Management (Secrets Manager) - # ========================================================================= - - async def get_client_secret(self, provider_id: str) -> Optional[str]: - """ - Get client secret for a provider from Secrets Manager. - - Args: - provider_id: Provider identifier - - Returns: - Client secret string, or None if not found - """ - if not self._secrets_arn: - logger.warning("Secrets ARN not configured, cannot retrieve client secret") - return None - - try: - response = self._secrets_client.get_secret_value( - SecretId=self._secrets_arn - ) - secrets = json.loads(response["SecretString"]) - return secrets.get(provider_id) - - except ClientError as e: - if e.response["Error"]["Code"] == "ResourceNotFoundException": - logger.warning("OAuth secrets not found in Secrets Manager") - return None - logger.error(f"Error getting client secret for {provider_id}: {e}") - raise - - async def _store_client_secret( - self, provider_id: str, client_secret: str - ) -> None: - """ - Store client secret in Secrets Manager. - - Args: - provider_id: Provider identifier - client_secret: Client secret to store - """ - if not self._secrets_arn: - logger.warning( - "Secrets ARN not configured, cannot store client secret. " - "This is only acceptable in development." - ) - return - - try: - # Get existing secrets - try: - response = self._secrets_client.get_secret_value( - SecretId=self._secrets_arn - ) - secrets = json.loads(response["SecretString"]) - except ClientError as e: - if e.response["Error"]["Code"] == "ResourceNotFoundException": - secrets = {} - else: - raise - - # Update with new secret - secrets[provider_id] = client_secret - - # Store back - self._secrets_client.put_secret_value( - SecretId=self._secrets_arn, - SecretString=json.dumps(secrets), - ) - - logger.info(f"Stored client secret for provider: {provider_id}") - except ClientError as e: - logger.error(f"Error storing client secret for {provider_id}: {e}") + logger.error("Error deleting provider %s: %s", provider_id, e) raise - async def _delete_client_secret(self, provider_id: str) -> None: - """ - Remove client secret from Secrets Manager. - - Args: - provider_id: Provider identifier - """ - if not self._secrets_arn: - return - - try: - # Get existing secrets - response = self._secrets_client.get_secret_value( - SecretId=self._secrets_arn - ) - secrets = json.loads(response["SecretString"]) - - # Remove provider's secret - if provider_id in secrets: - del secrets[provider_id] - - # Store back - self._secrets_client.put_secret_value( - SecretId=self._secrets_arn, - SecretString=json.dumps(secrets), - ) - - logger.info(f"Removed client secret for provider: {provider_id}") - - except ClientError as e: - if e.response["Error"]["Code"] != "ResourceNotFoundException": - logger.error(f"Error deleting client secret for {provider_id}: {e}") - raise - -# Singleton instance _provider_repository: Optional[OAuthProviderRepository] = None def get_provider_repository() -> OAuthProviderRepository: - """Get the provider repository singleton.""" + """Get the process-wide provider repository singleton.""" global _provider_repository if _provider_repository is None: _provider_repository = OAuthProviderRepository() diff --git a/backend/src/apis/shared/oauth/routes.py b/backend/src/apis/shared/oauth/routes.py deleted file mode 100644 index 93f2edf7..00000000 --- a/backend/src/apis/shared/oauth/routes.py +++ /dev/null @@ -1,268 +0,0 @@ -"""User-facing OAuth routes for connection management.""" - -import logging -import os -from typing import Optional -from urllib.parse import urlencode, urlparse - -from fastapi import APIRouter, Depends, HTTPException, Query, status -from fastapi.responses import RedirectResponse - -from apis.shared.auth import User, get_current_user -from apis.shared.rbac.service import AppRoleService, get_app_role_service - -from apis.shared.oauth.models import ( - OAuthConnectionListResponse, - OAuthConnectResponse, - OAuthProviderListResponse, - OAuthProviderResponse, -) -from apis.shared.oauth.provider_repository import OAuthProviderRepository, get_provider_repository -from apis.shared.oauth.service import OAuthService, get_oauth_service - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/oauth", tags=["oauth"]) - - -# ============================================================================= -# Provider Discovery (filtered by user roles) -# ============================================================================= - - -@router.get("/providers", response_model=OAuthProviderListResponse) -async def list_available_providers( - current_user: User = Depends(get_current_user), - provider_repo: OAuthProviderRepository = Depends(get_provider_repository), - role_service: AppRoleService = Depends(get_app_role_service), -): - """ - List OAuth providers available to the current user. - - Filters providers based on user's application roles. - - Returns: - OAuthProviderListResponse with available providers - """ - logger.info(f"User {current_user.name} listing available OAuth providers") - - # Resolve user's application roles - permissions = await role_service.resolve_user_permissions(current_user) - user_roles = permissions.app_roles if permissions.app_roles else [] - - # Get enabled providers - providers = await provider_repo.list_providers(enabled_only=True) - - # Filter by user roles - available = [] - for provider in providers: - if not provider.allowed_roles or any( - role in provider.allowed_roles for role in user_roles - ): - available.append(OAuthProviderResponse.from_provider(provider)) - - return OAuthProviderListResponse( - providers=available, - total=len(available), - ) - - -# ============================================================================= -# User Connections -# ============================================================================= - - -@router.get("/connections", response_model=OAuthConnectionListResponse) -async def list_user_connections( - current_user: User = Depends(get_current_user), - oauth_service: OAuthService = Depends(get_oauth_service), - role_service: AppRoleService = Depends(get_app_role_service), -): - """ - List the current user's OAuth connections. - - Returns all available providers with connection status. - - Returns: - OAuthConnectionListResponse with connection statuses - """ - logger.info(f"User {current_user.name} listing OAuth connections") - - # Resolve user's application roles - permissions = await role_service.resolve_user_permissions(current_user) - user_roles = permissions.app_roles if permissions.app_roles else [] - - connections = await oauth_service.get_user_connections( - user_id=current_user.user_id, - user_roles=user_roles, - ) - - return OAuthConnectionListResponse(connections=connections) - - -# ============================================================================= -# OAuth Flow -# ============================================================================= - - -@router.get("/connect/{provider_id}", response_model=OAuthConnectResponse) -async def initiate_connection( - provider_id: str, - redirect: Optional[str] = Query( - None, - description="Frontend URL to redirect after OAuth callback", - ), - current_user: User = Depends(get_current_user), - oauth_service: OAuthService = Depends(get_oauth_service), - role_service: AppRoleService = Depends(get_app_role_service), -): - """ - Initiate OAuth connection flow for a provider. - - Returns an authorization URL that the frontend should redirect to. - - Args: - provider_id: Provider to connect to - redirect: Optional frontend redirect URL after completion - - Returns: - OAuthConnectResponse with authorization URL - - Raises: - HTTPException: - - 404 if provider not found - - 403 if user not authorized for provider - """ - logger.info( - "User initiating OAuth connection" - ) - - # Resolve user's application roles - permissions = await role_service.resolve_user_permissions(current_user) - user_roles = permissions.app_roles if permissions.app_roles else [] - - authorization_url = await oauth_service.initiate_connect( - provider_id=provider_id, - user_id=current_user.user_id, - user_roles=user_roles, - frontend_redirect=redirect, - ) - - return OAuthConnectResponse(authorization_url=authorization_url) - - -@router.get("/callback") -async def oauth_callback( - code: Optional[str] = Query(None), - state: Optional[str] = Query(None), - error: Optional[str] = Query(None), - error_description: Optional[str] = Query(None), - oauth_service: OAuthService = Depends(get_oauth_service), -): - """ - Handle OAuth callback from provider. - - This endpoint is called by the OAuth provider after user authorization. - Exchanges the code for tokens and redirects to the frontend. - - Args: - code: Authorization code from provider - state: State parameter for validation - error: Error code if authorization failed - error_description: Error description if authorization failed - - Returns: - Redirect to frontend with success/error query params - """ - # Get frontend base URL from environment - frontend_url = os.getenv("FRONTEND_URL", "http://localhost:4200") - callback_path = "/settings/oauth/callback" - - # Handle error from provider - if error: - logger.warning("OAuth callback error") - params = urlencode({"error": error, "error_description": error_description or ""}) - return RedirectResponse( - url=f"{frontend_url}{callback_path}?{params}", - status_code=status.HTTP_302_FOUND, - ) - - # Validate required params - if not code or not state: - logger.warning("OAuth callback missing code or state") - params = urlencode({"error": "missing_params"}) - return RedirectResponse( - url=f"{frontend_url}{callback_path}?{params}", - status_code=status.HTTP_302_FOUND, - ) - - # Process callback - provider_id, frontend_redirect, callback_error = await oauth_service.handle_callback( - code=code, - state=state, - ) - - # Build redirect URL — validate that frontend_redirect is same-origin - # to prevent open redirect attacks via manipulated OAuth state - redirect_base = f"{frontend_url}{callback_path}" - if frontend_redirect: - parsed = urlparse(frontend_redirect) - parsed_frontend = urlparse(frontend_url) - if parsed.scheme == parsed_frontend.scheme and parsed.netloc == parsed_frontend.netloc: - redirect_base = frontend_redirect - else: - logger.warning( - f"OAuth callback redirect blocked — origin mismatch: {parsed.netloc} != {parsed_frontend.netloc}" - ) - - if callback_error: - params = urlencode({"error": callback_error, "provider": provider_id}) - return RedirectResponse( - url=f"{redirect_base}?{params}", - status_code=status.HTTP_302_FOUND, - ) - - # Success - params = urlencode({"success": "true", "provider": provider_id}) - return RedirectResponse( - url=f"{redirect_base}?{params}", - status_code=status.HTTP_302_FOUND, - ) - - -# ============================================================================= -# Disconnect -# ============================================================================= - - -@router.delete("/connections/{provider_id}", status_code=status.HTTP_204_NO_CONTENT) -async def disconnect_provider( - provider_id: str, - current_user: User = Depends(get_current_user), - oauth_service: OAuthService = Depends(get_oauth_service), -): - """ - Disconnect from an OAuth provider. - - Revokes tokens if possible and removes the connection. - - Args: - provider_id: Provider to disconnect from - - Raises: - HTTPException: 404 if not connected to provider - """ - logger.info("User disconnecting from OAuth provider") - - disconnected = await oauth_service.disconnect( - user_id=current_user.user_id, - provider_id=provider_id, - ) - - if not disconnected: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Not connected to provider '{provider_id}'", - ) - - return None diff --git a/backend/src/apis/shared/oauth/service.py b/backend/src/apis/shared/oauth/service.py deleted file mode 100644 index 78612c75..00000000 --- a/backend/src/apis/shared/oauth/service.py +++ /dev/null @@ -1,664 +0,0 @@ -"""OAuth service for managing provider connections and token exchange.""" - -import base64 -import hashlib -import logging -import os -import secrets -import time -from dataclasses import dataclass -from datetime import datetime, timezone -from typing import Dict, List, Optional, Tuple - -import httpx -from authlib.integrations.httpx_client import AsyncOAuth2Client -from fastapi import HTTPException, status - -from apis.shared.auth.state_store import StateStore, create_state_store - -from .encryption import TokenEncryptionService, get_token_encryption_service -from .models import ( - OAuthConnectionResponse, - OAuthConnectionStatus, - OAuthProvider, - OAuthUserToken, -) -from .provider_repository import OAuthProviderRepository, get_provider_repository -from .token_cache import TokenCache, get_token_cache -from .token_repository import OAuthTokenRepository, get_token_repository - -logger = logging.getLogger(__name__) - - -@dataclass -class OAuthStateData: - """Data stored with OAuth state for security validation.""" - - provider_id: str - user_id: str - code_verifier: Optional[str] = None # PKCE code verifier (S256) - redirect_uri: Optional[str] = None # Frontend redirect after callback - - -def generate_pkce_pair() -> Tuple[str, str]: - """ - Generate PKCE code verifier and challenge (S256). - - Returns: - Tuple of (code_verifier, code_challenge) - """ - # Generate 32 bytes of random data for code_verifier - code_verifier = secrets.token_urlsafe(32) - - # Create code_challenge using S256: BASE64URL(SHA256(code_verifier)) - digest = hashlib.sha256(code_verifier.encode("ascii")).digest() - code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") - - return code_verifier, code_challenge - - -class OAuthService: - """ - Service for OAuth flow management. - - Handles: - - Initiating OAuth connection flows - - Processing OAuth callbacks and token exchange - - Token refresh and decryption - - Connection status management - """ - - def __init__( - self, - provider_repo: Optional[OAuthProviderRepository] = None, - token_repo: Optional[OAuthTokenRepository] = None, - encryption_service: Optional[TokenEncryptionService] = None, - token_cache: Optional[TokenCache] = None, - state_store: Optional[StateStore] = None, - ): - """ - Initialize OAuth service. - - Args: - provider_repo: Provider repository (defaults to singleton) - token_repo: Token repository (defaults to singleton) - encryption_service: Token encryption service (defaults to singleton) - token_cache: Token cache (defaults to singleton) - state_store: State store for OAuth state (defaults to create_state_store) - """ - self._provider_repo = provider_repo or get_provider_repository() - self._token_repo = token_repo or get_token_repository() - self._encryption = encryption_service or get_token_encryption_service() - self._cache = token_cache or get_token_cache() - self._state_store = state_store or create_state_store() - - # OAuth callback URL (configured in environment) - self._callback_url = os.getenv("OAUTH_CALLBACK_URL", "") - if not self._callback_url: - logger.warning( - "OAUTH_CALLBACK_URL not set. OAuth flows will fail. " - "Set to e.g. https://your-app.com/oauth/callback" - ) - - # State TTL in seconds (10 minutes) - self._state_ttl = 600 - - @property - def enabled(self) -> bool: - """Check if OAuth service is enabled.""" - return self._provider_repo.enabled and self._token_repo.enabled - - # ========================================================================= - # Connection Flow - # ========================================================================= - - async def initiate_connect( - self, - provider_id: str, - user_id: str, - user_roles: List[str], - frontend_redirect: Optional[str] = None, - ) -> str: - """ - Initiate OAuth connection flow. - - Generates authorization URL for user to visit. - - Args: - provider_id: Provider to connect to - user_id: User initiating connection - user_roles: User's roles for access check - frontend_redirect: URL to redirect after callback - - Returns: - Authorization URL - - Raises: - HTTPException: If provider not found or user not authorized - """ - # Get provider - provider = await self._provider_repo.get_provider(provider_id) - if not provider: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Provider '{provider_id}' not found", - ) - - if not provider.enabled: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Provider '{provider_id}' is not enabled", - ) - - # Check role access - if provider.allowed_roles and not any( - role in provider.allowed_roles for role in user_roles - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="You do not have access to this provider", - ) - - # Generate state and PKCE - state = secrets.token_urlsafe(32) - - code_verifier = None - code_challenge = None - if provider.pkce_required: - code_verifier, code_challenge = generate_pkce_pair() - - # Store state data - state_data = OAuthStateData( - provider_id=provider_id, - user_id=user_id, - code_verifier=code_verifier, - redirect_uri=frontend_redirect, - ) - self._store_state(state, state_data) - - # Build authorization URL - params = { - "client_id": provider.client_id, - "redirect_uri": self._callback_url, - "response_type": "code", - "scope": " ".join(provider.scopes), - "state": state, - } - - if code_challenge: - params["code_challenge"] = code_challenge - params["code_challenge_method"] = "S256" - - # Add provider-specific authorization params (e.g., access_type=offline for Google) - if provider.authorization_params: - params.update(provider.authorization_params) - - # Build URL - auth_url = f"{provider.authorization_endpoint}?" - auth_url += "&".join(f"{k}={v}" for k, v in params.items()) - - logger.info(f"Initiated OAuth flow for user {user_id}, provider {provider_id}") - return auth_url - - async def handle_callback( - self, - code: str, - state: str, - ) -> Tuple[str, Optional[str], Optional[str]]: - """ - Handle OAuth callback after user authorization. - - Exchanges code for tokens and stores encrypted. - - Args: - code: Authorization code from provider - state: State parameter for validation - - Returns: - Tuple of (provider_id, frontend_redirect, error) - """ - # Validate and retrieve state data - valid, state_data = self._get_and_delete_state(state) - if not valid or not state_data: - logger.warning(f"Invalid or expired OAuth state: {state[:16]}...") - return "", None, "invalid_state" - - provider_id = state_data.provider_id - user_id = state_data.user_id - code_verifier = state_data.code_verifier - frontend_redirect = state_data.redirect_uri - - try: - # Get provider - provider = await self._provider_repo.get_provider(provider_id) - if not provider: - logger.error(f"Provider not found during callback: {provider_id}") - return provider_id, frontend_redirect, "provider_not_found" - - # Get client secret - client_secret = await self._provider_repo.get_client_secret(provider_id) - if not client_secret: - logger.error(f"Client secret not found for provider: {provider_id}") - return provider_id, frontend_redirect, "configuration_error" - - # Exchange code for tokens - token_data = await self._exchange_code( - provider=provider, - client_secret=client_secret, - code=code, - code_verifier=code_verifier, - ) - - if not token_data: - return provider_id, frontend_redirect, "token_exchange_failed" - - # Encrypt and store tokens - access_token = token_data.get("access_token") - refresh_token = token_data.get("refresh_token") - expires_in = token_data.get("expires_in") - token_type = token_data.get("token_type", "Bearer") - - if not access_token: - logger.error("No access token in response") - return provider_id, frontend_redirect, "no_access_token" - - # Calculate expiration - expires_at = None - if expires_in: - expires_at = int(time.time()) + int(expires_in) - - # Encrypt tokens - encrypted_access = self._encryption.encrypt(access_token) - encrypted_refresh = None - if refresh_token: - encrypted_refresh = self._encryption.encrypt(refresh_token) - - # Create token record - now = datetime.now(timezone.utc).isoformat() + "Z" - user_token = OAuthUserToken( - user_id=user_id, - provider_id=provider_id, - access_token_encrypted=encrypted_access, - refresh_token_encrypted=encrypted_refresh, - token_type=token_type, - expires_at=expires_at, - scopes_hash=provider.scopes_hash, - status=OAuthConnectionStatus.CONNECTED, - connected_at=now, - updated_at=now, - ) - - await self._token_repo.save_token(user_token) - - # Cache the decrypted token - self._cache.set(user_id, provider_id, access_token) - - logger.info(f"Successfully connected user {user_id} to provider {provider_id}") - return provider_id, frontend_redirect, None - - except Exception as e: - logger.error(f"Error handling OAuth callback: {e}", exc_info=True) - return provider_id, frontend_redirect, str(e) - - async def _exchange_code( - self, - provider: OAuthProvider, - client_secret: str, - code: str, - code_verifier: Optional[str] = None, - ) -> Optional[Dict]: - """ - Exchange authorization code for tokens. - - Args: - provider: OAuth provider - client_secret: Provider client secret - code: Authorization code - code_verifier: PKCE code verifier (if used) - - Returns: - Token response dict or None on error - """ - try: - async with AsyncOAuth2Client( - client_id=provider.client_id, - client_secret=client_secret, - token_endpoint=provider.token_endpoint, - ) as client: - token = await client.fetch_token( - url=provider.token_endpoint, - grant_type="authorization_code", - code=code, - redirect_uri=self._callback_url, - code_verifier=code_verifier, - ) - return dict(token) - - except Exception as e: - logger.error(f"Token exchange failed: {e}") - return None - - # ========================================================================= - # Token Access - # ========================================================================= - - async def get_decrypted_token( - self, - user_id: str, - provider_id: str, - ) -> Optional[str]: - """ - Get decrypted access token for a user's provider connection. - - Checks cache first, then decrypts from storage. - Handles token refresh if needed. - - Args: - user_id: User identifier - provider_id: Provider identifier - - Returns: - Decrypted access token, or None if not connected - """ - # Check cache first - cached = self._cache.get(user_id, provider_id) - if cached: - return cached - - # Get token from storage - token = await self._token_repo.get_token(user_id, provider_id) - if not token: - return None - - if token.status == OAuthConnectionStatus.REVOKED: - return None - - # Check if expired and needs refresh - if token.is_expired: - if token.refresh_token_encrypted: - refreshed = await self._refresh_token(user_id, provider_id) - if refreshed: - return refreshed - # Mark as expired - await self._token_repo.update_token_status( - user_id, provider_id, OAuthConnectionStatus.EXPIRED - ) - return None - - # Decrypt and cache - try: - access_token = self._encryption.decrypt(token.access_token_encrypted) - self._cache.set(user_id, provider_id, access_token) - return access_token - except Exception as e: - logger.error(f"Failed to decrypt token: {e}") - return None - - async def _refresh_token( - self, - user_id: str, - provider_id: str, - ) -> Optional[str]: - """ - Refresh an expired token. - - Args: - user_id: User identifier - provider_id: Provider identifier - - Returns: - New access token, or None on failure - """ - token = await self._token_repo.get_token(user_id, provider_id) - if not token or not token.refresh_token_encrypted: - return None - - provider = await self._provider_repo.get_provider(provider_id) - if not provider: - return None - - client_secret = await self._provider_repo.get_client_secret(provider_id) - if not client_secret: - return None - - try: - # Decrypt refresh token - refresh_token = self._encryption.decrypt(token.refresh_token_encrypted) - - # Request new tokens - async with AsyncOAuth2Client( - client_id=provider.client_id, - client_secret=client_secret, - token_endpoint=provider.token_endpoint, - ) as client: - new_token = await client.refresh_token( - url=provider.token_endpoint, - refresh_token=refresh_token, - ) - - if not new_token or "access_token" not in new_token: - logger.warning(f"Token refresh failed for {user_id}/{provider_id}") - return None - - # Update stored token - token.access_token_encrypted = self._encryption.encrypt( - new_token["access_token"] - ) - if "refresh_token" in new_token: - token.refresh_token_encrypted = self._encryption.encrypt( - new_token["refresh_token"] - ) - if "expires_in" in new_token: - token.expires_at = int(time.time()) + int(new_token["expires_in"]) - token.status = OAuthConnectionStatus.CONNECTED - - await self._token_repo.save_token(token) - - # Update cache - self._cache.set(user_id, provider_id, new_token["access_token"]) - - logger.info(f"Refreshed token for user {user_id}, provider {provider_id}") - return new_token["access_token"] - - except Exception as e: - logger.error(f"Token refresh failed: {e}") - return None - - # ========================================================================= - # Connection Management - # ========================================================================= - - async def get_user_connections( - self, - user_id: str, - user_roles: List[str], - ) -> List[OAuthConnectionResponse]: - """ - Get user's OAuth connections with status. - - Returns all available providers with connection status. - - Args: - user_id: User identifier - user_roles: User's roles for filtering available providers - - Returns: - List of connection responses - """ - # Get available providers for user's roles - providers = await self._provider_repo.list_providers(enabled_only=True) - available_providers = [ - p for p in providers - if not p.allowed_roles or any(r in p.allowed_roles for r in user_roles) - ] - - # Get user's tokens - tokens = await self._token_repo.list_user_tokens(user_id) - token_map = {t.provider_id: t for t in tokens} - - # Build connection responses - connections = [] - for provider in available_providers: - token = token_map.get(provider.provider_id) - - if token: - # Check if needs re-auth (scope changes) - needs_reauth = token.scopes_hash != provider.scopes_hash - - # Update status based on token state - if token.is_expired: - status = OAuthConnectionStatus.EXPIRED - elif needs_reauth: - status = OAuthConnectionStatus.NEEDS_REAUTH - else: - status = token.status - - connections.append( - OAuthConnectionResponse( - provider_id=provider.provider_id, - display_name=provider.display_name, - provider_type=provider.provider_type, - icon_name=provider.icon_name, - status=status, - connected_at=token.connected_at, - needs_reauth=needs_reauth, - ) - ) - else: - # Not connected - connections.append( - OAuthConnectionResponse( - provider_id=provider.provider_id, - display_name=provider.display_name, - provider_type=provider.provider_type, - icon_name=provider.icon_name, - status=OAuthConnectionStatus.REVOKED, # Use REVOKED as "not connected" - connected_at=None, - needs_reauth=False, - ) - ) - - return connections - - async def disconnect( - self, - user_id: str, - provider_id: str, - ) -> bool: - """ - Disconnect user from a provider. - - Revokes token if possible and deletes from storage. - - Args: - user_id: User identifier - provider_id: Provider identifier - - Returns: - True if disconnected, False if not connected - """ - # Get token - token = await self._token_repo.get_token(user_id, provider_id) - if not token: - return False - - # Try to revoke token at provider - provider = await self._provider_repo.get_provider(provider_id) - if provider and provider.revocation_endpoint: - try: - access_token = self._encryption.decrypt(token.access_token_encrypted) - await self._revoke_token(provider, access_token) - except Exception as e: - logger.warning(f"Failed to revoke token at provider: {e}") - # Continue with deletion anyway - - # Delete from storage - deleted = await self._token_repo.delete_token(user_id, provider_id) - - # Clear cache - self._cache.delete(user_id, provider_id) - - logger.info(f"Disconnected user {user_id} from provider {provider_id}") - return deleted - - async def _revoke_token( - self, - provider: OAuthProvider, - access_token: str, - ) -> None: - """ - Revoke token at the provider's revocation endpoint. - - Args: - provider: OAuth provider - access_token: Token to revoke - """ - if not provider.revocation_endpoint: - return - - client_secret = await self._provider_repo.get_client_secret(provider.provider_id) - if not client_secret: - return - - try: - async with httpx.AsyncClient() as client: - await client.post( - provider.revocation_endpoint, - data={ - "token": access_token, - "client_id": provider.client_id, - "client_secret": client_secret, - }, - timeout=10.0, - ) - except Exception as e: - logger.warning(f"Token revocation request failed: {e}") - - # ========================================================================= - # State Management (reuses OIDC state store pattern) - # ========================================================================= - - def _store_state(self, state: str, data: OAuthStateData) -> None: - """Store OAuth state data.""" - # Convert to dict for storage - from apis.shared.auth.state_store import OIDCStateData - - oidc_data = OIDCStateData( - redirect_uri=data.redirect_uri, - code_verifier=data.code_verifier, - nonce=f"{data.provider_id}|{data.user_id}", # Encode provider/user in nonce - ) - self._state_store.store_state(state, oidc_data, self._state_ttl) - - def _get_and_delete_state( - self, state: str - ) -> Tuple[bool, Optional[OAuthStateData]]: - """Retrieve and delete OAuth state data.""" - valid, oidc_data = self._state_store.get_and_delete_state(state) - if not valid or not oidc_data: - return False, None - - # Decode provider/user from nonce - if not oidc_data.nonce or "|" not in oidc_data.nonce: - return False, None - - provider_id, user_id = oidc_data.nonce.split("|", 1) - - return True, OAuthStateData( - provider_id=provider_id, - user_id=user_id, - code_verifier=oidc_data.code_verifier, - redirect_uri=oidc_data.redirect_uri, - ) - - -# Singleton instance -_oauth_service: Optional[OAuthService] = None - - -def get_oauth_service() -> OAuthService: - """Get the OAuth service singleton.""" - global _oauth_service - if _oauth_service is None: - _oauth_service = OAuthService() - return _oauth_service diff --git a/backend/src/apis/shared/oauth/token_cache.py b/backend/src/apis/shared/oauth/token_cache.py deleted file mode 100644 index 33224e11..00000000 --- a/backend/src/apis/shared/oauth/token_cache.py +++ /dev/null @@ -1,165 +0,0 @@ -"""In-memory cache for decrypted OAuth tokens. - -Uses TTLCache to avoid repeated KMS decrypt calls for frequently accessed tokens. -Cache entries expire after 5 minutes by default. -""" - -import logging -from typing import Optional - -from cachetools import TTLCache - -logger = logging.getLogger(__name__) - -# Default cache TTL in seconds (5 minutes) -DEFAULT_CACHE_TTL = 300 - -# Default cache size (number of tokens to cache) -DEFAULT_CACHE_SIZE = 1000 - - -class TokenCache: - """ - TTL-based cache for decrypted OAuth access tokens. - - Reduces KMS API calls by caching decrypted tokens in memory. - Tokens are automatically evicted after the TTL expires. - - Thread-safe through cachetools' internal locking. - """ - - def __init__( - self, - maxsize: int = DEFAULT_CACHE_SIZE, - ttl: int = DEFAULT_CACHE_TTL, - ): - """ - Initialize the token cache. - - Args: - maxsize: Maximum number of tokens to cache - ttl: Time-to-live in seconds for cache entries - """ - self._cache: TTLCache = TTLCache(maxsize=maxsize, ttl=ttl) - self._ttl = ttl - self._maxsize = maxsize - - logger.info(f"Initialized token cache: maxsize={maxsize}, ttl={ttl}s") - - def _make_key(self, user_id: str, provider_id: str) -> str: - """Create a cache key from user and provider IDs.""" - return f"{user_id}:{provider_id}" - - def get(self, user_id: str, provider_id: str) -> Optional[str]: - """ - Get a cached access token. - - Args: - user_id: User identifier - provider_id: OAuth provider identifier - - Returns: - Decrypted access token if cached, None otherwise - """ - key = self._make_key(user_id, provider_id) - token = self._cache.get(key) - if token: - logger.debug(f"Token cache hit: user={user_id}, provider={provider_id}") - else: - logger.debug(f"Token cache miss: user={user_id}, provider={provider_id}") - return token - - def set(self, user_id: str, provider_id: str, access_token: str) -> None: - """ - Cache a decrypted access token. - - Args: - user_id: User identifier - provider_id: OAuth provider identifier - access_token: Decrypted access token to cache - """ - key = self._make_key(user_id, provider_id) - self._cache[key] = access_token - logger.debug(f"Cached token: user={user_id}, provider={provider_id}") - - def delete(self, user_id: str, provider_id: str) -> bool: - """ - Remove a token from cache. - - Args: - user_id: User identifier - provider_id: OAuth provider identifier - - Returns: - True if token was in cache, False otherwise - """ - key = self._make_key(user_id, provider_id) - if key in self._cache: - del self._cache[key] - logger.debug(f"Evicted token from cache: user={user_id}, provider={provider_id}") - return True - return False - - def delete_for_user(self, user_id: str) -> int: - """ - Remove all cached tokens for a user. - - Args: - user_id: User identifier - - Returns: - Number of tokens removed - """ - prefix = f"{user_id}:" - keys_to_delete = [k for k in self._cache.keys() if k.startswith(prefix)] - for key in keys_to_delete: - del self._cache[key] - if keys_to_delete: - logger.debug(f"Evicted {len(keys_to_delete)} tokens for user: {user_id}") - return len(keys_to_delete) - - def delete_for_provider(self, provider_id: str) -> int: - """ - Remove all cached tokens for a provider. - - Useful when provider configuration changes (e.g., scopes updated). - - Args: - provider_id: OAuth provider identifier - - Returns: - Number of tokens removed - """ - suffix = f":{provider_id}" - keys_to_delete = [k for k in self._cache.keys() if k.endswith(suffix)] - for key in keys_to_delete: - del self._cache[key] - if keys_to_delete: - logger.debug(f"Evicted {len(keys_to_delete)} tokens for provider: {provider_id}") - return len(keys_to_delete) - - def clear(self) -> None: - """Clear all cached tokens.""" - size = len(self._cache) - self._cache.clear() - logger.info(f"Cleared token cache ({size} entries)") - - def get_stats(self) -> dict: - """Get cache statistics.""" - return { - "size": len(self._cache), - "maxsize": self._maxsize, - "ttl_seconds": self._ttl, - } - - -# Singleton instance -_token_cache: Optional[TokenCache] = None - - -def get_token_cache() -> TokenCache: - """Get the token cache singleton.""" - global _token_cache - if _token_cache is None: - _token_cache = TokenCache() - return _token_cache diff --git a/backend/src/apis/shared/oauth/token_repository.py b/backend/src/apis/shared/oauth/token_repository.py deleted file mode 100644 index a695b148..00000000 --- a/backend/src/apis/shared/oauth/token_repository.py +++ /dev/null @@ -1,337 +0,0 @@ -"""DynamoDB repository for OAuth user tokens.""" - -import logging -import os -from datetime import datetime, timezone -from typing import List, Optional - -import boto3 -from botocore.exceptions import ClientError - -from .models import OAuthUserToken, OAuthConnectionStatus - -logger = logging.getLogger(__name__) - - -class OAuthTokenRepository: - """ - Repository for OAuth user token CRUD operations in DynamoDB. - - Handles encrypted token storage with KMS. - Uses single-table design with GSI for querying by provider. - """ - - def __init__( - self, - table_name: Optional[str] = None, - region: Optional[str] = None, - ): - """ - Initialize repository. - - Args: - table_name: DynamoDB table name (defaults to env var) - region: AWS region (defaults to env var) - """ - self._table_name = table_name or os.getenv("DYNAMODB_OAUTH_USER_TOKENS_TABLE_NAME") - self._region = region or os.getenv("AWS_REGION", "us-west-2") - self._enabled = bool(self._table_name) - - if not self._enabled: - logger.warning( - "DYNAMODB_OAUTH_USER_TOKENS_TABLE_NAME not set. " - "OAuth token repository is disabled." - ) - return - - # Initialize client - profile = os.getenv("AWS_PROFILE") - if profile: - session = boto3.Session(profile_name=profile) - self._dynamodb = session.resource("dynamodb", region_name=self._region) - else: - self._dynamodb = boto3.resource("dynamodb", region_name=self._region) - - self._table = self._dynamodb.Table(self._table_name) - logger.info(f"Initialized OAuth token repository: table={self._table_name}") - - @property - def enabled(self) -> bool: - """Check if repository is enabled.""" - return self._enabled - - # ========================================================================= - # Token CRUD - # ========================================================================= - - async def get_token( - self, user_id: str, provider_id: str - ) -> Optional[OAuthUserToken]: - """ - Get a user's token for a provider. - - Args: - user_id: User identifier - provider_id: Provider identifier - - Returns: - OAuthUserToken if found, None otherwise - """ - if not self._enabled: - return None - - try: - response = self._table.get_item( - Key={ - "PK": f"USER#{user_id}", - "SK": f"PROVIDER#{provider_id}", - } - ) - item = response.get("Item") - if not item: - return None - return OAuthUserToken.from_dynamo_item(item) - - except ClientError as e: - logger.error(f"Error getting token for user {user_id}, provider {provider_id}: {e}") - raise - - async def list_user_tokens(self, user_id: str) -> List[OAuthUserToken]: - """ - List all tokens for a user. - - Args: - user_id: User identifier - - Returns: - List of OAuthUserToken objects - """ - if not self._enabled: - return [] - - try: - response = self._table.query( - KeyConditionExpression="PK = :pk AND begins_with(SK, :sk_prefix)", - ExpressionAttributeValues={ - ":pk": f"USER#{user_id}", - ":sk_prefix": "PROVIDER#", - }, - ) - items = response.get("Items", []) - - # Handle pagination - while "LastEvaluatedKey" in response: - response = self._table.query( - KeyConditionExpression="PK = :pk AND begins_with(SK, :sk_prefix)", - ExpressionAttributeValues={ - ":pk": f"USER#{user_id}", - ":sk_prefix": "PROVIDER#", - }, - ExclusiveStartKey=response["LastEvaluatedKey"], - ) - items.extend(response.get("Items", [])) - - return [OAuthUserToken.from_dynamo_item(item) for item in items] - - except ClientError as e: - logger.error(f"Error listing tokens for user {user_id}: {e}") - raise - - async def list_provider_tokens(self, provider_id: str) -> List[OAuthUserToken]: - """ - List all user tokens for a provider (admin view). - - Uses GSI for efficient lookup. - - Args: - provider_id: Provider identifier - - Returns: - List of OAuthUserToken objects - """ - if not self._enabled: - return [] - - try: - response = self._table.query( - IndexName="ProviderUsersIndex", - KeyConditionExpression="GSI1PK = :pk", - ExpressionAttributeValues={":pk": f"PROVIDER#{provider_id}"}, - ) - items = response.get("Items", []) - - # Handle pagination - while "LastEvaluatedKey" in response: - response = self._table.query( - IndexName="ProviderUsersIndex", - KeyConditionExpression="GSI1PK = :pk", - ExpressionAttributeValues={":pk": f"PROVIDER#{provider_id}"}, - ExclusiveStartKey=response["LastEvaluatedKey"], - ) - items.extend(response.get("Items", [])) - - return [OAuthUserToken.from_dynamo_item(item) for item in items] - - except ClientError as e: - logger.error(f"Error listing tokens for provider {provider_id}: {e}") - raise - - async def save_token(self, token: OAuthUserToken) -> OAuthUserToken: - """ - Save or update a user token. - - Args: - token: Token to save - - Returns: - Saved token - """ - if not self._enabled: - raise RuntimeError("OAuth token repository is not enabled") - - try: - token.updated_at = datetime.now(timezone.utc).isoformat() + "Z" - self._table.put_item(Item=token.to_dynamo_item()) - logger.info(f"Saved token for user {token.user_id}, provider {token.provider_id}") - return token - - except ClientError as e: - logger.error(f"Error saving token: {e}") - raise - - async def update_token_status( - self, - user_id: str, - provider_id: str, - status: OAuthConnectionStatus, - ) -> Optional[OAuthUserToken]: - """ - Update token status. - - Args: - user_id: User identifier - provider_id: Provider identifier - status: New status - - Returns: - Updated token, or None if not found - """ - if not self._enabled: - return None - - token = await self.get_token(user_id, provider_id) - if not token: - return None - - token.status = status - return await self.save_token(token) - - async def delete_token(self, user_id: str, provider_id: str) -> bool: - """ - Delete a user's token for a provider. - - Args: - user_id: User identifier - provider_id: Provider identifier - - Returns: - True if deleted, False if not found - """ - if not self._enabled: - return False - - try: - # Check if exists first - existing = await self.get_token(user_id, provider_id) - if not existing: - return False - - self._table.delete_item( - Key={ - "PK": f"USER#{user_id}", - "SK": f"PROVIDER#{provider_id}", - } - ) - - logger.info(f"Deleted token for user {user_id}, provider {provider_id}") - return True - - except ClientError as e: - logger.error(f"Error deleting token: {e}") - raise - - async def delete_user_tokens(self, user_id: str) -> int: - """ - Delete all tokens for a user. - - Args: - user_id: User identifier - - Returns: - Number of tokens deleted - """ - if not self._enabled: - return 0 - - try: - tokens = await self.list_user_tokens(user_id) - - with self._table.batch_writer() as batch: - for token in tokens: - batch.delete_item( - Key={ - "PK": f"USER#{user_id}", - "SK": f"PROVIDER#{token.provider_id}", - } - ) - - logger.info(f"Deleted {len(tokens)} tokens for user {user_id}") - return len(tokens) - - except ClientError as e: - logger.error(f"Error deleting tokens for user {user_id}: {e}") - raise - - async def delete_provider_tokens(self, provider_id: str) -> int: - """ - Delete all tokens for a provider (when provider is deleted). - - Args: - provider_id: Provider identifier - - Returns: - Number of tokens deleted - """ - if not self._enabled: - return 0 - - try: - tokens = await self.list_provider_tokens(provider_id) - - with self._table.batch_writer() as batch: - for token in tokens: - batch.delete_item( - Key={ - "PK": f"USER#{token.user_id}", - "SK": f"PROVIDER#{provider_id}", - } - ) - - logger.info(f"Deleted {len(tokens)} tokens for provider {provider_id}") - return len(tokens) - - except ClientError as e: - logger.error(f"Error deleting tokens for provider {provider_id}: {e}") - raise - - -# Singleton instance -_token_repository: Optional[OAuthTokenRepository] = None - - -def get_token_repository() -> OAuthTokenRepository: - """Get the token repository singleton.""" - global _token_repository - if _token_repository is None: - _token_repository = OAuthTokenRepository() - return _token_repository diff --git a/backend/src/apis/shared/sessions/__init__.py b/backend/src/apis/shared/sessions/__init__.py index 5b8a7c44..5e479ca0 100644 --- a/backend/src/apis/shared/sessions/__init__.py +++ b/backend/src/apis/shared/sessions/__init__.py @@ -33,6 +33,9 @@ from .metadata import ( store_message_metadata, store_session_metadata, + ensure_session_metadata_exists, + update_session_title, + update_session_activity, get_session_metadata, get_all_message_metadata, list_user_sessions, @@ -69,6 +72,9 @@ # Metadata operations "store_message_metadata", "store_session_metadata", + "ensure_session_metadata_exists", + "update_session_title", + "update_session_activity", "get_session_metadata", "get_all_message_metadata", "list_user_sessions", diff --git a/backend/src/apis/shared/sessions/messages.py b/backend/src/apis/shared/sessions/messages.py index 1ec668ab..32de1e95 100644 --- a/backend/src/apis/shared/sessions/messages.py +++ b/backend/src/apis/shared/sessions/messages.py @@ -411,9 +411,20 @@ async def fetch_metadata(): return await get_all_message_metadata(session_id, user_id) + async def fetch_pending_interrupts(): + """Fetch pending OAuth consent interrupts from session metadata. + + Returns an empty list when the session has none — the only signal + the frontend needs to know whether a refresh-restored conversation + still has a paused turn awaiting consent. + """ + from .metadata import get_pending_interrupts + + return await get_pending_interrupts(session_id, user_id) + # Run fetches in parallel - messages_raw, metadata_index = await asyncio.gather( - fetch_messages(), fetch_metadata() + messages_raw, metadata_index, pending_interrupts = await asyncio.gather( + fetch_messages(), fetch_metadata(), fetch_pending_interrupts() ) messages_raw = list(messages_raw or []) @@ -462,7 +473,11 @@ async def fetch_metadata(): message_responses = [_convert_message_to_response(msg, session_id, start_seq + idx) for idx, msg in enumerate(paginated_messages)] - return MessagesListResponse(messages=message_responses, next_token=next_page_token) + return MessagesListResponse( + messages=message_responses, + next_token=next_page_token, + pending_interrupts=pending_interrupts, + ) except Exception as e: logger.error(f"Error retrieving messages from AgentCore Memory: {e}") diff --git a/backend/src/apis/shared/sessions/metadata.py b/backend/src/apis/shared/sessions/metadata.py index acdb7b83..a4eb15ad 100644 --- a/backend/src/apis/shared/sessions/metadata.py +++ b/backend/src/apis/shared/sessions/metadata.py @@ -11,11 +11,11 @@ import json import os import base64 -from typing import Optional, Tuple, Any, Dict +from typing import Iterable, List, Optional, Tuple, Any, Dict from decimal import Decimal # Relative imports from shared sessions module -from .models import MessageMetadata, SessionMetadata +from .models import MessageMetadata, PausedTurnSnapshot, PendingInterrupt, SessionMetadata, SessionPreferences # Import preview session helper from agents.main_agent.session.preview_session_manager import is_preview_session @@ -697,6 +697,230 @@ async def _store_session_metadata_cloud( ) +async def ensure_session_metadata_exists(session_id: str, user_id: str) -> bool: + """Idempotently create a session metadata row if it doesn't exist yet. + + Uses a conditional ``put_item`` keyed on ``attribute_not_exists(PK)`` so + concurrent first-turn requests for the same session can't clobber each + other. Returns ``True`` when a new row was created (caller can use this + as the "first turn" signal, e.g. to fire title generation). + + No-op for preview sessions, which intentionally skip persistence. + """ + if is_preview_session(session_id): + return False + + sessions_metadata_table = os.environ.get("DYNAMODB_SESSIONS_METADATA_TABLE_NAME") + if not sessions_metadata_table: + raise RuntimeError("DYNAMODB_SESSIONS_METADATA_TABLE_NAME environment variable is required") + + try: + import boto3 + from botocore.exceptions import ClientError + from datetime import datetime, timezone + + dynamodb = boto3.resource("dynamodb") + table = dynamodb.Table(sessions_metadata_table) + + now = datetime.now(timezone.utc).isoformat() + item = { + "PK": f"USER#{user_id}", + "SK": f"S#ACTIVE#{now}#{session_id}", + "GSI_PK": f"SESSION#{session_id}", + "GSI_SK": "META", + "sessionId": session_id, + "userId": user_id, + "title": "New Conversation", + "status": "active", + "createdAt": now, + "lastMessageAt": now, + "messageCount": 0, + "starred": False, + "tags": [], + } + + try: + table.put_item(Item=item, ConditionExpression="attribute_not_exists(PK)") + logger.info(f"💾 Pre-created session metadata for {session_id}") + return True + except ClientError as e: + if e.response.get("Error", {}).get("Code") == "ConditionalCheckFailedException": + # Session already exists — expected for continuing conversations. + return False + raise + except Exception as e: + # Best-effort: failures must not block the stream. update_session_activity + # self-heals by retrying this call once if the row is missing post-stream. + logger.error(f"ensure_session_metadata_exists failed: {e}", exc_info=True) + return False + + +async def update_session_title(session_id: str, user_id: str, title: str) -> None: + """Update only the title attribute on the session row. + + Uses a targeted ``UpdateExpression`` so it can run concurrently with + ``store_session_metadata`` (which does a full-row merge) without racing + on other fields like ``messageCount`` or ``lastMessageAt``. Looks up the + current SK via the GSI because the SK contains a timestamp. + + No-op when the session row doesn't exist (preview sessions, sessions + deleted mid-turn). + """ + if is_preview_session(session_id): + return + + sessions_metadata_table = os.environ.get("DYNAMODB_SESSIONS_METADATA_TABLE_NAME") + if not sessions_metadata_table: + raise RuntimeError("DYNAMODB_SESSIONS_METADATA_TABLE_NAME environment variable is required") + + try: + import boto3 + + dynamodb = boto3.resource("dynamodb") + table = dynamodb.Table(sessions_metadata_table) + + existing = await _get_session_by_gsi(session_id, user_id, table) + if not existing: + logger.info(f"update_session_title: session {session_id} not found, skipping") + return + sk = existing.get("SK") + if not sk: + logger.warning(f"update_session_title: session {session_id} has no SK") + return + + table.update_item( + Key={"PK": f"USER#{user_id}", "SK": sk}, + UpdateExpression="SET title = :t", + ExpressionAttributeValues={":t": title}, + ) + logger.info(f"💾 Updated title for session {session_id}") + except Exception as e: + logger.error(f"update_session_title failed: {e}", exc_info=True) + + +async def update_session_activity( + session_id: str, + user_id: str, + *, + last_model: Optional[str] = None, + last_temperature: Optional[float] = None, + enabled_tools: Optional[List[str]] = None, + system_prompt_hash: Optional[str] = None, +) -> bool: + """Per-turn session activity update with targeted writes. + + Increments ``messageCount``, advances ``lastMessageAt`` to now, and + merges agent-derived preferences. No other attributes are written, so + concurrent writers (``update_session_title``, ``add_pending_interrupt``) + cannot be clobbered by this path. + + Phase A is a targeted ``UpdateExpression`` on the current SK. Phase B + rotates the SK because ``lastMessageAt`` is encoded in it for recency + listing — fresh-read after Phase A, put at the new SK, delete the old. + The Phase B carry picks up any concurrent write that landed between + Phase A and the fresh read; the residual race window is bounded to + that small interval (full elimination requires the schema change in + issue #175). + + Self-heals when the row is missing by calling + ``ensure_session_metadata_exists`` and retrying the lookup once. + No-op for preview sessions. Returns ``True`` when the update applied. + """ + if is_preview_session(session_id): + return False + + sessions_metadata_table = os.environ.get("DYNAMODB_SESSIONS_METADATA_TABLE_NAME") + if not sessions_metadata_table: + raise RuntimeError("DYNAMODB_SESSIONS_METADATA_TABLE_NAME environment variable is required") + + try: + import boto3 + from datetime import datetime, timezone + + dynamodb = boto3.resource("dynamodb") + table = dynamodb.Table(sessions_metadata_table) + + existing = await _get_session_by_gsi(session_id, user_id, table) + if not existing: + # Pre-create may have failed at /invocations entry — try once + # more so we don't lose the session record entirely. + await ensure_session_metadata_exists(session_id, user_id) + existing = await _get_session_by_gsi(session_id, user_id, table) + if not existing: + logger.warning( + "update_session_activity: session %s missing and could not be created", + session_id, + ) + return False + + old_sk = existing.get("SK") + if not old_sk: + logger.warning("update_session_activity: session %s has no SK", session_id) + return False + + # Merge preferences: existing values take effect for keys the + # caller didn't pass (e.g. assistantId set by the assistant-attach + # flow). We replace the whole `preferences` map in one SET so the + # update works whether the attribute exists yet or not — DynamoDB + # disallows updating both a parent path and its children in the + # same expression. + existing_prefs_raw = existing.get("preferences") or {} + try: + existing_prefs = SessionPreferences.model_validate(existing_prefs_raw) + except Exception: + existing_prefs = SessionPreferences() + prefs_dict = existing_prefs.model_dump(by_alias=False, exclude_none=True) + if last_model is not None: + prefs_dict["last_model"] = last_model + if last_temperature is not None: + prefs_dict["last_temperature"] = last_temperature + if enabled_tools is not None: + prefs_dict["enabled_tools"] = enabled_tools + if system_prompt_hash is not None: + prefs_dict["system_prompt_hash"] = system_prompt_hash + merged_prefs = SessionPreferences(**prefs_dict).model_dump(by_alias=True, exclude_none=True) + + now = datetime.now(timezone.utc).isoformat() + pk = f"USER#{user_id}" + + # Phase A: targeted update of owned attributes on the current SK. + # Disjoint from title, starred, tags, pendingInterrupts. + table.update_item( + Key={"PK": pk, "SK": old_sk}, + UpdateExpression="ADD messageCount :one SET lastMessageAt = :t, preferences = :p", + ExpressionAttributeValues={ + ":one": 1, + ":t": now, + ":p": _convert_floats_to_decimal(merged_prefs), + }, + ) + + # Phase B: SK rotation. lastMessageAt is encoded in the SK for + # recency listing, so a per-turn change forces a row move. Fresh + # read carries any concurrent write (e.g. title-gen) that landed + # between Phase A and now. + new_sk = f"S#ACTIVE#{now}#{session_id}" + if new_sk != old_sk: + fresh_resp = table.get_item(Key={"PK": pk, "SK": old_sk}) + fresh = fresh_resp.get("Item") + if not fresh: + logger.warning( + "update_session_activity: row vanished between Phase A and Phase B for %s", + session_id, + ) + return True + carried = {k: v for k, v in fresh.items() if k not in ("PK", "SK")} + new_item = {"PK": pk, "SK": new_sk, **carried} + table.put_item(Item=new_item) + table.delete_item(Key={"PK": pk, "SK": old_sk}) + + logger.info("Updated session activity for %s (sk_rotated=%s)", session_id, new_sk != old_sk) + return True + except Exception as e: + logger.error("update_session_activity failed for %s: %s", session_id, e, exc_info=True) + return False + + async def _get_session_by_gsi(session_id: str, user_id: str, table) -> Optional[dict]: """ Get session record using GSI (SessionLookupIndex) @@ -929,6 +1153,11 @@ async def _get_session_metadata_cloud( for key in ['PK', 'SK', 'GSI_PK', 'GSI_SK']: item.pop(key, None) + # Dedupe pending interrupts at the storage boundary so list_append + # re-emits don't surface as duplicate consent prompts. + if "pendingInterrupts" in item: + item["pendingInterrupts"] = _dedupe_interrupt_dicts(item["pendingInterrupts"]) + return SessionMetadata.model_validate(item) except Exception as e: @@ -1110,6 +1339,9 @@ async def _list_user_sessions_cloud( if is_preview_session(session_id): continue + if "pendingInterrupts" in item: + item["pendingInterrupts"] = _dedupe_interrupt_dicts(item["pendingInterrupts"]) + metadata = SessionMetadata.model_validate(item) sessions.append(metadata) @@ -1183,3 +1415,275 @@ def _deep_merge(base: dict, updates: dict) -> dict: result[key] = value return result + + +# ============================================================================ +# Pending OAuth interrupts +# ============================================================================ +# +# Pending interrupts persist the breadcrumb the SSE stream emits when the +# agent pauses on `oauth_required`, so the frontend can rediscover them on +# reload. We do read-modify-write through the SessionLookupIndex GSI: +# OAuth flows are rare and one-at-a-time per user, so the simplicity wins +# over an UpdateExpression with list_append/REMOVE-by-index gymnastics. + + +def _interrupts_to_dynamo(interrupts: Iterable[PendingInterrupt]) -> List[Dict[str, Any]]: + """Serialize PendingInterrupt list for DynamoDB storage (camelCase keys).""" + return [item.model_dump(by_alias=True, exclude_none=True) for item in interrupts] + + +def _dedupe_interrupt_dicts(raw: Any) -> List[Dict[str, Any]]: + """Last-write-wins dedupe of raw interrupt dicts by ``interruptId``. + + ``add_pending_interrupt`` uses ``list_append`` to be race-free against + concurrent writers, which means re-emits of the same interrupt across + stream replays accumulate as duplicate list entries. Storage-layer + callers run this on the raw list before handing it to the model so + Pydantic validation sees a clean list. Insertion order of the first + occurrence is preserved. + """ + if not raw or not isinstance(raw, list): + return [] + by_id: Dict[str, Dict[str, Any]] = {} + order: List[str] = [] + for entry in raw: + if not isinstance(entry, dict): + continue + iid = entry.get("interruptId") or entry.get("interrupt_id") + if not iid: + continue + if iid not in by_id: + order.append(iid) + by_id[iid] = entry + return [by_id[iid] for iid in order] + + +def _interrupts_from_dynamo(raw: Any) -> List[PendingInterrupt]: + """Parse stored interrupt entries with dedupe and corrupted-entry tolerance.""" + parsed: List[PendingInterrupt] = [] + for entry in _dedupe_interrupt_dicts(raw): + try: + parsed.append(PendingInterrupt.model_validate(entry)) + except Exception as exc: # pragma: no cover — corrupted entry shouldn't break load + logger.warning("Skipping unparseable pending_interrupts entry: %s", exc) + return parsed + + +async def add_pending_interrupt( + session_id: str, + user_id: str, + interrupt: PendingInterrupt, +) -> None: + """Append a pending OAuth interrupt to the session record. + + Uses ``list_append`` with ``if_not_exists`` so concurrent writers can't + lose each other's entries — no read-modify-write window. Re-emits of + the same ``interrupt_id`` across stream replays accumulate as duplicate + list entries and are collapsed last-write-wins by + ``_interrupts_from_dynamo`` on read. + + No-op when the session metadata record is missing (preview sessions, + sessions deleted mid-turn). The frontend will fall back to its in-memory + consent state in that case. + """ + sessions_metadata_table = os.environ.get("DYNAMODB_SESSIONS_METADATA_TABLE_NAME") + if not sessions_metadata_table: + logger.warning("DYNAMODB_SESSIONS_METADATA_TABLE_NAME not set; skipping pending_interrupts persistence") + return + + try: + import boto3 + + dynamodb = boto3.resource("dynamodb") + table = dynamodb.Table(sessions_metadata_table) + + existing = await _get_session_by_gsi(session_id, user_id, table) + if not existing: + logger.info("Skipping pending_interrupts add — session %s not found", session_id) + return + + sk = existing.get("SK") + if not sk: + logger.warning("Session %s has no SK; cannot update pending_interrupts", session_id) + return + + new_entry = interrupt.model_dump(by_alias=True, exclude_none=True) + + table.update_item( + Key={"PK": f"USER#{user_id}", "SK": sk}, + UpdateExpression="SET #pi = list_append(if_not_exists(#pi, :empty), :new)", + ExpressionAttributeNames={"#pi": "pendingInterrupts"}, + ExpressionAttributeValues={":empty": [], ":new": [new_entry]}, + ) + logger.info( + "Persisted pending_interrupt %s (provider=%s) for session %s", + interrupt.interrupt_id, interrupt.provider_id, session_id, + ) + except Exception as e: + # Persistence failure must not break the live SSE flow — the in-memory + # consent on the live tab still works; refresh-resume just won't. + logger.error("Failed to persist pending_interrupt: %s", e, exc_info=True) + + +async def remove_pending_interrupts( + session_id: str, + user_id: str, + interrupt_ids: Iterable[str], +) -> None: + """Drop the given ``interrupt_ids`` from the session's pending list. + + No-op for unknown ids and missing sessions. Used by the resume path + (after the agent successfully completes the resumed turn) and by the + explicit dismiss endpoint. + """ + drop_set = {iid for iid in interrupt_ids if iid} + if not drop_set: + return + + sessions_metadata_table = os.environ.get("DYNAMODB_SESSIONS_METADATA_TABLE_NAME") + if not sessions_metadata_table: + return + + try: + import boto3 + + dynamodb = boto3.resource("dynamodb") + table = dynamodb.Table(sessions_metadata_table) + + existing = await _get_session_by_gsi(session_id, user_id, table) + if not existing: + return + + sk = existing.get("SK") + if not sk: + return + + current = _interrupts_from_dynamo(existing.get("pendingInterrupts") or []) + kept = [p for p in current if p.interrupt_id not in drop_set] + + if len(kept) == len(current): + return # Nothing matched + + table.update_item( + Key={"PK": f"USER#{user_id}", "SK": sk}, + UpdateExpression="SET #pi = :pi", + ExpressionAttributeNames={"#pi": "pendingInterrupts"}, + ExpressionAttributeValues={":pi": _interrupts_to_dynamo(kept)}, + ) + logger.info( + "Cleared %d pending_interrupt(s) from session %s", + len(current) - len(kept), session_id, + ) + except Exception as e: + logger.error("Failed to remove pending_interrupts: %s", e, exc_info=True) + + +async def get_pending_interrupts(session_id: str, user_id: str) -> List[PendingInterrupt]: + """Return the current pending OAuth interrupts for a session. + + Returns an empty list when the session doesn't exist or has none. + """ + metadata = await get_session_metadata(session_id, user_id) + if not metadata: + return [] + return list(metadata.pending_interrupts or []) + + +async def set_paused_turn( + session_id: str, + user_id: str, + snapshot: PausedTurnSnapshot, +) -> None: + """Persist (or replace) the agent-construction snapshot for a paused turn. + + Idempotent overwrite: re-emits within the same turn replace the prior + snapshot rather than accumulating, since the snapshot is turn-scoped + rather than interrupt-scoped — multiple OAuth interrupts in a single + turn share the same construction context. + + No-op when the session metadata record is missing or when the table + name env var is unset (preview/anonymous flows). + """ + sessions_metadata_table = os.environ.get("DYNAMODB_SESSIONS_METADATA_TABLE_NAME") + if not sessions_metadata_table: + logger.warning("DYNAMODB_SESSIONS_METADATA_TABLE_NAME not set; skipping paused_turn persistence") + return + + try: + import boto3 + + dynamodb = boto3.resource("dynamodb") + table = dynamodb.Table(sessions_metadata_table) + + existing = await _get_session_by_gsi(session_id, user_id, table) + if not existing: + logger.info("Skipping paused_turn write — session %s not found", session_id) + return + + sk = existing.get("SK") + if not sk: + logger.warning("Session %s has no SK; cannot update paused_turn", session_id) + return + + snapshot_dict = _convert_floats_to_decimal( + snapshot.model_dump(by_alias=True, exclude_none=True) + ) + + table.update_item( + Key={"PK": f"USER#{user_id}", "SK": sk}, + UpdateExpression="SET #pt = :pt", + ExpressionAttributeNames={"#pt": "pausedTurn"}, + ExpressionAttributeValues={":pt": snapshot_dict}, + ) + logger.info("Persisted paused_turn snapshot for session %s", session_id) + except Exception as e: + # Best-effort: a write failure shouldn't break the live SSE flow. + # The same-process resume still works via the in-memory agent cache. + logger.error("Failed to persist paused_turn: %s", e, exc_info=True) + + +async def get_paused_turn(session_id: str, user_id: str) -> Optional[PausedTurnSnapshot]: + """Return the persisted paused-turn snapshot for a session, if any.""" + metadata = await get_session_metadata(session_id, user_id) + if not metadata: + return None + return metadata.paused_turn + + +async def clear_paused_turn(session_id: str, user_id: str) -> None: + """Drop the paused-turn snapshot for a session. + + Called on successful resume completion, on explicit dismiss, and at the + start of a non-resume invocation so a stale snapshot from an abandoned + turn doesn't poison a fresh one. + """ + sessions_metadata_table = os.environ.get("DYNAMODB_SESSIONS_METADATA_TABLE_NAME") + if not sessions_metadata_table: + return + + try: + import boto3 + + dynamodb = boto3.resource("dynamodb") + table = dynamodb.Table(sessions_metadata_table) + + existing = await _get_session_by_gsi(session_id, user_id, table) + if not existing: + return + + sk = existing.get("SK") + if not sk: + return + + if "pausedTurn" not in existing: + return # Already clear + + table.update_item( + Key={"PK": f"USER#{user_id}", "SK": sk}, + UpdateExpression="REMOVE #pt", + ExpressionAttributeNames={"#pt": "pausedTurn"}, + ) + logger.info("Cleared paused_turn for session %s", session_id) + except Exception as e: + logger.error("Failed to clear paused_turn: %s", e, exc_info=True) diff --git a/backend/src/apis/shared/sessions/models.py b/backend/src/apis/shared/sessions/models.py index d684b6c5..baaa2510 100644 --- a/backend/src/apis/shared/sessions/models.py +++ b/backend/src/apis/shared/sessions/models.py @@ -20,6 +20,58 @@ class VisualDisplayState(BaseModel): expanded: bool = Field(default=True, description="Visual is expanded vs collapsed") +class PendingInterrupt(BaseModel): + """An OAuth consent request that paused an agent turn and is awaiting user action. + + Persisted to session metadata when ``OAuthConsentHook`` fires the interrupt + so the frontend can rediscover pending consents on reload — without it, a + browser refresh leaves the consent prompt stuck and the tool call orphaned + in ``pending`` forever. + + Note: ``authorization_url`` is intentionally omitted. AgentCore Identity's + consent URLs are short-lived; storing them invites stale-URL bugs on + refresh-after-an-hour. Frontend re-fetches via ``initiate-consent`` on + Connect. + """ + + model_config = ConfigDict(populate_by_name=True) + interrupt_id: str = Field(..., alias="interruptId", description="Strands interrupt id used to resume the paused turn") + provider_id: str = Field(..., alias="providerId", description="Connector providerId needing consent") + triggering_message_id: Optional[str] = Field( + None, + alias="triggeringMessageId", + description="Id of the assistant message whose tool call triggered this interrupt, when known", + ) + created_at: str = Field(..., alias="createdAt", description="ISO 8601 timestamp when the interrupt was recorded") + + +class PausedTurnSnapshot(BaseModel): + """Frozen agent-construction context for a turn that paused on OAuth consent. + + Written once per paused turn so the resume request can rebuild the same + ``MainAgent`` shape (matching tool registry, model, prompt) regardless of + whether the in-process agent cache still holds it. Strands' session + manager separately persists ``_interrupt_state`` to AgentCore Memory, so + once the agent is rebuilt with the right shape the interrupt restores + automatically and the paused tool call can resume. + + Snapshot wins over current request state on resume: a turn the user + already authorized completes with the connector set it was authorized + against, even if the user toggled connectors mid-pause. + """ + + model_config = ConfigDict(populate_by_name=True) + enabled_tools: Optional[List[str]] = Field(default=None, alias="enabledTools") + model_id: Optional[str] = Field(default=None, alias="modelId") + provider: Optional[str] = Field(default=None) + temperature: Optional[float] = Field(default=None) + system_prompt: Optional[str] = Field(default=None, alias="systemPrompt") + caching_enabled: Optional[bool] = Field(default=None, alias="cachingEnabled") + max_tokens: Optional[int] = Field(default=None, alias="maxTokens") + captured_at: str = Field(..., alias="capturedAt", description="ISO 8601 timestamp when the turn paused") + expires_at: str = Field(..., alias="expiresAt", description="ISO 8601 timestamp after which the snapshot is no longer valid for resume") + + class SessionPreferences(BaseModel): """User preferences for a session""" @@ -78,6 +130,18 @@ class SessionMetadata(BaseModel): deleted: Optional[bool] = Field(False, description="Whether session is soft-deleted") deleted_at: Optional[str] = Field(None, alias="deletedAt", description="ISO 8601 timestamp of deletion") + # OAuth consent state + pending_interrupts: Optional[List[PendingInterrupt]] = Field( + default=None, + alias="pendingInterrupts", + description="Pending OAuth consent interrupts that paused agent turns in this session", + ) + paused_turn: Optional[PausedTurnSnapshot] = Field( + default=None, + alias="pausedTurn", + description="Agent-construction snapshot for a turn paused on OAuth consent; cleared on successful resume or when a new turn supersedes it", + ) + class UpdateSessionMetadataRequest(BaseModel): """Request body for updating session metadata""" @@ -298,3 +362,8 @@ class MessagesListResponse(BaseModel): messages: List[MessageResponse] = Field(..., description="List of messages in the session") next_token: Optional[str] = Field(None, alias="nextToken", description="Pagination token for retrieving the next page of results") + pending_interrupts: List[PendingInterrupt] = Field( + default_factory=list, + alias="pendingInterrupts", + description="OAuth consent interrupts that paused agent turns in this session and are awaiting user action", + ) diff --git a/backend/tests/agents/main_agent/integrations/test_agentcore_identity.py b/backend/tests/agents/main_agent/integrations/test_agentcore_identity.py new file mode 100644 index 00000000..f4794c8a --- /dev/null +++ b/backend/tests/agents/main_agent/integrations/test_agentcore_identity.py @@ -0,0 +1,278 @@ +"""Tests for AgentCoreIdentityClient.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agents.main_agent.integrations.agentcore_identity import ( + AgentCoreIdentityClient, + TokenResult, + WorkloadTokenUnavailableError, + custom_parameters_for, +) + + +class TestCustomParametersFor: + """Merge of vendor baseline + admin extras. Baseline is non-negotiable + because admins can't safely turn off documented requirements (e.g. + Google's `access_type=offline` for refresh tokens).""" + + def test_google_baseline_alone(self) -> None: + assert custom_parameters_for("google") == {"access_type": "offline"} + + def test_google_match_is_case_insensitive(self) -> None: + # OAuthProviderType.GOOGLE.value is "google", but defensive against + # callers that pass the upper-case enum name. + assert custom_parameters_for("Google") == {"access_type": "offline"} + + @pytest.mark.parametrize( + "vendor", ["microsoft", "github", "canvas", "custom", "unknown"] + ) + def test_other_vendors_with_no_extras_return_none(self, vendor: str) -> None: + # Per the AgentCore Identity docs, only Google requires baseline + # extras today. Returning None lets callers pass through. + assert custom_parameters_for(vendor) is None + + def test_none_returns_none(self) -> None: + assert custom_parameters_for(None) is None + + def test_empty_string_returns_none(self) -> None: + assert custom_parameters_for("") is None + + def test_admin_extras_merged_with_google_baseline(self) -> None: + # Admin can add domain restriction / prompt without losing + # the access_type=offline requirement. + result = custom_parameters_for( + "google", {"hd": "mycompany.com", "prompt": "consent"} + ) + assert result == { + "access_type": "offline", + "hd": "mycompany.com", + "prompt": "consent", + } + + def test_admin_cannot_override_baseline_keys(self) -> None: + # Admin-supplied access_type=online is silently superseded by the + # baseline. This is intentional — overriding it would silently + # break refresh tokens, the exact bug we hardcoded against. + result = custom_parameters_for("google", {"access_type": "online"}) + assert result == {"access_type": "offline"} + + def test_admin_extras_only_for_non_baseline_vendor(self) -> None: + # Vendors with no baseline still pass through admin extras. + result = custom_parameters_for("github", {"prompt": "consent"}) + assert result == {"prompt": "consent"} + + def test_empty_admin_extras_treated_as_none(self) -> None: + assert custom_parameters_for("microsoft", {}) is None + assert custom_parameters_for("microsoft", None) is None + + +class TestTokenResult: + def test_access_token_only_is_valid(self) -> None: + result = TokenResult(access_token="abc") + assert result.access_token == "abc" + assert result.authorization_url is None + assert result.requires_consent is False + + def test_authorization_url_only_is_valid(self) -> None: + result = TokenResult(authorization_url="https://example.com/auth") + assert result.requires_consent is True + + def test_both_populated_raises(self) -> None: + with pytest.raises(ValueError): + TokenResult(access_token="a", authorization_url="https://example.com") + + def test_neither_populated_raises(self) -> None: + with pytest.raises(ValueError): + TokenResult() + + +@pytest.fixture +def mock_identity_sdk(): + """Patch the IdentityClient class used inside the wrapper.""" + with patch( + "agents.main_agent.integrations.agentcore_identity.IdentityClient" + ) as sdk_cls: + yield sdk_cls + + +@pytest.fixture +def mock_context(): + """Patch BedrockAgentCoreContext accessors used inside the wrapper.""" + with patch( + "agents.main_agent.integrations.agentcore_identity.BedrockAgentCoreContext" + ) as ctx: + ctx.get_workload_access_token.return_value = "workload-token-xyz" + ctx.get_oauth2_callback_url.return_value = "https://cb.example.com/oauth" + yield ctx + + +class TestGetTokenForUserCacheHit: + @pytest.mark.asyncio + async def test_returns_access_token_when_vault_has_token( + self, mock_identity_sdk: MagicMock, mock_context: MagicMock + ) -> None: + sdk_instance = mock_identity_sdk.return_value + sdk_instance.get_token = AsyncMock(return_value="ya29.access-token") + + client = AgentCoreIdentityClient(region="us-east-1") + result = await client.get_token_for_user( + provider_name="google-workspace", scopes=["openid"] + ) + + assert result.access_token == "ya29.access-token" + assert result.requires_consent is False + + sdk_instance.get_token.assert_called_once() + kwargs = sdk_instance.get_token.call_args.kwargs + assert kwargs["provider_name"] == "google-workspace" + assert kwargs["scopes"] == ["openid"] + assert kwargs["auth_flow"] == "USER_FEDERATION" + assert kwargs["agent_identity_token"] == "workload-token-xyz" + # Wrapper appends provider_id to the callback so the /oauth-complete + # page knows which provider to dismiss in the consent banner. + assert kwargs["callback_url"] == ( + "https://cb.example.com/oauth?provider_id=google-workspace" + ) + assert kwargs["force_authentication"] is False + + @pytest.mark.asyncio + async def test_explicit_callback_url_overrides_context( + self, mock_identity_sdk: MagicMock, mock_context: MagicMock + ) -> None: + sdk_instance = mock_identity_sdk.return_value + sdk_instance.get_token = AsyncMock(return_value="t") + + client = AgentCoreIdentityClient() + await client.get_token_for_user( + provider_name="p", + scopes=["s"], + callback_url="https://override.example.com/cb", + ) + + kwargs = sdk_instance.get_token.call_args.kwargs + assert kwargs["callback_url"] == ( + "https://override.example.com/cb?provider_id=p" + ) + + +class TestGetTokenForUserConsentRequired: + @pytest.mark.asyncio + async def test_returns_authorization_url_when_sdk_invokes_callback( + self, mock_identity_sdk: MagicMock, mock_context: MagicMock + ) -> None: + """When the user needs to consent, the SDK calls on_auth_url with the + consent URL. The wrapper captures it and returns a TokenResult with + authorization_url set rather than raising.""" + sdk_instance = mock_identity_sdk.return_value + + async def fake_get_token(**kwargs): + kwargs["on_auth_url"]("https://accounts.example.com/consent?x=1") + return None + + sdk_instance.get_token = AsyncMock(side_effect=fake_get_token) + + client = AgentCoreIdentityClient() + result = await client.get_token_for_user(provider_name="p", scopes=["s"]) + + assert result.requires_consent is True + assert result.authorization_url == "https://accounts.example.com/consent?x=1" + assert result.access_token is None + + @pytest.mark.asyncio + async def test_auth_url_takes_precedence_over_stale_token( + self, mock_identity_sdk: MagicMock, mock_context: MagicMock + ) -> None: + """Defensive: if the SDK both returns a token AND invokes on_auth_url, + we treat consent-required as the authoritative signal.""" + sdk_instance = mock_identity_sdk.return_value + + async def fake_get_token(**kwargs): + kwargs["on_auth_url"]("https://consent.example.com") + return "stale-token" + + sdk_instance.get_token = AsyncMock(side_effect=fake_get_token) + + client = AgentCoreIdentityClient() + result = await client.get_token_for_user(provider_name="p", scopes=["s"]) + + assert result.requires_consent is True + assert result.authorization_url == "https://consent.example.com" + + +class TestGetTokenForUserErrors: + @pytest.mark.asyncio + async def test_raises_when_no_workload_token_on_context( + self, mock_identity_sdk: MagicMock, mock_context: MagicMock + ) -> None: + mock_context.get_workload_access_token.return_value = None + + client = AgentCoreIdentityClient() + with pytest.raises(WorkloadTokenUnavailableError): + await client.get_token_for_user(provider_name="p", scopes=["s"]) + + @pytest.mark.asyncio + async def test_raises_when_sdk_returns_nothing_and_no_auth_url( + self, mock_identity_sdk: MagicMock, mock_context: MagicMock + ) -> None: + sdk_instance = mock_identity_sdk.return_value + sdk_instance.get_token = AsyncMock(return_value=None) + + client = AgentCoreIdentityClient() + with pytest.raises(RuntimeError, match="neither a token nor"): + await client.get_token_for_user(provider_name="p", scopes=["s"]) + + @pytest.mark.asyncio + async def test_force_authentication_flag_is_forwarded( + self, mock_identity_sdk: MagicMock, mock_context: MagicMock + ) -> None: + sdk_instance = mock_identity_sdk.return_value + sdk_instance.get_token = AsyncMock(return_value="t") + + client = AgentCoreIdentityClient() + await client.get_token_for_user( + provider_name="p", scopes=["s"], force_authentication=True + ) + + kwargs = sdk_instance.get_token.call_args.kwargs + assert kwargs["force_authentication"] is True + + @pytest.mark.asyncio + async def test_custom_parameters_are_forwarded_to_sdk( + self, mock_identity_sdk: MagicMock, mock_context: MagicMock + ) -> None: + # AgentCore Identity needs Google's `access_type=offline` forwarded + # via the SDK's `custom_parameters` kwarg — without it Google + # issues no refresh token and the vault entry expires after 1hr. + sdk_instance = mock_identity_sdk.return_value + sdk_instance.get_token = AsyncMock(return_value="t") + + client = AgentCoreIdentityClient() + await client.get_token_for_user( + provider_name="p", + scopes=["s"], + custom_parameters={"access_type": "offline"}, + ) + + kwargs = sdk_instance.get_token.call_args.kwargs + assert kwargs["custom_parameters"] == {"access_type": "offline"} + + @pytest.mark.asyncio + async def test_custom_parameters_omitted_when_none_or_empty( + self, mock_identity_sdk: MagicMock, mock_context: MagicMock + ) -> None: + # The SDK only ships the kwarg when we actually have something to + # send; absent custom_parameters should not appear in the call. + sdk_instance = mock_identity_sdk.return_value + sdk_instance.get_token = AsyncMock(return_value="t") + + client = AgentCoreIdentityClient() + await client.get_token_for_user(provider_name="p", scopes=["s"]) + assert "custom_parameters" not in sdk_instance.get_token.call_args.kwargs + + sdk_instance.get_token.reset_mock() + await client.get_token_for_user( + provider_name="p", scopes=["s"], custom_parameters={} + ) + assert "custom_parameters" not in sdk_instance.get_token.call_args.kwargs diff --git a/backend/tests/agents/main_agent/integrations/test_external_mcp_client.py b/backend/tests/agents/main_agent/integrations/test_external_mcp_client.py index 2023c2ed..fef59b2b 100644 --- a/backend/tests/agents/main_agent/integrations/test_external_mcp_client.py +++ b/backend/tests/agents/main_agent/integrations/test_external_mcp_client.py @@ -1,13 +1,23 @@ """ -Tests for extract_region_from_url and detect_aws_service_from_url. +Tests for the external MCP client helpers. + +OAuth provisioning moved to `OAuthConsentHook` (see +`tests/agents/main_agent/session/hooks/test_oauth_consent.py`); this +module covers the URL-parsing helpers and the integration's +MCPClient -> provider_id map that the hook reads from. Requirements: 25.1–25.3 """ +from datetime import datetime, timezone +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + import pytest from agents.main_agent.integrations.external_mcp_client import ( - extract_region_from_url, + ExternalMCPIntegration, detect_aws_service_from_url, + extract_region_from_url, ) @@ -62,3 +72,234 @@ def test_defaults_to_lambda_for_unknown_url(self): """Req 25.3: Defaults to 'lambda' for unrecognized URL patterns.""" url = "https://example.com/api/v1" assert detect_aws_service_from_url(url) == "lambda" + + +class TestProviderForClient: + """The integration's MCPClient -> provider_id map is what + `OAuthConsentHook.provider_lookup` consults.""" + + def test_unknown_client_returns_none(self): + integration = ExternalMCPIntegration() + + class FakeClient: + pass + + assert integration.provider_for_client(FakeClient()) is None + + def test_records_and_resolves_provider_for_client(self): + integration = ExternalMCPIntegration() + + class FakeClient: + pass + + client = FakeClient() + # Simulate what `load_external_tools` does after creating an + # OAuth-gated MCP client. + integration._provider_for_client_id[id(client)] = "google-workspace" + + assert integration.provider_for_client(client) == "google-workspace" + + def test_clear_user_clients_drops_provider_mapping(self): + integration = ExternalMCPIntegration() + + class FakeClient: + pass + + client = FakeClient() + integration.clients["alice:gmail"] = client + integration._provider_for_client_id[id(client)] = "google-workspace" + + integration.clear_user_clients("alice") + + assert "alice:gmail" not in integration.clients + assert integration.provider_for_client(client) is None + + +class TestClearToolClients: + """Admin updates to a tool must invalidate cached clients for that + tool so the next agent build reconnects with the updated config.""" + + def test_clears_non_oauth_tool_and_keeps_other_tools(self): + integration = ExternalMCPIntegration() + + class FakeClient: + pass + + gmail = FakeClient() + jira = FakeClient() + integration.clients["gmail"] = gmail + integration.clients["jira"] = jira + + integration.clear_tool_clients("gmail") + + assert "gmail" not in integration.clients + assert integration.clients["jira"] is jira + + def test_clears_all_user_scoped_keys_for_tool(self): + integration = ExternalMCPIntegration() + + class FakeClient: + pass + + alice_gmail = FakeClient() + bob_gmail = FakeClient() + alice_jira = FakeClient() + integration.clients["alice:gmail"] = alice_gmail + integration.clients["bob:gmail"] = bob_gmail + integration.clients["alice:jira"] = alice_jira + integration._provider_for_client_id[id(alice_gmail)] = "google-workspace" + integration._provider_for_client_id[id(bob_gmail)] = "google-workspace" + + integration.clear_tool_clients("gmail") + + assert "alice:gmail" not in integration.clients + assert "bob:gmail" not in integration.clients + assert integration.clients["alice:jira"] is alice_jira + assert integration.provider_for_client(alice_gmail) is None + assert integration.provider_for_client(bob_gmail) is None + + def test_does_not_match_tool_id_as_key_suffix_without_colon(self): + """Guard against substring false positives: a tool named "gmail" + must not clear a tool named "super-gmail".""" + integration = ExternalMCPIntegration() + + class FakeClient: + pass + + super_gmail = FakeClient() + integration.clients["super-gmail"] = super_gmail + + integration.clear_tool_clients("gmail") + + assert integration.clients["super-gmail"] is super_gmail + + def test_no_op_when_tool_not_cached(self): + integration = ExternalMCPIntegration() + integration.clear_tool_clients("never-loaded") + assert integration.clients == {} + + +def _fake_tool(updated_at, tool_id="gmail"): + """Minimal tool stand-in for load_external_tools.""" + return SimpleNamespace( + tool_id=tool_id, + protocol="mcp_external", + mcp_config=SimpleNamespace(server_url="https://example.com/mcp"), + forward_auth_token=False, + requires_oauth_provider=None, + updated_at=updated_at, + ) + + +class TestLoadExternalToolsVersioning: + """`load_external_tools` must rebuild the MCPClient when the tool's + `updated_at` changes. Without this, admin edits to MCP config (URL, + auth mode, etc.) never take effect for the process lifetime.""" + + @pytest.mark.asyncio + async def test_reuses_client_when_updated_at_unchanged(self): + integration = ExternalMCPIntegration() + tool = _fake_tool(datetime(2025, 1, 1, tzinfo=timezone.utc)) + repo = SimpleNamespace(get_tool=AsyncMock(return_value=tool)) + + client = SimpleNamespace(load_tools=AsyncMock(return_value=[])) + + with patch( + "apis.app_api.tools.repository.get_tool_catalog_repository", + return_value=repo, + ), patch( + "agents.main_agent.integrations.external_mcp_client.create_external_mcp_client", + return_value=client, + ) as create_mock: + first = await integration.load_external_tools(["gmail"]) + second = await integration.load_external_tools(["gmail"]) + + assert first == second + assert create_mock.call_count == 1 + + @pytest.mark.asyncio + async def test_rebuilds_client_when_updated_at_changes(self): + integration = ExternalMCPIntegration() + old = _fake_tool(datetime(2025, 1, 1, tzinfo=timezone.utc)) + new = _fake_tool(datetime(2025, 2, 1, tzinfo=timezone.utc)) + + repo = SimpleNamespace(get_tool=AsyncMock(side_effect=[old, new])) + + client_old = SimpleNamespace(load_tools=AsyncMock(return_value=[])) + client_new = SimpleNamespace(load_tools=AsyncMock(return_value=[])) + + with patch( + "apis.app_api.tools.repository.get_tool_catalog_repository", + return_value=repo, + ), patch( + "agents.main_agent.integrations.external_mcp_client.create_external_mcp_client", + side_effect=[client_old, client_new], + ): + first = await integration.load_external_tools(["gmail"]) + second = await integration.load_external_tools(["gmail"]) + + assert first == [client_old] + assert second == [client_new] + assert integration.clients["gmail"] is client_new + # Old client must be evicted, not left dangling under the same key. + assert client_old not in integration.clients.values() + + +class TestLoadExternalToolsPreflight: + """A single unreachable MCP server must not fail the whole turn — + `load_external_tools` pre-flights each new client and silently drops + the ones whose session can't be opened.""" + + @pytest.mark.asyncio + async def test_skips_client_when_preflight_fails(self): + integration = ExternalMCPIntegration() + tool = _fake_tool(datetime(2025, 1, 1, tzinfo=timezone.utc)) + repo = SimpleNamespace(get_tool=AsyncMock(return_value=tool)) + + bad_client = SimpleNamespace( + load_tools=AsyncMock(side_effect=RuntimeError("connection refused")) + ) + + with patch( + "apis.app_api.tools.repository.get_tool_catalog_repository", + return_value=repo, + ), patch( + "agents.main_agent.integrations.external_mcp_client.create_external_mcp_client", + return_value=bad_client, + ): + result = await integration.load_external_tools(["gmail"]) + + assert result == [] + # Failed clients must not be cached — otherwise we'd serve a + # broken client back on subsequent turns. + assert "gmail" not in integration.clients + + @pytest.mark.asyncio + async def test_one_failing_client_does_not_block_others(self): + integration = ExternalMCPIntegration() + bad_tool = _fake_tool( + datetime(2025, 1, 1, tzinfo=timezone.utc), tool_id="calendar" + ) + good_tool = _fake_tool( + datetime(2025, 1, 1, tzinfo=timezone.utc), tool_id="gmail" + ) + repo = SimpleNamespace( + get_tool=AsyncMock(side_effect=[bad_tool, good_tool]) + ) + + bad_client = SimpleNamespace( + load_tools=AsyncMock(side_effect=RuntimeError("connection refused")) + ) + good_client = SimpleNamespace(load_tools=AsyncMock(return_value=[])) + + with patch( + "apis.app_api.tools.repository.get_tool_catalog_repository", + return_value=repo, + ), patch( + "agents.main_agent.integrations.external_mcp_client.create_external_mcp_client", + side_effect=[bad_client, good_client], + ): + result = await integration.load_external_tools(["calendar", "gmail"]) + + assert result == [good_client] + assert integration.clients == {"gmail": good_client} diff --git a/backend/tests/agents/main_agent/integrations/test_oauth_token_cache.py b/backend/tests/agents/main_agent/integrations/test_oauth_token_cache.py new file mode 100644 index 00000000..939ec006 --- /dev/null +++ b/backend/tests/agents/main_agent/integrations/test_oauth_token_cache.py @@ -0,0 +1,61 @@ +"""Tests for the in-process OAuth token cache.""" + +from agents.main_agent.integrations import oauth_token_cache + + +def _isolate(user: str = "tester") -> None: + oauth_token_cache.clear_user(user) + + +def test_get_returns_none_when_unset(): + _isolate() + assert oauth_token_cache.get("tester", "google") is None + + +def test_set_then_get_roundtrip(): + _isolate() + oauth_token_cache.set("tester", "google", "tok-1") + assert oauth_token_cache.get("tester", "google") == "tok-1" + + +def test_per_user_isolation(): + oauth_token_cache.clear_user("alice") + oauth_token_cache.clear_user("bob") + oauth_token_cache.set("alice", "google", "alice-tok") + oauth_token_cache.set("bob", "google", "bob-tok") + + assert oauth_token_cache.get("alice", "google") == "alice-tok" + assert oauth_token_cache.get("bob", "google") == "bob-tok" + + +def test_per_provider_isolation(): + _isolate() + oauth_token_cache.set("tester", "google", "g-tok") + oauth_token_cache.set("tester", "github", "gh-tok") + + assert oauth_token_cache.get("tester", "google") == "g-tok" + assert oauth_token_cache.get("tester", "github") == "gh-tok" + + +def test_clear_user_drops_only_that_user(): + oauth_token_cache.clear_user("alice") + oauth_token_cache.clear_user("bob") + oauth_token_cache.set("alice", "google", "a") + oauth_token_cache.set("bob", "google", "b") + + removed = oauth_token_cache.clear_user("alice") + + assert removed == 1 + assert oauth_token_cache.get("alice", "google") is None + assert oauth_token_cache.get("bob", "google") == "b" + + +def test_clear_user_provider_drops_only_that_pair(): + _isolate() + oauth_token_cache.set("tester", "google", "g") + oauth_token_cache.set("tester", "github", "gh") + + oauth_token_cache.clear_user_provider("tester", "google") + + assert oauth_token_cache.get("tester", "google") is None + assert oauth_token_cache.get("tester", "github") == "gh" diff --git a/backend/tests/agents/main_agent/session/test_oauth_consent_hook.py b/backend/tests/agents/main_agent/session/test_oauth_consent_hook.py new file mode 100644 index 00000000..570af4ea --- /dev/null +++ b/backend/tests/agents/main_agent/session/test_oauth_consent_hook.py @@ -0,0 +1,787 @@ +"""Tests for OAuthConsentHook. + +Covers the lazy token-resolution path: hook fires before each tool call, +asks AgentCore Identity for the user's token, caches it on a hit, and +raises a Strands interrupt with the consent URL on a miss. Resume is +exercised by pre-seeding the interrupt with a response so the second +`event.interrupt(...)` returns instead of raising. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from strands.interrupt import Interrupt, InterruptException + +from agents.main_agent.integrations import oauth_token_cache +from agents.main_agent.integrations.agentcore_identity import ( + TokenResult, + WorkloadTokenUnavailableError, +) +from agents.main_agent.session.hooks.oauth_consent import ( + OAuthConsentHook, + _looks_like_auth_failure, +) + + +@pytest.fixture(autouse=True) +def _clear_cache(): + """Token cache is process-global; isolate between tests.""" + oauth_token_cache.clear_user("alice") + yield + oauth_token_cache.clear_user("alice") + + +def _make_event(provider_id: str | None, *, agent=None) -> MagicMock: + """Build a stand-in for `BeforeToolCallEvent`. + + The hook reads `event.selected_tool` (passed straight to + `provider_lookup`) and calls `event.interrupt(...)`. We forward + `interrupt` to a real `_Interruptible.interrupt` style implementation + so the test exercises the same raise/return semantics as the SDK. + """ + event = MagicMock() + event.selected_tool = MagicMock() + event.cancel_tool = None + + agent = agent or MagicMock() + agent._interrupt_state = MagicMock() + agent._interrupt_state.interrupts = {} + + def interrupt(name: str, reason=None, response=None): + # Mirror the SDK: deterministic id keyed on the name so the second + # call returns the response instead of raising. + interrupt_id = f"v1:before_tool_call:tu_test:{name}" + existing = agent._interrupt_state.interrupts.setdefault( + interrupt_id, Interrupt(interrupt_id, name, reason, response) + ) + if existing.response is not None: + return existing.response + raise InterruptException(existing) + + event.interrupt = interrupt + event._agent = agent # for tests that want to inspect interrupt state + return event + + +class TestOAuthConsentHookCacheHit: + @pytest.mark.asyncio + async def test_no_op_when_tool_not_oauth_gated(self): + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: None, + scopes_lookup=lambda _: [], + ) + event = _make_event(provider_id=None) + + await hook._gate(event) + + assert event.cancel_tool is None + + @pytest.mark.asyncio + async def test_disconnected_lookup_bypasses_token_cache(self): + """When the durable disconnect flag is set, the in-process token + cache must not short-circuit — even on the same replica that holds + a warm token. Confirms the source-of-truth reordering: DDB first, + then cache.""" + oauth_token_cache.set("alice", "google", "stale-cached-token") + + identity = MagicMock() + identity.get_token_for_user = AsyncMock( + return_value=TokenResult(authorization_url="https://accounts/consent") + ) + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: ["openid"], + disconnected_lookup=lambda _pid: True, + ) + event = _make_event(provider_id="google") + + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + with pytest.raises(InterruptException): + await hook._gate(event) + + # Identity was consulted with force_authentication=True so AgentCore + # bypasses its vault and returns a fresh consent URL. + identity.get_token_for_user.assert_called_once() + kwargs = identity.get_token_for_user.call_args.kwargs + assert kwargs["force_authentication"] is True + + @pytest.mark.asyncio + async def test_uses_cached_token_without_calling_identity(self): + oauth_token_cache.set("alice", "google", "cached-token") + + identity = MagicMock() + identity.get_token_for_user = AsyncMock() + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: ["openid"], + ) + event = _make_event(provider_id="google") + + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + await hook._gate(event) + + identity.get_token_for_user.assert_not_called() + assert event.cancel_tool is None + + +class TestOAuthConsentHookVaultHit: + @pytest.mark.asyncio + async def test_warms_cache_when_vault_returns_token(self): + identity = MagicMock() + identity.get_token_for_user = AsyncMock( + return_value=TokenResult(access_token="tok-from-vault") + ) + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: ["openid"], + ) + event = _make_event(provider_id="google") + + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + await hook._gate(event) + + assert oauth_token_cache.get("alice", "google") == "tok-from-vault" + identity.get_token_for_user.assert_called_once() + kwargs = identity.get_token_for_user.call_args.kwargs + assert kwargs["provider_name"] == "google" + assert kwargs["scopes"] == ["openid"] + assert kwargs["user_id"] == "alice" + assert kwargs["force_authentication"] is False + + +class TestOAuthConsentHookConsentRequired: + @pytest.mark.asyncio + async def test_raises_interrupt_with_oauth_required_reason(self): + identity = MagicMock() + identity.get_token_for_user = AsyncMock( + return_value=TokenResult(authorization_url="https://accounts/consent") + ) + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: ["openid"], + ) + event = _make_event(provider_id="google") + + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + with pytest.raises(InterruptException) as excinfo: + await hook._gate(event) + + interrupt = excinfo.value.interrupt + assert interrupt.name == "oauth:google" + assert interrupt.reason == { + "type": "oauth_required", + "providerId": "google", + "authorizationUrl": "https://accounts/consent", + } + # Cache stays empty until consent actually completes. + assert oauth_token_cache.get("alice", "google") is None + + @pytest.mark.asyncio + async def test_resume_warms_cache_with_post_consent_token(self): + """On resume the SDK pre-populates the interrupt's response so + `event.interrupt(...)` returns. The hook then re-fetches from the + vault (which now has a token) and primes the cache so subsequent + MCP requests pick up the bearer token without another round trip.""" + identity = MagicMock() + # First call: consent required. Second call (post-consent): token. + identity.get_token_for_user = AsyncMock( + side_effect=[ + TokenResult(authorization_url="https://accounts/consent"), + TokenResult(access_token="post-consent-token"), + ] + ) + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: ["openid"], + ) + event = _make_event(provider_id="google") + + # Pre-seed the interrupt with a response — simulates the SDK + # restoring `_interrupt_state` before re-running the hook on resume. + agent = event._agent + interrupt_id = "v1:before_tool_call:tu_test:oauth:google" + agent._interrupt_state.interrupts[interrupt_id] = Interrupt( + interrupt_id, "oauth:google", reason=None, response="consented" + ) + + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + await hook._gate(event) + + assert oauth_token_cache.get("alice", "google") == "post-consent-token" + assert event.cancel_tool is None + + @pytest.mark.asyncio + async def test_resume_without_token_cancels_tool(self): + """If the user closes the popup mid-flow, AgentCore's vault stays + empty. Resuming surfaces this as a cancel_tool so the model + gets a tool_error and can apologize/replan instead of looping.""" + identity = MagicMock() + identity.get_token_for_user = AsyncMock( + side_effect=[ + TokenResult(authorization_url="https://accounts/consent"), + TokenResult(authorization_url="https://accounts/consent"), + ] + ) + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: ["openid"], + ) + event = _make_event(provider_id="google") + agent = event._agent + interrupt_id = "v1:before_tool_call:tu_test:oauth:google" + agent._interrupt_state.interrupts[interrupt_id] = Interrupt( + interrupt_id, "oauth:google", reason=None, response="consented" + ) + + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + await hook._gate(event) + + assert event.cancel_tool is not None + assert "google" in event.cancel_tool + + +class TestParallelToolCallsSameProvider: + """Regression guard for the OAuth interrupt collision concern. + + `_gate` calls `event.interrupt(name=f"oauth:{provider_id}")` with a + name that is *not* unique across parallel tool calls to the same + provider. We rely on Strands' BeforeToolCallEvent._interrupt_id to + fold `tool_use.toolUseId` into the final id, so two parallel calls + produce distinct entries in `_interrupt_state.interrupts`. + + If Strands ever drops `toolUseId` from the id formula, this test + fails and we'd need to incorporate it ourselves in the hook. + """ + + def test_parallel_tool_calls_same_provider_produce_distinct_interrupt_ids(self): + from strands.hooks import BeforeToolCallEvent + + agent = MagicMock() + event_a = BeforeToolCallEvent( + agent=agent, + selected_tool=MagicMock(), + tool_use={"toolUseId": "tu_parallel_a", "name": "search"}, + invocation_state={}, + ) + event_b = BeforeToolCallEvent( + agent=agent, + selected_tool=MagicMock(), + tool_use={"toolUseId": "tu_parallel_b", "name": "search"}, + invocation_state={}, + ) + + id_a = event_a._interrupt_id("oauth:google") + id_b = event_b._interrupt_id("oauth:google") + + assert id_a != id_b, ( + "Strands no longer disambiguates BeforeToolCallEvent interrupts " + "by toolUseId. OAuthConsentHook must now incorporate toolUseId " + "into the interrupt name to prevent parallel-call collision." + ) + assert "tu_parallel_a" in id_a + assert "tu_parallel_b" in id_b + + +class TestOAuthConsentHookAuthFailureRetry: + """The AfterToolCallEvent handler turns a 401-style tool error into + a retry that forces re-consent at AgentCore Identity.""" + + def _after_event( + self, + provider_id: str | None, + result_text: str, + *, + result_status: str = "error", + tool_use_id: str = "tu_1", + tool_name: str = "whoami", + ) -> MagicMock: + event = MagicMock() + event.selected_tool = MagicMock() + event.tool_use = {"name": tool_name, "toolUseId": tool_use_id} + event.invocation_state = {} + event.result = { + "toolUseId": tool_use_id, + "status": result_status, + "content": [{"text": result_text}], + } + event.retry = False + return event + + @pytest.mark.asyncio + async def test_401_records_disconnect_and_retries(self): + recorded: list[str] = [] + + async def mark_disconnected(pid: str) -> None: + recorded.append(pid) + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: [], + mark_disconnected=mark_disconnected, + ) + oauth_token_cache.set("alice", "google", "stale-token") + event = self._after_event( + "google", + "Error executing tool whoami: Google rejected the OAuth token (401).", + ) + + await hook._handle_auth_failure(event) + + assert event.retry is True + # Durable record of the disconnect intent so other replicas force + # fresh consent on the next request, too. + assert recorded == ["google"] + # Local cache cleared so the BeforeToolCallEvent retry doesn't + # short-circuit on this replica. + assert oauth_token_cache.get("alice", "google") is None + + @pytest.mark.asyncio + async def test_non_oauth_tool_is_ignored(self): + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: None, + scopes_lookup=lambda _: [], + ) + event = self._after_event(None, "401 Unauthorized") + + await hook._handle_auth_failure(event) + + assert event.retry is False + + @pytest.mark.asyncio + async def test_non_auth_error_is_ignored(self): + recorded: list[str] = [] + + async def mark_disconnected(pid: str) -> None: + recorded.append(pid) + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: [], + mark_disconnected=mark_disconnected, + ) + event = self._after_event("google", "Network unreachable") + + await hook._handle_auth_failure(event) + + assert event.retry is False + # No disconnect persisted — the failure wasn't auth-related. + assert recorded == [] + + @pytest.mark.asyncio + async def test_does_not_retry_twice_for_same_tool_use(self): + """Second 401 in the same retry cycle must not loop forever.""" + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: [], + ) + event1 = self._after_event("google", "401 Unauthorized") + await hook._handle_auth_failure(event1) + assert event1.retry is True + + # Same tool_use, second failure must surrender so the user sees + # the error instead of looping. + event2 = self._after_event("google", "401 Unauthorized") + await hook._handle_auth_failure(event2) + assert event2.retry is False + + @pytest.mark.asyncio + async def test_caps_retry_across_tool_calls_in_same_turn(self): + """A misconfigured provider would otherwise spawn a consent prompt + on every tool call in a turn. Cap at one retry per provider per + turn so subsequent 401s for the same provider just surface to the + model instead of triggering another consent flow.""" + recorded: list[str] = [] + + async def mark_disconnected(pid: str) -> None: + recorded.append(pid) + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: [], + mark_disconnected=mark_disconnected, + ) + + # First tool call 401s — retry path fires. + event1 = self._after_event( + "google", "401 Unauthorized", tool_use_id="tu_1", tool_name="search" + ) + await hook._handle_auth_failure(event1) + assert event1.retry is True + + # A *different* tool call (different toolUseId) on the same + # provider 401s later in the same turn. The per-turn cap must + # block another retry even though invocation_state is fresh. + event2 = self._after_event( + "google", "401 Unauthorized", tool_use_id="tu_2", tool_name="list" + ) + await hook._handle_auth_failure(event2) + assert event2.retry is False + # Disconnect was already recorded on the first 401 — don't write + # again. + assert recorded == ["google"] + + @pytest.mark.asyncio + async def test_before_invocation_event_resets_per_turn_budget(self): + """The agent instance is cached across turns by `get_agent`, so + the per-provider retry budget on the hook must be reset whenever + a new agent invocation begins (fresh turn or resume).""" + from unittest.mock import MagicMock + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: [], + ) + + event1 = self._after_event("google", "401 Unauthorized", tool_use_id="tu_1") + await hook._handle_auth_failure(event1) + assert event1.retry is True + + # Simulate a new turn starting (Strands fires BeforeInvocationEvent + # on each `agent.stream_async` call, including resume). + hook._on_invocation_start(MagicMock()) + + event2 = self._after_event("google", "401 Unauthorized", tool_use_id="tu_2") + await hook._handle_auth_failure(event2) + assert event2.retry is True + + @pytest.mark.asyncio + async def test_cap_is_per_provider_not_global(self): + """One provider hitting its cap mustn't starve a different + provider that just happens to 401 later in the same turn.""" + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda tool: getattr(tool, "_provider", None), + scopes_lookup=lambda _: [], + ) + + google_event = self._after_event( + "google", "401 Unauthorized", tool_use_id="tu_g" + ) + google_event.selected_tool._provider = "google" + await hook._handle_auth_failure(google_event) + assert google_event.retry is True + + slack_event = self._after_event( + "slack", "401 Unauthorized", tool_use_id="tu_s" + ) + slack_event.selected_tool._provider = "slack" + await hook._handle_auth_failure(slack_event) + assert slack_event.retry is True + + +class TestOAuthConsentHookErrors: + @pytest.mark.asyncio + async def test_workload_token_unavailable_lets_tool_proceed(self): + """A misconfigured runtime context shouldn't crash the agent; the + tool runs, the MCP server 401s, and the failure surfaces as a + normal tool_error the user can act on.""" + identity = MagicMock() + identity.get_token_for_user = AsyncMock( + side_effect=WorkloadTokenUnavailableError("no ctx") + ) + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: ["openid"], + ) + event = _make_event(provider_id="google") + + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + await hook._gate(event) # must not raise + + assert event.cancel_tool is None + assert oauth_token_cache.get("alice", "google") is None + + @pytest.mark.asyncio + async def test_scopes_lookup_can_be_async(self): + """Hook accepts async scopes_lookup so callers can read directly + from an async repository without a sync wrapper.""" + identity = MagicMock() + identity.get_token_for_user = AsyncMock( + return_value=TokenResult(access_token="t") + ) + + async def async_scopes(_pid: str) -> list[str]: + return ["openid", "profile"] + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=async_scopes, + ) + event = _make_event(provider_id="google") + + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + await hook._gate(event) + + kwargs = identity.get_token_for_user.call_args.kwargs + assert kwargs["scopes"] == ["openid", "profile"] + + @pytest.mark.asyncio + async def test_scopes_lookup_is_cached_across_calls(self): + """Repeated tool calls for the same provider hit the scopes lookup + once per hook lifetime (one agent invocation).""" + identity = MagicMock() + identity.get_token_for_user = AsyncMock( + return_value=TokenResult(access_token="t") + ) + + scopes_lookup = MagicMock(return_value=["openid"]) + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=scopes_lookup, + ) + + # First call hits identity (and the lookup). + event1 = _make_event(provider_id="google") + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + await hook._gate(event1) + + # Cache now warm — second call short-circuits before identity. + # Force a vault fetch by clearing the token cache. + oauth_token_cache.clear_user("alice") + event2 = _make_event(provider_id="google") + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + await hook._gate(event2) + + assert scopes_lookup.call_count == 1 + + @pytest.mark.asyncio + async def test_provider_type_lookup_forwards_custom_parameters(self): + """When the provider is Google, the hook forwards + `custom_parameters={"access_type": "offline"}` to AgentCore Identity + so Google issues a refresh token (vault entry would otherwise expire + after ~1 hour with no refresh path).""" + identity = MagicMock() + identity.get_token_for_user = AsyncMock( + return_value=TokenResult(access_token="t") + ) + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: ["openid"], + provider_type_lookup=lambda _: "google", + ) + + event = _make_event(provider_id="google") + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + await hook._gate(event) + + identity.get_token_for_user.assert_called_once() + assert identity.get_token_for_user.call_args.kwargs["custom_parameters"] == { + "access_type": "offline", + } + + @pytest.mark.asyncio + async def test_admin_custom_parameters_merge_with_baseline(self): + """Hook merges admin-supplied extras (e.g. Google `hd=` for Workspace + domain restriction) with the vendor baseline before forwarding to + AgentCore. Baseline still wins on conflict.""" + identity = MagicMock() + identity.get_token_for_user = AsyncMock( + return_value=TokenResult(access_token="t") + ) + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "google", + scopes_lookup=lambda _: ["openid"], + provider_type_lookup=lambda _: "google", + custom_parameters_lookup=lambda _: { + "hd": "mycompany.com", + "access_type": "online", # admin attempts override; ignored + }, + ) + + event = _make_event(provider_id="google") + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + await hook._gate(event) + + identity.get_token_for_user.assert_called_once() + assert identity.get_token_for_user.call_args.kwargs["custom_parameters"] == { + "access_type": "offline", # baseline wins + "hd": "mycompany.com", + } + + @pytest.mark.asyncio + async def test_no_provider_type_lookup_omits_custom_parameters(self): + """When the lookup is omitted (legacy callers / non-Google vendors), + no `custom_parameters` is sent — AgentCore handles vendor defaults + and we don't accidentally inject Google-specific keys elsewhere.""" + identity = MagicMock() + identity.get_token_for_user = AsyncMock( + return_value=TokenResult(access_token="t") + ) + + hook = OAuthConsentHook( + user_id="alice", + provider_lookup=lambda _tool: "github", + scopes_lookup=lambda _: ["read:user"], + # no provider_type_lookup + ) + + event = _make_event(provider_id="github") + with patch( + "agents.main_agent.session.hooks.oauth_consent.get_agentcore_identity_client", + return_value=identity, + ): + await hook._gate(event) + + identity.get_token_for_user.assert_called_once() + assert identity.get_token_for_user.call_args.kwargs["custom_parameters"] is None + + +class TestLooksLikeAuthFailure: + """Detector must fire on genuine auth failures and ignore everything + else — paths containing `401`, non-error statuses, benign prose. + """ + + def _err(self, text: str) -> dict: + return {"status": "error", "content": [{"text": text}]} + + def _ok(self, text: str) -> dict: + return {"status": "success", "content": [{"text": text}]} + + @pytest.mark.parametrize( + "text", + [ + # HTTP 401 in various shapes. + "HTTP 401 Unauthorized", + "Request failed: 401", + "status=401 message=unauthorized", + "401 Client Error: Unauthorized for url: https://...", + # "Unauthorized" paired with an HTTP/status/code keyword. + "HTTP response: Unauthorized", + "status code Unauthorized", + # Unambiguous OAuth/token signals stand alone. + "The server rejected the OAuth token", + "invalid_token", + "invalid-token", + "invalid token", + "expired_token", + "token expired", + "token_expired", + "oauth token expired", + "oauth token has expired", + # Refresh-token revocation surfaces with this OAuth error code. + "invalid_grant: Token has been expired or revoked", + # Google API auth signals. + "Request had invalid authentication credentials", + "Request had invalid_authentication_credentials", + 'status "UNAUTHENTICATED"', + ], + ) + def test_matches_genuine_auth_errors(self, text): + assert _looks_like_auth_failure(self._err(text)) is True + + @pytest.mark.parametrize( + "text", + [ + # Path segments containing 401 — previously false-positive. + "GET /v1/401/foo failed with 500", + "https://example.com/api/401/items returned empty", + # Digits embedded in other numbers. + "returned 4010 rows", + "status 14011", + # Token as substring of longer words should not match. + "refreshtokenRequired", + "ExpiredTokens", # plural — not \btoken\b + # Prose mentions of "unauthorized" without HTTP/status context. + # Previously fired off the bare \bunauthorized\b alternative; + # tightening means application-level "not authorized" prose no + # longer triggers an OAuth re-auth. + "The weather today is unauthorized-feeling, but fine", + "Unauthorized", # bare — too ambiguous on its own + "You are not authorized to view this calendar entry", + # Prose that shouldn't trigger anything. + "Everything is fine, nothing to see here", + "Rate limit exceeded", + "500 Internal Server Error", + # "PERMISSION_DENIED" is intentionally NOT matched — it's a + # scope/ACL problem at the provider, not an OAuth credential + # failure, and re-consenting won't change the outcome. + "PERMISSION_DENIED", + "Insufficient permissions", + ], + ) + def test_avoids_false_positives(self, text): + assert _looks_like_auth_failure(self._err(text)) is False + + def test_ignores_non_error_status(self): + # Even an auth-shaped body doesn't count if status is success. + assert _looks_like_auth_failure(self._ok("401 Unauthorized")) is False + + def test_ignores_non_dict_result(self): + assert _looks_like_auth_failure("HTTP 401 Unauthorized") is False + assert _looks_like_auth_failure(None) is False + assert _looks_like_auth_failure(["401"]) is False + + def test_ignores_missing_content(self): + assert _looks_like_auth_failure({"status": "error"}) is False + assert _looks_like_auth_failure({"status": "error", "content": None}) is False + assert _looks_like_auth_failure({"status": "error", "content": []}) is False + + def test_ignores_non_dict_content_blocks(self): + result = {"status": "error", "content": ["401 Unauthorized"]} + assert _looks_like_auth_failure(result) is False diff --git a/backend/tests/apis/app_api/test_admin_oauth_rollback.py b/backend/tests/apis/app_api/test_admin_oauth_rollback.py new file mode 100644 index 00000000..898fc1f8 --- /dev/null +++ b/backend/tests/apis/app_api/test_admin_oauth_rollback.py @@ -0,0 +1,119 @@ +"""Tests for the admin OAuth create-provider rollback helpers. + +Verifies retry + CloudWatch-metric behaviour of +`_rollback_orphaned_provider` and `_emit_orphan_metric`. These run +inside the admin `create_provider` route after a DB write fails; if +both the rollback AND the metric emit fail, the admin's original error +still propagates — these helpers must never raise. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +from botocore.exceptions import ClientError + +from apis.app_api.admin.oauth import routes + + +def _client_error(code: str = "ThrottlingException") -> ClientError: + return ClientError( + error_response={"Error": {"Code": code, "Message": code}}, + operation_name="DeleteOauth2CredentialProvider", + ) + + +@pytest.fixture(autouse=True) +def _reset_cw_client_cache(): + routes._cloudwatch_client.cache_clear() + yield + routes._cloudwatch_client.cache_clear() + + +@pytest.fixture +def fast_backoff(monkeypatch): + """Rollback retries sleep between attempts; zero those out in tests.""" + monkeypatch.setattr(routes, "_ROLLBACK_RETRY_DELAYS_SECONDS", (0.0, 0.0)) + + +class TestRollbackOrphanedProvider: + @pytest.mark.asyncio + async def test_succeeds_on_first_attempt(self, fast_backoff): + registrar = MagicMock() + registrar.delete_credential_provider.return_value = None + + await routes._rollback_orphaned_provider(registrar, "google") + + registrar.delete_credential_provider.assert_called_once_with("google") + + @pytest.mark.asyncio + async def test_retries_on_transient_error(self, fast_backoff): + registrar = MagicMock() + registrar.delete_credential_provider.side_effect = [ + _client_error("ThrottlingException"), + None, + ] + + await routes._rollback_orphaned_provider(registrar, "google") + + assert registrar.delete_credential_provider.call_count == 2 + + @pytest.mark.asyncio + async def test_emits_metric_when_all_retries_exhausted(self, fast_backoff): + registrar = MagicMock() + registrar.delete_credential_provider.side_effect = _client_error( + "InternalServerException" + ) + + with patch.object(routes, "_emit_orphan_metric") as emit: + await routes._rollback_orphaned_provider(registrar, "google") + + # 1 initial + 2 retries = 3 attempts with zero backoff schedule. + assert registrar.delete_credential_provider.call_count == 3 + emit.assert_called_once_with("google") + + @pytest.mark.asyncio + async def test_never_raises_even_when_everything_fails(self, fast_backoff): + """Rollback runs inside an `except` block that already re-raises + the admin-facing error. A secondary raise here would shadow it. + + Simulates the worst case: every registrar call fails AND + CloudWatch is down. `_emit_orphan_metric` swallows its own + errors, so the overall helper must return cleanly. + """ + registrar = MagicMock() + registrar.delete_credential_provider.side_effect = RuntimeError("boom") + + fake_cw = MagicMock() + fake_cw.put_metric_data.side_effect = RuntimeError("cw boom") + + with patch.object(routes, "_cloudwatch_client", lambda: fake_cw): + # Must not raise — secondary failures only log. + await routes._rollback_orphaned_provider(registrar, "google") + + +class TestEmitOrphanMetric: + def test_puts_metric_with_provider_dimension(self): + fake_cw = MagicMock() + with patch.object(routes, "_cloudwatch_client", lambda: fake_cw): + routes._emit_orphan_metric("google-workspace") + + fake_cw.put_metric_data.assert_called_once() + call_kwargs = fake_cw.put_metric_data.call_args.kwargs + assert call_kwargs["Namespace"] == "Agentcore/OAuth" + metric = call_kwargs["MetricData"][0] + assert metric["MetricName"] == "ProviderOrphaned" + assert metric["Dimensions"] == [ + {"Name": "ProviderId", "Value": "google-workspace"} + ] + assert metric["Value"] == 1 + assert metric["Unit"] == "Count" + + def test_swallows_cloudwatch_failure(self): + """CloudWatch outage must not shadow the admin's create failure.""" + fake_cw = MagicMock() + fake_cw.put_metric_data.side_effect = RuntimeError("cw down") + + with patch.object(routes, "_cloudwatch_client", lambda: fake_cw): + routes._emit_orphan_metric("google") # no raise diff --git a/backend/tests/apis/app_api/tools/test_freshness.py b/backend/tests/apis/app_api/tools/test_freshness.py new file mode 100644 index 00000000..600d9698 --- /dev/null +++ b/backend/tests/apis/app_api/tools/test_freshness.py @@ -0,0 +1,179 @@ +"""Tests for the tool-config freshness TTL cache.""" + +import asyncio +from datetime import datetime, timezone +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +import pytest + +from apis.app_api.tools import freshness + + +@pytest.fixture(autouse=True) +def _clear_cache(): + freshness._reset_for_tests() + yield + freshness._reset_for_tests() + + +def _tool(updated_at: datetime): + return SimpleNamespace(updated_at=updated_at) + + +@pytest.mark.asyncio +async def test_empty_tool_list_returns_empty_hash(): + assert await freshness.get_freshness_hash([]) == "" + + +@pytest.mark.asyncio +async def test_hash_reflects_updated_at_changes(): + repo = SimpleNamespace( + get_tool=AsyncMock( + return_value=_tool(datetime(2025, 1, 1, tzinfo=timezone.utc)) + ) + ) + with patch( + "apis.app_api.tools.repository.get_tool_catalog_repository", + return_value=repo, + ): + h1 = await freshness.get_freshness_hash(["gmail"]) + + # Invalidate so the next call re-fetches instead of hitting the TTL cache. + freshness.invalidate("gmail") + + repo.get_tool = AsyncMock( + return_value=_tool(datetime(2025, 2, 1, tzinfo=timezone.utc)) + ) + with patch( + "apis.app_api.tools.repository.get_tool_catalog_repository", + return_value=repo, + ): + h2 = await freshness.get_freshness_hash(["gmail"]) + + assert h1 != h2 + + +@pytest.mark.asyncio +async def test_ttl_avoids_repeat_reads_within_window(): + """Second call in the TTL window must not hit the repository.""" + repo = SimpleNamespace( + get_tool=AsyncMock( + return_value=_tool(datetime(2025, 1, 1, tzinfo=timezone.utc)) + ) + ) + with patch( + "apis.app_api.tools.repository.get_tool_catalog_repository", + return_value=repo, + ): + await freshness.get_freshness_hash(["gmail"]) + await freshness.get_freshness_hash(["gmail"]) + await freshness.get_freshness_hash(["gmail"]) + + assert repo.get_tool.await_count == 1 + + +@pytest.mark.asyncio +async def test_invalidate_forces_refetch(): + repo = SimpleNamespace( + get_tool=AsyncMock( + return_value=_tool(datetime(2025, 1, 1, tzinfo=timezone.utc)) + ) + ) + with patch( + "apis.app_api.tools.repository.get_tool_catalog_repository", + return_value=repo, + ): + await freshness.get_tool_updated_at("gmail") + freshness.invalidate("gmail") + await freshness.get_tool_updated_at("gmail") + + assert repo.get_tool.await_count == 2 + + +@pytest.mark.asyncio +async def test_invalidate_all_clears_every_entry(): + repo = SimpleNamespace( + get_tool=AsyncMock( + return_value=_tool(datetime(2025, 1, 1, tzinfo=timezone.utc)) + ) + ) + with patch( + "apis.app_api.tools.repository.get_tool_catalog_repository", + return_value=repo, + ): + await freshness.get_tool_updated_at("gmail") + await freshness.get_tool_updated_at("jira") + + freshness.invalidate() + assert freshness._cache == {} + + +@pytest.mark.asyncio +async def test_missing_tool_is_cached_as_none(): + """A deleted or never-existed tool must not cause a DB hit every turn.""" + repo = SimpleNamespace(get_tool=AsyncMock(return_value=None)) + with patch( + "apis.app_api.tools.repository.get_tool_catalog_repository", + return_value=repo, + ): + result1 = await freshness.get_tool_updated_at("ghost") + result2 = await freshness.get_tool_updated_at("ghost") + + assert result1 is None + assert result2 is None + assert repo.get_tool.await_count == 1 + + +@pytest.mark.asyncio +async def test_repository_error_does_not_raise(): + """Freshness is advisory — a DB blip must not fail the chat turn.""" + repo = SimpleNamespace(get_tool=AsyncMock(side_effect=RuntimeError("boom"))) + with patch( + "apis.app_api.tools.repository.get_tool_catalog_repository", + return_value=repo, + ): + result = await freshness.get_tool_updated_at("gmail") + + assert result is None + + +@pytest.mark.asyncio +async def test_repository_error_falls_back_to_last_known_value(): + repo_ok = SimpleNamespace( + get_tool=AsyncMock( + return_value=_tool(datetime(2025, 1, 1, tzinfo=timezone.utc)) + ) + ) + with patch( + "apis.app_api.tools.repository.get_tool_catalog_repository", + return_value=repo_ok, + ): + await freshness.get_tool_updated_at("gmail") + + freshness.invalidate("gmail") + + repo_err = SimpleNamespace(get_tool=AsyncMock(side_effect=RuntimeError("boom"))) + with patch( + "apis.app_api.tools.repository.get_tool_catalog_repository", + return_value=repo_err, + ): + # With invalidate cleared the cache entry, we should return None on error. + assert await freshness.get_tool_updated_at("gmail") is None + + +@pytest.mark.asyncio +async def test_hash_is_stable_regardless_of_input_order(): + repo = SimpleNamespace( + get_tool=AsyncMock( + return_value=_tool(datetime(2025, 1, 1, tzinfo=timezone.utc)) + ) + ) + with patch( + "apis.app_api.tools.repository.get_tool_catalog_repository", + return_value=repo, + ): + h1 = await freshness.get_freshness_hash(["gmail", "jira"]) + h2 = await freshness.get_freshness_hash(["jira", "gmail"]) + + assert h1 == h2 diff --git a/backend/tests/apis/inference_api/test_agentcore_context_middleware.py b/backend/tests/apis/inference_api/test_agentcore_context_middleware.py new file mode 100644 index 00000000..58448145 --- /dev/null +++ b/backend/tests/apis/inference_api/test_agentcore_context_middleware.py @@ -0,0 +1,189 @@ +"""Tests for AgentCoreContextMiddleware. + +Verifies that Runtime headers are copied into BedrockAgentCoreContext on +each request and that the middleware is a no-op when headers are absent +(local development, unit tests without Runtime). +""" + +from unittest.mock import patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from apis.inference_api.middleware.agentcore_context import ( + HEADER_OAUTH2_CALLBACK_URL, + HEADER_REQUEST_ID, + HEADER_SESSION_ID, + HEADER_WORKLOAD_ACCESS_TOKEN, + AgentCoreContextMiddleware, +) + + +@pytest.fixture +def app() -> FastAPI: + app = FastAPI() + app.add_middleware(AgentCoreContextMiddleware) + + @app.get("/echo") + def echo() -> dict: + return {"ok": True} + + return app + + +@pytest.fixture +def client(app: FastAPI) -> TestClient: + return TestClient(app) + + +class TestAgentCoreContextMiddleware: + def test_copies_workload_access_token_to_context(self, client: TestClient) -> None: + with patch( + "apis.inference_api.middleware.agentcore_context.BedrockAgentCoreContext" + ) as ctx: + response = client.get( + "/echo", headers={HEADER_WORKLOAD_ACCESS_TOKEN: "wat-abc123"} + ) + + assert response.status_code == 200 + ctx.set_workload_access_token.assert_called_once_with("wat-abc123") + + def test_copies_allowlisted_oauth2_callback_url_to_context( + self, client: TestClient, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("CORS_ORIGINS", "https://app.example.com,https://staging.example.com") + with patch( + "apis.inference_api.middleware.agentcore_context.BedrockAgentCoreContext" + ) as ctx: + client.get( + "/echo", + headers={HEADER_OAUTH2_CALLBACK_URL: "https://app.example.com/oauth-complete"}, + ) + + ctx.set_oauth2_callback_url.assert_called_once_with( + "https://app.example.com/oauth-complete" + ) + + def test_rejects_callback_url_with_origin_outside_allowlist( + self, client: TestClient, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("CORS_ORIGINS", "https://app.example.com") + with patch( + "apis.inference_api.middleware.agentcore_context.BedrockAgentCoreContext" + ) as ctx: + response = client.get( + "/echo", + headers={HEADER_OAUTH2_CALLBACK_URL: "https://evil.example.com/oauth-complete"}, + ) + + assert response.status_code == 200 + ctx.set_oauth2_callback_url.assert_not_called() + + def test_rejects_callback_url_with_wrong_path( + self, client: TestClient, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("CORS_ORIGINS", "https://app.example.com") + with patch( + "apis.inference_api.middleware.agentcore_context.BedrockAgentCoreContext" + ) as ctx: + client.get( + "/echo", + headers={HEADER_OAUTH2_CALLBACK_URL: "https://app.example.com/admin"}, + ) + + ctx.set_oauth2_callback_url.assert_not_called() + + def test_rejects_callback_url_with_query_or_fragment( + self, client: TestClient, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("CORS_ORIGINS", "https://app.example.com") + with patch( + "apis.inference_api.middleware.agentcore_context.BedrockAgentCoreContext" + ) as ctx: + client.get( + "/echo", + headers={ + HEADER_OAUTH2_CALLBACK_URL: "https://app.example.com/oauth-complete?next=/admin" + }, + ) + client.get( + "/echo", + headers={ + HEADER_OAUTH2_CALLBACK_URL: "https://app.example.com/oauth-complete#x" + }, + ) + + ctx.set_oauth2_callback_url.assert_not_called() + + def test_rejects_callback_url_with_unsupported_scheme( + self, client: TestClient, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("CORS_ORIGINS", "https://app.example.com") + with patch( + "apis.inference_api.middleware.agentcore_context.BedrockAgentCoreContext" + ) as ctx: + client.get( + "/echo", + headers={ + HEADER_OAUTH2_CALLBACK_URL: "javascript:alert(1)//app.example.com/oauth-complete" + }, + ) + + ctx.set_oauth2_callback_url.assert_not_called() + + def test_rejects_callback_url_when_allowlist_empty( + self, client: TestClient, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv("CORS_ORIGINS", raising=False) + with patch( + "apis.inference_api.middleware.agentcore_context.BedrockAgentCoreContext" + ) as ctx: + client.get( + "/echo", + headers={HEADER_OAUTH2_CALLBACK_URL: "https://app.example.com/oauth-complete"}, + ) + + ctx.set_oauth2_callback_url.assert_not_called() + + def test_copies_session_and_request_id_to_context( + self, client: TestClient + ) -> None: + with patch( + "apis.inference_api.middleware.agentcore_context.BedrockAgentCoreContext" + ) as ctx: + client.get( + "/echo", + headers={ + HEADER_SESSION_ID: "sess-1", + HEADER_REQUEST_ID: "req-1", + }, + ) + + ctx.set_request_context.assert_called_once_with( + request_id="req-1", session_id="sess-1" + ) + + def test_noop_when_headers_absent(self, client: TestClient) -> None: + """Local dev and tests without AgentCore Runtime must still work.""" + with patch( + "apis.inference_api.middleware.agentcore_context.BedrockAgentCoreContext" + ) as ctx: + response = client.get("/echo") + + assert response.status_code == 200 + ctx.set_workload_access_token.assert_not_called() + ctx.set_oauth2_callback_url.assert_not_called() + ctx.set_request_context.assert_not_called() + + def test_session_id_defaults_request_id_to_empty(self, client: TestClient) -> None: + """When session is present but request-id header is missing, the + request_id falls back to empty string rather than None.""" + with patch( + "apis.inference_api.middleware.agentcore_context.BedrockAgentCoreContext" + ) as ctx: + client.get("/echo", headers={HEADER_SESSION_ID: "sess-1"}) + + ctx.set_request_context.assert_called_once_with( + request_id="", session_id="sess-1" + ) diff --git a/backend/tests/apis/inference_api/test_connectors_routes.py b/backend/tests/apis/inference_api/test_connectors_routes.py new file mode 100644 index 00000000..b66434c0 --- /dev/null +++ b/backend/tests/apis/inference_api/test_connectors_routes.py @@ -0,0 +1,456 @@ +"""Route-level tests for the inference-API connectors endpoints. + +Covers `complete-consent` (forwards to AgentCore, surfaces errors), and +the side-effect-free `GET /{provider_id}/status`. + +External boundaries (AgentCore control-plane client, identity client, +provider repository, role service) are patched — we test our gating and +response shape, not the downstream calls. +""" + +from __future__ import annotations + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from agents.main_agent.integrations import oauth_token_cache +from agents.main_agent.integrations.agentcore_identity import ( + CallbackUrlUnavailableError, + TokenResult, + WorkloadTokenUnavailableError, +) +from apis.inference_api.connectors import routes +from apis.shared.auth.models import User +from apis.shared.oauth.disconnect_repository import get_disconnect_repository +from apis.shared.oauth.models import OAuthProvider, OAuthProviderType +from apis.shared.oauth.provider_repository import get_provider_repository +from apis.shared.rbac.models import UserEffectivePermissions +from apis.shared.rbac.service import get_app_role_service + + +@pytest.fixture(autouse=True) +def _reset_token_cache(): + """`oauth_token_cache` is process-global; isolate between tests.""" + oauth_token_cache.clear_user("alice") + oauth_token_cache.clear_user("bob") + yield + oauth_token_cache.clear_user("alice") + oauth_token_cache.clear_user("bob") + + +@pytest.fixture(autouse=True) +def _reset_control_client(): + """`_agentcore_control_client` is `lru_cache`d; reset between tests.""" + routes._agentcore_control_client.cache_clear() + yield + routes._agentcore_control_client.cache_clear() + + +def _make_user(user_id: str) -> User: + return User( + user_id=user_id, + email=f"{user_id}@example.com", + name=user_id.capitalize(), + roles=[], + raw_token="test-token", + ) + + +@pytest.fixture +def app_for_user(): + """Build a minimal FastAPI app with the connectors router mounted and + the `get_current_user_trusted` dependency stubbed to a specific user. + Returns a factory so each test picks the caller's identity. + """ + + def _build(user_id: str) -> FastAPI: + app = FastAPI() + app.include_router(routes.router) + app.dependency_overrides[routes.get_current_user_trusted] = lambda: _make_user(user_id) + return app + + return _build + + +class TestCompleteConsent: + """`complete-consent` is a thin wrapper around AgentCore's + `CompleteResourceTokenAuth`. The auth boundary is `current_user` + (verified by `get_current_user_trusted`) — we forward that identity + as `userIdentifier` and AgentCore's own binding rejects mismatches. + """ + + def test_forwards_caller_identity_to_agentcore(self, app_for_user, monkeypatch): + mock_client = MagicMock() + monkeypatch.setattr(routes, "_agentcore_control_client", lambda: mock_client) + + app = app_for_user("alice") + response = TestClient(app).post( + "/connectors/complete-consent", + json={"session_uri": "uri-abc", "provider_id": "google"}, + ) + + assert response.status_code == 200 + assert response.json() == {"ok": True} + mock_client.complete_resource_token_auth.assert_called_once_with( + userIdentifier={"userId": "alice"}, + sessionUri="uri-abc", + ) + + def test_surfaces_agentcore_error_as_502(self, app_for_user, monkeypatch): + mock_client = MagicMock() + mock_client.complete_resource_token_auth.side_effect = RuntimeError("agentcore down") + monkeypatch.setattr(routes, "_agentcore_control_client", lambda: mock_client) + + app = app_for_user("alice") + response = TestClient(app).post( + "/connectors/complete-consent", + json={"session_uri": "uri-abc", "provider_id": "google"}, + ) + + assert response.status_code == 502 + assert "agentcore down" in response.json()["detail"] + + +def _make_provider( + provider_id: str = "google", + *, + enabled: bool = True, + allowed_roles: list[str] | None = None, + custom_parameters: dict[str, str] | None = None, +) -> OAuthProvider: + now = datetime.now(timezone.utc).isoformat() + "Z" + return OAuthProvider( + provider_id=provider_id, + display_name=provider_id.capitalize(), + provider_type=OAuthProviderType.GOOGLE, + scopes=["openid", "email"], + allowed_roles=allowed_roles or [], + enabled=enabled, + custom_parameters=custom_parameters, + created_at=now, + updated_at=now, + ) + + +def _make_permissions(user_id: str, *, roles: list[str] | None = None) -> UserEffectivePermissions: + return UserEffectivePermissions( + user_id=user_id, + app_roles=roles or [], + tools=[], + models=[], + quota_tier=None, + resolved_at=datetime.now(timezone.utc).isoformat() + "Z", + ) + + +class _FakeDisconnectRepo: + """In-memory stand-in for the durable DDB-backed disconnect repository.""" + + def __init__(self) -> None: + self.disconnected: set[tuple[str, str]] = set() + + async def is_disconnected(self, user_id: str, provider_id: str) -> bool: + return (user_id, provider_id) in self.disconnected + + async def mark_disconnected(self, user_id: str, provider_id: str) -> None: + self.disconnected.add((user_id, provider_id)) + + async def clear_disconnected(self, user_id: str, provider_id: str) -> None: + self.disconnected.discard((user_id, provider_id)) + + +@pytest.fixture +def app_with_deps(app_for_user, monkeypatch): + """Mount the router and stub provider repo, role service, identity client. + + Returns a builder so each test wires the specific responses it needs. + """ + + def _build( + user_id: str, + *, + provider: OAuthProvider | None, + permissions: UserEffectivePermissions | None = None, + identity_result: TokenResult | None = None, + identity_raises: Exception | None = None, + disconnect_repo: _FakeDisconnectRepo | None = None, + ) -> tuple[FastAPI, MagicMock, _FakeDisconnectRepo]: + app = app_for_user(user_id) + + repo = MagicMock() + repo.get_provider = AsyncMock(return_value=provider) + app.dependency_overrides[get_provider_repository] = lambda: repo + + role_service = MagicMock() + role_service.resolve_user_permissions = AsyncMock( + return_value=permissions or _make_permissions(user_id), + ) + app.dependency_overrides[get_app_role_service] = lambda: role_service + + disconnect_repo = disconnect_repo or _FakeDisconnectRepo() + app.dependency_overrides[get_disconnect_repository] = lambda: disconnect_repo + + identity = MagicMock() + if identity_raises is not None: + identity.get_token_for_user = AsyncMock(side_effect=identity_raises) + else: + identity.get_token_for_user = AsyncMock( + return_value=identity_result + or TokenResult(access_token="vault-token"), + ) + monkeypatch.setattr(routes, "get_agentcore_identity_client", lambda: identity) + + return app, identity, disconnect_repo + + return _build + + +class TestConnectorStatus: + def test_returns_connected_when_vault_has_token(self, app_with_deps): + app, identity, _ = app_with_deps( + "alice", + provider=_make_provider(), + identity_result=TokenResult(access_token="vault-token"), + ) + response = TestClient(app).get("/connectors/google/status") + + assert response.status_code == 200 + assert response.json() == {"connected": True} + identity.get_token_for_user.assert_called_once() + + def test_returns_not_connected_when_vault_empty(self, app_with_deps): + # The point of /status: when the vault is empty we report it as + # {connected: false} and discard the auth URL — the listing UI + # only wants the badge, not to start a flow. + app, identity, _ = app_with_deps( + "alice", + provider=_make_provider(), + identity_result=TokenResult( + authorization_url="https://example.com/auth?request_uri=abc", + ), + ) + response = TestClient(app).get("/connectors/google/status") + + assert response.status_code == 200 + assert response.json() == {"connected": False} + # The auth URL is intentionally NOT echoed back. + assert "authorization_url" not in response.json() + assert "authorizationUrl" not in response.json() + + def test_404_when_provider_missing(self, app_with_deps): + app, identity, _ = app_with_deps("alice", provider=None) + response = TestClient(app).get("/connectors/google/status") + + assert response.status_code == 404 + identity.get_token_for_user.assert_not_called() + + def test_404_when_provider_disabled(self, app_with_deps): + app, identity, _ = app_with_deps( + "alice", provider=_make_provider(enabled=False) + ) + response = TestClient(app).get("/connectors/google/status") + + # Disabled providers are indistinguishable from missing to the user. + assert response.status_code == 404 + identity.get_token_for_user.assert_not_called() + + def test_403_when_user_lacks_role(self, app_with_deps): + app, identity, _ = app_with_deps( + "alice", + provider=_make_provider(allowed_roles=["admins"]), + permissions=_make_permissions("alice", roles=["users"]), + ) + response = TestClient(app).get("/connectors/google/status") + + assert response.status_code == 403 + identity.get_token_for_user.assert_not_called() + + def test_503_when_workload_token_unavailable(self, app_with_deps): + app, _, _ = app_with_deps( + "alice", + provider=_make_provider(), + identity_raises=WorkloadTokenUnavailableError("no workload token"), + ) + response = TestClient(app).get("/connectors/google/status") + assert response.status_code == 503 + assert "no workload token" in response.json()["detail"] + + def test_503_when_callback_url_unavailable(self, app_with_deps): + app, _, _ = app_with_deps( + "alice", + provider=_make_provider(), + identity_raises=CallbackUrlUnavailableError("no callback URL"), + ) + response = TestClient(app).get("/connectors/google/status") + assert response.status_code == 503 + assert "no callback URL" in response.json()["detail"] + + def test_disconnected_overrides_vault_state(self, app_with_deps): + # After a disconnect, the user is "not connected" even if AgentCore's + # vault still holds a valid token — and AgentCore is not consulted, + # so a stale vault entry can't accidentally flip the badge back on. + # The flag lives in the DDB-backed disconnect repository so a + # disconnect on one replica is honored on every subsequent request, + # even if it lands on a different replica. + repo = _FakeDisconnectRepo() + repo.disconnected.add(("alice", "google")) + app, identity, _ = app_with_deps( + "alice", + provider=_make_provider(), + identity_result=TokenResult(access_token="vault-token"), + disconnect_repo=repo, + ) + response = TestClient(app).get("/connectors/google/status") + + assert response.status_code == 200 + assert response.json() == {"connected": False} + identity.get_token_for_user.assert_not_called() + + +class TestDisconnect: + def test_marks_provider_disconnected_durably(self, app_with_deps): + app, _, repo = app_with_deps("alice", provider=_make_provider()) + assert ("alice", "google") not in repo.disconnected + + response = TestClient(app).delete("/connectors/google/connection") + + assert response.status_code == 204 + assert ("alice", "google") in repo.disconnected + + def test_clears_cached_token_on_disconnect(self, app_with_deps): + # The local hot-path cache must not keep serving the disconnected + # token to in-flight MCP requests — otherwise a tool call still in + # progress on this replica could leak past the disconnect. + oauth_token_cache.set("alice", "google", "warm-token") + app, _, _ = app_with_deps("alice", provider=_make_provider()) + + TestClient(app).delete("/connectors/google/connection") + + assert oauth_token_cache.get("alice", "google") is None + + def test_404_when_provider_missing(self, app_with_deps): + app, _, repo = app_with_deps("alice", provider=None) + response = TestClient(app).delete("/connectors/google/connection") + + assert response.status_code == 404 + assert ("alice", "google") not in repo.disconnected + + def test_403_when_user_lacks_role(self, app_with_deps): + app, _, repo = app_with_deps( + "alice", + provider=_make_provider(allowed_roles=["admins"]), + permissions=_make_permissions("alice", roles=["users"]), + ) + response = TestClient(app).delete("/connectors/google/connection") + + assert response.status_code == 403 + assert ("alice", "google") not in repo.disconnected + + +class TestForceReauthLifecycle: + """End-to-end-ish: disconnect → initiate-consent → complete-consent.""" + + def test_initiate_consent_forces_auth_after_disconnect(self, app_with_deps): + # disconnect → next initiate-consent must pass force_authentication + # so AgentCore returns a fresh authorize URL instead of the cached + # token (which we just told the user we'd stop using). + repo = _FakeDisconnectRepo() + repo.disconnected.add(("alice", "google")) + app, identity, _ = app_with_deps( + "alice", + provider=_make_provider(), + identity_result=TokenResult( + authorization_url="https://example.com/auth?request_uri=abc", + ), + disconnect_repo=repo, + ) + TestClient(app).post("/connectors/google/initiate-consent") + + identity.get_token_for_user.assert_called_once() + assert identity.get_token_for_user.call_args.kwargs["force_authentication"] is True + + def test_initiate_consent_does_not_force_when_not_disconnected(self, app_with_deps): + app, identity, _ = app_with_deps( + "alice", + provider=_make_provider(), + identity_result=TokenResult(access_token="vault-token"), + ) + TestClient(app).post("/connectors/google/initiate-consent") + + identity.get_token_for_user.assert_called_once() + assert identity.get_token_for_user.call_args.kwargs["force_authentication"] is False + + def test_status_forwards_google_access_type_offline(self, app_with_deps): + # Per AgentCore docs, Google needs `access_type=offline` in + # customParameters so the vault gets a refresh token. + app, identity, _ = app_with_deps( + "alice", + provider=_make_provider(), # default is OAuthProviderType.GOOGLE + identity_result=TokenResult(access_token="vault-token"), + ) + TestClient(app).get("/connectors/google/status") + + identity.get_token_for_user.assert_called_once() + assert identity.get_token_for_user.call_args.kwargs["custom_parameters"] == { + "access_type": "offline", + } + + def test_initiate_consent_forwards_google_access_type_offline(self, app_with_deps): + app, identity, _ = app_with_deps( + "alice", + provider=_make_provider(), + identity_result=TokenResult(access_token="vault-token"), + ) + TestClient(app).post("/connectors/google/initiate-consent") + + identity.get_token_for_user.assert_called_once() + assert identity.get_token_for_user.call_args.kwargs["custom_parameters"] == { + "access_type": "offline", + } + + def test_admin_custom_parameters_merge_with_google_baseline(self, app_with_deps): + # Admin set Workspace domain restriction. The route must merge + # admin extras with the hardcoded baseline before forwarding to + # AgentCore — and the baseline still wins on key conflict. + app, identity, _ = app_with_deps( + "alice", + provider=_make_provider( + custom_parameters={ + "hd": "mycompany.com", + "access_type": "online", # admin tries to override; ignored + }, + ), + identity_result=TokenResult(access_token="vault-token"), + ) + TestClient(app).get("/connectors/google/status") + + kwargs = identity.get_token_for_user.call_args.kwargs + assert kwargs["custom_parameters"] == { + "access_type": "offline", # baseline wins + "hd": "mycompany.com", + } + + def test_complete_consent_clears_disconnect_flag(self, app_for_user, monkeypatch): + # After a successful re-consent the disconnect intent is satisfied — + # the next status check should report connected without waiting for + # the agent loop to warm the cache. + repo = _FakeDisconnectRepo() + repo.disconnected.add(("alice", "google")) + + mock_client = MagicMock() + monkeypatch.setattr(routes, "_agentcore_control_client", lambda: mock_client) + + app = app_for_user("alice") + app.dependency_overrides[get_disconnect_repository] = lambda: repo + + response = TestClient(app).post( + "/connectors/complete-consent", + json={"session_uri": "uri-abc", "provider_id": "google"}, + ) + + assert response.status_code == 200 + assert ("alice", "google") not in repo.disconnected diff --git a/backend/tests/routes/conftest.py b/backend/tests/routes/conftest.py index e8587576..884e4fc3 100644 --- a/backend/tests/routes/conftest.py +++ b/backend/tests/routes/conftest.py @@ -19,6 +19,7 @@ os.environ.setdefault("AWS_DEFAULT_REGION", "us-east-1") from typing import Any, Callable, List, Optional +from unittest.mock import AsyncMock import pytest from fastapi import FastAPI, HTTPException, status @@ -28,6 +29,29 @@ from apis.shared.auth.models import User +# --------------------------------------------------------------------------- +# Auto-stub session-metadata pre-stream hook +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _stub_ensure_session_metadata_exists(monkeypatch): + """The /invocations route calls ensure_session_metadata_exists() before + streaming, which raises RuntimeError when DYNAMODB_SESSIONS_METADATA_TABLE_NAME + is unset. Route tests don't exercise metadata persistence, so stub it to a + no-op that reports "session already exists" (False) — this skips the + first-turn title-generation branch too. + + Tests that need real metadata behavior should provision the + `sessions_metadata_table` fixture from tests/shared/conftest.py and + monkeypatch this back to the real implementation. + """ + monkeypatch.setattr( + "apis.inference_api.chat.routes.ensure_session_metadata_exists", + AsyncMock(return_value=False), + ) + + # --------------------------------------------------------------------------- # Requirement 1.3: User factory fixture # --------------------------------------------------------------------------- diff --git a/backend/tests/routes/test_inference.py b/backend/tests/routes/test_inference.py index fdfc59fa..1e08ef0b 100644 --- a/backend/tests/routes/test_inference.py +++ b/backend/tests/routes/test_inference.py @@ -151,18 +151,10 @@ class TestInvocationsInvalid: """POST /invocations with invalid payload returns 422.""" def test_missing_required_fields_returns_422(self, authed_app, authed_client): - """Req 15.3: Missing session_id and message should return 422.""" + """Req 15.3: Missing session_id should return 422.""" resp = authed_client.post("/invocations", json={}) assert resp.status_code == 422 - def test_missing_message_returns_422(self, authed_app, authed_client): - """Req 15.3: Missing message field should return 422.""" - resp = authed_client.post( - "/invocations", - json={"session_id": "sess-001"}, - ) - assert resp.status_code == 422 - def test_missing_session_id_returns_422(self, authed_app, authed_client): """Req 15.3: Missing session_id field should return 422.""" resp = authed_client.post( diff --git a/backend/tests/shared/conftest.py b/backend/tests/shared/conftest.py index b00f2ca2..56f647c0 100644 --- a/backend/tests/shared/conftest.py +++ b/backend/tests/shared/conftest.py @@ -323,21 +323,14 @@ def auth_provider_repository(auth_providers_table, secrets_manager): @pytest.fixture() -def oauth_provider_repository(oauth_providers_table, secrets_manager): +def oauth_provider_repository(oauth_providers_table): from apis.shared.oauth.provider_repository import OAuthProviderRepository return OAuthProviderRepository( table_name="test-oauth-providers", - secrets_arn="oauth-client-secrets", region=AWS_REGION, ) -@pytest.fixture() -def oauth_token_repository(oauth_tokens_table): - from apis.shared.oauth.token_repository import OAuthTokenRepository - return OAuthTokenRepository(table_name="test-oauth-user-tokens", region=AWS_REGION) - - @pytest.fixture() def file_repository(files_table): from apis.shared.files.repository import FileUploadRepository diff --git a/backend/tests/shared/test_conftest_smoke.py b/backend/tests/shared/test_conftest_smoke.py index 5ad26732..9aa29a88 100644 --- a/backend/tests/shared/test_conftest_smoke.py +++ b/backend/tests/shared/test_conftest_smoke.py @@ -71,9 +71,6 @@ def test_auth_provider_repository(self, auth_provider_repository): def test_oauth_provider_repository(self, oauth_provider_repository): assert oauth_provider_repository.enabled - def test_oauth_token_repository(self, oauth_token_repository): - assert oauth_token_repository.enabled - def test_file_repository(self, file_repository): assert file_repository is not None diff --git a/backend/tests/shared/test_coverage_boost.py b/backend/tests/shared/test_coverage_boost.py index 72af7502..0f27173e 100644 --- a/backend/tests/shared/test_coverage_boost.py +++ b/backend/tests/shared/test_coverage_boost.py @@ -185,85 +185,6 @@ def test_convert_message_to_response(self): assert resp.role == "user" -class TestOAuthServiceExtended: - """Cover generate_pkce_pair, state management, disconnect, get_user_connections.""" - - def test_generate_pkce_pair(self): - from apis.shared.oauth.service import generate_pkce_pair - verifier, challenge = generate_pkce_pair() - assert len(verifier) > 20 - assert len(challenge) > 20 - assert verifier != challenge - - @pytest.mark.asyncio - async def test_disconnect(self, oauth_providers_table, oauth_tokens_table, secrets_manager, kms_key_arn, monkeypatch): - monkeypatch.setenv("OAUTH_CALLBACK_URL", "http://localhost/callback") - from apis.shared.oauth.service import OAuthService - from apis.shared.oauth.provider_repository import OAuthProviderRepository - from apis.shared.oauth.token_repository import OAuthTokenRepository - from apis.shared.oauth.encryption import TokenEncryptionService - from apis.shared.oauth.token_cache import TokenCache - from apis.shared.oauth.models import OAuthUserToken, OAuthConnectionStatus - - pr = OAuthProviderRepository(table_name="test-oauth-providers", secrets_arn="oauth-client-secrets", region="us-east-1") - tr = OAuthTokenRepository(table_name="test-oauth-user-tokens", region="us-east-1") - enc = TokenEncryptionService(key_arn=kms_key_arn, region="us-east-1") - cache = TokenCache() - - # Save a token - token = OAuthUserToken( - user_id="u1", provider_id="github", - access_token_encrypted=enc.encrypt("tok123"), - status=OAuthConnectionStatus.CONNECTED, - connected_at="2026-01-01", - ) - await tr.save_token(token) - - svc = OAuthService(provider_repo=pr, token_repo=tr, encryption_service=enc, token_cache=cache) - assert await svc.disconnect("u1", "github") is True - assert await svc.disconnect("u1", "github") is False # already deleted - - @pytest.mark.asyncio - async def test_get_decrypted_token_cached(self, oauth_tokens_table, monkeypatch): - monkeypatch.setenv("OAUTH_CALLBACK_URL", "http://localhost/callback") - from apis.shared.oauth.service import OAuthService - from apis.shared.oauth.token_cache import TokenCache - - cache = TokenCache() - cache.set("u1", "github", "cached_token") - svc = OAuthService(token_cache=cache) - result = await svc.get_decrypted_token("u1", "github") - assert result == "cached_token" - - @pytest.mark.asyncio - async def test_get_user_connections(self, oauth_providers_table, oauth_tokens_table, secrets_manager, kms_key_arn, monkeypatch): - monkeypatch.setenv("OAUTH_CALLBACK_URL", "http://localhost/callback") - from apis.shared.oauth.service import OAuthService - from apis.shared.oauth.provider_repository import OAuthProviderRepository - from apis.shared.oauth.token_repository import OAuthTokenRepository - from apis.shared.oauth.encryption import TokenEncryptionService - from apis.shared.oauth.token_cache import TokenCache - from apis.shared.oauth.models import OAuthProviderCreate, OAuthProviderType - - pr = OAuthProviderRepository(table_name="test-oauth-providers", secrets_arn="oauth-client-secrets", region="us-east-1") - tr = OAuthTokenRepository(table_name="test-oauth-user-tokens", region="us-east-1") - enc = TokenEncryptionService(key_arn=kms_key_arn, region="us-east-1") - cache = TokenCache() - - # Create a provider - await pr.create_provider(OAuthProviderCreate( - provider_id="github", display_name="GitHub", provider_type=OAuthProviderType.GITHUB, - authorization_endpoint="https://github.com/login/oauth/authorize", - token_endpoint="https://github.com/login/oauth/access_token", - client_id="cid", client_secret="csec", scopes=["repo"], allowed_roles=["viewer"], - )) - - svc = OAuthService(provider_repo=pr, token_repo=tr, encryption_service=enc, token_cache=cache) - connections = await svc.get_user_connections("u1", ["viewer"]) - assert len(connections) == 1 - assert connections[0].provider_id == "github" - - class TestStateStore: def test_in_memory_store_and_retrieve(self): from apis.shared.auth.state_store import InMemoryStateStore, OIDCStateData diff --git a/backend/tests/shared/test_models_and_utils.py b/backend/tests/shared/test_models_and_utils.py index c3cb7c87..90e38a83 100644 --- a/backend/tests/shared/test_models_and_utils.py +++ b/backend/tests/shared/test_models_and_utils.py @@ -97,35 +97,28 @@ def test_compute_scopes_hash(self): def test_provider_scopes_hash_property(self): from apis.shared.oauth.models import OAuthProvider, OAuthProviderType - p = OAuthProvider(provider_id="p1", display_name="P", provider_type=OAuthProviderType.CUSTOM, client_id="c", authorization_endpoint="http://a", token_endpoint="http://t", scopes=["a", "b"], allowed_roles=[]) + p = OAuthProvider( + provider_id="p1", display_name="P", + provider_type=OAuthProviderType.GOOGLE, + scopes=["a", "b"], allowed_roles=[], + ) assert p.scopes_hash == p.scopes_hash # consistent - def test_token_is_expired(self): - from apis.shared.oauth.models import OAuthUserToken - t = OAuthUserToken(user_id="u1", provider_id="p1", access_token_encrypted="x", expires_at=946684800) - assert t.is_expired is True - - def test_token_not_expired(self): - from apis.shared.oauth.models import OAuthUserToken - t = OAuthUserToken(user_id="u1", provider_id="p1", access_token_encrypted="x", expires_at=4070908800) - assert t.is_expired is False - def test_provider_dynamo_roundtrip(self): from apis.shared.oauth.models import OAuthProvider, OAuthProviderType - p = OAuthProvider(provider_id="p1", display_name="P", provider_type=OAuthProviderType.CUSTOM, client_id="c", authorization_endpoint="http://a", token_endpoint="http://t", scopes=["a"], allowed_roles=[]) + p = OAuthProvider( + provider_id="p1", display_name="P", + provider_type=OAuthProviderType.GOOGLE, + scopes=["a"], allowed_roles=[], + credential_provider_arn="arn:aws:bedrock-agentcore:us-east-1:1:cp/p1", + callback_url="https://bedrock-agentcore.us-east-1.amazonaws.com/cb/p1", + ) item = p.to_dynamo_item() assert item["PK"] == "PROVIDER#p1" restored = OAuthProvider.from_dynamo_item(item) assert restored.provider_id == "p1" - - def test_token_dynamo_roundtrip(self): - from apis.shared.oauth.models import OAuthUserToken - t = OAuthUserToken(user_id="u1", provider_id="p1", access_token_encrypted="enc", expires_at=4070908800) - item = t.to_dynamo_item() - assert item["PK"] == "USER#u1" - assert item["SK"] == "PROVIDER#p1" - restored = OAuthUserToken.from_dynamo_item(item) - assert restored.user_id == "u1" + assert restored.callback_url == p.callback_url + assert restored.credential_provider_arn == p.credential_provider_arn # =================================================================== diff --git a/backend/tests/shared/test_oauth_agentcore_registrar.py b/backend/tests/shared/test_oauth_agentcore_registrar.py new file mode 100644 index 00000000..e0ce801b --- /dev/null +++ b/backend/tests/shared/test_oauth_agentcore_registrar.py @@ -0,0 +1,403 @@ +"""AgentCore Identity credential-provider registrar tests. + +Mocks the `bedrock-agentcore-control` boto3 client directly — these tests +verify our translation layer (our OAuthProviderType → AgentCore vendor + +config shape), not AWS behaviour. +""" + +from unittest.mock import MagicMock + +import pytest +from botocore.exceptions import ClientError + +from apis.shared.oauth.agentcore_registrar import ( + AgentCoreRegistrar, + CredentialProviderConflictError, + CredentialProviderNotFoundError, + InvalidCustomProviderConfigError, +) +from apis.shared.oauth.models import OAuthProviderType + + +def _client_error(code: str) -> ClientError: + return ClientError( + error_response={"Error": {"Code": code, "Message": code}}, + operation_name="op", + ) + + +def _create_response( + *, arn="arn:aws:acps:us-east-1:123:token-vault/default/oauth2credentialprovider/p", + secret_arn="arn:aws:secretsmanager:us-east-1:123:secret:s", + callback="https://example.invalid/callback/p", + config_output=None, +): + return { + "credentialProviderArn": arn, + "clientSecretArn": {"secretArn": secret_arn}, + "callbackUrl": callback, + "name": "p", + "oauth2ProviderConfigOutput": config_output or {}, + } + + +@pytest.fixture +def boto_client(): + return MagicMock() + + +@pytest.fixture +def registrar(boto_client): + return AgentCoreRegistrar(client=boto_client, region="us-east-1") + + +class TestCreateCredentialProvider: + def test_google_uses_google_vendor_and_config_key(self, registrar, boto_client): + boto_client.create_oauth2_credential_provider.return_value = _create_response() + + info = registrar.create_credential_provider( + provider_id="google-workspace", + provider_type=OAuthProviderType.GOOGLE, + client_id="cid", + client_secret="sec", + ) + + boto_client.create_oauth2_credential_provider.assert_called_once_with( + name="google-workspace", + credentialProviderVendor="GoogleOauth2", + oauth2ProviderConfigInput={ + "googleOauth2ProviderConfig": {"clientId": "cid", "clientSecret": "sec"} + }, + ) + assert info.vendor == "GoogleOauth2" + assert info.callback_url.endswith("/callback/p") + + @pytest.mark.parametrize( + "provider_type,expected_vendor,expected_key", + [ + (OAuthProviderType.MICROSOFT, "MicrosoftOauth2", "microsoftOauth2ProviderConfig"), + (OAuthProviderType.GITHUB, "GithubOauth2", "githubOauth2ProviderConfig"), + (OAuthProviderType.SLACK, "SlackOauth2", "slackOauth2ProviderConfig"), + ( + OAuthProviderType.SALESFORCE, + "SalesforceOauth2", + "salesforceOauth2ProviderConfig", + ), + # Zoom is a first-class vendor but uses the shared + # `includedOauth2ProviderConfig` slot rather than its own + # config struct — see the SDK's Oauth2ProviderConfigInput + # shape for the authoritative list. + (OAuthProviderType.ZOOM, "ZoomOauth2", "includedOauth2ProviderConfig"), + ], + ) + def test_other_known_vendors( + self, registrar, boto_client, provider_type, expected_vendor, expected_key + ): + boto_client.create_oauth2_credential_provider.return_value = _create_response() + + registrar.create_credential_provider( + provider_id="p", + provider_type=provider_type, + client_id="cid", + client_secret="sec", + ) + + call = boto_client.create_oauth2_credential_provider.call_args.kwargs + assert call["credentialProviderVendor"] == expected_vendor + assert expected_key in call["oauth2ProviderConfigInput"] + # And no `oauthDiscovery` block — that's customOauth2-only. + config = call["oauth2ProviderConfigInput"][expected_key] + assert "oauthDiscovery" not in config + + def test_custom_requires_discovery(self, registrar): + with pytest.raises(InvalidCustomProviderConfigError): + registrar.create_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.CUSTOM, + client_id="cid", + client_secret="sec", + ) + + def test_custom_rejects_both_discovery_modes(self, registrar): + with pytest.raises(InvalidCustomProviderConfigError): + registrar.create_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.CUSTOM, + client_id="cid", + client_secret="sec", + discovery_url="https://idp.example/.well-known/openid-configuration", + authorization_server_metadata={"authorizationEndpoint": "https://idp/auth"}, + ) + + def test_custom_with_discovery_url(self, registrar, boto_client): + boto_client.create_oauth2_credential_provider.return_value = _create_response() + + registrar.create_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.CUSTOM, + client_id="cid", + client_secret="sec", + discovery_url="https://idp.example/.well-known/openid-configuration", + ) + + config = boto_client.create_oauth2_credential_provider.call_args.kwargs[ + "oauth2ProviderConfigInput" + ]["customOauth2ProviderConfig"] + assert config["oauthDiscovery"] == { + "discoveryUrl": "https://idp.example/.well-known/openid-configuration" + } + + def test_canvas_routes_through_custom_vendor(self, registrar, boto_client): + boto_client.create_oauth2_credential_provider.return_value = _create_response() + + registrar.create_credential_provider( + provider_id="canvas", + provider_type=OAuthProviderType.CANVAS, + client_id="cid", + client_secret="sec", + authorization_server_metadata={ + "authorizationEndpoint": "https://canvas.example/login/oauth2/auth", + "tokenEndpoint": "https://canvas.example/login/oauth2/token", + }, + ) + + call = boto_client.create_oauth2_credential_provider.call_args.kwargs + assert call["credentialProviderVendor"] == "CustomOauth2" + assert "customOauth2ProviderConfig" in call["oauth2ProviderConfigInput"] + + @pytest.mark.parametrize( + "provider_type", + [ + OAuthProviderType.GOOGLE, + OAuthProviderType.MICROSOFT, + OAuthProviderType.GITHUB, + OAuthProviderType.SLACK, + OAuthProviderType.SALESFORCE, + OAuthProviderType.ZOOM, + ], + ) + def test_known_vendor_rejects_discovery_params(self, registrar, provider_type): + # Every first-class vendor (Google, Microsoft, GitHub, Slack, + # Salesforce, Zoom) has its endpoints baked in by AgentCore. A + # discovery URL only makes sense for the CustomOauth2 path. + with pytest.raises(ValueError, match="only valid for CustomOauth2"): + registrar.create_credential_provider( + provider_id="p", + provider_type=provider_type, + client_id="cid", + client_secret="sec", + discovery_url="https://idp.example/.well-known/openid-configuration", + ) + + def test_conflict_maps_to_domain_error(self, registrar, boto_client): + boto_client.create_oauth2_credential_provider.side_effect = _client_error( + "ConflictException" + ) + + with pytest.raises(CredentialProviderConflictError): + registrar.create_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.GOOGLE, + client_id="cid", + client_secret="sec", + ) + + def test_surfaces_client_id_when_echoed(self, registrar, boto_client): + boto_client.create_oauth2_credential_provider.return_value = _create_response( + config_output={"googleOauth2ProviderConfig": {"clientId": "cid"}}, + ) + + info = registrar.create_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.GOOGLE, + client_id="cid", + client_secret="sec", + ) + + assert info.client_id == "cid" + + +class TestUpdateCredentialProvider: + def test_sends_full_config(self, registrar, boto_client): + boto_client.update_oauth2_credential_provider.return_value = _create_response() + + registrar.update_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.GITHUB, + client_id="new-cid", + client_secret="new-sec", + ) + + call = boto_client.update_oauth2_credential_provider.call_args.kwargs + assert call["credentialProviderVendor"] == "GithubOauth2" + assert call["oauth2ProviderConfigInput"]["githubOauth2ProviderConfig"] == { + "clientId": "new-cid", + "clientSecret": "new-sec", + } + + def test_not_found_maps_to_domain_error(self, registrar, boto_client): + boto_client.update_oauth2_credential_provider.side_effect = _client_error( + "ResourceNotFoundException" + ) + + with pytest.raises(CredentialProviderNotFoundError): + registrar.update_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.GOOGLE, + client_id="cid", + client_secret="sec", + ) + + def test_uses_fallback_arn_when_update_response_omits_it( + self, registrar, boto_client + ): + """AWS's UpdateOauth2CredentialProvider response doesn't include + credentialProviderArn. The caller passes the known-immutable ARN + via `fallback_arn` so the returned info is well-formed.""" + update_response = _create_response() + update_response.pop("credentialProviderArn") + boto_client.update_oauth2_credential_provider.return_value = update_response + + info = registrar.update_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.GOOGLE, + client_id="cid", + client_secret="sec", + fallback_arn="arn:aws:acps:us-east-1:123:token-vault/default/oauth2credentialprovider/p", + ) + + assert info.credential_provider_arn == ( + "arn:aws:acps:us-east-1:123:token-vault/default/oauth2credentialprovider/p" + ) + + def test_raises_when_response_lacks_arn_and_no_fallback( + self, registrar, boto_client + ): + update_response = _create_response() + update_response.pop("credentialProviderArn") + boto_client.update_oauth2_credential_provider.return_value = update_response + + with pytest.raises(TypeError, match="credentialProviderArn"): + registrar.update_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.GOOGLE, + client_id="cid", + client_secret="sec", + ) + + +class TestGetCredentialProvider: + def test_returns_info_including_callback_url(self, registrar, boto_client): + boto_client.get_oauth2_credential_provider.return_value = { + **_create_response( + config_output={"googleOauth2ProviderConfig": {"clientId": "cid"}}, + ), + "credentialProviderVendor": "GoogleOauth2", + } + + info = registrar.get_credential_provider("p") + + assert info.vendor == "GoogleOauth2" + assert info.client_id == "cid" + assert info.callback_url.endswith("/callback/p") + + def test_not_found(self, registrar, boto_client): + boto_client.get_oauth2_credential_provider.side_effect = _client_error( + "ResourceNotFoundException" + ) + + with pytest.raises(CredentialProviderNotFoundError): + registrar.get_credential_provider("missing") + + +class TestDeleteCredentialProvider: + def test_calls_boto(self, registrar, boto_client): + registrar.delete_credential_provider("p") + boto_client.delete_oauth2_credential_provider.assert_called_once_with(name="p") + + def test_not_found_is_success(self, registrar, boto_client): + boto_client.delete_oauth2_credential_provider.side_effect = _client_error( + "ResourceNotFoundException" + ) + registrar.delete_credential_provider("missing") # no raise + + def test_other_errors_bubble(self, registrar, boto_client): + boto_client.delete_oauth2_credential_provider.side_effect = _client_error( + "AccessDeniedException" + ) + with pytest.raises(ClientError): + registrar.delete_credential_provider("p") + + +class TestResponseParsing: + """`_info_from_response` must fail loudly on contract violations and + tolerate documented variations (missing secret, missing clientId).""" + + def test_tolerates_missing_client_secret_arn(self, registrar, boto_client): + response = _create_response() + response.pop("clientSecretArn") + boto_client.create_oauth2_credential_provider.return_value = response + + info = registrar.create_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.GOOGLE, + client_id="cid", + client_secret="csec", + ) + assert info.client_secret_arn == "" + + def test_rejects_client_secret_arn_as_string(self, registrar, boto_client): + """AgentCore contract is `{secretArn: str}`; a raw string signals + an API contract change we should fail on loudly.""" + response = _create_response() + response["clientSecretArn"] = "arn:aws:secretsmanager:...:secret:s" + boto_client.create_oauth2_credential_provider.return_value = response + + with pytest.raises(TypeError, match="clientSecretArn of unexpected type"): + registrar.create_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.GOOGLE, + client_id="cid", + client_secret="csec", + ) + + def test_rejects_non_string_nested_secret_arn(self, registrar, boto_client): + response = _create_response() + response["clientSecretArn"] = {"secretArn": 12345} + boto_client.create_oauth2_credential_provider.return_value = response + + with pytest.raises(TypeError, match="secretArn of unexpected type"): + registrar.create_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.GOOGLE, + client_id="cid", + client_secret="csec", + ) + + def test_rejects_missing_credential_provider_arn(self, registrar, boto_client): + response = _create_response() + response.pop("credentialProviderArn") + boto_client.create_oauth2_credential_provider.return_value = response + + with pytest.raises(TypeError, match="credentialProviderArn"): + registrar.create_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.GOOGLE, + client_id="cid", + client_secret="csec", + ) + + def test_tolerates_missing_callback_url(self, registrar, boto_client): + """Callback URL absence falls back to empty string — non-fatal for + vendors that don't declare one yet.""" + response = _create_response() + response["callbackUrl"] = None + boto_client.create_oauth2_credential_provider.return_value = response + + info = registrar.create_credential_provider( + provider_id="p", + provider_type=OAuthProviderType.GOOGLE, + client_id="cid", + client_secret="csec", + ) + assert info.callback_url == "" diff --git a/backend/tests/shared/test_oauth_disconnect_repository.py b/backend/tests/shared/test_oauth_disconnect_repository.py new file mode 100644 index 00000000..884ca360 --- /dev/null +++ b/backend/tests/shared/test_oauth_disconnect_repository.py @@ -0,0 +1,79 @@ +"""OAuth disconnect repository tests (moto DynamoDB). + +Co-tenants the existing `oauth-user-tokens` table — items use a +`DISCONNECT#{provider_id}` sort-key prefix so they cannot collide with +future per-user token storage. +""" + +import boto3 +import pytest + +from apis.shared.oauth.disconnect_repository import OAuthDisconnectRepository + + +@pytest.fixture() +def disconnect_repository(oauth_tokens_table, monkeypatch): + # Earlier tests in the suite can leave boto3's default session bound + # to real-world SSO credentials. moto only mocks API calls, not the + # credential resolution chain, so a stale SSO session would later try + # to refresh against AWS and fail (`GetRoleCredentials: Not yet + # implemented`). Resetting the default session forces boto3 to + # rebuild it under the conftest's `AWS_ACCESS_KEY_ID=testing` env + # vars on first use. + monkeypatch.setattr(boto3, "DEFAULT_SESSION", None) + monkeypatch.delenv("AWS_PROFILE", raising=False) + return OAuthDisconnectRepository( + table_name="test-oauth-user-tokens", + region="us-east-1", + ) + + +class TestOAuthDisconnectRepository: + @pytest.mark.asyncio + async def test_default_state_is_not_disconnected(self, disconnect_repository): + assert await disconnect_repository.is_disconnected("alice", "google") is False + + @pytest.mark.asyncio + async def test_mark_then_check(self, disconnect_repository): + await disconnect_repository.mark_disconnected("alice", "google") + assert await disconnect_repository.is_disconnected("alice", "google") is True + + @pytest.mark.asyncio + async def test_per_user_isolation(self, disconnect_repository): + await disconnect_repository.mark_disconnected("alice", "google") + assert await disconnect_repository.is_disconnected("bob", "google") is False + + @pytest.mark.asyncio + async def test_per_provider_isolation(self, disconnect_repository): + await disconnect_repository.mark_disconnected("alice", "google") + assert await disconnect_repository.is_disconnected("alice", "github") is False + + @pytest.mark.asyncio + async def test_clear_disconnected(self, disconnect_repository): + await disconnect_repository.mark_disconnected("alice", "google") + await disconnect_repository.clear_disconnected("alice", "google") + assert await disconnect_repository.is_disconnected("alice", "google") is False + + @pytest.mark.asyncio + async def test_clear_when_not_set_is_noop(self, disconnect_repository): + # Idempotent: clearing a flag that was never set should not raise. + await disconnect_repository.clear_disconnected("alice", "google") + assert await disconnect_repository.is_disconnected("alice", "google") is False + + @pytest.mark.asyncio + async def test_mark_is_idempotent(self, disconnect_repository): + await disconnect_repository.mark_disconnected("alice", "google") + await disconnect_repository.mark_disconnected("alice", "google") + assert await disconnect_repository.is_disconnected("alice", "google") is True + + @pytest.mark.asyncio + async def test_disabled_when_table_env_unset(self, monkeypatch): + # Without the env var set, the repo silently no-ops so local-dev + # without OAuth wiring still boots and `/status` falls through to + # AgentCore's vault state. + monkeypatch.delenv("DYNAMODB_OAUTH_USER_TOKENS_TABLE_NAME", raising=False) + repo = OAuthDisconnectRepository() + assert repo.enabled is False + assert await repo.is_disconnected("alice", "google") is False + await repo.mark_disconnected("alice", "google") # no-op, no raise + assert await repo.is_disconnected("alice", "google") is False diff --git a/backend/tests/shared/test_oauth_models.py b/backend/tests/shared/test_oauth_models.py new file mode 100644 index 00000000..46fa015b --- /dev/null +++ b/backend/tests/shared/test_oauth_models.py @@ -0,0 +1,54 @@ +"""DynamoDB serde tests for `OAuthProvider`. + +Focused on the round-trip — admins paste vendor-specific OAuth params +into the form, we persist them, and they must come back unchanged so +the merge with the hardcoded baseline at runtime stays correct. +""" + +from __future__ import annotations + +import pytest + +from apis.shared.oauth.models import OAuthProvider, OAuthProviderType + + +def _provider(**overrides) -> OAuthProvider: + base = dict( + provider_id="google-workspace", + display_name="Google Workspace", + provider_type=OAuthProviderType.GOOGLE, + scopes=["openid", "email"], + allowed_roles=[], + ) + base.update(overrides) + return OAuthProvider(**base) + + +class TestCustomParametersRoundTrip: + def test_round_trip_preserves_custom_parameters(self) -> None: + original = _provider( + custom_parameters={"hd": "mycompany.com", "prompt": "consent"}, + ) + revived = OAuthProvider.from_dynamo_item(original.to_dynamo_item()) + assert revived.custom_parameters == { + "hd": "mycompany.com", + "prompt": "consent", + } + + def test_none_round_trips_as_none(self) -> None: + # Default state for vendors with no admin-supplied extras. + original = _provider(custom_parameters=None) + item = original.to_dynamo_item() + # Persisted as None — the route layer treats `{}` as "explicitly + # cleared" and converts it to None before save. + assert item["customParameters"] is None + revived = OAuthProvider.from_dynamo_item(item) + assert revived.custom_parameters is None + + def test_legacy_item_without_field_loads_as_none(self) -> None: + # Simulates a pre-migration row in DynamoDB. New code must not + # KeyError on the missing attribute. + item = _provider(custom_parameters={"foo": "bar"}).to_dynamo_item() + del item["customParameters"] + revived = OAuthProvider.from_dynamo_item(item) + assert revived.custom_parameters is None diff --git a/backend/tests/shared/test_oauth_repos_extended.py b/backend/tests/shared/test_oauth_repos_extended.py deleted file mode 100644 index 0c876fe5..00000000 --- a/backend/tests/shared/test_oauth_repos_extended.py +++ /dev/null @@ -1,198 +0,0 @@ -"""Extended OAuth repository tests for deeper coverage.""" - -import pytest - - -class TestOAuthProviderRepositoryExtended: - @pytest.fixture(autouse=True) - def _setup(self, oauth_provider_repository): - self.repo = oauth_provider_repository - - def _make_create(self, pid="github", **kw): - from apis.shared.oauth.models import OAuthProviderCreate, OAuthProviderType - defaults = dict( - provider_id=pid, display_name="GitHub", provider_type=OAuthProviderType.GITHUB, - authorization_endpoint="https://github.com/login/oauth/authorize", - token_endpoint="https://github.com/login/oauth/access_token", - client_id="cid", client_secret="secret", scopes=["repo"], - allowed_roles=["viewer"], - ) - defaults.update(kw) - return OAuthProviderCreate(**defaults) - - @pytest.mark.asyncio - async def test_create_and_get(self): - p = await self.repo.create_provider(self._make_create()) - assert p.provider_id == "github" - got = await self.repo.get_provider("github") - assert got is not None - assert got.display_name == "GitHub" - - @pytest.mark.asyncio - async def test_get_nonexistent(self): - assert await self.repo.get_provider("nope") is None - - @pytest.mark.asyncio - async def test_create_duplicate_raises(self): - await self.repo.create_provider(self._make_create()) - with pytest.raises(ValueError, match="already exists"): - await self.repo.create_provider(self._make_create()) - - @pytest.mark.asyncio - async def test_list_all(self): - await self.repo.create_provider(self._make_create("a")) - await self.repo.create_provider(self._make_create("b")) - providers = await self.repo.list_providers() - assert len(providers) == 2 - - @pytest.mark.asyncio - async def test_list_enabled_only(self): - await self.repo.create_provider(self._make_create("a", enabled=True)) - await self.repo.create_provider(self._make_create("b", enabled=False)) - enabled = await self.repo.list_providers(enabled_only=True) - assert all(p.enabled for p in enabled) - - @pytest.mark.asyncio - async def test_update_provider(self): - from apis.shared.oauth.models import OAuthProviderUpdate - await self.repo.create_provider(self._make_create()) - updated = await self.repo.update_provider("github", OAuthProviderUpdate(display_name="GH")) - assert updated.display_name == "GH" - - @pytest.mark.asyncio - async def test_update_nonexistent(self): - from apis.shared.oauth.models import OAuthProviderUpdate - assert await self.repo.update_provider("nope", OAuthProviderUpdate(display_name="X")) is None - - @pytest.mark.asyncio - async def test_update_client_secret(self): - from apis.shared.oauth.models import OAuthProviderUpdate - await self.repo.create_provider(self._make_create()) - await self.repo.update_provider("github", OAuthProviderUpdate(client_secret="new_secret")) - secret = await self.repo.get_client_secret("github") - assert secret == "new_secret" - - @pytest.mark.asyncio - async def test_delete_provider(self): - await self.repo.create_provider(self._make_create()) - assert await self.repo.delete_provider("github") is True - assert await self.repo.get_provider("github") is None - - @pytest.mark.asyncio - async def test_delete_nonexistent(self): - assert await self.repo.delete_provider("nope") is False - - @pytest.mark.asyncio - async def test_get_client_secret(self): - await self.repo.create_provider(self._make_create()) - secret = await self.repo.get_client_secret("github") - assert secret == "secret" - - @pytest.mark.asyncio - async def test_get_client_secret_missing(self): - secret = await self.repo.get_client_secret("nope") - assert secret is None - - @pytest.mark.asyncio - async def test_disabled_repo(self, monkeypatch): - monkeypatch.delenv("DYNAMODB_OAUTH_PROVIDERS_TABLE_NAME", raising=False) - from apis.shared.oauth.provider_repository import OAuthProviderRepository - repo = OAuthProviderRepository(table_name=None) - assert repo.enabled is False - assert await repo.get_provider("x") is None - assert await repo.list_providers() == [] - assert await repo.delete_provider("x") is False - - -class TestOAuthTokenRepositoryExtended: - @pytest.fixture(autouse=True) - def _setup(self, oauth_token_repository): - self.repo = oauth_token_repository - - def _make_token(self, user_id="u1", provider_id="github", **kw): - from apis.shared.oauth.models import OAuthUserToken, OAuthConnectionStatus - defaults = dict( - user_id=user_id, provider_id=provider_id, - access_token_encrypted="enc_tok", - status=OAuthConnectionStatus.CONNECTED, - connected_at="2026-01-01", - ) - defaults.update(kw) - return OAuthUserToken(**defaults) - - @pytest.mark.asyncio - async def test_save_and_get(self): - token = self._make_token() - saved = await self.repo.save_token(token) - assert saved.updated_at is not None - got = await self.repo.get_token("u1", "github") - assert got is not None - assert got.access_token_encrypted == "enc_tok" - - @pytest.mark.asyncio - async def test_get_nonexistent(self): - assert await self.repo.get_token("u1", "nope") is None - - @pytest.mark.asyncio - async def test_list_user_tokens(self): - await self.repo.save_token(self._make_token(provider_id="a")) - await self.repo.save_token(self._make_token(provider_id="b")) - tokens = await self.repo.list_user_tokens("u1") - assert len(tokens) == 2 - - @pytest.mark.asyncio - async def test_list_provider_tokens(self): - await self.repo.save_token(self._make_token(user_id="u1")) - await self.repo.save_token(self._make_token(user_id="u2")) - tokens = await self.repo.list_provider_tokens("github") - assert len(tokens) == 2 - - @pytest.mark.asyncio - async def test_update_token_status(self): - from apis.shared.oauth.models import OAuthConnectionStatus - await self.repo.save_token(self._make_token()) - updated = await self.repo.update_token_status("u1", "github", OAuthConnectionStatus.EXPIRED) - assert updated.status == OAuthConnectionStatus.EXPIRED - - @pytest.mark.asyncio - async def test_update_token_status_nonexistent(self): - from apis.shared.oauth.models import OAuthConnectionStatus - assert await self.repo.update_token_status("u1", "nope", OAuthConnectionStatus.EXPIRED) is None - - @pytest.mark.asyncio - async def test_delete_token(self): - await self.repo.save_token(self._make_token()) - assert await self.repo.delete_token("u1", "github") is True - assert await self.repo.get_token("u1", "github") is None - - @pytest.mark.asyncio - async def test_delete_nonexistent(self): - assert await self.repo.delete_token("u1", "nope") is False - - @pytest.mark.asyncio - async def test_delete_user_tokens(self): - await self.repo.save_token(self._make_token(provider_id="a")) - await self.repo.save_token(self._make_token(provider_id="b")) - count = await self.repo.delete_user_tokens("u1") - assert count == 2 - assert await self.repo.list_user_tokens("u1") == [] - - @pytest.mark.asyncio - async def test_delete_provider_tokens(self): - await self.repo.save_token(self._make_token(user_id="u1")) - await self.repo.save_token(self._make_token(user_id="u2")) - count = await self.repo.delete_provider_tokens("github") - assert count == 2 - - @pytest.mark.asyncio - async def test_disabled_repo(self, monkeypatch): - monkeypatch.delenv("DYNAMODB_OAUTH_USER_TOKENS_TABLE_NAME", raising=False) - from apis.shared.oauth.token_repository import OAuthTokenRepository - repo = OAuthTokenRepository(table_name=None) - assert repo.enabled is False - assert await repo.get_token("u1", "x") is None - assert await repo.list_user_tokens("u1") == [] - assert await repo.list_provider_tokens("x") == [] - assert await repo.delete_token("u1", "x") is False - assert await repo.delete_user_tokens("u1") == 0 - assert await repo.delete_provider_tokens("x") == 0 diff --git a/backend/tests/shared/test_oauth_repositories.py b/backend/tests/shared/test_oauth_repositories.py index 9035b2b9..381ecf0f 100644 --- a/backend/tests/shared/test_oauth_repositories.py +++ b/backend/tests/shared/test_oauth_repositories.py @@ -1,44 +1,38 @@ -"""Task 6: OAuth provider + token repositories (moto DynamoDB + Secrets Manager).""" +"""OAuth provider repository tests (moto DynamoDB).""" import pytest -from apis.shared.oauth.models import OAuthProvider, OAuthProviderType, OAuthUserToken, OAuthConnectionStatus - -def _make_provider(provider_id="github", **kw): - defaults = dict( - provider_id=provider_id, display_name="GitHub", provider_type=OAuthProviderType.GITHUB, - authorization_endpoint="https://github.com/login/oauth/authorize", - token_endpoint="https://github.com/login/oauth/access_token", - client_id="cid", scopes=["repo"], allowed_roles=["editor"], - ) - defaults.update(kw) - return defaults +from apis.shared.oauth.models import ( + OAuthProvider, + OAuthProviderType, + OAuthProviderUpdate, +) +from apis.shared.oauth.provider_repository import OAuthProviderRepository -def _make_token(user_id="u1", provider_id="github", **kw): +def _make_provider(provider_id="github", **kw) -> OAuthProvider: defaults = dict( - user_id=user_id, provider_id=provider_id, - access_token_encrypted="enc-token", token_type="Bearer", - scopes_hash="abc", status=OAuthConnectionStatus.CONNECTED, + provider_id=provider_id, + display_name="GitHub", + provider_type=OAuthProviderType.GITHUB, + scopes=["repo"], + allowed_roles=["editor"], + credential_provider_arn=f"arn:aws:bedrock-agentcore:us-east-1:1:cp/{provider_id}", + callback_url=f"https://bedrock-agentcore.us-east-1.amazonaws.com/cb/{provider_id}", ) defaults.update(kw) - return OAuthUserToken(**defaults) + return OAuthProvider(**defaults) -# =================================================================== -# OAuthProviderRepository -# =================================================================== - class TestOAuthProviderRepository: @pytest.mark.asyncio - async def test_create_and_get(self, oauth_provider_repository): - from apis.shared.oauth.models import OAuthProviderCreate - data = OAuthProviderCreate(**_make_provider(), client_secret="secret") - provider = await oauth_provider_repository.create_provider(data) - assert provider.provider_id == "github" + async def test_put_and_get(self, oauth_provider_repository): + await oauth_provider_repository.put_provider(_make_provider()) result = await oauth_provider_repository.get_provider("github") assert result is not None assert result.display_name == "GitHub" + assert result.callback_url.endswith("/cb/github") + assert result.credential_provider_arn @pytest.mark.asyncio async def test_get_nonexistent(self, oauth_provider_repository): @@ -46,101 +40,51 @@ async def test_get_nonexistent(self, oauth_provider_repository): @pytest.mark.asyncio async def test_list_all(self, oauth_provider_repository): - from apis.shared.oauth.models import OAuthProviderCreate - await oauth_provider_repository.create_provider(OAuthProviderCreate(**_make_provider("p1"), client_secret="s")) - await oauth_provider_repository.create_provider(OAuthProviderCreate(**_make_provider("p2"), client_secret="s")) + await oauth_provider_repository.put_provider(_make_provider("p1")) + await oauth_provider_repository.put_provider(_make_provider("p2")) providers = await oauth_provider_repository.list_providers() assert len(providers) == 2 @pytest.mark.asyncio async def test_list_enabled_only(self, oauth_provider_repository): - from apis.shared.oauth.models import OAuthProviderCreate - await oauth_provider_repository.create_provider(OAuthProviderCreate(**_make_provider("p1"), client_secret="s")) - await oauth_provider_repository.create_provider(OAuthProviderCreate(**_make_provider("p2", enabled=False), client_secret="s")) + await oauth_provider_repository.put_provider(_make_provider("p1")) + await oauth_provider_repository.put_provider(_make_provider("p2", enabled=False)) providers = await oauth_provider_repository.list_providers(enabled_only=True) assert len(providers) == 1 + assert providers[0].provider_id == "p1" + + @pytest.mark.asyncio + async def test_apply_metadata_update(self, oauth_provider_repository): + await oauth_provider_repository.put_provider(_make_provider()) + updated = await oauth_provider_repository.apply_metadata_update( + "github", + OAuthProviderUpdate(display_name="GH", scopes=["repo", "read:user"]), + ) + assert updated.display_name == "GH" + assert updated.scopes == ["repo", "read:user"] + + @pytest.mark.asyncio + async def test_apply_metadata_update_nonexistent(self, oauth_provider_repository): + updates = OAuthProviderUpdate(display_name="X") + assert await oauth_provider_repository.apply_metadata_update("nope", updates) is None @pytest.mark.asyncio async def test_delete_provider(self, oauth_provider_repository): - from apis.shared.oauth.models import OAuthProviderCreate - await oauth_provider_repository.create_provider(OAuthProviderCreate(**_make_provider(), client_secret="s")) + await oauth_provider_repository.put_provider(_make_provider()) assert await oauth_provider_repository.delete_provider("github") is True assert await oauth_provider_repository.get_provider("github") is None @pytest.mark.asyncio - async def test_client_secret(self, oauth_provider_repository): - from apis.shared.oauth.models import OAuthProviderCreate - await oauth_provider_repository.create_provider(OAuthProviderCreate(**_make_provider(), client_secret="my-secret")) - secret = await oauth_provider_repository.get_client_secret("github") - assert secret == "my-secret" + async def test_delete_nonexistent(self, oauth_provider_repository): + assert await oauth_provider_repository.delete_provider("nope") is False def test_disabled_when_no_table(self): - from apis.shared.oauth.provider_repository import OAuthProviderRepository repo = OAuthProviderRepository(table_name=None) assert repo.enabled is False - -# =================================================================== -# OAuthTokenRepository -# =================================================================== - -class TestOAuthTokenRepository: - @pytest.mark.asyncio - async def test_save_and_get(self, oauth_token_repository): - token = _make_token() - saved = await oauth_token_repository.save_token(token) - assert saved.user_id == "u1" - result = await oauth_token_repository.get_token("u1", "github") - assert result is not None - assert result.access_token_encrypted == "enc-token" - - @pytest.mark.asyncio - async def test_get_nonexistent(self, oauth_token_repository): - assert await oauth_token_repository.get_token("u1", "nope") is None - - @pytest.mark.asyncio - async def test_list_user_tokens(self, oauth_token_repository): - await oauth_token_repository.save_token(_make_token(provider_id="p1")) - await oauth_token_repository.save_token(_make_token(provider_id="p2")) - tokens = await oauth_token_repository.list_user_tokens("u1") - assert len(tokens) == 2 - - @pytest.mark.asyncio - async def test_list_provider_tokens(self, oauth_token_repository): - await oauth_token_repository.save_token(_make_token(user_id="u1")) - await oauth_token_repository.save_token(_make_token(user_id="u2")) - tokens = await oauth_token_repository.list_provider_tokens("github") - assert len(tokens) == 2 - - @pytest.mark.asyncio - async def test_update_token_status(self, oauth_token_repository): - await oauth_token_repository.save_token(_make_token()) - updated = await oauth_token_repository.update_token_status("u1", "github", OAuthConnectionStatus.EXPIRED) - assert updated is not None - assert updated.status == OAuthConnectionStatus.EXPIRED - @pytest.mark.asyncio - async def test_delete_token(self, oauth_token_repository): - await oauth_token_repository.save_token(_make_token()) - assert await oauth_token_repository.delete_token("u1", "github") is True - assert await oauth_token_repository.get_token("u1", "github") is None - - @pytest.mark.asyncio - async def test_delete_user_tokens(self, oauth_token_repository): - await oauth_token_repository.save_token(_make_token(provider_id="p1")) - await oauth_token_repository.save_token(_make_token(provider_id="p2")) - count = await oauth_token_repository.delete_user_tokens("u1") - assert count == 2 - assert len(await oauth_token_repository.list_user_tokens("u1")) == 0 - - @pytest.mark.asyncio - async def test_delete_provider_tokens(self, oauth_token_repository): - await oauth_token_repository.save_token(_make_token(user_id="u1")) - await oauth_token_repository.save_token(_make_token(user_id="u2")) - count = await oauth_token_repository.delete_provider_tokens("github") - assert count == 2 - - def test_disabled_when_no_table(self): - from apis.shared.oauth.token_repository import OAuthTokenRepository - repo = OAuthTokenRepository(table_name=None) - assert repo.enabled is False + async def test_disabled_repo_is_inert(self): + repo = OAuthProviderRepository(table_name=None) + assert await repo.get_provider("x") is None + assert await repo.list_providers() == [] + assert await repo.delete_provider("x") is False diff --git a/backend/tests/shared/test_oauth_service.py b/backend/tests/shared/test_oauth_service.py deleted file mode 100644 index 8b4085cf..00000000 --- a/backend/tests/shared/test_oauth_service.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Task 7: OAuth encryption (moto KMS), token cache, and service tests.""" - -import time -import pytest -from unittest.mock import AsyncMock, MagicMock, patch - - -# =================================================================== -# TokenEncryptionService (moto KMS) -# =================================================================== - -class TestTokenEncryptionService: - def test_encrypt_decrypt_roundtrip(self, kms_key_arn): - from apis.shared.oauth.encryption import TokenEncryptionService - svc = TokenEncryptionService(key_arn=kms_key_arn, region="us-east-1") - ciphertext = svc.encrypt("my-secret-token") - assert ciphertext != "my-secret-token" - plaintext = svc.decrypt(ciphertext) - assert plaintext == "my-secret-token" - - def test_disabled_without_key(self): - from apis.shared.oauth.encryption import TokenEncryptionService - svc = TokenEncryptionService(key_arn=None) - assert svc.enabled is False - - def test_encrypt_disabled_returns_plaintext(self): - from apis.shared.oauth.encryption import TokenEncryptionService - svc = TokenEncryptionService(key_arn=None) - result = svc.encrypt("token") - assert result.startswith("DEV:") - - def test_decrypt_disabled_returns_ciphertext(self): - from apis.shared.oauth.encryption import TokenEncryptionService - svc = TokenEncryptionService(key_arn=None) - encrypted = svc.encrypt("token") - assert svc.decrypt(encrypted) == "token" - - -# =================================================================== -# TokenCache (pure in-memory) -# =================================================================== - -class TestTokenCache: - @pytest.fixture() - def cache(self): - from apis.shared.oauth.token_cache import TokenCache - return TokenCache() - - def test_set_and_get(self, cache): - cache.set("u1", "p1", "token-abc") - assert cache.get("u1", "p1") == "token-abc" - - def test_get_missing(self, cache): - assert cache.get("u1", "p1") is None - - def test_delete(self, cache): - cache.set("u1", "p1", "token") - assert cache.delete("u1", "p1") is True - assert cache.get("u1", "p1") is None - - def test_delete_missing(self, cache): - assert cache.delete("u1", "p1") is False - - def test_delete_for_user(self, cache): - cache.set("u1", "p1", "t1") - cache.set("u1", "p2", "t2") - cache.set("u2", "p1", "t3") - count = cache.delete_for_user("u1") - assert count == 2 - assert cache.get("u2", "p1") == "t3" - - def test_delete_for_provider(self, cache): - cache.set("u1", "p1", "t1") - cache.set("u2", "p1", "t2") - cache.set("u1", "p2", "t3") - count = cache.delete_for_provider("p1") - assert count == 2 - assert cache.get("u1", "p2") == "t3" - - def test_clear(self, cache): - cache.set("u1", "p1", "t") - cache.clear() - assert cache.get("u1", "p1") is None - - def test_get_stats(self, cache): - cache.set("u1", "p1", "t") - stats = cache.get_stats() - assert stats["size"] == 1 - - -# =================================================================== -# OAuthService -# =================================================================== - -class TestOAuthServicePKCE: - def test_generate_pkce_pair(self): - from apis.shared.oauth.service import generate_pkce_pair - verifier, challenge = generate_pkce_pair() - assert len(verifier) > 20 - assert len(challenge) > 20 - assert verifier != challenge - - -class TestOAuthServiceConnect: - @pytest.fixture() - def oauth_service(self, oauth_provider_repository, oauth_token_repository, kms_key_arn, monkeypatch): - monkeypatch.setenv("OAUTH_CALLBACK_URL", "http://localhost:8000/api/oauth/callback") - from apis.shared.oauth.service import OAuthService - from apis.shared.oauth.encryption import TokenEncryptionService - from apis.shared.oauth.token_cache import TokenCache - from apis.shared.auth.state_store import InMemoryStateStore - - enc = TokenEncryptionService(key_arn=kms_key_arn, region="us-east-1") - cache = TokenCache() - state_store = InMemoryStateStore() - - return OAuthService( - provider_repo=oauth_provider_repository, - token_repo=oauth_token_repository, - encryption_service=enc, - token_cache=cache, - state_store=state_store, - ) - - @pytest.mark.asyncio - async def test_initiate_connect(self, oauth_service, oauth_provider_repository): - from apis.shared.oauth.models import OAuthProviderCreate - await oauth_provider_repository.create_provider( - OAuthProviderCreate( - provider_id="github", display_name="GitHub", - provider_type="github", client_id="cid", client_secret="secret", - authorization_endpoint="https://github.com/login/oauth/authorize", - token_endpoint="https://github.com/login/oauth/access_token", - scopes=["repo"], allowed_roles=["editor"], - ) - ) - result = await oauth_service.initiate_connect( - provider_id="github", user_id="u1", - user_roles=["editor"], - frontend_redirect="http://localhost/callback" - ) - assert "github.com" in result - - @pytest.mark.asyncio - async def test_get_user_connections_empty(self, oauth_service): - connections = await oauth_service.get_user_connections("u1", user_roles=["editor"]) - assert connections == [] - - @pytest.mark.asyncio - async def test_disconnect_nonexistent(self, oauth_service): - # Should not raise - result = await oauth_service.disconnect("u1", "github") - assert result is True or result is False # implementation-dependent diff --git a/backend/tests/shared/test_sessions_messages.py b/backend/tests/shared/test_sessions_messages.py index df82f7f8..1cc098ab 100644 --- a/backend/tests/shared/test_sessions_messages.py +++ b/backend/tests/shared/test_sessions_messages.py @@ -20,7 +20,8 @@ async def test_get_messages_from_cloud(self, monkeypatch): with patch("apis.shared.sessions.messages.AgentCoreMemorySessionManager", return_value=mock_session_mgr), \ patch("apis.shared.sessions.messages.AgentCoreMemoryConfig"), \ patch("apis.shared.sessions.messages.AGENTCORE_MEMORY_AVAILABLE", True), \ - patch("apis.shared.sessions.metadata.get_all_message_metadata", new_callable=AsyncMock, return_value={}): + patch("apis.shared.sessions.metadata.get_all_message_metadata", new_callable=AsyncMock, return_value={}), \ + patch("apis.shared.sessions.metadata.get_pending_interrupts", new_callable=AsyncMock, return_value=[]): from apis.shared.sessions.messages import get_messages_from_cloud result = await get_messages_from_cloud("s1", "u1") assert len(result.messages) == 2 @@ -38,7 +39,8 @@ async def test_get_messages_pagination(self, monkeypatch): with patch("apis.shared.sessions.messages.AgentCoreMemorySessionManager", return_value=mock_session_mgr), \ patch("apis.shared.sessions.messages.AgentCoreMemoryConfig"), \ patch("apis.shared.sessions.messages.AGENTCORE_MEMORY_AVAILABLE", True), \ - patch("apis.shared.sessions.metadata.get_all_message_metadata", new_callable=AsyncMock, return_value={}): + patch("apis.shared.sessions.metadata.get_all_message_metadata", new_callable=AsyncMock, return_value={}), \ + patch("apis.shared.sessions.metadata.get_pending_interrupts", new_callable=AsyncMock, return_value=[]): from apis.shared.sessions.messages import get_messages_from_cloud result = await get_messages_from_cloud("s1", "u1", limit=3) assert len(result.messages) == 3 diff --git a/backend/tests/shared/test_sessions_metadata.py b/backend/tests/shared/test_sessions_metadata.py index e1d64edb..4e46e40f 100644 --- a/backend/tests/shared/test_sessions_metadata.py +++ b/backend/tests/shared/test_sessions_metadata.py @@ -230,3 +230,403 @@ async def test_missing_env_raises(self, sessions_metadata_table, monkeypatch): await store_user_display_text( session_id="s1", user_id="u1", message_id=0, display_text="boom", ) + + +class TestUpdateSessionActivity: + """Per-turn metadata update via targeted writes — closes the merge-write race.""" + + @pytest.mark.asyncio + async def test_increments_message_count(self, sessions_metadata_table): + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, + update_session_activity, + get_session_metadata, + ) + await ensure_session_metadata_exists("s1", "u1") + before = await get_session_metadata("s1", "u1") + assert before.message_count == 0 + + applied = await update_session_activity( + session_id="s1", user_id="u1", last_model="claude-3", last_temperature=0.7, + ) + assert applied is True + after = await get_session_metadata("s1", "u1") + assert after.message_count == 1 + + await update_session_activity(session_id="s1", user_id="u1", last_model="claude-3") + after2 = await get_session_metadata("s1", "u1") + assert after2.message_count == 2 + + @pytest.mark.asyncio + async def test_preserves_title_set_by_title_gen(self, sessions_metadata_table): + """Race regression: post-stream activity update must not clobber title-gen's write.""" + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, + update_session_title, + update_session_activity, + get_session_metadata, + ) + await ensure_session_metadata_exists("s1", "u1") + await update_session_title("s1", "u1", "My Generated Title") + await update_session_activity( + session_id="s1", user_id="u1", last_model="claude-3", last_temperature=0.5, + ) + result = await get_session_metadata("s1", "u1") + assert result.title == "My Generated Title" + + @pytest.mark.asyncio + async def test_preserves_pending_interrupts(self, sessions_metadata_table): + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, + add_pending_interrupt, + update_session_activity, + get_pending_interrupts, + ) + from apis.shared.sessions.models import PendingInterrupt + + await ensure_session_metadata_exists("s1", "u1") + await add_pending_interrupt( + session_id="s1", user_id="u1", + interrupt=PendingInterrupt( + interruptId="i1", providerId="slack", createdAt="2026-04-25T00:00:00Z", + ), + ) + await update_session_activity(session_id="s1", user_id="u1", last_model="claude-3") + interrupts = await get_pending_interrupts("s1", "u1") + assert len(interrupts) == 1 + assert interrupts[0].interrupt_id == "i1" + + @pytest.mark.asyncio + async def test_preserves_assistant_id_in_preferences(self, sessions_metadata_table): + """assistant_id set by the assistant-attach flow must survive per-turn updates.""" + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, + store_session_metadata, + update_session_activity, + get_session_metadata, + ) + from apis.shared.sessions.models import SessionMetadata, SessionPreferences + + await ensure_session_metadata_exists("s1", "u1") + existing = await get_session_metadata("s1", "u1") + seeded = SessionMetadata( + sessionId="s1", userId="u1", + title=existing.title, status="active", + createdAt=existing.created_at, + lastMessageAt=existing.last_message_at, + messageCount=existing.message_count, + preferences=SessionPreferences(assistantId="asst-abc"), + ) + await store_session_metadata("s1", "u1", seeded) + + await update_session_activity( + session_id="s1", user_id="u1", last_model="claude-3", last_temperature=0.5, + ) + result = await get_session_metadata("s1", "u1") + assert result.preferences.assistant_id == "asst-abc" + assert result.preferences.last_model == "claude-3" + assert result.preferences.last_temperature == 0.5 + + @pytest.mark.asyncio + async def test_rotates_sk_to_new_timestamp(self, sessions_metadata_table): + """SK rotation keeps recency listing correct — only one row remains.""" + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, + update_session_activity, + ) + await ensure_session_metadata_exists("s1", "u1") + items = sessions_metadata_table.scan()["Items"] + s_items = [i for i in items if i["SK"].startswith("S#ACTIVE#")] + assert len(s_items) == 1 + old_sk = s_items[0]["SK"] + + await update_session_activity(session_id="s1", user_id="u1", last_model="claude-3") + + items = sessions_metadata_table.scan()["Items"] + s_items_after = [i for i in items if i["SK"].startswith("S#ACTIVE#")] + assert len(s_items_after) == 1 + assert s_items_after[0]["SK"] != old_sk + + @pytest.mark.asyncio + async def test_self_heals_when_row_missing(self, sessions_metadata_table): + """If pre-create failed (or row was deleted), update self-heals via ensure_session_metadata_exists.""" + from apis.shared.sessions.metadata import update_session_activity, get_session_metadata + applied = await update_session_activity( + session_id="never-pre-created", user_id="u1", last_model="claude-3", + ) + assert applied is True + result = await get_session_metadata("never-pre-created", "u1") + assert result is not None + assert result.message_count == 1 + + @pytest.mark.asyncio + async def test_noop_for_preview_session(self, sessions_metadata_table): + from apis.shared.sessions.metadata import update_session_activity + applied = await update_session_activity( + session_id="preview-abc", user_id="u1", last_model="claude-3", + ) + assert applied is False + items = sessions_metadata_table.scan()["Items"] + assert items == [] + + +class TestAddPendingInterruptListAppend: + """list_append-based persistence — race-free with no read-modify-write.""" + + @pytest.mark.asyncio + async def test_first_interrupt(self, sessions_metadata_table): + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, add_pending_interrupt, get_pending_interrupts, + ) + from apis.shared.sessions.models import PendingInterrupt + + await ensure_session_metadata_exists("s1", "u1") + await add_pending_interrupt( + session_id="s1", user_id="u1", + interrupt=PendingInterrupt( + interruptId="i1", providerId="slack", createdAt="2026-04-25T00:00:00Z", + ), + ) + interrupts = await get_pending_interrupts("s1", "u1") + assert len(interrupts) == 1 + assert interrupts[0].interrupt_id == "i1" + assert interrupts[0].provider_id == "slack" + + @pytest.mark.asyncio + async def test_two_distinct_interrupts_accumulate(self, sessions_metadata_table): + """Two adds for different ids accumulate — list_append is atomic in DynamoDB.""" + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, add_pending_interrupt, get_pending_interrupts, + ) + from apis.shared.sessions.models import PendingInterrupt + + await ensure_session_metadata_exists("s1", "u1") + await add_pending_interrupt( + session_id="s1", user_id="u1", + interrupt=PendingInterrupt( + interruptId="i1", providerId="slack", createdAt="2026-04-25T00:00:00Z", + ), + ) + await add_pending_interrupt( + session_id="s1", user_id="u1", + interrupt=PendingInterrupt( + interruptId="i2", providerId="gmail", createdAt="2026-04-25T00:00:01Z", + ), + ) + interrupts = await get_pending_interrupts("s1", "u1") + ids = {p.interrupt_id for p in interrupts} + assert ids == {"i1", "i2"} + + @pytest.mark.asyncio + async def test_reemit_dedupes_on_read_last_write_wins(self, sessions_metadata_table): + """Same id added twice → one entry on read, last write's payload survives.""" + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, add_pending_interrupt, get_pending_interrupts, + ) + from apis.shared.sessions.models import PendingInterrupt + + await ensure_session_metadata_exists("s1", "u1") + await add_pending_interrupt( + session_id="s1", user_id="u1", + interrupt=PendingInterrupt( + interruptId="i1", providerId="slack", createdAt="2026-04-25T00:00:00Z", + ), + ) + await add_pending_interrupt( + session_id="s1", user_id="u1", + interrupt=PendingInterrupt( + interruptId="i1", providerId="slack", + triggeringMessageId="msg-7", + createdAt="2026-04-25T00:00:05Z", + ), + ) + interrupts = await get_pending_interrupts("s1", "u1") + assert len(interrupts) == 1 + assert interrupts[0].interrupt_id == "i1" + assert interrupts[0].triggering_message_id == "msg-7" + assert interrupts[0].created_at == "2026-04-25T00:00:05Z" + + @pytest.mark.asyncio + async def test_noop_when_session_missing(self, sessions_metadata_table): + from apis.shared.sessions.metadata import add_pending_interrupt, get_pending_interrupts + from apis.shared.sessions.models import PendingInterrupt + + await add_pending_interrupt( + session_id="never-created", user_id="u1", + interrupt=PendingInterrupt( + interruptId="i1", providerId="slack", createdAt="2026-04-25T00:00:00Z", + ), + ) + interrupts = await get_pending_interrupts("never-created", "u1") + assert interrupts == [] + + +class TestInterruptsFromDynamoDedupe: + """Read-side dedupe collapses duplicates produced by list_append re-emits.""" + + def test_dedupe_by_id_last_write_wins(self): + from apis.shared.sessions.metadata import _interrupts_from_dynamo + raw = [ + {"interruptId": "i1", "providerId": "slack", "createdAt": "2026-04-25T00:00:00Z"}, + {"interruptId": "i2", "providerId": "gmail", "createdAt": "2026-04-25T00:00:01Z"}, + {"interruptId": "i1", "providerId": "slack", + "triggeringMessageId": "msg-7", "createdAt": "2026-04-25T00:00:05Z"}, + ] + result = _interrupts_from_dynamo(raw) + assert [p.interrupt_id for p in result] == ["i1", "i2"] + i1 = next(p for p in result if p.interrupt_id == "i1") + assert i1.triggering_message_id == "msg-7" + assert i1.created_at == "2026-04-25T00:00:05Z" + + def test_skips_unparseable_entries(self): + from apis.shared.sessions.metadata import _interrupts_from_dynamo + raw = [ + {"interruptId": "i1", "providerId": "slack", "createdAt": "2026-04-25T00:00:00Z"}, + {"missing": "required-fields"}, + "not a dict", + ] + result = _interrupts_from_dynamo(raw) + assert len(result) == 1 + assert result[0].interrupt_id == "i1" + + def test_empty_input(self): + from apis.shared.sessions.metadata import _interrupts_from_dynamo + assert _interrupts_from_dynamo(None) == [] + assert _interrupts_from_dynamo([]) == [] + assert _interrupts_from_dynamo("not a list") == [] + + +class TestPausedTurnSnapshot: + """PausedTurnSnapshot persistence — singleton, idempotent, round-trippable. + + The snapshot is the durable contract that lets a refresh / cache eviction + resume a paused agent turn — without it, the resume rebuilds an agent + with an empty tool registry and the paused tool call has nothing to + resume against. + """ + + @pytest.mark.asyncio + async def test_set_get_round_trip(self, sessions_metadata_table): + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, set_paused_turn, get_paused_turn, + ) + from apis.shared.sessions.models import PausedTurnSnapshot + + await ensure_session_metadata_exists("s1", "u1") + snap = PausedTurnSnapshot( + enabledTools=["calendar", "gmail"], modelId="claude-sonnet-4-6", + provider="bedrock", temperature=0.2, systemPrompt="prompt-text", + cachingEnabled=True, maxTokens=4096, + capturedAt="2026-04-25T00:00:00Z", expiresAt="2026-04-25T01:00:00Z", + ) + await set_paused_turn("s1", "u1", snap) + got = await get_paused_turn("s1", "u1") + assert got is not None + assert got.enabled_tools == ["calendar", "gmail"] + assert got.model_id == "claude-sonnet-4-6" + assert got.system_prompt == "prompt-text" + assert got.temperature == 0.2 + assert got.caching_enabled is True + + @pytest.mark.asyncio + async def test_idempotent_overwrite(self, sessions_metadata_table): + """Multiple OAuth interrupts in one turn share a single snapshot — + re-writing replaces in place rather than accumulating.""" + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, set_paused_turn, get_paused_turn, + ) + from apis.shared.sessions.models import PausedTurnSnapshot + + await ensure_session_metadata_exists("s1", "u1") + first = PausedTurnSnapshot( + enabledTools=["calendar"], capturedAt="2026-04-25T00:00:00Z", + expiresAt="2026-04-25T01:00:00Z", + ) + second = PausedTurnSnapshot( + enabledTools=["calendar", "gmail"], capturedAt="2026-04-25T00:00:01Z", + expiresAt="2026-04-25T01:00:01Z", + ) + await set_paused_turn("s1", "u1", first) + await set_paused_turn("s1", "u1", second) + got = await get_paused_turn("s1", "u1") + assert got is not None + assert got.enabled_tools == ["calendar", "gmail"] + assert got.captured_at == "2026-04-25T00:00:01Z" + + @pytest.mark.asyncio + async def test_clear_removes_snapshot(self, sessions_metadata_table): + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, set_paused_turn, + get_paused_turn, clear_paused_turn, + ) + from apis.shared.sessions.models import PausedTurnSnapshot + + await ensure_session_metadata_exists("s1", "u1") + await set_paused_turn( + "s1", "u1", + PausedTurnSnapshot( + enabledTools=["calendar"], capturedAt="2026-04-25T00:00:00Z", + expiresAt="2026-04-25T01:00:00Z", + ), + ) + assert await get_paused_turn("s1", "u1") is not None + await clear_paused_turn("s1", "u1") + assert await get_paused_turn("s1", "u1") is None + + @pytest.mark.asyncio + async def test_clear_is_noop_when_already_clear(self, sessions_metadata_table): + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, clear_paused_turn, get_paused_turn, + ) + + await ensure_session_metadata_exists("s1", "u1") + await clear_paused_turn("s1", "u1") + assert await get_paused_turn("s1", "u1") is None + + @pytest.mark.asyncio + async def test_set_noop_when_session_missing(self, sessions_metadata_table): + """Preview/anonymous sessions don't have a metadata row — write must + not crash and a subsequent get returns None.""" + from apis.shared.sessions.metadata import set_paused_turn, get_paused_turn + from apis.shared.sessions.models import PausedTurnSnapshot + + await set_paused_turn( + "never-created", "u1", + PausedTurnSnapshot( + enabledTools=["calendar"], capturedAt="2026-04-25T00:00:00Z", + expiresAt="2026-04-25T01:00:00Z", + ), + ) + assert await get_paused_turn("never-created", "u1") is None + + @pytest.mark.asyncio + async def test_paused_turn_independent_of_pending_interrupts(self, sessions_metadata_table): + """``paused_turn`` and ``pending_interrupts`` live on the same row + but their lifecycles don't intrude on each other — clearing one + leaves the other intact.""" + from apis.shared.sessions.metadata import ( + ensure_session_metadata_exists, set_paused_turn, clear_paused_turn, + add_pending_interrupt, get_pending_interrupts, get_paused_turn, + ) + from apis.shared.sessions.models import PausedTurnSnapshot, PendingInterrupt + + await ensure_session_metadata_exists("s1", "u1") + await set_paused_turn( + "s1", "u1", + PausedTurnSnapshot( + enabledTools=["calendar"], capturedAt="2026-04-25T00:00:00Z", + expiresAt="2026-04-25T01:00:00Z", + ), + ) + await add_pending_interrupt( + "s1", "u1", + PendingInterrupt( + interruptId="i1", providerId="calendar", createdAt="2026-04-25T00:00:00Z", + ), + ) + + await clear_paused_turn("s1", "u1") + assert await get_paused_turn("s1", "u1") is None + interrupts = await get_pending_interrupts("s1", "u1") + assert len(interrupts) == 1 + assert interrupts[0].interrupt_id == "i1" diff --git a/docs/SSE_ERROR_MESSAGING.md b/docs/SSE_ERROR_MESSAGING.md index e6a435ab..2d06af8b 100644 --- a/docs/SSE_ERROR_MESSAGING.md +++ b/docs/SSE_ERROR_MESSAGING.md @@ -138,6 +138,19 @@ if hasattr(session_manager, 'base_manager') and hasattr(session_manager.base_man - `StreamError` interface - Error state signals for UI +### OAuth Required + +**Backend**: `apis/shared/oauth/models.py` +- `OAuthRequiredEvent` - Pydantic model with `to_sse_format()` method +- Event name: `oauth_required`, payload: `{providerId, authorizationUrl}` + +**Routes**: `apis/inference_api/chat/routes.py` +- After the agent stream drains, pull consent URLs from + `ExternalMCPIntegration.drain_pending_consent(user_id)` and emit one + `oauth_required` event per provider before `done`. +- Consent URLs are populated during external MCP tool loading whenever + AgentCore Identity reports the user hasn't granted the required scopes. + ## Adding New Error Types 1. **Create event model** in `apis/shared/errors.py`: diff --git a/frontend/ai.client/src/app/admin/admin.page.ts b/frontend/ai.client/src/app/admin/admin.page.ts index dba672ef..c5b4166b 100644 --- a/frontend/ai.client/src/app/admin/admin.page.ts +++ b/frontend/ai.client/src/app/admin/admin.page.ts @@ -110,10 +110,10 @@ export class AdminPage { route: '/admin/auth-providers', }, { - title: 'OAuth Providers', - description: 'Configure third-party OAuth integrations for MCP tool authentication. Manage Google, Microsoft, GitHub, and custom providers.', + title: 'Connectors', + description: 'Configure third-party OAuth integrations that users can connect for MCP tool authentication. Manage Google, Microsoft, GitHub, and custom connectors.', icon: 'heroLink', - route: '/admin/oauth-providers', + route: '/admin/connectors', }, { title: 'Fine-Tuning Access', diff --git a/frontend/ai.client/src/app/admin/connectors/index.ts b/frontend/ai.client/src/app/admin/connectors/index.ts new file mode 100644 index 00000000..c87a34f0 --- /dev/null +++ b/frontend/ai.client/src/app/admin/connectors/index.ts @@ -0,0 +1,9 @@ +// Models +export * from './models/connector.model'; + +// Services +export * from './services/connectors.service'; + +// Pages +export * from './pages/connector-list.page'; +export * from './pages/connector-form.page'; diff --git a/frontend/ai.client/src/app/admin/connectors/models/connector.model.ts b/frontend/ai.client/src/app/admin/connectors/models/connector.model.ts new file mode 100644 index 00000000..9e7428af --- /dev/null +++ b/frontend/ai.client/src/app/admin/connectors/models/connector.model.ts @@ -0,0 +1,234 @@ +/** + * Connector type enumeration. + * + * `canvas` routes through AgentCore's `CustomOauth2` vendor but is kept + * distinct so the UI can surface Canvas-specific guidance if we add a + * preset later. Today it is treated like `custom`. + * + * `slack`, `salesforce`, and `zoom` are first-class AgentCore Identity + * vendors — endpoints + provider defaults are pre-configured by AgentCore, + * so admins only supply credentials and scopes (no discovery URL). + */ +export type ConnectorType = + | 'google' + | 'microsoft' + | 'github' + | 'slack' + | 'salesforce' + | 'zoom' + | 'canvas' + | 'custom'; + +/** + * Connector record as returned by the admin API. + * + * AgentCore Identity is authoritative for `clientId`, `clientSecret`, the + * vendor-specific endpoint config, and `callbackUrl`. Our backend caches + * the ARN and callback URL on the record for admin convenience. + */ +export interface Connector { + providerId: string; + displayName: string; + providerType: ConnectorType; + scopes: string[]; + allowedRoles: string[]; + enabled: boolean; + iconName: string; + /** + * Optional admin-uploaded icon as a base64 data URL. When set, frontends + * prefer this over `iconName`. Stored inline on the provider record. + */ + iconData?: string | null; + credentialProviderArn?: string | null; + callbackUrl?: string | null; + /** Custom/Canvas only — OIDC discovery URL or explicit server metadata. */ + oauthDiscoveryUrl?: string | null; + authorizationServerMetadata?: Record | null; + /** + * Vendor-specific OAuth params merged into AgentCore Identity's + * `customParameters` at request time. Examples: Google `hd=mycorp.com` + * for Workspace domain restriction, `prompt=consent` for stricter UX. + * Hardcoded vendor baselines (e.g. Google's `access_type=offline`) + * always win on conflict. + */ + customParameters?: Record | null; + createdAt: string; + updatedAt: string; +} + +/** + * Response model for listing connectors. + * + * The backend still returns the array under `providers` — we preserve the + * field name to match the wire format exactly. + */ +export interface ConnectorListResponse { + providers: Connector[]; + total: number; +} + +/** + * Create request. `clientId` and `clientSecret` are forwarded to AgentCore + * Identity and are never stored in our DynamoDB. Custom/Canvas providers + * must supply exactly one of `oauthDiscoveryUrl` or + * `authorizationServerMetadata`. + */ +export interface ConnectorCreateRequest { + providerId: string; + displayName: string; + providerType: ConnectorType; + clientId: string; + clientSecret: string; + scopes: string[]; + allowedRoles?: string[]; + enabled?: boolean; + iconName?: string; + /** Optional admin-uploaded icon as a base64 data URL. */ + iconData?: string; + oauthDiscoveryUrl?: string; + authorizationServerMetadata?: Record; + customParameters?: Record; +} + +/** + * Update request. Credential rotation requires `clientId` and + * `clientSecret` together; metadata-only edits leave them undefined. + * + * `customParameters: {}` explicitly clears all admin-supplied extras; + * `undefined` leaves the existing value alone. `iconData: ""` clears any + * uploaded icon (frontends fall back to `iconName`); `undefined` leaves it. + */ +export interface ConnectorUpdateRequest { + displayName?: string; + clientId?: string; + clientSecret?: string; + scopes?: string[]; + allowedRoles?: string[]; + enabled?: boolean; + iconName?: string; + iconData?: string; + oauthDiscoveryUrl?: string; + authorizationServerMetadata?: Record; + customParameters?: Record; +} + +/** + * Form data bound to the connector form. Scopes are a comma-separated + * string for admin entry; parsed into `string[]` before submit. + */ +export interface ConnectorFormData { + providerId: string; + displayName: string; + providerType: ConnectorType; + clientId: string; + clientSecret: string; + scopes: string; + allowedRoles: string[]; + enabled: boolean; + iconName: string; + oauthDiscoveryUrl: string; + /** + * Free-form `key=value` lines for admin-supplied custom OAuth parameters, + * one per line. Parsed to `Record` before submit. + */ + customParameters: string; +} + +/** + * Preset configuration for the connector picker. Endpoints are owned by + * AgentCore Identity and not configurable here. + * + * `defaultScopes` and `defaultCustomParameters` populate the form when the + * admin clicks a preset. `scopesPlaceholder` and `customParametersPlaceholder` + * are vendor-relevant examples shown when the field is empty (e.g. after + * the admin clears one to type their own). + */ +export interface ConnectorPreset { + type: ConnectorType; + displayName: string; + defaultScopes: string[]; + defaultCustomParameters?: Record; + iconName: string; + scopesPlaceholder?: string; + customParametersPlaceholder?: string; + /** Optional hint shown to the admin when selecting the preset. */ + hint?: string; +} + +export const CONNECTOR_PRESETS: ConnectorPreset[] = [ + { + type: 'google', + displayName: 'Google', + // No defaults — Google connectors are too multi-purpose to pre-pick + // (Calendar / Gmail / Drive / Docs all need different scopes, and the + // OIDC-only `openid email profile` set doesn't let an agent do + // anything useful). The placeholder shows the URL format so admins + // know what to type. + defaultScopes: [], + iconName: 'heroCloud', + scopesPlaceholder: + 'openid, email, profile, https://www.googleapis.com/auth/calendar.readonly', + customParametersPlaceholder: 'hd=mycompany.com\nprompt=consent', + }, + { + type: 'microsoft', + displayName: 'Microsoft', + defaultScopes: ['openid', 'email', 'profile', 'offline_access'], + iconName: 'heroCloud', + scopesPlaceholder: + 'openid, email, profile, offline_access, User.Read, Calendars.Read', + customParametersPlaceholder: 'domain_hint=mycompany.com\nprompt=consent', + }, + { + type: 'github', + displayName: 'GitHub', + defaultScopes: ['read:user', 'user:email'], + iconName: 'heroCodeBracket', + scopesPlaceholder: 'read:user, user:email, repo', + }, + { + type: 'slack', + displayName: 'Slack', + defaultScopes: ['chat:write', 'channels:read', 'users:read'], + iconName: 'heroChatBubbleLeftRight', + scopesPlaceholder: + 'chat:write, channels:read, channels:history, users:read, files:read', + customParametersPlaceholder: 'team=T0123456789', + }, + { + type: 'salesforce', + displayName: 'Salesforce', + defaultScopes: ['api', 'refresh_token', 'offline_access', 'id', 'openid'], + iconName: 'heroCloud', + scopesPlaceholder: + 'api, refresh_token, offline_access, id, openid, lightning, content', + customParametersPlaceholder: 'prompt=login\nlogin_hint=user@mycompany.com', + }, + { + type: 'zoom', + displayName: 'Zoom', + defaultScopes: ['user:read:user', 'meeting:read:meeting'], + iconName: 'heroVideoCamera', + scopesPlaceholder: + 'user:read:user, meeting:read:meeting, recording:read:recording', + }, + { + type: 'custom', + displayName: 'Custom (OIDC)', + defaultScopes: [], + iconName: 'heroLink', + scopesPlaceholder: 'openid, email, profile', + hint: 'Requires an OpenID Connect discovery URL', + }, +]; + +export function getConnectorPreset(type: ConnectorType): ConnectorPreset | undefined { + return CONNECTOR_PRESETS.find(preset => preset.type === type); +} + +/** + * True when the provider type needs an OIDC discovery URL. + */ +export function requiresDiscovery(type: ConnectorType): boolean { + return type === 'custom' || type === 'canvas'; +} diff --git a/frontend/ai.client/src/app/admin/connectors/pages/connector-form.page.ts b/frontend/ai.client/src/app/admin/connectors/pages/connector-form.page.ts new file mode 100644 index 00000000..b43df68c --- /dev/null +++ b/frontend/ai.client/src/app/admin/connectors/pages/connector-form.page.ts @@ -0,0 +1,1018 @@ +import { + Component, + ChangeDetectionStrategy, + inject, + signal, + computed, + OnInit, +} from '@angular/core'; +import { Router, ActivatedRoute } from '@angular/router'; +import { + FormBuilder, + FormGroup, + FormControl, + Validators, + ReactiveFormsModule, +} from '@angular/forms'; +import { NgIcon, provideIcons } from '@ng-icons/core'; +import { + heroArrowLeft, + heroInformationCircle, + heroEye, + heroEyeSlash, + heroExclamationTriangle, + heroCheckCircle, + heroClipboard, + heroClipboardDocumentCheck, + heroCloud, + heroCodeBracket, + heroAcademicCap, + heroLink, + heroChatBubbleLeftRight, + heroVideoCamera, +} from '@ng-icons/heroicons/outline'; +import { ConnectorsService } from '../services/connectors.service'; +import { AppRolesService } from '../../roles/services/app-roles.service'; +import { + Connector, + ConnectorCreateRequest, + ConnectorUpdateRequest, + ConnectorType, + CONNECTOR_PRESETS, + getConnectorPreset, + requiresDiscovery, +} from '../models/connector.model'; +import { TooltipDirective } from '../../../components/tooltip/tooltip.directive'; + +interface ConnectorFormGroup { + providerId: FormControl; + displayName: FormControl; + providerType: FormControl; + clientId: FormControl; + clientSecret: FormControl; + oauthDiscoveryUrl: FormControl; + scopes: FormControl; + allowedRoles: FormControl; + grantAllRoles: FormControl; + enabled: FormControl; + iconName: FormControl; + /** + * Optional uploaded icon as a base64 data URL. `''` means no upload (fall + * back to `iconName`). On update, sending `''` to the backend clears any + * previously uploaded icon. + */ + iconData: FormControl; + /** + * Free-form `key=value` lines (one per line) for vendor-specific OAuth + * params. Parsed to `Record` before submit. Blank lines + * and lines without `=` are silently dropped. + */ + customParameters: FormControl; +} + +const ICON_DATA_MAX_BYTES = 100 * 1024; +const ICON_ACCEPTED_MIME_TYPES = [ + 'image/png', + 'image/jpeg', + 'image/gif', + 'image/webp', + 'image/svg+xml', +]; + +@Component({ + selector: 'app-connector-form', + changeDetection: ChangeDetectionStrategy.OnPush, + imports: [ReactiveFormsModule, NgIcon, TooltipDirective], + providers: [ + provideIcons({ + heroArrowLeft, + heroInformationCircle, + heroEye, + heroEyeSlash, + heroExclamationTriangle, + heroCheckCircle, + heroClipboard, + heroClipboardDocumentCheck, + heroCloud, + heroCodeBracket, + heroAcademicCap, + heroLink, + heroChatBubbleLeftRight, + heroVideoCamera, + }), + ], + host: { class: 'block' }, + template: ` +
+
+ + +
+

+ {{ pageTitle() }} +

+

+ {{ isEditMode() ? 'Update connector settings and credentials' : 'Register a new OAuth connector' }} +

+
+ + @if (loading()) { +
+
+
+

Loading connector...

+
+
+ } @else if (createdConnector(); as created) { + +
+
+
+ +
+

+ Connector created +

+

+ "{{ created.displayName }}" is registered with AgentCore Identity. +

+
+
+
+ +
+

+ Next step: register this callback URL +

+

+ Add the following URL to your OAuth provider's list of authorized redirect URIs + (in Google Cloud Console, Microsoft Entra, GitHub OAuth App settings, etc.). + Until this is done, users will see an error when they try to consent. +

+
+ + +
+
+ +
+ +
+
+ } @else { +
+ + @if (!isEditMode()) { +
+

Connector Type

+

+ Choose a preset or use Custom for any OIDC-compliant provider. +

+
+ @for (preset of presets; track preset.type) { + + } +
+
+ } + +
+

Basic Information

+
+
+ + +

+ Unique identifier. Lowercase letters, numbers, and hyphens only. +

+ @if (connectorForm.controls.providerId.invalid && connectorForm.controls.providerId.touched) { +

+ @if (connectorForm.controls.providerId.errors?.['required']) { Connector ID is required } + @else if (connectorForm.controls.providerId.errors?.['pattern']) { Must be lowercase letters, numbers, and hyphens only } + @else if (connectorForm.controls.providerId.errors?.['maxlength']) { Must be at most 64 characters } +

+ } +
+ +
+ + + @if (connectorForm.controls.displayName.invalid && connectorForm.controls.displayName.touched) { +

Display name is required

+ } +
+ +
+ +
+
+ @if (connectorForm.controls.iconData.value) { + Connector icon preview + } @else { +
+
+ + @if (connectorForm.controls.iconData.value) { + + } +
+
+

+ PNG, JPEG, GIF, WebP, or SVG. Max 100KB. Falls back to the default icon when no image is uploaded. +

+ @if (iconUploadError(); as iconErr) { +

{{ iconErr }}

+ } +
+ +
+ + +
+
+
+ + + @if (isEditMode() && loadedConnector(); as loaded) { +
+

AgentCore Identity

+
+
+ +
+ + +
+
+ @if (loaded.credentialProviderArn) { +
+ + +
+ } +
+
+ } + +
+

OAuth Credentials

+

+ @if (isEditMode()) { + Enter both fields to rotate credentials. Leave both blank to keep existing. + } @else { + Credentials are stored by AWS Bedrock AgentCore Identity — never by this application. + } +

+ +
+
+ + +
+ +
+ +
+ + +
+
+ + @if (credentialPairError()) { +

{{ credentialPairError() }}

+ } + + @if (needsDiscovery()) { +
+ + +

+ AgentCore fetches this URL to resolve authorization and token endpoints. +

+
+ } + +
+ + +

+ Comma-separated list of OAuth scopes to request during authorization. +

+
+ +
+ + +

+ One key=value pair per line, forwarded to AgentCore Identity as customParameters. + Required vendor params (e.g. Google's access_type=offline) are sent automatically and override any conflicting entries here. +

+
+
+
+ +
+

Access Control

+

+ Restrict which application roles can use this connector. +

+
+ +
+ + +
+ + @if (!connectorForm.controls.grantAllRoles.value) { + @if (rolesResource.isLoading() || rolesResource.value() === undefined) { +
+
+

Loading roles...

+
+ } @else if (availableRoles().length > 0) { +
+ @for (role of availableRoles(); track role.roleId) { + + } +
+ } @else { +

+ No roles available. Create roles in Role Management first. +

+ } + } +

+ Only users with selected roles will be able to use this connector. +

+
+
+ + @if (isEditMode()) { +
+
+ +
+

Security Notice

+

+ Changing scopes forces connected users to re-consent on their next tool call. + Rotating credentials requires re-entering both Client ID and Client Secret. +

+
+
+
+ } + +
+ + +
+
+ } +
+
+ `, +}) +export class ConnectorFormPage implements OnInit { + private fb = inject(FormBuilder); + private router = inject(Router); + private route = inject(ActivatedRoute); + private connectorsService = inject(ConnectorsService); + private appRolesService = inject(AppRolesService); + + readonly rolesResource = this.appRolesService.rolesResource; + + readonly presets = CONNECTOR_PRESETS; + readonly acceptedIconTypes = ICON_ACCEPTED_MIME_TYPES.join(','); + + readonly isEditMode = signal(false); + readonly providerId = signal(null); + readonly isSubmitting = signal(false); + readonly loading = signal(false); + readonly showClientSecret = signal(false); + readonly loadedConnector = signal(null); + readonly createdConnector = signal(null); + readonly callbackCopied = signal(false); + /** Validation error from the most recent file pick (null when ok). */ + readonly iconUploadError = signal(null); + /** + * Tracks the icon_data value that was loaded from the server, so we know + * whether to send `iconData: ""` on update (clear) when the admin removes + * the upload. `null` means no icon was loaded; a string means one was. + */ + private readonly iconLoadedFromServer = signal(null); + + readonly connectorForm: FormGroup = this.fb.group({ + providerId: this.fb.control('', { + nonNullable: true, + validators: [ + Validators.required, + Validators.minLength(1), + Validators.maxLength(64), + Validators.pattern(/^[a-z0-9-]+$/), + ], + }), + displayName: this.fb.control('', { + nonNullable: true, + validators: [Validators.required, Validators.maxLength(100)], + }), + providerType: this.fb.control('custom', { nonNullable: true }), + clientId: this.fb.control('', { nonNullable: true }), + clientSecret: this.fb.control('', { nonNullable: true }), + oauthDiscoveryUrl: this.fb.control('', { nonNullable: true }), + scopes: this.fb.control('', { nonNullable: true }), + allowedRoles: this.fb.control(['*'], { nonNullable: true }), + grantAllRoles: this.fb.control(true, { nonNullable: true }), + enabled: this.fb.control(true, { nonNullable: true }), + iconName: this.fb.control('heroLink', { nonNullable: true }), + iconData: this.fb.control('', { nonNullable: true }), + customParameters: this.fb.control('', { nonNullable: true }), + }); + + readonly pageTitle = computed(() => (this.isEditMode() ? 'Edit Connector' : 'Add Connector')); + + readonly availableRoles = computed(() => this.appRolesService.getEnabledRoles()); + + readonly selectedRoles = signal(['*']); + + // Form controls aren't observable signals, so mirror providerType into a + // signal updated from valueChanges. This drives the template's @if for + // discovery, the placeholder lookups, and the submit-time discovery + // gating. + readonly needsDiscovery = signal( + requiresDiscovery(this.connectorForm.controls.providerType.value) + ); + + /** Mirrors `providerType` so computed placeholders react to changes. */ + private readonly providerTypeSignal = signal( + this.connectorForm.controls.providerType.value, + ); + + /** Vendor-relevant scopes example shown when the field is empty. */ + readonly scopesPlaceholder = computed(() => { + const preset = getConnectorPreset(this.providerTypeSignal()); + return preset?.scopesPlaceholder ?? 'openid, email, profile'; + }); + + /** + * Vendor-relevant `key=value` example for the custom-parameters + * textarea. Generic `key=value` fallback for vendors with no + * commonly-used extras. + */ + readonly customParametersPlaceholder = computed(() => { + const preset = getConnectorPreset(this.providerTypeSignal()); + return preset?.customParametersPlaceholder ?? 'key=value'; + }); + + /** + * Returns a user-facing error string when clientId and clientSecret are + * inconsistent. Rotation requires both or neither. + */ + readonly credentialPairError = computed(() => { + const id = this.connectorForm.controls.clientId.value.trim(); + const secret = this.connectorForm.controls.clientSecret.value.trim(); + if (!this.isEditMode()) return null; // create mode requires both, enforced elsewhere + if (!!id === !!secret) return null; + return 'Client ID and Client Secret must be provided together to rotate credentials.'; + }); + + ngOnInit(): void { + const id = this.route.snapshot.paramMap.get('providerId'); + if (id && id !== 'new') { + this.isEditMode.set(true); + this.providerId.set(id); + this.loadConnectorData(id); + } else { + this.connectorForm.controls.clientId.setValidators([Validators.required]); + this.connectorForm.controls.clientSecret.setValidators([Validators.required]); + this.connectorForm.controls.clientId.updateValueAndValidity(); + this.connectorForm.controls.clientSecret.updateValueAndValidity(); + this.applyDiscoveryValidator(); + } + + this.connectorForm.controls.providerType.valueChanges.subscribe((value) => { + this.providerTypeSignal.set(value); + this.applyDiscoveryValidator(); + }); + } + + private applyDiscoveryValidator(): void { + const ctrl = this.connectorForm.controls.oauthDiscoveryUrl; + const needs = requiresDiscovery(this.connectorForm.controls.providerType.value); + this.needsDiscovery.set(needs); + if (needs) { + ctrl.setValidators([Validators.required, Validators.pattern(/^https?:\/\/.+/)]); + } else { + ctrl.clearValidators(); + ctrl.setValue(''); + } + ctrl.updateValueAndValidity({ emitEvent: false }); + } + + private async loadConnectorData(id: string): Promise { + this.loading.set(true); + try { + const connector = await this.connectorsService.fetchConnector(id); + this.loadedConnector.set(connector); + + this.connectorForm.patchValue({ + providerId: connector.providerId, + displayName: connector.displayName, + providerType: connector.providerType, + clientId: '', + clientSecret: '', + oauthDiscoveryUrl: connector.oauthDiscoveryUrl ?? '', + scopes: connector.scopes.join(', '), + allowedRoles: connector.allowedRoles.length > 0 ? connector.allowedRoles : ['*'], + grantAllRoles: connector.allowedRoles.length === 0, + enabled: connector.enabled, + iconName: connector.iconName || 'heroLink', + iconData: connector.iconData ?? '', + customParameters: this.serializeCustomParameters(connector.customParameters ?? null), + }); + this.iconLoadedFromServer.set(connector.iconData ?? null); + this.selectedRoles.set(connector.allowedRoles.length > 0 ? connector.allowedRoles : ['*']); + this.applyDiscoveryValidator(); + } catch (error) { + console.error('Error loading connector:', error); + alert('Failed to load connector. Returning to list.'); + this.router.navigate(['/admin/connectors']); + } finally { + this.loading.set(false); + } + } + + selectConnectorType(type: ConnectorType): void { + const preset = getConnectorPreset(type); + if (preset) { + this.connectorForm.patchValue({ + providerType: type, + displayName: preset.displayName, + scopes: preset.defaultScopes.join(', '), + iconName: preset.iconName, + // Only seed customParameters from a preset if the preset declares + // them — most don't, and we don't want to clobber whatever the + // admin has already typed. + ...(preset.defaultCustomParameters + ? { + customParameters: this.serializeCustomParameters( + preset.defaultCustomParameters, + ), + } + : {}), + }); + } + this.applyDiscoveryValidator(); + } + + getPresetIconClasses(type: ConnectorType): string { + const base = 'flex size-10 items-center justify-center rounded-sm'; + switch (type) { + case 'google': + return `${base} bg-red-100 text-red-600 dark:bg-red-900/30 dark:text-red-400`; + case 'microsoft': + return `${base} bg-blue-100 text-blue-600 dark:bg-blue-900/30 dark:text-blue-400`; + case 'github': + return `${base} bg-gray-800 text-white dark:bg-gray-600`; + case 'slack': + return `${base} bg-fuchsia-100 text-fuchsia-700 dark:bg-fuchsia-900/30 dark:text-fuchsia-300`; + case 'salesforce': + return `${base} bg-sky-100 text-sky-700 dark:bg-sky-900/30 dark:text-sky-300`; + case 'zoom': + return `${base} bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-300`; + case 'canvas': + return `${base} bg-orange-100 text-orange-600 dark:bg-orange-900/30 dark:text-orange-400`; + default: + return `${base} bg-purple-100 text-purple-600 dark:bg-purple-900/30 dark:text-purple-400`; + } + } + + onGrantAllRolesChange(): void { + const checked = this.connectorForm.controls.grantAllRoles.value; + if (checked) { + this.connectorForm.controls.allowedRoles.setValue(['*']); + this.selectedRoles.set(['*']); + } else { + this.connectorForm.controls.allowedRoles.setValue([]); + this.selectedRoles.set([]); + if (this.rolesResource.value() === undefined) { + this.rolesResource.reload(); + } + } + } + + isRoleSelected(roleId: string): boolean { + return this.selectedRoles().includes(roleId); + } + + toggleRole(roleId: string): void { + const currentRoles = this.selectedRoles().filter(r => r !== '*'); + const newRoles = currentRoles.includes(roleId) + ? currentRoles.filter(r => r !== roleId) + : [...currentRoles, roleId]; + this.connectorForm.controls.allowedRoles.setValue(newRoles); + this.selectedRoles.set(newRoles); + } + + async copyCallbackUrl(url: string): Promise { + if (!url) return; + try { + await navigator.clipboard.writeText(url); + this.callbackCopied.set(true); + setTimeout(() => this.callbackCopied.set(false), 2000); + } catch (err) { + console.error('Clipboard write failed', err); + } + } + + async onSubmit(): Promise { + if (this.connectorForm.invalid || this.credentialPairError()) { + this.connectorForm.markAllAsTouched(); + return; + } + + this.isSubmitting.set(true); + try { + const formValue = this.connectorForm.getRawValue(); + + const scopes = formValue.scopes + ? formValue.scopes.split(',').map((s: string) => s.trim()).filter(Boolean) + : []; + + const allowedRoles = formValue.grantAllRoles + ? [] + : (formValue.allowedRoles || []).filter((r: string) => r !== '*'); + + // Parse the textarea into a key/value map. The empty case sends `{}` + // on update (explicitly clears extras) and is omitted on create. + const customParameters = this.parseCustomParameters(formValue.customParameters); + + if (this.isEditMode() && this.providerId()) { + const updates: ConnectorUpdateRequest = { + displayName: formValue.displayName, + scopes, + allowedRoles, + enabled: formValue.enabled, + iconName: formValue.iconName, + customParameters, + }; + // Tri-state for iconData: only send when the admin actually changed + // it. Replaced upload → send the new data URL. Removed an existing + // upload → send `""` so the backend clears it. No change → omit. + const previousIcon = this.iconLoadedFromServer(); + const currentIcon = formValue.iconData || ''; + if (currentIcon !== (previousIcon ?? '')) { + updates.iconData = currentIcon; + } + if (formValue.clientId && formValue.clientSecret) { + updates.clientId = formValue.clientId; + updates.clientSecret = formValue.clientSecret; + } + if (this.needsDiscovery() && formValue.oauthDiscoveryUrl) { + updates.oauthDiscoveryUrl = formValue.oauthDiscoveryUrl; + } + await this.connectorsService.updateConnector(this.providerId()!, updates); + this.router.navigate(['/admin/connectors']); + } else { + const createData: ConnectorCreateRequest = { + providerId: formValue.providerId, + displayName: formValue.displayName, + providerType: formValue.providerType, + clientId: formValue.clientId, + clientSecret: formValue.clientSecret, + scopes, + allowedRoles, + enabled: formValue.enabled, + iconName: formValue.iconName, + }; + if (formValue.iconData) { + createData.iconData = formValue.iconData; + } + if (this.needsDiscovery() && formValue.oauthDiscoveryUrl) { + createData.oauthDiscoveryUrl = formValue.oauthDiscoveryUrl; + } + if (Object.keys(customParameters).length > 0) { + createData.customParameters = customParameters; + } + const created = await this.connectorsService.createConnector(createData); + this.createdConnector.set(created); + } + } catch (error: unknown) { + console.error('Error saving connector:', error); + alert(this.formatErrorMessage(error)); + } finally { + this.isSubmitting.set(false); + } + } + + goBack(): void { + this.router.navigate(['/admin/connectors']); + } + + /** + * FastAPI returns validation errors as `detail: [{loc, msg, type, ...}]` + * and business errors as `detail: "string"`. Collapse both shapes into a + * single human-readable string for the alert. + */ + private formatErrorMessage(error: unknown): string { + const body = (error as { error?: { detail?: unknown } } | null)?.error; + const detail = body?.detail; + if (typeof detail === 'string') return detail; + if (Array.isArray(detail)) { + return detail + .map(d => (d as { msg?: string })?.msg) + .filter((m): m is string => typeof m === 'string') + .join('\n') || 'Failed to save connector.'; + } + const message = (error as { message?: string } | null)?.message; + return message ?? 'Failed to save connector.'; + } + + /** + * Parse the textarea contents (one `key=value` per line) into a map. + * Blank lines and lines without `=` are silently dropped — the admin + * sees the cleaned-up version when they re-open the form for editing. + */ + private parseCustomParameters(raw: string): Record { + const out: Record = {}; + if (!raw) return out; + for (const line of raw.split(/\r?\n/)) { + const trimmed = line.trim(); + if (!trimmed) continue; + const eq = trimmed.indexOf('='); + if (eq <= 0) continue; // drop lines with no `=` or empty key + const key = trimmed.slice(0, eq).trim(); + const value = trimmed.slice(eq + 1).trim(); + if (!key) continue; + out[key] = value; + } + return out; + } + + /** + * Serialize a saved map back into the `key=value\nkey=value` textarea + * format. Keys are sorted for deterministic display so admin diffs stay + * stable across edits. + */ + private serializeCustomParameters( + map: Record | null, + ): string { + if (!map) return ''; + return Object.keys(map) + .sort() + .map(key => `${key}=${map[key]}`) + .join('\n'); + } + + async onIconFileSelected(event: Event): Promise { + const input = event.target as HTMLInputElement; + const file = input.files?.[0]; + if (!file) return; + + if (!ICON_ACCEPTED_MIME_TYPES.includes(file.type)) { + this.iconUploadError.set( + 'Unsupported file type. Use PNG, JPEG, GIF, WebP, or SVG.', + ); + input.value = ''; + return; + } + if (file.size > ICON_DATA_MAX_BYTES) { + this.iconUploadError.set( + `Icon must be ${Math.floor(ICON_DATA_MAX_BYTES / 1024)}KB or smaller.`, + ); + input.value = ''; + return; + } + + try { + const dataUrl = await this.readFileAsDataUrl(file); + this.connectorForm.controls.iconData.setValue(dataUrl); + this.connectorForm.controls.iconData.markAsDirty(); + this.iconUploadError.set(null); + } catch (err) { + console.error('Icon upload read failed', err); + this.iconUploadError.set('Failed to read the file. Try again.'); + } finally { + // Reset so picking the same file again still re-fires the change event. + input.value = ''; + } + } + + removeUploadedIcon(): void { + this.connectorForm.controls.iconData.setValue(''); + this.connectorForm.controls.iconData.markAsDirty(); + this.iconUploadError.set(null); + } + + private readFileAsDataUrl(file: File): Promise { + return new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = () => resolve(reader.result as string); + reader.onerror = () => reject(reader.error); + reader.readAsDataURL(file); + }); + } +} diff --git a/frontend/ai.client/src/app/admin/oauth-providers/pages/provider-list.page.ts b/frontend/ai.client/src/app/admin/connectors/pages/connector-list.page.ts similarity index 73% rename from frontend/ai.client/src/app/admin/oauth-providers/pages/provider-list.page.ts rename to frontend/ai.client/src/app/admin/connectors/pages/connector-list.page.ts index b836e2c0..720b9e0f 100644 --- a/frontend/ai.client/src/app/admin/oauth-providers/pages/provider-list.page.ts +++ b/frontend/ai.client/src/app/admin/connectors/pages/connector-list.page.ts @@ -7,6 +7,8 @@ import { } from '@angular/core'; import { Router, RouterLink } from '@angular/router'; import { FormsModule } from '@angular/forms'; +import { Dialog } from '@angular/cdk/dialog'; +import { firstValueFrom } from 'rxjs'; import { NgIcon, provideIcons } from '@ng-icons/core'; import { heroPlus, @@ -23,12 +25,16 @@ import { heroXCircle, heroShieldCheck, } from '@ng-icons/heroicons/outline'; -import { OAuthProvidersService } from '../services/oauth-providers.service'; -import { OAuthProvider, OAuthProviderType } from '../models/oauth-provider.model'; +import { ConnectorsService } from '../services/connectors.service'; +import { Connector, ConnectorType } from '../models/connector.model'; import { TooltipDirective } from '../../../components/tooltip/tooltip.directive'; +import { + ConfirmationDialogComponent, + ConfirmationDialogData, +} from '../../../components/confirmation-dialog'; @Component({ - selector: 'app-provider-list', + selector: 'app-connector-list', changeDetection: ChangeDetectionStrategy.OnPush, imports: [RouterLink, FormsModule, NgIcon, TooltipDirective], providers: [ @@ -66,17 +72,17 @@ import { TooltipDirective } from '../../../components/tooltip/tooltip.directive'
-

OAuth Providers

+

Connectors

- Configure third-party OAuth integrations for MCP tool authentication. + Configure third-party OAuth integrations that users can connect for MCP tool authentication.

- Add Provider + Add Connector
@@ -90,7 +96,7 @@ import { TooltipDirective } from '../../../components/tooltip/tooltip.directive' @if (searchQuery()) { @@ -110,7 +116,7 @@ import { TooltipDirective } from '../../../components/tooltip/tooltip.directive' (ngModelChange)="enabledFilter.set($event)" class="rounded-sm border border-gray-300 bg-white px-3 py-2.5 text-sm/6 dark:border-gray-600 dark:bg-gray-800 dark:text-white" > - + @@ -139,26 +145,26 @@ import { TooltipDirective } from '../../../components/tooltip/tooltip.directive' - @if (providersResource.isLoading() && providers().length === 0) { + @if (connectorsResource.isLoading() && connectors().length === 0) {

- Loading providers... + Loading connectors...

} - @if (providersResource.error()) { + @if (connectorsResource.error()) {
-

Failed to load providers

+

Failed to load connectors

Please check your connection and try again.

} - - @if (!providersResource.isLoading() || providers().length > 0) { + + @if (!connectorsResource.isLoading() || connectors().length > 0) {
- @for (provider of filteredProviders(); track provider.providerId) { + @for (connector of filteredConnectors(); track connector.providerId) { - +
- Provider + Connector
-
- -
+ @if (connector.iconData) { +
+ +
+ } @else { +
+ +
+ }

- {{ provider.displayName }} + {{ connector.displayName }}

- {{ provider.providerId }} + {{ connector.providerId }}

@@ -217,23 +233,23 @@ import { TooltipDirective } from '../../../components/tooltip/tooltip.directive'
- @if (provider.enabled) { + @if (connector.enabled) { Active @@ -276,17 +292,17 @@ import { TooltipDirective } from '../../../components/tooltip/tooltip.directive'
- @if (filteredProviders().length === 0 && !providersResource.isLoading()) { + @if (filteredConnectors().length === 0 && !connectorsResource.isLoading()) {
@if (hasActiveFilters()) { -

No providers match your filters

+

No connectors match your filters

Try adjusting your search or filter criteria.

@@ -332,15 +348,15 @@ import { TooltipDirective } from '../../../components/tooltip/tooltip.directive' } - @if (providers().length > 0) { + @if (connectors().length > 0) {
-

About OAuth Providers

+

About Connectors

- Provider Types: Choose from common presets (Google, Microsoft, GitHub, Canvas) or configure a custom OAuth 2.0 provider. + Connector Types: Choose from common presets (Google, Microsoft, GitHub, Canvas) or configure a custom OAuth 2.0 connector.

- Role Restrictions: Control which application roles can connect to each provider. Leave empty for unrestricted access. + Role Restrictions: Control which application roles can use each connector. Leave empty for unrestricted access.

Security: Client secrets are encrypted and stored securely. They are never exposed to the frontend after creation. @@ -352,47 +368,45 @@ import { TooltipDirective } from '../../../components/tooltip/tooltip.directive'

`, }) -export class ProviderListPage { - oauthProvidersService = inject(OAuthProvidersService); +export class ConnectorListPage { + connectorsService = inject(ConnectorsService); private router = inject(Router); + private dialog = inject(Dialog); - readonly providersResource = this.oauthProvidersService.providersResource; + readonly connectorsResource = this.connectorsService.connectorsResource; - // Local state searchQuery = signal(''); enabledFilter = signal(''); typeFilter = signal(''); - // Computed - readonly providers = computed(() => this.oauthProvidersService.getProviders()); + readonly connectors = computed(() => this.connectorsService.getConnectors()); - readonly filteredProviders = computed(() => { - let providers = this.providers(); + readonly filteredConnectors = computed(() => { + let connectors = this.connectors(); const query = this.searchQuery().toLowerCase(); const enabled = this.enabledFilter(); const type = this.typeFilter(); if (query) { - providers = providers.filter( - p => - p.displayName.toLowerCase().includes(query) || - p.providerId.toLowerCase().includes(query) || - p.providerType.toLowerCase().includes(query) + connectors = connectors.filter( + c => + c.displayName.toLowerCase().includes(query) || + c.providerId.toLowerCase().includes(query) || + c.providerType.toLowerCase().includes(query) ); } if (enabled === 'enabled') { - providers = providers.filter(p => p.enabled); + connectors = connectors.filter(c => c.enabled); } else if (enabled === 'disabled') { - providers = providers.filter(p => !p.enabled); + connectors = connectors.filter(c => !c.enabled); } if (type) { - providers = providers.filter(p => p.providerType === type); + connectors = connectors.filter(c => c.providerType === type); } - // Sort: enabled first, then alphabetically by display name - return providers.sort((a, b) => { + return connectors.sort((a, b) => { if (a.enabled !== b.enabled) { return a.enabled ? -1 : 1; } @@ -410,12 +424,11 @@ export class ProviderListPage { this.typeFilter.set(''); } - getProviderIcon(provider: OAuthProvider): string { - if (provider.iconName) { - return provider.iconName; + getConnectorIcon(connector: Connector): string { + if (connector.iconName) { + return connector.iconName; } - // Default icons by type - switch (provider.providerType) { + switch (connector.providerType) { case 'google': case 'microsoft': return 'heroCloud'; @@ -428,7 +441,7 @@ export class ProviderListPage { } } - getProviderIconClasses(type: OAuthProviderType): string { + getConnectorIconClasses(type: ConnectorType): string { const baseClasses = 'flex size-10 shrink-0 items-center justify-center rounded-sm'; switch (type) { case 'google': @@ -444,7 +457,7 @@ export class ProviderListPage { } } - getProviderTypeBadgeClasses(type: OAuthProviderType): string { + getConnectorTypeBadgeClasses(type: ConnectorType): string { const baseClasses = 'inline-flex items-center rounded-xs px-2 py-0.5 text-xs font-medium'; switch (type) { case 'google': @@ -460,7 +473,7 @@ export class ProviderListPage { } } - getProviderTypeLabel(type: OAuthProviderType): string { + getConnectorTypeLabel(type: ConnectorType): string { switch (type) { case 'google': return 'Google'; @@ -475,16 +488,27 @@ export class ProviderListPage { } } - async deleteProvider(provider: OAuthProvider): Promise { - if (!confirm(`Are you sure you want to delete the provider "${provider.displayName}"?\n\nThis will disconnect all users currently using this provider. This action cannot be undone.`)) { - return; - } + async deleteConnector(connector: Connector): Promise { + const dialogRef = this.dialog.open(ConfirmationDialogComponent, { + data: { + title: `Delete ${connector.displayName}`, + message: + `This will disconnect all users currently using this connector ` + + `and delete it from AgentCore Identity. This action cannot be undone.`, + confirmText: 'Delete', + cancelText: 'Cancel', + destructive: true, + } as ConfirmationDialogData, + }); + + const confirmed = await firstValueFrom(dialogRef.closed); + if (confirmed !== true) return; try { - await this.oauthProvidersService.deleteProvider(provider.providerId); + await this.connectorsService.deleteConnector(connector.providerId); } catch (error: any) { - console.error('Error deleting provider:', error); - const message = error?.error?.detail || error?.message || 'Failed to delete provider.'; + console.error('Error deleting connector:', error); + const message = error?.error?.detail || error?.message || 'Failed to delete connector.'; alert(message); } } diff --git a/frontend/ai.client/src/app/admin/oauth-providers/services/oauth-providers.service.spec.ts b/frontend/ai.client/src/app/admin/connectors/services/connectors.service.spec.ts similarity index 59% rename from frontend/ai.client/src/app/admin/oauth-providers/services/oauth-providers.service.spec.ts rename to frontend/ai.client/src/app/admin/connectors/services/connectors.service.spec.ts index bbea6c75..be7aee54 100644 --- a/frontend/ai.client/src/app/admin/oauth-providers/services/oauth-providers.service.spec.ts +++ b/frontend/ai.client/src/app/admin/connectors/services/connectors.service.spec.ts @@ -2,12 +2,12 @@ import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; import { TestBed } from '@angular/core/testing'; import { HttpClientTestingModule, HttpTestingController } from '@angular/common/http/testing'; import { signal } from '@angular/core'; -import { OAuthProvidersService } from './oauth-providers.service'; +import { ConnectorsService } from './connectors.service'; import { ConfigService } from '../../../services/config.service'; import { AuthService } from '../../../auth/auth.service'; -describe('OAuthProvidersService', () => { - let service: OAuthProvidersService; +describe('ConnectorsService', () => { + let service: ConnectorsService; let httpMock: HttpTestingController; beforeEach(() => { @@ -15,63 +15,63 @@ describe('OAuthProvidersService', () => { TestBed.configureTestingModule({ imports: [HttpClientTestingModule], providers: [ - OAuthProvidersService, + ConnectorsService, { provide: AuthService, useValue: { ensureAuthenticated: vi.fn().mockResolvedValue(undefined) } }, { provide: ConfigService, useValue: { appApiUrl: signal('http://localhost:8000') } }, ], }); - service = TestBed.inject(OAuthProvidersService); + service = TestBed.inject(ConnectorsService); httpMock = TestBed.inject(HttpTestingController); }); afterEach(() => { - httpMock.match(() => true); // discard pending requests + httpMock.match(() => true); TestBed.resetTestingModule(); }); - it('should fetch providers', async () => { + it('should fetch connectors', async () => { const mockResponse = { providers: [], total: 0 }; - const promise = service.fetchProviders(); + const promise = service.fetchConnectors(); await vi.waitFor(() => { httpMock.expectOne('http://localhost:8000/admin/oauth-providers/').flush(mockResponse); }); expect(await promise).toEqual(mockResponse); }); - it('should fetch provider by id', async () => { - const mockProvider = { provider_id: '1', name: 'Test Provider' }; - const promise = service.fetchProvider('1'); + it('should fetch connector by id', async () => { + const mockConnector = { provider_id: '1', name: 'Test Connector' }; + const promise = service.fetchConnector('1'); await vi.waitFor(() => { - httpMock.expectOne('http://localhost:8000/admin/oauth-providers/1').flush(mockProvider); + httpMock.expectOne('http://localhost:8000/admin/oauth-providers/1').flush(mockConnector); }); - expect(await promise).toEqual({ providerId: '1', name: 'Test Provider' }); + expect(await promise).toEqual({ providerId: '1', name: 'Test Connector' }); }); - it('should create provider', async () => { - const providerData = { name: 'New Provider' } as any; - const mockProvider = { provider_id: '1', name: 'New Provider' }; - const promise = service.createProvider(providerData); + it('should create connector', async () => { + const data = { name: 'New Connector' } as any; + const mockConnector = { provider_id: '1', name: 'New Connector' }; + const promise = service.createConnector(data); await vi.waitFor(() => { - httpMock.expectOne('http://localhost:8000/admin/oauth-providers/').flush(mockProvider); + httpMock.expectOne('http://localhost:8000/admin/oauth-providers/').flush(mockConnector); }); - expect(await promise).toEqual({ providerId: '1', name: 'New Provider' }); + expect(await promise).toEqual({ providerId: '1', name: 'New Connector' }); }); - it('should update provider', async () => { - const updates = { name: 'Updated Provider' } as any; - const mockProvider = { provider_id: '1', name: 'Updated Provider' }; - const promise = service.updateProvider('1', updates); + it('should update connector', async () => { + const updates = { name: 'Updated Connector' } as any; + const mockConnector = { provider_id: '1', name: 'Updated Connector' }; + const promise = service.updateConnector('1', updates); await vi.waitFor(() => { - httpMock.expectOne('http://localhost:8000/admin/oauth-providers/1').flush(mockProvider); + httpMock.expectOne('http://localhost:8000/admin/oauth-providers/1').flush(mockConnector); }); - expect(await promise).toEqual({ providerId: '1', name: 'Updated Provider' }); + expect(await promise).toEqual({ providerId: '1', name: 'Updated Connector' }); }); - it('should delete provider', async () => { - const promise = service.deleteProvider('1'); + it('should delete connector', async () => { + const promise = service.deleteConnector('1'); await vi.waitFor(() => { httpMock.expectOne('http://localhost:8000/admin/oauth-providers/1').flush(null); }); await promise; }); -}); \ No newline at end of file +}); diff --git a/frontend/ai.client/src/app/admin/connectors/services/connectors.service.ts b/frontend/ai.client/src/app/admin/connectors/services/connectors.service.ts new file mode 100644 index 00000000..fbf1ff78 --- /dev/null +++ b/frontend/ai.client/src/app/admin/connectors/services/connectors.service.ts @@ -0,0 +1,113 @@ +import { Injectable, inject, resource, computed } from '@angular/core'; +import { HttpClient } from '@angular/common/http'; +import { firstValueFrom } from 'rxjs'; +import { ConfigService } from '../../../services/config.service'; +import { AuthService } from '../../../auth/auth.service'; +import { + Connector, + ConnectorListResponse, + ConnectorCreateRequest, + ConnectorUpdateRequest, +} from '../models/connector.model'; + +function toSnakeCase(obj: Record): Record { + const result: Record = {}; + for (const [key, value] of Object.entries(obj)) { + if (value === undefined) continue; + const snakeKey = key.replace(/[A-Z]/g, letter => `_${letter.toLowerCase()}`); + result[snakeKey] = value; + } + return result; +} + +function toCamelCase(obj: Record): Record { + const result: Record = {}; + for (const [key, value] of Object.entries(obj)) { + const camelKey = key.replace(/_([a-z])/g, (_, letter) => letter.toUpperCase()); + result[camelKey] = value; + } + return result; +} + +/** + * Admin service for managing connectors. + * + * The backend admin endpoint is still `/admin/oauth-providers` — that is the + * stable wire contract. We use the connectors vernacular throughout the + * frontend and translate at this layer. + */ +@Injectable({ + providedIn: 'root' +}) +export class ConnectorsService { + private http = inject(HttpClient); + private authService = inject(AuthService); + private config = inject(ConfigService); + + private readonly baseUrl = computed(() => `${this.config.appApiUrl()}/admin/oauth-providers`); + + readonly connectorsResource = resource({ + loader: async () => { + await this.authService.ensureAuthenticated(); + return this.fetchConnectors(); + } + }); + + getConnectors(): Connector[] { + return this.connectorsResource.value()?.providers ?? []; + } + + getEnabledConnectors(): Connector[] { + return this.getConnectors().filter(c => c.enabled); + } + + getConnectorById(providerId: string): Connector | undefined { + return this.getConnectors().find(c => c.providerId === providerId); + } + + async fetchConnectors(): Promise { + const response = await firstValueFrom( + this.http.get(`${this.baseUrl()}/`) + ); + return { + providers: response.providers.map((p: any) => toCamelCase(p) as Connector), + total: response.total, + }; + } + + async fetchConnector(providerId: string): Promise { + const response = await firstValueFrom( + this.http.get(`${this.baseUrl()}/${providerId}`) + ); + return toCamelCase(response) as Connector; + } + + async createConnector(data: ConnectorCreateRequest): Promise { + const snakeCaseData = toSnakeCase(data as unknown as Record); + const response = await firstValueFrom( + this.http.post(`${this.baseUrl()}/`, snakeCaseData) + ); + this.connectorsResource.reload(); + return toCamelCase(response) as Connector; + } + + async updateConnector(providerId: string, updates: ConnectorUpdateRequest): Promise { + const snakeCaseData = toSnakeCase(updates as unknown as Record); + const response = await firstValueFrom( + this.http.patch(`${this.baseUrl()}/${providerId}`, snakeCaseData) + ); + this.connectorsResource.reload(); + return toCamelCase(response) as Connector; + } + + async deleteConnector(providerId: string): Promise { + await firstValueFrom( + this.http.delete(`${this.baseUrl()}/${providerId}`) + ); + this.connectorsResource.reload(); + } + + reload(): void { + this.connectorsResource.reload(); + } +} diff --git a/frontend/ai.client/src/app/admin/oauth-providers/index.ts b/frontend/ai.client/src/app/admin/oauth-providers/index.ts deleted file mode 100644 index 7b71465d..00000000 --- a/frontend/ai.client/src/app/admin/oauth-providers/index.ts +++ /dev/null @@ -1,9 +0,0 @@ -// Models -export * from './models/oauth-provider.model'; - -// Services -export * from './services/oauth-providers.service'; - -// Pages -export * from './pages/provider-list.page'; -export * from './pages/provider-form.page'; diff --git a/frontend/ai.client/src/app/admin/oauth-providers/models/oauth-provider.model.ts b/frontend/ai.client/src/app/admin/oauth-providers/models/oauth-provider.model.ts deleted file mode 100644 index 8f2b5577..00000000 --- a/frontend/ai.client/src/app/admin/oauth-providers/models/oauth-provider.model.ts +++ /dev/null @@ -1,186 +0,0 @@ -/** - * OAuth provider type enumeration. - */ -export type OAuthProviderType = 'google' | 'microsoft' | 'github' | 'canvas' | 'custom'; - -/** - * OAuth Provider configuration. - */ -export interface OAuthProvider { - /** Unique provider identifier (lowercase alphanumeric + underscore) */ - providerId: string; - /** Human-readable display name */ - displayName: string; - /** Provider type for preset configurations */ - providerType: OAuthProviderType; - /** OAuth authorization endpoint URL */ - authorizationEndpoint: string; - /** OAuth token endpoint URL */ - tokenEndpoint: string; - /** OAuth client ID (public) */ - clientId: string; - /** OAuth scopes to request */ - scopes: string[]; - /** AppRole IDs that can use this provider */ - allowedRoles: string[]; - /** Whether this provider is active */ - enabled: boolean; - /** Icon name for UI display (heroicons) */ - iconName: string; - /** Additional authorization URL parameters (e.g., access_type=offline for Google) */ - authorizationParams: Record; - /** ISO 8601 creation timestamp */ - createdAt: string; - /** ISO 8601 update timestamp */ - updatedAt: string; -} - -/** - * Response model for listing OAuth providers. - */ -export interface OAuthProviderListResponse { - providers: OAuthProvider[]; - total: number; -} - -/** - * Request model for creating a new OAuth provider. - */ -export interface OAuthProviderCreateRequest { - /** Unique provider identifier (lowercase alphanumeric + underscore, 3-50 chars) */ - providerId: string; - /** Human-readable display name (1-100 chars) */ - displayName: string; - /** Provider type */ - providerType: OAuthProviderType; - /** OAuth authorization endpoint URL */ - authorizationEndpoint: string; - /** OAuth token endpoint URL */ - tokenEndpoint: string; - /** OAuth client ID */ - clientId: string; - /** OAuth client secret (only sent on create, never returned) */ - clientSecret: string; - /** OAuth scopes to request */ - scopes: string[]; - /** AppRole IDs that can use this provider */ - allowedRoles?: string[]; - /** Whether this provider is active */ - enabled?: boolean; - /** Icon name for UI display */ - iconName?: string; - /** Additional authorization URL parameters (e.g., access_type=offline for Google) */ - authorizationParams?: Record; -} - -/** - * Request model for updating an OAuth provider. - * All fields are optional for partial updates. - */ -export interface OAuthProviderUpdateRequest { - /** Human-readable display name (1-100 chars) */ - displayName?: string; - /** OAuth authorization endpoint URL */ - authorizationEndpoint?: string; - /** OAuth token endpoint URL */ - tokenEndpoint?: string; - /** OAuth client ID */ - clientId?: string; - /** OAuth client secret (only set if updating) */ - clientSecret?: string; - /** OAuth scopes to request */ - scopes?: string[]; - /** AppRole IDs that can use this provider */ - allowedRoles?: string[]; - /** Whether this provider is active */ - enabled?: boolean; - /** Icon name for UI display */ - iconName?: string; - /** Additional authorization URL parameters (e.g., access_type=offline for Google) */ - authorizationParams?: Record; -} - -/** - * Form data model for creating/editing an OAuth provider. - */ -export interface OAuthProviderFormData { - providerId: string; - displayName: string; - providerType: OAuthProviderType; - authorizationEndpoint: string; - tokenEndpoint: string; - clientId: string; - clientSecret: string; - scopes: string; - allowedRoles: string[]; - enabled: boolean; - iconName: string; - authorizationParams: string; -} - -/** - * Preset configurations for common OAuth providers. - */ -export interface OAuthProviderPreset { - type: OAuthProviderType; - displayName: string; - authorizationEndpoint: string; - tokenEndpoint: string; - defaultScopes: string[]; - iconName: string; - authorizationParams?: Record; -} - -/** - * Common OAuth provider presets. - */ -export const OAUTH_PROVIDER_PRESETS: OAuthProviderPreset[] = [ - { - type: 'google', - displayName: 'Google', - authorizationEndpoint: 'https://accounts.google.com/o/oauth2/v2/auth', - tokenEndpoint: 'https://oauth2.googleapis.com/token', - defaultScopes: ['openid', 'email', 'profile'], - iconName: 'heroCloud', - authorizationParams: { access_type: 'offline', prompt: 'consent' }, - }, - { - type: 'microsoft', - displayName: 'Microsoft', - authorizationEndpoint: 'https://login.microsoftonline.com/common/oauth2/v2.0/authorize', - tokenEndpoint: 'https://login.microsoftonline.com/common/oauth2/v2.0/token', - defaultScopes: ['openid', 'email', 'profile', 'offline_access'], - iconName: 'heroCloud', - }, - { - type: 'github', - displayName: 'GitHub', - authorizationEndpoint: 'https://github.com/login/oauth/authorize', - tokenEndpoint: 'https://github.com/login/oauth/access_token', - defaultScopes: ['read:user', 'user:email'], - iconName: 'heroCodeBracket', - }, - { - type: 'canvas', - displayName: 'Canvas LMS', - authorizationEndpoint: '', // User must configure - tokenEndpoint: '', // User must configure - defaultScopes: [], - iconName: 'heroAcademicCap', - }, - { - type: 'custom', - displayName: 'Custom Provider', - authorizationEndpoint: '', - tokenEndpoint: '', - defaultScopes: [], - iconName: 'heroLink', - }, -]; - -/** - * Get preset configuration for a provider type. - */ -export function getProviderPreset(type: OAuthProviderType): OAuthProviderPreset | undefined { - return OAUTH_PROVIDER_PRESETS.find(preset => preset.type === type); -} diff --git a/frontend/ai.client/src/app/admin/oauth-providers/pages/provider-form.page.ts b/frontend/ai.client/src/app/admin/oauth-providers/pages/provider-form.page.ts deleted file mode 100644 index cf5dab75..00000000 --- a/frontend/ai.client/src/app/admin/oauth-providers/pages/provider-form.page.ts +++ /dev/null @@ -1,751 +0,0 @@ -import { - Component, - ChangeDetectionStrategy, - inject, - signal, - computed, - OnInit, -} from '@angular/core'; -import { Router, ActivatedRoute } from '@angular/router'; -import { - FormBuilder, - FormGroup, - FormControl, - Validators, - ReactiveFormsModule, -} from '@angular/forms'; -import { NgIcon, provideIcons } from '@ng-icons/core'; -import { - heroArrowLeft, - heroInformationCircle, - heroEye, - heroEyeSlash, - heroExclamationTriangle, - heroCheckCircle, -} from '@ng-icons/heroicons/outline'; -import { OAuthProvidersService } from '../services/oauth-providers.service'; -import { AppRolesService } from '../../roles/services/app-roles.service'; -import { - OAuthProviderCreateRequest, - OAuthProviderUpdateRequest, - OAuthProviderType, - OAUTH_PROVIDER_PRESETS, - getProviderPreset, -} from '../models/oauth-provider.model'; -import { TooltipDirective } from '../../../components/tooltip/tooltip.directive'; - -interface ProviderFormGroup { - providerId: FormControl; - displayName: FormControl; - providerType: FormControl; - authorizationEndpoint: FormControl; - tokenEndpoint: FormControl; - clientId: FormControl; - clientSecret: FormControl; - scopes: FormControl; - authorizationParams: FormControl; - allowedRoles: FormControl; - grantAllRoles: FormControl; - enabled: FormControl; - iconName: FormControl; -} - -@Component({ - selector: 'app-provider-form', - changeDetection: ChangeDetectionStrategy.OnPush, - imports: [ReactiveFormsModule, NgIcon, TooltipDirective], - providers: [ - provideIcons({ - heroArrowLeft, - heroInformationCircle, - heroEye, - heroEyeSlash, - heroExclamationTriangle, - heroCheckCircle, - }), - ], - host: { - class: 'block', - }, - template: ` -
-
- - - - -
-

- {{ pageTitle() }} -

-

- {{ isEditMode() ? 'Update OAuth provider settings and credentials' : 'Configure a new OAuth provider for tool authentication' }} -

-
- - - @if (loading()) { -
-
-
-

- Loading provider... -

-
-
- } @else { - -
- - - @if (!isEditMode()) { -
-

- Provider Type -

-

- Select a preset or configure a custom OAuth 2.0 provider. -

- -
- @for (preset of presets; track preset.type) { - - } -
-
- } - - -
-

- Basic Information -

- -
- -
- - -

- Unique identifier. Lowercase letters, numbers, and hyphens only. -

- @if (providerForm.controls.providerId.invalid && providerForm.controls.providerId.touched) { -

- @if (providerForm.controls.providerId.errors?.['required']) { - Provider ID is required - } @else if (providerForm.controls.providerId.errors?.['pattern']) { - Must be lowercase letters, numbers, and hyphens only - } @else if (providerForm.controls.providerId.errors?.['maxlength']) { - Must be at most 64 characters - } -

- } -
- - -
- - - @if (providerForm.controls.displayName.invalid && providerForm.controls.displayName.touched) { -

Display name is required

- } -
- - -
- - -
-
-
- - -
-

- OAuth Configuration -

-

- Configure the OAuth 2.0 endpoints and credentials. -

- -
- -
- - - @if (providerForm.controls.authorizationEndpoint.invalid && providerForm.controls.authorizationEndpoint.touched) { -

- @if (providerForm.controls.authorizationEndpoint.errors?.['required']) { - Authorization endpoint is required - } @else { - Must be a valid URL - } -

- } -
- - -
- - - @if (providerForm.controls.tokenEndpoint.invalid && providerForm.controls.tokenEndpoint.touched) { -

- @if (providerForm.controls.tokenEndpoint.errors?.['required']) { - Token endpoint is required - } @else { - Must be a valid URL - } -

- } -
- - -
- - - @if (providerForm.controls.clientId.invalid && providerForm.controls.clientId.touched) { -

Client ID is required

- } -
- - -
- -
- - -
- @if (isEditMode()) { -

- Leave blank to keep the existing secret. Enter a new value to update it. -

- } - @if (!isEditMode() && providerForm.controls.clientSecret.invalid && providerForm.controls.clientSecret.touched) { -

Client secret is required

- } -
- - -
- - -

- Comma-separated list of OAuth scopes to request during authorization. -

-
- - -
- - -

- Extra URL parameters for the authorization request. For Google, use "access_type=offline, prompt=consent" to enable refresh tokens. -

-
-
-
- - -
-

- Access Control -

-

- Restrict which application roles can use this provider. -

- -
- - - -
- - -
- - @if (!providerForm.controls.grantAllRoles.value) { - @if (rolesResource.isLoading() || rolesResource.value() === undefined) { -
-
-

Loading roles...

-
- } @else if (availableRoles().length > 0) { -
- @for (role of availableRoles(); track role.roleId) { - - } -
- } @else { -

- No roles available. Create roles in Role Management first. -

- } - } -

- Only users with selected roles will be able to connect to this provider. -

-
-
- - - @if (isEditMode()) { -
-
- -
-

- Security Notice -

-

- Changing scopes may invalidate existing user tokens. Users may need to re-authenticate after scope changes. -

-
-
-
- } - - -
- - -
-
- } -
-
- `, -}) -export class ProviderFormPage implements OnInit { - private fb = inject(FormBuilder); - private router = inject(Router); - private route = inject(ActivatedRoute); - private oauthProvidersService = inject(OAuthProvidersService); - private appRolesService = inject(AppRolesService); - - // Resources - readonly rolesResource = this.appRolesService.rolesResource; - - // Presets - readonly presets = OAUTH_PROVIDER_PRESETS; - - // State - readonly isEditMode = signal(false); - readonly providerId = signal(null); - readonly isSubmitting = signal(false); - readonly loading = signal(false); - readonly showClientSecret = signal(false); - - // Form - readonly providerForm: FormGroup = this.fb.group({ - providerId: this.fb.control('', { - nonNullable: true, - validators: [ - Validators.required, - Validators.minLength(1), - Validators.maxLength(64), - Validators.pattern(/^[a-z0-9-]+$/), - ], - }), - displayName: this.fb.control('', { - nonNullable: true, - validators: [Validators.required, Validators.maxLength(100)], - }), - providerType: this.fb.control('custom', { nonNullable: true }), - authorizationEndpoint: this.fb.control('', { - nonNullable: true, - validators: [Validators.required, Validators.pattern(/^https?:\/\/.+/)], - }), - tokenEndpoint: this.fb.control('', { - nonNullable: true, - validators: [Validators.required, Validators.pattern(/^https?:\/\/.+/)], - }), - clientId: this.fb.control('', { - nonNullable: true, - validators: [Validators.required], - }), - clientSecret: this.fb.control('', { nonNullable: true }), - scopes: this.fb.control('', { nonNullable: true }), - authorizationParams: this.fb.control('', { nonNullable: true }), - allowedRoles: this.fb.control(['*'], { nonNullable: true }), - grantAllRoles: this.fb.control(true, { nonNullable: true }), - enabled: this.fb.control(true, { nonNullable: true }), - iconName: this.fb.control('heroLink', { nonNullable: true }), - }); - - readonly pageTitle = computed(() => - this.isEditMode() ? 'Edit OAuth Provider' : 'Add OAuth Provider' - ); - - readonly availableRoles = computed(() => - this.appRolesService.getEnabledRoles() - ); - - // Track selected roles as a signal for change detection with OnPush - readonly selectedRoles = signal(['*']); - - ngOnInit(): void { - const id = this.route.snapshot.paramMap.get('providerId'); - if (id && id !== 'new') { - this.isEditMode.set(true); - this.providerId.set(id); - this.loadProviderData(id); - } else { - // Set client secret as required for new providers - this.providerForm.controls.clientSecret.setValidators([Validators.required]); - this.providerForm.controls.clientSecret.updateValueAndValidity(); - } - } - - private async loadProviderData(id: string): Promise { - this.loading.set(true); - try { - const provider = await this.oauthProvidersService.fetchProvider(id); - - // Convert authorizationParams object to "key=value, key=value" string - const authParamsString = provider.authorizationParams - ? Object.entries(provider.authorizationParams) - .map(([k, v]) => `${k}=${v}`) - .join(', ') - : ''; - - this.providerForm.patchValue({ - providerId: provider.providerId, - displayName: provider.displayName, - providerType: provider.providerType, - authorizationEndpoint: provider.authorizationEndpoint, - tokenEndpoint: provider.tokenEndpoint, - clientId: provider.clientId, - clientSecret: '', // Never returned from API - scopes: provider.scopes.join(', '), - authorizationParams: authParamsString, - allowedRoles: provider.allowedRoles.length > 0 ? provider.allowedRoles : ['*'], - grantAllRoles: provider.allowedRoles.length === 0, - enabled: provider.enabled, - iconName: provider.iconName || 'heroLink', - }); - // Sync selectedRoles signal with loaded data - this.selectedRoles.set(provider.allowedRoles.length > 0 ? provider.allowedRoles : ['*']); - } catch (error) { - console.error('Error loading provider:', error); - alert('Failed to load provider. Returning to list.'); - this.router.navigate(['/admin/oauth-providers']); - } finally { - this.loading.set(false); - } - } - - selectProviderType(type: OAuthProviderType): void { - const preset = getProviderPreset(type); - if (preset) { - // Convert authorizationParams object to "key=value, key=value" string - const authParamsString = preset.authorizationParams - ? Object.entries(preset.authorizationParams) - .map(([k, v]) => `${k}=${v}`) - .join(', ') - : ''; - - this.providerForm.patchValue({ - providerType: type, - displayName: preset.displayName, - authorizationEndpoint: preset.authorizationEndpoint, - tokenEndpoint: preset.tokenEndpoint, - scopes: preset.defaultScopes.join(', '), - authorizationParams: authParamsString, - iconName: preset.iconName, - }); - } - } - - getPresetIconClasses(type: OAuthProviderType): string { - const baseClasses = 'flex size-10 items-center justify-center rounded-sm'; - switch (type) { - case 'google': - return `${baseClasses} bg-red-100 text-red-600 dark:bg-red-900/30 dark:text-red-400`; - case 'microsoft': - return `${baseClasses} bg-blue-100 text-blue-600 dark:bg-blue-900/30 dark:text-blue-400`; - case 'github': - return `${baseClasses} bg-gray-800 text-white dark:bg-gray-600`; - case 'canvas': - return `${baseClasses} bg-orange-100 text-orange-600 dark:bg-orange-900/30 dark:text-orange-400`; - default: - return `${baseClasses} bg-purple-100 text-purple-600 dark:bg-purple-900/30 dark:text-purple-400`; - } - } - - onGrantAllRolesChange(): void { - const checked = this.providerForm.controls.grantAllRoles.value; - if (checked) { - this.providerForm.controls.allowedRoles.setValue(['*']); - this.selectedRoles.set(['*']); - } else { - this.providerForm.controls.allowedRoles.setValue([]); - this.selectedRoles.set([]); - // Trigger roles load if not already loaded - if (this.rolesResource.value() === undefined) { - this.rolesResource.reload(); - } - } - } - - isRoleSelected(roleId: string): boolean { - return this.selectedRoles().includes(roleId); - } - - toggleRole(roleId: string): void { - const currentRoles = this.selectedRoles().filter(r => r !== '*'); - let newRoles: string[]; - if (currentRoles.includes(roleId)) { - newRoles = currentRoles.filter(r => r !== roleId); - } else { - newRoles = [...currentRoles, roleId]; - } - this.providerForm.controls.allowedRoles.setValue(newRoles); - this.selectedRoles.set(newRoles); - } - - async onSubmit(): Promise { - if (this.providerForm.invalid) { - this.providerForm.markAllAsTouched(); - return; - } - - this.isSubmitting.set(true); - - try { - const formValue = this.providerForm.value; - - // Parse scopes from comma-separated string - const scopes = formValue.scopes - ? formValue.scopes - .split(',') - .map((s: string) => s.trim()) - .filter((s: string) => s.length > 0) - : []; - - // Parse authorization params from "key=value, key=value" string - const authorizationParams: Record = {}; - if (formValue.authorizationParams) { - formValue.authorizationParams - .split(',') - .map((p: string) => p.trim()) - .filter((p: string) => p.length > 0) - .forEach((p: string) => { - const [key, ...valueParts] = p.split('='); - if (key && valueParts.length > 0) { - authorizationParams[key.trim()] = valueParts.join('=').trim(); - } - }); - } - - // Normalize allowed roles - const allowedRoles = formValue.grantAllRoles - ? [] - : (formValue.allowedRoles || []).filter((r: string) => r !== '*'); - - if (this.isEditMode() && this.providerId()) { - const updates: OAuthProviderUpdateRequest = { - displayName: formValue.displayName, - authorizationEndpoint: formValue.authorizationEndpoint, - tokenEndpoint: formValue.tokenEndpoint, - clientId: formValue.clientId, - scopes, - authorizationParams, - allowedRoles, - enabled: formValue.enabled, - iconName: formValue.iconName, - }; - // Only include client secret if provided - if (formValue.clientSecret) { - updates.clientSecret = formValue.clientSecret; - } - await this.oauthProvidersService.updateProvider(this.providerId()!, updates); - } else { - const createData: OAuthProviderCreateRequest = { - providerId: formValue.providerId!, - displayName: formValue.displayName!, - providerType: formValue.providerType!, - authorizationEndpoint: formValue.authorizationEndpoint!, - tokenEndpoint: formValue.tokenEndpoint!, - clientId: formValue.clientId!, - clientSecret: formValue.clientSecret!, - scopes, - authorizationParams, - allowedRoles, - enabled: formValue.enabled, - iconName: formValue.iconName, - }; - await this.oauthProvidersService.createProvider(createData); - } - - this.router.navigate(['/admin/oauth-providers']); - } catch (error: any) { - console.error('Error saving provider:', error); - const message = - error?.error?.detail || error?.message || 'Failed to save provider.'; - alert(message); - } finally { - this.isSubmitting.set(false); - } - } - - goBack(): void { - this.router.navigate(['/admin/oauth-providers']); - } -} diff --git a/frontend/ai.client/src/app/admin/oauth-providers/services/oauth-providers.service.ts b/frontend/ai.client/src/app/admin/oauth-providers/services/oauth-providers.service.ts deleted file mode 100644 index b433014e..00000000 --- a/frontend/ai.client/src/app/admin/oauth-providers/services/oauth-providers.service.ts +++ /dev/null @@ -1,154 +0,0 @@ -import { Injectable, inject, resource, computed } from '@angular/core'; -import { HttpClient } from '@angular/common/http'; -import { firstValueFrom } from 'rxjs'; -import { ConfigService } from '../../../services/config.service'; -import { AuthService } from '../../../auth/auth.service'; -import { - OAuthProvider, - OAuthProviderListResponse, - OAuthProviderCreateRequest, - OAuthProviderUpdateRequest, -} from '../models/oauth-provider.model'; - -/** - * Convert camelCase to snake_case for backend API. - */ -function toSnakeCase(obj: Record): Record { - const result: Record = {}; - for (const [key, value] of Object.entries(obj)) { - if (value === undefined) continue; - const snakeKey = key.replace(/[A-Z]/g, letter => `_${letter.toLowerCase()}`); - result[snakeKey] = value; - } - return result; -} - -/** - * Convert snake_case to camelCase for frontend models. - */ -function toCamelCase(obj: Record): Record { - const result: Record = {}; - for (const [key, value] of Object.entries(obj)) { - const camelKey = key.replace(/_([a-z])/g, (_, letter) => letter.toUpperCase()); - result[camelKey] = value; - } - return result; -} - -/** - * Service to manage OAuth Providers. - * - * Provides access to the provider list for use in forms and displays, - * as well as CRUD operations for provider management. - */ -@Injectable({ - providedIn: 'root' -}) -export class OAuthProvidersService { - private http = inject(HttpClient); - private authService = inject(AuthService); - private config = inject(ConfigService); - - private readonly baseUrl = computed(() => `${this.config.appApiUrl()}/admin/oauth-providers`); - - /** - * Reactive resource for fetching OAuth Providers. - */ - readonly providersResource = resource({ - loader: async () => { - await this.authService.ensureAuthenticated(); - return this.fetchProviders(); - } - }); - - /** - * Get all OAuth Providers (from resource). - */ - getProviders(): OAuthProvider[] { - return this.providersResource.value()?.providers ?? []; - } - - /** - * Get only enabled OAuth Providers. - */ - getEnabledProviders(): OAuthProvider[] { - return this.getProviders().filter(p => p.enabled); - } - - /** - * Get a provider by ID from the cached resource. - */ - getProviderById(providerId: string): OAuthProvider | undefined { - return this.getProviders().find(p => p.providerId === providerId); - } - - /** - * Fetch all OAuth Providers from the API. - */ - async fetchProviders(): Promise { - const response = await firstValueFrom( - this.http.get(`${this.baseUrl()}/`) - ); - // Convert snake_case response to camelCase - return { - providers: response.providers.map((p: any) => toCamelCase(p) as OAuthProvider), - total: response.total, - }; - } - - /** - * Fetch a single provider by ID from the API. - */ - async fetchProvider(providerId: string): Promise { - const response = await firstValueFrom( - this.http.get(`${this.baseUrl()}/${providerId}`) - ); - // Convert snake_case response to camelCase - return toCamelCase(response) as OAuthProvider; - } - - /** - * Create a new OAuth Provider. - */ - async createProvider(providerData: OAuthProviderCreateRequest): Promise { - // Convert camelCase request to snake_case - const snakeCaseData = toSnakeCase(providerData as unknown as Record); - const response = await firstValueFrom( - this.http.post(`${this.baseUrl()}/`, snakeCaseData) - ); - this.providersResource.reload(); - // Convert snake_case response to camelCase - return toCamelCase(response) as OAuthProvider; - } - - /** - * Update an existing OAuth Provider. - */ - async updateProvider(providerId: string, updates: OAuthProviderUpdateRequest): Promise { - // Convert camelCase request to snake_case - const snakeCaseData = toSnakeCase(updates as unknown as Record); - const response = await firstValueFrom( - this.http.patch(`${this.baseUrl()}/${providerId}`, snakeCaseData) - ); - this.providersResource.reload(); - // Convert snake_case response to camelCase - return toCamelCase(response) as OAuthProvider; - } - - /** - * Delete an OAuth Provider. - */ - async deleteProvider(providerId: string): Promise { - await firstValueFrom( - this.http.delete(`${this.baseUrl()}/${providerId}`) - ); - this.providersResource.reload(); - } - - /** - * Reload the providers resource. - */ - reload(): void { - this.providersResource.reload(); - } -} diff --git a/frontend/ai.client/src/app/admin/tools/pages/tool-form.page.ts b/frontend/ai.client/src/app/admin/tools/pages/tool-form.page.ts index 7df23c46..0fecfa5a 100644 --- a/frontend/ai.client/src/app/admin/tools/pages/tool-form.page.ts +++ b/frontend/ai.client/src/app/admin/tools/pages/tool-form.page.ts @@ -19,7 +19,7 @@ import { heroShieldCheck, } from '@ng-icons/heroicons/outline'; import { AdminToolService } from '../services/admin-tool.service'; -import { OAuthProvidersService } from '../../oauth-providers/services/oauth-providers.service'; +import { ConnectorsService } from '../../connectors/services/connectors.service'; import { TOOL_CATEGORIES, TOOL_PROTOCOLS, @@ -340,7 +340,7 @@ import {
-

User OAuth Connection

+

User OAuth Connector

If this tool requires access to a user's external account (e.g., Google Workspace, Microsoft 365), @@ -361,8 +361,8 @@ import { }

- Users must connect this provider before using the tool. Manage providers in - OAuth Settings. + Users must connect this connector before using the tool. Manage connectors in + Connectors.

@@ -568,7 +568,7 @@ export class ToolFormPage implements OnInit { private router = inject(Router); private route = inject(ActivatedRoute); private adminToolService = inject(AdminToolService); - private oauthProvidersService = inject(OAuthProvidersService); + private connectorsService = inject(ConnectorsService); readonly categories = TOOL_CATEGORIES; readonly protocols = TOOL_PROTOCOLS; @@ -585,8 +585,8 @@ export class ToolFormPage implements OnInit { readonly isEditMode = computed(() => !!this.toolId()); readonly selectedProtocol = signal('local'); - /** Available OAuth providers for dropdown */ - readonly oauthProviders = computed(() => this.oauthProvidersService.getEnabledProviders()); + /** Available connectors for dropdown */ + readonly oauthProviders = computed(() => this.connectorsService.getEnabledConnectors()); form: FormGroup = this.fb.group({ toolId: ['', [Validators.required, Validators.pattern(/^[a-z][a-z0-9_]{2,49}$/)]], diff --git a/frontend/ai.client/src/app/app.routes.ts b/frontend/ai.client/src/app/app.routes.ts index a750cf21..3874cdd4 100644 --- a/frontend/ai.client/src/app/app.routes.ts +++ b/frontend/ai.client/src/app/app.routes.ts @@ -32,11 +32,6 @@ export const routes: Routes = [ path: 'auth/callback', loadComponent: () => import('./auth/callback/callback.page').then(m => m.CallbackPage), }, - { - path: 'connections', - redirectTo: 'settings/connections', - pathMatch: 'full', - }, { path: 'admin', loadComponent: () => import('./admin/admin.page').then(m => m.AdminPage), @@ -103,8 +98,8 @@ export const routes: Routes = [ canActivate: [authGuard], }, { - path: 'settings/oauth/callback', - loadComponent: () => import('./settings/oauth-callback/oauth-callback.page').then(m => m.OAuthCallbackPage), + path: 'oauth-complete', + loadComponent: () => import('./oauth-complete/oauth-complete.page').then(m => m.OAuthCompletePage), }, { path: 'settings', @@ -179,17 +174,32 @@ export const routes: Routes = [ }, { path: 'admin/oauth-providers', - loadComponent: () => import('./admin/oauth-providers/pages/provider-list.page').then(m => m.ProviderListPage), - canActivate: [adminGuard], + redirectTo: 'admin/connectors', + pathMatch: 'full', }, { path: 'admin/oauth-providers/new', - loadComponent: () => import('./admin/oauth-providers/pages/provider-form.page').then(m => m.ProviderFormPage), - canActivate: [adminGuard], + redirectTo: 'admin/connectors/new', + pathMatch: 'full', }, { path: 'admin/oauth-providers/edit/:providerId', - loadComponent: () => import('./admin/oauth-providers/pages/provider-form.page').then(m => m.ProviderFormPage), + redirectTo: 'admin/connectors/edit/:providerId', + pathMatch: 'full', + }, + { + path: 'admin/connectors', + loadComponent: () => import('./admin/connectors/pages/connector-list.page').then(m => m.ConnectorListPage), + canActivate: [adminGuard], + }, + { + path: 'admin/connectors/new', + loadComponent: () => import('./admin/connectors/pages/connector-form.page').then(m => m.ConnectorFormPage), + canActivate: [adminGuard], + }, + { + path: 'admin/connectors/edit/:providerId', + loadComponent: () => import('./admin/connectors/pages/connector-form.page').then(m => m.ConnectorFormPage), canActivate: [adminGuard], }, { diff --git a/frontend/ai.client/src/app/assistants/assistant-form/services/preview-chat.service.ts b/frontend/ai.client/src/app/assistants/assistant-form/services/preview-chat.service.ts index 5c9dea3d..2727fa79 100644 --- a/frontend/ai.client/src/app/assistants/assistant-form/services/preview-chat.service.ts +++ b/frontend/ai.client/src/app/assistants/assistant-form/services/preview-chat.service.ts @@ -214,6 +214,7 @@ export class PreviewChatService { 'Content-Type': 'application/json', Authorization: `Bearer ${token}`, Accept: 'text/event-stream', + OAuth2CallbackUrl: `${window.location.origin}/oauth-complete`, }, body: JSON.stringify(requestBody), signal: this.abortController.signal, diff --git a/frontend/ai.client/src/app/oauth-complete/oauth-complete.page.ts b/frontend/ai.client/src/app/oauth-complete/oauth-complete.page.ts new file mode 100644 index 00000000..be501537 --- /dev/null +++ b/frontend/ai.client/src/app/oauth-complete/oauth-complete.page.ts @@ -0,0 +1,371 @@ +import { + ChangeDetectionStrategy, + Component, + OnDestroy, + OnInit, + computed, + inject, + signal, +} from '@angular/core'; +import { ActivatedRoute, Router } from '@angular/router'; +import { NgIcon, provideIcons } from '@ng-icons/core'; +import { + heroCheckCircle, + heroExclamationCircle, +} from '@ng-icons/heroicons/outline'; +import { AuthService } from '../auth/auth.service'; +import { ConfigService } from '../services/config.service'; + +/** + * Landing page for AgentCore Identity's 3-legged OAuth flow. + * + * AgentCore Identity redirects the user here after consent completes (or + * fails). We detect whether this page was opened in a popup: + * + * - Popup: post a message to the opener so the chat can retry the tool + * call, then close the window. + * - Same tab: show a brief success/error message, then route back to chat. + * + * Query params AgentCore Identity may append are not strictly contractual, + * so we treat missing params as success and known error indicators as + * failure (`error`, `error_description`). This matches the defensive + * parsing pattern in `settings/oauth-callback`. + */ + +type CompleteState = 'success' | 'error'; + +/** postMessage payload shape — kept public so other code can type the listener. */ +export interface OAuthCompleteMessage { + type: 'agentcore-oauth-complete'; + status: CompleteState; + providerId: string | null; + error: string | null; +} + +@Component({ + selector: 'app-oauth-complete', + imports: [NgIcon], + providers: [ + provideIcons({ heroCheckCircle, heroExclamationCircle }), + ], + changeDetection: ChangeDetectionStrategy.OnPush, + template: ` +
+ @if (state() === 'success') { +
+
+ } @else { + + } +
+ `, + styles: ` + :host { + display: block; + min-height: 100dvh; + background: var(--color-gray-50); + } + + :host-context(html.dark) { + background: var(--color-gray-900); + } + + .page { + min-height: 100dvh; + display: flex; + align-items: center; + justify-content: center; + padding: 2rem; + } + + .card { + width: 100%; + max-width: 28rem; + padding: 2rem; + border-radius: 0.75rem; + background: var(--color-white); + border: 1px solid var(--color-gray-200); + box-shadow: 0 1px 2px rgb(0 0 0 / 0.05); + text-align: center; + } + + :host-context(html.dark) .card { + background: var(--color-gray-800); + border-color: var(--color-gray-700); + } + + .icon { + display: inline-flex; + width: 3rem; + height: 3rem; + margin-bottom: 1rem; + } + + .success-icon { + color: var(--color-green-600); + } + + :host-context(html.dark) .success-icon { + color: var(--color-green-400); + } + + .error-icon { + color: var(--color-red-600); + } + + :host-context(html.dark) .error-icon { + color: var(--color-red-400); + } + + .title { + font-weight: 600; + font-size: 1.5rem; + margin: 0 0 0.5rem; + color: var(--color-gray-900); + } + + :host-context(html.dark) .title { + color: var(--color-gray-100); + } + + .subtitle { + margin: 0 0 1rem; + color: var(--color-gray-700); + font-size: 0.95rem; + } + + :host-context(html.dark) .subtitle { + color: var(--color-gray-300); + } + + .hint { + margin: 0; + font-size: 0.8125rem; + color: var(--color-gray-500); + } + `, +}) +export class OAuthCompletePage implements OnInit, OnDestroy { + private readonly route = inject(ActivatedRoute); + private readonly router = inject(Router); + private readonly authService = inject(AuthService); + private readonly config = inject(ConfigService); + + private redirectTimer: ReturnType | null = null; + + readonly state = signal('success'); + readonly providerId = signal(null); + readonly sessionUri = signal(null); + readonly errorMessage = signal('Authorization was denied or did not complete.'); + private readonly isPopup = signal(false); + private readonly finalizing = signal(false); + + readonly providerLabel = computed(() => { + const id = this.providerId(); + if (!id) { + return null; + } + return id + .replace(/[-_]+/g, ' ') + .split(' ') + .filter((part) => part.length > 0) + .map((part) => part.charAt(0).toUpperCase() + part.slice(1)) + .join(' '); + }); + + readonly dismissHint = computed(() => { + if (!this.isPopup()) { + return 'Redirecting back to your chat…'; + } + return this.state() === 'error' + ? 'You can close this window once you\'re done reading the error.' + : 'You can close this window.'; + }); + + ngOnInit(): void { + const params = this.route.snapshot.queryParamMap; + const error = params.get('error'); + const errorDescription = params.get('error_description'); + // AgentCore echoes our custom_state back as `state`; we set it server-side + // to the providerId when initiating consent from the settings page. + const providerId = + params.get('provider_id') ?? + params.get('providerId') ?? + params.get('state'); + + // AgentCore's redirect also carries `session_id`, which is the + // `request_uri` from the initial authorize call. We must hand it back + // to CompleteResourceTokenAuth or the token vault stays empty. + const sessionUri = params.get('session_id') ?? params.get('sessionUri'); + + if (providerId) { + this.providerId.set(providerId); + } + if (sessionUri) { + this.sessionUri.set(sessionUri); + } + + if (error) { + this.state.set('error'); + this.errorMessage.set( + errorDescription?.trim() || this.describeError(error), + ); + } + + const inPopup = this.detectPopup(); + this.isPopup.set(inPopup); + + // Finalize AgentCore's OAuth session (exchanges the `request_uri` for a + // persisted token). Must happen BEFORE we tell the opener we're done — + // otherwise the opener's next tool call will still see "consent required". + if (this.state() === 'success' && sessionUri) { + this.finalizing.set(true); + this.finalizeConsent(sessionUri, providerId) + .catch((err) => { + this.state.set('error'); + this.errorMessage.set( + err instanceof Error + ? `Couldn't finalize authorization: ${err.message}` + : "Couldn't finalize authorization.", + ); + }) + .finally(() => { + this.finalizing.set(false); + if (inPopup) { + this.notifyOpenerAndClose(); + } else { + this.redirectTimer = setTimeout(() => this.router.navigate(['/']), 2000); + } + }); + return; + } + + if (inPopup) { + this.notifyOpenerAndClose(); + } else { + this.redirectTimer = setTimeout(() => this.router.navigate(['/']), 2000); + } + } + + private async finalizeConsent( + sessionUri: string, + providerId: string | null, + ): Promise { + const baseUrl = this.config.inferenceApiUrl().replace(/\/invocations\/?$/, ''); + if (!baseUrl) { + throw new Error('inferenceApiUrl not configured'); + } + const token = this.authService.getAccessToken(); + if (!token) { + throw new Error('No access token available'); + } + const url = `${baseUrl}/connectors/complete-consent`; + const response = await fetch(url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}`, + }, + body: JSON.stringify({ + session_uri: sessionUri, + provider_id: providerId, + }), + }); + if (!response.ok) { + const text = await response.text().catch(() => ''); + throw new Error(`HTTP ${response.status}: ${text || response.statusText}`); + } + } + + ngOnDestroy(): void { + if (this.redirectTimer !== null) { + clearTimeout(this.redirectTimer); + } + } + + /** + * A page is "in a popup" only when it has an opener from the same origin + * that it can actually postMessage to. We guard the property reads + * defensively because cross-origin `window.opener` access can throw. + */ + private detectPopup(): boolean { + try { + return typeof window !== 'undefined' && window.opener != null && window.opener !== window; + } catch { + return false; + } + } + + private notifyOpenerAndClose(): void { + const message: OAuthCompleteMessage = { + type: 'agentcore-oauth-complete', + status: this.state(), + providerId: this.providerId(), + error: this.state() === 'error' ? this.errorMessage() : null, + }; + + // Primary channel: BroadcastChannel. Survives the Cross-Origin-Opener-Policy + // split that severs window.opener once a popup navigates through external + // origins (AgentCore, Google) and back. Same-origin tabs sharing a channel + // name always see each other, regardless of opener relationship. + try { + const channel = new BroadcastChannel('agentcore-oauth-complete'); + channel.postMessage(message); + // Give the message a tick to propagate before closing the channel. + setTimeout(() => channel.close(), 200); + } catch { + // BroadcastChannel unavailable — fall back to postMessage below. + } + + // Fallback: postMessage to opener. Works when COOP isn't in play. + try { + window.opener?.postMessage(message, window.location.origin); + } catch { + // Cross-origin or COOP-isolated opener — BroadcastChannel above + // handles the handoff in that case. + } + + // Only auto-close on success. On error, leave the window open so the + // user can read the failure reason — they dismiss it manually. + if (this.state() !== 'success') { + return; + } + + this.redirectTimer = setTimeout(() => { + try { + window.close(); + } catch { + // Some browsers refuse to close pages they didn't open; show the + // static "you can close this window" hint and let the user dismiss. + } + }, 400); + } + + private describeError(code: string): string { + switch (code) { + case 'access_denied': + return 'You declined the authorization request.'; + case 'invalid_scope': + return 'The requested permissions are not available for this account.'; + case 'server_error': + return 'The provider could not complete the request. Try again in a moment.'; + default: + return 'Authorization did not complete. Try again.'; + } + } +} diff --git a/frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.spec.ts b/frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.spec.ts new file mode 100644 index 00000000..efb77a8d --- /dev/null +++ b/frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.spec.ts @@ -0,0 +1,98 @@ +import { TestBed } from '@angular/core/testing'; +import { OAuthConsentService } from './oauth-consent.service'; +import { UserConnectorsService } from '../../settings/connectors/services/user-connectors.service'; +import { SessionService } from '../../session/services/session/session.service'; + +describe('OAuthConsentService', () => { + let service: OAuthConsentService; + + beforeEach(() => { + TestBed.resetTestingModule(); + TestBed.configureTestingModule({ + providers: [ + { provide: UserConnectorsService, useValue: {} }, + { + provide: SessionService, + useValue: { dismissPendingInterrupt: () => Promise.resolve() }, + }, + ], + }); + service = TestBed.inject(OAuthConsentService); + }); + + afterEach(() => { + TestBed.resetTestingModule(); + }); + + describe('requestConsent dedup by interruptId', () => { + it('drops a re-emission with the same interruptId', () => { + service.requestConsent( + 'google', + 'https://accounts.example/consent?req=1', + 'i-1', + 'msg-1', + 'sess-1', + ); + service.requestConsent( + 'google', + 'https://accounts.example/consent?req=2', + 'i-1', + 'msg-1', + 'sess-1', + ); + + const pending = service.pending(); + expect(pending.length).toBe(1); + // First emission's URL is preserved — the duplicate did not overwrite. + expect(pending[0].authorizationUrl).toBe('https://accounts.example/consent?req=1'); + }); + + it('surfaces a new interruptId for the same provider', () => { + service.requestConsent('google', 'https://accounts.example/c1', 'i-1'); + service.requestConsent('google', 'https://accounts.example/c2', 'i-2'); + + // Provider-keyed map still collapses to one entry, but the second + // call refreshed the interruptId — it was not dropped as a duplicate. + const pending = service.pending(); + expect(pending.length).toBe(1); + expect(pending[0].interruptId).toBe('i-2'); + expect(pending[0].authorizationUrl).toBe('https://accounts.example/c2'); + }); + + it('surfaces distinct providers independently', () => { + service.requestConsent('google', 'https://accounts.example/g', 'i-g'); + service.requestConsent('slack', 'https://slack.example/s', 'i-s'); + + expect(service.pending().length).toBe(2); + }); + + it('keeps interruptId-less requests (settings-page consent) unduped', () => { + service.requestConsent('google', 'https://accounts.example/c1'); + service.requestConsent('google', 'https://accounts.example/c2'); + + // No interruptId means no dedup key — the second call still refreshes + // the entry. Settings-page flows have no agent turn to resume. + expect(service.pending().length).toBe(1); + expect(service.pending()[0].authorizationUrl).toBe('https://accounts.example/c2'); + }); + + it('drops a re-emission even after the request was dismissed', () => { + service.requestConsent('google', 'https://accounts.example/c', 'i-1'); + service.dismiss('google', { syncServer: false }); + expect(service.pending().length).toBe(0); + + // Stream replay or late breadcrumb resurrection of the same id — + // the prompt must not come back. + service.requestConsent('google', 'https://accounts.example/c', 'i-1'); + expect(service.pending().length).toBe(0); + }); + + it('clear() resets the dedup set so a fresh session can re-prompt', () => { + service.requestConsent('google', 'https://accounts.example/c', 'i-1'); + service.clear(); + + service.requestConsent('google', 'https://accounts.example/c', 'i-1'); + expect(service.pending().length).toBe(1); + }); + }); +}); diff --git a/frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.ts b/frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.ts new file mode 100644 index 00000000..4acbf333 --- /dev/null +++ b/frontend/ai.client/src/app/services/oauth-consent/oauth-consent.service.ts @@ -0,0 +1,542 @@ +import { Injectable, signal, computed, inject, DestroyRef } from '@angular/core'; +import { takeUntilDestroyed } from '@angular/core/rxjs-interop'; +import { fromEvent } from 'rxjs'; +import { UserConnectorsService } from '../../settings/connectors/services/user-connectors.service'; +import { SessionService } from '../../session/services/session/session.service'; + +/** + * Pending OAuth consent request surfaced by the backend when an external + * MCP tool needs the user to authorize AgentCore Identity. + * + * `interruptId` is set when the request comes from a paused agent turn + * (SSE `oauth_required` event) so the chat layer can resume the same turn + * after consent. It's omitted when the user proactively connects from the + * settings page — in that case there's no agent turn to resume. + */ +export interface OAuthConsentRequest { + providerId: string; + /** Authorization URL captured from a live `oauth_required` SSE event. + * Absent on requests hydrated from session metadata after a refresh — + * AgentCore's URLs expire quickly, so the service re-fetches a fresh one + * via `initiate-consent` when the user clicks Connect. */ + authorizationUrl?: string; + interruptId?: string; + receivedAt: number; + /** Id of the assistant message whose tool call triggered this consent request. + * Used by the inline message renderer to anchor the prompt to the turn that + * needs it. Omitted for proactive consents from the settings page. */ + messageId?: string; + /** Session id the request belongs to. Required for the backend dismiss + * endpoint to clear the persisted breadcrumb so a refresh doesn't + * resurrect a dismissed prompt. */ + sessionId?: string; +} + +/** + * postMessage payload shape broadcast by the `/oauth-complete` landing + * page. Kept in sync with `OAuthCompleteMessage` in + * `src/app/oauth-complete/oauth-complete.page.ts`. + */ +export interface OAuthCompleteMessage { + type: 'agentcore-oauth-complete'; + status: 'success' | 'error'; + providerId: string | null; + error: string | null; +} + +/** + * Handler the chat layer registers to resume a paused agent turn after + * one or more OAuth consents complete. Receives the interrupt ids whose + * tokens are now available, plus the originating session id so the + * handler can resume even when the live ``lastRequestObject`` is gone + * (post-refresh hydration). The handler is expected to POST a resume + * request to `/invocations` with `interrupt_responses` populated. + */ +export type OAuthResumeHandler = ( + interruptIds: string[], + context?: { sessionId?: string }, +) => void | Promise; + +function isOAuthCompleteMessage(data: unknown): data is OAuthCompleteMessage { + if (!data || typeof data !== 'object') { + return false; + } + const msg = data as Partial; + return msg.type === 'agentcore-oauth-complete'; +} + +/** + * Only https URLs are accepted for consent navigation. Guards against a + * compromised backend or a misconfigured AgentCore response smuggling a + * `javascript:` or `data:` URL through the `oauth_required` event and + * executing in our origin when the user clicks Connect. + */ +function isSafeConsentUrl(raw: string): boolean { + try { + return new URL(raw).protocol === 'https:'; + } catch { + return false; + } +} + +/** + * Tracks OAuth consent requests surfaced by the SSE stream and coordinates + * the popup + auto-resume flow. + * + * The stream parser calls {@link requestConsent} when an `oauth_required` + * event arrives; components render a "Connect" affordance bound to + * {@link pending}. When the user clicks, {@link openConsentPopup} opens the + * AgentCore Identity URL, and this service listens for the + * `agentcore-oauth-complete` postMessage from the `/oauth-complete` landing + * page. On success it dismisses the request and asks the registered + * {@link OAuthResumeHandler} to fire a resume request — the user does NOT + * have to retype the original prompt. + */ +@Injectable({ providedIn: 'root' }) +export class OAuthConsentService { + private readonly destroyRef = inject(DestroyRef); + private readonly userConnectorsService = inject(UserConnectorsService); + private readonly sessionService = inject(SessionService); + + /** Map of providerId → request. A provider only appears once, even if + * the backend emits duplicates mid-stream. */ + private readonly requests = signal>(new Map()); + + /** Interrupt ids we've already surfaced or resolved this session. Used to + * ignore re-emissions of the same `oauth_required` event after a stream + * replay or a late server-side breadcrumb clear — without this, a popup + * that already completed consent would resurrect once dismissed. New + * tool calls always carry a fresh interrupt id (Strands generates it + * from `toolUseId`), so legitimate prompts are never suppressed. */ + private readonly seenInterruptIds = new Set(); + + /** ProviderIds whose popup is currently open. */ + private readonly inFlight = signal>(new Set()); + + /** Public read of inFlight so settings/chat UIs can react when a popup + * closes without completing (state needs to flip from "Awaiting" back + * to "Connect" so the user can retry). */ + readonly inFlightProviders = this.inFlight.asReadonly(); + + /** Active close-watcher intervals keyed by providerId so we can cancel + * cleanly on completion / dismissal. */ + private readonly closeWatchers = new Map>(); + + /** ProviderIds whose popup was blocked on the last open attempt. */ + private readonly blocked = signal>(new Set()); + + /** Most recent completion notice surfaced to the chat layer. */ + private readonly lastCompletion = signal(null); + + /** Resume handler registered by the chat layer. Replayed when a + * consent completes successfully. */ + private resumeHandler: OAuthResumeHandler | null = null; + + readonly pending = computed(() => + Array.from(this.requests().values()).sort((a, b) => a.receivedAt - b.receivedAt), + ); + + readonly hasPending = computed(() => this.requests().size > 0); + + readonly completion = this.lastCompletion.asReadonly(); + + constructor() { + // Primary channel: BroadcastChannel. AgentCore's OAuth popup navigates + // through external origins (Google, AgentCore), which triggers Chrome's + // Cross-Origin-Opener-Policy and severs window.opener. window.postMessage + // from the /oauth-complete page is silently blocked in that case, so we + // rely on a same-origin BroadcastChannel to bridge popup → opener. + try { + const channel = new BroadcastChannel('agentcore-oauth-complete'); + channel.addEventListener('message', (event) => { + if (!isOAuthCompleteMessage(event.data)) { + return; + } + this.handleCompletion(event.data); + }); + this.destroyRef.onDestroy(() => channel.close()); + } catch { + // BroadcastChannel unavailable — fall back to postMessage below. + } + + // Fallback channel: window postMessage (pre-COOP browsers, or flows + // where the popup manages to retain window.opener). The origin guard + // makes sure cross-origin pages can't spoof a completion. + fromEvent(window, 'message') + .pipe(takeUntilDestroyed(this.destroyRef)) + .subscribe((event) => { + if (event.origin !== window.location.origin) { + return; + } + if (!isOAuthCompleteMessage(event.data)) { + return; + } + this.handleCompletion(event.data); + }); + } + + /** + * Register a consent request coming off the SSE stream. + * Duplicate providerIds refresh the existing entry — the backend may + * reissue an interrupt with a new id if the user retried. + * + * Rejects non-https URLs — see {@link isSafeConsentUrl}. + */ + requestConsent( + providerId: string, + authorizationUrl: string | undefined, + interruptId?: string, + messageId?: string, + sessionId?: string, + ): void { + // Hydration from session metadata passes undefined — the URL gets fetched + // lazily on Connect. Live SSE flows still pass the URL up front for the + // fast-path popup with no extra roundtrip. + if (authorizationUrl !== undefined && !isSafeConsentUrl(authorizationUrl)) { + console.error( + 'OAuth consent rejected: authorizationUrl is not https', + { providerId }, + ); + return; + } + // Drop re-emissions of an already-handled interrupt. Stream replays after + // refresh, or a delayed server-side breadcrumb clear, can fire the same + // `oauth_required` event again — without this guard a successfully + // consented or dismissed prompt would resurrect. + if (interruptId && this.seenInterruptIds.has(interruptId)) { + return; + } + if (interruptId) { + this.seenInterruptIds.add(interruptId); + } + this.requests.update((map) => { + const next = new Map(map); + next.set(providerId, { + providerId, + authorizationUrl, + interruptId, + messageId, + sessionId, + receivedAt: Date.now(), + }); + return next; + }); + // A fresh request clears any prior blocked state for this provider. + this.blocked.update((set) => { + if (!set.has(providerId)) { + return set; + } + const next = new Set(set); + next.delete(providerId); + return next; + }); + } + + /** + * Open the AgentCore Identity consent URL in a popup window. + * + * If the browser blocks the popup, we mark the provider as blocked and + * surface that to the UI rather than navigating the parent tab away — + * a redirect would tear down the chat mid-conversation and leave the + * paused agent turn hanging. + * + * Returns true if the popup opened, false if it was blocked or the URL + * failed validation. Callers can use this to trigger a fallback UI. + */ + async openConsentPopup(providerId: string): Promise { + const request = this.requests().get(providerId); + if (!request) { + return false; + } + + // Hydrated requests don't carry a URL — fetch a fresh one. Stored URLs + // can also be stale if the live SSE event fired more than a few minutes + // ago, so we treat any missing/expired URL as a refresh trigger. + let authorizationUrl = request.authorizationUrl; + if (!authorizationUrl) { + this.inFlight.update((set) => { + const next = new Set(set); + next.add(providerId); + return next; + }); + try { + const response = await this.userConnectorsService.initiateConsent(providerId); + if (response.connected || !response.authorizationUrl) { + // Already consented while paused (e.g. the user authorized in + // another tab). Drop the request and let the resume handler — if + // any — fire so the agent can finish the turn. + this.dismiss(providerId); + if (request.interruptId && this.resumeHandler) { + void Promise.resolve( + this.resumeHandler([request.interruptId], { sessionId: request.sessionId }), + ).catch((err) => + console.error('OAuth resume handler failed after pre-consented refresh', err), + ); + } + return false; + } + authorizationUrl = response.authorizationUrl; + this.requests.update((map) => { + const next = new Map(map); + const current = next.get(providerId); + if (current) { + next.set(providerId, { ...current, authorizationUrl }); + } + return next; + }); + } catch (err) { + console.error('Failed to fetch fresh authorization URL', err); + this.inFlight.update((set) => { + if (!set.has(providerId)) return set; + const next = new Set(set); + next.delete(providerId); + return next; + }); + return false; + } + } + + // Re-validate on the hot path even though requestConsent already + // checked — defensive against anyone mutating the stored entry. + if (!isSafeConsentUrl(authorizationUrl)) { + console.error( + 'OAuth consent rejected at open: authorizationUrl is not https', + { providerId }, + ); + this.inFlight.update((set) => { + if (!set.has(providerId)) return set; + const next = new Set(set); + next.delete(providerId); + return next; + }); + return false; + } + + const width = 520; + const height = 680; + const left = window.screenX + Math.max(0, (window.outerWidth - width) / 2); + const top = window.screenY + Math.max(0, (window.outerHeight - height) / 2); + + const features = [ + `width=${width}`, + `height=${height}`, + `left=${Math.round(left)}`, + `top=${Math.round(top)}`, + 'resizable=yes', + 'scrollbars=yes', + 'status=no', + 'toolbar=no', + 'menubar=no', + 'location=no', + ].join(','); + + const popup = window.open(authorizationUrl, `oauth-${providerId}`, features); + + if (!popup) { + this.blocked.update((set) => { + if (set.has(providerId)) { + return set; + } + const next = new Set(set); + next.add(providerId); + return next; + }); + return false; + } + + this.blocked.update((set) => { + if (!set.has(providerId)) { + return set; + } + const next = new Set(set); + next.delete(providerId); + return next; + }); + + this.inFlight.update((set) => { + const next = new Set(set); + next.add(providerId); + return next; + }); + + // Watch for the user closing the popup without completing consent. + // Without this the provider stays "in-flight" forever and the Connect + // button remains disabled. We poll because there's no reliable + // cross-browser event for popup close, especially under COOP. + this.watchPopupClose(providerId, popup); + return true; + } + + /** Poll a popup window until it closes; on close, drop the provider out + * of `inFlight` so the UI re-enables the Connect button. The pending + * request stays so the chat banner can offer a retry. */ + private watchPopupClose(providerId: string, popup: Window): void { + // Cancel any prior watcher for this provider — only one popup at a time. + this.cancelCloseWatcher(providerId); + + const interval = setInterval(() => { + let closed = false; + try { + closed = popup.closed; + } catch { + // Cross-Origin-Opener-Policy can block reads of `closed` after the + // popup navigates externally. Give up — the user can dismiss the + // banner manually if needed. + this.cancelCloseWatcher(providerId); + return; + } + if (!closed) return; + + this.cancelCloseWatcher(providerId); + // Only act if still flagged in-flight: a successful completion already + // ran handleCompletion → dismiss() before the popup's own close. + if (!this.inFlight().has(providerId)) return; + this.inFlight.update((set) => { + if (!set.has(providerId)) return set; + const next = new Set(set); + next.delete(providerId); + return next; + }); + }, 500); + this.closeWatchers.set(providerId, interval); + this.destroyRef.onDestroy(() => this.cancelCloseWatcher(providerId)); + } + + private cancelCloseWatcher(providerId: string): void { + const interval = this.closeWatchers.get(providerId); + if (interval !== undefined) { + clearInterval(interval); + this.closeWatchers.delete(providerId); + } + } + + /** Check whether a popup is still open for this provider. */ + isInFlight(providerId: string): boolean { + return this.inFlight().has(providerId); + } + + /** Check whether the last popup-open attempt was blocked. */ + isBlocked(providerId: string): boolean { + return this.blocked().has(providerId); + } + + /** + * Return the https authorization URL for a provider, or null if no + * pending request. Used by the banner to render an anchor-based fallback + * when the popup is blocked. + */ + getAuthorizationUrl(providerId: string): string | null { + const request = this.requests().get(providerId); + return request?.authorizationUrl ?? null; + } + + /** + * Register the chat-layer handler that resumes the paused agent turn + * after one or more OAuth consents complete. The handler receives the + * interrupt ids whose tokens are ready; replacing it (set to null) + * disables auto-resume. + */ + setResumeHandler(handler: OAuthResumeHandler | null): void { + this.resumeHandler = handler; + } + + /** + * Drop a single consent request from local state, and (when called from + * the UI's explicit dismiss button) clear the persisted breadcrumb so a + * refresh doesn't resurrect the prompt. + * + * On completion-driven cleanup ({@link handleCompletion}) we set + * ``syncServer: false`` because the resume request that follows will + * remove the same breadcrumb server-side — a separate DELETE would just + * be redundant network noise. + */ + dismiss(providerId: string, options?: { syncServer?: boolean }): void { + const entry = this.requests().get(providerId); + const sessionId = entry?.sessionId; + const interruptId = entry?.interruptId; + + this.requests.update((map) => { + if (!map.has(providerId)) { + return map; + } + const next = new Map(map); + next.delete(providerId); + return next; + }); + this.inFlight.update((set) => { + if (!set.has(providerId)) { + return set; + } + const next = new Set(set); + next.delete(providerId); + return next; + }); + this.blocked.update((set) => { + if (!set.has(providerId)) { + return set; + } + const next = new Set(set); + next.delete(providerId); + return next; + }); + + if (options?.syncServer === false || !sessionId || !interruptId) { + return; + } + + // Best-effort: a backend cleanup failure shouldn't block the UI from + // hiding the prompt — the prompt is already gone locally. + void this.sessionService + .dismissPendingInterrupt(sessionId, interruptId) + .catch((err) => { + console.warn('Failed to clear persisted pending_interrupt; local dismiss still applied', err); + }); + } + + /** Reset all state (new session, logout). */ + clear(): void { + this.requests.set(new Map()); + this.inFlight.set(new Set()); + this.blocked.set(new Set()); + this.lastCompletion.set(null); + this.seenInterruptIds.clear(); + } + + /** Acknowledge the last completion signal after the UI has reacted. */ + acknowledgeCompletion(): void { + this.lastCompletion.set(null); + } + + private handleCompletion(message: OAuthCompleteMessage): void { + this.lastCompletion.set(message); + if (message.providerId) { + // Completion arrived — the close watcher is no longer needed and + // would otherwise fire spuriously when the popup auto-closes after + // postMessage. + this.cancelCloseWatcher(message.providerId); + } + if (message.status !== 'success' || !message.providerId) { + return; + } + + // Capture the paused interrupt id BEFORE dismissing the request, since + // dismiss removes the entry the handler needs. A user-initiated + // settings-page consent has no interruptId — nothing to resume. + // Skip server sync: the resume request fired below clears the persisted + // interrupt server-side, so a separate DELETE would just be redundant. + const request = this.requests().get(message.providerId); + this.dismiss(message.providerId, { syncServer: false }); + + if (!request?.interruptId || !this.resumeHandler) { + return; + } + + void Promise.resolve( + this.resumeHandler([request.interruptId], { sessionId: request.sessionId }), + ).catch((err) => { + // Resume failures are surfaced through the resume request's own error + // handling — log here for diagnostics but don't crash the consent flow. + console.error('OAuth resume handler failed', err); + }); + } +} diff --git a/frontend/ai.client/src/app/session/components/message-list/components/assistant-message.component.ts b/frontend/ai.client/src/app/session/components/message-list/components/assistant-message.component.ts index b0078598..f0b13b65 100644 --- a/frontend/ai.client/src/app/session/components/message-list/components/assistant-message.component.ts +++ b/frontend/ai.client/src/app/session/components/message-list/components/assistant-message.component.ts @@ -1,4 +1,4 @@ -import { ChangeDetectionStrategy, Component, input, computed } from '@angular/core'; +import { ChangeDetectionStrategy, Component, computed, inject, input } from '@angular/core'; import { Message, ContentBlock, ToolUseData } from '../../../services/models/message.model'; import { ToolUseComponent } from './tool-use'; import { ToolRailComponent } from './tool-rail'; @@ -6,6 +6,11 @@ import { ToolCallGroup, ToolCallDisplay } from './tool-rail/tool-rail.model'; import { ReasoningContentComponent } from './reasoning-content'; import { StreamingTextComponent } from './streaming-text.component'; import { InlineVisualComponent } from './inline-visual'; +import { OAuthConsentPromptComponent } from './oauth-consent-prompt/oauth-consent-prompt.component'; +import { + OAuthConsentRequest, + OAuthConsentService, +} from '../../../../services/oauth-consent/oauth-consent.service'; // ────────────────────────────────────────────────────────────── // 🔧 MOCK FLAG — set to true to render 10 fake tool calls @@ -98,7 +103,13 @@ const MOCK_TOOL_GROUP: ToolCallGroup = { * promoted visuals and grouped tool rails. */ interface DisplayBlock { - type: 'text' | 'tool_group' | 'tool_use_minimized' | 'promoted_visual' | 'reasoningContent'; + type: + | 'text' + | 'tool_group' + | 'tool_use_minimized' + | 'promoted_visual' + | 'reasoningContent' + | 'oauth_required'; data?: ContentBlock; // For tool groups (inline rail) group?: ToolCallGroup; @@ -106,6 +117,8 @@ interface DisplayBlock { uiType?: string; payload?: unknown; toolUseId?: string; + // For inline OAuth consent prompts + oauthRequest?: OAuthConsentRequest; } @Component({ @@ -117,6 +130,7 @@ interface DisplayBlock { ReasoningContentComponent, StreamingTextComponent, InlineVisualComponent, + OAuthConsentPromptComponent, ], template: `
@@ -182,6 +196,14 @@ interface DisplayBlock { />
} + @case ('oauth_required') { +
+ +
+ } } } @@ -236,6 +258,8 @@ export class AssistantMessageComponent { message = input.required(); isStreaming = input(false); + private consentService = inject(OAuthConsentService); + /** * Transforms content blocks into display blocks. * - Consecutive non-promoted tool-use blocks are grouped into a single ToolCallGroup @@ -253,6 +277,14 @@ export class AssistantMessageComponent { } const blocks = this.message().content; + const messageId = this.message().id; + // Pending interrupts anchored to this message. Used to flip the matching + // tool_use blocks to ``awaiting_auth`` so the row reads as "paused for + // authorization" instead of an indefinite spinner. + const pendingInterruptsHere = this.consentService + .pending() + .filter((req) => req.messageId === messageId); + const hasPendingInterruptHere = pendingInterruptsHere.length > 0; const result: DisplayBlock[] = []; let pendingToolCalls: ToolCallDisplay[] = []; @@ -307,13 +339,22 @@ export class AssistantMessageComponent { toolUseId: toolUse.toolUseId }); } else { - // Accumulate into the current tool group + // Accumulate into the current tool group. A tool_use with no result + // on a message that has a pending OAuth interrupt is the row that + // got paused — surface that distinct state instead of a forever- + // spinning ``pending``. + const baseStatus = toolUse.status || 'pending'; + const hasNoResult = !toolUse.result; + const status: ToolCallDisplay['status'] = + hasPendingInterruptHere && hasNoResult && baseStatus === 'pending' + ? 'awaiting_auth' + : baseStatus; pendingToolCalls.push({ id: toolUse.toolUseId, toolName: toolUse.name, input: toolUse.input || {}, result: toolUse.result, - status: toolUse.status || 'pending', + status, }); } continue; @@ -323,6 +364,13 @@ export class AssistantMessageComponent { // Flush any remaining tool calls flushToolGroup(); + // Append any pending OAuth consent prompts anchored to this message. + // Tracking through the consent service signal keeps the synthetic prompt + // out of message.content so it is never persisted to the backend. + for (const req of pendingInterruptsHere) { + result.push({ type: 'oauth_required', oauthRequest: req }); + } + return result; }); diff --git a/frontend/ai.client/src/app/session/components/message-list/components/oauth-consent-prompt/oauth-consent-prompt.component.ts b/frontend/ai.client/src/app/session/components/message-list/components/oauth-consent-prompt/oauth-consent-prompt.component.ts new file mode 100644 index 00000000..9cc9cfbb --- /dev/null +++ b/frontend/ai.client/src/app/session/components/message-list/components/oauth-consent-prompt/oauth-consent-prompt.component.ts @@ -0,0 +1,304 @@ +import { ChangeDetectionStrategy, Component, computed, inject, input } from '@angular/core'; +import { NgIcon, provideIcons } from '@ng-icons/core'; +import { + heroAcademicCap, + heroArrowTopRightOnSquare, + heroCloud, + heroCodeBracket, + heroLink, + heroLockClosed, + heroXMark, +} from '@ng-icons/heroicons/outline'; +import { + OAuthConsentRequest, + OAuthConsentService, +} from '../../../../../services/oauth-consent/oauth-consent.service'; +import { UserConnector } from '../../../../../settings/connectors/models/user-connector.model'; +import { UserConnectorsService } from '../../../../../settings/connectors/services/user-connectors.service'; + +/** + * Inline OAuth consent prompt rendered alongside the assistant message whose + * tool call needed authorization. Looks up the connector definition to display + * its icon (admin-uploaded base64 wins over heroicon name) and friendly name, + * delegates click handling to {@link OAuthConsentService}. + */ +@Component({ + selector: 'app-oauth-consent-prompt', + changeDetection: ChangeDetectionStrategy.OnPush, + imports: [NgIcon], + providers: [ + provideIcons({ + heroAcademicCap, + heroArrowTopRightOnSquare, + heroCloud, + heroCodeBracket, + heroLink, + heroLockClosed, + heroXMark, + }), + ], + host: { class: 'block' }, + template: ` +
+ + +
+ @if (iconDataUrl(); as data) { + + } @else { +
+ +
+

+

+

+ @if (isBlocked()) { + Popup blocked. Open + {{ displayName() }} + in a new tab to continue. + } @else { + Connect + {{ displayName() }} + so the assistant can finish this request. + } +

+
+ +
+ @if (isBlocked() && request().authorizationUrl; as blockedUrl) { + + Open + + } @else { + + } + +
+
+ `, + styles: ` + @import 'tailwindcss'; + @custom-variant dark (&:where(.dark, .dark *)); + + :host { + display: block; + } + + .oauth-prompt { + animation: oauth-rise 0.32s cubic-bezier(0.16, 1, 0.3, 1); + } + + /* Override the global \`.message-block p\` rule (styles.css) which adds + a 16px margin-bottom for prose paragraphs. Inside the prompt the two +

s are a tight label + description pair. */ + .oauth-prompt p { + margin-bottom: 0; + } + + .action-btn { + display: inline-flex; + align-items: center; + gap: 0.25rem; + border-radius: 0.375rem; + padding: 0.25rem 0.625rem; + font-size: 0.75rem; + font-weight: 600; + color: white; + background: var(--color-secondary-500); + transition: + background-color 120ms ease, + transform 120ms ease; + } + + .action-btn:hover:not(:disabled) { + background: var(--color-secondary-600); + } + + .action-btn:active:not(:disabled) { + transform: translateY(1px); + } + + .action-btn:focus-visible { + outline: 2px solid var(--color-secondary-500); + outline-offset: 2px; + } + + .action-btn:disabled { + opacity: 0.85; + cursor: default; + } + + /* Dismiss is subtle until the row is hovered or focused. */ + .dismiss-btn { + opacity: 0; + } + + .group:hover .dismiss-btn, + .group:focus-within .dismiss-btn { + opacity: 1; + } + + @keyframes oauth-rise { + from { + opacity: 0; + transform: translateY(6px); + } + to { + opacity: 1; + transform: translateY(0); + } + } + + @media (prefers-reduced-motion: reduce) { + .oauth-prompt { + animation: none; + } + .action-btn { + transition: none; + } + } + `, +}) +export class OAuthConsentPromptComponent { + request = input.required(); + + protected consentService = inject(OAuthConsentService); + private connectorsService = inject(UserConnectorsService); + + /** Connector definition for this providerId, when the catalog is loaded. */ + private connector = computed(() => { + const connectors = this.connectorsService.connectorsResource.value(); + if (!connectors) return null; + return connectors.find((c) => c.providerId === this.request().providerId) ?? null; + }); + + /** Admin-uploaded base64 icon. Wins over heroicon when present. */ + protected iconDataUrl = computed(() => this.connector()?.iconData ?? null); + + /** Heroicon fallback when no iconData exists. Mirrors the cascade used by + * the connectors-settings page so the same connector renders identically. */ + protected iconName = computed(() => { + const c = this.connector(); + if (c?.iconName) return c.iconName; + switch (c?.providerType) { + case 'google': + case 'microsoft': + return 'heroCloud'; + case 'github': + return 'heroCodeBracket'; + case 'canvas': + return 'heroAcademicCap'; + default: + return 'heroLink'; + } + }); + + protected displayName = computed(() => { + const c = this.connector(); + if (c?.displayName) return c.displayName; + return this.titleCase(this.request().providerId); + }); + + protected isInFlight = computed(() => + this.consentService.isInFlight(this.request().providerId), + ); + + protected isBlocked = computed(() => + this.consentService.isBlocked(this.request().providerId), + ); + + connect(): void { + void this.consentService.openConsentPopup(this.request().providerId); + } + + dismiss(event: Event): void { + event.stopPropagation(); + this.consentService.dismiss(this.request().providerId); + } + + private titleCase(providerId: string): string { + if (!providerId) return 'This tool'; + return providerId + .replace(/[-_]+/g, ' ') + .split(' ') + .filter((part) => part.length > 0) + .map((part) => part.charAt(0).toUpperCase() + part.slice(1)) + .join(' '); + } +} diff --git a/frontend/ai.client/src/app/session/components/message-list/components/tool-rail/tool-rail.component.ts b/frontend/ai.client/src/app/session/components/message-list/components/tool-rail/tool-rail.component.ts index 209c3b6a..35c51277 100644 --- a/frontend/ai.client/src/app/session/components/message-list/components/tool-rail/tool-rail.component.ts +++ b/frontend/ai.client/src/app/session/components/message-list/components/tool-rail/tool-rail.component.ts @@ -81,9 +81,11 @@ export class ToolRailComponent { /** CSS class for status dot */ statusDotClass(call: ToolCallDisplay): string { switch (call.status) { - case 'complete': return 'status-dot bg-green-500'; - case 'pending': return 'status-dot bg-amber-400 shimmer'; - case 'error': return 'status-dot bg-red-500'; + case 'complete': return 'status-dot bg-green-500'; + case 'pending': return 'status-dot bg-amber-400 shimmer'; + case 'error': return 'status-dot bg-red-500'; + case 'awaiting_auth': return 'status-dot bg-primary-500 ring-2 ring-primary-300/40 dark:ring-primary-400/30'; + default: return 'status-dot bg-gray-400'; } } diff --git a/frontend/ai.client/src/app/session/components/message-list/components/tool-rail/tool-rail.model.ts b/frontend/ai.client/src/app/session/components/message-list/components/tool-rail/tool-rail.model.ts index 4dd8fda7..c669c333 100644 --- a/frontend/ai.client/src/app/session/components/message-list/components/tool-rail/tool-rail.model.ts +++ b/frontend/ai.client/src/app/session/components/message-list/components/tool-rail/tool-rail.model.ts @@ -20,8 +20,11 @@ export interface ToolCallDisplay { content: ToolResultContent[]; }; - /** Execution status (from toolUseData.status, defaults to 'pending') */ - status: 'pending' | 'complete' | 'error'; + /** Execution status (from toolUseData.status, defaults to 'pending'). + * ``awaiting_auth`` is derived in the message renderer when the tool was + * paused on an OAuth consent gate — the tool didn't fail, it's waiting + * for the user to authorize. */ + status: 'pending' | 'complete' | 'error' | 'awaiting_auth'; /** Optional LLM-generated one-line summary of this tool call's result */ summary?: string; diff --git a/frontend/ai.client/src/app/session/components/message-list/message-list.component.html b/frontend/ai.client/src/app/session/components/message-list/message-list.component.html index 45751424..a76a0edb 100644 --- a/frontend/ai.client/src/app/session/components/message-list/message-list.component.html +++ b/frontend/ai.client/src/app/session/components/message-list/message-list.component.html @@ -30,6 +30,15 @@ } } + + @for (request of unanchoredInterrupts(); track request.providerId) { +

+ +
+ } + @if (isChatLoading()) {
diff --git a/frontend/ai.client/src/app/session/components/message-list/message-list.component.ts b/frontend/ai.client/src/app/session/components/message-list/message-list.component.ts index 8ca7f641..e3d1ba12 100644 --- a/frontend/ai.client/src/app/session/components/message-list/message-list.component.ts +++ b/frontend/ai.client/src/app/session/components/message-list/message-list.component.ts @@ -1,4 +1,4 @@ -import { Component, input, signal, effect, OnDestroy, inject, PLATFORM_ID } from '@angular/core'; +import { Component, computed, input, signal, effect, OnDestroy, inject, PLATFORM_ID } from '@angular/core'; import { isPlatformBrowser } from '@angular/common'; import { Message } from '../../services/models/message.model'; import { UserMessageComponent } from './components/user-message.component'; @@ -6,10 +6,22 @@ import { AssistantMessageComponent } from './components/assistant-message.compon import { MessageMetadataBadgesComponent } from './components/message-metadata-badges.component'; import { CitationDisplayComponent } from '../citation-display/citation-display.component'; import { PulsatingLoaderComponent } from '../../../components/pulsating-loader.component'; +import { OAuthConsentPromptComponent } from './components/oauth-consent-prompt/oauth-consent-prompt.component'; +import { + OAuthConsentRequest, + OAuthConsentService, +} from '../../../services/oauth-consent/oauth-consent.service'; @Component({ selector: 'app-message-list', - imports: [UserMessageComponent, AssistantMessageComponent, MessageMetadataBadgesComponent, CitationDisplayComponent, PulsatingLoaderComponent], + imports: [ + UserMessageComponent, + AssistantMessageComponent, + MessageMetadataBadgesComponent, + CitationDisplayComponent, + PulsatingLoaderComponent, + OAuthConsentPromptComponent, + ], templateUrl: './message-list.component.html', styleUrl: './message-list.component.css', }) @@ -27,6 +39,18 @@ export class MessageListComponent implements OnDestroy { streamingMessageId = input(null); embeddedMode = input(false); + private consentService = inject(OAuthConsentService); + + /** Pending consent prompts whose anchor message id isn't in the loaded + * message list — typically the case when an interrupt fires on a turn + * whose partial assistant message wasn't persisted to AgentCore Memory. + * Rendered at the end of the conversation so the user still sees the + * affordance instead of a silently stalled tool call. */ + protected unanchoredInterrupts = computed(() => { + const ids = new Set(this.messages().map((m) => m.id)); + return this.consentService.pending().filter((req) => !req.messageId || !ids.has(req.messageId)); + }); + // Calculate the spacer height dynamically // This creates space at the bottom so user messages can scroll to the top spacerHeight = signal(0); diff --git a/frontend/ai.client/src/app/session/services/chat/chat-http.service.ts b/frontend/ai.client/src/app/session/services/chat/chat-http.service.ts index 5743f91f..f28485c7 100644 --- a/frontend/ai.client/src/app/session/services/chat/chat-http.service.ts +++ b/frontend/ai.client/src/app/session/services/chat/chat-http.service.ts @@ -65,7 +65,8 @@ export class ChatHttpService { headers: { 'Content-Type': 'application/json', Authorization: `Bearer ${token}`, - Accept: 'text/event-stream', + Accept: 'text/event-stream', + OAuth2CallbackUrl: `${window.location.origin}/oauth-complete`, }, body: JSON.stringify(requestObject), signal: abortController.signal, @@ -136,20 +137,10 @@ export class ChatHttpService { this.messageMapService.endStreaming(); this.chatStateService.setChatLoading(false); - // Generate title only for new sessions (fire and forget - don't block on this) + // Title is generated server-side concurrently with the stream + // (see /invocations). Refresh metadata so the sidebar reflects it. if (this.sessionService.isNewSession(requestObject.session_id)) { - this.generateTitle(requestObject.session_id, requestObject.message) - .then((response) => { - // Update the session title in the local cache - this.sessionService.updateSessionTitleInCache( - requestObject.session_id, - response.title, - ); - }) - .catch((error) => { - // Log error but don't block the user experience - console.error('Failed to generate session title:', error); - }); + this.refreshTitleFromServer(requestObject.session_id); } }, onerror: (err) => { @@ -186,6 +177,29 @@ export class ChatHttpService { this.chatStateService.resetState(); } + /** + * Pull the server-generated title into the local sidebar cache. + * + * Title generation runs concurrently with the agent stream on the backend. + * Nova Micro typically finishes well before the stream does, but on fast + * responses we may race past it — so on a "New Conversation" placeholder + * we retry once after a short delay before giving up. + */ + private async refreshTitleFromServer(sessionId: string, retried = false): Promise { + try { + const metadata = await this.sessionService.getSessionMetadata(sessionId); + if (metadata.title && metadata.title !== 'New Conversation') { + this.sessionService.updateSessionTitleInCache(sessionId, metadata.title); + return; + } + if (!retried) { + setTimeout(() => this.refreshTitleFromServer(sessionId, true), 1500); + } + } catch (error) { + console.error('Failed to refresh session title:', error); + } + } + /** * Generates a title for a session based on the user's input. * diff --git a/frontend/ai.client/src/app/session/services/chat/chat-request.service.ts b/frontend/ai.client/src/app/session/services/chat/chat-request.service.ts index e7ab45fa..be128a3f 100644 --- a/frontend/ai.client/src/app/session/services/chat/chat-request.service.ts +++ b/frontend/ai.client/src/app/session/services/chat/chat-request.service.ts @@ -1,4 +1,4 @@ -import { inject, Injectable } from '@angular/core'; +import { inject, Injectable, OnDestroy } from '@angular/core'; import { Router } from '@angular/router'; import { v4 as uuidv4 } from 'uuid'; import { ChatStateService } from './chat-state.service'; @@ -10,6 +10,10 @@ import { ModelService } from '../model/model.service'; import { ToolService } from '../../../services/tool/tool.service'; import { FileUploadService } from '../../../services/file-upload'; import { FileAttachmentData } from '../models/message.model'; +import { OAuthConsentService } from '../../../services/oauth-consent/oauth-consent.service'; +import { ErrorService } from '../../../services/error/error.service'; +import { StreamParserService } from './stream-parser.service'; +import { HttpErrorResponse } from '@angular/common/http'; export interface ContentFile { fileName: string; @@ -21,7 +25,7 @@ export interface ContentFile { @Injectable({ providedIn: 'root', }) -export class ChatRequestService { +export class ChatRequestService implements OnDestroy { // private conversationService = inject(ConversationService); private chatHttpService = inject(ChatHttpService); private chatStateService = inject(ChatStateService); @@ -31,9 +35,22 @@ export class ChatRequestService { private modelService = inject(ModelService); private toolService = inject(ToolService); private fileUploadService = inject(FileUploadService); + private oauthConsentService = inject(OAuthConsentService); + private streamParserService = inject(StreamParserService); + private errorService = inject(ErrorService); private router = inject(Router); // TODO: Inject proper logging service + constructor() { + this.oauthConsentService.setResumeHandler((interruptIds, context) => + this.resumeFromOAuthConsent(interruptIds, context?.sessionId), + ); + } + + ngOnDestroy(): void { + this.oauthConsentService.setResumeHandler(null); + } + async submitChatRequest( userInput: string, sessionId: string | null, @@ -149,6 +166,83 @@ export class ChatRequestService { return requestObject; } + /** + * Resume the paused agent turn by POSTing the interrupt responses. The + * backend rebuilds the agent from its persisted ``PausedTurnSnapshot``, + * so this request only needs to identify the session and the interrupts — + * no model / tools / prompt context is sent or required. Triggered by + * OAuthConsentService after the user completes a consent popup. + */ + private async resumeFromOAuthConsent( + interruptIds: string[], + sessionId?: string, + ): Promise { + if (interruptIds.length === 0 || !sessionId) { + return; + } + + // Reset the parser so the resumed stream is treated as a fresh batch + // of events. Without this, the parser stays in Completed state from + // the prior `done` and ignores everything. + this.streamParserService.reset(sessionId); + this.messageMapService.startStreaming(sessionId); + this.chatStateService.createNewAbortController(); + this.chatStateService.setChatLoading(true); + + const resumeRequest: Record = { + session_id: sessionId, + // The original prompt is already in the agent's interrupt context; + // sending an empty string keeps the request valid without + // re-augmenting or re-charging quota. + message: '', + interrupt_responses: interruptIds.map((interruptId) => ({ + interruptId, + // The token is already in AgentCore Identity's vault by the time + // we resume; the response payload itself doesn't carry a secret — + // it's just the signal that consent completed. + response: 'consented', + })), + }; + + try { + await this.chatHttpService.sendChatRequest(resumeRequest); + } catch (error) { + this.chatStateService.setChatLoading(false); + this.messageMapService.endStreaming(); + + // 400 from the resume route means either the persisted snapshot is + // missing/expired, or the agent's `_interrupt_state` doesn't recognize + // the submitted ids. Either way the user needs to retry the prompt. + if (this.isExpiredInterruptError(error)) { + this.errorService.addError( + 'Authorization expired', + 'The agent paused too long ago to resume this turn automatically. Please send your message again.', + ); + return; + } + throw error; + } + } + + /** Detect the 400 the inference-api returns for unknown/expired interrupt + * ids. Both fetch-based and HttpClient-based flows are checked because + * the resume path uses `fetch-event-source`, which surfaces errors as + * plain Error/Response objects rather than HttpErrorResponse. */ + private isExpiredInterruptError(error: unknown): boolean { + if (error instanceof HttpErrorResponse) { + return error.status === 400; + } + if (typeof error === 'object' && error !== null) { + const status = (error as { status?: unknown }).status; + if (status === 400) return true; + const message = (error as { message?: unknown }).message; + if (typeof message === 'string' && /expired interrupt/i.test(message)) { + return true; + } + } + return false; + } + /** * Get file attachment metadata for display in user messages. * Retrieves file metadata from FileUploadService for given upload IDs. diff --git a/frontend/ai.client/src/app/session/services/chat/stream-parser.service.ts b/frontend/ai.client/src/app/session/services/chat/stream-parser.service.ts index 14a343a8..c8e4c61d 100644 --- a/frontend/ai.client/src/app/session/services/chat/stream-parser.service.ts +++ b/frontend/ai.client/src/app/session/services/chat/stream-parser.service.ts @@ -14,6 +14,8 @@ import { QuotaWarning, QuotaExceeded, } from '../../../services/quota/quota-warning.service'; +import { OAuthConsentService } from '../../../services/oauth-consent/oauth-consent.service'; +import type { OAuthRequiredEvent } from '../../../shared/utils/stream-parser'; import { processStreamEvent, createStreamLineParser, @@ -48,6 +50,7 @@ export class StreamParserService { private chatStateService = inject(ChatStateService); private errorService = inject(ErrorService); private quotaWarningService = inject(QuotaWarningService); + private oauthConsentService = inject(OAuthConsentService); // ========================================================================= // State Signals @@ -194,8 +197,11 @@ export class StreamParserService { } // Check if we should process this event - const isStartOrErrorEvent = event === 'message_start' || event === 'error'; - if (!isStartOrErrorEvent && !this.shouldProcessEvent()) { + // oauth_required arrives after message_stop/done by design (see CLAUDE.md SSE + // table) — allow it through even when the stream state is Completed. + const isAlwaysAllowedEvent = + event === 'message_start' || event === 'error' || event === 'oauth_required'; + if (!isAlwaysAllowedEvent && !this.shouldProcessEvent()) { return; } @@ -289,6 +295,28 @@ export class StreamParserService { onQuotaWarning: (data) => this.quotaWarningService.setWarning(data as QuotaWarning), onQuotaExceeded: (data) => this.quotaWarningService.setQuotaExceeded(data as QuotaExceeded), + onOAuthRequired: (data: OAuthRequiredEvent) => { + // oauth_required arrives after message_stop, so the triggering + // assistant message is normally in completedMessages; fall back + // to the in-flight builder for tool_use stop reasons that keep + // the message active. + const messages = this.allMessages(); + let lastAssistantId: string | undefined; + for (let i = messages.length - 1; i >= 0; i--) { + if (messages[i].role === 'assistant') { + lastAssistantId = messages[i].id; + break; + } + } + this.oauthConsentService.requestConsent( + data.providerId, + data.authorizationUrl, + data.interruptId, + lastAssistantId, + this.sessionId ?? undefined, + ); + }, + onError: (data) => this.handleError(data), onStreamError: (data) => this.errorService.handleConversationalStreamError(data as ConversationalStreamError), diff --git a/frontend/ai.client/src/app/session/services/session/message-map.service.spec.ts b/frontend/ai.client/src/app/session/services/session/message-map.service.spec.ts index 9e33f138..54b23c2e 100644 --- a/frontend/ai.client/src/app/session/services/session/message-map.service.spec.ts +++ b/frontend/ai.client/src/app/session/services/session/message-map.service.spec.ts @@ -4,6 +4,7 @@ import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; import { MessageMapService } from './message-map.service'; import { SessionService } from './session.service'; import { FileUploadService } from '../../../services/file-upload'; +import { OAuthConsentService } from '../../../services/oauth-consent/oauth-consent.service'; import { signal } from '@angular/core'; describe('MessageMapService', () => { @@ -11,6 +12,7 @@ describe('MessageMapService', () => { let httpMock: HttpTestingController; let mockSessionService: any; let mockFileUploadService: any; + let mockOAuthConsentService: any; beforeEach(() => { TestBed.resetTestingModule(); @@ -22,12 +24,16 @@ describe('MessageMapService', () => { mockFileUploadService = { listSessionFiles: vi.fn().mockResolvedValue([]) }; + mockOAuthConsentService = { + requestConsent: vi.fn() + }; TestBed.configureTestingModule({ imports: [HttpClientTestingModule], providers: [ MessageMapService, { provide: SessionService, useValue: mockSessionService }, - { provide: FileUploadService, useValue: mockFileUploadService } + { provide: FileUploadService, useValue: mockFileUploadService }, + { provide: OAuthConsentService, useValue: mockOAuthConsentService } ] }); service = TestBed.inject(MessageMapService); @@ -85,11 +91,48 @@ describe('MessageMapService', () => { expect(mockSessionService.getMessages).toHaveBeenCalledWith('session-1'); expect(mockFileUploadService.listSessionFiles).toHaveBeenCalledWith('session-1'); - + const messagesSignal = service.getMessagesForSession('session-1'); expect(messagesSignal()).toEqual(mockMessages); }); + it('should hydrate pending OAuth interrupts from camelCase wire response', async () => { + // Regression: backend serializes with by_alias=True so the wire payload uses + // camelCase (pendingInterrupts, interruptId, providerId, ...). If the consumer + // reads snake_case fields, the consent prompt silently fails to re-render + // after a refresh. + const mockMessages = [ + { id: 'msg-assistant-7', role: 'assistant', content: [{ type: 'text', text: 'ok' }] }, + ]; + mockSessionService.getMessages.mockResolvedValue({ + messages: mockMessages, + pendingInterrupts: [ + { + interruptId: 'v1:before_tool_call:tooluse_abc:xyz', + providerId: 'google-calendar-employee', + createdAt: '2026-04-26T01:13:54.543143+00:00', + }, + ], + }); + + await service.loadMessagesForSession('session-with-interrupt'); + + expect(mockOAuthConsentService.requestConsent).toHaveBeenCalledTimes(1); + expect(mockOAuthConsentService.requestConsent).toHaveBeenCalledWith( + 'google-calendar-employee', + undefined, + 'v1:before_tool_call:tooluse_abc:xyz', + 'msg-assistant-7', + 'session-with-interrupt', + ); + }); + + it('should not call requestConsent when no pending interrupts are returned', async () => { + mockSessionService.getMessages.mockResolvedValue({ messages: [] }); + await service.loadMessagesForSession('session-clean'); + expect(mockOAuthConsentService.requestConsent).not.toHaveBeenCalled(); + }); + it('should set loading session state', () => { service.setLoadingSession('session-1'); expect(service.isLoadingSession()).toBe('session-1'); diff --git a/frontend/ai.client/src/app/session/services/session/message-map.service.ts b/frontend/ai.client/src/app/session/services/session/message-map.service.ts index f4d7197d..446f84dc 100644 --- a/frontend/ai.client/src/app/session/services/session/message-map.service.ts +++ b/frontend/ai.client/src/app/session/services/session/message-map.service.ts @@ -2,8 +2,9 @@ import { Injectable, Signal, WritableSignal, signal, effect, inject } from '@angular/core'; import { ContentBlock, FileAttachmentData, Message } from '../models/message.model'; import { StreamParserService } from '../chat/stream-parser.service'; -import { SessionService } from './session.service'; +import { PendingInterrupt, SessionService } from './session.service'; import { FileUploadService, FileMetadata } from '../../../services/file-upload'; +import { OAuthConsentService } from '../../../services/oauth-consent/oauth-consent.service'; /** Regex to match file attachment marker in message text: [Attached files: file1.pdf, file2.png] */ const ATTACHED_FILES_PATTERN = /\n\n\[Attached files: ([^\]]+)\]$/; @@ -42,6 +43,7 @@ export class MessageMapService { private streamParser = inject(StreamParserService); private sessionService = inject(SessionService); private fileUploadService = inject(FileUploadService); + private oauthConsentService = inject(OAuthConsentService); constructor() { // Reactive effect: automatically sync streaming messages to the message map @@ -278,6 +280,13 @@ export class MessageMapService { return updated; }); + + // Hydrate OAuth consent service from any persisted pending interrupts + // so a reload restores the consent prompt anchored to the right turn. + // Anchor falls back to the most recent assistant message in history, + // matching how the live SSE flow already attaches prompts. Authorization + // URL is intentionally omitted — fresh URL is fetched lazily on Connect. + this.hydratePendingInterrupts(sessionId, messagesResponse.pendingInterrupts, processedMessages); } catch (error) { console.error('Failed to load messages for session:', sessionId, error); throw error; @@ -287,6 +296,38 @@ export class MessageMapService { } } + /** + * Replay persisted pending OAuth interrupts into the consent service so + * the inline prompt re-renders. If the backend records a triggering message + * id we use it directly; otherwise we anchor to the most recent assistant + * message (mirroring the live-stream behavior in stream-parser.service.ts). + */ + private hydratePendingInterrupts( + sessionId: string, + interrupts: PendingInterrupt[] | undefined, + messages: Message[], + ): void { + if (!interrupts || interrupts.length === 0) { + return; + } + let lastAssistantId: string | undefined; + for (let i = messages.length - 1; i >= 0; i--) { + if (messages[i].role === 'assistant') { + lastAssistantId = messages[i].id; + break; + } + } + for (const interrupt of interrupts) { + this.oauthConsentService.requestConsent( + interrupt.providerId, + undefined, // URL is fetched lazily on Connect — stored URLs go stale + interrupt.interruptId, + interrupt.triggeringMessageId ?? lastAssistantId, + sessionId, + ); + } + } + /** * Fetch file metadata for a session from the backend. * Returns empty array on error to allow graceful degradation. diff --git a/frontend/ai.client/src/app/session/services/session/session.service.spec.ts b/frontend/ai.client/src/app/session/services/session/session.service.spec.ts index cff98f49..8b90a112 100644 --- a/frontend/ai.client/src/app/session/services/session/session.service.spec.ts +++ b/frontend/ai.client/src/app/session/services/session/session.service.spec.ts @@ -55,7 +55,7 @@ describe('SessionService', () => { describe('getMessages (no ensureAuthenticated)', () => { it('should GET messages', async () => { - const resp: MessagesListResponse = { messages: [mockMessage], next_token: null }; + const resp: MessagesListResponse = { messages: [mockMessage], nextToken: null }; const promise = service.getMessages('s1'); httpMock.expectOne('http://localhost:8000/sessions/s1/messages').flush(resp); expect(await promise).toEqual(resp); diff --git a/frontend/ai.client/src/app/session/services/session/session.service.ts b/frontend/ai.client/src/app/session/services/session/session.service.ts index 96dd45e1..05a521ca 100644 --- a/frontend/ai.client/src/app/session/services/session/session.service.ts +++ b/frontend/ai.client/src/app/session/services/session/session.service.ts @@ -28,16 +28,36 @@ export interface SessionsListResponse { nextToken: string | null; } +/** + * Pending OAuth consent interrupt that paused an agent turn for this session. + * + * Returned from `GET /sessions/{id}/messages` so the frontend can rediscover + * pending consents after a browser refresh. `authorization_url` is intentionally + * absent — those expire quickly; the frontend re-fetches on Connect. + */ +export interface PendingInterrupt { + /** Strands interrupt id used to resume the paused turn */ + interruptId: string; + /** Connector providerId needing consent */ + providerId: string; + /** Id of the assistant message whose tool call triggered this interrupt, if known */ + triggeringMessageId?: string | null; + /** ISO 8601 timestamp when the interrupt was recorded */ + createdAt: string; +} + /** * Response model for listing messages with pagination support. - * + * * Matches the MessagesListResponse model from the Python API. */ export interface MessagesListResponse { /** List of messages in the session */ messages: Message[]; /** Pagination token for retrieving the next page of results */ - next_token: string | null; + nextToken: string | null; + /** OAuth consent interrupts that paused agent turns and are awaiting user action */ + pendingInterrupts?: PendingInterrupt[]; } /** @@ -369,7 +389,7 @@ export class SessionService { * // Get next page * const nextPage = await sessionService.getSessions({ * limit: 20, - * next_token: response.next_token + * next_token: response.nextToken * }); * ``` */ @@ -417,17 +437,17 @@ export class SessionService { * // Get next page * const nextPage = await sessionService.getMessages( * '8e70ae89-93af-4db7-ba60-f13ea201f4cd', - * { limit: 20, next_token: response.next_token } + * { limit: 20, next_token: response.nextToken } * ); * ``` */ async getMessages(sessionId: string, params?: GetMessagesParams): Promise { let httpParams = new HttpParams(); - + if (params?.limit !== undefined) { httpParams = httpParams.set('limit', params.limit.toString()); } - + if (params?.next_token) { httpParams = httpParams.set('next_token', params.next_token); } @@ -446,6 +466,19 @@ export class SessionService { } } + /** + * Dismiss a persisted OAuth pending interrupt for a session. Idempotent — + * the backend returns 204 even if the entry is already gone. + */ + async dismissPendingInterrupt(sessionId: string, interruptId: string): Promise { + // The interrupt id contains a colon (e.g. ``oauth:google-calendar``); + // encode it so it survives URL parsing on the path. + const encoded = encodeURIComponent(interruptId); + await firstValueFrom( + this.http.delete(`${this.baseUrl()}/${sessionId}/pending-interrupts/${encoded}`), + ); + } + /** * Fetches metadata for a specific session from the Python API. * diff --git a/frontend/ai.client/src/app/settings/connections/connections.page.ts b/frontend/ai.client/src/app/settings/connections/connections.page.ts deleted file mode 100644 index 904104c8..00000000 --- a/frontend/ai.client/src/app/settings/connections/connections.page.ts +++ /dev/null @@ -1,399 +0,0 @@ -import { - Component, - ChangeDetectionStrategy, - inject, - signal, - computed, - OnInit, -} from '@angular/core'; -import { Router, RouterLink, ActivatedRoute } from '@angular/router'; -import { NgIcon, provideIcons } from '@ng-icons/core'; -import { - heroArrowLeft, - heroLink, - heroCloud, - heroCodeBracket, - heroAcademicCap, - heroCheck, - heroExclamationTriangle, - heroArrowPath, - heroKey, -} from '@ng-icons/heroicons/outline'; -import { ConnectionsService } from './services'; -import { OAuthConnection, OAuthProviderType } from './models'; -import { ToastService } from '../../services/toast/toast.service'; - -@Component({ - selector: 'app-connections', - changeDetection: ChangeDetectionStrategy.OnPush, - imports: [RouterLink, NgIcon], - providers: [ - provideIcons({ - heroArrowLeft, - heroLink, - heroCloud, - heroCodeBracket, - heroAcademicCap, - heroCheck, - heroExclamationTriangle, - heroArrowPath, - heroKey, - }), - ], - host: { - class: 'block', - }, - template: ` -
-
- - - - Back to Chat - - - -
-

Connected Apps

-

- Connect your accounts to enable tools that require third-party authentication. -

-
- - - - - - - @if (connectionsResource.isLoading() && connections().length === 0) { -
-
-
-

- Loading connections... -

-
-
- } - - - @if (connectionsResource.error()) { -
-

Failed to load connections

-

Please check your connection and try again.

- -
- } - - - @if (!connectionsResource.isLoading() || connections().length > 0) { - @if (connections().length === 0) { - -
-
- -
-

No connections available

-

- There are no OAuth providers configured for your account. Contact an administrator if you need access to external tools. -

-
- } @else { -
- @for (connection of connections(); track connection.providerId) { -
- -
-
- -
-
-

- {{ connection.displayName }} -

- - - @if (isConnected(connection)) { -
- - - Connected - - @if (connection.connectedAt) { - - since {{ formatDate(connection.connectedAt) }} - - } -
- } @else if (connection.needsReauth || connection.status === 'needs_reauth' || connection.status === 'expired') { -
- - - Needs Re-authorization - -
- } @else { -

- Not connected -

- } -
-
- - -
- @if (isConnected(connection) && !connection.needsReauth && connection.status !== 'needs_reauth' && connection.status !== 'expired') { - - } @else { - - } -
-
- } -
- } - } - - - @if (connections().length > 0) { -
-

About Connections

-
-

- Connected apps allow certain tools to access external services on your behalf. -

-

- Re-authorization may be required if the app's permissions change or your token expires. -

-

- Disconnecting will revoke the app's access to your account. You can reconnect at any time. -

-
-
- } -
-
- `, -}) -export class ConnectionsPage implements OnInit { - connectionsService = inject(ConnectionsService); - private router = inject(Router); - private route = inject(ActivatedRoute); - private toast = inject(ToastService); - - readonly connectionsResource = this.connectionsService.connectionsResource; - - // Local state - connecting = signal(null); - disconnecting = signal(null); - - // Computed - readonly connections = computed(() => this.connectionsService.getConnections()); - - ngOnInit(): void { - this.handleCallbackParams(); - } - - /** - * Handle OAuth callback query parameters. - */ - private handleCallbackParams(): void { - const params = this.route.snapshot.queryParams; - - if (params['success'] === 'true') { - const provider = params['provider'] || 'the service'; - this.toast.success('Connected!', `Successfully connected to ${provider}.`); - // Refresh connections after successful OAuth - this.connectionsService.reload(); - // Clear query params - this.router.navigate([], { - relativeTo: this.route, - queryParams: {}, - replaceUrl: true, - }); - } else if (params['error']) { - const error = params['error']; - const provider = params['provider'] || 'the service'; - const description = params['error_description']; - - let message = `Failed to connect to ${provider}.`; - if (description) { - message = description; - } else if (error === 'access_denied') { - message = 'Authorization was denied. Please try again.'; - } else if (error === 'missing_params') { - message = 'Invalid callback. Please try again.'; - } - - this.toast.error('Connection Failed', message); - // Clear query params - this.router.navigate([], { - relativeTo: this.route, - queryParams: {}, - replaceUrl: true, - }); - } - } - - /** - * Check if a connection is actively connected. - */ - isConnected(connection: OAuthConnection): boolean { - return connection.status === 'connected'; - } - - /** - * Initiate OAuth connection flow. - */ - async connect(connection: OAuthConnection): Promise { - this.connecting.set(connection.providerId); - - try { - const redirectUrl = window.location.origin + '/settings/oauth/callback'; - const authUrl = await this.connectionsService.connect(connection.providerId, redirectUrl); - // Redirect to OAuth authorization - window.location.href = authUrl; - } catch (error: any) { - console.error('Error initiating connection:', error); - const message = error?.error?.detail || error?.message || 'Failed to initiate connection.'; - this.toast.error('Connection Error', message); - this.connecting.set(null); - } - } - - /** - * Disconnect from a provider. - */ - async disconnect(connection: OAuthConnection): Promise { - if (!confirm(`Are you sure you want to disconnect from ${connection.displayName}?`)) { - return; - } - - this.disconnecting.set(connection.providerId); - - try { - await this.connectionsService.disconnect(connection.providerId); - this.toast.success('Disconnected', `Successfully disconnected from ${connection.displayName}.`); - } catch (error: any) { - console.error('Error disconnecting:', error); - const message = error?.error?.detail || error?.message || 'Failed to disconnect.'; - this.toast.error('Disconnect Error', message); - } finally { - this.disconnecting.set(null); - } - } - - /** - * Get icon name for a provider. - */ - getProviderIcon(connection: OAuthConnection): string { - if (connection.iconName && connection.iconName !== 'heroLink') { - return connection.iconName; - } - // Default icons by type - switch (connection.providerType) { - case 'google': - case 'microsoft': - return 'heroCloud'; - case 'github': - return 'heroCodeBracket'; - case 'canvas': - return 'heroAcademicCap'; - default: - return 'heroLink'; - } - } - - /** - * Get icon container classes for a provider type. - */ - getProviderIconClasses(type: OAuthProviderType): string { - const baseClasses = 'flex size-12 shrink-0 items-center justify-center rounded-sm'; - switch (type) { - case 'google': - return `${baseClasses} bg-red-100 text-red-600 dark:bg-red-900/30 dark:text-red-400`; - case 'microsoft': - return `${baseClasses} bg-blue-100 text-blue-600 dark:bg-blue-900/30 dark:text-blue-400`; - case 'github': - return `${baseClasses} bg-gray-800 text-white dark:bg-gray-700`; - case 'canvas': - return `${baseClasses} bg-orange-100 text-orange-600 dark:bg-orange-900/30 dark:text-orange-400`; - default: - return `${baseClasses} bg-purple-100 text-purple-600 dark:bg-purple-900/30 dark:text-purple-400`; - } - } - - /** - * Format a date string for display. - */ - formatDate(dateString: string): string { - try { - const date = new Date(dateString); - return date.toLocaleDateString(undefined, { - month: 'short', - day: 'numeric', - year: 'numeric', - }); - } catch { - return dateString; - } - } -} diff --git a/frontend/ai.client/src/app/settings/connections/index.ts b/frontend/ai.client/src/app/settings/connections/index.ts deleted file mode 100644 index b7348c17..00000000 --- a/frontend/ai.client/src/app/settings/connections/index.ts +++ /dev/null @@ -1,3 +0,0 @@ -export * from './models'; -export * from './services'; -export * from './connections.page'; diff --git a/frontend/ai.client/src/app/settings/connections/models/index.ts b/frontend/ai.client/src/app/settings/connections/models/index.ts deleted file mode 100644 index df906738..00000000 --- a/frontend/ai.client/src/app/settings/connections/models/index.ts +++ /dev/null @@ -1 +0,0 @@ -export * from './oauth-connection.model'; diff --git a/frontend/ai.client/src/app/settings/connections/models/oauth-connection.model.ts b/frontend/ai.client/src/app/settings/connections/models/oauth-connection.model.ts deleted file mode 100644 index af88bfb7..00000000 --- a/frontend/ai.client/src/app/settings/connections/models/oauth-connection.model.ts +++ /dev/null @@ -1,57 +0,0 @@ -/** - * OAuth connection models for user-facing connections UI. - */ - -/** Connection status for user OAuth tokens */ -export type OAuthConnectionStatus = 'connected' | 'expired' | 'revoked' | 'needs_reauth'; - -/** Supported OAuth provider types */ -export type OAuthProviderType = 'google' | 'microsoft' | 'github' | 'canvas' | 'custom'; - -/** - * User's OAuth connection to a provider. - * Returned from GET /oauth/connections - */ -export interface OAuthConnection { - providerId: string; - displayName: string; - providerType: OAuthProviderType; - iconName: string; - status: OAuthConnectionStatus; - connectedAt: string | null; - needsReauth: boolean; -} - -/** - * Response from GET /oauth/connections - */ -export interface OAuthConnectionListResponse { - connections: OAuthConnection[]; -} - -/** - * Available OAuth provider for connection. - * Returned from GET /oauth/providers (filtered by user roles) - */ -export interface OAuthProvider { - providerId: string; - displayName: string; - providerType: OAuthProviderType; - iconName: string; - scopes: string[]; -} - -/** - * Response from GET /oauth/providers - */ -export interface OAuthProviderListResponse { - providers: OAuthProvider[]; - total: number; -} - -/** - * Response from GET /oauth/connect/{provider_id} - */ -export interface OAuthConnectResponse { - authorizationUrl: string; -} diff --git a/frontend/ai.client/src/app/settings/connections/services/connections.service.spec.ts b/frontend/ai.client/src/app/settings/connections/services/connections.service.spec.ts deleted file mode 100644 index 22fec454..00000000 --- a/frontend/ai.client/src/app/settings/connections/services/connections.service.spec.ts +++ /dev/null @@ -1,81 +0,0 @@ -import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; -import { TestBed } from '@angular/core/testing'; -import { HttpClientTestingModule, HttpTestingController } from '@angular/common/http/testing'; -import { signal } from '@angular/core'; -import { ConnectionsService } from './connections.service'; -import { ConfigService } from '../../../services/config.service'; -import { AuthService } from '../../../auth/auth.service'; - -describe('ConnectionsService', () => { - let service: ConnectionsService; - let httpMock: HttpTestingController; - - beforeEach(() => { - TestBed.resetTestingModule(); - TestBed.configureTestingModule({ - imports: [HttpClientTestingModule], - providers: [ - ConnectionsService, - { provide: AuthService, useValue: { ensureAuthenticated: vi.fn().mockResolvedValue(undefined) } }, - { provide: ConfigService, useValue: { appApiUrl: signal('http://localhost:8000') } }, - ], - }); - service = TestBed.inject(ConnectionsService); - httpMock = TestBed.inject(HttpTestingController); - }); - - afterEach(() => { - httpMock.match(() => true); // discard pending requests - TestBed.resetTestingModule(); - }); - - it('should fetch connections', async () => { - const mockResponse = { connections: [{ provider_id: 'google', status: 'connected' }] }; - - const connectionsPromise = service.fetchConnections(); - - await vi.waitFor(() => { - httpMock.expectOne('http://localhost:8000/oauth/connections').flush(mockResponse); - }); - - const connections = await connectionsPromise; - expect(connections.connections[0].providerId).toBe('google'); - }); - - it('should fetch providers', async () => { - const mockResponse = { providers: [{ provider_id: 'google', name: 'Google' }], total: 1 }; - - const providersPromise = service.fetchProviders(); - - await vi.waitFor(() => { - httpMock.expectOne('http://localhost:8000/oauth/providers').flush(mockResponse); - }); - - const providers = await providersPromise; - expect(providers.providers[0].providerId).toBe('google'); - expect(providers.total).toBe(1); - }); - - it('should connect to provider', async () => { - const mockResponse = { authorization_url: 'https://oauth.example.com/auth' }; - - const connectPromise = service.connect('google'); - - await vi.waitFor(() => { - httpMock.expectOne('http://localhost:8000/oauth/connect/google').flush(mockResponse); - }); - - const authUrl = await connectPromise; - expect(authUrl).toBe('https://oauth.example.com/auth'); - }); - - it('should disconnect from provider', async () => { - const disconnectPromise = service.disconnect('google'); - - await vi.waitFor(() => { - httpMock.expectOne('http://localhost:8000/oauth/connections/google').flush({}); - }); - - await disconnectPromise; - }); -}); \ No newline at end of file diff --git a/frontend/ai.client/src/app/settings/connections/services/connections.service.ts b/frontend/ai.client/src/app/settings/connections/services/connections.service.ts deleted file mode 100644 index b356ba11..00000000 --- a/frontend/ai.client/src/app/settings/connections/services/connections.service.ts +++ /dev/null @@ -1,136 +0,0 @@ -import { Injectable, inject, resource, computed } from '@angular/core'; -import { HttpClient } from '@angular/common/http'; -import { firstValueFrom } from 'rxjs'; -import { ConfigService } from '../../../services/config.service'; -import { AuthService } from '../../../auth/auth.service'; -import { - OAuthConnection, - OAuthConnectionListResponse, - OAuthProvider, - OAuthProviderListResponse, -} from '../models'; - -/** - * Convert snake_case to camelCase for frontend models. - */ -function toCamelCase(obj: Record): Record { - const result: Record = {}; - for (const [key, value] of Object.entries(obj)) { - const camelKey = key.replace(/_([a-z])/g, (_, letter) => letter.toUpperCase()); - result[camelKey] = value; - } - return result; -} - -/** - * Service for managing user OAuth connections. - * - * Provides access to available providers and user's connections, - * as well as connect/disconnect operations. - */ -@Injectable({ - providedIn: 'root' -}) -export class ConnectionsService { - private http = inject(HttpClient); - private authService = inject(AuthService); - private config = inject(ConfigService); - - private readonly baseUrl = computed(() => `${this.config.appApiUrl()}/oauth`); - - /** - * Reactive resource for fetching user's OAuth connections. - */ - readonly connectionsResource = resource({ - loader: async () => { - await this.authService.ensureAuthenticated(); - return this.fetchConnections(); - } - }); - - /** - * Reactive resource for fetching available OAuth providers. - */ - readonly providersResource = resource({ - loader: async () => { - await this.authService.ensureAuthenticated(); - return this.fetchProviders(); - } - }); - - /** - * Get all user connections (from resource). - */ - getConnections(): OAuthConnection[] { - return this.connectionsResource.value()?.connections ?? []; - } - - /** - * Get all available providers (from resource). - */ - getProviders(): OAuthProvider[] { - return this.providersResource.value()?.providers ?? []; - } - - /** - * Get a connection by provider ID. - */ - getConnectionByProviderId(providerId: string): OAuthConnection | undefined { - return this.getConnections().find(c => c.providerId === providerId); - } - - /** - * Fetch user's OAuth connections from the API. - */ - async fetchConnections(): Promise { - const response = await firstValueFrom( - this.http.get(`${this.baseUrl()}/connections`) - ); - return { - connections: response.connections.map((c: any) => toCamelCase(c) as OAuthConnection), - }; - } - - /** - * Fetch available OAuth providers from the API. - */ - async fetchProviders(): Promise { - const response = await firstValueFrom( - this.http.get(`${this.baseUrl()}/providers`) - ); - return { - providers: response.providers.map((p: any) => toCamelCase(p) as OAuthProvider), - total: response.total, - }; - } - - /** - * Initiate OAuth connection flow. - * Returns the authorization URL to redirect to. - */ - async connect(providerId: string, redirectUrl?: string): Promise { - const params = redirectUrl ? `?redirect=${encodeURIComponent(redirectUrl)}` : ''; - const response = await firstValueFrom( - this.http.get(`${this.baseUrl()}/connect/${providerId}${params}`) - ); - return response.authorization_url; - } - - /** - * Disconnect from an OAuth provider. - */ - async disconnect(providerId: string): Promise { - await firstValueFrom( - this.http.delete(`${this.baseUrl()}/connections/${providerId}`) - ); - this.connectionsResource.reload(); - } - - /** - * Reload both resources. - */ - reload(): void { - this.connectionsResource.reload(); - this.providersResource.reload(); - } -} diff --git a/frontend/ai.client/src/app/settings/connections/services/index.ts b/frontend/ai.client/src/app/settings/connections/services/index.ts deleted file mode 100644 index d8c6cca4..00000000 --- a/frontend/ai.client/src/app/settings/connections/services/index.ts +++ /dev/null @@ -1 +0,0 @@ -export * from './connections.service'; diff --git a/frontend/ai.client/src/app/settings/connectors/models/user-connector.model.ts b/frontend/ai.client/src/app/settings/connectors/models/user-connector.model.ts new file mode 100644 index 00000000..85557f46 --- /dev/null +++ b/frontend/ai.client/src/app/settings/connectors/models/user-connector.model.ts @@ -0,0 +1,37 @@ +/** + * Connector shape returned by the user-facing catalog endpoint. + * Strips admin-only fields (ARN, callback URL, allow-list). + */ +export interface UserConnector { + providerId: string; + displayName: string; + providerType: 'google' | 'microsoft' | 'github' | 'canvas' | 'custom'; + iconName: string; + /** Optional admin-uploaded icon (base64 data URL). Wins over `iconName`. */ + iconData?: string | null; + scopes: string[]; +} + +export interface UserConnectorListResponse { + connectors: UserConnector[]; +} + +/** + * Inference-API response for `/connectors/{id}/initiate-consent`. + * Exactly one of `connected` (true) or `authorizationUrl` (populated) will + * be meaningful — `connected: false` with a URL is the consent path. + */ +export interface InitiateConsentResponse { + connected: boolean; + authorizationUrl: string | null; +} + +/** + * Inference-API response for `GET /connectors/{id}/status`. Side-effect-free: + * unlike initiate-consent, this never remembers a session_uri server-side + * or hands back an authorization URL. Use it to render "Connected" badges + * without committing the user to a consent flow. + */ +export interface ConnectorStatusResponse { + connected: boolean; +} diff --git a/frontend/ai.client/src/app/settings/connectors/services/user-connectors.service.ts b/frontend/ai.client/src/app/settings/connectors/services/user-connectors.service.ts new file mode 100644 index 00000000..14bb678b --- /dev/null +++ b/frontend/ai.client/src/app/settings/connectors/services/user-connectors.service.ts @@ -0,0 +1,118 @@ +import { Injectable, inject, resource, computed } from '@angular/core'; +import { HttpClient } from '@angular/common/http'; +import { firstValueFrom } from 'rxjs'; +import { ConfigService } from '../../../services/config.service'; +import { AuthService } from '../../../auth/auth.service'; +import { + ConnectorStatusResponse, + InitiateConsentResponse, + UserConnector, +} from '../models/user-connector.model'; + +function toCamelCase(obj: Record): T { + const out: Record = {}; + for (const [k, v] of Object.entries(obj)) { + const camel = k.replace(/_([a-z])/g, (_, c: string) => c.toUpperCase()); + out[camel] = v; + } + return out as T; +} + +/** + * User-facing connectors service. + * + * Catalog lives on app-api (`GET /connectors`). Consent initiation lives + * on inference-api (`POST /connectors/{id}/initiate-consent`) because + * AgentCore's IdentityClient needs the workload token that the runtime + * middleware injects only on the inference path. + */ +@Injectable({ providedIn: 'root' }) +export class UserConnectorsService { + private readonly http = inject(HttpClient); + private readonly auth = inject(AuthService); + private readonly config = inject(ConfigService); + + private readonly appApiUrl = computed(() => `${this.config.appApiUrl()}/connectors`); + private readonly inferenceUrl = computed( + () => `${this.config.inferenceApiUrl()}/connectors`, + ); + + readonly connectorsResource = resource({ + loader: async () => { + await this.auth.ensureAuthenticated(); + const response = await firstValueFrom( + this.http.get<{ connectors: Record[] }>(`${this.appApiUrl()}/`), + ); + return response.connectors.map((c) => toCamelCase(c)); + }, + }); + + /** + * Side-effect-free check of whether AgentCore's vault has a usable token + * for this user + provider. Use it to render a "Connected" badge without + * committing the user to a consent flow — `initiateConsent` records a + * server-side pending session every time it returns an auth URL, which is + * wasteful when we only want a badge. + * + * Sends `OAuth2CallbackUrl` for the same reason `initiateConsent` does: + * the runtime injects it in prod, but the settings page bypasses the + * runtime, so without it the backend's identity client has no callback + * URL to give AgentCore Identity and rejects the request 503. + */ + async getStatus(providerId: string): Promise { + await this.auth.ensureAuthenticated(); + const callback = new URL('/oauth-complete', window.location.origin); + callback.searchParams.set('provider_id', providerId); + return await firstValueFrom( + this.http.get( + `${this.inferenceUrl()}/${providerId}/status`, + { + headers: { + OAuth2CallbackUrl: callback.toString(), + }, + }, + ), + ); + } + + async initiateConsent(providerId: string): Promise { + await this.auth.ensureAuthenticated(); + // AgentCore's GetResourceOauth2Token requires a callback URL. The runtime + // reads this header in prod via AgentCoreContextMiddleware; the settings + // page bypasses the runtime, so we set it explicitly on the request. + // We also tack the provider_id onto the callback URL as a query param so + // /oauth-complete can surface the right providerId in its postMessage. + const callback = new URL('/oauth-complete', window.location.origin); + callback.searchParams.set('provider_id', providerId); + const raw = await firstValueFrom( + this.http.post>( + `${this.inferenceUrl()}/${providerId}/initiate-consent`, + {}, + { + headers: { + OAuth2CallbackUrl: callback.toString(), + }, + }, + ), + ); + return toCamelCase(raw); + } + + /** + * Best-effort disconnect: clears the inference-api's local token cache + * for this user/provider and flags it for forced re-consent on next + * use. AgentCore Identity exposes no per-user vault-delete, so the + * upstream token at the provider keeps existing until it expires or + * the user revokes our app from their provider account. + */ + async disconnect(providerId: string): Promise { + await this.auth.ensureAuthenticated(); + await firstValueFrom( + this.http.delete(`${this.inferenceUrl()}/${providerId}/connection`), + ); + } + + reload(): void { + this.connectorsResource.reload(); + } +} diff --git a/frontend/ai.client/src/app/settings/oauth-callback/oauth-callback.page.ts b/frontend/ai.client/src/app/settings/oauth-callback/oauth-callback.page.ts deleted file mode 100644 index 3ab5489a..00000000 --- a/frontend/ai.client/src/app/settings/oauth-callback/oauth-callback.page.ts +++ /dev/null @@ -1,655 +0,0 @@ -import { - Component, - ChangeDetectionStrategy, - inject, - OnInit, - OnDestroy, - signal, -} from '@angular/core'; -import { Router, ActivatedRoute } from '@angular/router'; -import { NgIcon, provideIcons } from '@ng-icons/core'; -import { - heroCheck, - heroXMark, - heroArrowPath, - heroLink, -} from '@ng-icons/heroicons/outline'; -import { SidenavService } from '../../services/sidenav/sidenav.service'; - -type CallbackState = 'processing' | 'success' | 'error'; - -@Component({ - selector: 'app-oauth-callback', - imports: [NgIcon], - providers: [ - provideIcons({ - heroCheck, - heroXMark, - heroArrowPath, - heroLink, - }), - ], - changeDetection: ChangeDetectionStrategy.OnPush, - template: ` -
- - - - - - - -
- - @if (state() === 'processing') { -
-
- -
-
-
-
-
-
-

Connecting

-

- Establishing secure connection - - . - . - . - -

-
- } - - - @if (state() === 'success') { -
-
- -
-
-
-
-

Connected

-

- @if (providerName()) { - Successfully linked to {{ providerName() }} - } @else { - Authorization complete - } -

-

- Redirecting to your connections... -

-
- } - - - @if (state() === 'error') { -
-
- -
-
-
-
-

Connection Failed

-

- {{ errorMessage() }} -

-

- Redirecting back... -

-
- } - - - -
- - - -
- `, - styles: ` - :host { - display: block; - min-height: 100dvh; - background: var(--color-gray-50); - } - - :host-context(html.dark) { - background: var(--color-gray-900); - } - - .callback-container { - position: relative; - min-height: 100dvh; - display: flex; - flex-direction: column; - align-items: center; - justify-content: center; - overflow: hidden; - padding: 2rem; - } - - /* Animated grid background */ - .grid-background { - position: absolute; - inset: 0; - display: grid; - grid-template-columns: repeat(8, 1fr); - opacity: 0.04; - pointer-events: none; - } - - :host-context(html.dark) .grid-background { - opacity: 0.06; - } - - .grid-line { - border-right: 1px solid var(--color-primary-500); - height: 100%; - animation: pulse-line 4s ease-in-out infinite; - } - - @keyframes pulse-line { - 0%, 100% { opacity: 0.3; } - 50% { opacity: 1; } - } - - /* Floating accent shapes */ - .accent-shape { - position: absolute; - border-radius: 50%; - filter: blur(80px); - pointer-events: none; - animation: float 8s ease-in-out infinite; - } - - .shape-1 { - width: 350px; - height: 350px; - background: var(--color-primary-500); - opacity: 0.12; - top: -80px; - right: -80px; - animation-delay: 0s; - } - - .shape-2 { - width: 280px; - height: 280px; - background: var(--color-secondary-500); - opacity: 0.1; - bottom: -60px; - left: -60px; - animation-delay: 3s; - } - - @keyframes float { - 0%, 100% { transform: translate(0, 0) scale(1); } - 25% { transform: translate(10px, -20px) scale(1.05); } - 50% { transform: translate(-5px, 10px) scale(0.95); } - 75% { transform: translate(-15px, -10px) scale(1.02); } - } - - /* Main content */ - .content-wrapper { - position: relative; - z-index: 10; - display: flex; - flex-direction: column; - align-items: center; - text-align: center; - animation: fade-up 0.6s ease-out; - min-width: 320px; - } - - @keyframes fade-up { - from { - opacity: 0; - transform: translateY(20px); - } - to { - opacity: 1; - transform: translateY(0); - } - } - - /* Status display */ - .status-display { - margin-bottom: 2rem; - animation: scale-in 0.5s cubic-bezier(0.34, 1.56, 0.64, 1); - } - - @keyframes scale-in { - from { - opacity: 0; - transform: scale(0.8); - } - to { - opacity: 1; - transform: scale(1); - } - } - - .icon-container { - position: relative; - width: 120px; - height: 120px; - display: flex; - align-items: center; - justify-content: center; - border-radius: 50%; - } - - /* Processing state */ - .processing-icon { - background: linear-gradient(135deg, var(--color-primary-100) 0%, var(--color-primary-50) 100%); - color: var(--color-primary-600); - animation: rotate-subtle 8s linear infinite; - } - - :host-context(html.dark) .processing-icon { - background: linear-gradient(135deg, var(--color-primary-900) 0%, var(--color-primary-950) 100%); - color: var(--color-primary-400); - } - - @keyframes rotate-subtle { - from { transform: rotate(0deg); } - to { transform: rotate(360deg); } - } - - .pulse-ring { - position: absolute; - inset: -8px; - border-radius: 50%; - border: 2px solid var(--color-primary-400); - animation: pulse-out 2s ease-out infinite; - } - - .pulse-ring.delay-1 { - animation-delay: 0.6s; - } - - .pulse-ring.delay-2 { - animation-delay: 1.2s; - } - - @keyframes pulse-out { - 0% { - opacity: 0.6; - transform: scale(1); - } - 100% { - opacity: 0; - transform: scale(1.6); - } - } - - /* Success state */ - .success-icon { - background: linear-gradient(135deg, var(--color-green-100) 0%, var(--color-green-50) 100%); - color: var(--color-green-600); - animation: success-bounce 0.6s cubic-bezier(0.34, 1.56, 0.64, 1); - } - - :host-context(html.dark) .success-icon { - background: linear-gradient(135deg, var(--color-green-900) 0%, var(--color-green-950) 100%); - color: var(--color-green-400); - } - - @keyframes success-bounce { - 0% { - opacity: 0; - transform: scale(0.3); - } - 50% { - transform: scale(1.1); - } - 100% { - opacity: 1; - transform: scale(1); - } - } - - .check-ring { - position: absolute; - inset: -4px; - border-radius: 50%; - border: 3px solid var(--color-green-400); - animation: ring-appear 0.4s ease-out 0.3s both; - } - - @keyframes ring-appear { - from { - opacity: 0; - transform: scale(0.8); - } - to { - opacity: 1; - transform: scale(1); - } - } - - /* Error state */ - .error-icon { - background: linear-gradient(135deg, var(--color-red-100) 0%, var(--color-red-50) 100%); - color: var(--color-red-600); - animation: error-shake 0.5s ease-out; - } - - :host-context(html.dark) .error-icon { - background: linear-gradient(135deg, var(--color-red-900) 0%, var(--color-red-950) 100%); - color: var(--color-red-400); - } - - @keyframes error-shake { - 0%, 100% { transform: translateX(0); } - 20% { transform: translateX(-8px); } - 40% { transform: translateX(8px); } - 60% { transform: translateX(-4px); } - 80% { transform: translateX(4px); } - } - - .error-ring { - position: absolute; - inset: -4px; - border-radius: 50%; - border: 3px solid var(--color-red-400); - animation: ring-appear 0.4s ease-out 0.3s both; - } - - /* Message section */ - .message-section { - animation: fade-up 0.6s ease-out 0.2s both; - } - - .title { - font-family: 'Outfit', system-ui, sans-serif; - font-weight: 700; - font-size: clamp(1.75rem, 5vw, 2.5rem); - color: var(--color-gray-900); - margin: 0 0 0.75rem; - letter-spacing: -0.02em; - } - - :host-context(html.dark) .title { - color: var(--color-gray-100); - } - - .success-title { - color: var(--color-green-600); - } - - :host-context(html.dark) .success-title { - color: var(--color-green-400); - } - - .error-title { - color: var(--color-red-600); - } - - :host-context(html.dark) .error-title { - color: var(--color-red-400); - } - - .subtitle { - font-family: 'Space Mono', monospace; - font-size: clamp(0.875rem, 2vw, 1rem); - color: var(--color-gray-600); - margin: 0; - max-width: 320px; - display: flex; - align-items: center; - justify-content: center; - gap: 0; - } - - :host-context(html.dark) .subtitle { - color: var(--color-gray-400); - } - - .error-subtitle { - color: var(--color-red-600); - } - - :host-context(html.dark) .error-subtitle { - color: var(--color-red-400); - } - - .typing-text { - overflow: hidden; - white-space: nowrap; - animation: typing 1.5s steps(30) forwards; - } - - @keyframes typing { - from { width: 0; } - to { width: 100%; } - } - - .dots { - display: inline-flex; - margin-left: 2px; - } - - .dot { - animation: dot-bounce 1.4s ease-in-out infinite; - opacity: 0; - } - - .dot:nth-child(1) { animation-delay: 0s; } - .dot:nth-child(2) { animation-delay: 0.2s; } - .dot:nth-child(3) { animation-delay: 0.4s; } - - @keyframes dot-bounce { - 0%, 80%, 100% { - opacity: 0; - transform: translateY(0); - } - 40% { - opacity: 1; - transform: translateY(-4px); - } - } - - .redirect-notice { - font-family: 'Space Mono', monospace; - font-size: 0.75rem; - color: var(--color-gray-500); - margin-top: 1.5rem; - animation: fade-up 0.4s ease-out 0.5s both; - } - - :host-context(html.dark) .redirect-notice { - color: var(--color-gray-500); - } - - /* Progress track */ - .progress-track { - width: 200px; - height: 4px; - background: var(--color-gray-200); - border-radius: 2px; - margin-top: 2.5rem; - overflow: hidden; - animation: fade-up 0.4s ease-out 0.4s both; - } - - :host-context(html.dark) .progress-track { - background: var(--color-gray-700); - } - - .progress-fill { - height: 100%; - width: 0%; - background: var(--color-primary-500); - border-radius: 2px; - animation: progress-loading 2.5s ease-out forwards; - } - - .progress-fill.success { - background: var(--color-green-500); - animation: progress-complete 0.4s ease-out forwards; - } - - .progress-fill.error { - background: var(--color-red-500); - animation: progress-complete 0.4s ease-out forwards; - } - - @keyframes progress-loading { - 0% { width: 0%; } - 20% { width: 15%; } - 50% { width: 45%; } - 80% { width: 75%; } - 100% { width: 90%; } - } - - @keyframes progress-complete { - to { width: 100%; } - } - - /* Bottom accent bar */ - .bottom-bar { - position: absolute; - bottom: 0; - left: 0; - right: 0; - height: 6px; - display: flex; - } - - .bar-segment { - flex: 1; - background: var(--color-primary-500); - animation: bar-grow 0.8s ease-out both; - } - - .bar-segment.success { - background: var(--color-green-500); - } - - .bar-segment.error { - background: var(--color-red-500); - } - - @keyframes bar-grow { - from { transform: scaleX(0); } - to { transform: scaleX(1); } - } - `, -}) -export class OAuthCallbackPage implements OnInit, OnDestroy { - private router = inject(Router); - private route = inject(ActivatedRoute); - private sidenavService = inject(SidenavService); - - gridLines = Array.from({ length: 8 }, (_, i) => i); - - // State signals - state = signal('processing'); - providerName = signal(null); - errorMessage = signal('Authorization was denied or failed.'); - - ngOnInit(): void { - this.sidenavService.hide(); - this.handleCallback(); - } - - ngOnDestroy(): void { - this.sidenavService.show(); - } - - private handleCallback(): void { - const params = this.route.snapshot.queryParams; - - // Simulate brief processing delay for visual feedback - setTimeout(() => { - if (params['success'] === 'true') { - this.handleSuccess(params); - } else if (params['error']) { - this.handleError(params); - } else { - // No valid params, redirect to connections - this.redirectToConnections(); - } - }, 800); - } - - private handleSuccess(params: Record): void { - const provider = params['provider']; - if (provider) { - this.providerName.set(this.formatProviderName(provider)); - } - this.state.set('success'); - - // Redirect after showing success - setTimeout(() => { - this.redirectToConnections({ success: 'true', provider }); - }, 1500); - } - - private handleError(params: Record): void { - const error = params['error']; - const description = params['error_description']; - const provider = params['provider']; - - let message = 'Authorization was denied or failed.'; - if (description) { - message = description; - } else if (error === 'access_denied') { - message = 'Authorization was denied. Please try again.'; - } else if (error === 'missing_params') { - message = 'Invalid callback parameters.'; - } else if (error === 'invalid_state') { - message = 'Session expired. Please try again.'; - } else if (error === 'token_exchange_failed') { - message = 'Failed to complete authorization.'; - } - - this.errorMessage.set(message); - if (provider) { - this.providerName.set(this.formatProviderName(provider)); - } - this.state.set('error'); - - // Redirect after showing error - setTimeout(() => { - this.redirectToConnections({ error, provider }); - }, 2500); - } - - private redirectToConnections(queryParams?: Record): void { - this.router.navigate(['/settings/connections'], { - queryParams, - replaceUrl: true, - }); - } - - private formatProviderName(providerId: string): string { - // Convert provider_id to display name - return providerId - .replace(/_/g, ' ') - .replace(/-/g, ' ') - .split(' ') - .map(word => word.charAt(0).toUpperCase() + word.slice(1)) - .join(' '); - } -} diff --git a/frontend/ai.client/src/app/settings/pages/connections-settings/connections-settings.page.ts b/frontend/ai.client/src/app/settings/pages/connections-settings/connections-settings.page.ts deleted file mode 100644 index a1247468..00000000 --- a/frontend/ai.client/src/app/settings/pages/connections-settings/connections-settings.page.ts +++ /dev/null @@ -1,308 +0,0 @@ -import { - Component, - ChangeDetectionStrategy, - inject, - signal, - computed, - OnInit, -} from '@angular/core'; -import { Router, ActivatedRoute } from '@angular/router'; -import { NgIcon, provideIcons } from '@ng-icons/core'; -import { - heroLink, - heroCloud, - heroCodeBracket, - heroAcademicCap, - heroCheck, - heroExclamationTriangle, - heroArrowPath, - heroKey, -} from '@ng-icons/heroicons/outline'; -import { ConnectionsService } from '../../connections/services'; -import { OAuthConnection, OAuthProviderType } from '../../connections/models'; -import { ToastService } from '../../../services/toast/toast.service'; - -@Component({ - selector: 'app-connections-settings', - changeDetection: ChangeDetectionStrategy.OnPush, - imports: [NgIcon], - providers: [ - provideIcons({ - heroLink, - heroCloud, - heroCodeBracket, - heroAcademicCap, - heroCheck, - heroExclamationTriangle, - heroArrowPath, - heroKey, - }), - ], - host: { class: 'block' }, - template: ` -
- -
-

Connections

-

- Connect your accounts to enable tools that require third-party authentication. -

-
- - - @if (connectionsResource.isLoading() && connections().length === 0) { -
-
-
-

Loading connections...

-
-
- } - - - @if (connectionsResource.error()) { -
-

Failed to load connections

-

Please check your connection and try again.

- -
- } - - - @if (!connectionsResource.isLoading() || connections().length > 0) { - @if (connections().length === 0 && !connectionsResource.error()) { - -
-
- -
-

No connections available

-

- There are no OAuth providers configured for your account. Contact an administrator if you need access to external tools. -

-
- } @else { -
- @for (connection of connections(); track connection.providerId) { -
- -
-
- -
-
-

- {{ connection.displayName }} -

- - - @if (isConnected(connection)) { -
- - - Connected - - @if (connection.connectedAt) { - - since {{ formatDate(connection.connectedAt) }} - - } -
- } @else if (connection.needsReauth || connection.status === 'needs_reauth' || connection.status === 'expired') { -
- - - Needs Re-authorization - -
- } @else { -

- Not connected -

- } -
-
- - -
- @if (isConnected(connection) && !connection.needsReauth && connection.status !== 'needs_reauth' && connection.status !== 'expired') { - - } @else { - - } -
-
- } -
- } - } -
- `, -}) -export class ConnectionsSettingsPage implements OnInit { - readonly connectionsService = inject(ConnectionsService); - private router = inject(Router); - private route = inject(ActivatedRoute); - private toast = inject(ToastService); - - readonly connectionsResource = this.connectionsService.connectionsResource; - - connecting = signal(null); - disconnecting = signal(null); - - readonly connections = computed(() => this.connectionsService.getConnections()); - - ngOnInit(): void { - this.handleCallbackParams(); - } - - private handleCallbackParams(): void { - const params = this.route.snapshot.queryParams; - - if (params['success'] === 'true') { - const provider = params['provider'] || 'the service'; - this.toast.success('Connected!', `Successfully connected to ${provider}.`); - this.connectionsService.reload(); - this.router.navigate([], { - relativeTo: this.route, - queryParams: {}, - replaceUrl: true, - }); - } else if (params['error']) { - const error = params['error']; - const provider = params['provider'] || 'the service'; - const description = params['error_description']; - - let message = `Failed to connect to ${provider}.`; - if (description) { - message = description; - } else if (error === 'access_denied') { - message = 'Authorization was denied. Please try again.'; - } else if (error === 'missing_params') { - message = 'Invalid callback. Please try again.'; - } - - this.toast.error('Connection Failed', message); - this.router.navigate([], { - relativeTo: this.route, - queryParams: {}, - replaceUrl: true, - }); - } - } - - isConnected(connection: OAuthConnection): boolean { - return connection.status === 'connected'; - } - - async connect(connection: OAuthConnection): Promise { - this.connecting.set(connection.providerId); - - try { - const redirectUrl = window.location.origin + '/settings/oauth/callback'; - const authUrl = await this.connectionsService.connect(connection.providerId, redirectUrl); - window.location.href = authUrl; - } catch (error: unknown) { - const err = error as { error?: { detail?: string }; message?: string }; - const message = err?.error?.detail || err?.message || 'Failed to initiate connection.'; - this.toast.error('Connection Error', message); - this.connecting.set(null); - } - } - - async disconnect(connection: OAuthConnection): Promise { - if (!confirm(`Are you sure you want to disconnect from ${connection.displayName}?`)) { - return; - } - - this.disconnecting.set(connection.providerId); - - try { - await this.connectionsService.disconnect(connection.providerId); - this.toast.success('Disconnected', `Successfully disconnected from ${connection.displayName}.`); - } catch (error: unknown) { - const err = error as { error?: { detail?: string }; message?: string }; - const message = err?.error?.detail || err?.message || 'Failed to disconnect.'; - this.toast.error('Disconnect Error', message); - } finally { - this.disconnecting.set(null); - } - } - - getProviderIcon(connection: OAuthConnection): string { - if (connection.iconName && connection.iconName !== 'heroLink') { - return connection.iconName; - } - switch (connection.providerType) { - case 'google': - case 'microsoft': - return 'heroCloud'; - case 'github': - return 'heroCodeBracket'; - case 'canvas': - return 'heroAcademicCap'; - default: - return 'heroLink'; - } - } - - getProviderIconClasses(type: OAuthProviderType): string { - const baseClasses = 'flex size-12 shrink-0 items-center justify-center rounded-sm'; - switch (type) { - case 'google': - return `${baseClasses} bg-red-100 text-red-600 dark:bg-red-900/30 dark:text-red-400`; - case 'microsoft': - return `${baseClasses} bg-blue-100 text-blue-600 dark:bg-blue-900/30 dark:text-blue-400`; - case 'github': - return `${baseClasses} bg-gray-800 text-white dark:bg-gray-700`; - case 'canvas': - return `${baseClasses} bg-orange-100 text-orange-600 dark:bg-orange-900/30 dark:text-orange-400`; - default: - return `${baseClasses} bg-purple-100 text-purple-600 dark:bg-purple-900/30 dark:text-purple-400`; - } - } - - formatDate(dateString: string): string { - try { - const date = new Date(dateString); - return date.toLocaleDateString(undefined, { - month: 'short', - day: 'numeric', - year: 'numeric', - }); - } catch { - return dateString; - } - } -} diff --git a/frontend/ai.client/src/app/settings/pages/connectors-settings/connectors-settings.page.ts b/frontend/ai.client/src/app/settings/pages/connectors-settings/connectors-settings.page.ts new file mode 100644 index 00000000..e8dc615c --- /dev/null +++ b/frontend/ai.client/src/app/settings/pages/connectors-settings/connectors-settings.page.ts @@ -0,0 +1,373 @@ +import { + Component, + ChangeDetectionStrategy, + inject, + signal, + computed, + effect, +} from '@angular/core'; +import { Dialog } from '@angular/cdk/dialog'; +import { firstValueFrom } from 'rxjs'; +import { NgIcon, provideIcons } from '@ng-icons/core'; +import { + heroLink, + heroCloud, + heroCodeBracket, + heroAcademicCap, + heroCheckCircle, + heroArrowPath, + heroExclamationTriangle, +} from '@ng-icons/heroicons/outline'; +import { UserConnectorsService } from '../../connectors/services/user-connectors.service'; +import { OAuthConsentService } from '../../../services/oauth-consent/oauth-consent.service'; +import { UserConnector } from '../../connectors/models/user-connector.model'; +import { ToastService } from '../../../services/toast/toast.service'; +import { + ConfirmationDialogComponent, + ConfirmationDialogData, +} from '../../../components/confirmation-dialog'; + +type ConnectState = + | 'probing' + | 'idle' + | 'initiating' + | 'awaiting' + | 'connected' + | 'error'; + +@Component({ + selector: 'app-connectors-settings', + changeDetection: ChangeDetectionStrategy.OnPush, + imports: [NgIcon], + providers: [ + provideIcons({ + heroLink, + heroCloud, + heroCodeBracket, + heroAcademicCap, + heroCheckCircle, + heroArrowPath, + heroExclamationTriangle, + }), + ], + host: { class: 'block' }, + template: ` +
+
+

Connectors

+

+ Connect your third-party accounts so agents can call tools on your behalf. +

+
+ + @if (resource.isLoading()) { +
+
+ Loading connectors... +
+ } @else if (resource.error()) { +
+ +
+

Couldn't load connectors

+

+ {{ resource.error()?.message || 'Try again in a moment.' }} +

+ +
+
+ } @else if (connectors().length === 0) { +
+ +

+ No connectors are available to you yet. +

+

+ Ask an administrator to enable a connector for your role. +

+
+ } @else { +
    + @for (connector of connectors(); track connector.providerId) { +
  • +
    + @if (connector.iconData) { +
    + +
    + } @else { +
    + +
    + } +
    +

    + {{ connector.displayName }} +

    +
    +
    + + @let state = getState(connector.providerId); + @if (state === 'probing') { + + + } @else { +
    + @if (state === 'connected') { + + + Connected + + } @else if (state === 'error') { + + + Failed + + } + + @if (state === 'connected') { + + } @else { + + } +
    + } +
  • + } +
+ } +
+ `, +}) +export class ConnectorsSettingsPage { + private readonly connectorsService = inject(UserConnectorsService); + private readonly consentService = inject(OAuthConsentService); + private readonly toast = inject(ToastService); + private readonly dialog = inject(Dialog); + + protected readonly resource = this.connectorsService.connectorsResource; + + protected readonly connectors = computed( + () => this.resource.value() ?? [], + ); + + private readonly states = signal>(new Map()); + + constructor() { + // Flip a provider to `connected` when the /oauth-complete landing page + // postMessages success. This is the same signal the chat-input banner + // listens to, so both UIs stay in sync. + effect(() => { + const completion = this.consentService.completion(); + if (!completion || !completion.providerId) return; + if (completion.status === 'success') { + this.setState(completion.providerId, 'connected'); + } else { + this.setState(completion.providerId, 'error'); + } + this.consentService.acknowledgeCompletion(); + }); + + // Probe AgentCore on load (and whenever the connector list changes) + // to restore the "Connected" badge without the user having to click. + // Uses the side-effect-free `/status` endpoint — `initiateConsent` + // would record a pending session on the server every time the vault is + // empty, which is wasteful when we only want a badge. + effect(() => { + const connectors = this.connectors(); + if (connectors.length === 0) return; + void this.probeConnectedStatus(connectors); + }); + + // If a user closes the OAuth popup without completing consent, the + // service drops the providerId from inFlight. Reset our local state + // back to `idle` so the Connect button becomes interactive again. + effect(() => { + const inFlight = this.consentService.inFlightProviders(); + this.states.update((states) => { + let changed = false; + const next = new Map(states); + for (const [providerId, state] of next.entries()) { + if (state === 'awaiting' && !inFlight.has(providerId)) { + next.set(providerId, 'idle'); + changed = true; + } + } + return changed ? next : states; + }); + }); + } + + private async probeConnectedStatus(connectors: UserConnector[]): Promise { + const unknown = connectors.filter((c) => !this.states().has(c.providerId)); + if (unknown.length === 0) return; + + // Flip to `probing` synchronously so the skeleton renders before the + // first network round-trip resolves, instead of flashing the Connect + // button and replacing it half a second later. + unknown.forEach((c) => this.setState(c.providerId, 'probing')); + + await Promise.all( + unknown.map(async (c) => { + try { + const status = await this.connectorsService.getStatus(c.providerId); + // Only resolve from `probing` — if the user clicked Connect + // mid-probe we don't want to clobber their in-flight state. + if (this.getState(c.providerId) === 'probing') { + this.setState(c.providerId, status.connected ? 'connected' : 'idle'); + } + } catch { + // Status check failed (e.g. backend 503). Fall back to idle so the + // Connect button is interactive and the user can retry manually. + if (this.getState(c.providerId) === 'probing') { + this.setState(c.providerId, 'idle'); + } + } + }), + ); + } + + protected getState(providerId: string): ConnectState { + return this.states().get(providerId) ?? 'idle'; + } + + private setState(providerId: string, state: ConnectState): void { + this.states.update((map) => { + const next = new Map(map); + next.set(providerId, state); + return next; + }); + } + + protected async disconnect(connector: UserConnector): Promise { + // The "destructive" styling matches the existing pattern for delete + // affordances in this codebase (see file-browser bulk delete). The + // message flags that the upstream provider may still hold an + // authorization the user should revoke separately for full removal. + const dialogRef = this.dialog.open(ConfirmationDialogComponent, { + data: { + title: `Disconnect ${connector.displayName}`, + message: + `Agents will stop using this connector for you, and you'll be ` + + `prompted to re-authorize the next time it's needed. For full ` + + `revocation (e.g. removing this app from your Google account), ` + + `visit your account settings at the provider.`, + confirmText: 'Disconnect', + cancelText: 'Cancel', + destructive: true, + } as ConfirmationDialogData, + }); + + const confirmed = await firstValueFrom(dialogRef.closed); + if (confirmed !== true) return; + + try { + await this.connectorsService.disconnect(connector.providerId); + this.setState(connector.providerId, 'idle'); + this.toast.success(`${connector.displayName} disconnected.`); + } catch (err: unknown) { + console.error('Disconnect failed', err); + const detail = (err as { error?: { detail?: string }; message?: string })?.error?.detail; + this.toast.error(detail ?? 'Could not disconnect.'); + } + } + + protected async connect(providerId: string): Promise { + this.setState(providerId, 'initiating'); + try { + const result = await this.connectorsService.initiateConsent(providerId); + if (result.connected) { + this.setState(providerId, 'connected'); + this.toast.success(`${this.displayNameFor(providerId)} is already connected.`); + return; + } + if (!result.authorizationUrl) { + this.setState(providerId, 'error'); + this.toast.error('Unexpected response from the server.'); + return; + } + this.consentService.requestConsent(providerId, result.authorizationUrl); + this.consentService.openConsentPopup(providerId); + this.setState(providerId, 'awaiting'); + } catch (err: unknown) { + console.error('Consent initiation failed', err); + this.setState(providerId, 'error'); + const detail = (err as { error?: { detail?: string }; message?: string })?.error?.detail; + this.toast.error(detail ?? 'Could not start the consent flow.'); + } + } + + private displayNameFor(providerId: string): string { + return this.connectors().find((c) => c.providerId === providerId)?.displayName ?? providerId; + } + + protected defaultIcon(providerType: UserConnector['providerType']): string { + switch (providerType) { + case 'google': + case 'microsoft': + return 'heroCloud'; + case 'github': + return 'heroCodeBracket'; + case 'canvas': + return 'heroAcademicCap'; + default: + return 'heroLink'; + } + } + + protected iconClasses(providerType: UserConnector['providerType']): string { + const base = 'flex size-10 items-center justify-center rounded-sm'; + switch (providerType) { + case 'google': + return `${base} bg-red-100 text-red-600 dark:bg-red-900/30 dark:text-red-400`; + case 'microsoft': + return `${base} bg-blue-100 text-blue-600 dark:bg-blue-900/30 dark:text-blue-400`; + case 'github': + return `${base} bg-gray-800 text-white dark:bg-gray-600`; + case 'canvas': + return `${base} bg-orange-100 text-orange-600 dark:bg-orange-900/30 dark:text-orange-400`; + default: + return `${base} bg-purple-100 text-purple-600 dark:bg-purple-900/30 dark:text-purple-400`; + } + } +} diff --git a/frontend/ai.client/src/app/settings/settings.page.ts b/frontend/ai.client/src/app/settings/settings.page.ts index c9d064ab..97079221 100644 --- a/frontend/ai.client/src/app/settings/settings.page.ts +++ b/frontend/ai.client/src/app/settings/settings.page.ts @@ -114,7 +114,7 @@ export class SettingsPage { { label: 'Profile', icon: 'heroUser', route: '/settings/profile', description: 'Your personal information' }, { label: 'Appearance', icon: 'heroPaintBrush', route: '/settings/appearance', description: 'Theme and display' }, { label: 'Chat', icon: 'heroChatBubbleLeftRight', route: '/settings/chat', description: 'Chat preferences' }, - { label: 'Connections', icon: 'heroLink', route: '/settings/connections', description: 'Connected apps' }, + { label: 'Connectors', icon: 'heroLink', route: '/settings/connectors', description: 'Connected accounts' }, { label: 'API Keys', icon: 'heroKey', route: '/settings/api-keys', description: 'API key management' }, { label: 'Usage', icon: 'heroChartBar', route: '/settings/usage', description: 'Usage and billing' }, ]; diff --git a/frontend/ai.client/src/app/settings/settings.routes.ts b/frontend/ai.client/src/app/settings/settings.routes.ts index 959ff2c7..90d3243f 100644 --- a/frontend/ai.client/src/app/settings/settings.routes.ts +++ b/frontend/ai.client/src/app/settings/settings.routes.ts @@ -22,9 +22,9 @@ export const settingsRoutes: Routes = [ import('./pages/chat-preferences/chat-preferences-settings.page').then(m => m.ChatPreferencesSettingsPage), }, { - path: 'connections', + path: 'connectors', loadComponent: () => - import('./pages/connections-settings/connections-settings.page').then(m => m.ConnectionsSettingsPage), + import('./pages/connectors-settings/connectors-settings.page').then(m => m.ConnectorsSettingsPage), }, { path: 'api-keys', diff --git a/frontend/ai.client/src/app/shared/utils/stream-parser/index.ts b/frontend/ai.client/src/app/shared/utils/stream-parser/index.ts index 9a2085eb..abfe4f72 100644 --- a/frontend/ai.client/src/app/shared/utils/stream-parser/index.ts +++ b/frontend/ai.client/src/app/shared/utils/stream-parser/index.ts @@ -55,6 +55,7 @@ export { validateQuotaExceededEvent, validateConversationalStreamError, validateCitation, + validateOAuthRequiredEvent, } from './stream-parser-core'; // Types @@ -74,6 +75,7 @@ export type { QuotaExceededEvent, StreamErrorEvent, ConversationalStreamErrorEvent, + OAuthRequiredEvent, StreamEventType, StreamEventData, ParsedStreamEvent, diff --git a/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-core.ts b/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-core.ts index 339d697d..de03d17c 100644 --- a/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-core.ts +++ b/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-core.ts @@ -36,6 +36,7 @@ import type { QuotaExceededEvent, StreamErrorEvent, ConversationalStreamErrorEvent, + OAuthRequiredEvent, ToolProgress, } from './stream-parser-types'; import type { MetadataEvent } from '../../../session/services/models/content-types'; @@ -75,6 +76,9 @@ export interface StreamParserCallbacks { onQuotaWarning?: (data: QuotaWarningEvent) => void; onQuotaExceeded?: (data: QuotaExceededEvent) => void; + // OAuth consent required (external MCP tool needs user authorization) + onOAuthRequired?: (data: OAuthRequiredEvent) => void; + // Error handling onError?: (data: StreamErrorEvent | ConversationalStreamErrorEvent | string) => void; onStreamError?: (data: ConversationalStreamErrorEvent) => void; @@ -318,6 +322,27 @@ export function validateConversationalStreamError( ); } +/** + * Validate OAuthRequiredEvent structure + */ +export function validateOAuthRequiredEvent(data: unknown): data is OAuthRequiredEvent { + if (!data || typeof data !== 'object') { + return false; + } + + const event = data as Partial; + + return ( + event.type === 'oauth_required' && + typeof event.providerId === 'string' && + event.providerId.length > 0 && + typeof event.authorizationUrl === 'string' && + event.authorizationUrl.length > 0 && + typeof event.interruptId === 'string' && + event.interruptId.length > 0 + ); +} + /** * Validate Citation structure */ @@ -483,6 +508,14 @@ export function processStreamEvent( } break; + case 'oauth_required': + if (validateOAuthRequiredEvent(data)) { + callbacks.onOAuthRequired?.(data); + } else { + callbacks.onParseError?.('oauth_required: invalid data structure'); + } + break; + default: // Ignore unknown events (ping, etc.) break; diff --git a/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-types.ts b/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-types.ts index 8b9ed116..59b6586b 100644 --- a/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-types.ts +++ b/frontend/ai.client/src/app/shared/utils/stream-parser/stream-parser-types.ts @@ -86,6 +86,19 @@ export interface ReasoningEvent { reasoningText?: string; } +/** + * OAuth required event — emitted when an external MCP tool needs the user + * to grant consent via AgentCore Identity. The agent's tool call is paused + * (Strands interrupt) and the frontend resumes the same turn after the + * user completes consent by POSTing back the carried `interruptId`. + */ +export interface OAuthRequiredEvent { + type: 'oauth_required'; + providerId: string; + authorizationUrl: string; + interruptId: string; +} + /** * Tool result event data structure */ @@ -123,7 +136,8 @@ export type StreamEventType = | 'quota_warning' | 'quota_exceeded' | 'stream_error' - | 'citation'; + | 'citation' + | 'oauth_required'; /** * Union type of all possible event data types @@ -143,6 +157,7 @@ export type StreamEventData = | StreamErrorEvent | ConversationalStreamErrorEvent | Citation + | OAuthRequiredEvent | null | undefined; diff --git a/infrastructure/lib/app-api-stack.ts b/infrastructure/lib/app-api-stack.ts index 20524a40..b3176944 100644 --- a/infrastructure/lib/app-api-stack.ts +++ b/infrastructure/lib/app-api-stack.ts @@ -917,6 +917,46 @@ export class AppApiStack extends cdk.Stack { resources: [`${oauthClientSecretsArn}*`], // Wildcard for random suffix }) ); + + // Admin CRUD for OAuth2 credential providers stored in AgentCore Identity. + // Provider-scoped actions are scoped to the default token vault; List + // requires a broader resource since it enumerates the vault itself. + taskDefinition.taskRole.addToPrincipalPolicy( + new iam.PolicyStatement({ + sid: 'AgentCoreCredentialProviderAdmin', + effect: iam.Effect.ALLOW, + actions: [ + 'bedrock-agentcore:CreateOauth2CredentialProvider', + 'bedrock-agentcore:UpdateOauth2CredentialProvider', + 'bedrock-agentcore:DeleteOauth2CredentialProvider', + 'bedrock-agentcore:GetOauth2CredentialProvider', + 'bedrock-agentcore:ListOauth2CredentialProviders', + ], + resources: [ + `arn:aws:bedrock-agentcore:${config.awsRegion}:${config.awsAccount}:token-vault/default`, + `arn:aws:bedrock-agentcore:${config.awsRegion}:${config.awsAccount}:token-vault/default/oauth2credentialprovider/*`, + ], + }) + ); + + // Custom metrics for OAuth admin flows (e.g. ProviderOrphaned emitted + // by `_emit_orphan_metric` when a failed DB write + failed rollback + // leaves an AgentCore credential provider stranded). PutMetricData + // cannot be resource-scoped; we scope via the namespace condition. + taskDefinition.taskRole.addToPrincipalPolicy( + new iam.PolicyStatement({ + sid: 'OAuthAdminMetrics', + effect: iam.Effect.ALLOW, + actions: ['cloudwatch:PutMetricData'], + resources: ['*'], + conditions: { + StringEquals: { + 'cloudwatch:namespace': 'Agentcore/OAuth', + }, + }, + }) + ); + // Grant permissions for API Keys table (imported from Infrastructure Stack) taskDefinition.taskRole.addToPrincipalPolicy( new iam.PolicyStatement({ diff --git a/infrastructure/lib/inference-api-stack.ts b/infrastructure/lib/inference-api-stack.ts index 35f3413c..e6a78cd3 100644 --- a/infrastructure/lib/inference-api-stack.ts +++ b/infrastructure/lib/inference-api-stack.ts @@ -199,7 +199,7 @@ export class InferenceApiStack extends cdk.Stack { resources: [`arn:aws:ssm:${config.awsRegion}:${config.awsAccount}:parameter/${config.projectPrefix}/*`], })); - // Secrets Manager read permissions for OAuth client secrets (imported from App API Stack) + // Secrets Manager read permissions for OAuth client secrets (imported from Infrastructure Stack) const oauthClientSecretsArn = ssm.StringParameter.valueForStringParameter( this, `/${config.projectPrefix}/oauth/client-secrets-arn` @@ -218,7 +218,7 @@ export class InferenceApiStack extends cdk.Stack { ], })); - // DynamoDB Users Table permissions (imported from App API Stack) + // DynamoDB Users Table permissions (imported from Infrastructure Stack) const usersTableArn = ssm.StringParameter.valueForStringParameter( this, `/${config.projectPrefix}/users/users-table-arn` @@ -240,7 +240,7 @@ export class InferenceApiStack extends cdk.Stack { ], })); - // DynamoDB AppRoles Table permissions (imported from App API Stack) + // DynamoDB AppRoles Table permissions (imported from Infrastructure Stack) // This table stores both RBAC roles AND tool catalog definitions const appRolesTableArn = ssm.StringParameter.valueForStringParameter( this, @@ -262,7 +262,7 @@ export class InferenceApiStack extends cdk.Stack { ], })); - // DynamoDB OAuth Providers Table permissions (imported from App API Stack) + // DynamoDB OAuth Providers Table permissions (imported from Infrastructure Stack) const oauthProvidersTableArn = ssm.StringParameter.valueForStringParameter( this, `/${config.projectPrefix}/oauth/providers-table-arn` @@ -283,7 +283,7 @@ export class InferenceApiStack extends cdk.Stack { ], })); - // DynamoDB OAuth User Tokens Table permissions (imported from App API Stack) + // DynamoDB OAuth User Tokens Table permissions (imported from Infrastructure Stack) const oauthUserTokensTableArn = ssm.StringParameter.valueForStringParameter( this, `/${config.projectPrefix}/oauth/user-tokens-table-arn` @@ -482,7 +482,26 @@ export class InferenceApiStack extends cdk.Stack { ], })); - // DynamoDB Quota Tables permissions (imported from App API Stack) + // AgentCore Identity OAuth2 token vault access — lets the Runtime fetch + // user-federated OAuth tokens for external MCP tools (e.g. Google, Slack) + // via IdentityClient.get_token. When the user has not consented, this + // call returns an authorization URL instead of a token, which the + // inference route surfaces as an `oauth_required` SSE event. + runtimeExecutionRole.addToPolicy(new iam.PolicyStatement({ + sid: 'GetResourceOauth2Token', + effect: iam.Effect.ALLOW, + actions: [ + 'bedrock-agentcore:GetResourceOauth2Token', + ], + resources: [ + `arn:aws:bedrock-agentcore:${config.awsRegion}:${config.awsAccount}:token-vault/default`, + `arn:aws:bedrock-agentcore:${config.awsRegion}:${config.awsAccount}:token-vault/default/*`, + `arn:aws:bedrock-agentcore:${config.awsRegion}:${config.awsAccount}:workload-identity-directory/default`, + `arn:aws:bedrock-agentcore:${config.awsRegion}:${config.awsAccount}:workload-identity-directory/default/workload-identity/hosted_agent_*`, + ], + })); + + // DynamoDB Quota Tables permissions (imported from Infrastructure Stack) const userQuotasTableArn = ssm.StringParameter.valueForStringParameter( this, `/${config.projectPrefix}/quota/user-quotas-table-arn` @@ -510,7 +529,7 @@ export class InferenceApiStack extends cdk.Stack { ], })); - // DynamoDB Cost Tracking Tables permissions (imported from App API Stack) + // DynamoDB Cost Tracking Tables permissions (imported from Infrastructure Stack) const sessionsMetadataTableArn = ssm.StringParameter.valueForStringParameter( this, `/${config.projectPrefix}/cost-tracking/sessions-metadata-table-arn` @@ -545,7 +564,7 @@ export class InferenceApiStack extends cdk.Stack { ], })); - // DynamoDB Managed Models Table permissions (imported from App API Stack) + // DynamoDB Managed Models Table permissions (imported from Infrastructure Stack) const managedModelsTableArn = ssm.StringParameter.valueForStringParameter( this, `/${config.projectPrefix}/admin/managed-models-table-arn` @@ -585,7 +604,7 @@ export class InferenceApiStack extends cdk.Stack { ], })); - // DynamoDB Auth Providers Table permissions (imported from App API Stack) + // DynamoDB Auth Providers Table permissions (imported from Infrastructure Stack) const authProvidersTableArn = ssm.StringParameter.valueForStringParameter( this, `/${config.projectPrefix}/auth/auth-providers-table-arn`