diff --git a/cecli/mcp/__init__.py b/cecli/mcp/__init__.py
index f40c27af978..44d1e6a5f15 100644
--- a/cecli/mcp/__init__.py
+++ b/cecli/mcp/__init__.py
@@ -11,7 +11,7 @@ def _parse_mcp_servers_from_json_string(json_string, io, verbose=False, mcp_tran
try:
config = json.loads(json_string)
if verbose:
- io.tool_output("Loading MCP servers from provided JSON string")
+ io.tool_output("Loading MCP servers from provided JSON")
if "mcpServers" in config:
for name, server_config in config["mcpServers"].items():
@@ -22,21 +22,21 @@ def _parse_mcp_servers_from_json_string(json_string, io, verbose=False, mcp_tran
server_config["name"] = name
transport = server_config.get("transport", mcp_transport)
if transport == "stdio":
- servers.append(McpServer(server_config))
+ servers.append(McpServer(server_config, io=io, verbose=verbose))
elif transport == "http":
- servers.append(HttpStreamingServer(server_config))
+ servers.append(HttpStreamingServer(server_config, io=io, verbose=verbose))
elif transport == "sse":
- servers.append(SseServer(server_config))
+ servers.append(SseServer(server_config, io=io, verbose=verbose))
if verbose:
- io.tool_output(f"Loaded {len(servers)} MCP servers from JSON string")
+ io.tool_output(f"Loaded {len(servers)} MCP servers")
return servers
else:
- io.tool_warning("No 'mcpServers' key found in MCP config JSON string")
+ io.tool_warning("No 'mcpServers' key found in MCP config")
except json.JSONDecodeError:
- io.tool_error("Invalid JSON in MCP config string")
+ io.tool_error("Invalid JSON in MCP config")
except Exception as e:
- io.tool_error(f"Error loading MCP config from string: {e}")
+ io.tool_error(f"Error loading MCP config: {e}")
return servers
@@ -100,44 +100,24 @@ def _resolve_mcp_config_path(file_path, io, verbose=False):
def _parse_mcp_servers_from_file(file_path, io, verbose=False, mcp_transport="stdio"):
"""Parse MCP servers from a JSON file."""
- servers = []
-
# Resolve the file path relative to closest cecli.conf.yml, git directory, or CWD
resolved_file_path = _resolve_mcp_config_path(file_path, io, verbose)
try:
with open(resolved_file_path, "r") as f:
- config = json.load(f)
+ json_string = f.read()
if verbose:
io.tool_output(f"Loading MCP servers from file: {file_path}")
- if "mcpServers" in config:
- for name, server_config in config["mcpServers"].items():
- if verbose:
- io.tool_output(f"Loading MCP server: {name}")
+ return _parse_mcp_servers_from_json_string(json_string, io, verbose, mcp_transport)
- # Create a server config with name included
- server_config["name"] = name
- transport = server_config.get("transport", mcp_transport)
- if transport == "stdio":
- servers.append(McpServer(server_config))
- elif transport == "http":
- servers.append(HttpStreamingServer(server_config))
-
- if verbose:
- io.tool_output(f"Loaded {len(servers)} MCP servers from {file_path}")
- return servers
- else:
- io.tool_warning(f"No 'mcpServers' key found in MCP config file: {file_path}")
except FileNotFoundError:
io.tool_warning(f"MCP config file not found: {file_path}")
- except json.JSONDecodeError:
- io.tool_error(f"Invalid JSON in MCP config file: {file_path}")
except Exception as e:
- io.tool_error(f"Error loading MCP config from file: {e}")
+ io.tool_error(f"Error reading MCP config file: {e}")
- return servers
+ return []
def load_mcp_servers(mcp_servers, mcp_servers_file, io, verbose=False, mcp_transport="stdio"):
@@ -169,6 +149,6 @@ def load_mcp_servers(mcp_servers, mcp_servers_file, io, verbose=False, mcp_trans
# and maybe it is actually prompt_toolkit's fault
# but this hack works swimmingly because ???
# so sure! why not
- servers = [McpServer(json.loads('{"cecli_default": {}}'))]
+ servers = [McpServer(json.loads('{"cecli_default": {}}'), io=io, verbose=verbose)]
return servers
diff --git a/cecli/mcp/oauth.py b/cecli/mcp/oauth.py
new file mode 100644
index 00000000000..c9d9897116f
--- /dev/null
+++ b/cecli/mcp/oauth.py
@@ -0,0 +1,250 @@
+import asyncio
+import base64
+import hashlib
+import http.server
+import json
+import os
+import secrets
+import socketserver
+import threading
+import time
+from pathlib import Path
+from typing import Awaitable, Callable, Optional, Tuple
+from urllib.parse import parse_qs, urlparse
+
+from mcp.client.auth import TokenStorage
+from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
+
+
+def find_available_port(start_port=8484, end_port=8584):
+ """Find an available port in the given range."""
+ for port in range(start_port, end_port + 1):
+ try:
+ # Check if the port is available by trying to bind to it
+ with socketserver.TCPServer(("localhost", port), None):
+ return port
+ except OSError:
+ # Port is likely already in use
+ continue
+ return None
+
+
+def create_oauth_callback_server(
+ port, path="/callback"
+) -> Tuple[Callable[[], Awaitable[Tuple[str, str]]], Callable[[], None]]:
+ """
+ Create a local HTTP server to handle OAuth callback.
+
+ Returns:
+ Tuple of (async callback handler function, shutdown function)
+ """
+ auth_code = None
+ state = None
+ server_error = None
+ callback_received = threading.Event()
+ server = None
+
+ class OAuthCallbackHandler(http.server.SimpleHTTPRequestHandler):
+ def do_GET(self):
+ nonlocal auth_code, state, server_error
+ parsed_path = urlparse(self.path)
+
+ if parsed_path.path == path:
+ query_params = parse_qs(parsed_path.query)
+ if "code" in query_params:
+ auth_code = query_params["code"][0]
+ if "state" in query_params:
+ state = query_params["state"][0]
+
+ self.send_response(200)
+ self.send_header("Content-type", "text/html")
+ self.end_headers()
+ self.wfile.write(
+ b"
Success!
"
+ b"Authentication successful. You can close this browser tab.
"
+ b""
+ )
+ callback_received.set()
+ elif "error" in query_params:
+ error = query_params["error"][0]
+ error_desc = query_params.get("error_description", [""])[0]
+ server_error = f"OAuth error: {error} - {error_desc}"
+
+ self.send_response(400)
+ self.send_header("Content-type", "text/html")
+ self.end_headers()
+ self.wfile.write(
+ "Authentication Failed
"
+ f"{error}: {error_desc}
".encode()
+ )
+ callback_received.set()
+ else:
+ self.send_response(400)
+ self.send_header("Content-type", "text/html")
+ self.end_headers()
+ self.wfile.write(b"Invalid Request
")
+ else:
+ self.send_response(404)
+ self.send_header("Content-type", "text/html")
+ self.end_headers()
+ self.wfile.write(b"Not Found
")
+
+ def log_message(self, format, *args):
+ pass
+
+ # Start server in a separate thread
+ def start_server():
+ nonlocal server
+ try:
+ server = socketserver.TCPServer(("localhost", port), OAuthCallbackHandler)
+ server.serve_forever()
+ except Exception as e:
+ server_error = f"Server error: {e}" # noqa
+ callback_received.set()
+
+ server_thread = threading.Thread(target=start_server, daemon=True)
+ server_thread.start()
+
+ # Shutdown function
+ def shutdown():
+ nonlocal server
+ if server:
+ server.shutdown()
+ server = None
+
+ async def get_auth_code() -> Tuple[str, str]:
+ # Wait for callback to be received
+ MINUTES = 5
+ timeout = MINUTES * 60
+
+ start_time = time.time()
+ while not callback_received.is_set():
+ if time.time() - start_time > timeout:
+ shutdown()
+ raise Exception(f"OAuth callback timed out after {MINUTES} minutes")
+
+ # Small sleep to avoid busy waiting
+ await asyncio.sleep(0.1)
+
+ if server_error:
+ shutdown()
+ raise Exception(server_error)
+
+ if not auth_code:
+ shutdown()
+ raise Exception("No authorization code received")
+
+ return auth_code, state
+
+ return get_auth_code, shutdown
+
+
+def generate_pkce_codes():
+ """Generate PKCE code verifier and challenge."""
+ code_verifier = secrets.token_urlsafe(64)
+ hasher = hashlib.sha256()
+ hasher.update(code_verifier.encode("utf-8"))
+ code_challenge = base64.urlsafe_b64encode(hasher.digest()).rstrip(b"=").decode("utf-8")
+ return code_verifier, code_challenge
+
+
+def get_token_file_path():
+ """Get the path to the MCP OAuth tokens file."""
+ config_dir = Path.home() / ".cecli"
+ config_dir.mkdir(parents=True, exist_ok=True)
+ return config_dir / "mcp-oauth-tokens.json"
+
+
+def load_mcp_oauth_tokens():
+ """Load stored OAuth tokens from file."""
+ token_file = get_token_file_path()
+ if not token_file.exists():
+ return {}
+
+ try:
+ with open(token_file, "r", encoding="utf-8") as f:
+ # File might be empty
+ return json.load(f) or {}
+ except Exception:
+ return {}
+
+
+def save_mcp_oauth_token(server_name, token_data):
+ """Save OAuth token for an MCP server."""
+ tokens = load_mcp_oauth_tokens()
+ tokens[server_name] = token_data
+
+ token_file = get_token_file_path()
+ try:
+ with open(token_file, "w", encoding="utf-8") as f:
+ json.dump(tokens, f, indent=2)
+ # Set restrictive permissions (owner read/write only)
+ os.chmod(token_file, 0o600)
+ except Exception as e:
+ raise Exception(f"Failed to save OAuth token: {e}")
+
+
+def save_mcp_oauth_tokens(tokens_dict):
+ """Save all OAuth tokens to file."""
+ token_file = get_token_file_path()
+ try:
+ with open(token_file, "w", encoding="utf-8") as f:
+ json.dump(tokens_dict, f, indent=2)
+ # Set restrictive permissions (owner read/write only)
+ os.chmod(token_file, 0o600)
+ except Exception as e:
+ raise Exception(f"Failed to save OAuth tokens: {e}")
+
+
+def get_mcp_oauth_token(server_name):
+ """Retrieve stored OAuth token for an MCP server."""
+ tokens = load_mcp_oauth_tokens()
+ return tokens.get(server_name, {})
+
+
+class FileBasedTokenStorage(TokenStorage):
+ """File-based token storage for MCP OAuth using the SDK's TokenStorage interface."""
+
+ def __init__(self, server_name: str):
+ self.server_name = server_name
+
+ async def get_tokens(self) -> Optional[OAuthToken]:
+ """Get stored tokens for this server."""
+ all_tokens = load_mcp_oauth_tokens()
+ server_data = all_tokens.get(self.server_name, {})
+
+ if "tokens" not in server_data:
+ return None
+
+ return OAuthToken.model_validate(server_data["tokens"])
+
+ async def set_tokens(self, tokens: OAuthToken) -> None:
+ """Store tokens for this server."""
+ all_tokens = load_mcp_oauth_tokens()
+
+ if self.server_name not in all_tokens:
+ all_tokens[self.server_name] = {}
+
+ tokens_dict = tokens.model_dump()
+ all_tokens[self.server_name]["tokens"] = tokens_dict
+ save_mcp_oauth_tokens(all_tokens)
+
+ async def get_client_info(self) -> Optional[OAuthClientInformationFull]:
+ """Get stored client information."""
+ all_tokens = load_mcp_oauth_tokens()
+ server_data = all_tokens.get(self.server_name, {})
+
+ if "client_info" not in server_data:
+ return None
+
+ return OAuthClientInformationFull.model_validate(server_data["client_info"])
+
+ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
+ """Store client information."""
+ all_tokens = load_mcp_oauth_tokens()
+
+ if self.server_name not in all_tokens:
+ all_tokens[self.server_name] = {}
+
+ all_tokens[self.server_name]["client_info"] = json.loads(client_info.model_dump_json())
+ save_mcp_oauth_tokens(all_tokens)
diff --git a/cecli/mcp/server.py b/cecli/mcp/server.py
index 9d4162fb666..58c2bcb661e 100644
--- a/cecli/mcp/server.py
+++ b/cecli/mcp/server.py
@@ -1,12 +1,25 @@
import asyncio
import logging
import os
+import webbrowser
from contextlib import AsyncExitStack
+from urllib.parse import urlparse
+import httpx
from mcp import ClientSession, StdioServerParameters
+from mcp.client.auth import OAuthClientProvider
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
-from mcp.client.streamable_http import streamablehttp_client
+from mcp.client.streamable_http import streamable_http_client
+from mcp.shared.auth import OAuthClientMetadata
+
+from cecli.mcp.oauth import (
+ FileBasedTokenStorage,
+ create_oauth_callback_server,
+ find_available_port,
+ get_mcp_oauth_token,
+ save_mcp_oauth_token,
+)
class McpServer:
@@ -17,14 +30,18 @@ class McpServer:
Uses the mcp library to create and initialize ClientSession objects.
"""
- def __init__(self, server_config):
+ def __init__(self, server_config, io=None, verbose=False):
"""Initialize the MCP tool provider.
Args:
server_config: Configuration for the MCP server
+ io: InputOutput object for user interaction
+ verbose: Whether to output verbose logging
"""
self.config = server_config
self.name = server_config.get("name", "unnamed-server")
+ self.io = io
+ self.verbose = verbose
self.session = None
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
self.exit_stack = AsyncExitStack()
@@ -39,15 +56,21 @@ async def connect(self):
ClientSession: The active session
"""
if self.session is not None:
- logging.info(f"Using existing session for MCP server: {self.name}")
+ if self.verbose and self.io:
+ self.io.tool_output(f"Using existing session for MCP server: {self.name}")
return self.session
- logging.info(f"Establishing new connection to MCP server: {self.name}")
+ if self.verbose and self.io:
+ self.io.tool_output(f"Establishing new connection to MCP server: {self.name}")
+
command = self.config["command"]
+
+ env = {**os.environ, **self.config["env"]} if self.config.get("env") else None
+
server_params = StdioServerParameters(
command=command,
args=self.config.get("args"),
- env={**os.environ, **self.config["env"]} if self.config.get("env") else None,
+ env=env,
)
try:
@@ -76,58 +99,163 @@ async def disconnect(self):
logging.error(f"Error during cleanup of server {self.name}: {e}")
-class HttpStreamingServer(McpServer):
- """HTTP streaming MCP server using mcp.client.streamablehttp_client."""
+class HttpBasedMcpServer(McpServer):
+ """Base class for HTTP-based MCP servers (HTTP streaming and SSE)."""
- async def connect(self):
- if self.session is not None:
- logging.info(f"Using existing session for MCP server: {self.name}")
- return self.session
+ async def _create_oauth_provider(self):
+ """Create an OAuthClientProvider using the MCP SDK."""
+ parsed = urlparse(self.config.get("url"))
+ server_url = f"{parsed.scheme}://{parsed.netloc}"
+ if self.verbose and self.io:
+ self.io.tool_output(f"Auto-derived OAuth server URL: {server_url}", log_only=True)
- logging.info(f"Establishing new connection to HTTP MCP server: {self.name}")
- try:
- url = self.config.get("url")
- headers = self.config.get("headers", {})
- http_transport = await self.exit_stack.enter_async_context(
- streamablehttp_client(url, headers=headers)
- )
- read, write, _response = http_transport
+ # Check if we have existing client info with a redirect URI
+ server_info = get_mcp_oauth_token(self.name)
+ existing_redirect_uri = None
- session = await self.exit_stack.enter_async_context(ClientSession(read, write))
- await session.initialize()
- self.session = session
- return session
- except Exception as e:
- logging.error(f"Error initializing HTTP server {self.name}: {e}")
- await self.disconnect()
- raise
+ if "client_info" in server_info and "redirect_uris" in server_info["client_info"]:
+ redirect_uris = server_info["client_info"].get("redirect_uris", [])
+ if redirect_uris:
+ existing_redirect_uri = redirect_uris[0]
+ if self.verbose and self.io:
+ self.io.tool_output(
+ f"Found existing redirect URI: {existing_redirect_uri}", log_only=True
+ )
+
+ # If we have an existing redirect URI, parse it to get the port
+ if existing_redirect_uri:
+ try:
+ parsed_uri = urlparse(existing_redirect_uri)
+ port = int(parsed_uri.netloc.split(":")[1])
+ if self.verbose and self.io:
+ self.io.tool_output(f"Reusing existing port: {port}", log_only=True)
+ except (ValueError, IndexError):
+ # If we can't parse the port, find a new one
+ port = find_available_port()
+ else:
+ # No existing redirect URI, find an available port
+ port = find_available_port()
+ if not port:
+ raise Exception("Could not find available port for OAuth callback")
-class SseServer(McpServer):
- """SSE (Server-Sent Events) MCP server using mcp.client.sse_client."""
+ redirect_uri = f"http://localhost:{port}/callback"
+
+ get_auth_code, shutdown = create_oauth_callback_server(port)
+
+ # Store shutdown function for cleanup
+ self._oauth_shutdown = shutdown
+
+ async def handle_redirect(auth_url: str) -> None:
+ if self.io:
+ self.io.tool_output(f"\nAuthentication required for MCP server: {self.name}")
+ self.io.tool_output("\nPlease open this URL in your browser to authenticate:")
+ self.io.tool_output(f"\n{auth_url}\n")
+ self.io.tool_output("\nWaiting for you to complete authentication...")
+ self.io.tool_output("Use Control-C to interrupt.")
+ try:
+ webbrowser.open(auth_url)
+ except Exception:
+ pass
+
+ client_metadata = OAuthClientMetadata(
+ client_name="Cecli",
+ redirect_uris=[redirect_uri],
+ grant_types=["authorization_code", "refresh_token"],
+ )
+ oauth_provider = OAuthClientProvider(
+ server_url=server_url,
+ client_metadata=client_metadata,
+ storage=FileBasedTokenStorage(self.name),
+ redirect_handler=handle_redirect,
+ callback_handler=get_auth_code,
+ )
+
+ return oauth_provider
+
+ def _create_transport(self, url, http_client):
+ """
+ Create the transport for this server type.
+ Must be implemented by subclasses.
+ """
+ raise NotImplementedError("Subclasses must implement _create_transport")
async def connect(self):
if self.session is not None:
- logging.info(f"Using existing session for SSE MCP server: {self.name}")
+ if self.verbose and self.io:
+ self.io.tool_output(f"Using existing session for {self.name}")
return self.session
- logging.info(f"Establishing new connection to SSE MCP server: {self.name}")
+ if self.verbose and self.io:
+ self.io.tool_output(f"Establishing new connection to {self.name}")
+
try:
url = self.config.get("url")
headers = self.config.get("headers", {})
- sse_transport = await self.exit_stack.enter_async_context(
- sse_client(url, headers=headers)
+ oauth_provider = await self._create_oauth_provider()
+
+ http_client = await self.exit_stack.enter_async_context(
+ httpx.AsyncClient(
+ auth=oauth_provider,
+ follow_redirects=True,
+ headers=headers,
+ timeout=30,
+ )
+ )
+
+ transport = await self.exit_stack.enter_async_context(
+ self._create_transport(url, http_client=http_client)
)
- read, write, _response = sse_transport
+
+ read, write, _ = transport
+
session = await self.exit_stack.enter_async_context(ClientSession(read, write))
await session.initialize()
self.session = session
+
+ if oauth_provider.context.oauth_metadata:
+ token_endpoint = oauth_provider._get_token_endpoint()
+ server_info = get_mcp_oauth_token(self.name)
+ if "client_info" not in server_info:
+ server_info["client_info"] = {}
+
+ server_info["client_info"]["token_endpoint"] = token_endpoint
+
+ save_mcp_oauth_token(self.name, server_info)
+
return session
except Exception as e:
- logging.error(f"Error initializing SSE server {self.name}: {e}")
+ logging.error(f"Error initializing {self.name}: {e}")
await self.disconnect()
raise
+ async def disconnect(self):
+ """Disconnect from the MCP server and clean up resources."""
+ async with self._cleanup_lock:
+ try:
+ if hasattr(self, "_oauth_shutdown"):
+ self._oauth_shutdown()
+ await self.exit_stack.aclose()
+ self.session = None
+ except Exception as e:
+ logging.error(f"Error during cleanup of server {self.name}: {e}")
+
+
+class HttpStreamingServer(HttpBasedMcpServer):
+ """HTTP streaming MCP server using mcp.client.streamable_http_client."""
+
+ def _create_transport(self, url, http_client):
+ """Create the HTTP streaming transport."""
+ return streamable_http_client(url, http_client=http_client)
+
+
+class SseServer(HttpBasedMcpServer):
+ """SSE (Server-Sent Events) MCP server using mcp.client.sse_client."""
+
+ def _create_transport(self, url, http_client):
+ """Create the SSE transport."""
+ return sse_client(url, http_client=http_client)
+
class LocalServer(McpServer):
"""
@@ -138,7 +266,8 @@ class LocalServer(McpServer):
async def connect(self):
"""Local tools don't need a connection."""
if self.session is not None:
- logging.info(f"Using existing session for local tools: {self.name}")
+ if self.verbose and self.io:
+ self.io.tool_output(f"Using existing session for local tools: {self.name}")
return self.session
self.session = object() # Dummy session object
diff --git a/cecli/onboarding.py b/cecli/onboarding.py
index 72d49ff4c51..a0037520d1f 100644
--- a/cecli/onboarding.py
+++ b/cecli/onboarding.py
@@ -1,8 +1,5 @@
-import base64
-import hashlib
import http.server
import os
-import secrets
import socketserver
import threading
import time
@@ -13,6 +10,7 @@
from cecli import urls
from cecli.io import InputOutput
+from cecli.mcp.oauth import find_available_port, generate_pkce_codes
def check_openrouter_tier(api_key):
@@ -117,24 +115,7 @@ async def select_default_model(args, io):
await io.offer_url(urls.models_and_keys, "Open documentation URL for more info?")
-def find_available_port(start_port=8484, end_port=8584):
- for port in range(start_port, end_port + 1):
- try:
- with socketserver.TCPServer(("localhost", port), None):
- return port
- except OSError:
- continue
- return None
-
-
-def generate_pkce_codes():
- code_verifier = secrets.token_urlsafe(64)
- hasher = hashlib.sha256()
- hasher.update(code_verifier.encode("utf-8"))
- code_challenge = base64.urlsafe_b64encode(hasher.digest()).rstrip(b"=").decode("utf-8")
- return code_verifier, code_challenge
-
-
+# Function to exchange the authorization code for an API key
def exchange_code_for_key(code, code_verifier, io):
try:
response = requests.post(