From d65cb48f53c30e8b40a8ce15e213bbdfa071bce7 Mon Sep 17 00:00:00 2001 From: Gopar Date: Tue, 13 Jan 2026 10:13:50 -0800 Subject: [PATCH 1/3] [gh-392] Add custom commands to load/remove MCPs dynamically --- cecli/coders/agent_coder.py | 37 +++++---- cecli/coders/base_coder.py | 68 +++------------- cecli/commands/__init__.py | 6 ++ cecli/commands/load_mcp.py | 77 ++++++++++++++++++ cecli/commands/remove_mcp.py | 65 +++++++++++++++ cecli/main.py | 2 +- cecli/mcp/manager.py | 117 +++++++++++++++++---------- cecli/mcp/server.py | 16 ++-- cecli/website/docs/usage/commands.md | 2 + 9 files changed, 261 insertions(+), 129 deletions(-) create mode 100644 cecli/commands/load_mcp.py create mode 100644 cecli/commands/remove_mcp.py diff --git a/cecli/coders/agent_coder.py b/cecli/coders/agent_coder.py index 64a89a1f0e3..0e1134580ca 100644 --- a/cecli/coders/agent_coder.py +++ b/cecli/coders/agent_coder.py @@ -203,25 +203,28 @@ def get_local_tool_schemas(self): return schemas async def initialize_mcp_tools(self): + # TODO(Gopar): update this part await super().initialize_mcp_tools() + + if not self.mcp_manager: + self.mcp_manager = McpServerManager() + server_name = "Local" - if server_name not in [name for name, _ in self.mcp_tools]: - local_tools = self.get_local_tool_schemas() - if not local_tools: - return - - local_server_config = {"name": server_name} - local_server = LocalServer(local_server_config) - - if not self.mcp_manager: - self.mcp_manager = McpServerManager() - if not self.mcp_manager.get_server(server_name): - await self.mcp_manager.add_server(local_server) - if not self.mcp_tools: - self.mcp_tools = [] - - if server_name not in [name for name, _ in self.mcp_tools]: - self.mcp_tools.append((local_server.name, local_tools)) + server = self.mcp_manager.get_server(server_name) + + # We have already initialized local server, no need to duplicate work + if server is not None: + return + + # If we dont have any tools for local server to use, no point in creating it then + local_tools = self.get_local_tool_schemas() + if not local_tools: + return + + local_server_config = {"name": server_name} + local_server = LocalServer(local_server_config) + + await self.mcp_manager.add_server(local_server, connect=True) async def _execute_local_tool_calls(self, tool_calls_list): tool_responses = [] diff --git a/cecli/coders/base_coder.py b/cecli/coders/base_coder.py index 5645fb78458..e0443f06a40 100755 --- a/cecli/coders/base_coder.py +++ b/cecli/coders/base_coder.py @@ -139,7 +139,6 @@ class Coder: commit_language = None file_watcher = None mcp_manager = None - mcp_tools = None run_one_completed = True compact_context_completed = True suppress_announcements_for_next_prompt = False @@ -228,6 +227,7 @@ async def create( total_tokens_sent=from_coder.total_tokens_sent, total_tokens_received=from_coder.total_tokens_received, file_watcher=from_coder.file_watcher, + mcp_manager=from_coder.mcp_manager, ) use_kwargs.update(update) # override to complete the switch use_kwargs.update(kwargs) # override passed kwargs @@ -251,7 +251,6 @@ async def create( if from_coder: if from_coder.mcp_manager: res.mcp_manager = from_coder.mcp_manager - res.mcp_tools = from_coder.mcp_tools # Transfer TUI app weak reference res.tui = from_coder.tui @@ -2750,66 +2749,17 @@ async def initialize_mcp_tools(self): Initialize tools from all configured MCP servers. MCP Servers that fail to be initialized will not be available to the Coder instance. """ - # TODO(@gopar): refactor here once we have fully moved over to use the mcp manager - tools = [] - async def get_server_tools(server): - # Check if we already have tools for this server in mcp_tools - if self.mcp_tools: - for server_name, server_tools in self.mcp_tools: - if server_name == server.name: - return (server.name, server_tools) - - try: - did_connect = await self.mcp_manager.connect_server(server.name) - if not did_connect: - raise Exception("Failed to load tools") - - server = self.mcp_manager.get_server(server.name) - server_tools = await experimental_mcp_client.load_mcp_tools( - session=server.session, format="openai" - ) - return (server.name, server_tools) - except Exception as e: - if server.name != "unnamed-server" and server.name != "Local": - self.io.tool_warning(f"Error initializing MCP server {server.name}: {e}") - return None - - async def get_all_server_tools(): - tasks = [get_server_tools(server) for server in self.mcp_manager if server.is_enabled] - results = await asyncio.gather(*tasks) - return [result for result in results if result is not None] - - if self.mcp_manager: - # Retry initialization in case of CancelledError - max_retries = 3 - for i in range(max_retries): - try: - tools = await get_all_server_tools() - break - except asyncio.exceptions.CancelledError: - if i < max_retries - 1: - await asyncio.sleep(0.1) # Brief pause before retrying - else: - self.io.tool_warning( - "MCP tool initialization failed after multiple retries due to" - " cancellation." - ) - tools = [] - - if len(tools) > 0: - if self.verbose: - self.io.tool_output("MCP servers configured:") - - for server_name, server_tools in tools: - self.io.tool_output(f" - {server_name}") + @property + def mcp_tools(self): + if not self.mcp_manager: + return [] - for tool in server_tools: - tool_name = tool.get("function", {}).get("name", "unknown") - tool_desc = tool.get("function", {}).get("description", "").split("\n")[0] - self.io.tool_output(f" - {tool_name}: {tool_desc}") + return list(self.mcp_manager.all_tools.items()) - self.mcp_tools = tools + @mcp_tools.setter + def mcp_tools(self, value): + raise AttributeError("mcp_tools is read only.") def get_tool_list(self): """Get a flattened list of all MCP tools.""" diff --git a/cecli/commands/__init__.py b/cecli/commands/__init__.py index b399b855792..d860deebd24 100644 --- a/cecli/commands/__init__.py +++ b/cecli/commands/__init__.py @@ -31,6 +31,7 @@ from .lint import LintCommand from .list_sessions import ListSessionsCommand from .load import LoadCommand +from .load_mcp import LoadMcpCommand from .load_session import LoadSessionCommand from .load_skill import LoadSkillCommand from .ls import LsCommand @@ -44,6 +45,7 @@ from .read_only import ReadOnlyCommand from .read_only_stub import ReadOnlyStubCommand from .reasoning_effort import ReasoningEffortCommand +from .remove_mcp import RemoveMcpCommand from .remove_skill import RemoveSkillCommand from .report import ReportCommand from .reset import ResetCommand @@ -125,6 +127,8 @@ CommandRegistry.register(LoadSkillCommand) CommandRegistry.register(RemoveSkillCommand) CommandRegistry.register(TerminalSetupCommand) +CommandRegistry.register(LoadMcpCommand) +CommandRegistry.register(RemoveMcpCommand) __all__ = [ @@ -192,4 +196,6 @@ "TerminalSetupCommand", "SwitchCoderSignal", "Commands", + "LoadMcpCommand", + "RemoveMcpCommand", ] diff --git a/cecli/commands/load_mcp.py b/cecli/commands/load_mcp.py new file mode 100644 index 00000000000..ad19ebc0b62 --- /dev/null +++ b/cecli/commands/load_mcp.py @@ -0,0 +1,77 @@ +from typing import List + +from cecli.commands.utils.base_command import BaseCommand +from cecli.commands.utils.helpers import format_command_result + + +class LoadMcpCommand(BaseCommand): + NORM_NAME = "load-mcp" + DESCRIPTION = "Load a MCP server by name" + + @classmethod + async def execute(cls, io, coder, args, **kwargs): + """Execute the load-mcp command with given parameters.""" + if not args.strip(): + return format_command_result(io, cls.NORM_NAME, "Usage: /load-mcp ") + + if not coder.mcp_manager or not coder.mcp_manager.servers: + return format_command_result( + io, cls.NORM_NAME, "No MCP servers found, nothing to load." + ) + + server_name = args.strip() + server = coder.mcp_manager.get_server(server_name) + if server is None: + return format_command_result( + io, cls.NORM_NAME, "", f"MCP server {server_name} does not exist." + ) + + did_connect = await coder.mcp_manager.connect_server(server.name) + + if not did_connect: + return format_command_result(io, cls.NORM_NAME, f"Unable to load server: {server_name}") + + try: + if did_connect: + return format_command_result(io, cls.NORM_NAME, f"Loaded server: {server_name}") + else: + return format_command_result( + io, cls.NORM_NAME, "", f"Unable to Load server: {server_name}" + ) + finally: + from . import SwitchCoderSignal + + raise SwitchCoderSignal( + edit_format=coder.edit_format, + summarize_from_coder=False, + from_coder=coder, + show_announcements=True, + ) + + @classmethod + def get_completions(cls, io, coder, args) -> List[str]: + """Get completion options for load-mcp command.""" + if not coder.mcp_manager or not coder.mcp_manager.servers: + return [] + + try: + server_names = [ + server.name + for server in coder.mcp_manager + if server not in coder.mcp_manager.connected_servers + ] + return server_names + except Exception: + return [] + + @classmethod + def get_help(cls) -> str: + """Get help text for the load-mcp command.""" + help_text = super().get_help() + help_text += "\nUsage:\n" + help_text += " /load-mcp # Load a mcp by name\n" + help_text += "\nExamples:\n" + help_text += " /load-mcp context7 # Load the context7 mcp\n" + help_text += " /load-mcp github # Load the github mcp\n" + help_text += "\nThis command loads a MCP server by name.\n" + return help_text diff --git a/cecli/commands/remove_mcp.py b/cecli/commands/remove_mcp.py new file mode 100644 index 00000000000..9350a9670d8 --- /dev/null +++ b/cecli/commands/remove_mcp.py @@ -0,0 +1,65 @@ +from typing import List + +from cecli.commands.utils.base_command import BaseCommand +from cecli.commands.utils.helpers import format_command_result + + +class RemoveMcpCommand(BaseCommand): + NORM_NAME = "remove-mcp" + DESCRIPTION = "Remove a MCP server by name" + + @classmethod + async def execute(cls, io, coder, args, **kwargs): + """Execute the remove-mcp command with given parameters.""" + if not args.strip(): + return format_command_result(io, cls.NORM_NAME, "Usage: /remove-mcp ") + + if not coder.mcp_manager or not coder.mcp_manager.servers: + return format_command_result( + io, cls.NORM_NAME, "No MCP servers connected, nothing to remove." + ) + + server_name = args.strip() + was_disconnected = await coder.mcp_manager.disconnect_server(server_name) + + try: + if was_disconnected: + return format_command_result(io, cls.NORM_NAME, f"Removed server: {server_name}") + else: + return format_command_result( + io, cls.NORM_NAME, "", f"Unable to remove server: {server_name}" + ) + finally: + from . import SwitchCoderSignal + + raise SwitchCoderSignal( + edit_format=coder.edit_format, + summarize_from_coder=False, + from_coder=coder, + show_announcements=True, + mcp_manager=coder.mcp_manager, + ) + + @classmethod + def get_completions(cls, io, coder, args) -> List[str]: + """Get completion options for remove-mcp command.""" + if not coder.mcp_manager or not coder.mcp_manager.servers: + return [] + + try: + server_names = [server.name for server in coder.mcp_manager if server.is_connected] + return server_names + except Exception: + return [] + + @classmethod + def get_help(cls) -> str: + """Get help text for the remove-mcp command.""" + help_text = super().get_help() + help_text += "\nUsage:\n" + help_text += " /remove-mcp # Remove a mcp by name\n" + help_text += "\nExamples:\n" + help_text += " /remove-mcp context7 # Remove the context7 mcp\n" + help_text += " /remove-mcp github # Remove the github mcp\n" + help_text += "\nThis command removes a MCP server by name.\n" + return help_text diff --git a/cecli/main.py b/cecli/main.py index 5c0286bcddf..32fba8fee5d 100644 --- a/cecli/main.py +++ b/cecli/main.py @@ -983,7 +983,7 @@ def apply_model_overrides(model_name): mcp_servers = load_mcp_servers( args.mcp_servers, args.mcp_servers_file, io, args.verbose, args.mcp_transport ) - mcp_manager = McpServerManager(mcp_servers, io, args.verbose) + mcp_manager = await McpServerManager.from_servers(mcp_servers, io, args.verbose) coder = await Coder.create( main_model=main_model, diff --git a/cecli/mcp/manager.py b/cecli/mcp/manager.py index 78d25f2896c..6e795da397d 100644 --- a/cecli/mcp/manager.py +++ b/cecli/mcp/manager.py @@ -1,7 +1,9 @@ import asyncio -import logging -from cecli.mcp.server import McpServer +from litellm import experimental_mcp_client + +from cecli.mcp.server import LocalServer, McpServer +from cecli.tools.utils.registry import ToolRegistry class McpServerManager: @@ -73,35 +75,6 @@ def get_server(self, name: str) -> McpServer | None: except StopIteration: return None - async def connect_all(self) -> None: - """Connect to all MCP servers while skipping ones that are not enabled.""" - if self.is_connected: - self._log_verbose("Some MCP servers already connected") - return - - self._log_verbose(f"Connecting to {len(self._servers)} MCP servers") - - async def connect_server(server: McpServer) -> tuple[McpServer, bool]: - try: - session = await server.connect() - tools_result = await session.list_tools() - self._server_tools[server.name] = tools_result.tools - self._log_verbose(f"Connected to MCP server: {server.name}") - return (server, True) - except Exception as e: - if server.name != "unnamed-server": - logging.error(f"Error connecting to MCP server {server.name}: {e}") - self._log_error(f"Failed to connect to MCP server {server.name}: {e}") - return (server, False) - - results = await asyncio.gather( - *[connect_server(server) for server in self._servers if server.is_enabled] - ) - - for server, success in results: - if success: - self._connected_servers.add(server) - async def disconnect_all(self) -> None: """Disconnect from all MCP servers.""" if not self._connected_servers: @@ -145,24 +118,27 @@ async def connect_server(self, name: str) -> bool: self._log_warning(f"MCP server not found: {name}") return False - if not server.is_enabled: - self._log_verbose("MCP is not enabled.") - return False - if server in self._connected_servers: self._log_verbose(f"MCP server already connected: {name}") return True + # We will handle local server differently since its only used for internal usage + # We'll pretend we connect and fetched all tools + if isinstance(server, LocalServer): + await server.connect() + self._connected_servers.add(server) + self._server_tools[server.name] = get_local_tool_schemas() + return True + try: session = await server.connect() - tools_result = await session.list_tools() - self._server_tools[server.name] = tools_result.tools + tools = await experimental_mcp_client.load_mcp_tools(session=session, format="openai") + self._server_tools[server.name] = tools self._connected_servers.add(server) self._log_verbose(f"Connected to MCP server: {name}") return True except Exception as e: if server.name != "unnamed-server": - logging.error(f"Error connecting to MCP server {name}: {e}") self._log_error(f"Failed to connect to MCP server {name}: {e}") return False @@ -234,7 +210,7 @@ def __iter__(self): for server in self._servers: yield server - def get_server_tools(self, name: str) -> list | None: + def get_server_tools(self, name: str) -> list: """ Get the tools for a specific server. @@ -242,9 +218,9 @@ def get_server_tools(self, name: str) -> list | None: name: Name of the server Returns: - List of tools or None if server not found or not connected + List of tools or empty list if server not found or not connected """ - return self._server_tools.get(name) + return self._server_tools.get(name, list()) @property def all_tools(self) -> dict[str, list]: @@ -255,3 +231,62 @@ def all_tools(self) -> dict[str, list]: Dictionary mapping server names to their tools """ return self._server_tools.copy() + + @classmethod + async def from_servers( + cls, servers: list[McpServer], io=None, verbose: bool = False + ) -> "McpServerManager": + """ + Create an MCP Server Manager from a list of servers it should manage. + Automatically connects if the server is set to auto connect (by default it is) + """ + mcp_manager = cls(servers=[], io=io, verbose=verbose) + + async def add_server_with_retry( + server: McpServer, connect: bool = True, max_retries: int = 3 + ) -> tuple[McpServer, bool]: + """Try to add and connect to a server with retries.""" + if not connect: + success = await mcp_manager.add_server(server, connect=False) + return (server, success) + + for _attempt in range(max_retries): + success = await mcp_manager.add_server(server, connect=True) + if success: + return (server, True) + return (server, False) + + tasks = [] + for server in servers: + auto_connect = server.config.get("enabled", True) + tasks.append(add_server_with_retry(server, connect=auto_connect)) + + results = await asyncio.gather(*tasks) + for server, did_connect in results: + if not did_connect and server.name not in ["unnamed-server", "Local"]: + io.tool_warning( + f"MCP tool initialization failed after multiple retries: {server.name}" + ) + + if verbose: + io.tool_output("MCP servers configured:") + + for server, _ in results: + io.tool_output(f" - {server.name}") + + for tool in mcp_manager.get_server_tools(server.name): + tool_name = tool.get("function", {}).get("name", "unknown") + tool_desc = tool.get("function", {}).get("description", "").split("\n")[0] + io.tool_output(f" - {tool_name}: {tool_desc}") + + return mcp_manager + + +def get_local_tool_schemas(): + """Returns the JSON schemas for all local tools using the tool registry.""" + schemas = [] + for tool_name in ToolRegistry.get_registered_tools(): + tool_module = ToolRegistry.get_tool(tool_name) + if hasattr(tool_module, "SCHEMA"): + schemas.append(tool_module.SCHEMA) + return schemas diff --git a/cecli/mcp/server.py b/cecli/mcp/server.py index 221739288b4..e4769f8d0ac 100644 --- a/cecli/mcp/server.py +++ b/cecli/mcp/server.py @@ -39,13 +39,17 @@ def __init__(self, server_config, io=None, verbose=False): """ self.config = server_config self.name = server_config.get("name", "unnamed-server") - self.is_enabled = server_config.get("enabled", True) self.io = io self.verbose = verbose self.session = None self._cleanup_lock: asyncio.Lock = asyncio.Lock() self.exit_stack = AsyncExitStack() + @property + def is_connected(self) -> bool: + """Check if this server is currently connected.""" + return self.session is not None + async def connect(self): """Connect to the MCP server and return the session. @@ -55,11 +59,6 @@ async def connect(self): Returns: ClientSession: The active session if mcp is not disabled """ - if not self.is_enabled: - if self.verbose and self.io: - self.io.tool_output(f"Enabled option is set to false for MCP server: {self.name}") - return None - if self.session is not None: if self.verbose and self.io: self.io.tool_output(f"Using existing session for MCP server: {self.name}") @@ -194,11 +193,6 @@ def _create_transport(self, url, http_client): raise NotImplementedError("Subclasses must implement _create_transport") async def connect(self): - if not self.is_enabled: - if self.verbose and self.io: - self.io.tool_output(f"Enabled option is set to false for MCP server: {self.name}") - return None - if self.session is not None: if self.verbose and self.io: self.io.tool_output(f"Using existing session for {self.name}") diff --git a/cecli/website/docs/usage/commands.md b/cecli/website/docs/usage/commands.md index 34f13a026ae..b8007c84bb3 100644 --- a/cecli/website/docs/usage/commands.md +++ b/cecli/website/docs/usage/commands.md @@ -42,6 +42,7 @@ cog.out(get_help_md()) | **/history-search** | Fuzzy search your command history and paste the selected command into the chat. | | **/lint** | Lint and fix in-chat files or all dirty files if none in chat | | **/load** | Load and execute commands from a file | +| **/load-mcp** | Load a MCP server by name | | **/ls** | List all known files and indicate which are included in the chat session | | **/map** | Print out the current repository map | | **/map-refresh** | Force a refresh of the repository map | @@ -53,6 +54,7 @@ cog.out(get_help_md()) | **/read-only** | Add files to the chat that are for reference only, or turn added files to read-only | | **/reasoning-effort** | Set the reasoning effort level (values: number or low/medium/high depending on model) | | **/report** | Report a problem by opening a GitHub Issue | +| **/remove-mcp** | Remove a MCP server by name | | **/reset** | Drop all files and clear the chat history | | **/run** | Run a shell command and optionally add the output to the chat (alias: !) | | **/save** | Save commands to a file that can reconstruct the current chat session's files | From 49c53b28e76e38e476598056b1f7e87323cea8d4 Mon Sep 17 00:00:00 2001 From: Gopar Date: Tue, 13 Jan 2026 11:11:52 -0800 Subject: [PATCH 2/3] [gh-392] Add tests --- cecli/coders/agent_coder.py | 6 +- cecli/coders/base_coder.py | 8 - tests/basic/test_coder.py | 136 +-------------- tests/mcp/__init__.py | 0 tests/mcp/test_manager.py | 319 ++++++++++++++++++++++++++++++++++++ 5 files changed, 324 insertions(+), 145 deletions(-) create mode 100644 tests/mcp/__init__.py create mode 100644 tests/mcp/test_manager.py diff --git a/cecli/coders/agent_coder.py b/cecli/coders/agent_coder.py index 0e1134580ca..c1b67faf5a3 100644 --- a/cecli/coders/agent_coder.py +++ b/cecli/coders/agent_coder.py @@ -165,6 +165,7 @@ def _get_agent_config(self): config["tools_excludelist"].append("removeskill") self._initialize_skills_manager(config) + self._initialize_mcp_tools() return config def _initialize_skills_manager(self, config): @@ -202,10 +203,7 @@ def get_local_tool_schemas(self): schemas.append(tool_module.SCHEMA) return schemas - async def initialize_mcp_tools(self): - # TODO(Gopar): update this part - await super().initialize_mcp_tools() - + async def _initialize_mcp_tools(self): if not self.mcp_manager: self.mcp_manager = McpServerManager() diff --git a/cecli/coders/base_coder.py b/cecli/coders/base_coder.py index e0443f06a40..7c5938907c0 100755 --- a/cecli/coders/base_coder.py +++ b/cecli/coders/base_coder.py @@ -255,8 +255,6 @@ async def create( # Transfer TUI app weak reference res.tui = from_coder.tui - await res.initialize_mcp_tools() - res.original_kwargs = dict(kwargs) return res @@ -2744,12 +2742,6 @@ async def _execute_all_tool_calls(): return tool_responses - async def initialize_mcp_tools(self): - """ - Initialize tools from all configured MCP servers. MCP Servers that fail to be - initialized will not be available to the Coder instance. - """ - @property def mcp_tools(self): if not self.mcp_manager: diff --git a/tests/basic/test_coder.py b/tests/basic/test_coder.py index 14ed1ff6652..e302dc6be35 100644 --- a/tests/basic/test_coder.py +++ b/tests/basic/test_coder.py @@ -1434,136 +1434,6 @@ async def test_architect_coder_auto_accept_false_rejected(self): io.confirm_ask.assert_called_once_with("Edit the files?", allow_tweak=False) mock_create.assert_not_called() - @patch("cecli.coders.base_coder.experimental_mcp_client") - async def test_mcp_server_connection(self, mock_mcp_client): - """Test that the coder connects to MCP servers for tools.""" - with GitTemporaryDirectory(): - io = InputOutput(yes=True) - - # Create mock MCP server - mock_server = MagicMock() - mock_server.name = "test_server" - mock_server.connect = MagicMock() - mock_server.disconnect = MagicMock() - - # Setup mock for initialize_mcp_tools - mock_tools = [("test_server", [{"function": {"name": "test_tool"}}])] - - # Create coder with mock MCP server - with patch.object(Coder, "initialize_mcp_tools", return_value=mock_tools): - coder = await Coder.create(self.GPT35, "diff", io=io) - - # Manually set mcp_tools since we're bypassing initialize_mcp_tools - coder.mcp_tools = mock_tools - - # Verify that mcp_tools contains the expected data - assert coder.mcp_tools is not None - assert len(coder.mcp_tools) == 1 - assert coder.mcp_tools[0][0] == "test_server" - - @patch("cecli.coders.base_coder.experimental_mcp_client") - async def test_coder_creation_with_partial_failed_mcp_server(self, mock_mcp_client): - """Test that a coder can still be created even if an MCP server fails to initialize.""" - with GitTemporaryDirectory(): - io = InputOutput(yes=True) - io.tool_warning = MagicMock() - - # Create mock MCP servers - one working, one failing - working_server = AsyncMock() - working_server.name = "working_server" - working_server.connect = AsyncMock() - working_server.disconnect = AsyncMock() - - failing_server = AsyncMock() - failing_server.name = "failing_server" - failing_server.connect = AsyncMock() - failing_server.disconnect = AsyncMock() - - manager = McpServerManager([working_server, failing_server]) - manager._connected_servers = [working_server] - - # Mock load_mcp_tools to succeed for working_server and fail for failing_server - async def mock_load_mcp_tools(session, format): - if session == working_server.session: - return [{"function": {"name": "working_tool"}}] - else: - raise Exception("Failed to load tools") - - mock_mcp_client.load_mcp_tools = AsyncMock(side_effect=mock_load_mcp_tools) - - # Create coder with both servers - coder = await Coder.create( - self.GPT35, - "diff", - io=io, - mcp_manager=manager, - verbose=True, - ) - - # Verify that coder was created successfully - assert isinstance(coder, Coder) - - # Verify that only the working server's tools were added - assert coder.mcp_tools is not None - assert len(coder.mcp_tools) == 1 - assert coder.mcp_tools[0][0] == "working_server" - - # Verify that the tool list contains only working tools - tool_list = coder.get_tool_list() - assert len(tool_list) == 1 - assert tool_list[0]["function"]["name"] == "working_tool" - - # Verify that the warning was logged for the failing server - io.tool_warning.assert_called_with( - "Error initializing MCP server failing_server: Failed to load tools" - ) - - @patch("cecli.coders.base_coder.experimental_mcp_client") - async def test_coder_creation_with_all_failed_mcp_server(self, mock_mcp_client): - """Test that a coder can still be created even if an MCP server fails to initialize.""" - with GitTemporaryDirectory(): - io = InputOutput(yes=True) - io.tool_warning = MagicMock() - - failing_server = AsyncMock() - failing_server.name = "failing_server" - failing_server.connect = AsyncMock() - failing_server.disconnect = AsyncMock() - - manager = McpServerManager([failing_server]) - manager._connected_servers = [] - - # Mock load_mcp_tools to succeed for working_server and fail for failing_server - async def mock_load_mcp_tools(session, format): - raise Exception("Failed to load tools") - - mock_mcp_client.load_mcp_tools = AsyncMock(side_effect=mock_load_mcp_tools) - - # Create coder with both servers - coder = await Coder.create( - self.GPT35, - "diff", - io=io, - mcp_manager=manager, - verbose=True, - ) - - # Verify that coder was created successfully - assert isinstance(coder, Coder) - - # Verify that only the working server's tools were added - assert coder.mcp_tools is not None - assert len(coder.mcp_tools) == 0 - - # Verify that the tool list contains only working tools - tool_list = coder.get_tool_list() - assert len(tool_list) == 0 - - # Verify that the warning was logged for the failing server - io.tool_warning.assert_called_with( - "Error initializing MCP server failing_server: Failed to load tools" - ) - async def test_process_tool_calls_none_response(self): """Test that process_tool_calls handles None response correctly.""" with GitTemporaryDirectory(): @@ -1622,8 +1492,8 @@ async def test_process_tool_calls_with_tools(self): ) # Create coder with mock MCP tools and servers + manager._server_tools[mock_server.name] = [{"function": {"name": "test_tool"}}] coder = await Coder.create(self.GPT35, "diff", io=io, mcp_manager=manager) - coder.mcp_tools = [("test_server", [{"function": {"name": "test_tool"}}])] # Mock _execute_tool_calls to return tool responses tool_responses = [ @@ -1677,9 +1547,9 @@ async def test_process_tool_calls_max_calls_exceeded(self): manager._connected_servers = [mock_server] # Create coder with max tool calls exceeded + manager._server_tools[mock_server.name] = [{"function": {"name": "test_tool"}}] coder = await Coder.create(self.GPT35, "diff", io=io, mcp_manager=manager) coder.num_tool_calls = coder.max_tool_calls - coder.mcp_tools = [("test_server", [{"function": {"name": "test_tool"}}])] # Test process_tool_calls result = await coder.process_tool_calls(response) @@ -1719,8 +1589,8 @@ async def test_process_tool_calls_user_rejects(self): manager._connected_servers = [mock_server] # Create coder with mock MCP tools + manager._server_tools[mock_server.name] = [{"function": {"name": "test_tool"}}] coder = await Coder.create(self.GPT35, "diff", io=io, mcp_manager=manager) - coder.mcp_tools = [("test_server", [{"function": {"name": "test_tool"}}])] # Test process_tool_calls result = await coder.process_tool_calls(response) diff --git a/tests/mcp/__init__.py b/tests/mcp/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/mcp/test_manager.py b/tests/mcp/test_manager.py new file mode 100644 index 00000000000..8c5ee5eb6b1 --- /dev/null +++ b/tests/mcp/test_manager.py @@ -0,0 +1,319 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from cecli.mcp.manager import McpServerManager +from cecli.mcp.server import LocalServer, McpServer + + +@pytest.fixture +def mock_io(): + io = MagicMock() + io.tool_output = MagicMock() + io.tool_error = MagicMock() + io.tool_warning = MagicMock() + return io + + +@pytest.fixture +def mock_server(): + server = MagicMock(spec=McpServer) + server.name = "test-server" + server.config = {"name": "test-server", "enabled": True} + server.connect = AsyncMock() + server.disconnect = AsyncMock() + server.is_connected = False + return server + + +@pytest.fixture +def mock_local_server(): + server = MagicMock(spec=LocalServer) + server.name = "Local" + server.config = {"name": "Local", "enabled": True} + server.connect = AsyncMock() + server.disconnect = AsyncMock() + server.is_connected = False + return server + + +@pytest.fixture +def mock_tools(): + return [ + { + "function": { + "name": "test_tool", + "description": "A test tool", + "parameters": {}, + } + } + ] + + +class TestMcpServerManager: + def test_manager_init(self, mock_io): + manager = McpServerManager(servers=[], io=mock_io, verbose=True) + + assert manager.io == mock_io + assert manager.verbose is True + assert manager._servers == [] + assert manager._server_tools == {} + assert manager._connected_servers == set() + + def test_manager_servers_property(self, mock_server): + manager = McpServerManager(servers=[mock_server]) + + assert manager.servers == [mock_server] + + def test_manager_is_connected_false_initially(self): + manager = McpServerManager(servers=[]) + + assert manager.is_connected is False + assert manager.connected_servers == [] + + def test_manager_failed_servers(self, mock_server): + manager = McpServerManager(servers=[mock_server]) + + assert manager.failed_servers == [mock_server] + + # Add to connected set + manager._connected_servers.add(mock_server) + + assert manager.failed_servers == [] + + def test_get_server_found(self, mock_server): + manager = McpServerManager(servers=[mock_server]) + + result = manager.get_server("test-server") + + assert result is mock_server + + def test_get_server_not_found(self, mock_server): + manager = McpServerManager(servers=[mock_server]) + + result = manager.get_server("nonexistent-server") + + assert result is None + + @pytest.mark.asyncio + async def test_connect_server_not_found(self, mock_io): + manager = McpServerManager(servers=[], io=mock_io) + + result = await manager.connect_server("nonexistent-server") + + assert result is False + mock_io.tool_warning.assert_called_once() + + @pytest.mark.asyncio + async def test_connect_server_already_connected(self, mock_server, mock_io): + manager = McpServerManager(servers=[mock_server], io=mock_io, verbose=True) + manager._connected_servers.add(mock_server) + + result = await manager.connect_server("test-server") + + assert result is True + mock_io.tool_output.assert_called_once() + mock_server.connect.assert_not_called() + + @pytest.mark.asyncio + async def test_connect_server_local_server(self, mock_local_server): + manager = McpServerManager(servers=[mock_local_server]) + + with patch("cecli.mcp.manager.get_local_tool_schemas") as mock_get_schemas: + mock_get_schemas.return_value = [{"name": "local_tool"}] + result = await manager.connect_server("Local") + + assert result is True + mock_local_server.connect.assert_called_once() + assert mock_local_server in manager._connected_servers + assert manager._server_tools["Local"] == [{"name": "local_tool"}] + + @pytest.mark.asyncio + async def test_connect_server_success(self, mock_server, mock_tools): + manager = McpServerManager(servers=[mock_server]) + mock_session = MagicMock() + mock_server.connect.return_value = mock_session + + with patch("litellm.experimental_mcp_client.load_mcp_tools") as mock_load_tools: + mock_load_tools.return_value = mock_tools + result = await manager.connect_server("test-server") + + assert result is True + mock_server.connect.assert_called_once() + mock_load_tools.assert_called_once_with(session=mock_session, format="openai") + assert mock_server in manager._connected_servers + assert manager._server_tools["test-server"] == mock_tools + + @pytest.mark.asyncio + async def test_connect_server_failure(self, mock_server, mock_io): + manager = McpServerManager(servers=[mock_server], io=mock_io) + mock_server.connect.side_effect = Exception("Connection failed") + + result = await manager.connect_server("test-server") + + assert result is False + mock_server.connect.assert_called_once() + mock_io.tool_error.assert_called_once() + assert mock_server not in manager._connected_servers + + @pytest.mark.asyncio + async def test_disconnect_server_not_found(self, mock_io): + manager = McpServerManager(servers=[], io=mock_io) + + result = await manager.disconnect_server("nonexistent-server") + + assert result is False + mock_io.tool_warning.assert_called_once() + + @pytest.mark.asyncio + async def test_disconnect_server_not_connected(self, mock_server, mock_io): + manager = McpServerManager(servers=[mock_server], io=mock_io, verbose=True) + + result = await manager.disconnect_server("test-server") + + assert result is True + mock_io.tool_output.assert_called_once() + mock_server.disconnect.assert_not_called() + + @pytest.mark.asyncio + async def test_disconnect_server_success(self, mock_server, mock_io): + manager = McpServerManager(servers=[mock_server], io=mock_io, verbose=True) + manager._connected_servers.add(mock_server) + manager._server_tools["test-server"] = [{"name": "test_tool"}] + + result = await manager.disconnect_server("test-server") + + assert result is True + mock_server.disconnect.assert_called_once() + assert "test-server" not in manager._server_tools + assert mock_server not in manager._connected_servers + + @pytest.mark.asyncio + async def test_disconnect_all_no_servers(self, mock_io): + manager = McpServerManager(servers=[], io=mock_io, verbose=True) + + await manager.disconnect_all() + + mock_io.tool_output.assert_called_once_with("MCP servers already disconnected") + + @pytest.mark.asyncio + async def test_disconnect_all_multiple_servers(self, mock_server, mock_io): + server1 = MagicMock(spec=McpServer) + server1.name = "server1" + server1.disconnect = AsyncMock() + + server2 = MagicMock(spec=McpServer) + server2.name = "server2" + server2.disconnect = AsyncMock() + + manager = McpServerManager(servers=[server1, server2], io=mock_io, verbose=True) + manager._connected_servers.add(server1) + manager._connected_servers.add(server2) + manager._server_tools = {"server1": [], "server2": []} + + await manager.disconnect_all() + + server1.disconnect.assert_called_once() + server2.disconnect.assert_called_once() + assert manager._connected_servers == set() + assert "server1" not in manager._server_tools + assert "server2" not in manager._server_tools + + @pytest.mark.asyncio + async def test_add_server_success(self, mock_server, mock_io): + manager = McpServerManager(servers=[], io=mock_io, verbose=True) + + result = await manager.add_server(mock_server, connect=False) + + assert result is True + assert manager._servers == [mock_server] + mock_io.tool_output.assert_called_once() + mock_server.connect.assert_not_called() + + @pytest.mark.asyncio + async def test_add_server_duplicate_name(self, mock_server, mock_io): + manager = McpServerManager(servers=[mock_server], io=mock_io) + + duplicate_server = MagicMock(spec=McpServer) + duplicate_server.name = "test-server" + + result = await manager.add_server(duplicate_server) + + assert result is False + mock_io.tool_warning.assert_called_once() + + @pytest.mark.asyncio + async def test_add_server_with_connect(self, mock_server, mock_io): + manager = McpServerManager(servers=[], io=mock_io) + + # Mock connect_server to return True + manager.connect_server = AsyncMock(return_value=True) + + result = await manager.add_server(mock_server, connect=True) + + assert result is True + assert manager._servers == [mock_server] + manager.connect_server.assert_called_once_with("test-server") + + def test_get_server_tools_found(self, mock_server): + manager = McpServerManager(servers=[mock_server]) + tools = [{"name": "test_tool"}] + manager._server_tools["test-server"] = tools + + result = manager.get_server_tools("test-server") + + assert result == tools + + def test_get_server_tools_not_found(self, mock_server): + manager = McpServerManager(servers=[mock_server]) + + result = manager.get_server_tools("nonexistent-server") + + assert result == [] + + def test_all_tools_returns_copy(self, mock_server): + manager = McpServerManager(servers=[mock_server]) + tools = {"test-server": [{"name": "test_tool"}]} + manager._server_tools = tools + + result = manager.all_tools + + assert result == tools + assert result is not tools # Should be a copy + + @pytest.mark.asyncio + async def test_from_servers_creates_manager(self, mock_server, mock_io, mock_tools): + with patch("litellm.experimental_mcp_client.load_mcp_tools") as mock_load_tools: + mock_load_tools.return_value = mock_tools + mock_session = MagicMock() + mock_server.connect.return_value = mock_session + + manager = await McpServerManager.from_servers( + servers=[mock_server], io=mock_io, verbose=True + ) + + assert isinstance(manager, McpServerManager) + assert manager._servers == [mock_server] + assert mock_server in manager._connected_servers + mock_server.connect.assert_called_once() + mock_load_tools.assert_called_once() + + @pytest.mark.asyncio + async def test_from_servers_skips_disabled(self, mock_io): + disabled_server = MagicMock(spec=McpServer) + disabled_server.name = "disabled-server" + disabled_server.config = {"name": "disabled-server", "enabled": False} + disabled_server.connect = AsyncMock() + + manager = await McpServerManager.from_servers(servers=[disabled_server], io=mock_io) + + assert manager._servers == [disabled_server] + assert disabled_server not in manager._connected_servers + disabled_server.connect.assert_not_called() + + def test_manager_iteration(self, mock_server): + manager = McpServerManager(servers=[mock_server]) + + servers = list(manager) + + assert servers == [mock_server] From 569ff9053f476a1277cd5361170ddcc4c5d8bd5f Mon Sep 17 00:00:00 2001 From: Gopar Date: Thu, 15 Jan 2026 18:38:31 -0800 Subject: [PATCH 3/3] [gh-392] Remove local server when not in agent mode --- cecli/coders/agent_coder.py | 17 ++++++++++------- cecli/coders/base_coder.py | 8 ++++++++ cecli/main.py | 13 ++++++++++++- 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/cecli/coders/agent_coder.py b/cecli/coders/agent_coder.py index c1b67faf5a3..d88851c0f20 100644 --- a/cecli/coders/agent_coder.py +++ b/cecli/coders/agent_coder.py @@ -165,7 +165,6 @@ def _get_agent_config(self): config["tools_excludelist"].append("removeskill") self._initialize_skills_manager(config) - self._initialize_mcp_tools() return config def _initialize_skills_manager(self, config): @@ -203,15 +202,16 @@ def get_local_tool_schemas(self): schemas.append(tool_module.SCHEMA) return schemas - async def _initialize_mcp_tools(self): + async def initialize_mcp_tools(self): if not self.mcp_manager: self.mcp_manager = McpServerManager() server_name = "Local" server = self.mcp_manager.get_server(server_name) - # We have already initialized local server, no need to duplicate work - if server is not None: + # We have already initialized local server and its connected + # then no need to duplicate work + if server is not None and server.is_connected: return # If we dont have any tools for local server to use, no point in creating it then @@ -219,10 +219,13 @@ async def _initialize_mcp_tools(self): if not local_tools: return - local_server_config = {"name": server_name} - local_server = LocalServer(local_server_config) + if server is None: + local_server_config = {"name": server_name} + local_server = LocalServer(local_server_config) - await self.mcp_manager.add_server(local_server, connect=True) + await self.mcp_manager.add_server(local_server, connect=True) + else: + await self.mcp_manager.connect_server(server_name) async def _execute_local_tool_calls(self, tool_calls_list): tool_responses = [] diff --git a/cecli/coders/base_coder.py b/cecli/coders/base_coder.py index 7c5938907c0..5fb6d83868a 100755 --- a/cecli/coders/base_coder.py +++ b/cecli/coders/base_coder.py @@ -255,6 +255,8 @@ async def create( # Transfer TUI app weak reference res.tui = from_coder.tui + await res.initialize_mcp_tools() + res.original_kwargs = dict(kwargs) return res @@ -2742,6 +2744,12 @@ async def _execute_all_tool_calls(): return tool_responses + async def initialize_mcp_tools(self): + """ + Any setup that needs to happen for MCP Servers so that coder can use it properly + """ + pass + @property def mcp_tools(self): if not self.mcp_manager: diff --git a/cecli/main.py b/cecli/main.py index 32fba8fee5d..9ca64c8f2a4 100644 --- a/cecli/main.py +++ b/cecli/main.py @@ -37,7 +37,7 @@ from cecli import __version__, models, urls, utils from cecli.args import get_parser -from cecli.coders import Coder +from cecli.coders import AgentCoder, Coder from cecli.coders.base_coder import UnknownEditFormat from cecli.commands import Commands, SwitchCoderSignal from cecli.deprecated_args import handle_deprecated_model_args @@ -1189,17 +1189,28 @@ def apply_model_overrides(model_name): return await graceful_exit(coder) except SwitchCoderSignal as switch: coder.ok_to_warm_cache = False + if hasattr(switch, "placeholder") and switch.placeholder is not None: io.placeholder = switch.placeholder kwargs = dict(io=io, from_coder=coder) kwargs.update(switch.kwargs) + if "show_announcements" in kwargs: del kwargs["show_announcements"] kwargs["num_cache_warming_pings"] = 0 kwargs["args"] = coder.args + + if kwargs["edit_format"] != AgentCoder.edit_format and ( + coder := kwargs.get("from_coder") + ): + if coder.mcp_manager.get_server("Local"): + await coder.mcp_manager.disconnect_server("Local") + coder = await Coder.create(**kwargs) + if switch.kwargs.get("show_announcements") is False: coder.suppress_announcements_for_next_prompt = True + except SystemExit: sys.settrace(None) return await graceful_exit(coder)