Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions cecli/coders/agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,25 +203,29 @@ def get_local_tool_schemas(self):
return schemas

async def initialize_mcp_tools(self):
await super().initialize_mcp_tools()
if not self.mcp_manager:
self.mcp_manager = McpServerManager()

server_name = "Local"
if server_name not in [name for name, _ in self.mcp_tools]:
local_tools = self.get_local_tool_schemas()
if not local_tools:
return
server = self.mcp_manager.get_server(server_name)

# We have already initialized local server and its connected
# then no need to duplicate work
if server is not None and server.is_connected:
return

# If we dont have any tools for local server to use, no point in creating it then
local_tools = self.get_local_tool_schemas()
if not local_tools:
return

if server is None:
local_server_config = {"name": server_name}
local_server = LocalServer(local_server_config)

if not self.mcp_manager:
self.mcp_manager = McpServerManager()
if not self.mcp_manager.get_server(server_name):
await self.mcp_manager.add_server(local_server)
if not self.mcp_tools:
self.mcp_tools = []

if server_name not in [name for name, _ in self.mcp_tools]:
self.mcp_tools.append((local_server.name, local_tools))
await self.mcp_manager.add_server(local_server, connect=True)
else:
await self.mcp_manager.connect_server(server_name)

async def _execute_local_tool_calls(self, tool_calls_list):
tool_responses = []
Expand Down
72 changes: 11 additions & 61 deletions cecli/coders/base_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ class Coder:
commit_language = None
file_watcher = None
mcp_manager = None
mcp_tools = None
run_one_completed = True
compact_context_completed = True
suppress_announcements_for_next_prompt = False
Expand Down Expand Up @@ -228,6 +227,7 @@ async def create(
total_tokens_sent=from_coder.total_tokens_sent,
total_tokens_received=from_coder.total_tokens_received,
file_watcher=from_coder.file_watcher,
mcp_manager=from_coder.mcp_manager,
)
use_kwargs.update(update) # override to complete the switch
use_kwargs.update(kwargs) # override passed kwargs
Expand All @@ -251,7 +251,6 @@ async def create(
if from_coder:
if from_coder.mcp_manager:
res.mcp_manager = from_coder.mcp_manager
res.mcp_tools = from_coder.mcp_tools

# Transfer TUI app weak reference
res.tui = from_coder.tui
Expand Down Expand Up @@ -2747,69 +2746,20 @@ async def _execute_all_tool_calls():

async def initialize_mcp_tools(self):
Comment thread
gopar marked this conversation as resolved.
"""
Initialize tools from all configured MCP servers. MCP Servers that fail to be
initialized will not be available to the Coder instance.
Any setup that needs to happen for MCP Servers so that coder can use it properly
"""
# TODO(@gopar): refactor here once we have fully moved over to use the mcp manager
tools = []

async def get_server_tools(server):
# Check if we already have tools for this server in mcp_tools
if self.mcp_tools:
for server_name, server_tools in self.mcp_tools:
if server_name == server.name:
return (server.name, server_tools)

try:
did_connect = await self.mcp_manager.connect_server(server.name)
if not did_connect:
raise Exception("Failed to load tools")

server = self.mcp_manager.get_server(server.name)
server_tools = await experimental_mcp_client.load_mcp_tools(
session=server.session, format="openai"
)
return (server.name, server_tools)
except Exception as e:
if server.name != "unnamed-server" and server.name != "Local":
self.io.tool_warning(f"Error initializing MCP server {server.name}: {e}")
return None

async def get_all_server_tools():
tasks = [get_server_tools(server) for server in self.mcp_manager if server.is_enabled]
results = await asyncio.gather(*tasks)
return [result for result in results if result is not None]

if self.mcp_manager:
# Retry initialization in case of CancelledError
max_retries = 3
for i in range(max_retries):
try:
tools = await get_all_server_tools()
break
except asyncio.exceptions.CancelledError:
if i < max_retries - 1:
await asyncio.sleep(0.1) # Brief pause before retrying
else:
self.io.tool_warning(
"MCP tool initialization failed after multiple retries due to"
" cancellation."
)
tools = []

if len(tools) > 0:
if self.verbose:
self.io.tool_output("MCP servers configured:")
pass

for server_name, server_tools in tools:
self.io.tool_output(f" - {server_name}")
@property
def mcp_tools(self):
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to reduce size of PR (and minimize breaking things), i'm creating a "wrapper" that keeps the original mcp_tools intent.

if not self.mcp_manager:
return []

for tool in server_tools:
tool_name = tool.get("function", {}).get("name", "unknown")
tool_desc = tool.get("function", {}).get("description", "").split("\n")[0]
self.io.tool_output(f" - {tool_name}: {tool_desc}")
return list(self.mcp_manager.all_tools.items())

self.mcp_tools = tools
@mcp_tools.setter
def mcp_tools(self, value):
raise AttributeError("mcp_tools is read only.")

def get_tool_list(self):
"""Get a flattened list of all MCP tools."""
Expand Down
6 changes: 6 additions & 0 deletions cecli/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .lint import LintCommand
from .list_sessions import ListSessionsCommand
from .load import LoadCommand
from .load_mcp import LoadMcpCommand
from .load_session import LoadSessionCommand
from .load_skill import LoadSkillCommand
from .ls import LsCommand
Expand All @@ -44,6 +45,7 @@
from .read_only import ReadOnlyCommand
from .read_only_stub import ReadOnlyStubCommand
from .reasoning_effort import ReasoningEffortCommand
from .remove_mcp import RemoveMcpCommand
from .remove_skill import RemoveSkillCommand
from .report import ReportCommand
from .reset import ResetCommand
Expand Down Expand Up @@ -125,6 +127,8 @@
CommandRegistry.register(LoadSkillCommand)
CommandRegistry.register(RemoveSkillCommand)
CommandRegistry.register(TerminalSetupCommand)
CommandRegistry.register(LoadMcpCommand)
CommandRegistry.register(RemoveMcpCommand)


__all__ = [
Expand Down Expand Up @@ -192,4 +196,6 @@
"TerminalSetupCommand",
"SwitchCoderSignal",
"Commands",
"LoadMcpCommand",
"RemoveMcpCommand",
]
77 changes: 77 additions & 0 deletions cecli/commands/load_mcp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import List

from cecli.commands.utils.base_command import BaseCommand
from cecli.commands.utils.helpers import format_command_result


class LoadMcpCommand(BaseCommand):
NORM_NAME = "load-mcp"
DESCRIPTION = "Load a MCP server by name"

@classmethod
async def execute(cls, io, coder, args, **kwargs):
"""Execute the load-mcp command with given parameters."""
if not args.strip():
return format_command_result(io, cls.NORM_NAME, "Usage: /load-mcp <mcp-name>")

if not coder.mcp_manager or not coder.mcp_manager.servers:
return format_command_result(
io, cls.NORM_NAME, "No MCP servers found, nothing to load."
)

server_name = args.strip()
server = coder.mcp_manager.get_server(server_name)
if server is None:
return format_command_result(
io, cls.NORM_NAME, "", f"MCP server {server_name} does not exist."
)

did_connect = await coder.mcp_manager.connect_server(server.name)

if not did_connect:
return format_command_result(io, cls.NORM_NAME, f"Unable to load server: {server_name}")

try:
if did_connect:
return format_command_result(io, cls.NORM_NAME, f"Loaded server: {server_name}")
else:
return format_command_result(
io, cls.NORM_NAME, "", f"Unable to Load server: {server_name}"
)
finally:
from . import SwitchCoderSignal

raise SwitchCoderSignal(
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like raising signals is the only way to make things persistent when changing internal states? Otherwise when i would remove/load an mcp, all my changes disappeared when switching do a different mode

edit_format=coder.edit_format,
summarize_from_coder=False,
from_coder=coder,
show_announcements=True,
)

@classmethod
def get_completions(cls, io, coder, args) -> List[str]:
"""Get completion options for load-mcp command."""
if not coder.mcp_manager or not coder.mcp_manager.servers:
return []

try:
server_names = [
server.name
for server in coder.mcp_manager
if server not in coder.mcp_manager.connected_servers
]
return server_names
except Exception:
return []

@classmethod
def get_help(cls) -> str:
"""Get help text for the load-mcp command."""
help_text = super().get_help()
help_text += "\nUsage:\n"
help_text += " /load-mcp <mcp-name> # Load a mcp by name\n"
help_text += "\nExamples:\n"
help_text += " /load-mcp context7 # Load the context7 mcp\n"
help_text += " /load-mcp github # Load the github mcp\n"
help_text += "\nThis command loads a MCP server by name.\n"
return help_text
65 changes: 65 additions & 0 deletions cecli/commands/remove_mcp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import List

from cecli.commands.utils.base_command import BaseCommand
from cecli.commands.utils.helpers import format_command_result


class RemoveMcpCommand(BaseCommand):
NORM_NAME = "remove-mcp"
DESCRIPTION = "Remove a MCP server by name"

@classmethod
async def execute(cls, io, coder, args, **kwargs):
"""Execute the remove-mcp command with given parameters."""
if not args.strip():
return format_command_result(io, cls.NORM_NAME, "Usage: /remove-mcp <mcp-name>")

if not coder.mcp_manager or not coder.mcp_manager.servers:
return format_command_result(
io, cls.NORM_NAME, "No MCP servers connected, nothing to remove."
)

server_name = args.strip()
was_disconnected = await coder.mcp_manager.disconnect_server(server_name)

try:
if was_disconnected:
return format_command_result(io, cls.NORM_NAME, f"Removed server: {server_name}")
else:
return format_command_result(
io, cls.NORM_NAME, "", f"Unable to remove server: {server_name}"
)
finally:
from . import SwitchCoderSignal

raise SwitchCoderSignal(
edit_format=coder.edit_format,
summarize_from_coder=False,
from_coder=coder,
show_announcements=True,
mcp_manager=coder.mcp_manager,
)

@classmethod
def get_completions(cls, io, coder, args) -> List[str]:
"""Get completion options for remove-mcp command."""
if not coder.mcp_manager or not coder.mcp_manager.servers:
return []

try:
server_names = [server.name for server in coder.mcp_manager if server.is_connected]
return server_names
except Exception:
return []

@classmethod
def get_help(cls) -> str:
"""Get help text for the remove-mcp command."""
help_text = super().get_help()
help_text += "\nUsage:\n"
help_text += " /remove-mcp <mcp-name> # Remove a mcp by name\n"
help_text += "\nExamples:\n"
help_text += " /remove-mcp context7 # Remove the context7 mcp\n"
help_text += " /remove-mcp github # Remove the github mcp\n"
help_text += "\nThis command removes a MCP server by name.\n"
return help_text
15 changes: 13 additions & 2 deletions cecli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

from cecli import __version__, models, urls, utils
from cecli.args import get_parser
from cecli.coders import Coder
from cecli.coders import AgentCoder, Coder
from cecli.coders.base_coder import UnknownEditFormat
from cecli.commands import Commands, SwitchCoderSignal
from cecli.deprecated_args import handle_deprecated_model_args
Expand Down Expand Up @@ -983,7 +983,7 @@ def apply_model_overrides(model_name):
mcp_servers = load_mcp_servers(
args.mcp_servers, args.mcp_servers_file, io, args.verbose, args.mcp_transport
)
mcp_manager = McpServerManager(mcp_servers, io, args.verbose)
mcp_manager = await McpServerManager.from_servers(mcp_servers, io, args.verbose)

coder = await Coder.create(
main_model=main_model,
Expand Down Expand Up @@ -1189,17 +1189,28 @@ def apply_model_overrides(model_name):
return await graceful_exit(coder)
except SwitchCoderSignal as switch:
coder.ok_to_warm_cache = False

if hasattr(switch, "placeholder") and switch.placeholder is not None:
io.placeholder = switch.placeholder
kwargs = dict(io=io, from_coder=coder)
kwargs.update(switch.kwargs)

if "show_announcements" in kwargs:
del kwargs["show_announcements"]
kwargs["num_cache_warming_pings"] = 0
kwargs["args"] = coder.args

if kwargs["edit_format"] != AgentCoder.edit_format and (
coder := kwargs.get("from_coder")
):
if coder.mcp_manager.get_server("Local"):
await coder.mcp_manager.disconnect_server("Local")

coder = await Coder.create(**kwargs)

if switch.kwargs.get("show_announcements") is False:
coder.suppress_announcements_for_next_prompt = True

except SystemExit:
sys.settrace(None)
return await graceful_exit(coder)
Expand Down
Loading
Loading