Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions api/controllers/console/workspace/tool_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
setup_required,
)
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.helper.tool_provider_cache import ToolProviderListCache
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
from core.mcp.mcp_client import MCPClient
Expand Down Expand Up @@ -944,7 +945,7 @@ def post(self):
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None

# Create provider
# Create provider in transaction
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
result = service.create_provider(
Expand All @@ -960,7 +961,11 @@ def post(self):
configuration=configuration,
authentication=authentication,
)
return jsonable_encoder(result)

# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
ToolProviderListCache.invalidate_cache(tenant_id)

return jsonable_encoder(result)

@console_ns.expect(parser_mcp_put)
@setup_required
Expand All @@ -972,17 +977,23 @@ def put(self):
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
_, current_tenant_id = current_account_with_tenant()

# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
validation_result = None
# Step 1: Get provider data for URL validation (short-lived session, no network I/O)
validation_data = None
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
validation_result = service.validate_server_url_change(
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
validation_data = service.get_provider_for_url_validation(
tenant_id=current_tenant_id, provider_id=args["provider_id"]
)

# No need to check for errors here, exceptions will be raised directly
# Step 2: Perform URL validation with network I/O OUTSIDE of any database session
# This prevents holding database locks during potentially slow network operations
validation_result = MCPToolManageService.validate_server_url_standalone(
tenant_id=current_tenant_id,
new_server_url=args["server_url"],
validation_data=validation_data,
)

# Step 2: Perform database update in a transaction
# Step 3: Perform database update in a transaction
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider(
Expand All @@ -999,7 +1010,11 @@ def put(self):
authentication=authentication,
validation_result=validation_result,
)
return {"result": "success"}

# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
ToolProviderListCache.invalidate_cache(current_tenant_id)

return {"result": "success"}

@console_ns.expect(parser_mcp_delete)
@setup_required
Expand All @@ -1012,7 +1027,11 @@ def delete(self):
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"}

# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
ToolProviderListCache.invalidate_cache(current_tenant_id)

return {"result": "success"}


parser_auth = (
Expand Down
55 changes: 36 additions & 19 deletions api/core/mcp/auth/auth_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def build_protected_resource_metadata_discovery_urls(
"""
Build a list of URLs to try for Protected Resource Metadata discovery.

Per SEP-985, supports fallback when discovery fails at one URL.
Per RFC 9728 Section 5.1, supports fallback when discovery fails at one URL.
Priority order:
1. URL from WWW-Authenticate header (if provided)
2. Well-known URI with path: https://example.com/.well-known/oauth-protected-resource/public/mcp
3. Well-known URI at root: https://example.com/.well-known/oauth-protected-resource
"""
urls = []

Expand All @@ -58,9 +62,18 @@ def build_protected_resource_metadata_discovery_urls(
# Fallback: construct from server URL
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
if fallback_url not in urls:
urls.append(fallback_url)
path = parsed.path.rstrip("/")

# Priority 2: With path insertion (e.g., /.well-known/oauth-protected-resource/public/mcp)
if path:
path_url = f"{base_url}/.well-known/oauth-protected-resource{path}"
if path_url not in urls:
urls.append(path_url)

# Priority 3: At root (e.g., /.well-known/oauth-protected-resource)
root_url = f"{base_url}/.well-known/oauth-protected-resource"
if root_url not in urls:
urls.append(root_url)

return urls

Expand All @@ -71,30 +84,34 @@ def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: st

Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.

Per RFC 8414 section 3:
- If issuer has no path: https://example.com/.well-known/oauth-authorization-server
- If issuer has path: https://example.com/.well-known/oauth-authorization-server{path}

Example:
- issuer: https://example.com/oauth
- metadata: https://example.com/.well-known/oauth-authorization-server/oauth
Per RFC 8414 section 3.1 and section 5, try all possible endpoints:
- OAuth 2.0 with path insertion: https://example.com/.well-known/oauth-authorization-server/tenant1
- OpenID Connect with path insertion: https://example.com/.well-known/openid-configuration/tenant1
- OpenID Connect path appending: https://example.com/tenant1/.well-known/openid-configuration
- OAuth 2.0 at root: https://example.com/.well-known/oauth-authorization-server
- OpenID Connect at root: https://example.com/.well-known/openid-configuration
"""
urls = []
base_url = auth_server_url or server_url

parsed = urlparse(base_url)
base = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/") # Remove trailing slash
path = parsed.path.rstrip("/")
# OAuth 2.0 Authorization Server Metadata at root (MCP-03-26)
urls.append(f"{base}/.well-known/oauth-authorization-server")

# Try OpenID Connect discovery first (more common)
urls.append(urljoin(base + "/", ".well-known/openid-configuration"))
# OpenID Connect Discovery at root
urls.append(f"{base}/.well-known/openid-configuration")

# OAuth 2.0 Authorization Server Metadata (RFC 8414)
# Include the path component if present in the issuer URL
if path:
urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}"))
else:
urls.append(urljoin(base, ".well-known/oauth-authorization-server"))
# OpenID Connect Discovery with path insertion
urls.append(f"{base}/.well-known/openid-configuration{path}")

# OpenID Connect Discovery path appending
urls.append(f"{base}{path}/.well-known/openid-configuration")

# OAuth 2.0 Authorization Server Metadata with path insertion
urls.append(f"{base}/.well-known/oauth-authorization-server{path}")

return urls

Expand Down
2 changes: 1 addition & 1 deletion api/core/mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _initialize(
try:
logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
self.connect_server(sse_client, "sse")
except MCPConnectionError:
except (MCPConnectionError, ValueError):
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
self.connect_server(streamablehttp_client, "mcp")

Expand Down
120 changes: 83 additions & 37 deletions api/services/tools/mcp_tools_manage_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.tool_provider_cache import ToolProviderListCache
from core.mcp.auth.auth_flow import auth
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError, MCPError
Expand Down Expand Up @@ -65,6 +64,15 @@ def should_update_server_url(self) -> bool:
return self.needs_validation and self.validation_passed and self.reconnect_result is not None


class ProviderUrlValidationData(BaseModel):
"""Data required for URL validation, extracted from database to perform network operations outside of session"""

current_server_url_hash: str
headers: dict[str, str]
timeout: float | None
sse_read_timeout: float | None


class MCPToolManageService:
"""Service class for managing MCP tools and providers."""

Expand Down Expand Up @@ -166,9 +174,6 @@ def create_provider(
self._session.add(mcp_tool)
self._session.flush()

# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)

mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
return mcp_providers

Expand All @@ -192,7 +197,7 @@ def update_provider(
Update an MCP provider.

Args:
validation_result: Pre-validation result from validate_server_url_change.
validation_result: Pre-validation result from validate_server_url_standalone.
If provided and contains reconnect_result, it will be used
instead of performing network operations.
"""
Expand Down Expand Up @@ -251,8 +256,6 @@ def update_provider(
# Flush changes to database
self._session.flush()

# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
except IntegrityError as e:
self._handle_integrity_error(e, name, server_url, server_identifier)

Expand All @@ -261,9 +264,6 @@ def delete_provider(self, *, tenant_id: str, provider_id: str) -> None:
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
self._session.delete(mcp_tool)

# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)

def list_providers(
self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True
) -> list[ToolProviderApiEntity]:
Expand Down Expand Up @@ -546,30 +546,39 @@ def auth_with_actions(
)
return self.execute_auth_actions(auth_result)

def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult:
"""Attempt to reconnect to MCP provider with new server URL."""
provider_entity = provider.to_entity()
headers = provider_entity.headers
def get_provider_for_url_validation(self, *, tenant_id: str, provider_id: str) -> ProviderUrlValidationData:
"""
Get provider data required for URL validation.
This method performs database read and should be called within a session.

try:
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
return ReconnectResult(
authed=True,
tools=json.dumps([tool.model_dump() for tool in tools]),
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
)
except MCPAuthError:
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
except MCPError as e:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
Returns:
ProviderUrlValidationData: Data needed for standalone URL validation
"""
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
provider_entity = provider.to_entity()
return ProviderUrlValidationData(
current_server_url_hash=provider.server_url_hash,
headers=provider_entity.headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
)

def validate_server_url_change(
self, *, tenant_id: str, provider_id: str, new_server_url: str
@staticmethod
def validate_server_url_standalone(
*,
tenant_id: str,
new_server_url: str,
validation_data: ProviderUrlValidationData,
) -> ServerUrlValidationResult:
"""
Validate server URL change by attempting to connect to the new server.
This method should be called BEFORE update_provider to perform network operations
outside of the database transaction.
This method performs network operations and MUST be called OUTSIDE of any database session
to avoid holding locks during network I/O.

Args:
tenant_id: Tenant ID for encryption
new_server_url: The new server URL to validate
validation_data: Provider data obtained from get_provider_for_url_validation

Returns:
ServerUrlValidationResult: Validation result with connection status and tools if successful
Expand All @@ -579,25 +588,30 @@ def validate_server_url_change(
return ServerUrlValidationResult(needs_validation=False)

# Validate URL format
if not self._is_valid_url(new_server_url):
parsed = urlparse(new_server_url)
if not all([parsed.scheme, parsed.netloc]) or parsed.scheme not in ["http", "https"]:
raise ValueError("Server URL is not valid.")

# Always encrypt and hash the URL
encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url)
new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest()

# Get current provider
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)

# Check if URL is actually different
if new_server_url_hash == provider.server_url_hash:
if new_server_url_hash == validation_data.current_server_url_hash:
# URL hasn't changed, but still return the encrypted data
return ServerUrlValidationResult(
needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash
needs_validation=False,
encrypted_server_url=encrypted_server_url,
server_url_hash=new_server_url_hash,
)

# Perform validation by attempting to connect
reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider)
# Perform network validation - this is the expensive operation that should be outside session
reconnect_result = MCPToolManageService._reconnect_with_url(
server_url=new_server_url,
headers=validation_data.headers,
timeout=validation_data.timeout,
sse_read_timeout=validation_data.sse_read_timeout,
)
return ServerUrlValidationResult(
needs_validation=True,
validation_passed=True,
Expand All @@ -606,6 +620,38 @@ def validate_server_url_change(
server_url_hash=new_server_url_hash,
)

@staticmethod
def _reconnect_with_url(
*,
server_url: str,
headers: dict[str, str],
timeout: float | None,
sse_read_timeout: float | None,
) -> ReconnectResult:
"""
Attempt to connect to MCP server with given URL.
This is a static method that performs network I/O without database access.
"""
from core.mcp.mcp_client import MCPClient

try:
with MCPClient(
server_url=server_url,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
return ReconnectResult(
authed=True,
tools=json.dumps([tool.model_dump() for tool in tools]),
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
)
except MCPAuthError:
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
except MCPError as e:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e

def _build_tool_provider_response(
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
) -> ToolProviderApiEntity:
Expand Down
Loading
Loading