diff --git a/cecli/coders/agent_coder.py b/cecli/coders/agent_coder.py index a3cf86b629e..ee73f3b4028 100644 --- a/cecli/coders/agent_coder.py +++ b/cecli/coders/agent_coder.py @@ -827,11 +827,12 @@ async def reply_completed(self): ) self.io.tool_output(waiting_msg) await asyncio.sleep(command_timeout / 2) - return True + return False # Check for recently finished commands that need reflection if recently_finished_commands and not self.agent_finished: - return True # Retrigger reflection to process recently finished command outputs + self.reflected_message = "Background command finished, processing output." + return False # Retrigger reflection to process recently finished command outputs # 3. If no content and no tools, we might be done or just empty response if (not content or not content.strip()) and not tool_calls_found: diff --git a/cecli/coders/base_coder.py b/cecli/coders/base_coder.py index 0dd6203e298..fb10f29a057 100755 --- a/cecli/coders/base_coder.py +++ b/cecli/coders/base_coder.py @@ -2,6 +2,7 @@ import asyncio import base64 +import asyncio import hashlib import json import locale @@ -366,6 +367,7 @@ def __init__( self.context_compaction_max_tokens = context_compaction_max_tokens self.context_compaction_summary_tokens = context_compaction_summary_tokens + self.globally_approved_tool_calls = False self.max_reflections = ( 3 if self.edit_format == "agent" else nested.getter(self.args, "max_reflections", 3) ) @@ -1278,7 +1280,11 @@ async def _run_linear(self, with_message=None, preproc=True): try: if with_message: self.io.user_input(with_message) - await self.run_one(with_message, preproc) + self.io.is_processing_prompt = True + try: + await self.run_one(with_message, preproc) + finally: + self.io.is_processing_prompt = False return self.partial_response_content user_message = None @@ -1340,7 +1346,11 @@ async def _run_parallel(self, with_message=None, preproc=True): try: if with_message: self.io.user_input(with_message) - await self.run_one(with_message, preproc) + self.io.is_processing_prompt = True + try: + await self.run_one(with_message, preproc) + finally: + self.io.is_processing_prompt = False return self.partial_response_content # Initialize state for task coordination @@ -1534,7 +1544,11 @@ async def generate(self, user_message, preproc): self.compact_context_completed = True self.run_one_completed = False - await self.run_one(user_message, preproc) + self.io.is_processing_prompt = True + try: + await self.run_one(user_message, preproc) + finally: + self.io.is_processing_prompt = False self.show_undo_hint() except asyncio.CancelledError: # Don't show undo hint if cancelled @@ -1738,6 +1752,7 @@ def keyboard_interrupt(self): Console().show_cursor(True) self.io.tool_warning("\n\n^C KeyboardInterrupt") + self.interrupt_event.set() self.interrupt_event.set() self.last_keyboard_interrupt = time.time() @@ -2261,7 +2276,7 @@ async def send_message(self, inp): self.io.tool_output(f"Retrying in {retry_delay:.1f} seconds...") await asyncio.sleep(retry_delay) continue - except KeyboardInterrupt: + except (KeyboardInterrupt, asyncio.CancelledError): interrupted = True break except FinishReasonLength: @@ -2721,11 +2736,41 @@ async def process_tool_calls(self, tool_call_response): self._print_tool_call_info(server_tool_calls=tool_groups) # 4. Ask for user confirmation - if not await self.io.confirm_ask("Run tools?", group_response="Run MCP Tools"): - return False + try: + self.globally_approved_tool_calls = False + if not await self.io.confirm_ask("Run tools?", group_response="Run MCP Tools"): + return False + + # 5. Execute tools + tool_execution_task = asyncio.create_task(self._execute_tool_groups(tool_groups)) + interrupt_task = asyncio.create_task(self.interrupt_event.wait()) + + tool_responses_by_server = {} + try: + done, pending = await asyncio.wait( + {tool_execution_task, interrupt_task}, + return_when=asyncio.FIRST_COMPLETED, + ) - # 5. Execute tools - tool_responses_by_server = await self._execute_tool_groups(tool_groups) + if interrupt_task in done: + tool_execution_task.cancel() + try: + await tool_execution_task + except asyncio.CancelledError: + pass + self.io.tool_warning("Tool execution interrupted.") + return False + + if tool_execution_task in done: + tool_responses_by_server = tool_execution_task.result() + + except asyncio.CancelledError: + self.io.tool_warning("Tool execution cancelled.") + return False + if self.io.group_responses.get("Run MCP Tools"): + self.globally_approved_tool_calls = True + finally: + self.globally_approved_tool_calls = False # 6. Add responses to conversation (re-prefixing if necessary) tool_responses = [] @@ -2745,7 +2790,6 @@ async def process_tool_calls(self, tool_call_response): def _print_tool_call_info(self, server_tool_calls): """Print information about an MCP tool call.""" - self.io.ring_bell() # self.io.tool_output("Preparing to run MCP tools", bold=False) for server, tool_calls in server_tool_calls.items(): @@ -3045,7 +3089,7 @@ async def send(self, messages, model=None, functions=None, tools=None): self.temperature, # This could include any tools, but for now it is just MCP tools tools=tools, - override_kwargs=self.model_kwargs.copy(), + override_kwargs=self.model_kwargs, ) ) interrupt_task = asyncio.create_task(self.interrupt_event.wait()) @@ -3086,7 +3130,7 @@ async def send(self, messages, model=None, functions=None, tools=None): self.token_profiler.on_error() self.calculate_and_show_tokens_and_cost(messages, completion) raise - except KeyboardInterrupt as kbi: + except (KeyboardInterrupt, asyncio.CancelledError) as kbi: self.keyboard_interrupt() raise kbi finally: diff --git a/cecli/io.py b/cecli/io.py index 4f50b9f6a02..2303161fe9e 100644 --- a/cecli/io.py +++ b/cecli/io.py @@ -385,6 +385,7 @@ def __init__( self.verbose = verbose self.profile_start_time = None self.profile_last_time = None + self.is_processing_prompt = False # Variables used to interface with base_coder self.coder = None @@ -762,6 +763,8 @@ def interrupt_input(self): coder = self.coder() if coder and hasattr(coder, "interrupt_event"): coder.interrupt_event.set() + if self.output_task and not self.output_task.done(): + self.output_task.cancel() if self.prompt_session and self.prompt_session.app: # Store any partial input before interrupting @@ -812,267 +815,272 @@ async def get_input( abs_read_only_stubs_fnames=None, edit_format=None, ): - self.rule() - - rel_fnames = list(rel_fnames) - show = "" - if rel_fnames: - rel_read_only_fnames = [ - get_rel_fname(fname, root) for fname in abs_read_only_fnames or [] - ] - rel_read_only_stubs_fnames = [ - get_rel_fname(fname, root) for fname in abs_read_only_stubs_fnames or [] - ] - show = self.format_files_for_input( - rel_fnames, rel_read_only_fnames, rel_read_only_stubs_fnames - ) - - prompt_prefix = "" - - if edit_format: - prompt_prefix += edit_format - if self.multiline_mode: - prompt_prefix += (" " if edit_format else "") + "multi" - prompt_prefix += "> " - - show += prompt_prefix - self.prompt_prefix = prompt_prefix - - inp = "" - multiline_input = False + self.is_processing_prompt = True + try: + self.rule() + + rel_fnames = list(rel_fnames) + show = "" + if rel_fnames: + rel_read_only_fnames = [ + get_rel_fname(fname, root) for fname in abs_read_only_fnames or [] + ] + rel_read_only_stubs_fnames = [ + get_rel_fname(fname, root) for fname in abs_read_only_stubs_fnames or [] + ] + show = self.format_files_for_input( + rel_fnames, rel_read_only_fnames, rel_read_only_stubs_fnames + ) - style = self._get_style() + prompt_prefix = "" - completer_instance = ThreadedCompleter( - AutoCompleter( - root, - rel_fnames, - addable_rel_fnames, - commands, - self.encoding, - abs_read_only_fnames=(abs_read_only_fnames or set()) - | (abs_read_only_stubs_fnames or set()), + if edit_format: + prompt_prefix += edit_format + if self.multiline_mode: + prompt_prefix += (" " if edit_format else "") + "multi" + prompt_prefix += "> " + + show += prompt_prefix + self.prompt_prefix = prompt_prefix + + inp = "" + multiline_input = False + + style = self._get_style() + + completer_instance = ThreadedCompleter( + AutoCompleter( + root, + rel_fnames, + addable_rel_fnames, + commands, + self.encoding, + abs_read_only_fnames=(abs_read_only_fnames or set()) + | (abs_read_only_stubs_fnames or set()), + ) ) - ) - def suspend_to_bg(event): - """Suspend currently running application.""" - event.app.suspend_to_background() + def suspend_to_bg(event): + """Suspend currently running application.""" + event.app.suspend_to_background() - kb = KeyBindings() + kb = KeyBindings() - @kb.add(Keys.ControlZ, filter=Condition(lambda: hasattr(signal, "SIGTSTP"))) - def _(event): - "Suspend to background with ctrl-z" - suspend_to_bg(event) + @kb.add(Keys.ControlZ, filter=Condition(lambda: hasattr(signal, "SIGTSTP"))) + def _(event): + "Suspend to background with ctrl-z" + suspend_to_bg(event) - @kb.add("c-space") - def _(event): - "Ignore Ctrl when pressing space bar" - event.current_buffer.insert_text(" ") + @kb.add("c-space") + def _(event): + "Ignore Ctrl when pressing space bar" + event.current_buffer.insert_text(" ") - @kb.add("c-up") - def _(event): - "Navigate backward through history" - event.current_buffer.history_backward() + @kb.add("c-up") + def _(event): + "Navigate backward through history" + event.current_buffer.history_backward() - @kb.add("c-down") - def _(event): - "Navigate forward through history" - event.current_buffer.history_forward() + @kb.add("c-down") + def _(event): + "Navigate forward through history" + event.current_buffer.history_forward() - @kb.add("c-x", "c-e") - def _(event): - "Edit current input in external editor (like Bash)" - buffer = event.current_buffer - current_text = buffer.text + @kb.add("c-x", "c-e") + def _(event): + "Edit current input in external editor (like Bash)" + buffer = event.current_buffer + current_text = buffer.text - # Open the editor with the current text - edited_text = pipe_editor(input_data=current_text, suffix="md") + # Open the editor with the current text + edited_text = pipe_editor(input_data=current_text, suffix="md") - # Replace the buffer with the edited text, strip any trailing newlines - buffer.text = edited_text.rstrip("\n") + # Replace the buffer with the edited text, strip any trailing newlines + buffer.text = edited_text.rstrip("\n") - # Move cursor to the end of the text - buffer.cursor_position = len(buffer.text) - - @kb.add("c-t", filter=Condition(lambda: self.fzf_available)) - def _(event): - "Fuzzy find files to add to the chat" - buffer = event.current_buffer - if not buffer.text.strip().startswith("/add "): - return - - files = run_fzf(addable_rel_fnames, multi=True) - if files: - buffer.text = "/add " + " ".join(files) - buffer.cursor_position = len(buffer.text) - - @kb.add("c-r", filter=Condition(lambda: self.fzf_available)) - def _(event): - "Fuzzy search in history and paste it in the prompt" - buffer = event.current_buffer - history_lines = self.get_input_history() - selected_lines = run_fzf(history_lines) - if selected_lines: - buffer.text = "".join(selected_lines) + # Move cursor to the end of the text buffer.cursor_position = len(buffer.text) - @kb.add("enter", eager=True, filter=~is_searching) - def _(event): - "Handle Enter key press" - if self.multiline_mode and not ( - self.editingmode == EditingMode.VI - and event.app.vi_state.input_mode == InputMode.NAVIGATION - ): - # In multiline mode and if not in vi-mode or vi navigation/normal mode, - # Enter adds a newline - event.current_buffer.insert_text("\n") - else: - # In normal mode, Enter submits - event.current_buffer.validate_and_handle() - - @kb.add("escape", "enter", eager=True, filter=~is_searching) # This is Alt+Enter - def _(event): - "Handle Alt+Enter key press" - if self.multiline_mode: - # In multiline mode, Alt+Enter submits - event.current_buffer.validate_and_handle() - else: - # In normal mode, Alt+Enter adds a newline - event.current_buffer.insert_text("\n") - - while True: - if multiline_input: - show = self.prompt_prefix - - try: - self.interrupted = False - if not multiline_input: - if self.file_watcher: - self.file_watcher.start() - if self.clipboard_watcher: - self.clipboard_watcher.start() + @kb.add("c-t", filter=Condition(lambda: self.fzf_available)) + def _(event): + "Fuzzy find files to add to the chat" + buffer = event.current_buffer + if not buffer.text.strip().startswith("/add "): + return + + files = run_fzf(addable_rel_fnames, multi=True) + if files: + buffer.text = "/add " + " ".join(files) + buffer.cursor_position = len(buffer.text) + + @kb.add("c-r", filter=Condition(lambda: self.fzf_available)) + def _(event): + "Fuzzy search in history and paste it in the prompt" + buffer = event.current_buffer + history_lines = self.get_input_history() + selected_lines = run_fzf(history_lines) + if selected_lines: + buffer.text = "".join(selected_lines) + buffer.cursor_position = len(buffer.text) + + @kb.add("enter", eager=True, filter=~is_searching) + def _(event): + "Handle Enter key press" + if self.multiline_mode and not ( + self.editingmode == EditingMode.VI + and event.app.vi_state.input_mode == InputMode.NAVIGATION + ): + # In multiline mode and if not in vi-mode or vi navigation/normal mode, + # Enter adds a newline + event.current_buffer.insert_text("\n") + else: + # In normal mode, Enter submits + event.current_buffer.validate_and_handle() + + @kb.add("escape", "enter", eager=True, filter=~is_searching) # This is Alt+Enter + def _(event): + "Handle Alt+Enter key press" + if self.multiline_mode: + # In multiline mode, Alt+Enter submits + event.current_buffer.validate_and_handle() + else: + # In normal mode, Alt+Enter adds a newline + event.current_buffer.insert_text("\n") - if self.prompt_session: - # Use placeholder if set, then clear it - default = self.placeholder or "" - self.placeholder = None + while True: + if multiline_input: + show = self.prompt_prefix - def get_continuation(width, line_number, is_soft_wrap): - return self.prompt_prefix + try: + self.interrupted = False + if not multiline_input: + if self.file_watcher: + self.file_watcher.start() + if self.clipboard_watcher: + self.clipboard_watcher.start() + + if self.prompt_session: + # Use placeholder if set, then clear it + default = self.placeholder or "" + self.placeholder = None + + def get_continuation(width, line_number, is_soft_wrap): + return self.prompt_prefix + + line = await self.prompt_session.prompt_async( + show, + default=default, + completer=completer_instance, + reserve_space_for_menu=4, + complete_style=CompleteStyle.MULTI_COLUMN, + style=style, + key_bindings=kb, + complete_while_typing=True, + prompt_continuation=get_continuation, + ) + else: + try: + self.interruptible_input = InterruptibleInput() + except RuntimeError: + # Fallback to non-interruptible input (Windows ...) + line = await asyncio.get_event_loop().run_in_executor(None, input, show) + + if self.interruptible_input: + try: + line = await asyncio.get_event_loop().run_in_executor( + None, self.interruptible_input.input, show + ) + except InterruptedError: + self.interrupted = True + line = "" + finally: + self.interruptible_input.close() + self.interruptible_input = None + + # Check if we were interrupted by a file change + if self.interrupted: + line = line or "" + if self.file_watcher: + cmd = self.file_watcher.process_changes() + return cmd + + except EOFError: + coder = self.get_coder() + + if coder: + await coder.commands.execute("exit", "") + return "" + else: + raise SystemExit - line = await self.prompt_session.prompt_async( - show, - default=default, - completer=completer_instance, - reserve_space_for_menu=4, - complete_style=CompleteStyle.MULTI_COLUMN, - style=style, - key_bindings=kb, - complete_while_typing=True, - prompt_continuation=get_continuation, - ) - else: + except KeyboardInterrupt: + self.console.print() + return "" + except UnicodeEncodeError as err: + self.tool_error(str(err)) + return "" + except Exception as err: try: - self.interruptible_input = InterruptibleInput() - except RuntimeError: - # Fallback to non-interruptible input (Windows ...) - line = await asyncio.get_event_loop().run_in_executor(None, input, show) + self.prompt_session.app.exit() + except Exception: + pass - if self.interruptible_input: - try: - line = await asyncio.get_event_loop().run_in_executor( - None, self.interruptible_input.input, show - ) - except InterruptedError: - self.interrupted = True - line = "" - finally: - self.interruptible_input.close() - self.interruptible_input = None - - # Check if we were interrupted by a file change - if self.interrupted: - line = line or "" - if self.file_watcher: - cmd = self.file_watcher.process_changes() - return cmd + import traceback - except EOFError: - coder = self.get_coder() - - if coder: - await coder.commands.execute("exit", "") + self.tool_error(str(err)) + self.tool_error(traceback.format_exc()) return "" - else: - raise SystemExit - - except KeyboardInterrupt: - self.console.print() - return "" - except UnicodeEncodeError as err: - self.tool_error(str(err)) - return "" - except Exception as err: - try: - self.prompt_session.app.exit() - except Exception: - pass + finally: + if self.file_watcher: + self.file_watcher.stop() + if self.clipboard_watcher: + self.clipboard_watcher.stop() - import traceback + line = line or "" - self.tool_error(str(err)) - self.tool_error(traceback.format_exc()) - return "" - finally: - if self.file_watcher: - self.file_watcher.stop() - if self.clipboard_watcher: - self.clipboard_watcher.stop() - - line = line or "" - - if line.strip("\r\n") and not multiline_input: - stripped = line.strip("\r\n") - if stripped == "{": - multiline_input = True - multiline_tag = None - inp += "" - elif stripped[0] == "{": - # Extract tag if it exists (only alphanumeric chars) - tag = "".join(c for c in stripped[1:] if c.isalnum()) - if stripped == "{" + tag: + if line.strip("\r\n") and not multiline_input: + stripped = line.strip("\r\n") + if stripped == "{": multiline_input = True - multiline_tag = tag + multiline_tag = None inp += "" + elif stripped[0] == "{": + # Extract tag if it exists (only alphanumeric chars) + tag = "".join(c for c in stripped[1:] if c.isalnum()) + if stripped == "{" + tag: + multiline_input = True + multiline_tag = tag + inp += "" + else: + inp = line + break else: inp = line break - else: - inp = line - break - continue - elif multiline_input and line.strip(): - if multiline_tag: - # Check if line is exactly "tag}" - if line.strip("\r\n") == f"{multiline_tag}}}": + continue + elif multiline_input and line.strip(): + if multiline_tag: + # Check if line is exactly "tag}" + if line.strip("\r\n") == f"{multiline_tag}}}": + break + else: + inp += line + "\n" + # Check if line is exactly "}" + elif line.strip("\r\n") == "}": break else: inp += line + "\n" - # Check if line is exactly "}" - elif line.strip("\r\n") == "}": - break - else: + elif multiline_input: inp += line + "\n" - elif multiline_input: - inp += line + "\n" - else: - inp = line - break + else: + inp = line + break - self.user_input(inp) - return inp + self.user_input(inp) + return inp + finally: + self.is_processing_prompt = False + self.is_processing_prompt = False async def stop_input_task(self): if self.input_task: @@ -1715,6 +1723,8 @@ def get_default_notification_command(self): return None # Unknown system def _send_notification(self): + if self.is_processing_prompt: + return if self.notifications_command: try: result = subprocess.run(self.notifications_command, shell=True, capture_output=True) @@ -1728,6 +1738,8 @@ def _send_notification(self): def notify_user_input_required(self): """Send a notification that user input is required.""" + if self.is_processing_prompt: + return if self.notifications: self._send_notification() diff --git a/cecli/tools/command.py b/cecli/tools/command.py index 3b541eb8c8b..af9abf42bfa 100644 --- a/cecli/tools/command.py +++ b/cecli/tools/command.py @@ -131,7 +131,7 @@ async def execute( @classmethod async def _get_confirmation(cls, coder, command_string, background): """Get user confirmation for command execution.""" - if coder.skip_cli_confirmations: + if coder.skip_cli_confirmations or getattr(coder, "globally_approved_tool_calls", False): return True command_string = coder.format_command_with_prefix(command_string) diff --git a/cecli/tools/command_interactive.py b/cecli/tools/command_interactive.py index 45d3251bdcb..39ec75755f0 100644 --- a/cecli/tools/command_interactive.py +++ b/cecli/tools/command_interactive.py @@ -37,6 +37,7 @@ async def execute(cls, coder, command_string, **kwargs): confirmed = ( True if coder.skip_cli_confirmations + or getattr(coder, "globally_approved_tool_calls", False) else await coder.io.confirm_ask( "Allow execution of this command?", subject=command_string, @@ -72,6 +73,7 @@ def _run_interactive(): else: coder.io.tool_output(">>> You may need to interact with the command below <<<") coder.io.tool_output(" \n") + coder.io.bell_on_next_input = False await coder.io.stop_input_task() await asyncio.sleep(1) exit_status, combined_output = _run_interactive() diff --git a/cecli/tools/ls.py b/cecli/tools/ls.py index 9a9ed276340..443f74acd67 100644 --- a/cecli/tools/ls.py +++ b/cecli/tools/ls.py @@ -11,67 +11,78 @@ class Tool(BaseTool): SCHEMA = { "type": "function", "function": { - "name": "Ls", - "description": "List files in a directory.", + "name": "ls", + "description": "List files in a directory. Paths are relative to the project root.", "parameters": { "type": "object", "properties": { - "directory": { + "path": { "type": "string", - "description": "The directory to list.", - }, + "description": ( + "The path of the directory to list, relative to the project root. " + "Defaults to the project root." + ), + "default": ".", + } }, - "required": ["directory"], + "required": [], }, }, } @classmethod - def execute(cls, coder, dir_path=None, directory=None, **kwargs): - # Handle both positional and keyword arguments for backward compatibility - if dir_path is None and directory is not None: - dir_path = directory - elif dir_path is None: - return "Error: Missing directory parameter" + def execute(cls, coder, path=None, directory=None, **kwargs): """ List files in directory and optionally add some to context. This provides information about the structure of the codebase, similar to how a developer would explore directories. """ + # Handle both positional and keyword arguments for backward compatibility + dir_path = path or directory or "." + try: - # Make the path relative to root if it's absolute - if dir_path.startswith("/"): - rel_dir = os.path.relpath(dir_path, coder.root) - else: - rel_dir = dir_path + # Create an absolute path from the provided relative path + abs_path = os.path.abspath(os.path.join(coder.root, dir_path)) - # Get absolute path - abs_dir = coder.abs_root_path(rel_dir) + # Security check: ensure the resolved path is within the project root + if not abs_path.startswith(os.path.abspath(coder.root)): + coder.io.tool_error( + f"Error: Path '{dir_path}' attempts to access files outside the project root." + ) + return "Error: Path is outside the project root." # Check if path exists - if not os.path.exists(abs_dir): - coder.io.tool_output(f"⚠️ Directory '{dir_path}' not found") + if not os.path.exists(abs_path): + coder.io.tool_output(f"⚠️ Path '{dir_path}' not found") return "Directory not found" # Get directory contents contents = [] - try: - with os.scandir(abs_dir) as entries: - for entry in entries: - if entry.is_file() and not entry.name.startswith("."): - rel_path = os.path.join(rel_dir, entry.name) - contents.append(rel_path) - except NotADirectoryError: - # If it's a file, just return the file - contents = [rel_dir] + if os.path.isdir(abs_path): + # It's a directory, list its contents + try: + with os.scandir(abs_path) as entries: + for entry in entries: + if entry.is_file() and not entry.name.startswith("."): + rel_path = os.path.relpath(entry.path, coder.root) + contents.append(rel_path) + except OSError as e: + coder.io.tool_error(f"Error listing directory '{dir_path}': {e}") + return f"Error: {e}" + elif os.path.isfile(abs_path): + # It's a file, just return its relative path + contents.append(os.path.relpath(abs_path, coder.root)) if contents: coder.io.tool_output(f"📋 Listed {len(contents)} file(s) in '{dir_path}'") - if len(contents) > 10: - return f"Found {len(contents)} files: {', '.join(contents[:10])}..." + sorted_contents = sorted(contents) + if len(sorted_contents) > 10: + return ( + f"Found {len(sorted_contents)} files: {', '.join(sorted_contents[:10])}..." + ) else: - return f"Found {len(contents)} files: {', '.join(contents)}" + return f"Found {len(sorted_contents)} files: {', '.join(sorted_contents)}" else: coder.io.tool_output(f"📋 No files found in '{dir_path}'") return "No files found in directory" diff --git a/tests/basic/test_io.py b/tests/basic/test_io.py index cd838cbfbb7..876a2b96bc2 100644 --- a/tests/basic/test_io.py +++ b/tests/basic/test_io.py @@ -648,3 +648,40 @@ def test_format_files_for_input_pretty_true_mixed_files( args_ed, _ = mock_columns.call_args_list[2] renderables_ed = args_ed[0] assert renderables_ed == ["Editable:", "edit1.txt", "edit[markup].txt"] +import asyncio +from unittest.mock import MagicMock, patch + +import pytest + +from cecli.io import InputOutput + + +@pytest.mark.asyncio +async def test_notification_suppressed_during_processing(): + """ + Verify that notifications are not sent when a prompt is being processed. + """ + # Initialize InputOutput with notifications enabled + io = InputOutput(notifications=True) + io.is_processing_prompt = False # Start in idle state + + with patch.object(io, "_send_notification") as mock_send_notification: + # 1. Test when idle: notification should be sent + io.notify_user_input_required() + mock_send_notification.assert_called_once() + + # Reset mock for the next check + mock_send_notification.reset_mock() + + # 2. Test when processing: notification should be suppressed + io.is_processing_prompt = True + io.notify_user_input_required() + mock_send_notification.assert_not_called() + + # Reset mock for the next check + mock_send_notification.reset_mock() + + # 3. Test after processing: notification should be sent again + io.is_processing_prompt = False + io.notify_user_input_required() + mock_send_notification.assert_called_once()