Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions cecli/coders/agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from cecli.helpers import nested, responses
from cecli.helpers.background_commands import BackgroundCommandManager
from cecli.helpers.conversation import ConversationService, MessageTag
from cecli.helpers.coroutines import interruptible # isort:skip
from cecli.helpers.similarity import (
cosine_similarity,
create_bigram_vector,
Expand All @@ -27,10 +28,11 @@
from cecli.mcp import LocalServer, McpServerManager
from cecli.tools.utils.base_tool import BaseTool
from cecli.tools.utils.registry import ToolRegistry
from cecli.helpers.coroutines import interruptible
from cecli.utils import copy_tool_call, tool_call_to_dict

from .base_coder import Coder

Check failure on line 35 in cecli/coders/agent_coder.py

View workflow job for this annotation

GitHub Actions / pre-commit

F811 redefinition of unused 'interruptible' from line 19

class AgentCoder(Coder):
"""Mode where the LLM autonomously manages which files are in context."""
Expand Down Expand Up @@ -301,8 +303,20 @@
else:
all_results_content.append(f"Error: Unknown tool name '{tool_name}'")
if tasks:
task_results = await asyncio.gather(*tasks)
all_results_content.extend(str(res) for res in task_results)
gather_coro = asyncio.gather(*tasks, return_exceptions=True)
task_results, interrupted = await interruptible(
gather_coro, self.interrupt_event
)

if interrupted:
self.io.tool_warning("Tool execution interrupted.")
all_results_content.append("Tool execution interrupted by user.")
elif task_results:
for res in task_results:
if isinstance(res, Exception):
all_results_content.append(f"Error in tool execution: {res}")
else:
all_results_content.append(str(res))

if not await HookIntegration.call_post_tool_hooks(
self, tool_name, args_string, "\n\n".join(all_results_content)
Expand Down Expand Up @@ -393,7 +407,11 @@
""")
return f"Error executing tool call {tool_name}: {e}"

return await _exec_async()
result, interrupted = await interruptible(_exec_async(), self.interrupt_event)

if interrupted:
return "Tool execution interrupted by user."
return result

def _calculate_context_block_tokens(self, force=False):
"""
Expand Down
88 changes: 48 additions & 40 deletions cecli/coders/base_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,11 +1370,6 @@ async def _run_parallel(self, with_message=None, preproc=True):
except (SwitchCoderSignal, SystemExit):
# Re-raise SwitchCoder to be handled by outer try block
raise
except KeyboardInterrupt:
# Handle keyboard interrupt gracefully
self.io.set_placeholder("")
self.io.stop_spinner()
self.keyboard_interrupt()
finally:
# Signal tasks to stop
self.input_running = False
Expand Down Expand Up @@ -1454,10 +1449,6 @@ async def input_task(self, preproc):

await asyncio.sleep(0.1) # Small yield to prevent tight loop

except KeyboardInterrupt:
self.io.set_placeholder("")
self.keyboard_interrupt()
await self.io.stop_task_streams()
except (SwitchCoderSignal, SystemExit):
raise
except Exception as e:
Expand Down Expand Up @@ -1739,7 +1730,6 @@ def keyboard_interrupt(self):
Console().show_cursor(True)

self.io.tool_warning("\n\n^C KeyboardInterrupt")

self.interrupt_event.set()
self.last_keyboard_interrupt = time.time()

Expand Down Expand Up @@ -2260,9 +2250,16 @@ async def send_message(self, inp):
self.io.tool_error(err_msg)

self.io.tool_output(f"Retrying in {retry_delay:.1f} seconds...")
await asyncio.sleep(retry_delay)

_res, interrupted_sleep = await coroutines.interruptible(
asyncio.sleep(retry_delay), self.interrupt_event
)
if interrupted_sleep:
interrupted = True
break

continue
except KeyboardInterrupt:
except (KeyboardInterrupt, asyncio.CancelledError):
interrupted = True
break
except FinishReasonLength:
Expand Down Expand Up @@ -2627,11 +2624,19 @@ async def _execute_mcp_tools(self, server, tool_calls):
all_results_content.append("Tool Request Aborted.")
continue

call_result = await experimental_mcp_client.call_openai_tool(
session=session,
openai_tool=new_tool_call,
async def do_tool_call():
return await experimental_mcp_client.call_openai_tool(
session=session,
openai_tool=new_tool_call,
)

call_result, interrupted = await coroutines.interruptible(
do_tool_call(), self.interrupt_event
)

if interrupted:
raise KeyboardInterrupt("Tool call interrupted")

content_parts = []
if call_result.content:
for item in call_result.content:
Expand Down Expand Up @@ -2676,6 +2681,9 @@ async def _execute_mcp_tools(self, server, tool_calls):
}
)

except KeyboardInterrupt:
self.io.tool_warning(f"Tool call {tool_call.function.name} interrupted.")
raise
except Exception as e:
tool_error = f"Error executing tool call {tool_call.function.name}: \n{e}"
self.io.tool_warning(
Expand All @@ -2692,6 +2700,9 @@ async def _execute_mcp_tools(self, server, tool_calls):
tool_responses.append(
{"role": "tool", "tool_call_id": tool_call.id, "content": connection_error}
)
except asyncio.CancelledError:
# Re-raise CancelledError to ensure the task cancellation propagates
raise
except Exception as e:
connection_error = f"Could not connect to server {server.name}\n{e}"
self.io.tool_warning(connection_error)
Expand Down Expand Up @@ -2726,7 +2737,15 @@ async def process_tool_calls(self, tool_call_response):
return False

# 5. Execute tools
tool_responses_by_server = await self._execute_tool_groups(tool_groups)
self.interrupt_event.clear()

tool_responses_by_server, interrupted = await coroutines.interruptible(
self._execute_tool_groups(tool_groups), self.interrupt_event
)

if interrupted:
self.io.tool_warning("Tool execution interrupted.")
return False

# 6. Add responses to conversation (re-prefixing if necessary)
tool_responses = []
Expand Down Expand Up @@ -3038,33 +3057,22 @@ async def send(self, messages, model=None, functions=None, tools=None):
self.token_profiler.start()

try:
completion_task = asyncio.create_task(
model.send_completion(
messages,
functions,
self.stream,
self.temperature,
# This could include any tools, but for now it is just MCP tools
tools=tools,
override_kwargs=self.model_kwargs.copy(),
)
completion_coro = model.send_completion(
messages,
functions,
self.stream,
self.temperature,
# This could include any tools, but for now it is just MCP tools
tools=tools,
override_kwargs=self.model_kwargs.copy(),
interrupt_event=self.interrupt_event,
)
interrupt_task = asyncio.create_task(self.interrupt_event.wait())

done, pending = await asyncio.wait(
{completion_task, interrupt_task},
return_when=asyncio.FIRST_COMPLETED,
(hash_object, completion), interrupted = await coroutines.interruptible(
completion_coro, self.interrupt_event
)

if interrupt_task in done:
completion_task.cancel()
try:
await completion_task
except asyncio.CancelledError:
pass
if interrupted:
raise KeyboardInterrupt

hash_object, completion = completion_task.result()
self.chat_completion_call_hashes.append(hash_object.hexdigest())

if not isinstance(completion, ModelResponse):
Expand All @@ -3087,7 +3095,7 @@ async def send(self, messages, model=None, functions=None, tools=None):
self.token_profiler.on_error()
self.calculate_and_show_tokens_and_cost(messages, completion)
raise
except KeyboardInterrupt as kbi:
except (KeyboardInterrupt, asyncio.CancelledError) as kbi:
self.keyboard_interrupt()
raise kbi
finally:
Expand Down
62 changes: 41 additions & 21 deletions cecli/commands/load_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,48 +20,68 @@ async def execute(cls, io, coder, args, **kwargs):
)

server_names = args.strip().split()

results = []

servers_to_load = []

# Handle '*' wildcard to load all servers enabled by default
if server_names == ["*"]:
for server in coder.mcp_manager.servers:
if server in coder.mcp_manager.connected_servers:
results.append(f"Server already loaded: {server.name}")
continue

auto_connect = server.config.get("enabled", True)
if not auto_connect:
results.append(f"Skipping server (not enabled by default): {server.name}")
continue
did_connect = await coder.mcp_manager.connect_server(server.name)
if did_connect:
results.append(f"Loaded server: {server.name}")
else:
results.append(f"Unable to load server: {server.name}")

servers_to_load.append(server)
else:
for server_name in server_names:
server = coder.mcp_manager.get_server(server_name)
if server is None:
io.tool_error(f"MCP server {server_name} does not exist.")
results.append(f"MCP server {server_name} does not exist.")
continue

did_connect = await coder.mcp_manager.connect_server(server.name)
if did_connect:
results.append(f"Loaded server: {server_name}")
else:
results.append(f"Unable to load server: {server_name}")
servers_to_load.append(server)

try:
return format_command_result(io, cls.NORM_NAME, "\n".join(results))
finally:
from . import SwitchCoderSignal

raise SwitchCoderSignal(
edit_format=coder.edit_format,
summarize_from_coder=False,
from_coder=coder,
show_announcements=True,
# Early exit if nothing valid to process
if not servers_to_load and results:
return format_command_result(io, cls.NORM_NAME, "", "\n".join(results))

# Process connections with interrupt support
for server in servers_to_load:
server_name = server.name
coder.interrupt_event.clear()

did_connect, interrupted = await coder.coroutines.interruptible(
coder.mcp_manager.connect_server(server_name),
coder.interrupt_event,
)

if interrupted:
io.tool_warning(f"MCP connection interrupted: {server_name}")
results.append(f"Interrupted: {server_name}")
continue

if did_connect:
results.append(f"Loaded server: {server_name}")
else:
results.append(f"Unable to load server: {server_name}")

io.tool_output("\n".join(results))

from . import SwitchCoderSignal

raise SwitchCoderSignal(
edit_format=coder.edit_format,
summarize_from_coder=False,
from_coder=coder,
show_announcements=True,
)

@classmethod
def get_completions(cls, io, coder, args) -> List[str]:
"""Get completion options for load-mcp command."""
Expand Down
59 changes: 40 additions & 19 deletions cecli/commands/remove_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,38 +20,59 @@ async def execute(cls, io, coder, args, **kwargs):
)

server_names = args.strip().split()

results = []
servers_to_disconnect = []

# Handle '*' wildcard to disconnect all servers
if server_names == ["*"]:
connected = [s for s in coder.mcp_manager.servers if s.is_connected]

if not connected:
results.append("No MCP servers connected, nothing to remove.")
else:
for server in connected:
await coder.mcp_manager.disconnect_server(server.name)
results.append(f"Removed server: {server.name}")
servers_to_disconnect.extend(connected)
else:
for server_name in server_names:
was_disconnected = await coder.mcp_manager.disconnect_server(server_name)
if was_disconnected:
results.append(f"Removed server: {server_name}")
else:
results.append(f"Unable to remove server: {server_name}")
servers_to_disconnect.append(server_name)

try:
return format_command_result(io, cls.NORM_NAME, "\n".join(results))
finally:
from . import SwitchCoderSignal

raise SwitchCoderSignal(
edit_format=coder.edit_format,
summarize_from_coder=False,
from_coder=coder,
show_announcements=True,
mcp_manager=coder.mcp_manager,
# Early exit if nothing to process
if not servers_to_disconnect and results:
return format_command_result(io, cls.NORM_NAME, "", "\n".join(results))

# Process disconnections with interrupt support
for item in servers_to_disconnect:
server_name = item.name if hasattr(item, "name") else item

coder.interrupt_event.clear()

was_disconnected, interrupted = await coder.coroutines.interruptible(
coder.mcp_manager.disconnect_server(server_name),
coder.interrupt_event,
)

if interrupted:
io.tool_warning(f"MCP disconnection interrupted: {server_name}")
results.append(f"Interrupted: {server_name}")
continue

if was_disconnected:
results.append(f"Removed server: {server_name}")
else:
results.append(f"Unable to remove server: {server_name}")

io.tool_output("\n".join(results))

from . import SwitchCoderSignal

raise SwitchCoderSignal(
edit_format=coder.edit_format,
summarize_from_coder=False,
from_coder=coder,
show_announcements=True,
mcp_manager=coder.mcp_manager,
)

@classmethod
def get_completions(cls, io, coder, args) -> List[str]:
"""Get completion options for remove-mcp command."""
Expand Down
Loading
Loading