diff --git a/cecli/mcp/__init__.py b/cecli/mcp/__init__.py index f40c27af978..44d1e6a5f15 100644 --- a/cecli/mcp/__init__.py +++ b/cecli/mcp/__init__.py @@ -11,7 +11,7 @@ def _parse_mcp_servers_from_json_string(json_string, io, verbose=False, mcp_tran try: config = json.loads(json_string) if verbose: - io.tool_output("Loading MCP servers from provided JSON string") + io.tool_output("Loading MCP servers from provided JSON") if "mcpServers" in config: for name, server_config in config["mcpServers"].items(): @@ -22,21 +22,21 @@ def _parse_mcp_servers_from_json_string(json_string, io, verbose=False, mcp_tran server_config["name"] = name transport = server_config.get("transport", mcp_transport) if transport == "stdio": - servers.append(McpServer(server_config)) + servers.append(McpServer(server_config, io=io, verbose=verbose)) elif transport == "http": - servers.append(HttpStreamingServer(server_config)) + servers.append(HttpStreamingServer(server_config, io=io, verbose=verbose)) elif transport == "sse": - servers.append(SseServer(server_config)) + servers.append(SseServer(server_config, io=io, verbose=verbose)) if verbose: - io.tool_output(f"Loaded {len(servers)} MCP servers from JSON string") + io.tool_output(f"Loaded {len(servers)} MCP servers") return servers else: - io.tool_warning("No 'mcpServers' key found in MCP config JSON string") + io.tool_warning("No 'mcpServers' key found in MCP config") except json.JSONDecodeError: - io.tool_error("Invalid JSON in MCP config string") + io.tool_error("Invalid JSON in MCP config") except Exception as e: - io.tool_error(f"Error loading MCP config from string: {e}") + io.tool_error(f"Error loading MCP config: {e}") return servers @@ -100,44 +100,24 @@ def _resolve_mcp_config_path(file_path, io, verbose=False): def _parse_mcp_servers_from_file(file_path, io, verbose=False, mcp_transport="stdio"): """Parse MCP servers from a JSON file.""" - servers = [] - # Resolve the file path relative to closest cecli.conf.yml, git directory, or CWD resolved_file_path = _resolve_mcp_config_path(file_path, io, verbose) try: with open(resolved_file_path, "r") as f: - config = json.load(f) + json_string = f.read() if verbose: io.tool_output(f"Loading MCP servers from file: {file_path}") - if "mcpServers" in config: - for name, server_config in config["mcpServers"].items(): - if verbose: - io.tool_output(f"Loading MCP server: {name}") + return _parse_mcp_servers_from_json_string(json_string, io, verbose, mcp_transport) - # Create a server config with name included - server_config["name"] = name - transport = server_config.get("transport", mcp_transport) - if transport == "stdio": - servers.append(McpServer(server_config)) - elif transport == "http": - servers.append(HttpStreamingServer(server_config)) - - if verbose: - io.tool_output(f"Loaded {len(servers)} MCP servers from {file_path}") - return servers - else: - io.tool_warning(f"No 'mcpServers' key found in MCP config file: {file_path}") except FileNotFoundError: io.tool_warning(f"MCP config file not found: {file_path}") - except json.JSONDecodeError: - io.tool_error(f"Invalid JSON in MCP config file: {file_path}") except Exception as e: - io.tool_error(f"Error loading MCP config from file: {e}") + io.tool_error(f"Error reading MCP config file: {e}") - return servers + return [] def load_mcp_servers(mcp_servers, mcp_servers_file, io, verbose=False, mcp_transport="stdio"): @@ -169,6 +149,6 @@ def load_mcp_servers(mcp_servers, mcp_servers_file, io, verbose=False, mcp_trans # and maybe it is actually prompt_toolkit's fault # but this hack works swimmingly because ??? # so sure! why not - servers = [McpServer(json.loads('{"cecli_default": {}}'))] + servers = [McpServer(json.loads('{"cecli_default": {}}'), io=io, verbose=verbose)] return servers diff --git a/cecli/mcp/oauth.py b/cecli/mcp/oauth.py new file mode 100644 index 00000000000..c9d9897116f --- /dev/null +++ b/cecli/mcp/oauth.py @@ -0,0 +1,250 @@ +import asyncio +import base64 +import hashlib +import http.server +import json +import os +import secrets +import socketserver +import threading +import time +from pathlib import Path +from typing import Awaitable, Callable, Optional, Tuple +from urllib.parse import parse_qs, urlparse + +from mcp.client.auth import TokenStorage +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + + +def find_available_port(start_port=8484, end_port=8584): + """Find an available port in the given range.""" + for port in range(start_port, end_port + 1): + try: + # Check if the port is available by trying to bind to it + with socketserver.TCPServer(("localhost", port), None): + return port + except OSError: + # Port is likely already in use + continue + return None + + +def create_oauth_callback_server( + port, path="/callback" +) -> Tuple[Callable[[], Awaitable[Tuple[str, str]]], Callable[[], None]]: + """ + Create a local HTTP server to handle OAuth callback. + + Returns: + Tuple of (async callback handler function, shutdown function) + """ + auth_code = None + state = None + server_error = None + callback_received = threading.Event() + server = None + + class OAuthCallbackHandler(http.server.SimpleHTTPRequestHandler): + def do_GET(self): + nonlocal auth_code, state, server_error + parsed_path = urlparse(self.path) + + if parsed_path.path == path: + query_params = parse_qs(parsed_path.query) + if "code" in query_params: + auth_code = query_params["code"][0] + if "state" in query_params: + state = query_params["state"][0] + + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write( + b"

Success!

" + b"

Authentication successful. You can close this browser tab.

" + b"" + ) + callback_received.set() + elif "error" in query_params: + error = query_params["error"][0] + error_desc = query_params.get("error_description", [""])[0] + server_error = f"OAuth error: {error} - {error_desc}" + + self.send_response(400) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write( + "

Authentication Failed

" + f"

{error}: {error_desc}

".encode() + ) + callback_received.set() + else: + self.send_response(400) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write(b"

Invalid Request

") + else: + self.send_response(404) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write(b"

Not Found

") + + def log_message(self, format, *args): + pass + + # Start server in a separate thread + def start_server(): + nonlocal server + try: + server = socketserver.TCPServer(("localhost", port), OAuthCallbackHandler) + server.serve_forever() + except Exception as e: + server_error = f"Server error: {e}" # noqa + callback_received.set() + + server_thread = threading.Thread(target=start_server, daemon=True) + server_thread.start() + + # Shutdown function + def shutdown(): + nonlocal server + if server: + server.shutdown() + server = None + + async def get_auth_code() -> Tuple[str, str]: + # Wait for callback to be received + MINUTES = 5 + timeout = MINUTES * 60 + + start_time = time.time() + while not callback_received.is_set(): + if time.time() - start_time > timeout: + shutdown() + raise Exception(f"OAuth callback timed out after {MINUTES} minutes") + + # Small sleep to avoid busy waiting + await asyncio.sleep(0.1) + + if server_error: + shutdown() + raise Exception(server_error) + + if not auth_code: + shutdown() + raise Exception("No authorization code received") + + return auth_code, state + + return get_auth_code, shutdown + + +def generate_pkce_codes(): + """Generate PKCE code verifier and challenge.""" + code_verifier = secrets.token_urlsafe(64) + hasher = hashlib.sha256() + hasher.update(code_verifier.encode("utf-8")) + code_challenge = base64.urlsafe_b64encode(hasher.digest()).rstrip(b"=").decode("utf-8") + return code_verifier, code_challenge + + +def get_token_file_path(): + """Get the path to the MCP OAuth tokens file.""" + config_dir = Path.home() / ".cecli" + config_dir.mkdir(parents=True, exist_ok=True) + return config_dir / "mcp-oauth-tokens.json" + + +def load_mcp_oauth_tokens(): + """Load stored OAuth tokens from file.""" + token_file = get_token_file_path() + if not token_file.exists(): + return {} + + try: + with open(token_file, "r", encoding="utf-8") as f: + # File might be empty + return json.load(f) or {} + except Exception: + return {} + + +def save_mcp_oauth_token(server_name, token_data): + """Save OAuth token for an MCP server.""" + tokens = load_mcp_oauth_tokens() + tokens[server_name] = token_data + + token_file = get_token_file_path() + try: + with open(token_file, "w", encoding="utf-8") as f: + json.dump(tokens, f, indent=2) + # Set restrictive permissions (owner read/write only) + os.chmod(token_file, 0o600) + except Exception as e: + raise Exception(f"Failed to save OAuth token: {e}") + + +def save_mcp_oauth_tokens(tokens_dict): + """Save all OAuth tokens to file.""" + token_file = get_token_file_path() + try: + with open(token_file, "w", encoding="utf-8") as f: + json.dump(tokens_dict, f, indent=2) + # Set restrictive permissions (owner read/write only) + os.chmod(token_file, 0o600) + except Exception as e: + raise Exception(f"Failed to save OAuth tokens: {e}") + + +def get_mcp_oauth_token(server_name): + """Retrieve stored OAuth token for an MCP server.""" + tokens = load_mcp_oauth_tokens() + return tokens.get(server_name, {}) + + +class FileBasedTokenStorage(TokenStorage): + """File-based token storage for MCP OAuth using the SDK's TokenStorage interface.""" + + def __init__(self, server_name: str): + self.server_name = server_name + + async def get_tokens(self) -> Optional[OAuthToken]: + """Get stored tokens for this server.""" + all_tokens = load_mcp_oauth_tokens() + server_data = all_tokens.get(self.server_name, {}) + + if "tokens" not in server_data: + return None + + return OAuthToken.model_validate(server_data["tokens"]) + + async def set_tokens(self, tokens: OAuthToken) -> None: + """Store tokens for this server.""" + all_tokens = load_mcp_oauth_tokens() + + if self.server_name not in all_tokens: + all_tokens[self.server_name] = {} + + tokens_dict = tokens.model_dump() + all_tokens[self.server_name]["tokens"] = tokens_dict + save_mcp_oauth_tokens(all_tokens) + + async def get_client_info(self) -> Optional[OAuthClientInformationFull]: + """Get stored client information.""" + all_tokens = load_mcp_oauth_tokens() + server_data = all_tokens.get(self.server_name, {}) + + if "client_info" not in server_data: + return None + + return OAuthClientInformationFull.model_validate(server_data["client_info"]) + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + """Store client information.""" + all_tokens = load_mcp_oauth_tokens() + + if self.server_name not in all_tokens: + all_tokens[self.server_name] = {} + + all_tokens[self.server_name]["client_info"] = json.loads(client_info.model_dump_json()) + save_mcp_oauth_tokens(all_tokens) diff --git a/cecli/mcp/server.py b/cecli/mcp/server.py index 9d4162fb666..58c2bcb661e 100644 --- a/cecli/mcp/server.py +++ b/cecli/mcp/server.py @@ -1,12 +1,25 @@ import asyncio import logging import os +import webbrowser from contextlib import AsyncExitStack +from urllib.parse import urlparse +import httpx from mcp import ClientSession, StdioServerParameters +from mcp.client.auth import OAuthClientProvider from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client -from mcp.client.streamable_http import streamablehttp_client +from mcp.client.streamable_http import streamable_http_client +from mcp.shared.auth import OAuthClientMetadata + +from cecli.mcp.oauth import ( + FileBasedTokenStorage, + create_oauth_callback_server, + find_available_port, + get_mcp_oauth_token, + save_mcp_oauth_token, +) class McpServer: @@ -17,14 +30,18 @@ class McpServer: Uses the mcp library to create and initialize ClientSession objects. """ - def __init__(self, server_config): + def __init__(self, server_config, io=None, verbose=False): """Initialize the MCP tool provider. Args: server_config: Configuration for the MCP server + io: InputOutput object for user interaction + verbose: Whether to output verbose logging """ self.config = server_config self.name = server_config.get("name", "unnamed-server") + self.io = io + self.verbose = verbose self.session = None self._cleanup_lock: asyncio.Lock = asyncio.Lock() self.exit_stack = AsyncExitStack() @@ -39,15 +56,21 @@ async def connect(self): ClientSession: The active session """ if self.session is not None: - logging.info(f"Using existing session for MCP server: {self.name}") + if self.verbose and self.io: + self.io.tool_output(f"Using existing session for MCP server: {self.name}") return self.session - logging.info(f"Establishing new connection to MCP server: {self.name}") + if self.verbose and self.io: + self.io.tool_output(f"Establishing new connection to MCP server: {self.name}") + command = self.config["command"] + + env = {**os.environ, **self.config["env"]} if self.config.get("env") else None + server_params = StdioServerParameters( command=command, args=self.config.get("args"), - env={**os.environ, **self.config["env"]} if self.config.get("env") else None, + env=env, ) try: @@ -76,58 +99,163 @@ async def disconnect(self): logging.error(f"Error during cleanup of server {self.name}: {e}") -class HttpStreamingServer(McpServer): - """HTTP streaming MCP server using mcp.client.streamablehttp_client.""" +class HttpBasedMcpServer(McpServer): + """Base class for HTTP-based MCP servers (HTTP streaming and SSE).""" - async def connect(self): - if self.session is not None: - logging.info(f"Using existing session for MCP server: {self.name}") - return self.session + async def _create_oauth_provider(self): + """Create an OAuthClientProvider using the MCP SDK.""" + parsed = urlparse(self.config.get("url")) + server_url = f"{parsed.scheme}://{parsed.netloc}" + if self.verbose and self.io: + self.io.tool_output(f"Auto-derived OAuth server URL: {server_url}", log_only=True) - logging.info(f"Establishing new connection to HTTP MCP server: {self.name}") - try: - url = self.config.get("url") - headers = self.config.get("headers", {}) - http_transport = await self.exit_stack.enter_async_context( - streamablehttp_client(url, headers=headers) - ) - read, write, _response = http_transport + # Check if we have existing client info with a redirect URI + server_info = get_mcp_oauth_token(self.name) + existing_redirect_uri = None - session = await self.exit_stack.enter_async_context(ClientSession(read, write)) - await session.initialize() - self.session = session - return session - except Exception as e: - logging.error(f"Error initializing HTTP server {self.name}: {e}") - await self.disconnect() - raise + if "client_info" in server_info and "redirect_uris" in server_info["client_info"]: + redirect_uris = server_info["client_info"].get("redirect_uris", []) + if redirect_uris: + existing_redirect_uri = redirect_uris[0] + if self.verbose and self.io: + self.io.tool_output( + f"Found existing redirect URI: {existing_redirect_uri}", log_only=True + ) + + # If we have an existing redirect URI, parse it to get the port + if existing_redirect_uri: + try: + parsed_uri = urlparse(existing_redirect_uri) + port = int(parsed_uri.netloc.split(":")[1]) + if self.verbose and self.io: + self.io.tool_output(f"Reusing existing port: {port}", log_only=True) + except (ValueError, IndexError): + # If we can't parse the port, find a new one + port = find_available_port() + else: + # No existing redirect URI, find an available port + port = find_available_port() + if not port: + raise Exception("Could not find available port for OAuth callback") -class SseServer(McpServer): - """SSE (Server-Sent Events) MCP server using mcp.client.sse_client.""" + redirect_uri = f"http://localhost:{port}/callback" + + get_auth_code, shutdown = create_oauth_callback_server(port) + + # Store shutdown function for cleanup + self._oauth_shutdown = shutdown + + async def handle_redirect(auth_url: str) -> None: + if self.io: + self.io.tool_output(f"\nAuthentication required for MCP server: {self.name}") + self.io.tool_output("\nPlease open this URL in your browser to authenticate:") + self.io.tool_output(f"\n{auth_url}\n") + self.io.tool_output("\nWaiting for you to complete authentication...") + self.io.tool_output("Use Control-C to interrupt.") + try: + webbrowser.open(auth_url) + except Exception: + pass + + client_metadata = OAuthClientMetadata( + client_name="Cecli", + redirect_uris=[redirect_uri], + grant_types=["authorization_code", "refresh_token"], + ) + oauth_provider = OAuthClientProvider( + server_url=server_url, + client_metadata=client_metadata, + storage=FileBasedTokenStorage(self.name), + redirect_handler=handle_redirect, + callback_handler=get_auth_code, + ) + + return oauth_provider + + def _create_transport(self, url, http_client): + """ + Create the transport for this server type. + Must be implemented by subclasses. + """ + raise NotImplementedError("Subclasses must implement _create_transport") async def connect(self): if self.session is not None: - logging.info(f"Using existing session for SSE MCP server: {self.name}") + if self.verbose and self.io: + self.io.tool_output(f"Using existing session for {self.name}") return self.session - logging.info(f"Establishing new connection to SSE MCP server: {self.name}") + if self.verbose and self.io: + self.io.tool_output(f"Establishing new connection to {self.name}") + try: url = self.config.get("url") headers = self.config.get("headers", {}) - sse_transport = await self.exit_stack.enter_async_context( - sse_client(url, headers=headers) + oauth_provider = await self._create_oauth_provider() + + http_client = await self.exit_stack.enter_async_context( + httpx.AsyncClient( + auth=oauth_provider, + follow_redirects=True, + headers=headers, + timeout=30, + ) + ) + + transport = await self.exit_stack.enter_async_context( + self._create_transport(url, http_client=http_client) ) - read, write, _response = sse_transport + + read, write, _ = transport + session = await self.exit_stack.enter_async_context(ClientSession(read, write)) await session.initialize() self.session = session + + if oauth_provider.context.oauth_metadata: + token_endpoint = oauth_provider._get_token_endpoint() + server_info = get_mcp_oauth_token(self.name) + if "client_info" not in server_info: + server_info["client_info"] = {} + + server_info["client_info"]["token_endpoint"] = token_endpoint + + save_mcp_oauth_token(self.name, server_info) + return session except Exception as e: - logging.error(f"Error initializing SSE server {self.name}: {e}") + logging.error(f"Error initializing {self.name}: {e}") await self.disconnect() raise + async def disconnect(self): + """Disconnect from the MCP server and clean up resources.""" + async with self._cleanup_lock: + try: + if hasattr(self, "_oauth_shutdown"): + self._oauth_shutdown() + await self.exit_stack.aclose() + self.session = None + except Exception as e: + logging.error(f"Error during cleanup of server {self.name}: {e}") + + +class HttpStreamingServer(HttpBasedMcpServer): + """HTTP streaming MCP server using mcp.client.streamable_http_client.""" + + def _create_transport(self, url, http_client): + """Create the HTTP streaming transport.""" + return streamable_http_client(url, http_client=http_client) + + +class SseServer(HttpBasedMcpServer): + """SSE (Server-Sent Events) MCP server using mcp.client.sse_client.""" + + def _create_transport(self, url, http_client): + """Create the SSE transport.""" + return sse_client(url, http_client=http_client) + class LocalServer(McpServer): """ @@ -138,7 +266,8 @@ class LocalServer(McpServer): async def connect(self): """Local tools don't need a connection.""" if self.session is not None: - logging.info(f"Using existing session for local tools: {self.name}") + if self.verbose and self.io: + self.io.tool_output(f"Using existing session for local tools: {self.name}") return self.session self.session = object() # Dummy session object diff --git a/cecli/onboarding.py b/cecli/onboarding.py index 72d49ff4c51..a0037520d1f 100644 --- a/cecli/onboarding.py +++ b/cecli/onboarding.py @@ -1,8 +1,5 @@ -import base64 -import hashlib import http.server import os -import secrets import socketserver import threading import time @@ -13,6 +10,7 @@ from cecli import urls from cecli.io import InputOutput +from cecli.mcp.oauth import find_available_port, generate_pkce_codes def check_openrouter_tier(api_key): @@ -117,24 +115,7 @@ async def select_default_model(args, io): await io.offer_url(urls.models_and_keys, "Open documentation URL for more info?") -def find_available_port(start_port=8484, end_port=8584): - for port in range(start_port, end_port + 1): - try: - with socketserver.TCPServer(("localhost", port), None): - return port - except OSError: - continue - return None - - -def generate_pkce_codes(): - code_verifier = secrets.token_urlsafe(64) - hasher = hashlib.sha256() - hasher.update(code_verifier.encode("utf-8")) - code_challenge = base64.urlsafe_b64encode(hasher.digest()).rstrip(b"=").decode("utf-8") - return code_verifier, code_challenge - - +# Function to exchange the authorization code for an API key def exchange_code_for_key(code, code_verifier, io): try: response = requests.post(