diff --git a/cecli/coders/agent_coder.py b/cecli/coders/agent_coder.py index 30f52a193fb..21cbe00aa61 100644 --- a/cecli/coders/agent_coder.py +++ b/cecli/coders/agent_coder.py @@ -16,6 +16,7 @@ from cecli.helpers import nested, responses from cecli.helpers.background_commands import BackgroundCommandManager from cecli.helpers.conversation import ConversationService, MessageTag +from cecli.helpers.coroutines import interruptible # isort:skip from cecli.helpers.similarity import ( cosine_similarity, create_bigram_vector, @@ -31,6 +32,8 @@ from .base_coder import Coder +from cecli.helpers.coroutines import interruptible # isort:skip + class AgentCoder(Coder): """Mode where the LLM autonomously manages which files are in context.""" @@ -301,8 +304,20 @@ async def _execute_local_tool_calls(self, tool_calls_list): else: all_results_content.append(f"Error: Unknown tool name '{tool_name}'") if tasks: - task_results = await asyncio.gather(*tasks) - all_results_content.extend(str(res) for res in task_results) + gather_coro = asyncio.gather(*tasks, return_exceptions=True) + task_results, interrupted = await interruptible( + gather_coro, self.interrupt_event + ) + + if interrupted: + self.io.tool_warning("Tool execution interrupted.") + all_results_content.append("Tool execution interrupted by user.") + elif task_results: + for res in task_results: + if isinstance(res, Exception): + all_results_content.append(f"Error in tool execution: {res}") + else: + all_results_content.append(str(res)) if not await HookIntegration.call_post_tool_hooks( self, tool_name, args_string, "\n\n".join(all_results_content) @@ -393,7 +408,11 @@ async def _exec_async(): """) return f"Error executing tool call {tool_name}: {e}" - return await _exec_async() + result, interrupted = await interruptible(_exec_async(), self.interrupt_event) + + if interrupted: + return "Tool execution interrupted by user." + return result def _calculate_context_block_tokens(self, force=False): """ diff --git a/cecli/coders/base_coder.py b/cecli/coders/base_coder.py index 0b7f847d436..b4de5ab2b8c 100755 --- a/cecli/coders/base_coder.py +++ b/cecli/coders/base_coder.py @@ -1370,11 +1370,6 @@ async def _run_parallel(self, with_message=None, preproc=True): except (SwitchCoderSignal, SystemExit): # Re-raise SwitchCoder to be handled by outer try block raise - except KeyboardInterrupt: - # Handle keyboard interrupt gracefully - self.io.set_placeholder("") - self.io.stop_spinner() - self.keyboard_interrupt() finally: # Signal tasks to stop self.input_running = False @@ -1454,10 +1449,6 @@ async def input_task(self, preproc): await asyncio.sleep(0.1) # Small yield to prevent tight loop - except KeyboardInterrupt: - self.io.set_placeholder("") - self.keyboard_interrupt() - await self.io.stop_task_streams() except (SwitchCoderSignal, SystemExit): raise except Exception as e: @@ -1739,7 +1730,6 @@ def keyboard_interrupt(self): Console().show_cursor(True) self.io.tool_warning("\n\n^C KeyboardInterrupt") - self.interrupt_event.set() self.last_keyboard_interrupt = time.time() @@ -2260,9 +2250,16 @@ async def send_message(self, inp): self.io.tool_error(err_msg) self.io.tool_output(f"Retrying in {retry_delay:.1f} seconds...") - await asyncio.sleep(retry_delay) + + _res, interrupted_sleep = await coroutines.interruptible( + asyncio.sleep(retry_delay), self.interrupt_event + ) + if interrupted_sleep: + interrupted = True + break + continue - except KeyboardInterrupt: + except (KeyboardInterrupt, asyncio.CancelledError): interrupted = True break except FinishReasonLength: @@ -2627,11 +2624,19 @@ async def _execute_mcp_tools(self, server, tool_calls): all_results_content.append("Tool Request Aborted.") continue - call_result = await experimental_mcp_client.call_openai_tool( - session=session, - openai_tool=new_tool_call, + async def do_tool_call(): + return await experimental_mcp_client.call_openai_tool( + session=session, + openai_tool=new_tool_call, + ) + + call_result, interrupted = await coroutines.interruptible( + do_tool_call(), self.interrupt_event ) + if interrupted: + raise KeyboardInterrupt("Tool call interrupted") + content_parts = [] if call_result.content: for item in call_result.content: @@ -2676,6 +2681,9 @@ async def _execute_mcp_tools(self, server, tool_calls): } ) + except KeyboardInterrupt: + self.io.tool_warning(f"Tool call {tool_call.function.name} interrupted.") + raise except Exception as e: tool_error = f"Error executing tool call {tool_call.function.name}: \n{e}" self.io.tool_warning( @@ -2692,6 +2700,9 @@ async def _execute_mcp_tools(self, server, tool_calls): tool_responses.append( {"role": "tool", "tool_call_id": tool_call.id, "content": connection_error} ) + except asyncio.CancelledError: + # Re-raise CancelledError to ensure the task cancellation propagates + raise except Exception as e: connection_error = f"Could not connect to server {server.name}\n{e}" self.io.tool_warning(connection_error) @@ -2726,7 +2737,15 @@ async def process_tool_calls(self, tool_call_response): return False # 5. Execute tools - tool_responses_by_server = await self._execute_tool_groups(tool_groups) + self.interrupt_event.clear() + + tool_responses_by_server, interrupted = await coroutines.interruptible( + self._execute_tool_groups(tool_groups), self.interrupt_event + ) + + if interrupted: + self.io.tool_warning("Tool execution interrupted.") + return False # 6. Add responses to conversation (re-prefixing if necessary) tool_responses = [] @@ -3038,33 +3057,22 @@ async def send(self, messages, model=None, functions=None, tools=None): self.token_profiler.start() try: - completion_task = asyncio.create_task( - model.send_completion( - messages, - functions, - self.stream, - self.temperature, - # This could include any tools, but for now it is just MCP tools - tools=tools, - override_kwargs=self.model_kwargs.copy(), - ) + completion_coro = model.send_completion( + messages, + functions, + self.stream, + self.temperature, + # This could include any tools, but for now it is just MCP tools + tools=tools, + override_kwargs=self.model_kwargs.copy(), + interrupt_event=self.interrupt_event, ) - interrupt_task = asyncio.create_task(self.interrupt_event.wait()) - done, pending = await asyncio.wait( - {completion_task, interrupt_task}, - return_when=asyncio.FIRST_COMPLETED, + (hash_object, completion), interrupted = await coroutines.interruptible( + completion_coro, self.interrupt_event ) - - if interrupt_task in done: - completion_task.cancel() - try: - await completion_task - except asyncio.CancelledError: - pass + if interrupted: raise KeyboardInterrupt - - hash_object, completion = completion_task.result() self.chat_completion_call_hashes.append(hash_object.hexdigest()) if not isinstance(completion, ModelResponse): @@ -3087,7 +3095,7 @@ async def send(self, messages, model=None, functions=None, tools=None): self.token_profiler.on_error() self.calculate_and_show_tokens_and_cost(messages, completion) raise - except KeyboardInterrupt as kbi: + except (KeyboardInterrupt, asyncio.CancelledError) as kbi: self.keyboard_interrupt() raise kbi finally: diff --git a/cecli/commands/load_mcp.py b/cecli/commands/load_mcp.py index eb1e6d2e402..302d568640f 100644 --- a/cecli/commands/load_mcp.py +++ b/cecli/commands/load_mcp.py @@ -20,48 +20,68 @@ async def execute(cls, io, coder, args, **kwargs): ) server_names = args.strip().split() + results = [] + servers_to_load = [] + # Handle '*' wildcard to load all servers enabled by default if server_names == ["*"]: for server in coder.mcp_manager.servers: if server in coder.mcp_manager.connected_servers: results.append(f"Server already loaded: {server.name}") continue + auto_connect = server.config.get("enabled", True) if not auto_connect: results.append(f"Skipping server (not enabled by default): {server.name}") continue - did_connect = await coder.mcp_manager.connect_server(server.name) - if did_connect: - results.append(f"Loaded server: {server.name}") - else: - results.append(f"Unable to load server: {server.name}") + + servers_to_load.append(server) else: for server_name in server_names: server = coder.mcp_manager.get_server(server_name) if server is None: + io.tool_error(f"MCP server {server_name} does not exist.") results.append(f"MCP server {server_name} does not exist.") - continue - - did_connect = await coder.mcp_manager.connect_server(server.name) - if did_connect: - results.append(f"Loaded server: {server_name}") else: - results.append(f"Unable to load server: {server_name}") + servers_to_load.append(server) - try: - return format_command_result(io, cls.NORM_NAME, "\n".join(results)) - finally: - from . import SwitchCoderSignal - - raise SwitchCoderSignal( - edit_format=coder.edit_format, - summarize_from_coder=False, - from_coder=coder, - show_announcements=True, + # Early exit if nothing valid to process + if not servers_to_load and results: + return format_command_result(io, cls.NORM_NAME, "", "\n".join(results)) + + # Process connections with interrupt support + for server in servers_to_load: + server_name = server.name + coder.interrupt_event.clear() + + did_connect, interrupted = await coder.coroutines.interruptible( + coder.mcp_manager.connect_server(server_name), + coder.interrupt_event, ) + if interrupted: + io.tool_warning(f"MCP connection interrupted: {server_name}") + results.append(f"Interrupted: {server_name}") + continue + + if did_connect: + results.append(f"Loaded server: {server_name}") + else: + results.append(f"Unable to load server: {server_name}") + + io.tool_output("\n".join(results)) + + from . import SwitchCoderSignal + + raise SwitchCoderSignal( + edit_format=coder.edit_format, + summarize_from_coder=False, + from_coder=coder, + show_announcements=True, + ) + @classmethod def get_completions(cls, io, coder, args) -> List[str]: """Get completion options for load-mcp command.""" diff --git a/cecli/commands/remove_mcp.py b/cecli/commands/remove_mcp.py index 2239d7ba883..ad212da4051 100644 --- a/cecli/commands/remove_mcp.py +++ b/cecli/commands/remove_mcp.py @@ -20,38 +20,59 @@ async def execute(cls, io, coder, args, **kwargs): ) server_names = args.strip().split() + results = [] + servers_to_disconnect = [] # Handle '*' wildcard to disconnect all servers if server_names == ["*"]: connected = [s for s in coder.mcp_manager.servers if s.is_connected] + if not connected: results.append("No MCP servers connected, nothing to remove.") else: - for server in connected: - await coder.mcp_manager.disconnect_server(server.name) - results.append(f"Removed server: {server.name}") + servers_to_disconnect.extend(connected) else: for server_name in server_names: - was_disconnected = await coder.mcp_manager.disconnect_server(server_name) - if was_disconnected: - results.append(f"Removed server: {server_name}") - else: - results.append(f"Unable to remove server: {server_name}") + servers_to_disconnect.append(server_name) - try: - return format_command_result(io, cls.NORM_NAME, "\n".join(results)) - finally: - from . import SwitchCoderSignal - - raise SwitchCoderSignal( - edit_format=coder.edit_format, - summarize_from_coder=False, - from_coder=coder, - show_announcements=True, - mcp_manager=coder.mcp_manager, + # Early exit if nothing to process + if not servers_to_disconnect and results: + return format_command_result(io, cls.NORM_NAME, "", "\n".join(results)) + + # Process disconnections with interrupt support + for item in servers_to_disconnect: + server_name = item.name if hasattr(item, "name") else item + + coder.interrupt_event.clear() + + was_disconnected, interrupted = await coder.coroutines.interruptible( + coder.mcp_manager.disconnect_server(server_name), + coder.interrupt_event, ) + if interrupted: + io.tool_warning(f"MCP disconnection interrupted: {server_name}") + results.append(f"Interrupted: {server_name}") + continue + + if was_disconnected: + results.append(f"Removed server: {server_name}") + else: + results.append(f"Unable to remove server: {server_name}") + + io.tool_output("\n".join(results)) + + from . import SwitchCoderSignal + + raise SwitchCoderSignal( + edit_format=coder.edit_format, + summarize_from_coder=False, + from_coder=coder, + show_announcements=True, + mcp_manager=coder.mcp_manager, + ) + @classmethod def get_completions(cls, io, coder, args) -> List[str]: """Get completion options for remove-mcp command.""" diff --git a/cecli/helpers/coroutines.py b/cecli/helpers/coroutines.py index 77cee82b162..07f1a669d5a 100644 --- a/cecli/helpers/coroutines.py +++ b/cecli/helpers/coroutines.py @@ -1,8 +1,45 @@ -import asyncio # noqa: F401 +import asyncio -def is_active(coroutine): - if not coroutine or coroutine.done() or coroutine.cancelled(): +def is_active(task): + if not task or task.done() or task.cancelled(): return False return True + + +async def interruptible(coroutine, interrupt_event): + """ + Runs a coroutine and allows it to be interrupted by an asyncio.Event. + + Args: + coroutine: The coroutine to run. + interrupt_event: The asyncio.Event that signals an interruption. + + Returns: + A tuple of (result, interrupted). + - If not interrupted: (coroutine_result, False) + - If interrupted: (None, True) + """ + main_task = asyncio.create_task(coroutine) + interrupt_task = asyncio.create_task(interrupt_event.wait()) + + done, pending = await asyncio.wait( + {main_task, interrupt_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass # Expected + + if interrupt_task in done: + return None, True + + try: + return main_task.result(), False + except asyncio.CancelledError: + return None, True diff --git a/cecli/io.py b/cecli/io.py index 4f50b9f6a02..c3f207bade8 100644 --- a/cecli/io.py +++ b/cecli/io.py @@ -762,6 +762,8 @@ def interrupt_input(self): coder = self.coder() if coder and hasattr(coder, "interrupt_event"): coder.interrupt_event.set() + if self.output_task and not self.output_task.done(): + self.output_task.cancel() if self.prompt_session and self.prompt_session.app: # Store any partial input before interrupting diff --git a/cecli/main.py b/cecli/main.py index bf8b89fa99d..d0c8a2a31c2 100644 --- a/cecli/main.py +++ b/cecli/main.py @@ -1247,6 +1247,9 @@ def get_io(pretty): if switch.kwargs.get("show_announcements") is False: coder.suppress_announcements_for_next_prompt = True + except KeyboardInterrupt: + coder.keyboard_interrupt() + continue except SystemExit: sys.settrace(None) await coder.auto_save_session(force=True) diff --git a/cecli/models.py b/cecli/models.py index 04e47c7d4ac..495895bda12 100644 --- a/cecli/models.py +++ b/cecli/models.py @@ -19,7 +19,7 @@ from cecli import __version__ from cecli.dump import dump from cecli.exceptions import LiteLLMExceptions -from cecli.helpers import nested +from cecli.helpers import coroutines, nested from cecli.helpers.file_searcher import generate_search_path_list, handle_core_files from cecli.helpers.model_providers import ModelProviderManager from cecli.helpers.nested import deep_merge @@ -1132,6 +1132,7 @@ async def send_completion( min_wait=0, max_wait=2, override_kwargs={}, + interrupt_event=None, ): if os.environ.get("CECLI_SANITY_CHECK_TURNS"): sanity_check_messages(messages) @@ -1290,7 +1291,14 @@ async def send_completion( return hash_object, self.model_error_response() print(f"Retrying in {retry_delay:.1f} seconds...") - await asyncio.sleep(retry_delay) + if interrupt_event: + _res, interrupted = await coroutines.interruptible( + asyncio.sleep(retry_delay), interrupt_event + ) + if interrupted: + raise KeyboardInterrupt("Interrupted during retry sleep") + else: + await asyncio.sleep(retry_delay) continue async def simple_send_with_retries( diff --git a/cecli/tools/command.py b/cecli/tools/command.py index 28c1bec9ba6..4bf1ec941c4 100644 --- a/cecli/tools/command.py +++ b/cecli/tools/command.py @@ -228,6 +228,15 @@ async def _execute_with_timeout(cls, coder, command_string, timeout, use_pty=Fal start_time = time.time() while True: + if coder.interrupt_event.is_set(): + process.terminate() + try: + process.wait(timeout=1) + except subprocess.TimeoutExpired: + process.kill() + BackgroundCommandManager.stop_background_command(command_key) + return "Command execution interrupted by user." + # Check if process has completed exit_code = process.poll() if exit_code is not None: diff --git a/cecli/tui/app.py b/cecli/tui/app.py index 464726a61c1..f4121d36270 100644 --- a/cecli/tui/app.py +++ b/cecli/tui/app.py @@ -105,7 +105,10 @@ def __init__(self, coder_worker, output_queue, input_queue, args): show=True, ) self.bind( - self._encode_keys(self.get_keys_for("cancel")), "noop", description="Cancel", show=True + self._encode_keys(self.get_keys_for("cancel")), + "interrupt", + description="Cancel", + show=True, ) self.bind( self._encode_keys(self.get_keys_for("editor")),