diff --git a/cecli/coders/agent_coder.py b/cecli/coders/agent_coder.py index 508aa25cb92..97ce53842fa 100644 --- a/cecli/coders/agent_coder.py +++ b/cecli/coders/agent_coder.py @@ -22,7 +22,7 @@ normalize_vector, ) from cecli.helpers.skills import SkillsManager -from cecli.mcp.server import LocalServer +from cecli.mcp import LocalServer, McpServerManager from cecli.repo import ANY_GIT_ERROR from cecli.tools.utils.registry import ToolRegistry @@ -197,14 +197,17 @@ async def initialize_mcp_tools(self): local_tools = self.get_local_tool_schemas() if not local_tools: return + local_server_config = {"name": server_name} local_server = LocalServer(local_server_config) - if not self.mcp_servers: - self.mcp_servers = [] - if not any(isinstance(s, LocalServer) for s in self.mcp_servers): - self.mcp_servers.append(local_server) + + if not self.mcp_manager: + self.mcp_manager = McpServerManager() + if not self.mcp_manager.get_server(server_name): + await self.mcp_manager.add_server(local_server) if not self.mcp_tools: self.mcp_tools = [] + if server_name not in [name for name, _ in self.mcp_tools]: self.mcp_tools.append((local_server.name, local_tools)) @@ -245,9 +248,7 @@ async def _execute_local_tool_calls(self, tool_calls_list): t.get("function", {}).get("name") == norm_tool_name for t in server_tools ): - server = next( - (s for s in self.mcp_servers if s.name == server_name), None - ) + server = self.mcp_manager.get_server(server_name) if server: for params in parsed_args_list: tasks.append( @@ -943,7 +944,7 @@ async def _execute_tool_with_registry(self, norm_tool_name, params): if self.mcp_tools: for server_name, server_tools in self.mcp_tools: if any(t.get("function", {}).get("name") == norm_tool_name for t in server_tools): - server = next((s for s in self.mcp_servers if s.name == server_name), None) + server = self.mcp_manager.get_server(server_name) if server: return await self._execute_mcp_tool(server, norm_tool_name, params) else: diff --git a/cecli/coders/base_coder.py b/cecli/coders/base_coder.py index 0991fc011e2..38b49c34d2d 100755 --- a/cecli/coders/base_coder.py +++ b/cecli/coders/base_coder.py @@ -44,7 +44,7 @@ from cecli.io import ConfirmGroup, InputOutput from cecli.linter import Linter from cecli.llm import litellm -from cecli.mcp.server import LocalServer +from cecli.mcp import LocalServer from cecli.models import RETRY_TIMEOUT from cecli.reasoning_tags import ( REASONING_TAG, @@ -138,7 +138,7 @@ class Coder: chat_language = None commit_language = None file_watcher = None - mcp_servers = None + mcp_manager = None mcp_tools = None run_one_completed = True compact_context_completed = True @@ -249,8 +249,8 @@ async def create( if res is not None: if from_coder: - if from_coder.mcp_servers and kwargs.get("mcp_servers", False): - res.mcp_servers = from_coder.mcp_servers + if from_coder.mcp_manager: + res.mcp_manager = from_coder.mcp_manager res.mcp_tools = from_coder.mcp_tools # Transfer TUI app weak reference @@ -316,7 +316,7 @@ def __init__( file_watcher=None, auto_copy_context=False, auto_accept_architect=True, - mcp_servers=None, + mcp_manager=None, enable_context_compaction=False, context_compaction_max_tokens=None, context_compaction_summary_tokens=8192, @@ -350,7 +350,7 @@ def __init__( self.args = args self.num_cache_warming_pings = num_cache_warming_pings - self.mcp_servers = mcp_servers + self.mcp_manager = mcp_manager self.enable_context_compaction = enable_context_compaction self.context_compaction_max_tokens = context_compaction_max_tokens @@ -2546,7 +2546,7 @@ def _gather_server_tool_calls(self, tool_calls): and tool_name_from_schema.lower() == tool_call.function.name.lower() ): # Find the McpServer instance that will be used for communication - for server in self.mcp_servers: + for server in self.mcp_manager: if server.name == server_name: if server not in server_tool_calls: server_tool_calls[server] = [] @@ -2724,6 +2724,7 @@ async def initialize_mcp_tools(self): Initialize tools from all configured MCP servers. MCP Servers that fail to be initialized will not be available to the Coder instance. """ + # TODO(@gopar): refactor here once we have fully moved over to use the mcp manager tools = [] async def get_server_tools(server): @@ -2734,9 +2735,13 @@ async def get_server_tools(server): return (server.name, server_tools) try: - session = await server.connect() + did_connect = await self.mcp_manager.connect_server(server.name) + if not did_connect: + raise Exception("Failed to load tools") + + server = self.mcp_manager.get_server(server.name) server_tools = await experimental_mcp_client.load_mcp_tools( - session=session, format="openai" + session=server.session, format="openai" ) return (server.name, server_tools) except Exception as e: @@ -2745,11 +2750,11 @@ async def get_server_tools(server): return None async def get_all_server_tools(): - tasks = [get_server_tools(server) for server in self.mcp_servers] + tasks = [get_server_tools(server) for server in self.mcp_manager] results = await asyncio.gather(*tasks) return [result for result in results if result is not None] - if self.mcp_servers: + if self.mcp_manager: # Retry initialization in case of CancelledError max_retries = 3 for i in range(max_retries): diff --git a/cecli/commands/exit.py b/cecli/commands/exit.py index 73f3e5d4d61..46b96576df5 100644 --- a/cecli/commands/exit.py +++ b/cecli/commands/exit.py @@ -14,14 +14,6 @@ class ExitCommand(BaseCommand): @classmethod async def execute(cls, io, coder, args, **kwargs): """Execute the exit command with given parameters.""" - for server in coder.mcp_servers: - try: - await server.exit_stack.aclose() - except Exception: - pass - - await asyncio.sleep(0) - # Check if running in TUI mode - use graceful exit to restore terminal if hasattr(io, "request_exit"): io.request_exit() diff --git a/cecli/main.py b/cecli/main.py index 757d35299a1..a35c56f20d2 100644 --- a/cecli/main.py +++ b/cecli/main.py @@ -47,7 +47,7 @@ from cecli.history import ChatSummary from cecli.io import InputOutput from cecli.llm import litellm -from cecli.mcp import load_mcp_servers +from cecli.mcp import McpServerManager, load_mcp_servers from cecli.models import ModelSettings from cecli.onboarding import offer_openrouter_oauth, select_default_model from cecli.repo import ANY_GIT_ERROR, GitRepo @@ -976,8 +976,8 @@ def apply_model_overrides(model_name): mcp_servers = load_mcp_servers( args.mcp_servers, args.mcp_servers_file, io, args.verbose, args.mcp_transport ) - if not mcp_servers: - mcp_servers = [] + mcp_manager = McpServerManager(mcp_servers, io, args.verbose) + coder = await Coder.create( main_model=main_model, edit_format=args.edit_format, @@ -1013,7 +1013,7 @@ def apply_model_overrides(model_name): detect_urls=args.detect_urls, auto_copy_context=args.copy_paste, auto_accept_architect=args.auto_accept_architect, - mcp_servers=mcp_servers, + mcp_manager=mcp_manager, add_gitignore_files=args.add_gitignore_files, enable_context_compaction=args.enable_context_compaction, context_compaction_max_tokens=args.context_compaction_max_tokens, @@ -1267,11 +1267,9 @@ async def graceful_exit(coder=None, exit_code=0): if coder: if hasattr(coder, "_autosave_future"): await coder._autosave_future - for server in coder.mcp_servers: - try: - await server.exit_stack.aclose() - except Exception: - pass + + if coder.mcp_manager and coder.mcp_manager.is_connected: + await coder.mcp_manager.disconnect_all() return exit_code diff --git a/cecli/mcp/__init__.py b/cecli/mcp/__init__.py index 44d1e6a5f15..19eb8f60ae1 100644 --- a/cecli/mcp/__init__.py +++ b/cecli/mcp/__init__.py @@ -1,154 +1,14 @@ -import json -from pathlib import Path - -from cecli.mcp.server import HttpStreamingServer, McpServer, SseServer - - -def _parse_mcp_servers_from_json_string(json_string, io, verbose=False, mcp_transport="stdio"): - """Parse MCP servers from a JSON string.""" - servers = [] - - try: - config = json.loads(json_string) - if verbose: - io.tool_output("Loading MCP servers from provided JSON") - - if "mcpServers" in config: - for name, server_config in config["mcpServers"].items(): - if verbose: - io.tool_output(f"Loading MCP server: {name}") - - # Create a server config with name included - server_config["name"] = name - transport = server_config.get("transport", mcp_transport) - if transport == "stdio": - servers.append(McpServer(server_config, io=io, verbose=verbose)) - elif transport == "http": - servers.append(HttpStreamingServer(server_config, io=io, verbose=verbose)) - elif transport == "sse": - servers.append(SseServer(server_config, io=io, verbose=verbose)) - - if verbose: - io.tool_output(f"Loaded {len(servers)} MCP servers") - return servers - else: - io.tool_warning("No 'mcpServers' key found in MCP config") - except json.JSONDecodeError: - io.tool_error("Invalid JSON in MCP config") - except Exception as e: - io.tool_error(f"Error loading MCP config: {e}") - - return servers - - -def _resolve_mcp_config_path(file_path, io, verbose=False): - """Resolve MCP config file path relative to closest cecli.conf.yml, git directory, or CWD.""" - if not file_path: - return None - - # If the path is absolute or already exists, use it as-is - path = Path(file_path) - if path.is_absolute() or path.exists(): - return str(path.resolve()) - - # Search for the closest cecli.conf.yml in parent directories - current_dir = Path.cwd() - conf_path = None - - for parent in [current_dir] + list(current_dir.parents): - conf_file = parent / ".cecli.conf.yml" - if conf_file.exists(): - conf_path = parent - break - - # If cecli.conf.yml found, try relative to that directory - if conf_path: - resolved_path = conf_path / file_path - if resolved_path.exists(): - if verbose: - io.tool_output(f"Resolved MCP config relative to cecli.conf.yml: {resolved_path}") - return str(resolved_path.resolve()) - - # Try to find git root directory - git_root = None - try: - import git - - repo = git.Repo(search_parent_directories=True) - git_root = Path(repo.working_tree_dir) - except (ImportError, git.InvalidGitRepositoryError, FileNotFoundError): - pass - - # If git root found, try relative to that directory - if git_root: - resolved_path = git_root / file_path - if resolved_path.exists(): - if verbose: - io.tool_output(f"Resolved MCP config relative to git root: {resolved_path}") - return str(resolved_path.resolve()) - - # Finally, try relative to current working directory - resolved_path = current_dir / file_path - if resolved_path.exists(): - if verbose: - io.tool_output(f"Resolved MCP config relative to CWD: {resolved_path}") - return str(resolved_path.resolve()) - - # If none found, return the original path (will trigger FileNotFoundError) - return str(path.resolve()) - - -def _parse_mcp_servers_from_file(file_path, io, verbose=False, mcp_transport="stdio"): - """Parse MCP servers from a JSON file.""" - # Resolve the file path relative to closest cecli.conf.yml, git directory, or CWD - resolved_file_path = _resolve_mcp_config_path(file_path, io, verbose) - - try: - with open(resolved_file_path, "r") as f: - json_string = f.read() - - if verbose: - io.tool_output(f"Loading MCP servers from file: {file_path}") - - return _parse_mcp_servers_from_json_string(json_string, io, verbose, mcp_transport) - - except FileNotFoundError: - io.tool_warning(f"MCP config file not found: {file_path}") - except Exception as e: - io.tool_error(f"Error reading MCP config file: {e}") - - return [] - - -def load_mcp_servers(mcp_servers, mcp_servers_file, io, verbose=False, mcp_transport="stdio"): - """Load MCP servers from a JSON string or file.""" - servers = [] - - # First try to load from the JSON string (preferred) - if mcp_servers: - servers = _parse_mcp_servers_from_json_string(mcp_servers, io, verbose, mcp_transport) - if servers: - return servers - - # If JSON string failed or wasn't provided, try the file - if mcp_servers_file: - servers = _parse_mcp_servers_from_file(mcp_servers_file, io, verbose, mcp_transport) - - if not servers: - # A default MCP server is actually now necessary for the overall agentic loop - # and a dummy server does suffice for the job - # because I am not smart enough to figure out why - # on coder switch, the agent actually initializes the prompt area twice - # once immediately after input for the old coder - # and immediately again for the new target coder - # which causes a race condition where we are awaiting a coroutine - # that can no longer yield control (somehow?) - # but somehow having to run through the MCP server checks - # allows control to be yielded again somehow - # and I cannot figure out just how that is happening - # and maybe it is actually prompt_toolkit's fault - # but this hack works swimmingly because ??? - # so sure! why not - servers = [McpServer(json.loads('{"cecli_default": {}}'), io=io, verbose=verbose)] - - return servers +from .manager import McpServerManager +from .server import HttpStreamingServer, LocalServer, McpServer, SseServer +from .utils import find_available_port, generate_pkce_codes, load_mcp_servers + +__all__ = [ + "McpServerManager", + "McpServer", + "HttpStreamingServer", + "SseServer", + "LocalServer", + "load_mcp_servers", + "find_available_port", + "generate_pkce_codes", +] diff --git a/cecli/mcp/manager.py b/cecli/mcp/manager.py new file mode 100644 index 00000000000..8ec30bdd58c --- /dev/null +++ b/cecli/mcp/manager.py @@ -0,0 +1,249 @@ +import asyncio +import logging + +from cecli.mcp.server import McpServer + + +class McpServerManager: + """ + Centralized manager for MCP server connections. + + Handles connection lifecycle for all MCP servers, ensuring + connections are established once and reused across all Coder instances. + """ + + def __init__( + self, + servers: list[McpServer], + io=None, + verbose: bool = False, + ): + """ + Initialize the MCP server manager. + + Args: + mcp_servers: JSON string containing MCP server configurations + mcp_servers_file: Path to a JSON file containing MCP server configurations + io: InputOutput instance for user interaction + verbose: Whether to output verbose logging + """ + self.io = io + self.verbose = verbose + self._servers = servers + self._server_tools: dict[str, list] = {} # Maps server name to its tools + self._connected_servers: set[McpServer] = set() + + def _log_verbose(self, message: str) -> None: + """Log a verbose message if verbose mode is enabled and IO is available.""" + if self.verbose and self.io: + self.io.tool_output(message) + + def _log_error(self, message: str) -> None: + """Log an error message if IO is available.""" + if self.io: + self.io.tool_error(message) + + def _log_warning(self, message: str) -> None: + """Log a warning message if IO is available.""" + if self.io: + self.io.tool_warning(message) + + @property + def servers(self) -> list["McpServer"]: + """Get the list of managed MCP servers.""" + return self._servers + + @property + def is_connected(self) -> bool: + """Check if any servers are connected.""" + return len(self._connected_servers) > 0 + + def get_server(self, name: str) -> McpServer | None: + """ + Get a server by name. + + Args: + name: Name of the server to retrieve + + Returns: + The server instance or None if not found + """ + try: + return next(server for server in self._servers if server.name == name) + except StopIteration: + return None + + async def connect_all(self) -> None: + """Connect to all MCP servers.""" + if self.is_connected: + self._log_verbose("Some MCP servers already connected") + return + + self._log_verbose(f"Connecting to {len(self._servers)} MCP servers") + + async def connect_server(server: McpServer) -> tuple[McpServer, bool]: + try: + session = await server.connect() + tools_result = await session.list_tools() + self._server_tools[server.name] = tools_result.tools + self._log_verbose(f"Connected to MCP server: {server.name}") + return (server, True) + except Exception as e: + logging.error(f"Error connecting to MCP server {server.name}: {e}") + self._log_error(f"Failed to connect to MCP server {server.name}: {e}") + return (server, False) + + results = await asyncio.gather(*[connect_server(server) for server in self._servers]) + + for server, success in results: + if success: + self._connected_servers.add(server) + + async def disconnect_all(self) -> None: + """Disconnect from all MCP servers.""" + if not self._connected_servers: + self._log_verbose("MCP servers already disconnected") + return + + self._log_verbose("Disconnecting from all MCP servers") + + async def disconnect_server(server: McpServer) -> tuple[McpServer, bool]: + try: + await server.disconnect() + if server.name in self._server_tools: + del self._server_tools[server.name] + self._log_verbose(f"Disconnected from MCP server: {server.name}") + return (server, True) + except Exception: + self._log_warning(f"Error disconnected from MCP server: {server.name}") + return (server, False) + + # Create a copy to avoid modifying during iteration + servers_to_disconnect = list(self._connected_servers) + tasks = [disconnect_server(server) for server in servers_to_disconnect] + results = await asyncio.gather(*tasks) + + for server, success in results: + if success: + self._connected_servers.remove(server) + + async def connect_server(self, name: str) -> bool: + """ + Connect to a specific MCP server by name. + + Args: + name: Name of the server to connect to + + Returns: + Boolean indicating success or failure + """ + server = self.get_server(name) + if not server: + self._log_warning(f"MCP server not found: {name}") + return False + + if server in self._connected_servers: + self._log_verbose(f"MCP server already connected: {name}") + return True + + try: + session = await server.connect() + tools_result = await session.list_tools() + self._server_tools[server.name] = tools_result.tools + self._connected_servers.add(server) + self._log_verbose(f"Connected to MCP server: {name}") + return True + except Exception as e: + logging.error(f"Error connecting to MCP server {name}: {e}") + self._log_error(f"Failed to connect to MCP server {name}: {e}") + return False + + async def disconnect_server(self, name: str) -> bool: + """ + Disconnect from a specific MCP server by name. + + Args: + name: Name of the server to disconnect from + + Returns: + Boolean indicating success or failure + """ + server = self.get_server(name) + if not server: + self._log_warning(f"MCP server not found: {name}") + return False + + if server not in self._connected_servers: + self._log_verbose(f"MCP server not connected: {name}") + return True + + try: + await server.disconnect() + if server.name in self._server_tools: + del self._server_tools[server.name] + self._connected_servers.remove(server) + self._log_verbose(f"Disconnected from MCP server: {name}") + return True + except Exception as e: + self._log_warning(f"Error disconnecting from MCP server {name}: {e}") + return False + + async def add_server(self, server: McpServer, connect: bool = False) -> bool: + """ + Add a new MCP server to the manager. + + Args: + server: McpServer instance to add + connect: Whether to immediately connect to the server + + Returns: + Boolean indicating success or failure + """ + existing_server = self.get_server(server.name) + if existing_server: + self._log_warning(f"MCP server with name '{server.name}' already exists") + return False + + self._servers.append(server) + self._log_verbose(f"Added MCP server: {server.name}") + + if connect: + return await self.connect_server(server.name) + + return True + + @property + def connected_servers(self) -> list["McpServer"]: + """Get the list of successfully connected servers.""" + return list(self._connected_servers) + + @property + def failed_servers(self) -> list["McpServer"]: + """Get the list of servers that failed to connect.""" + return [server for server in self._servers if server not in self._connected_servers] + + def __iter__(self): + for server in self._servers: + yield server + + def get_server_tools(self, name: str) -> list | None: + """ + Get the tools for a specific server. + + Args: + name: Name of the server + + Returns: + List of tools or None if server not found or not connected + """ + return self._server_tools.get(name) + + @property + def all_tools(self) -> dict[str, list]: + """ + Get all tools from all connected servers. + + Returns: + Dictionary mapping server names to their tools + """ + return self._server_tools.copy() diff --git a/cecli/mcp/oauth.py b/cecli/mcp/oauth.py index c9d9897116f..82611c61462 100644 --- a/cecli/mcp/oauth.py +++ b/cecli/mcp/oauth.py @@ -1,10 +1,7 @@ import asyncio -import base64 -import hashlib import http.server import json import os -import secrets import socketserver import threading import time @@ -16,19 +13,6 @@ from mcp.shared.auth import OAuthClientInformationFull, OAuthToken -def find_available_port(start_port=8484, end_port=8584): - """Find an available port in the given range.""" - for port in range(start_port, end_port + 1): - try: - # Check if the port is available by trying to bind to it - with socketserver.TCPServer(("localhost", port), None): - return port - except OSError: - # Port is likely already in use - continue - return None - - def create_oauth_callback_server( port, path="/callback" ) -> Tuple[Callable[[], Awaitable[Tuple[str, str]]], Callable[[], None]]: @@ -139,15 +123,6 @@ async def get_auth_code() -> Tuple[str, str]: return get_auth_code, shutdown -def generate_pkce_codes(): - """Generate PKCE code verifier and challenge.""" - code_verifier = secrets.token_urlsafe(64) - hasher = hashlib.sha256() - hasher.update(code_verifier.encode("utf-8")) - code_challenge = base64.urlsafe_b64encode(hasher.digest()).rstrip(b"=").decode("utf-8") - return code_verifier, code_challenge - - def get_token_file_path(): """Get the path to the MCP OAuth tokens file.""" config_dir = Path.home() / ".cecli" diff --git a/cecli/mcp/server.py b/cecli/mcp/server.py index 58c2bcb661e..65a97af00af 100644 --- a/cecli/mcp/server.py +++ b/cecli/mcp/server.py @@ -13,10 +13,9 @@ from mcp.client.streamable_http import streamable_http_client from mcp.shared.auth import OAuthClientMetadata -from cecli.mcp.oauth import ( +from .oauth import ( FileBasedTokenStorage, create_oauth_callback_server, - find_available_port, get_mcp_oauth_token, save_mcp_oauth_token, ) @@ -94,9 +93,14 @@ async def disconnect(self): async with self._cleanup_lock: try: await self.exit_stack.aclose() - self.session = None + except (asyncio.CancelledError, RuntimeError, GeneratorExit): + # Expected during shutdown - anyio cancel scopes don't play + # well with asyncio teardown. Resources are still cleaned up. + pass except Exception as e: logging.error(f"Error during cleanup of server {self.name}: {e}") + finally: + self.session = None class HttpBasedMcpServer(McpServer): @@ -122,6 +126,8 @@ async def _create_oauth_provider(self): f"Found existing redirect URI: {existing_redirect_uri}", log_only=True ) + from .utils import find_available_port + # If we have an existing redirect URI, parse it to get the port if existing_redirect_uri: try: @@ -236,9 +242,14 @@ async def disconnect(self): if hasattr(self, "_oauth_shutdown"): self._oauth_shutdown() await self.exit_stack.aclose() - self.session = None + except (asyncio.CancelledError, RuntimeError, GeneratorExit): + # Expected during shutdown - anyio cancel scopes don't play + # well with asyncio teardown. Resources are still cleaned up. + pass except Exception as e: logging.error(f"Error during cleanup of server {self.name}: {e}") + finally: + self.session = None class HttpStreamingServer(HttpBasedMcpServer): diff --git a/cecli/mcp/utils.py b/cecli/mcp/utils.py new file mode 100644 index 00000000000..5642a9b9aae --- /dev/null +++ b/cecli/mcp/utils.py @@ -0,0 +1,184 @@ +import base64 +import hashlib +import json +import secrets +import socketserver +from pathlib import Path + +from .server import McpServer + + +def find_available_port(start_port=8484, end_port=8584): + """Find an available port in the given range.""" + for port in range(start_port, end_port + 1): + try: + # Check if the port is available by trying to bind to it + with socketserver.TCPServer(("localhost", port), None): + return port + except OSError: + # Port is likely already in use + continue + return None + + +def generate_pkce_codes(): + """Generate PKCE code verifier and challenge.""" + code_verifier = secrets.token_urlsafe(64) + hasher = hashlib.sha256() + hasher.update(code_verifier.encode("utf-8")) + code_challenge = base64.urlsafe_b64encode(hasher.digest()).rstrip(b"=").decode("utf-8") + return code_verifier, code_challenge + + +def _parse_mcp_servers_from_json_string(json_string, io, verbose=False, mcp_transport="stdio"): + """Parse MCP servers from a JSON string.""" + from .server import HttpStreamingServer, McpServer, SseServer + + servers = [] + + try: + config = json.loads(json_string) + if verbose: + io.tool_output("Loading MCP servers from provided JSON") + + if "mcpServers" in config: + for name, server_config in config["mcpServers"].items(): + if verbose: + io.tool_output(f"Loading MCP server: {name}") + + # Create a server config with name included + server_config["name"] = name + transport = server_config.get("transport", mcp_transport) + if transport == "stdio": + servers.append(McpServer(server_config, io=io, verbose=verbose)) + elif transport == "http": + servers.append(HttpStreamingServer(server_config, io=io, verbose=verbose)) + elif transport == "sse": + servers.append(SseServer(server_config, io=io, verbose=verbose)) + + if verbose: + io.tool_output(f"Loaded {len(servers)} MCP servers") + return servers + else: + io.tool_warning("No 'mcpServers' key found in MCP config") + except json.JSONDecodeError: + io.tool_error("Invalid JSON in MCP config") + except Exception as e: + io.tool_error(f"Error loading MCP config: {e}") + + return servers + + +def _resolve_mcp_config_path(file_path, io, verbose=False): + """Resolve MCP config file path relative to closest cecli.conf.yml, git directory, or CWD.""" + if not file_path: + return None + + # If the path is absolute or already exists, use it as-is + path = Path(file_path) + if path.is_absolute() or path.exists(): + return str(path.resolve()) + + # Search for the closest cecli.conf.yml in parent directories + current_dir = Path.cwd() + conf_path = None + + for parent in [current_dir] + list(current_dir.parents): + conf_file = parent / ".cecli.conf.yml" + if conf_file.exists(): + conf_path = parent + break + + # If cecli.conf.yml found, try relative to that directory + if conf_path: + resolved_path = conf_path / file_path + if resolved_path.exists(): + if verbose: + io.tool_output(f"Resolved MCP config relative to cecli.conf.yml: {resolved_path}") + return str(resolved_path.resolve()) + + # Try to find git root directory + git_root = None + try: + import git + + repo = git.Repo(search_parent_directories=True) + git_root = Path(repo.working_tree_dir) + except (ImportError, git.InvalidGitRepositoryError, FileNotFoundError): + pass + + # If git root found, try relative to that directory + if git_root: + resolved_path = git_root / file_path + if resolved_path.exists(): + if verbose: + io.tool_output(f"Resolved MCP config relative to git root: {resolved_path}") + return str(resolved_path.resolve()) + + # Finally, try relative to current working directory + resolved_path = current_dir / file_path + if resolved_path.exists(): + if verbose: + io.tool_output(f"Resolved MCP config relative to CWD: {resolved_path}") + return str(resolved_path.resolve()) + + # If none found, return the original path (will trigger FileNotFoundError) + return str(path.resolve()) + + +def _parse_mcp_servers_from_file(file_path, io, verbose=False, mcp_transport="stdio"): + """Parse MCP servers from a JSON file.""" + # Resolve the file path relative to closest cecli.conf.yml, git directory, or CWD + resolved_file_path = _resolve_mcp_config_path(file_path, io, verbose) + + try: + with open(resolved_file_path, "r") as f: + json_string = f.read() + + if verbose: + io.tool_output(f"Loading MCP servers from file: {file_path}") + + return _parse_mcp_servers_from_json_string(json_string, io, verbose, mcp_transport) + + except FileNotFoundError: + io.tool_warning(f"MCP config file not found: {file_path}") + except Exception as e: + io.tool_error(f"Error reading MCP config file: {e}") + + return [] + + +def load_mcp_servers( + mcp_servers, mcp_servers_file, io, verbose=False, mcp_transport="stdio" +) -> list["McpServer"]: + """Load MCP servers from a JSON string or file.""" + servers = [] + + # First try to load from the JSON string (preferred) + if mcp_servers: + servers = _parse_mcp_servers_from_json_string(mcp_servers, io, verbose, mcp_transport) + if servers: + return servers + + # If JSON string failed or wasn't provided, try the file + if mcp_servers_file: + servers = _parse_mcp_servers_from_file(mcp_servers_file, io, verbose, mcp_transport) + + if not servers: + # A default MCP server is actually now necessary for the overall agentic loop + # and a dummy server does suffice for the job + # because I am not smart enough to figure out why + # on coder switch, the agent actually initializes the prompt area twice + # once immediately after input for the old coder + # and immediately again for the new target coder + # which causes a race condition where we are awaiting a coroutine + # that can no longer yield control (somehow?) + # but somehow having to run through the MCP server checks + # allows control to be yielded again somehow + # and I cannot figure out just how that is happening + # and maybe it is actually prompt_toolkit's fault + # but this hack works swimmingly because ??? + # so sure! why not + servers = [McpServer(json.loads('{"cecli_default": {}}'), io=io, verbose=verbose)] + + return servers diff --git a/cecli/onboarding.py b/cecli/onboarding.py index a0037520d1f..63470bb88c8 100644 --- a/cecli/onboarding.py +++ b/cecli/onboarding.py @@ -10,7 +10,7 @@ from cecli import urls from cecli.io import InputOutput -from cecli.mcp.oauth import find_available_port, generate_pkce_codes +from cecli.mcp import find_available_port, generate_pkce_codes def check_openrouter_tier(api_key): diff --git a/tests/basic/test_coder.py b/tests/basic/test_coder.py index febe38028e3..07e3a161892 100644 --- a/tests/basic/test_coder.py +++ b/tests/basic/test_coder.py @@ -12,6 +12,7 @@ from cecli.commands import SwitchCoderSignal from cecli.dump import dump # noqa: F401 from cecli.io import InputOutput +from cecli.mcp import McpServerManager from cecli.models import Model from cecli.repo import GitRepo from cecli.sendchat import sanity_check_messages @@ -1450,7 +1451,7 @@ async def test_mcp_server_connection(self, mock_mcp_client): # Create coder with mock MCP server with patch.object(Coder, "initialize_mcp_tools", return_value=mock_tools): - coder = await Coder.create(self.GPT35, "diff", io=io, mcp_servers=[mock_server]) + coder = await Coder.create(self.GPT35, "diff", io=io) # Manually set mcp_tools since we're bypassing initialize_mcp_tools coder.mcp_tools = mock_tools @@ -1478,9 +1479,12 @@ async def test_coder_creation_with_partial_failed_mcp_server(self, mock_mcp_clie failing_server.connect = AsyncMock() failing_server.disconnect = AsyncMock() + manager = McpServerManager([working_server, failing_server]) + manager._connected_servers = [working_server] + # Mock load_mcp_tools to succeed for working_server and fail for failing_server async def mock_load_mcp_tools(session, format): - if session == await working_server.connect(): + if session == working_server.session: return [{"function": {"name": "working_tool"}}] else: raise Exception("Failed to load tools") @@ -1492,7 +1496,7 @@ async def mock_load_mcp_tools(session, format): self.GPT35, "diff", io=io, - mcp_servers=[working_server, failing_server], + mcp_manager=manager, verbose=True, ) @@ -1526,6 +1530,9 @@ async def test_coder_creation_with_all_failed_mcp_server(self, mock_mcp_client): failing_server.connect = AsyncMock() failing_server.disconnect = AsyncMock() + manager = McpServerManager([failing_server]) + manager._connected_servers = [] + # Mock load_mcp_tools to succeed for working_server and fail for failing_server async def mock_load_mcp_tools(session, format): raise Exception("Failed to load tools") @@ -1537,7 +1544,7 @@ async def mock_load_mcp_tools(session, format): self.GPT35, "diff", io=io, - mcp_servers=[failing_server], + mcp_manager=manager, verbose=True, ) @@ -1594,6 +1601,9 @@ async def test_process_tool_calls_with_tools(self): mock_server.connect = AsyncMock() mock_server.disconnect = AsyncMock() + manager = McpServerManager([mock_server]) + manager._connected_servers = [mock_server] + # Create a tool call tool_call = MagicMock() tool_call.id = "test_id" @@ -1612,9 +1622,8 @@ async def test_process_tool_calls_with_tools(self): ) # Create coder with mock MCP tools and servers - coder = await Coder.create(self.GPT35, "diff", io=io) + coder = await Coder.create(self.GPT35, "diff", io=io, mcp_manager=manager) coder.mcp_tools = [("test_server", [{"function": {"name": "test_tool"}}])] - coder.mcp_servers = [mock_server] # Mock _execute_tool_calls to return tool responses tool_responses = [ @@ -1661,12 +1670,16 @@ async def test_process_tool_calls_max_calls_exceeded(self): # Create mock MCP server mock_server = MagicMock() mock_server.name = "test_server" + mock_server.connect = AsyncMock() + mock_server.session = AsyncMock() + + manager = McpServerManager([mock_server]) + manager._connected_servers = [mock_server] # Create coder with max tool calls exceeded - coder = await Coder.create(self.GPT35, "diff", io=io) + coder = await Coder.create(self.GPT35, "diff", io=io, mcp_manager=manager) coder.num_tool_calls = coder.max_tool_calls coder.mcp_tools = [("test_server", [{"function": {"name": "test_tool"}}])] - coder.mcp_servers = [mock_server] # Test process_tool_calls result = await coder.process_tool_calls(response) @@ -1702,10 +1715,12 @@ async def test_process_tool_calls_user_rejects(self): mock_server.connect = AsyncMock() mock_server.disconnect = AsyncMock() + manager = McpServerManager([mock_server]) + manager._connected_servers = [mock_server] + # Create coder with mock MCP tools - coder = await Coder.create(self.GPT35, "diff", io=io) + coder = await Coder.create(self.GPT35, "diff", io=io, mcp_manager=manager) coder.mcp_tools = [("test_server", [{"function": {"name": "test_tool"}}])] - coder.mcp_servers = [mock_server] # Test process_tool_calls result = await coder.process_tool_calls(response) diff --git a/tests/basic/test_main.py b/tests/basic/test_main.py index fa1912d382c..08037750ed6 100644 --- a/tests/basic/test_main.py +++ b/tests/basic/test_main.py @@ -334,6 +334,7 @@ async def mock_run(*args, **kwargs): MockCoder = mocker.patch("cecli.coders.Coder.create") mock_coder_instance = MagicMock() mock_coder_instance.run = AsyncMock() + mock_coder_instance.mcp_manager = False mock_coder_instance._autosave_future = mock_autosave_future() MockCoder.return_value = mock_coder_instance main(["--yes-always", "--message-file", str(message_file)], **dummy_io) @@ -973,6 +974,7 @@ def test_model_overrides_suffix_applied(dummy_io, git_temp_dir, mocker): MockCoder = mocker.patch("cecli.coders.Coder.create") mock_coder_instance = MagicMock() mock_coder_instance._autosave_future = mock_autosave_future() + mock_coder_instance.mcp_manager = False MockCoder.return_value = mock_coder_instance mock_instance = MockModel.return_value mock_instance.info = {} @@ -1004,6 +1006,7 @@ def test_model_overrides_no_match_preserves_model_name(dummy_io, git_temp_dir, m MockModel = mocker.patch("cecli.models.Model") MockCoder = mocker.patch("cecli.coders.Coder.create") mock_coder_instance = MagicMock() + mock_coder_instance.mcp_manager = False mock_coder_instance._autosave_future = mock_autosave_future() MockCoder.return_value = mock_coder_instance mock_instance = MockModel.return_value @@ -1345,6 +1348,7 @@ def test_load_dotenv_files_override(dummy_io, git_temp_dir, mocker): def test_mcp_servers_parsing(dummy_io, git_temp_dir, mocker): mock_coder_create = mocker.patch("cecli.coders.Coder.create") mock_coder_instance = MagicMock() + mock_coder_instance.mcp_manager = False mock_coder_instance._autosave_future = mock_autosave_future() mock_coder_create.return_value = mock_coder_instance main( @@ -1358,10 +1362,12 @@ def test_mcp_servers_parsing(dummy_io, git_temp_dir, mocker): ) mock_coder_create.assert_called_once() _, kwargs = mock_coder_create.call_args - assert "mcp_servers" in kwargs - assert kwargs["mcp_servers"] is not None - assert len(kwargs["mcp_servers"]) > 0 - assert hasattr(kwargs["mcp_servers"][0], "name") + + assert "mcp_manager" in kwargs + assert kwargs["mcp_manager"] is not None + assert len(kwargs["mcp_manager"].servers) > 0 + assert hasattr(kwargs["mcp_manager"].servers[0], "name") + mock_coder_create.reset_mock() mock_coder_instance._autosave_future = mock_autosave_future() mcp_file = Path("mcp_servers.json") @@ -1370,7 +1376,8 @@ def test_mcp_servers_parsing(dummy_io, git_temp_dir, mocker): main(["--mcp-servers-file", str(mcp_file), "--exit", "--yes-always"], **dummy_io) mock_coder_create.assert_called_once() _, kwargs = mock_coder_create.call_args - assert "mcp_servers" in kwargs - assert kwargs["mcp_servers"] is not None - assert len(kwargs["mcp_servers"]) > 0 - assert hasattr(kwargs["mcp_servers"][0], "name") + + assert "mcp_manager" in kwargs + assert kwargs["mcp_manager"] is not None + assert len(kwargs["mcp_manager"].servers) > 0 + assert hasattr(kwargs["mcp_manager"].servers[0], "name")