diff --git a/.github/workflows/ubuntu-tests.yml b/.github/workflows/ubuntu-tests.yml index ad79b78dec4..66e03f9c44d 100644 --- a/.github/workflows/ubuntu-tests.yml +++ b/.github/workflows/ubuntu-tests.yml @@ -49,6 +49,7 @@ jobs: pip install uv uv pip install --system \ pytest \ + pytest-asyncio \ -r requirements/requirements.in \ -r requirements/requirements-browser.in \ -r requirements/requirements-help.in \ diff --git a/.github/workflows/windows-tests.yml b/.github/workflows/windows-tests.yml index 25c41c39d36..3d81acbe4ad 100644 --- a/.github/workflows/windows-tests.yml +++ b/.github/workflows/windows-tests.yml @@ -42,7 +42,7 @@ jobs: run: | python -m pip install --upgrade pip pip install uv - uv pip install --system pytest -r requirements/requirements.in -r requirements/requirements-browser.in -r requirements/requirements-help.in -r requirements/requirements-playwright.in '.[browser,help,playwright]' + uv pip install --system pytest pytest-asyncio -r requirements/requirements.in -r requirements/requirements-browser.in -r requirements/requirements-help.in -r requirements/requirements-playwright.in '.[browser,help,playwright]' - name: Run tests env: diff --git a/README.md b/README.md index 0f905cdfd03..f0e98f46b4a 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ The current priorities are to improve core capabilities and user experience of t 5. **TUI Experience** - [Discussion](https://github.com/dwash96/aider-ce/issues/48) * [ ] Add a full TUI (probably using textual) to have a visual interface competitive with the other coding agent terminal programs - * [ ] Re-integrate pretty output formatting + * [x] Re-integrate pretty output formatting * [ ] Implement a response area, a prompt area with current auto completion capabilities, and a helper area for management utility commands ## Fork Additions @@ -36,6 +36,7 @@ This project aims to be compatible with upstream Aider, but with priority commit * [MCP Multi Tool Response](https://github.com/quinlanjager/aider/pull/1) * [Navigator Mode: #3781](https://github.com/Aider-AI/aider/pull/3781) * [Navigator Mode Large File Count](https://github.com/Aider-AI/aider/commit/b88a7bda649931798209945d9687718316c7427f) + * [Fix navigator mode auto commit](https://github.com/dwash96/aider-ce/issues/38) * [Qwen 3: #4383](https://github.com/Aider-AI/aider/pull/4383) * [Fuzzy Search: #4366](https://github.com/Aider-AI/aider/pull/4366) * [Map Cache Location Config: #2911](https://github.com/Aider-AI/aider/pull/2911) @@ -66,20 +67,34 @@ This project aims to be compatible with upstream Aider, but with priority commit * [MCP Configuration](https://github.com/dwash96/aider-ce/blob/main/aider/website/docs/config/mcp.md) ### Installation Instructions -This project should be installable using the commands +This project can be installed using several methods: -``` +### Package Installation +```bash pip install aider-ce ``` or -``` +```bash uv pip install aider-ce ``` The package exports an `aider-ce` command that accepts all of Aider's configuration options +### Tool Installation +```bash +uv tool install --python python3.12 aider-ce +``` + +Use the tool installation so aider doesn't interfere with your development environment + +### All Contributors (Both Aider Main and Aider-CE) + + + + +

Aider Logo

diff --git a/aider/__init__.py b/aider/__init__.py index 67bd1c17f2a..b2dc4171a2e 100644 --- a/aider/__init__.py +++ b/aider/__init__.py @@ -1,6 +1,6 @@ from packaging import version -__version__ = "0.87.13.dev" +__version__ = "0.88.0.dev" safe_version = __version__ try: diff --git a/aider/args.py b/aider/args.py index b75436d1abd..5b19006e040 100644 --- a/aider/args.py +++ b/aider/args.py @@ -770,6 +770,12 @@ def get_parser(default_config_files, git_root): ###### group = parser.add_argument_group("Other settings") + group.add_argument( + "--preserve-todo-list", + action="store_true", + help="Preserve the existing .aider.todo.txt file on startup (default: False)", + default=False, + ) group.add_argument( "--disable-playwright", action="store_true", diff --git a/aider/coders/architect_coder.py b/aider/coders/architect_coder.py index f3e2a38b13a..a7cba79eb2e 100644 --- a/aider/coders/architect_coder.py +++ b/aider/coders/architect_coder.py @@ -8,7 +8,7 @@ class ArchitectCoder(AskCoder): gpt_prompts = ArchitectPrompts() auto_accept_architect = False - def reply_completed(self): + async def reply_completed(self): content = self.partial_response_content if not content or not content.strip(): @@ -34,14 +34,14 @@ def reply_completed(self): new_kwargs = dict(io=self.io, from_coder=self) new_kwargs.update(kwargs) - editor_coder = Coder.create(**new_kwargs) + editor_coder = await Coder.create(**new_kwargs) editor_coder.cur_messages = [] editor_coder.done_messages = [] if self.verbose: editor_coder.show_announcements() - editor_coder.run(with_message=content, preproc=False) + await editor_coder.run(with_message=content, preproc=False) self.move_back_cur_messages("I made those changes to the files.") self.total_cost = editor_coder.total_cost diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index 4f536a80699..3645fedf700 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -14,6 +14,7 @@ import threading import time import traceback +import weakref from collections import defaultdict from datetime import datetime @@ -27,7 +28,16 @@ from pathlib import Path from typing import List +import httpx from litellm import experimental_mcp_client +from litellm.types.utils import ( + ChatCompletionMessageToolCall, + Choices, + Function, + Message, + ModelResponse, +) +from prompt_toolkit.patch_stdout import patch_stdout from rich.console import Console from aider import __version__, models, prompts, urls, utils @@ -50,7 +60,6 @@ from aider.repomap import RepoMap from aider.run_cmd import run_cmd from aider.utils import format_content, format_messages, format_tokens, is_image_file -from aider.waiting import WaitingSpinner from ..dump import dump # noqa: F401 from .chat_chunks import ChatChunks @@ -115,7 +124,7 @@ class Coder: test_outcome = None multi_response_content = "" partial_response_content = "" - partial_response_tool_call = [] + partial_response_tool_calls = [] commit_before_message = [] message_cost = 0.0 add_cache_headers = False @@ -129,6 +138,10 @@ class Coder: file_watcher = None mcp_servers = None mcp_tools = None + run_one_completed = True + compact_context_completed = True + suppress_announcements_for_next_prompt = False + tool_reflection = False # Context management settings (for all modes) context_management_enabled = False # Disabled by default except for navigator mode @@ -137,7 +150,7 @@ class Coder: ) @classmethod - def create( + async def create( self, main_model=None, edit_format=None, @@ -175,7 +188,7 @@ def create( done_messages = from_coder.done_messages if edit_format != from_coder.edit_format and done_messages and summarize_from_coder: try: - done_messages = from_coder.summarizer.summarize_all(done_messages) + done_messages = await from_coder.summarizer.summarize_all(done_messages) except ValueError: # If summarization fails, keep the original messages and warn the user io.tool_warning( @@ -208,6 +221,7 @@ def create( for coder in coders.__all__: if hasattr(coder, "edit_format") and coder.edit_format == edit_format: res = coder(main_model, io, **kwargs) + await res.initialize_mcp_tools() res.original_kwargs = dict(kwargs) return res @@ -218,8 +232,8 @@ def create( ] raise UnknownEditFormat(edit_format, valid_formats) - def clone(self, **kwargs): - new_coder = Coder.create(from_coder=self, **kwargs) + async def clone(self, **kwargs): + new_coder = await Coder.create(from_coder=self, **kwargs) return new_coder def get_announcements(self): @@ -296,13 +310,11 @@ def get_announcements(self): else: lines.append("Repo-map: disabled") - # Files - for fname in self.get_inchat_relative_files(): - lines.append(f"Added {fname} to the chat.") - - for fname in self.abs_read_only_fnames: - rel_fname = self.get_rel_fname(fname) - lines.append(f"Added {rel_fname} to the chat (read-only).") + if self.mcp_tools: + mcp_servers = [] + for server_name, server_tools in self.mcp_tools: + mcp_servers.append(server_name) + lines.append(f"MCP servers configured: {', '.join(mcp_servers)}") for fname in self.abs_read_only_stubs_fnames: rel_fname = self.get_rel_fname(fname) @@ -368,6 +380,7 @@ def __init__( context_compaction_summary_tokens=8192, map_cache_dir=".", repomap_in_memory=False, + preserve_todo_list=False, ): # initialize from args.map_cache_dir self.map_cache_dir = map_cache_dir @@ -385,6 +398,7 @@ def __init__( self.auto_copy_context = auto_copy_context self.auto_accept_architect = auto_accept_architect + self.preserve_todo_list = preserve_todo_list self.ignore_mentions = ignore_mentions if not self.ignore_mentions: @@ -442,8 +456,10 @@ def __init__( self.done_messages = [] self.io = io + self.io.coder = weakref.ref(self) self.shell_commands = [] + self.partial_response_tool_calls = [] if not auto_commits: dirty_commits = False @@ -455,6 +471,7 @@ def __init__( self.pretty = self.io.pretty self.main_model = main_model + # Set the reasoning tag name based on model settings or default self.reasoning_tag_name = ( self.main_model.reasoning_tag if self.main_model.reasoning_tag else REASONING_TAG @@ -568,6 +585,10 @@ def __init__( self.summarizer_thread = None self.summarized_done_messages = [] self.summarizing_messages = None + self.input_task = None + self.confirmation_in_progress = False + + self.files_edited_by_tools = set() if not self.done_messages and restore_chat_history: history_md = self.io.read_text(self.io.chat_history_file) @@ -583,9 +604,21 @@ def __init__( self.auto_test = auto_test self.test_cmd = test_cmd + # Clean up todo list file on startup unless preserve_todo_list is True + if not getattr(self, "preserve_todo_list", False): + todo_file_path = ".aider.todo.txt" + abs_path = self.abs_root_path(todo_file_path) + if os.path.isfile(abs_path): + try: + os.remove(abs_path) + if self.verbose: + self.io.tool_output(f"Removed existing todo list file: {todo_file_path}") + except Exception as e: + self.io.tool_warning(f"Could not remove todo list file {todo_file_path}: {e}") + # Instantiate MCP tools if self.mcp_servers: - self.initialize_mcp_tools() + pass # validate the functions jsonschema if self.functions: from jsonschema import Draft7Validator @@ -644,12 +677,7 @@ def show_pretty(self): def _stop_waiting_spinner(self): """Stop and clear the waiting spinner if it is running.""" - spinner = getattr(self, "waiting_spinner", None) - if spinner: - try: - spinner.stop() - finally: - self.waiting_spinner = None + self.io.stop_spinner() def get_abs_fnames_content(self): for fname in list(self.abs_fnames): @@ -1013,10 +1041,11 @@ def get_images_message(self, fnames): return {"role": "user", "content": image_messages} - def run_stream(self, user_message): + async def run_stream(self, user_message): self.io.user_input(user_message) self.init_before_message() - yield from self.send_message(user_message) + async for chunk in self.send_message(user_message): + yield chunk def init_before_message(self): self.aider_edited_files = set() @@ -1030,36 +1059,139 @@ def init_before_message(self): if self.repo: self.commit_before_message.append(self.repo.get_head_commit_sha()) - def run(self, with_message=None, preproc=True): + async def run(self, with_message=None, preproc=True): + while self.confirmation_in_progress: + await asyncio.sleep(0.1) # Yield control and wait briefly + + if self.io.prompt_session: + with patch_stdout(raw=True): + return await self._run_patched(with_message, preproc) + else: + return await self._run_patched(with_message, preproc) + + async def _run_patched(self, with_message=None, preproc=True): + input_task = None + processing_task = None try: if with_message: self.io.user_input(with_message) - self.run_one(with_message, preproc) + await self.run_one(with_message, preproc) return self.partial_response_content + + user_message = None + while True: try: - if not self.io.placeholder: + if ( + not self.confirmation_in_progress + and not input_task + and not user_message + and (not processing_task or not self.io.placeholder) + ): + if not self.suppress_announcements_for_next_prompt: + self.show_announcements() + self.suppress_announcements_for_next_prompt = False + + # Stop spinner before showing announcements or getting input + self.io.stop_spinner() + self.copy_context() - user_message = self.get_input() - self.compact_context_if_needed() - self.run_one(user_message, preproc) - self.show_undo_hint() + self.input_task = asyncio.create_task(self.get_input()) + input_task = self.input_task + + tasks = set() + if processing_task: + tasks.add(processing_task) + if input_task: + tasks.add(input_task) + + if tasks: + done, pending = await asyncio.wait( + tasks, return_when=asyncio.FIRST_COMPLETED + ) + + if input_task and input_task in done: + if processing_task: + if not self.confirmation_in_progress: + processing_task.cancel() + try: + await processing_task + except asyncio.CancelledError: + pass + self.io.stop_spinner() + processing_task = None + + try: + user_message = input_task.result() + except (asyncio.CancelledError, KeyboardInterrupt): + user_message = None + input_task = None + self.input_task = None + if user_message is None: + continue + + if processing_task and processing_task in done: + try: + await processing_task + except (asyncio.CancelledError, KeyboardInterrupt): + pass + processing_task = None + # Stop spinner when processing task completes + self.io.stop_spinner() + + if user_message and self.run_one_completed and self.compact_context_completed: + processing_task = asyncio.create_task( + self._processing_logic(user_message, preproc) + ) + # Start spinner for processing task + self.io.start_spinner("Processing...") + user_message = None # Clear message after starting task except KeyboardInterrupt: + if processing_task: + processing_task.cancel() + processing_task = None + # Stop spinner when processing task is cancelled + self.io.stop_spinner() + if input_task: + self.io.set_placeholder("") + input_task.cancel() + input_task = None self.keyboard_interrupt() except EOFError: return + finally: + if input_task: + input_task.cancel() + if processing_task: + processing_task.cancel() + + async def _processing_logic(self, user_message, preproc): + try: + self.compact_context_completed = False + await self.compact_context_if_needed() + self.compact_context_completed = True + + self.run_one_completed = False + await self.run_one(user_message, preproc) + self.show_undo_hint() + except asyncio.CancelledError: + # Don't show undo hint if cancelled + raise + finally: + self.run_one_completed = True + self.compact_context_completed = True def copy_context(self): if self.auto_copy_context: self.commands.cmd_copy_context() - def get_input(self): + async def get_input(self): inchat_files = self.get_inchat_relative_files() all_read_only_fnames = self.abs_read_only_fnames | self.abs_read_only_stubs_fnames all_read_only_files = [self.get_rel_fname(fname) for fname in all_read_only_fnames] all_files = sorted(set(inchat_files + all_read_only_files)) edit_format = "" if self.edit_format == self.main_model.edit_format else self.edit_format - return self.io.get_input( + return await self.io.get_input( self.root, all_files, self.get_addable_relative_files(), @@ -1069,29 +1201,35 @@ def get_input(self): edit_format=edit_format, ) - def preproc_user_input(self, inp): + async def preproc_user_input(self, inp): if not inp: return if self.commands.is_command(inp): - return self.commands.run(inp) + return await self.commands.run(inp) - self.check_for_file_mentions(inp) - inp = self.check_for_urls(inp) + await self.check_for_file_mentions(inp) + inp = await self.check_for_urls(inp) return inp - def run_one(self, user_message, preproc): + async def run_one(self, user_message, preproc): self.init_before_message() if preproc: - message = self.preproc_user_input(user_message) + message = await self.preproc_user_input(user_message) else: message = user_message - while message: + if self.commands.is_command(user_message): + return + + while True: self.reflected_message = None - list(self.send_message(message)) + self.tool_reflection = False + + async for _ in self.send_message(message): + pass if not self.reflected_message: break @@ -1101,7 +1239,14 @@ def run_one(self, user_message, preproc): return self.num_reflections += 1 - message = self.reflected_message + + if self.tool_reflection: + self.num_reflections -= 1 + + if self.reflected_message is True: + message = None + else: + message = self.reflected_message def check_and_open_urls(self, exc, friendly_msg=None): """Check exception for URLs, offer to open in a browser, with user-friendly error msgs.""" @@ -1122,7 +1267,7 @@ def check_and_open_urls(self, exc, friendly_msg=None): self.io.offer_url(url) return urls - def check_for_urls(self, inp: str) -> List[str]: + async def check_for_urls(self, inp: str) -> List[str]: """Check input for URLs and offer to add them to the chat.""" if not self.detect_urls: return inp @@ -1135,11 +1280,11 @@ def check_for_urls(self, inp: str) -> List[str]: for url in urls: if url not in self.rejected_urls: url = url.rstrip(".',\"") - if self.io.confirm_ask( + if await self.io.confirm_ask( "Add URL to the chat?", subject=url, group=group, allow_never=True ): inp += "\n\n" - inp += self.commands.cmd_web(url, return_content=True) + inp += await self.commands.cmd_web(url, return_content=True) else: self.rejected_urls.add(url) @@ -1149,17 +1294,9 @@ def keyboard_interrupt(self): # Ensure cursor is visible on exit Console().show_cursor(True) - now = time.time() - - thresh = 2 # seconds - if self.last_keyboard_interrupt and now - self.last_keyboard_interrupt < thresh: - self.io.tool_warning("\n\n^C KeyboardInterrupt") - self.event("exit", reason="Control-C") - sys.exit() - - self.io.tool_warning("\n\n^C again to exit") + self.io.tool_warning("\n\n^C KeyboardInterrupt") - self.last_keyboard_interrupt = now + self.last_keyboard_interrupt = time.time() def summarize_start(self): if not self.summarizer.check_max_tokens(self.done_messages): @@ -1176,7 +1313,9 @@ def summarize_start(self): def summarize_worker(self): self.summarizing_messages = list(self.done_messages) try: - self.summarized_done_messages = self.summarizer.summarize(self.summarizing_messages) + self.summarized_done_messages = asyncio.run( + self.summarizer.summarize(self.summarizing_messages) + ) except ValueError as err: self.io.tool_warning(err.args[0]) self.summarized_done_messages = self.summarizing_messages @@ -1196,7 +1335,7 @@ def summarize_end(self): self.summarizing_messages = None self.summarized_done_messages = [] - def compact_context_if_needed(self): + async def compact_context_if_needed(self): if not self.enable_context_compaction: self.summarize_start() return @@ -1210,7 +1349,7 @@ def compact_context_if_needed(self): try: # Create a summary of the conversation - summary_text = self.summarizer.summarize_all_as_text( + summary_text = await self.summarizer.summarize_all_as_text( self.done_messages, self.gpt_prompts.compaction_prompt, self.context_compaction_summary_tokens, @@ -1517,7 +1656,13 @@ def format_chat_chunks(self): cur_tokens = self.main_model.token_count(chunks.cur) if None not in (messages_tokens, reminder_tokens, cur_tokens): - total_tokens = messages_tokens + reminder_tokens + cur_tokens + total_tokens = messages_tokens + # Only add tokens for reminder and cur if they're not already included + # in the messages_tokens calculation + if not chunks.reminder: + total_tokens += reminder_tokens + if not chunks.cur: + total_tokens += cur_tokens else: # add the reminder anyway total_tokens = 0 @@ -1633,15 +1778,16 @@ def check_tokens(self, messages): return False return True - def send_message(self, inp): + async def send_message(self, inp): self.event("message_send_starting") # Notify IO that LLM processing is starting self.io.llm_started() - self.cur_messages += [ - dict(role="user", content=inp), - ] + if inp: + self.cur_messages += [ + dict(role="user", content=inp), + ] chunks = self.format_messages() messages = chunks.all_messages() @@ -1655,10 +1801,13 @@ def send_message(self, inp): self.multi_response_content = "" if self.show_pretty(): - self.waiting_spinner = WaitingSpinner("Waiting for " + self.main_model.name) - self.waiting_spinner.start() + spinner_text = ( + f"Waiting for {self.main_model.name} • ${self.format_cost(self.total_cost)} session" + ) + self.io.start_spinner(spinner_text) + if self.stream: - self.mdstream = self.io.get_assistant_mdstream() + self.mdstream = True else: self.mdstream = None else: @@ -1674,7 +1823,8 @@ def send_message(self, inp): try: while True: try: - yield from self.send(messages, functions=self.functions) + async for chunk in self.send(messages, tools=self.get_tool_list()): + yield chunk break except litellm_ex.exceptions_tuple() as err: ex_info = litellm_ex.get_ex_info(err) @@ -1702,7 +1852,7 @@ def send_message(self, inp): self.io.tool_error(err_msg) self.io.tool_output(f"Retrying in {retry_delay:.1f} seconds...") - time.sleep(retry_delay) + await asyncio.sleep(retry_delay) continue except KeyboardInterrupt: interrupted = True @@ -1730,8 +1880,9 @@ def send_message(self, inp): return finally: if self.mdstream: - self.live_incremental_response(True) - self.mdstream = None + content_to_show = self.live_incremental_response(True) + self.stream_wrapper(content_to_show, final=True) + self.mdstream = None # Ensure any waiting spinner is stopped self._stop_waiting_spinner() @@ -1740,11 +1891,6 @@ def send_message(self, inp): self.remove_reasoning_content() self.multi_response_content = "" - ### - # print() - # print("=" * 20) - # dump(self.partial_response_content) - self.io.tool_output() self.show_usage_report() @@ -1785,11 +1931,11 @@ def send_message(self, inp): ] return - edited = self.apply_updates() + edited = await self.apply_updates() if edited: self.aider_edited_files.update(edited) - saved_message = self.auto_commit(edited) + saved_message = await self.auto_commit(edited) if not saved_message and hasattr(self.gpt_prompts, "files_content_gpt_edits_no_repo"): saved_message = self.gpt_prompts.files_content_gpt_edits_no_repo @@ -1797,7 +1943,7 @@ def send_message(self, inp): self.move_back_cur_messages(saved_message) if not interrupted: - add_rel_files_message = self.check_for_file_mentions(content) + add_rel_files_message = await self.check_for_file_mentions(content) if add_rel_files_message: if self.reflected_message: self.reflected_message += "\n\n" + add_rel_files_message @@ -1806,15 +1952,49 @@ def send_message(self, inp): return # Process any tools using MCP servers - tool_call_response = litellm.stream_chunk_builder(self.partial_response_tool_call) - if self.process_tool_calls(tool_call_response): - self.num_tool_calls += 1 - return self.run(with_message="Continue with tool call response", preproc=False) + try: + if self.partial_response_tool_calls: + tool_calls = [] + for tool_call_dict in self.partial_response_tool_calls: + tool_calls.append( + ChatCompletionMessageToolCall( + id=tool_call_dict.get("id"), + function=Function( + name=tool_call_dict.get("function", {}).get("name"), + arguments=tool_call_dict.get("function", {}).get( + "arguments", "" + ), + ), + type=tool_call_dict.get("type"), + ) + ) + + tool_call_response = ModelResponse( + choices=[ + Choices( + finish_reason="tool_calls", + index=0, + message=Message( + content=None, + role="assistant", + tool_calls=tool_calls, + ), + ) + ] + ) + + if await self.process_tool_calls(tool_call_response): + self.num_tool_calls += 1 + self.reflected_message = True + return + except Exception as e: + self.io.tool_error(f"Error processing tool calls: {str(e)}") + # Continue without tool processing self.num_tool_calls = 0 try: - if self.reply_completed(): + if await self.reply_completed(): return except KeyboardInterrupt: interrupted = True @@ -1824,15 +2004,15 @@ def send_message(self, inp): if edited and self.auto_lint: lint_errors = self.lint_edited(edited) - self.auto_commit(edited, context="Ran the linter") + await self.auto_commit(edited, context="Ran the linter") self.lint_outcome = not lint_errors if lint_errors: - ok = self.io.confirm_ask("Attempt to fix lint errors?") + ok = await self.io.confirm_ask("Attempt to fix lint errors?") if ok: self.reflected_message = lint_errors return - shared_output = self.run_shell_commands() + shared_output = await self.run_shell_commands() if shared_output: self.cur_messages += [ dict(role="user", content=shared_output), @@ -1840,19 +2020,33 @@ def send_message(self, inp): ] if edited and self.auto_test: - test_errors = self.commands.cmd_test(self.test_cmd) + test_errors = await self.commands.cmd_test(self.test_cmd) self.test_outcome = not test_errors if test_errors: - ok = self.io.confirm_ask("Attempt to fix test errors?") + ok = await self.io.confirm_ask("Attempt to fix test errors?") if ok: self.reflected_message = test_errors return - def process_tool_calls(self, tool_call_response): + async def process_tool_calls(self, tool_call_response): if tool_call_response is None: return False - original_tool_calls = tool_call_response.choices[0].message.tool_calls + # Handle different response structures + try: + # Try to get tool calls from the standard OpenAI response format + if hasattr(tool_call_response, "choices") and tool_call_response.choices: + message = tool_call_response.choices[0].message + if hasattr(message, "tool_calls") and message.tool_calls: + original_tool_calls = message.tool_calls + else: + return False + else: + # Handle other response formats + return False + except (AttributeError, IndexError): + return False + if not original_tool_calls: return False @@ -1888,22 +2082,14 @@ def process_tool_calls(self, tool_call_response): ) expanded_tool_calls.append(new_tool_call) - # Replace the original tool_calls in the response object with the expanded list. - tool_call_response.choices[0].message.tool_calls = expanded_tool_calls - tool_calls = expanded_tool_calls - # Collect all tool calls grouped by server - server_tool_calls = self._gather_server_tool_calls(tool_calls) + server_tool_calls = self._gather_server_tool_calls(expanded_tool_calls) if server_tool_calls and self.num_tool_calls < self.max_tool_calls: self._print_tool_call_info(server_tool_calls) - if self.io.confirm_ask("Run tools?"): - tool_responses = self._execute_tool_calls(server_tool_calls) - - # Add the assistant message with the modified (expanded) tool calls. - # This ensures that what's stored in history is valid. - self.cur_messages.append(tool_call_response.choices[0].message.to_dict()) + if await self.io.confirm_ask("Run tools?"): + tool_responses = await self._execute_tool_calls(server_tool_calls) # Add all tool responses for tool_response in tool_responses: @@ -1917,12 +2103,45 @@ def process_tool_calls(self, tool_call_response): def _print_tool_call_info(self, server_tool_calls): """Print information about an MCP tool call.""" - self.io.tool_output("Preparing to run MCP tools", bold=True) + self.io.tool_output("Preparing to run MCP tools", bold=False) for server, tool_calls in server_tool_calls.items(): for tool_call in tool_calls: self.io.tool_output(f"Tool Call: {tool_call.function.name}") - self.io.tool_output(f"Arguments: {tool_call.function.arguments}") + + # Parse and format arguments as headers with values + if tool_call.function.arguments: + # Only do JSON unwrapping for tools containing "replace" in their name + if "replace" in tool_call.function.name.lower(): + try: + args_dict = json.loads(tool_call.function.arguments) + first_key = True + for key, value in args_dict.items(): + # Convert explicit \\n sequences to actual newlines using regex + # Only match \\n that is not preceded by any other backslashes + if isinstance(value, str): + value = re.sub(r"(? 0: - self.io.tool_output("MCP servers configured:") - for server_name, server_tools in tools: - self.io.tool_output(f" - {server_name}") + if self.verbose: + self.io.tool_output("MCP servers configured:") + + for server_name, server_tools in tools: + self.io.tool_output(f" - {server_name}") - if self.verbose: for tool in server_tools: tool_name = tool.get("function", {}).get("name", "unknown") tool_desc = tool.get("function", {}).get("description", "").split("\n")[0] @@ -2172,7 +2399,7 @@ def get_tool_list(self): tool_list.extend(server_tools) return tool_list - def reply_completed(self): + async def reply_completed(self): pass def show_exhausted_error(self): @@ -2250,16 +2477,33 @@ def __del__(self): self.ok_to_warm_cache = False def add_assistant_reply_to_cur_messages(self): - if self.partial_response_content: - self.cur_messages += [dict(role="assistant", content=self.partial_response_content)] - if self.partial_response_function_call: - self.cur_messages += [ - dict( - role="assistant", - content=None, - function_call=self.partial_response_function_call, - ) - ] + """ + Add the assistant's reply to `cur_messages`. + Handles model-specific quirks, like Deepseek which requires `content` + to be `None` when `tool_calls` are present. + """ + msg = dict(role="assistant") + has_tool_calls = self.partial_response_tool_calls or self.partial_response_function_call + + # If we have tool calls and we're using a Deepseek model, force content to be None. + if has_tool_calls and self.main_model.is_deepseek(): + msg["content"] = None + else: + # Otherwise, use logic similar to the base implementation. + content = self.partial_response_content + if content: + msg["content"] = content + elif has_tool_calls: + msg["content"] = None + + if self.partial_response_tool_calls: + msg["tool_calls"] = self.partial_response_tool_calls + elif self.partial_response_function_call: + msg["function_call"] = self.partial_response_function_call + + # Only add a message if it's not empty. + if msg.get("content") is not None or msg.get("tool_calls") or msg.get("function_call"): + self.cur_messages.append(msg) def get_file_mentions(self, content, ignore_current=False): words = set(word for word in content.split()) @@ -2309,7 +2553,7 @@ def get_file_mentions(self, content, ignore_current=False): return mentioned_rel_fnames - def check_for_file_mentions(self, content): + async def check_for_file_mentions(self, content): mentioned_rel_fnames = self.get_file_mentions(content) new_mentions = mentioned_rel_fnames - self.ignore_mentions @@ -2320,7 +2564,7 @@ def check_for_file_mentions(self, content): added_fnames = [] group = ConfirmGroup(new_mentions) for rel_fname in sorted(new_mentions): - if self.io.confirm_ask( + if await self.io.confirm_ask( "Add file to the chat?", subject=rel_fname, group=group, allow_never=True ): self.add_rel_fname(rel_fname) @@ -2331,35 +2575,38 @@ def check_for_file_mentions(self, content): if added_fnames: return prompts.added_files.format(fnames=", ".join(added_fnames)) - def send(self, messages, model=None, functions=None): + async def send(self, messages, model=None, functions=None, tools=None): self.got_reasoning_content = False self.ended_reasoning_content = False + self._streaming_buffer_length = 0 + self.io.reset_streaming_response() + if not model: model = self.main_model self.partial_response_content = "" self.partial_response_function_call = dict() + self.partial_response_tool_calls = [] self.io.log_llm_history("TO LLM", format_messages(messages)) completion = None try: - tool_list = self.get_tool_list() - - hash_object, completion = model.send_completion( + hash_object, completion = await model.send_completion( messages, functions, self.stream, self.temperature, # This could include any tools, but for now it is just MCP tools - tools=tool_list, + tools=tools, ) self.chat_completion_call_hashes.append(hash_object.hexdigest()) if self.stream: - yield from self.show_send_output_stream(completion) + async for chunk in self.show_send_output_stream(completion): + yield chunk else: self.show_send_output(completion) @@ -2390,9 +2637,6 @@ def send(self, messages, model=None, functions=None): self.io.ai_output(json.dumps(args, indent=4)) def show_send_output(self, completion): - # Stop spinner once we have a response - self._stop_waiting_spinner() - if self.verbose: print(completion) @@ -2453,11 +2697,14 @@ def show_send_output(self, completion): ): raise FinishReasonLength() - def show_send_output_stream(self, completion): + async def show_send_output_stream(self, completion): received_content = False - self.partial_response_tool_call = [] - for chunk in completion: + async for chunk in completion: + # Check if confirmation is in progress and wait if needed + while self.confirmation_in_progress: + await asyncio.sleep(0.1) # Yield control and wait briefly + if isinstance(chunk, str): text = chunk received_content = True @@ -2471,13 +2718,59 @@ def show_send_output_stream(self, completion): ): raise FinishReasonLength() - if chunk.choices[0].delta.tool_calls: - self.partial_response_tool_call.append(chunk) + try: + if chunk.choices[0].delta.tool_calls: + received_content = True + for tool_call_chunk in chunk.choices[0].delta.tool_calls: + self.tool_reflection = True + + index = tool_call_chunk.index + if len(self.partial_response_tool_calls) <= index: + self.partial_response_tool_calls.extend( + [{}] * (index - len(self.partial_response_tool_calls) + 1) + ) + + if tool_call_chunk.id: + self.partial_response_tool_calls[index]["id"] = tool_call_chunk.id + if tool_call_chunk.type: + self.partial_response_tool_calls[index][ + "type" + ] = tool_call_chunk.type + if tool_call_chunk.function: + if "function" not in self.partial_response_tool_calls[index]: + self.partial_response_tool_calls[index]["function"] = {} + if tool_call_chunk.function.name: + if ( + "name" + not in self.partial_response_tool_calls[index]["function"] + ): + self.partial_response_tool_calls[index]["function"][ + "name" + ] = "" + self.partial_response_tool_calls[index]["function"][ + "name" + ] += tool_call_chunk.function.name + if tool_call_chunk.function.arguments: + if ( + "arguments" + not in self.partial_response_tool_calls[index]["function"] + ): + self.partial_response_tool_calls[index]["function"][ + "arguments" + ] = "" + self.partial_response_tool_calls[index]["function"][ + "arguments" + ] += tool_call_chunk.function.arguments + except (AttributeError, IndexError): + # Handle cases where the response structure doesn't match expectations + pass try: func = chunk.choices[0].delta.function_call # dump(func) for k, v in func.items(): + self.tool_reflection = True + if k in self.partial_response_function_call: self.partial_response_function_call[k] += v else: @@ -2516,36 +2809,62 @@ def show_send_output_stream(self, completion): except AttributeError: pass - if received_content: - self._stop_waiting_spinner() self.partial_response_content += text - if self.show_pretty(): - self.live_incremental_response(False) + # Use simplified streaming - just call the method with full content + content_to_show = self.live_incremental_response(False) + self.stream_wrapper(content_to_show, final=False) elif text: - # Apply reasoning tag formatting + # Apply reasoning tag formatting for non-pretty output text = replace_reasoning_tags(text, self.reasoning_tag_name) try: - sys.stdout.write(text) + self.stream_wrapper(text, final=False) except UnicodeEncodeError: # Safely encode and decode the text safe_text = text.encode(sys.stdout.encoding, errors="backslashreplace").decode( sys.stdout.encoding ) - sys.stdout.write(safe_text) - sys.stdout.flush() + self.stream_wrapper(safe_text, final=False) yield text - if not received_content and len(self.partial_response_tool_call) == 0: + if not received_content and len(self.partial_response_tool_calls) == 0: self.io.tool_warning("Empty response received from LLM. Check your provider account?") + def stream_wrapper(self, content, final): + if not hasattr(self, "_streaming_buffer_length"): + self._streaming_buffer_length = 0 + + if final: + content += "\n\n" + + if isinstance(content, str): + self._streaming_buffer_length += len(content) + + self.io.stream_output(content, final=final) + + if final: + self._streaming_buffer_length = 0 + def live_incremental_response(self, final): show_resp = self.render_incremental_response(final) # Apply any reasoning tag formatting show_resp = replace_reasoning_tags(show_resp, self.reasoning_tag_name) - self.mdstream.update(show_resp, final=final) + + # Track streaming state to avoid repetitive output + if not hasattr(self, "_streaming_buffer_length"): + self._streaming_buffer_length = 0 + + # Only send new content that hasn't been streamed yet + if len(show_resp) >= self._streaming_buffer_length: + new_content = show_resp[self._streaming_buffer_length :] + return new_content + else: + self._streaming_buffer_length = 0 + self.io.reset_streaming_response() + return show_resp def render_incremental_response(self, final): + # Just return the current content - the streaming logic will handle incremental updates return self.get_multi_response_content_in_progress() def remove_reasoning_content(self): @@ -2611,18 +2930,9 @@ def calculate_and_show_tokens_and_cost(self, messages, completion=None): self.total_cost += cost self.message_cost += cost - def format_cost(value): - if value == 0: - return "0.00" - magnitude = abs(value) - if magnitude >= 0.01: - return f"{value:.2f}" - else: - return f"{value:.{max(2, 2 - int(math.log10(magnitude)))}f}" - cost_report = ( - f"Cost: ${format_cost(self.message_cost)} message," - f" ${format_cost(self.total_cost)} session." + f"Cost: ${self.format_cost(self.message_cost)} message," + f" ${self.format_cost(self.total_cost)} session." ) if cache_hit_tokens and cache_write_tokens: @@ -2632,6 +2942,15 @@ def format_cost(value): self.usage_report = tokens_report + sep + cost_report + def format_cost(self, value): + if value == 0: + return "0.00" + magnitude = abs(value) + if magnitude >= 0.01: + return f"{value:.2f}" + else: + return f"{value:.{max(2, 2 - int(math.log10(magnitude)))}f}" + def compute_costs_from_tokens( self, prompt_tokens, completion_tokens, cache_write_tokens, cache_hit_tokens ): @@ -2757,7 +3076,7 @@ def check_for_dirty_commit(self, path): self.io.tool_output(f"Committing {path} before applying edits.") self.need_commit_before_edits.add(path) - def allowed_to_edit(self, path): + async def allowed_to_edit(self, path): full_path = self.abs_root_path(path) if self.repo: need_to_add = not self.repo.path_in_repo(path) @@ -2792,7 +3111,7 @@ def allowed_to_edit(self, path): self.check_added_files() return True - if not self.io.confirm_ask( + if not await self.io.confirm_ask( "Allow edits to file that has not been added to the chat?", subject=path, ): @@ -2835,7 +3154,7 @@ def check_added_files(self): self.io.tool_warning(urls.edit_errors) self.warning_given = True - def prepare_to_edit(self, edits): + async def prepare_to_edit(self, edits): res = [] seen = dict() @@ -2851,23 +3170,23 @@ def prepare_to_edit(self, edits): if path in seen: allowed = seen[path] else: - allowed = self.allowed_to_edit(path) + allowed = await self.allowed_to_edit(path) seen[path] = allowed if allowed: res.append(edit) - self.dirty_commit() + await self.dirty_commit() self.need_commit_before_edits = set() return res - def apply_updates(self): + async def apply_updates(self): edited = set() try: edits = self.get_edits() edits = self.apply_edits_dry_run(edits) - edits = self.prepare_to_edit(edits) + edits = await self.prepare_to_edit(edits) edited = set(edit[0] for edit in edits) self.apply_edits(edits) @@ -2890,9 +3209,7 @@ def apply_updates(self): except Exception as err: self.io.tool_error("Exception while updating files:") self.io.tool_error(str(err), strip=False) - - traceback.print_exc() - + self.io.tool_error(traceback.format_exc()) self.reflected_message = str(err) return edited @@ -2964,7 +3281,7 @@ def get_context_from_history(self, history): return context - def auto_commit(self, edited, context=None): + async def auto_commit(self, edited, context=None): if not self.repo or not self.auto_commits or self.dry_run: return @@ -2972,7 +3289,9 @@ def auto_commit(self, edited, context=None): context = self.get_context_from_history(self.cur_messages) try: - res = self.repo.commit(fnames=edited, context=context, aider_edits=True, coder=self) + res = await self.repo.commit( + fnames=edited, context=context, aider_edits=True, coder=self + ) if res: self.show_auto_commit_outcome(res) commit_hash, commit_message = res @@ -3000,7 +3319,7 @@ def show_undo_hint(self): if self.commit_before_message[-1] != self.repo.get_head_commit_sha(): self.io.tool_output("You can use /undo to undo and discard each aider commit.") - def dirty_commit(self): + async def dirty_commit(self): if not self.need_commit_before_edits: return if not self.dirty_commits: @@ -3008,7 +3327,7 @@ def dirty_commit(self): if not self.repo: return - self.repo.commit(fnames=self.need_commit_before_edits, coder=self) + await self.repo.commit(fnames=self.need_commit_before_edits, coder=self) # files changed, move cur messages back behind the files messages # self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits) @@ -3023,7 +3342,7 @@ def apply_edits(self, edits): def apply_edits_dry_run(self, edits): return edits - def run_shell_commands(self): + async def run_shell_commands(self): if not self.suggest_shell_commands: return "" @@ -3034,18 +3353,18 @@ def run_shell_commands(self): if command in done: continue done.add(command) - output = self.handle_shell_commands(command, group) + output = await self.handle_shell_commands(command, group) if output: accumulated_output += output + "\n\n" return accumulated_output - def handle_shell_commands(self, commands_str, group): + async def handle_shell_commands(self, commands_str, group): commands = commands_str.strip().splitlines() command_count = sum( 1 for cmd in commands if cmd.strip() and not cmd.strip().startswith("#") ) prompt = "Run shell command?" if command_count == 1 else "Run shell commands?" - if not self.io.confirm_ask( + if not await self.io.confirm_ask( prompt, subject="\n".join(commands), explicit_yes_required=True, @@ -3064,11 +3383,13 @@ def handle_shell_commands(self, commands_str, group): self.io.tool_output(f"Running {command}") # Add the command to input history self.io.add_to_input_history(f"/run {command.strip()}") - exit_status, output = run_cmd(command, error_print=self.io.tool_error, cwd=self.root) + exit_status, output = await asyncio.to_thread( + run_cmd, command, error_print=self.io.tool_error, cwd=self.root + ) if output: accumulated_output += f"Output from {command}\n{output}\n" - if accumulated_output.strip() and self.io.confirm_ask( + if accumulated_output.strip() and await self.io.confirm_ask( "Add command output to the chat?", allow_never=True ): num_lines = len(accumulated_output.strip().splitlines()) diff --git a/aider/coders/context_coder.py b/aider/coders/context_coder.py index 73fe64af0ab..ee5dd949e62 100644 --- a/aider/coders/context_coder.py +++ b/aider/coders/context_coder.py @@ -18,7 +18,7 @@ def __init__(self, *args, **kwargs): self.repo_map.max_map_tokens *= self.repo_map.map_mul_no_files self.repo_map.map_mul_no_files = 1.0 - def reply_completed(self): + async def reply_completed(self): content = self.partial_response_content if not content or not content.strip(): return True diff --git a/aider/coders/editblock_func_coder.py b/aider/coders/editblock_func_coder.py index 27aa53f115c..9750f5e2635 100644 --- a/aider/coders/editblock_func_coder.py +++ b/aider/coders/editblock_func_coder.py @@ -92,7 +92,7 @@ def render_incremental_response(self, final=False): res = json.dumps(args, indent=4) return res - def _update_files(self): + async def _update_files(self): name = self.partial_response_function_call.get("name") if name and name != "replace_lines": @@ -121,7 +121,7 @@ def _update_files(self): if updated and not updated.endswith("\n"): updated += "\n" - full_path = self.allowed_to_edit(path) + full_path = await self.allowed_to_edit(path) if not full_path: continue content = self.io.read_text(full_path) diff --git a/aider/coders/navigator_coder.py b/aider/coders/navigator_coder.py index 0232e9c3e83..df7e4159ccd 100644 --- a/aider/coders/navigator_coder.py +++ b/aider/coders/navigator_coder.py @@ -10,7 +10,7 @@ import traceback # Add necessary imports if not already present -from collections import defaultdict +from collections import Counter, defaultdict from datetime import datetime from pathlib import Path @@ -24,12 +24,48 @@ from aider.repo import ANY_GIT_ERROR # Import run_cmd for potentially interactive execution and run_cmd_subprocess for guaranteed non-interactive +from aider.tools import ( + command_interactive_schema, + command_schema, + delete_block_schema, + delete_line_schema, + delete_lines_schema, + extract_lines_schema, + grep_schema, + indent_lines_schema, + insert_block_schema, + list_changes_schema, + ls_schema, + make_editable_schema, + make_readonly_schema, + remove_schema, + replace_all_schema, + replace_line_schema, + replace_lines_schema, + replace_text_schema, + show_numbered_context_schema, + undo_change_schema, + update_todo_list_schema, + view_files_matching_schema, + view_files_with_symbol_schema, + view_schema, +) from aider.tools.command import _execute_command from aider.tools.command_interactive import _execute_command_interactive from aider.tools.delete_block import _execute_delete_block from aider.tools.delete_line import _execute_delete_line from aider.tools.delete_lines import _execute_delete_lines from aider.tools.extract_lines import _execute_extract_lines +from aider.tools.git import ( + _execute_git_diff, + _execute_git_log, + _execute_git_show, + _execute_git_status, + git_diff_schema, + git_log_schema, + git_show_schema, + git_status_schema, +) from aider.tools.grep import _execute_grep from aider.tools.indent_lines import _execute_indent_lines from aider.tools.insert_block import _execute_insert_block @@ -44,10 +80,10 @@ from aider.tools.replace_text import _execute_replace_text from aider.tools.show_numbered_context import execute_show_numbered_context from aider.tools.undo_change import _execute_undo_change +from aider.tools.update_todo_list import _execute_update_todo_list from aider.tools.view import execute_view # Import tool functions -from aider.tools.view_files_at_glob import execute_view_files_at_glob from aider.tools.view_files_matching import execute_view_files_matching from aider.tools.view_files_with_symbol import _execute_view_files_with_symbol @@ -56,6 +92,22 @@ from .navigator_legacy_prompts import NavigatorLegacyPrompts from .navigator_prompts import NavigatorPrompts +# UNUSED TOOL SCHEMAS +# view_files_matching_schema, +# grep_schema, +# replace_all_schema, +# insert_block_schema, +# delete_block_schema, +# replace_line_schema, +# replace_lines_schema, +# indent_lines_schema, +# delete_line_schema, +# delete_lines_schema, +# undo_change_schema, +# list_changes_schema, +# extract_lines_schema, +# show_numbered_context_schema, + class NavigatorCoder(Coder): """Mode where the LLM autonomously manages which files are in context.""" @@ -75,6 +127,29 @@ def __init__(self, *args, **kwargs): # Dictionary to track recently removed files self.recently_removed = {} + # Tool usage history + self.tool_usage_history = [] + self.tool_usage_retries = 10 + self.read_tools = { + "viewfilesatglob", + "viewfilesmatching", + "ls", + "viewfileswithsymbol", + "grep", + "listchanges", + "extractlines", + "shownumberedcontext", + } + self.write_tools = { + "command", + "commandinteractive", + "insertblock", + "replaceblock", + "replaceall", + "replacetext", + "undochange", + } + # Configuration parameters self.max_tool_calls = 100 # Maximum number of tool calls per response @@ -110,11 +185,42 @@ def __init__(self, *args, **kwargs): self.tokens_calculated = False super().__init__(*args, **kwargs) - self.initialize_local_tools() - def initialize_local_tools(self): - if not self.use_granular_editing: - return + def get_local_tool_schemas(self): + """Returns the JSON schemas for all local tools.""" + return [ + view_files_matching_schema, + ls_schema, + view_schema, + remove_schema, + make_editable_schema, + make_readonly_schema, + view_files_with_symbol_schema, + command_schema, + command_interactive_schema, + grep_schema, + replace_text_schema, + replace_all_schema, + insert_block_schema, + delete_block_schema, + replace_line_schema, + replace_lines_schema, + indent_lines_schema, + delete_line_schema, + delete_lines_schema, + undo_change_schema, + list_changes_schema, + extract_lines_schema, + show_numbered_context_schema, + update_todo_list_schema, + git_diff_schema, + git_log_schema, + git_show_schema, + git_status_schema, + ] + + async def initialize_mcp_tools(self): + await super().initialize_mcp_tools() local_tools = self.get_local_tool_schemas() if not local_tools: @@ -133,491 +239,6 @@ def initialize_local_tools(self): if "local_tools" not in [name for name, _ in self.mcp_tools]: self.mcp_tools.append((local_server.name, local_tools)) - self.functions = self.get_tool_list() - - def get_local_tool_schemas(self): - """Returns the JSON schemas for all local tools.""" - return [ - { - "type": "function", - "function": { - "name": "ViewFilesAtGlob", - "description": "View files matching a glob pattern.", - "parameters": { - "type": "object", - "properties": { - "pattern": { - "type": "string", - "description": "The glob pattern to match files.", - }, - }, - "required": ["pattern"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "ViewFilesMatching", - "description": "View files containing a specific pattern.", - "parameters": { - "type": "object", - "properties": { - "pattern": { - "type": "string", - "description": "The pattern to search for in file contents.", - }, - "file_pattern": { - "type": "string", - "description": ( - "An optional glob pattern to filter which files are searched." - ), - }, - "regex": { - "type": "boolean", - "description": ( - "Whether the pattern is a regular expression. Defaults to" - " False." - ), - }, - }, - "required": ["pattern"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "Ls", - "description": "List files in a directory.", - "parameters": { - "type": "object", - "properties": { - "directory": { - "type": "string", - "description": "The directory to list.", - }, - }, - "required": ["directory"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "View", - "description": "View a specific file.", - "parameters": { - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "The path to the file to view.", - }, - }, - "required": ["file_path"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "Remove", - "description": "Remove a file from the chat context.", - "parameters": { - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "The path to the file to remove.", - }, - }, - "required": ["file_path"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "MakeEditable", - "description": "Make a read-only file editable.", - "parameters": { - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "The path to the file to make editable.", - }, - }, - "required": ["file_path"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "MakeReadonly", - "description": "Make an editable file read-only.", - "parameters": { - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "The path to the file to make read-only.", - }, - }, - "required": ["file_path"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "ViewFilesWithSymbol", - "description": ( - "View files that contain a specific symbol (e.g., class, function)." - ), - "parameters": { - "type": "object", - "properties": { - "symbol": { - "type": "string", - "description": "The symbol to search for.", - }, - }, - "required": ["symbol"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "Command", - "description": "Execute a shell command.", - "parameters": { - "type": "object", - "properties": { - "command_string": { - "type": "string", - "description": "The shell command to execute.", - }, - }, - "required": ["command_string"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "CommandInteractive", - "description": "Execute a shell command interactively.", - "parameters": { - "type": "object", - "properties": { - "command_string": { - "type": "string", - "description": "The interactive shell command to execute.", - }, - }, - "required": ["command_string"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "Grep", - "description": "Search for a pattern in files.", - "parameters": { - "type": "object", - "properties": { - "pattern": { - "type": "string", - "description": "The pattern to search for.", - }, - "file_pattern": { - "type": "string", - "description": "Glob pattern for files to search. Defaults to '*'.", - }, - "directory": { - "type": "string", - "description": "Directory to search in. Defaults to '.'.", - }, - "use_regex": { - "type": "boolean", - "description": "Whether to use regex. Defaults to False.", - }, - "case_insensitive": { - "type": "boolean", - "description": ( - "Whether to perform a case-insensitive search. Defaults to" - " False." - ), - }, - "context_before": { - "type": "integer", - "description": ( - "Number of lines to show before a match. Defaults to 5." - ), - }, - "context_after": { - "type": "integer", - "description": ( - "Number of lines to show after a match. Defaults to 5." - ), - }, - }, - "required": ["pattern"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "ReplaceText", - "description": "Replace text in a file.", - "parameters": { - "type": "object", - "properties": { - "file_path": {"type": "string"}, - "find_text": {"type": "string"}, - "replace_text": {"type": "string"}, - "near_context": {"type": "string"}, - "occurrence": {"type": "integer", "default": 1}, - "change_id": {"type": "string"}, - "dry_run": {"type": "boolean", "default": False}, - }, - "required": ["file_path", "find_text", "replace_text"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "ReplaceAll", - "description": "Replace all occurrences of text in a file.", - "parameters": { - "type": "object", - "properties": { - "file_path": {"type": "string"}, - "find_text": {"type": "string"}, - "replace_text": {"type": "string"}, - "change_id": {"type": "string"}, - "dry_run": {"type": "boolean", "default": False}, - }, - "required": ["file_path", "find_text", "replace_text"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "InsertBlock", - "description": "Insert a block of content into a file.", - "parameters": { - "type": "object", - "properties": { - "file_path": {"type": "string"}, - "content": {"type": "string"}, - "after_pattern": {"type": "string"}, - "before_pattern": {"type": "string"}, - "occurrence": {"type": "integer", "default": 1}, - "change_id": {"type": "string"}, - "dry_run": {"type": "boolean", "default": False}, - "position": {"type": "string", "enum": ["top", "bottom"]}, - "auto_indent": {"type": "boolean", "default": True}, - "use_regex": {"type": "boolean", "default": False}, - }, - "required": ["file_path", "content"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "DeleteBlock", - "description": "Delete a block of lines from a file.", - "parameters": { - "type": "object", - "properties": { - "file_path": {"type": "string"}, - "start_pattern": {"type": "string"}, - "end_pattern": {"type": "string"}, - "line_count": {"type": "integer"}, - "near_context": {"type": "string"}, - "occurrence": {"type": "integer", "default": 1}, - "change_id": {"type": "string"}, - "dry_run": {"type": "boolean", "default": False}, - }, - "required": ["file_path", "start_pattern"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "ReplaceLine", - "description": "Replace a single line in a file.", - "parameters": { - "type": "object", - "properties": { - "file_path": {"type": "string"}, - "line_number": {"type": "integer"}, - "new_content": {"type": "string"}, - "change_id": {"type": "string"}, - "dry_run": {"type": "boolean", "default": False}, - }, - "required": ["file_path", "line_number", "new_content"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "ReplaceLines", - "description": "Replace a range of lines in a file.", - "parameters": { - "type": "object", - "properties": { - "file_path": {"type": "string"}, - "start_line": {"type": "integer"}, - "end_line": {"type": "integer"}, - "new_content": {"type": "string"}, - "change_id": {"type": "string"}, - "dry_run": {"type": "boolean", "default": False}, - }, - "required": ["file_path", "start_line", "end_line", "new_content"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "IndentLines", - "description": "Indent a block of lines in a file.", - "parameters": { - "type": "object", - "properties": { - "file_path": {"type": "string"}, - "start_pattern": {"type": "string"}, - "end_pattern": {"type": "string"}, - "line_count": {"type": "integer"}, - "indent_levels": {"type": "integer", "default": 1}, - "near_context": {"type": "string"}, - "occurrence": {"type": "integer", "default": 1}, - "change_id": {"type": "string"}, - "dry_run": {"type": "boolean", "default": False}, - }, - "required": ["file_path", "start_pattern"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "DeleteLine", - "description": "Delete a single line from a file.", - "parameters": { - "type": "object", - "properties": { - "file_path": {"type": "string"}, - "line_number": {"type": "integer"}, - "change_id": {"type": "string"}, - "dry_run": {"type": "boolean", "default": False}, - }, - "required": ["file_path", "line_number"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "DeleteLines", - "description": "Delete a range of lines from a file.", - "parameters": { - "type": "object", - "properties": { - "file_path": {"type": "string"}, - "start_line": {"type": "integer"}, - "end_line": {"type": "integer"}, - "change_id": {"type": "string"}, - "dry_run": {"type": "boolean", "default": False}, - }, - "required": ["file_path", "start_line", "end_line"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "UndoChange", - "description": "Undo a previously applied change.", - "parameters": { - "type": "object", - "properties": { - "change_id": {"type": "string"}, - "file_path": {"type": "string"}, - }, - }, - }, - }, - { - "type": "function", - "function": { - "name": "ListChanges", - "description": "List recent changes made.", - "parameters": { - "type": "object", - "properties": { - "file_path": {"type": "string"}, - "limit": {"type": "integer", "default": 10}, - }, - }, - }, - }, - { - "type": "function", - "function": { - "name": "ExtractLines", - "description": ( - "Extract lines from a source file and append them to a target file." - ), - "parameters": { - "type": "object", - "properties": { - "source_file_path": {"type": "string"}, - "target_file_path": {"type": "string"}, - "start_pattern": {"type": "string"}, - "end_pattern": {"type": "string"}, - "line_count": {"type": "integer"}, - "near_context": {"type": "string"}, - "occurrence": {"type": "integer", "default": 1}, - "dry_run": {"type": "boolean", "default": False}, - }, - "required": ["source_file_path", "target_file_path", "start_pattern"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "ShowNumberedContext", - "description": ( - "Show numbered lines of context around a pattern or line number." - ), - "parameters": { - "type": "object", - "properties": { - "file_path": {"type": "string"}, - "pattern": {"type": "string"}, - "line_number": {"type": "integer"}, - "context_lines": {"type": "integer", "default": 3}, - }, - "required": ["file_path"], - }, - }, - }, - ] async def _execute_local_tool_calls(self, tool_calls_list): tool_responses = [] @@ -646,61 +267,52 @@ async def _execute_local_tool_calls(self, tool_calls_list): all_results_content = [] norm_tool_name = tool_name.lower() - for params in parsed_args_list: - single_result = "" - # Dispatch to the correct tool execution function - if norm_tool_name == "viewfilesatglob": - single_result = execute_view_files_at_glob(self, **params) - elif norm_tool_name == "viewfilesmatching": - single_result = execute_view_files_matching(self, **params) - elif norm_tool_name == "ls": - single_result = execute_ls(self, **params) - elif norm_tool_name == "view": - single_result = execute_view(self, **params) - elif norm_tool_name == "remove": - single_result = _execute_remove(self, **params) - elif norm_tool_name == "makeeditable": - single_result = _execute_make_editable(self, **params) - elif norm_tool_name == "makereadonly": - single_result = _execute_make_readonly(self, **params) - elif norm_tool_name == "viewfileswithsymbol": - single_result = _execute_view_files_with_symbol(self, **params) - elif norm_tool_name == "command": - single_result = _execute_command(self, **params) - elif norm_tool_name == "commandinteractive": - single_result = _execute_command_interactive(self, **params) - elif norm_tool_name == "grep": - single_result = _execute_grep(self, **params) - elif norm_tool_name == "replacetext": - single_result = _execute_replace_text(self, **params) - elif norm_tool_name == "replaceall": - single_result = _execute_replace_all(self, **params) - elif norm_tool_name == "insertblock": - single_result = _execute_insert_block(self, **params) - elif norm_tool_name == "deleteblock": - single_result = _execute_delete_block(self, **params) - elif norm_tool_name == "replaceline": - single_result = _execute_replace_line(self, **params) - elif norm_tool_name == "replacelines": - single_result = _execute_replace_lines(self, **params) - elif norm_tool_name == "indentlines": - single_result = _execute_indent_lines(self, **params) - elif norm_tool_name == "deleteline": - single_result = _execute_delete_line(self, **params) - elif norm_tool_name == "deletelines": - single_result = _execute_delete_lines(self, **params) - elif norm_tool_name == "undochange": - single_result = _execute_undo_change(self, **params) - elif norm_tool_name == "listchanges": - single_result = _execute_list_changes(self, **params) - elif norm_tool_name == "extractlines": - single_result = _execute_extract_lines(self, **params) - elif norm_tool_name == "shownumberedcontext": - single_result = execute_show_numbered_context(self, **params) - else: - single_result = f"Error: Unknown local tool name '{tool_name}'" + tasks = [] + tool_functions = { + "viewfilesmatching": execute_view_files_matching, + "ls": execute_ls, + "view": execute_view, + "remove": _execute_remove, + "makeeditable": _execute_make_editable, + "makereadonly": _execute_make_readonly, + "viewfileswithsymbol": _execute_view_files_with_symbol, + "command": _execute_command, + "commandinteractive": _execute_command_interactive, + "grep": _execute_grep, + "replacetext": _execute_replace_text, + "replaceall": _execute_replace_all, + "insertblock": _execute_insert_block, + "deleteblock": _execute_delete_block, + "replaceline": _execute_replace_line, + "replacelines": _execute_replace_lines, + "indentlines": _execute_indent_lines, + "deleteline": _execute_delete_line, + "deletelines": _execute_delete_lines, + "undochange": _execute_undo_change, + "listchanges": _execute_list_changes, + "extractlines": _execute_extract_lines, + "shownumberedcontext": execute_show_numbered_context, + "updatetodolist": _execute_update_todo_list, + "git_diff": _execute_git_diff, + "git_log": _execute_git_log, + "git_show": _execute_git_show, + "git_status": _execute_git_status, + } + + func = tool_functions.get(norm_tool_name) + + if func: + for params in parsed_args_list: + if asyncio.iscoroutinefunction(func): + tasks.append(func(self, **params)) + else: + tasks.append(asyncio.to_thread(func, self, **params)) + else: + all_results_content.append(f"Error: Unknown local tool name '{tool_name}'") - all_results_content.append(str(single_result)) + if tasks: + task_results = await asyncio.gather(*tasks) + all_results_content.extend(str(res) for res in task_results) result_message = "\n\n".join(all_results_content) @@ -714,13 +326,12 @@ async def _execute_local_tool_calls(self, tool_calls_list): { "role": "tool", "tool_call_id": tool_call.id, - "name": tool_name, "content": result_message, } ) return tool_responses - def _execute_mcp_tool(self, server, tool_name, params): + async def _execute_mcp_tool(self, server, tool_name, params): """Helper to execute a single MCP tool call, created from legacy format.""" # This is a simplified, synchronous wrapper around async logic @@ -767,10 +378,8 @@ async def _exec_async(): f"Executing {tool_name} on {server.name} failed: \n Error: {e}\n" ) return f"Error executing tool call {tool_name}: {e}" - finally: - await server.disconnect() - return asyncio.run(_exec_async()) + return await _exec_async() def _calculate_context_block_tokens(self, force=False): """ @@ -840,6 +449,8 @@ def _generate_context_block(self, block_name): content = self.get_context_symbol_outline() elif block_name == "context_summary": content = self.get_context_summary() + elif block_name == "todo_list": + content = self.get_todo_list() # Cache the result if it's not None if content is not None: @@ -965,63 +576,11 @@ def format_chat_chunks(self): This approach preserves prefix caching while providing fresh context information. """ - # First get the normal chat chunks from the parent method without calling super - # We'll manually build the chunks to control placement of context blocks - chunks = self.format_chat_chunks_base() - - # If enhanced context blocks are not enabled, just return the base chunks + # If enhanced context blocks are not enabled, use the base implementation if not self.use_enhanced_context: - return chunks - - # Make sure token counts are updated - using centralized method - # This also populates the context block cache - self._calculate_context_block_tokens() + return super().format_chat_chunks() - # Get blocks from cache to avoid regenerating them - env_context = self.get_cached_context_block("environment_info") - dir_structure = self.get_cached_context_block("directory_structure") - git_status = self.get_cached_context_block("git_status") - symbol_outline = self.get_cached_context_block("symbol_outline") - - # Context summary needs special handling because it depends on other blocks - context_summary = self.get_context_summary() - - # 1. Add relatively static blocks BEFORE done_messages - # These blocks change less frequently and can be part of the cacheable prefix - static_blocks = [] - if dir_structure: - static_blocks.append(dir_structure) - if env_context: - static_blocks.append(env_context) - - if static_blocks: - static_message = "\n\n".join(static_blocks) - # Insert as a system message right before done_messages - chunks.done.insert(0, dict(role="system", content=static_message)) - - # 2. Add dynamic blocks AFTER chat_files - # These blocks change with the current files in context - dynamic_blocks = [] - if context_summary: - dynamic_blocks.append(context_summary) - if symbol_outline: - dynamic_blocks.append(symbol_outline) - if git_status: - dynamic_blocks.append(git_status) - - if dynamic_blocks: - dynamic_message = "\n\n".join(dynamic_blocks) - # Append as a system message after chat_files - chunks.chat_files.append(dict(role="system", content=dynamic_message)) - - return chunks - - def format_chat_chunks_base(self): - """ - Create base chat chunks without enhanced context blocks. - This is a copy of the parent's format_chat_chunks method to avoid - calling super() which would create a recursive loop. - """ + # Build chunks from scratch to avoid duplication with enhanced context blocks self.choose_fence() main_sys = self.fmt_system_prompt(self.gpt_prompts.main_system) @@ -1072,12 +631,65 @@ def format_chat_chunks_base(self): chunks.examples = example_messages self.summarize_end() - chunks.done = self.done_messages + chunks.done = list(self.done_messages) chunks.repo = self.get_repo_messages() chunks.readonly_files = self.get_readonly_files_messages() chunks.chat_files = self.get_chat_files_messages() + # Make sure token counts are updated - using centralized method + # This also populates the context block cache + self._calculate_context_block_tokens() + + # Get blocks from cache to avoid regenerating them + env_context = self.get_cached_context_block("environment_info") + dir_structure = self.get_cached_context_block("directory_structure") + git_status = self.get_cached_context_block("git_status") + symbol_outline = self.get_cached_context_block("symbol_outline") + todo_list = self.get_cached_context_block("todo_list") + + # Context summary needs special handling because it depends on other blocks + context_summary = self.get_context_summary() + + # 1. Add relatively static blocks BEFORE done_messages + # These blocks change less frequently and can be part of the cacheable prefix + static_blocks = [] + if dir_structure: + static_blocks.append(dir_structure) + if env_context: + static_blocks.append(env_context) + + if static_blocks: + static_message = "\n\n".join(static_blocks) + # Insert as a system message right before done_messages + chunks.done.insert(0, dict(role="system", content=static_message)) + + # 2. Add dynamic blocks AFTER chat_files + # These blocks change with the current files in context + dynamic_blocks = [] + if todo_list: + dynamic_blocks.append(todo_list) + if context_summary: + dynamic_blocks.append(context_summary) + if symbol_outline: + dynamic_blocks.append(symbol_outline) + if git_status: + dynamic_blocks.append(git_status) + + # Add tool usage context if there are repetitive tools + if hasattr(self, "tool_usage_history") and self.tool_usage_history: + repetitive_tools = self._get_repetitive_tools() + if repetitive_tools: + tool_context = self._generate_tool_context(repetitive_tools) + if tool_context: + dynamic_blocks.append(tool_context) + + if dynamic_blocks: + dynamic_message = "\n\n".join(dynamic_blocks) + # Append as a system message after chat_files + chunks.chat_files.append(dict(role="system", content=dynamic_message)) + + # Add reminder if needed if self.gpt_prompts.system_reminder: reminder_message = [ dict( @@ -1293,7 +905,21 @@ def get_environment_info(self): self.io.tool_error(f"Error generating environment info: {str(e)}") return None - def reply_completed(self): + async def process_tool_calls(self, tool_call_response): + """ + Track tool usage before calling the base implementation. + """ + + if self.partial_response_tool_calls: + for tool_call in self.partial_response_tool_calls: + self.tool_usage_history.append(tool_call.get("function", {}).get("name")) + + if len(self.tool_usage_history) > self.tool_usage_retries: + self.tool_usage_history.pop(0) + + return await super().process_tool_calls(tool_call_response) + + async def reply_completed(self): """Process the completed response from the LLM. This is a key method that: @@ -1305,8 +931,9 @@ def reply_completed(self): iteratively discover and analyze relevant files before providing a final answer to the user's question. """ - # In granular editing mode, tool calls are handled by BaseCoder's process_tool_calls. - # This method is now only for legacy tool call format and search/replace blocks. + # In granular editing mode, tool calls are handled by BaseCoder's process_tool_calls, + # which is overridden in this class to track tool usage. This method is now only for + # legacy tool call format and search/replace blocks. if self.use_granular_editing: # Handle SEARCH/REPLACE blocks content = self.partial_response_content @@ -1319,7 +946,7 @@ def reply_completed(self): has_replace = ">>>>>>> REPLACE" in content if has_search and has_divider and has_replace: self.io.tool_output("Detected edit blocks, applying changes...") - edited_files = self._apply_edits_from_response() + edited_files = await self._apply_edits_from_response() if self.reflected_message: return False # Trigger reflection if edits failed @@ -1335,14 +962,19 @@ def reply_completed(self): # Legacy tool call processing for use_granular_editing=False content = self.partial_response_content if not content or not content.strip(): + if len(self.tool_usage_history) > self.tool_usage_retries: + self.tool_usage_history = [] return True original_content = content # Keep the original response - # Process tool commands: returns content with tool calls removed, results, flag if any tool calls were found, - # and the content before the last '---' line - processed_content, result_messages, tool_calls_found, content_before_last_separator = ( - self._process_tool_commands(content) - ) + # Process tool commands: returns content with tool calls removed, results, flag if any tool calls were found + ( + processed_content, + result_messages, + tool_calls_found, + content_before_last_separator, + tool_names_this_turn, + ) = await self._process_tool_commands(content) # Since we are no longer suppressing, the partial_response_content IS the final content. # We might want to update it to the processed_content (without tool calls) if we don't @@ -1370,7 +1002,7 @@ def reply_completed(self): if edit_match: self.io.tool_output("Detected edit blocks, applying changes within Navigator...") - edited_files = self._apply_edits_from_response() + edited_files = await self._apply_edits_from_response() # If _apply_edits_from_response set a reflected_message (due to errors), # return False to trigger a reflection loop. if self.reflected_message: @@ -1408,6 +1040,7 @@ def reply_completed(self): if tool_calls_found and self.num_reflections < self.max_reflections: # Reset tool counter for next iteration self.tool_call_count = 0 + # Clear exploration files for the next round self.files_added_in_exploration = set() @@ -1460,15 +1093,24 @@ def reply_completed(self): # After applying edits OR determining no edits were needed (and no reflection needed), # the turn is complete. Reset counters and finalize history. + + # Auto-commit any files edited by granular tools + if self.files_edited_by_tools: + saved_message = await self.auto_commit(self.files_edited_by_tools) + if not saved_message and hasattr(self.gpt_prompts, "files_content_gpt_edits_no_repo"): + saved_message = self.gpt_prompts.files_content_gpt_edits_no_repo + self.move_back_cur_messages(saved_message) + self.tool_call_count = 0 self.files_added_in_exploration = set() + self.files_edited_by_tools = set() # Move cur_messages to done_messages self.move_back_cur_messages( None ) # Pass None as we handled commit message earlier if needed return True # Indicate exploration is finished for this round - def _process_tool_commands(self, content): + async def _process_tool_commands(self, content): """ Process tool commands in the `[tool_call(name, param=value)]` format within the content. @@ -1485,6 +1127,7 @@ def _process_tool_commands(self, content): tool_calls_found = False call_count = 0 max_calls = self.max_tool_calls + tool_names = [] # Check if there's a '---' separator and only process tool calls after the LAST one separator_marker = "---" @@ -1493,7 +1136,7 @@ def _process_tool_commands(self, content): # If there's no separator, treat the entire content as before the separator if len(content_parts) == 1: # Return the original content with no tool calls processed, and the content itself as before_separator - return content, result_messages, False, content + return content, result_messages, False, content, tool_names # Take everything before the last separator (including intermediate separators) content_before_separator = separator_marker.join(content_parts[:-1]) @@ -1683,6 +1326,8 @@ def _process_tool_commands(self, content): else: raise ValueError("Tool name must be an identifier or a string literal") + tool_names.append(tool_name) + # Extract keyword arguments for keyword in call_node.keywords: key = keyword.arg @@ -1764,20 +1409,13 @@ def _process_tool_commands(self, content): # Normalize tool name for case-insensitive matching norm_tool_name = tool_name.lower() - if norm_tool_name == "viewfilesatglob": - pattern = params.get("pattern") - if pattern is not None: - # Call the imported function - result_message = execute_view_files_at_glob(self, pattern) - else: - result_message = "Error: Missing 'pattern' parameter for ViewFilesAtGlob" - elif norm_tool_name == "viewfilesmatching": + if norm_tool_name == "viewfilesmatching": pattern = params.get("pattern") file_pattern = params.get("file_pattern") # Optional regex = params.get("regex", False) # Default to False if not provided if pattern is not None: result_message = execute_view_files_matching( - self, pattern, file_pattern, regex + self, pattern=pattern, file_pattern=file_pattern, regex=regex ) else: result_message = "Error: Missing 'pattern' parameter for ViewFilesMatching" @@ -1823,13 +1461,13 @@ def _process_tool_commands(self, content): elif norm_tool_name == "command": command_string = params.get("command_string") if command_string is not None: - result_message = _execute_command(self, command_string) + result_message = await _execute_command(self, command_string) else: result_message = "Error: Missing 'command_string' parameter for Command" elif norm_tool_name == "commandinteractive": command_string = params.get("command_string") if command_string is not None: - result_message = _execute_command_interactive(self, command_string) + result_message = await _execute_command_interactive(self, command_string) else: result_message = ( "Error: Missing 'command_string' parameter for CommandInteractive" @@ -1848,10 +1486,8 @@ def _process_tool_commands(self, content): context_after = params.get("context_after", 5) if pattern is not None: - # Import the function if not already imported (it should be) - from aider.tools.grep import _execute_grep - - result_message = _execute_grep( + result_message = await asyncio.to_thread( + _execute_grep, self, pattern, file_pattern, @@ -1942,6 +1578,7 @@ def _process_tool_commands(self, content): auto_indent, use_regex, ) + else: result_message = ( "Error: Missing required parameters for InsertBlock (file_path," @@ -2139,6 +1776,21 @@ def _process_tool_commands(self, content): " and either pattern or line_number)" ) + elif norm_tool_name == "updatetodolist": + content = params.get("content") + append = params.get("append", False) + change_id = params.get("change_id") + dry_run = params.get("dry_run", False) + + if content is not None: + result_message = _execute_update_todo_list( + self, content, append, change_id, dry_run + ) + else: + result_message = ( + "Error: Missing required 'content' parameter for UpdateTodoList" + ) + else: result_message = f"Error: Unknown tool name '{tool_name}'" if self.mcp_tools: @@ -2150,7 +1802,7 @@ def _process_tool_commands(self, content): (s for s in self.mcp_servers if s.name == server_name), None ) if server: - result_message = self._execute_mcp_tool( + result_message = await self._execute_mcp_tool( server, tool_name, params ) else: @@ -2176,12 +1828,124 @@ def _process_tool_commands(self, content): # Return the content with tool calls removed modified_content = processed_content - # Update internal counter - self.tool_call_count += call_count + return ( + modified_content, + result_messages, + tool_calls_found, + content_before_separator, + tool_names, + ) + + def _get_repetitive_tools(self): + """ + Identifies repetitive tool usage patterns from a flat list of tool calls. + + This method checks for the following patterns in order: + 1. If the last tool used was a write tool, it assumes progress and returns no repetitive tools. + 2. It checks for any read tool that has been used 2 or more times in the history. + 3. If no tools are repeated, but all tools in the history are read tools, + it flags all of them as potentially repetitive. + + It avoids flagging repetition if a "write" tool was used recently, + as that suggests progress is being made. + """ + history_len = len(self.tool_usage_history) + + # Not enough history to detect a pattern + if history_len < 2: + return set() + + # If the last tool was a write tool, we're likely making progress. + if isinstance(self.tool_usage_history[-1], str): + last_tool_lower = self.tool_usage_history[-1].lower() + + if last_tool_lower in self.write_tools: + self.tool_usage_history = [] + return set() + + # If all tools in history are read tools, return all of them + if all(tool.lower() in self.read_tools for tool in self.tool_usage_history): + return set(tool for tool in self.tool_usage_history) + + # Check for any read tool used more than once + tool_counts = Counter(tool for tool in self.tool_usage_history) + repetitive_tools = { + tool + for tool, count in tool_counts.items() + if count >= 2 and tool.lower() in self.read_tools + } + + if repetitive_tools: + return repetitive_tools + + return set() + + def _generate_tool_context(self, repetitive_tools): + """ + Generate a context message for the LLM about recent tool usage. + """ + if not self.tool_usage_history: + return "" + + context_parts = [''] + + # Add turn and tool call statistics + context_parts.append("## Turn and Tool Call Statistics") + context_parts.append(f"- Current turn: {self.num_reflections + 1}") + context_parts.append(f"- Tool calls this turn: {self.tool_call_count}") + context_parts.append(f"- Total tool calls in session: {self.num_tool_calls}") + context_parts.append("\n\n") + + # Add recent tool usage history + context_parts.append("## Recent Tool Usage History") + if len(self.tool_usage_history) > 10: + recent_history = self.tool_usage_history[-10:] + context_parts.append("(Showing last 10 tools)") + else: + recent_history = self.tool_usage_history + + for i, tool in enumerate(recent_history, 1): + context_parts.append(f"{i}. {tool}") + context_parts.append("\n\n") + + if repetitive_tools: + context_parts.append( + "**Instruction:**\nYou have used the following tool(s) repeatedly:" + ) - return modified_content, result_messages, tool_calls_found, content_before_separator + context_parts.append("### DO NOT USE THE FOLLOWING TOOLS/FUNCTIONS") - def _apply_edits_from_response(self): + for tool in repetitive_tools: + context_parts.append(f"- `{tool}`") + context_parts.append( + "Your exploration appears to be stuck in a loop. Please try a different approach:" + ) + context_parts.append("\n") + context_parts.append("**Suggestions for alternative approaches:**") + context_parts.append( + "- If you've been searching for files, try working with the files already in" + " context" + ) + context_parts.append( + "- If you've been viewing files, try making actual edits to move forward" + ) + context_parts.append("- Consider using different tools that you haven't used recently") + context_parts.append( + "- Focus on making concrete progress rather than gathering more information" + ) + context_parts.append( + "- Use the files you've already discovered to implement the requested changes" + ) + context_parts.append("\n") + context_parts.append( + "You most likely have enough context for a subset of the necessary changes." + ) + context_parts.append("Please prioritize file editing over further exploration.") + + context_parts.append("") + return "\n".join(context_parts) + + async def _apply_edits_from_response(self): """ Parses and applies SEARCH/REPLACE edits found in self.partial_response_content. Returns a set of relative file paths that were successfully edited. @@ -2213,13 +1977,13 @@ def _apply_edits_from_response(self): allowed = seen_paths[path] else: # Use the base Coder's permission check method - allowed = self.allowed_to_edit(path) + allowed = await self.allowed_to_edit(path) seen_paths[path] = allowed if allowed: prepared_edits.append(edit) # Commit any dirty files identified by allowed_to_edit - self.dirty_commit() + await self.dirty_commit() self.need_commit_before_edits = set() # Clear after commit # 3. Apply edits (logic adapted from EditBlockCoder.apply_edits) @@ -2318,20 +2082,20 @@ def _apply_edits_from_response(self): lint_errors = self.lint_edited(edited_files) self.auto_commit(edited_files, context="Ran the linter") if lint_errors and not self.reflected_message: # Reflect only if no edit errors - ok = self.io.confirm_ask("Attempt to fix lint errors?") + ok = await self.io.confirm_ask("Attempt to fix lint errors?") if ok: self.reflected_message = lint_errors - shared_output = self.run_shell_commands() + shared_output = await self.run_shell_commands() if shared_output: # Add shell output as a new user message? Or just display? # Let's just display for now to avoid complex history manipulation self.io.tool_output("Shell command output:\n" + shared_output) if self.auto_test and not self.reflected_message: # Reflect only if no prior errors - test_errors = self.commands.cmd_test(self.test_cmd) + test_errors = await self.commands.cmd_test(self.test_cmd) if test_errors: - ok = self.io.confirm_ask("Attempt to fix test errors?") + ok = await self.io.confirm_ask("Attempt to fix test errors?") if ok: self.reflected_message = test_errors @@ -2352,7 +2116,7 @@ def _apply_edits_from_response(self): except Exception as err: self.io.tool_error("Exception while applying edits:") self.io.tool_error(str(err), strip=False) - traceback.print_exc() + self.io.tool_error(traceback.format_exc()) self.reflected_message = f"Exception while applying edits: {str(err)}" return edited_files @@ -2363,7 +2127,7 @@ def _add_file_to_context(self, file_path, explicit=False): Parameters: - file_path: Path to the file to add - - explicit: Whether this was an explicit view command (vs. implicit through ViewFilesAtGlob/ViewFilesMatching) + - explicit: Whether this was an explicit view command (vs. implicit through ViewFilesMatching) """ # Check if file exists abs_path = self.abs_root_path(file_path) @@ -2437,7 +2201,7 @@ def _process_file_mentions(self, content): # Do nothing here for implicit mentions. pass - def check_for_file_mentions(self, content): + async def check_for_file_mentions(self, content): """ Override parent's method to use our own file processing logic. @@ -2448,13 +2212,13 @@ def check_for_file_mentions(self, content): # Do nothing - disable implicit file adds in navigator mode. pass - def preproc_user_input(self, inp): + async def preproc_user_input(self, inp): """ Override parent's method to wrap user input in a context block. This clearly delineates user input from other sections in the context window. """ # First apply the parent's preprocessing - inp = super().preproc_user_input(inp) + inp = await super().preproc_user_input(inp) # If we still have input after preprocessing, wrap it in a context block if inp and not inp.startswith(''): @@ -2566,6 +2330,44 @@ def print_tree(node, prefix="- ", indent=" ", path=""): self.io.tool_error(f"Error generating directory structure: {str(e)}") return None + def get_todo_list(self): + """ + Generate a todo list context block from the .aider.todo.txt file. + Returns formatted string with the current todo list or None if empty/not present. + """ + + try: + # Define the todo file path + todo_file_path = ".aider.todo.txt" + abs_path = self.abs_root_path(todo_file_path) + + # Check if file exists + import os + + if not os.path.isfile(abs_path): + return ( + '\n' + "Todo list does not exist. Please update it." + "" + ) + + # Read todo list content + content = self.io.read_text(abs_path) + if content is None or not content.strip(): + return None + + # Format the todo list context block + result = '\n' + result += "## Current Todo List\n\n" + result += "Below is the current todo list managed via `UpdateTodoList` tool:\n\n" + result += f"```\n{content}\n```\n" + result += "" + + return result + except Exception as e: + self.io.tool_error(f"Error generating todo list context: {str(e)}") + return None + def get_git_status(self): """ Generate a git status context block for repository information. diff --git a/aider/coders/navigator_legacy_prompts.py b/aider/coders/navigator_legacy_prompts.py index 5b95aa77f0a..72beee97962 100644 --- a/aider/coders/navigator_legacy_prompts.py +++ b/aider/coders/navigator_legacy_prompts.py @@ -5,254 +5,65 @@ class NavigatorLegacyPrompts(CoderPrompts): """ - Prompt templates for the Navigator mode using search/replace instead of granular editing tools. + Prompt templates for the Navigator mode, which enables autonomous codebase exploration. The NavigatorCoder uses these prompts to guide its behavior when exploring and modifying - a codebase using special tool commands like Glob, Grep, Add, etc. This version uses the legacy - search/replace editing method instead of granular editing tools. + a codebase using special tool commands like Glob, Grep, Add, etc. This mode enables the + LLM to manage its own context by adding/removing files and executing commands. """ - main_system = r''' -## Role and Purpose -Act as an expert software engineer with the ability to autonomously navigate and modify a codebase. - -### Proactiveness and Confirmation -- **Explore proactively:** You are encouraged to use file discovery tools (`ViewFilesAtGlob`, `ViewFilesMatching`, `Ls`, `ViewFilesWithSymbol`) and context management tools (`View`, `Remove`) autonomously to gather information needed to fulfill the user's request. Use tool calls to continue exploration across multiple turns. -- **Confirm complex/ambiguous plans:** Before applying potentially complex or ambiguous edits, briefly outline your plan and ask the user for confirmation. For simple, direct edits requested by the user, confirmation may not be necessary unless you are unsure. - -## Response Style Guidelines -- **Be extremely concise and direct.** Prioritize brevity in all responses. -- **Minimize output tokens.** Only provide essential information. -- **Answer the specific question asked.** Avoid tangential information or elaboration unless requested. -- **Keep responses short (1-3 sentences)** unless the user asks for detail or a step-by-step explanation is necessary for a complex task. -- **Avoid unnecessary preamble or postamble.** Do not start with "Okay, I will..." or end with summaries unless crucial. -- When exploring, *briefly* indicate your search strategy. -- When editing, *briefly* explain changes before presenting edit blocks or tool calls. -- For ambiguous references, prioritize user-mentioned items. -- Use markdown for formatting where it enhances clarity (like lists or code). -- End *only* with a clear question or call-to-action if needed, otherwise just stop. + main_system = r""" + +## Core Directives +- **Role**: Act as an expert software engineer. +- **Act Proactively**: Autonomously use file discovery and context management tools (`ViewFilesAtGlob`, `ViewFilesMatching`, `Ls`, `View`, `Remove`) to gather information and fulfill the user's request. Chain tool calls across multiple turns to continue exploration. +- **Be Decisive**: Do not ask the same question or search for the same term in multiple ways. Trust your initial valid findings. +- **Be Concise**: Keep all responses brief and direct (1-3 sentences). Avoid preamble, postamble, and unnecessary explanations. +- **Confirm Ambiguity**: Before applying complex or ambiguous edits, briefly state your plan and ask for confirmation. For simple, direct edits, proceed without confirmation. - -## Available Tools - -### File Discovery Tools -- **ViewFilesAtGlob**: `[tool_call(ViewFilesAtGlob, pattern="**/*.py")]` - Find files matching a glob pattern. **Found files are automatically added to context as read-only.** - Supports patterns like "src/**/*.ts" or "*.json". - -- **ViewFilesMatching**: `[tool_call(ViewFilesMatching, pattern="class User", file_pattern="*.py", regex=False)]` - Search for text in files. **Matching files are automatically added to context as read-only.** - Files with more matches are prioritized. `file_pattern` is optional. `regex` (optional, default False) enables regex search for `pattern`. - -- **Ls**: `[tool_call(Ls, directory="src/components")]` - List files in a directory. Useful for exploring the project structure. - -- **ViewFilesWithSymbol**: `[tool_call(ViewFilesWithSymbol, symbol="my_function")]` - Find files containing a specific symbol (function, class, variable). **Found files are automatically added to context as read-only.** - Leverages the repo map for accurate symbol lookup. - -- **Grep**: `[tool_call(Grep, pattern="my_variable", file_pattern="*.py", directory="src", use_regex=False, case_insensitive=False, context_before=5, context_after=5)]` - Search for lines matching a pattern in files using the best available tool (`rg`, `ag`, or `grep`). Returns matching lines with line numbers and context. - `file_pattern` (optional, default "*") filters files using glob syntax. - `directory` (optional, default ".") specifies the search directory relative to the repo root. - `use_regex` (optional, default False): If False, performs a literal/fixed string search. If True, uses basic Extended Regular Expression (ERE) syntax. - `case_insensitive` (optional, default False): If False (default), the search is case-sensitive. If True, the search is case-insensitive. - `context_before` (optional, default 5): Number of lines to show before each match. - `context_after` (optional, default 5): Number of lines to show after each match. - -### Context Management Tools -- **View**: `[tool_call(View, file_path="src/main.py")]` - Explicitly add a specific file to context as read-only. - -- **Remove**: `[tool_call(Remove, file_path="tests/old_test.py")]` - Explicitly remove a file from context when no longer needed. - Accepts a single file path, not glob patterns. - -- **MakeEditable**: `[tool_call(MakeEditable, file_path="src/main.py")]` - Convert a read-only file to an editable file. Required before making changes. - -- **MakeReadonly**: `[tool_call(MakeReadonly, file_path="src/main.py")]` - Convert an editable file back to read-only status. - -### Other Tools -- **Command**: `[tool_call(Command, command_string="git diff HEAD~1")]` - Execute a *non-interactive* shell command. Requires user confirmation. Use for commands that don't need user input (e.g., `ls`, `git status`, `cat file`). -- **CommandInteractive**: `[tool_call(CommandInteractive, command_string="python manage.py shell")]` - Execute an *interactive* shell command using a pseudo-terminal (PTY). Use for commands that might require user interaction (e.g., running a shell, a development server, `ssh`). Does *not* require separate confirmation as interaction happens directly. - -### Multi-Turn Exploration -When you include any tool call, the system will automatically continue to the next round. - - - -## Navigation and Task Workflow - -### General Task Flow -1. **Understand Request:** Ensure you fully understand the user's goal. Ask clarifying questions if needed. -2. **Explore & Search:** Use discovery tools (`ViewFilesAtGlob`, `ViewFilesMatching`, `Ls`, `ViewFilesWithSymbol`) and context tools (`View`) proactively to locate relevant files and understand the existing code. Use `Remove` to keep context focused. -3. **Plan Changes (If Editing):** Determine the necessary edits. For complex changes, outline your plan briefly for the user. -4. **Confirm Plan (If Editing & Complex/Ambiguous):** If the planned changes are non-trivial or could be interpreted in multiple ways, briefly present your plan and ask the user for confirmation *before* proceeding with edits. -5. **Execute Actions:** Use the appropriate tools (discovery, context management) to implement the plan, and use SEARCH/REPLACE blocks for editing. Remember to use `MakeEditable` before attempting edits. -6. **Verify Edits (If Editing):** Carefully review any changes you've suggested and confirm they meet the requirements. -7. **Final Response:** Provide the final answer or result. Omit tool calls unless further exploration is needed. - -### Exploration Strategy -- Use discovery tools (`ViewFilesAtGlob`, `ViewFilesMatching`, `Ls`, `ViewFilesWithSymbol`) to identify relevant files initially. **These tools automatically add found files to context as read-only.** -- If you suspect a search pattern for `ViewFilesMatching` might return a large number of files, consider using `Grep` first. `Grep` will show you the matching lines and file paths without adding the full files to context, helping you decide which specific files are most relevant to `View`. -- Use `View` *only* if you need to add a specific file *not* already added by discovery tools, or one that was previously removed or is not part of the project structure (like an external file path mentioned by the user). -- Remove irrelevant files with `Remove` to maintain focus. -- Convert files to editable with `MakeEditable` *only* when you are ready to propose edits. -- Include any tool call to automatically continue exploration to the next round. - -### Tool Usage Best Practices -- All tool calls MUST be placed after a '---' line separator at the end of your message -- Use the exact syntax `[tool_call(ToolName, param1=value1, param2="value2")]` for execution -- Tool names are case-insensitive; parameters can be unquoted or quoted -- **Remember:** Discovery tools (`ViewFilesAtGlob`, `ViewFilesMatching`, `ViewFilesWithSymbol`) automatically add found files to context. You usually don't need to use `View` immediately afterward for the same files. Verify files aren't already in context *before* using `View`. -- Use precise search patterns with `ViewFilesMatching` and `file_pattern` to narrow scope -- Target specific patterns rather than overly broad searches -- Remember the `ViewFilesWithSymbol` tool is optimized for locating symbols across the codebase - -### Format Example -``` -Your answer to the user's question... + +## Core Workflow +1. **Plan**: Determine the necessary changes. Use the `UpdateTodoList` tool to manage your plan. Always begin by the todo list. +2. **Explore**: Use discovery tools (`ViewFilesAtGlob`, `ViewFilesMatching`, `Ls`, `Grep`) to find relevant files. These tools add files to context as read-only. Use `Grep` first for broad searches to avoid context clutter. +3. **Think**: Given the contents of your exploration, reason through the edits that need to be made to accomplish the goal. For complex edits, briefly outline your plan for the user. +4. **Execute**: Use the appropriate editing tool. Remember to use `MakeEditable` on a file before modifying it. +5. **Verify & Recover**: After every edit, check the resulting diff snippet. If an edit is incorrect, **immediately** use `UndoChange` in your very next message before attempting any other action. -SEARCH/REPLACE blocks can ONLY appear BEFORE the last '---' separator. Any SEARCH/REPLACE blocks after the separator will be IGNORED. +## Todo List Management +- **Track Progress**: Use the `UpdateTodoList` tool to add or modify items. +- **Plan Steps**: Create a todo list at the start of complex tasks to track your progress through multiple exploration rounds. +- **Stay Organized**: Update the todo list as you complete steps every 3-10 tool calls to maintain context across multiple tool calls. -file.py -<<<<<<< SEARCH -old code -======= -new code ->>>>>>> REPLACE +## Code Editing Hierarchy +Your primary method for all modifications is through granular tool calls. Use SEARCH/REPLACE only as a last resort. ---- -[tool_call(ViewFilesMatching, pattern="findme")] -[tool_call(Command, command_string="ls -la")] -``` +### 1. Granular Tools (Always Preferred) +Use these for precision and safety. +- **Text/Block Manipulation**: `ReplaceText` (Preferred for the majority of edits), `InsertBlock`, `DeleteBlock`, `ReplaceAll` (use with `dry_run=True` for safety). +- **Line-Based Edits**: `ReplaceLine(s)`, `DeleteLine(s)`, `IndentLines`. +- **Refactoring & History**: `ExtractLines`, `ListChanges`, `UndoChange`. -## SEARCH/REPLACE Block Format -When you need to make changes to code, use the SEARCH/REPLACE block format. You can include multiple edits in one message. +**MANDATORY Safety Protocol for Line-Based Tools:** Line numbers are fragile. You **MUST** use a two-turn process: +1. **Turn 1**: Use `ShowNumberedContext` to get the exact, current line numbers. +2. **Turn 2**: In your *next* message, use the line-based editing tool (`ReplaceLines`, etc.) with the verified numbers. -````python -path/to/file.ext -<<<<<<< SEARCH -Original code lines to match exactly -======= -Replacement code lines ->>>>>>> REPLACE -```` -NOTE that this uses four backticks as the fence and not three! +### 2. SEARCH/REPLACE (Last Resort Only) +Use this format **only** when granular tools are demonstrably insufficient for the task (e.g., a complex, non-contiguous pattern change). Using SEARCH/REPLACE for tasks achievable by tools like `ReplaceLines` is a violation of your instructions. -IMPORTANT: Any SEARCH/REPLACE blocks that appear after the last '---' separator will be IGNORED. +**You MUST include a justification comment explaining why granular tools cannot be used.** -#### Guidelines for SEARCH/REPLACE -- Every SEARCH section must EXACTLY MATCH existing content, including whitespace and indentation. -- Keep edit blocks focused and concise - include only the necessary context. -- Include enough lines for uniqueness but avoid long unchanged sections. -- For new files, use an empty SEARCH section. -- To move code within a file, use two separate SEARCH/REPLACE blocks. -- Respect the file paths exactly as they appear. +Justification: I'm using SEARCH/REPLACE because [specific reason granular tools are insufficient]. +path/to/file.ext <<<<<<< SEARCH Original code to be replaced. +New code to insert. -### Context Management Strategy -- **Remember: Files added with `View` or `MakeEditable` remain fully visible in the context for subsequent messages until you explicitly `Remove` them.** -- Keep your context focused by removing files that are no longer relevant. -- For large codebases, maintain only 5-15 files in context at once for best performance. -- Files are added as read-only by default; only make files editable when you need to modify them. -- Toggle context management with `/context-management` if you need complete content of large files. - - - -## Code Editing Process - -### SEARCH/REPLACE Block Format -When making code changes, use SEARCH/REPLACE blocks as shown below: - -``` -path/to/file.ext -<<<<<<< SEARCH -Original code lines to match exactly -======= -Replacement code lines ->>>>>>> REPLACE -``` - -#### Guidelines for SEARCH/REPLACE -- Every SEARCH section must EXACTLY MATCH existing content, including whitespace and indentation. -- Keep edit blocks focused and concise - include only the necessary context. -- Include enough lines for uniqueness but avoid long unchanged sections. -- For new files, use an empty SEARCH section. -- To move code within a file, use two separate SEARCH/REPLACE blocks. -- Respect the file paths exactly as they appear. +REPLACE -### Error Handling and Recovery -- If a search pattern fails to match, analyze why the edit was incorrect (e.g., whitespace differences, incorrect indentation) -- Verify your SEARCH block matches the exact content in the file, including all whitespace -- Use more context lines to ensure uniqueness when necessary -- For complex changes, break them into multiple smaller edits -- If you're unsure about a file's exact content, use tool commands to view it again -Prioritize direct SEARCH/REPLACE blocks for making edits. Remember to make files editable with MakeEditable before suggesting changes. Always reply to the user in {language}. """ - # Example messages remain unchanged as they are already good examples of proper tool usage - example_messages = [ - dict( - role="user", - content="Find all files related to authentication in this project and explain how the login flow works.", - ), - dict( - role="assistant", - content="""I'll help you understand the authentication system in this project. Let me explore the codebase first to find all relevant files. - ---- -[tool_call(ViewFilesMatching, pattern="login|auth|password|session", file_pattern="*.py")]""", - ), - dict( - role="user", - content="[Result (ViewFilesMatching): Found in 3 files: auth/models.py (12 matches), auth/views.py (8 matches), users/authentication.py (5 matches)]", - ), - dict( - role="assistant", - content="""Great, I've found several files related to authentication. Let me examine them to understand the login flow. - ---- -[tool_call(View, file_path="auth/models.py")] -[tool_call(View, file_path="auth/views.py")] -[tool_call(View, file_path="users/authentication.py")]""", - ), - dict( - role="user", - content="[Added files to context with their complete contents]", - ), - dict( - role="assistant", - content="""Now that I have the relevant files, I can explain how the authentication flow works: - -The login process in this application follows these steps: - -1. User enters credentials on the login form -2. The `auth/views.py` file handles the login request via the `LoginView` class -3. Credentials are validated by `users/authentication.py` using the `authenticate` function -4. If valid, a session token is created by the `TokenManager` in `auth/models.py` -5. The user is redirected to the dashboard with their session active - -The system uses a token-based authentication mechanism with a 24-hour expiry. Password storage uses bcrypt hashing with salt as implemented in the `User` model in `auth/models.py`. - -Would you like me to explain any specific part of the authentication process in more detail?""", - ), - ] - - # File content messages remain largely unchanged as they're already concise - files_content_prefix = """ -These files have been added to the chat so you can see all of their contents. -Trust this message as the true contents of the files! - -''' - files_content_assistant_reply = "I understand. I'll use these files to help with your request." files_no_full_files = ( @@ -261,78 +72,32 @@ class NavigatorLegacyPrompts(CoderPrompts): ) files_no_full_files_with_repo_map = """ -I have access to a map of the repository with summary information about files, but I don't have the complete content of any files yet. -I'll use my navigation tools (`ViewFilesAtGlob`, `ViewFilesMatching`, `ViewFilesWithSymbol`, `View`) to find and add relevant files to the context as needed. +I have a repository map but no full file contents yet. I will use my navigation tools to add relevant files to the context. """ - files_no_full_files_with_repo_map_reply = """I understand. I'll use the repository map along with my navigation tools (`ViewFilesAtGlob`, `ViewFilesMatching`, `ViewFilesWithSymbol`, `View`) to find and add relevant files to our conversation. + files_no_full_files_with_repo_map_reply = """I understand. I'll use the repository map and navigation tools to find and add files as needed. """ repo_content_prefix = """ -I am working with code in a git repository. -Here are summaries of some files present in this repo: +I am working with code in a git repository. Here are summaries of some files: """ - # The system_reminder is significantly streamlined to reduce duplication system_reminder = """ -## Tool Command Reminder -- All tool calls MUST appear after a '---' line separator at the end of your message -- To execute a tool, use: `[tool_call(ToolName, param1=value1)]` -- To show tool examples without executing: `\\[tool_call(ToolName, param1=value1)]` -- Including ANY tool call will automatically continue to the next round -- When editing with tools, you'll receive feedback to let you know how your edits went after they're applied -- For final answers, do NOT include any tool calls - -## Tool Call Format -- Tool calls MUST be at the end of your message, after a '---' separator -- If emitting 3 or more tool calls, OR if any tool call spans multiple lines, place each call on a new line for clarity. - -## SEARCH/REPLACE blocks -- When using SEARCH/REPLACE blocks, they MUST ONLY appear BEFORE the last '---' separator line in your response -- If there is no '---' separator, they can appear anywhere in your response -- IMPORTANT: Using SEARCH/REPLACE blocks is the standard editing method in this mode -- Format example: - ``` - Your answer text here... - - file.py - <<<<<<< SEARCH - old code - ======= - new code - >>>>>>> REPLACE - - --- - [tool_call(ToolName, param1=value1)] - ``` - Note that SEARCH/REPLACE blocks should use four backticks (````) as the fence, not three -- IMPORTANT: Any SEARCH/REPLACE blocks that appear after the last '---' separator will be IGNORED - -## Context Features -- Use enhanced context blocks (directory structure and git status) to orient yourself -- Toggle context blocks with `/context-blocks` -- Toggle large file truncation with `/context-management` +## Reminders +- Any tool call automatically continues to the next turn. Provide no tool calls in your final answer. +- Prioritize granular tools. Using SEARCH/REPLACE unnecessarily is incorrect. +- For SEARCH/REPLACE, you MUST provide a justification. +- Use context blocks (directory structure, git status) to orient yourself. {lazy_prompt} {shell_cmd_reminder} """ - try_again = """I need to retry my exploration to better answer your question. - -Here are the issues I encountered in my previous exploration: -1. Some relevant files might have been missed or incorrectly identified -2. The search patterns may have been too broad or too narrow -3. The context might have become too cluttered with irrelevant files - -Let me explore the codebase more strategically this time: -- I'll use more specific search patterns -- I'll be more selective about which files to add to context -- I'll remove irrelevant files more proactively -- I'll use tool calls to automatically continue exploration until I have enough information + try_again = """I need to retry my exploration. My previous attempt may have missed relevant files or used incorrect search patterns. -I'll start exploring again with improved search strategies to find exactly what we need. +I will now explore more strategically with more specific patterns and better context management. I will chain tool calls to continue until I have sufficient information. """ diff --git a/aider/coders/navigator_prompts.py b/aider/coders/navigator_prompts.py index d6730d9718b..1bf0a8a8466 100644 --- a/aider/coders/navigator_prompts.py +++ b/aider/coders/navigator_prompts.py @@ -12,436 +12,58 @@ class NavigatorPrompts(CoderPrompts): LLM to manage its own context by adding/removing files and executing commands. """ - main_system = r''' -## Role and Purpose -Act as an expert software engineer with the ability to autonomously navigate and modify a codebase. - -### Proactiveness and Confirmation -- **Explore proactively:** You are encouraged to use file discovery tools (`ViewFilesAtGlob`, `ViewFilesMatching`, `Ls`, `ViewFilesWithSymbol`) and context management tools (`View`, `Remove`) autonomously to gather information needed to fulfill the user's request. Use tool calls to continue exploration across multiple turns. -- **Confirm complex/ambiguous plans:** Before applying potentially complex or ambiguous edits, briefly outline your plan and ask the user for confirmation. For simple, direct edits requested by the user, confirmation may not be necessary unless you are unsure. - -## Response Style Guidelines -- **Be extremely concise and direct.** Prioritize brevity in all responses. -- **Minimize output tokens.** Only provide essential information. -- **Answer the specific question asked.** Avoid tangential information or elaboration unless requested. -- **Keep responses short (1-3 sentences)** unless the user asks for detail or a step-by-step explanation is necessary for a complex task. -- **Avoid unnecessary preamble or postamble.** Do not start with "Okay, I will..." or end with summaries unless crucial. -- When exploring, *briefly* indicate your search strategy. -- When editing, *briefly* explain changes before presenting edit blocks or tool calls. -- For ambiguous references, prioritize user-mentioned items. -- Use markdown for formatting where it enhances clarity (like lists or code). -- End *only* with a clear question or call-to-action if needed, otherwise just stop. - - - -## Available Tools - -### File Discovery Tools -- **ViewFilesAtGlob**: `[tool_call(ViewFilesAtGlob, pattern="**/*.py")]` - Find files matching a glob pattern. **Found files are automatically added to context as read-only.** - Supports patterns like "src/**/*.ts" or "*.json". - -- **ViewFilesMatching**: `[tool_call(ViewFilesMatching, pattern="class User", file_pattern="*.py", regex=False)]` - Search for text in files. **Matching files are automatically added to context as read-only.** - Files with more matches are prioritized. `file_pattern` is optional. `regex` (optional, default False) enables regex search for `pattern`. - -- **Ls**: `[tool_call(Ls, directory="src/components")]` - List files in a directory. Useful for exploring the project structure. - -- **ViewFilesWithSymbol**: `[tool_call(ViewFilesWithSymbol, symbol="my_function")]` - Find files containing a specific symbol (function, class, variable). **Found files are automatically added to context as read-only.** - Leverages the repo map for accurate symbol lookup. - -- **Grep**: `[tool_call(Grep, pattern="my_variable", file_pattern="*.py", directory="src", use_regex=False, case_insensitive=False, context_before=5, context_after=5)]` - Search for lines matching a pattern in files using the best available tool (`rg`, `ag`, or `grep`). Returns matching lines with line numbers and context. - `file_pattern` (optional, default "*") filters files using glob syntax. - `directory` (optional, default ".") specifies the search directory relative to the repo root. - `use_regex` (optional, default False): If False, performs a literal/fixed string search. If True, uses basic Extended Regular Expression (ERE) syntax. - `case_insensitive` (optional, default False): If False (default), the search is case-sensitive. If True, the search is case-insensitive. - `context_before` (optional, default 5): Number of lines to show before each match. - `context_after` (optional, default 5): Number of lines to show after each match. - -### Context Management Tools -- **View**: `[tool_call(View, file_path="src/main.py")]` - Explicitly add a specific file to context as read-only. - -- **Remove**: `[tool_call(Remove, file_path="tests/old_test.py")]` - Explicitly remove a file from context when no longer needed. - Accepts a single file path, not glob patterns. - -- **MakeEditable**: `[tool_call(MakeEditable, file_path="src/main.py")]` - Convert a read-only file to an editable file. Required before making changes. - -- **MakeReadonly**: `[tool_call(MakeReadonly, file_path="src/main.py")]` - Convert an editable file back to read-only status. - -### Granular Editing Tools -- **ReplaceText**: `[tool_call(ReplaceText, file_path="...", find_text="...", replace_text="...", near_context="...", occurrence=1, dry_run=False)]` - Replace specific text. `near_context` (optional) helps find the right spot. `occurrence` (optional, default 1) specifies which match (-1 for last). `dry_run=True` simulates the change. - *Useful for correcting typos or renaming a single instance of a variable.* - -- **ReplaceAll**: `[tool_call(ReplaceAll, file_path="...", find_text="...", replace_text="...", dry_run=False)]` - Replace ALL occurrences of text. Use with caution. `dry_run=True` simulates the change. - *Useful for renaming variables, functions, or classes project-wide (use with caution).* - -- **InsertBlock**: `[tool_call(InsertBlock, file_path="...", content="...", after_pattern="...", before_pattern="...", position="start_of_file", occurrence=1, auto_indent=True, dry_run=False)]` - Insert a block of code or text. Specify *exactly one* location: - - `after_pattern`: Insert after lines matching this pattern (use multi-line patterns for uniqueness) - - `before_pattern`: Insert before lines matching this pattern (use multi-line patterns for uniqueness) - - `position`: Use "start_of_file" or "end_of_file" - - Optional parameters: - - `occurrence`: Which match to use (1-based indexing: 1 for first match, 2 for second, -1 for last match) - - `auto_indent`: Automatically adjust indentation to match surrounding code (default True) - - `dry_run`: Simulate the change without applying it (default False) - *Useful for adding new functions, methods, or blocks of configuration.* - -- **DeleteBlock**: `[tool_call(DeleteBlock, file_path="...", start_pattern="...", end_pattern="...", near_context="...", occurrence=1, dry_run=False)]` - Delete block from `start_pattern` line to `end_pattern` line (inclusive). Use `line_count` instead of `end_pattern` for fixed number of lines. Use `near_context` and `occurrence` (optional, default 1, -1 for last) for `start_pattern`. `dry_run=True` simulates. - *Useful for removing deprecated functions, unused code sections, or configuration blocks.* - -- **ReplaceLine**: `[tool_call(ReplaceLine, file_path="...", line_number=42, new_content="...", dry_run=False)]` - Replace a specific line number (1-based). `dry_run=True` simulates. - *Useful for fixing specific errors reported by linters or compilers on a single line.* - -- **ReplaceLines**: `[tool_call(ReplaceLines, file_path="...", start_line=42, end_line=45, new_content="...", dry_run=False)]` - Replace a range of lines (1-based, inclusive). `dry_run=True` simulates. - *Useful for replacing multi-line logic blocks or fixing issues spanning several lines.* - -- **IndentLines**: `[tool_call(IndentLines, file_path="...", start_pattern="...", end_pattern="...", indent_levels=1, near_context="...", occurrence=1, dry_run=False)]` - Indent (`indent_levels` > 0) or unindent (`indent_levels` < 0) a block. Use `end_pattern` or `line_count` for range. Use `near_context` and `occurrence` (optional, default 1, -1 for last) for `start_pattern`. `dry_run=True` simulates. - *Useful for fixing indentation errors reported by linters or reformatting code blocks. Also helpful for adjusting indentation after moving code with `ExtractLines`.* - -- **DeleteLine**: `[tool_call(DeleteLine, file_path="...", line_number=42, dry_run=False)]` - Delete a specific line number (1-based). `dry_run=True` simulates. - *Useful for removing single erroneous lines identified by linters or exact line number.* - -- **DeleteLines**: `[tool_call(DeleteLines, file_path="...", start_line=42, end_line=45, dry_run=False)]` - Delete a range of lines (1-based, inclusive). `dry_run=True` simulates. - *Useful for removing multi-line blocks when exact line numbers are known.* - -- **UndoChange**: `[tool_call(UndoChange, change_id="a1b2c3d4")]` or `[tool_call(UndoChange, file_path="...")]` - Undo a specific change by ID, or the last change made to the specified `file_path`. - -- **ListChanges**: `[tool_call(ListChanges, file_path="...", limit=5)]` - List recent changes, optionally filtered by `file_path` and limited. - -- **ExtractLines**: `[tool_call(ExtractLines, source_file_path="...", target_file_path="...", start_pattern="...", end_pattern="...", near_context="...", occurrence=1, dry_run=False)]` - Extract lines from `start_pattern` to `end_pattern` (or use `line_count`) in `source_file_path` and move them to `target_file_path`. Creates `target_file_path` if it doesn't exist. Use `near_context` and `occurrence` (optional, default 1, -1 for last) for `start_pattern`. `dry_run=True` simulates. - *Useful for refactoring, like moving functions, classes, or configuration blocks into separate files.* - -- **ShowNumberedContext**: `[tool_call(ShowNumberedContext, file_path="path/to/file.py", pattern="optional_text", line_number=optional_int, context_lines=3)]` - Displays numbered lines from `file_path` centered around a target location, without adding the file to context. Provide *either* `pattern` (to find the first occurrence) *or* `line_number` (1-based) to specify the center point. Returns the target line(s) plus `context_lines` (default 3) of surrounding context directly in the result message. Crucial for verifying exact line numbers and content before using `ReplaceLine` or `ReplaceLines`. - -### Other Tools -- **Command**: `[tool_call(Command, command_string="git diff HEAD~1")]` - Execute a *non-interactive* shell command. Requires user confirmation. Use for commands that don't need user input (e.g., `ls`, `git status`, `cat file`). -- **CommandInteractive**: `[tool_call(CommandInteractive, command_string="python manage.py shell")]` - Execute an *interactive* shell command using a pseudo-terminal (PTY). Use for commands that might require user interaction (e.g., running a shell, a development server, `ssh`). Does *not* require separate confirmation as interaction happens directly. - -### Multi-Turn Exploration -When you include any tool call, the system will automatically continue to the next round. - - - -## Navigation and Task Workflow - -### General Task Flow -1. **Understand Request:** Ensure you fully understand the user's goal. Ask clarifying questions if needed. -2. **Explore & Search:** Use discovery tools (`ViewFilesAtGlob`, `ViewFilesMatching`, `Ls`, `ViewFilesWithSymbol`) and context tools (`View`) proactively to locate relevant files and understand the existing code. Use `Remove` to keep context focused. -3. **Plan Changes (If Editing):** Determine the necessary edits. For complex changes, outline your plan briefly for the user. -4. **Confirm Plan (If Editing & Complex/Ambiguous):** If the planned changes are non-trivial or could be interpreted in multiple ways, briefly present your plan and ask the user for confirmation *before* proceeding with edits. -5. **Execute Actions:** Use the appropriate tools (discovery, context management, or editing) to implement the plan. Remember to use `MakeEditable` before attempting edits. -6. **Verify Edits (If Editing):** Carefully review the results and diff snippets provided after each editing tool call to ensure the change was correct. -7. **Final Response:** Provide the final answer or result. Omit tool calls unless further exploration is needed. - -### Exploration Strategy -- Use discovery tools (`ViewFilesAtGlob`, `ViewFilesMatching`, `Ls`, `ViewFilesWithSymbol`) to identify relevant files initially. **These tools automatically add found files to context as read-only.** -- If you suspect a search pattern for `ViewFilesMatching` might return a large number of files, consider using `Grep` first. `Grep` will show you the matching lines and file paths without adding the full files to context, helping you decide which specific files are most relevant to `View`. -- Use `View` *only* if you need to add a specific file *not* already added by discovery tools, or one that was previously removed or is not part of the project structure (like an external file path mentioned by the user). -- Remove irrelevant files with `Remove` to maintain focus. -- Convert files to editable with `MakeEditable` *only* when you are ready to propose edits. -- Include any tool call to automatically continue exploration to the next round. - -### Tool Usage Best Practices -- All tool calls MUST be placed after a '---' line separator at the end of your message -- Use the exact syntax `[tool_call(ToolName, param1=value1, param2="value2")]` for execution -- Tool names are case-insensitive; parameters can be unquoted or quoted -- **Remember:** Discovery tools (`ViewFilesAtGlob`, `ViewFilesMatching`, `ViewFilesWithSymbol`) automatically add found files to context. You usually don't need to use `View` immediately afterward for the same files. Verify files aren't already in context *before* using `View`. -- Use precise search patterns with `ViewFilesMatching` and `file_pattern` to narrow scope -- Target specific patterns rather than overly broad searches -- Remember the `ViewFilesWithSymbol` tool is optimized for locating symbols across the codebase - -### Format Example -``` -Your answer to the user's question... - -SEARCH/REPLACE blocks can ONLY appear BEFORE the last '---' separator. Using SEARCH/REPLACE when granular tools could have been used is incorrect and violates core instructions. Always prioritize granular tools. - -# If you must use SEARCH/REPLACE, include a required justification: -# Justification: I'm using SEARCH/REPLACE here because [specific reasons why granular tools can't achieve this edit]. - -file.py -<<<<<<< SEARCH -old code -======= -new code ->>>>>>> REPLACE - ---- -[tool_call(ViewFilesMatching, pattern="findme")] -[tool_call(Command, command_string="ls -la")] -``` - -## Granular Editing Workflow - -**Sequential Edits Warning:** Tool calls within a single message execute sequentially. An edit made by one tool call *can* change line numbers or pattern locations for subsequent tool calls targeting the *same file* in the *same message*. **Always check the result message and diff snippet after each edit.** - -1. **Discover and View Files**: Use discovery tools and `View` as needed. -2. **Make Files Editable**: Use `MakeEditable` for files you intend to change. Can be combined in the same message as subsequent edits to that file. -3. **Plan & Confirm Edits (If Needed)**: Determine necessary edits. For complex or potentially ambiguous changes, briefly outline your plan and **ask the user for confirmation before proceeding.** For simple, direct changes, proceed to verification. -4. **Verify Parameters Before Execution:** - * **Pattern-Based Tools** (`InsertBlock`, `DeleteBlock`, `IndentLines`, `ExtractLines`, `ReplaceText`): **Crucially, before executing the tool call, carefully examine the complete file content *already visible in the chat context*** to confirm your `start_pattern`, `end_pattern`, `near_context`, and `occurrence` parameters target the *exact* intended location. Do *not* rely on memory. This verification uses the existing context, *not* `ShowNumberedContext`. State that you have verified the parameters if helpful, then proceed with execution (Step 5). - * **Line-Number Based Tools** (`ReplaceLine`, `ReplaceLines`): **Mandatory Verification Workflow:** Follow the strict two-turn process using `ShowNumberedContext` as detailed below. Never view and edit lines in the same turn. -5. **Execute Edit (Default: Direct Edit)**: - * Apply the change directly using the tool with `dry_run=False` (or omitted) *after* performing the necessary verification (Step 4) and obtaining user confirmation (Step 3, *if required* for the plan). - * **Immediately review the diff snippet in the `[Result (ToolName): ...]` message** to confirm the change was correct. -6. **(Optional) Use `dry_run=True` for Higher Risk:** Consider `dry_run=True` *before* the actual edit (`dry_run=False`) if: - * Using `ReplaceAll` (High Risk!). - * Using pattern-based tools where verification in Step 4 still leaves ambiguity (e.g., multiple similar patterns). - * Using line-number based tools *after* other edits to the *same file* in the *same message* (due to potential line shifts). - * If using `dry_run=True`, review the simulation, then issue the *exact same call* with `dry_run=False`. -7. **Review and Recover:** - * Use `ListChanges` to review history. - * **Critical:** If a direct edit's result diff shows an error (wrong location, unintended changes), **immediately use `[tool_call(UndoChange, change_id="...")]` in your *very next* message.** Do *not* attempt to fix the error with further edits before undoing. - -**Using Line Number Based Tools (`ReplaceLine`, `ReplaceLines`, `DeleteLine`, `DeleteLines`):** -* **Extreme Caution Required:** Line numbers are extremely fragile. They can become outdated due to preceding edits, even within the same multi-tool message, or simply be incorrect in the source (like linter output or diffs). Using these tools without recent, direct verification via `ShowNumberedContext` is **highly likely to cause incorrect changes.** -* **Mandatory Verification Workflow (No Exceptions):** - 1. **Identify Target Location:** Determine the *approximate* location. **Crucially, do NOT trust line numbers from previous tool outputs (like diffs) or external sources (like linters) as accurate for editing.** They are only starting points for verification. - 2. **View Numbered Context (Separate Turn):** In one message, use `ShowNumberedContext` specifying *either* the approximate `line_number` *or* a nearby `pattern` to display the current, accurate numbered lines for the target area. - ``` - # Example using potentially outdated line number for verification target - --- - [tool_call(ShowNumberedContext, file_path="path/to/file.py", line_number=APPROX_LINE_FROM_LINTER, context_lines=5)] - ``` - ``` - # Example using pattern near the target - --- - [tool_call(ShowNumberedContext, file_path="path/to/file.py", pattern="text_near_target", context_lines=5)] - ``` - 3. **Verify:** Carefully examine the numbered output in the result message. This is the **only** reliable source for the line numbers you will use. Confirm the *exact* line numbers and content you intend to modify based *only* on this output. - 4. **Edit (Next Turn):** Only in the *next* message, issue the `ReplaceLine`, `ReplaceLines`, `DeleteLine`, or `DeleteLines` command using the line numbers **verified in the previous step's `ShowNumberedContext` output.** - ``` - --- - [tool_call(ReplaceLine, file_path="path/to/file.py", line_number=VERIFIED_LINE_FROM_SHOW_NUMBERED_CONTEXT, new_content="...")] - ``` -* **Never view numbered lines and attempt a line-based edit in the same message.** This workflow *must* span two separate turns. - -## Refactoring with Granular Tools - -This section provides guidance on using granular editing tools for common refactoring tasks. - -### Replacing Large Code Blocks - -When you need to replace a significant chunk of code (more than a few lines), using `ReplaceLines` with precise line numbers is often the most reliable approach, especially if the surrounding code might be ambiguous for pattern matching. - -1. **Identify Start and End:** Determine the approximate start and end points of the code block you want to replace. Use nearby unique text as patterns. -2. **Verify Line Numbers (Two-Step):** Use `ShowNumberedContext` **twice in the same message** to get the exact line numbers for the start and end of the block. Request a large context window (e.g., `context_lines=30`) for each call to ensure you have enough surrounding code to confirm the boundaries accurately. - ``` - # Example verification message - --- - [tool_call(ShowNumberedContext, file_path="path/to/file.py", pattern="unique_text_near_start", context_lines=30)] - [tool_call(ShowNumberedContext, file_path="path/to/file.py", pattern="unique_text_near_end", context_lines=30)] - ``` -3. **Confirm Boundaries:** Carefully examine the output from *both* `ShowNumberedContext` calls in the result message. Confirm the exact `start_line` and `end_line` based *only* on this verified output. -4. **Execute Replacement (Next Turn):** In the *next* message, use `ReplaceLines` with the verified `start_line` and `end_line`, providing the `new_content`. - ``` - --- - [tool_call(ReplaceLines, file_path="path/to/file.py", start_line=VERIFIED_START, end_line=VERIFIED_END, new_content=)] - ``` -5. **Review:** Check the result diff carefully to ensure the replacement occurred exactly as intended. - -### Context Management Strategy -- **Remember: Files added with `View` or `MakeEditable` remain fully visible in the context for subsequent messages until you explicitly `Remove` them.** -- Keep your context focused by removing files that are no longer relevant. -- For large codebases, maintain only 5-15 files in context at once for best performance. -- Files are added as read-only by default; only make files editable when you need to modify them. -- Toggle context management with `/context-management` if you need complete content of large files. + main_system = r""" + +## Core Directives +- **Role**: Act as an expert software engineer. +- **Act Proactively**: Autonomously use file discovery and context management tools (`ViewFilesAtGlob`, `ViewFilesMatching`, `Ls`, `View`, `Remove`) to gather information and fulfill the user's request. Chain tool calls across multiple turns to continue exploration. +- **Be Decisive**: Do not ask the same question or search for the same term in multiple ways. Trust your initial valid findings. +- **Be Concise**: Keep all responses brief and direct (1-3 sentences). Avoid preamble, postamble, and unnecessary explanations. +- **Confirm Ambiguity**: Before applying complex or ambiguous edits, briefly state your plan and ask for confirmation. For simple, direct edits, proceed without confirmation. - -## Code Editing Process - -### Granular Editing with Tool Calls (Strongly Preferred Method) -**Use the granular editing tools whenever possible.** They offer the most precision and safety. - -**Available Granular Tools:** -- `ReplaceText`: For specific text instances. -- `ReplaceAll`: **Use with extreme caution!** Best suited for targeted renaming across a file. Consider `dry_run=True` first. Can easily cause unintended changes if `find_text` is common. -- `InsertBlock`: For adding code blocks. -- `DeleteBlock`: For removing code sections. -- `ReplaceLine`/`ReplaceLines`: For line-specific fixes (requires strict `ShowNumberedContext` verification). -- `DeleteLine`/`DeleteLines`: For removing lines by number (requires strict `ShowNumberedContext` verification). -- `IndentLines`: For adjusting indentation. -- `ExtractLines`: For moving code between files. -- `UndoChange`: For reverting specific edits. -- `ListChanges`: For reviewing edit history. - -#### When to Use Line Number Based Tools - -When dealing with errors or warnings that include line numbers, you *can* use the line-based editing tools, but **you MUST follow the mandatory verification workflow described in the `## Granular Editing Workflow` section above.** This involves using `ShowNumberedContext` in one turn to verify the lines, and then using `ReplaceLine`/`ReplaceLines` in the *next* turn. - -``` -Error in /path/to/file.py line 42: Syntax error: unexpected token -Warning in /path/to/file.py lines 105-107: This block should be indented -``` - -For these cases, use: -- `ReplaceLine` for single line fixes (e.g., syntax errors) -- `ReplaceLines` for multi-line issues -- `DeleteLine` for removing single erroneous lines -- `DeleteLines` for removing multi-line blocks by number -- `IndentLines` for indentation problems - -#### Multiline Tool Call Content Format - -When providing multiline content in tool calls (like ReplaceLines, InsertBlock), one leading and one trailing -newline will be automatically trimmed if present. This makes it easier to format code blocks in triple-quoted strings: - -``` -new_content=""" -def better_function(param): - # Fixed implementation - return process(param) -""" -``` - -You don't need to worry about the extra blank lines at the beginning and end. If you actually need to -preserve blank lines in your output, simply add an extra newline: - -``` -new_content=""" - -def better_function(param): # Note the extra newline above to preserve a blank line - # Fixed implementation - return process(param) -""" -``` + +## Core Workflow +1. **Plan**: Determine the necessary changes. Use the `UpdateTodoList` tool to manage your plan. Always begin by creating the todo list. +2. **Explore**: Use discovery tools (`ViewFilesAtGlob`, `ViewFilesMatching`, `Ls`, `Grep`) to find relevant files. These tools add files to context as read-only. Use `Grep` first for broad searches to avoid context clutter. +3. **Think**: Given the contents of your exploration, reason through the edits that need to be made to accomplish the goal. For complex edits, briefly outline your plan for the user. +4. **Execute**: Use the appropriate editing tool. Remember to use `MakeEditable` on a file before modifying it. +5. **Verify & Recover**: After every edit, check the resulting diff snippet. If an edit is incorrect, **immediately** use `UndoChange` in your very next message before attempting any other action. -Example of inserting a new multi-line function: -``` -[tool_call(InsertBlock, - file_path="src/utils.py", - after_pattern="def existing_function():", - content=""" -def new_function(param1, param2): - # This is a new utility function - result = process_data(param1) - if result and param2: - return result - return None -""")] -``` +## Todo List Management +- **Track Progress**: Use the `UpdateTodoList` tool to add or modify items. +- **Plan Steps**: Create a todo list at the start of complex tasks to track your progress through multiple exploration rounds. +- **Stay Organized**: Update the todo list as you complete steps every 3-10 tool calls to maintain context across multiple tool calls. -### SEARCH/REPLACE Block Format (Use ONLY as a Last Resort) -**Granular editing tools (like `ReplaceLines`, `InsertBlock`, `DeleteBlock`) are STRONGLY PREFERRED for ALL edits.** They offer significantly more precision and safety. +## Code Editing Hierarchy +Your primary method for all modifications is through granular tool calls. Use SEARCH/REPLACE only as a last resort. -Use SEARCH/REPLACE blocks **only** in the rare cases where granular tools **provably cannot** achieve the desired outcome due to the *inherent nature* of the change itself (e.g., extremely complex pattern matching across non-contiguous sections, edits that fundamentally don't map to tool capabilities). **Do NOT use SEARCH/REPLACE simply because an edit involves multiple lines; `ReplaceLines` is designed for that.** +### 1. Granular Tools (Always Preferred) +Use these for precision and safety. +- **Text/Block Manipulation**: `ReplaceText` (Preferred for the majority of edits), `InsertBlock`, `DeleteBlock`, `ReplaceAll` (use with `dry_run=True` for safety). +- **Line-Based Edits**: `ReplaceLine(s)`, `DeleteLine(s)`, `IndentLines`. +- **Refactoring & History**: `ExtractLines`, `ListChanges`, `UndoChange`. -**IMPORTANT: Using SEARCH/REPLACE when granular editing tools could have been used is considered incorrect and violates core instructions. Always prioritize granular tools.** +**MANDATORY Safety Protocol for Line-Based Tools:** Line numbers are fragile. You **MUST** use a two-turn process: +1. **Turn 1**: Use `ShowNumberedContext` to get the exact, current line numbers. +2. **Turn 2**: In your *next* message, use the line-based editing tool (`ReplaceLines`, etc.) with the verified numbers. -**Before generating a SEARCH/REPLACE block for more than 1-2 lines, you MUST include an explicit justification explaining why granular editing tools (particularly `ReplaceLines` with the mandatory two-step verification workflow) cannot handle this specific edit case. Your justification must clearly articulate the specific limitations that make granular tools unsuitable for this particular change.** +### 2. SEARCH/REPLACE (Last Resort Only) +Use this format **only** when granular tools are demonstrably insufficient for the task (e.g., a complex, non-contiguous pattern change). Using SEARCH/REPLACE for tasks achievable by tools like `ReplaceLines` is a violation of your instructions. -If you must use SEARCH/REPLACE, adhere strictly to this format: +**You MUST include a justification comment explaining why granular tools cannot be used.** -# Justification: I'm using SEARCH/REPLACE because [specific reasons why granular tools can't achieve this edit] -````python -path/to/file.ext -<<<<<<< SEARCH -Original code lines to match exactly -======= -Replacement code lines ->>>>>>> REPLACE -```` -NOTE that this uses four backticks as the fence and not three! +Justification: I'm using SEARCH/REPLACE because [specific reason granular tools are insufficient]. +path/to/file.ext <<<<<<< SEARCH Original code to be replaced. +New code to insert. -#### Guidelines for SEARCH/REPLACE (When Absolutely Necessary) -- Every SEARCH section must EXACTLY MATCH existing content, including whitespace and indentation. -- Keep edit blocks focused and concise - include only the necessary context. -- Include enough lines for uniqueness but avoid long unchanged sections. -- For new files, use an empty SEARCH section. -- To move code within a file, use two separate SEARCH/REPLACE blocks. -- Respect the file paths exactly as they appear. +REPLACE -### Error Handling and Recovery -- **Tool Call Errors:** If a tool call returns an error message (e.g., pattern not found, file not found), analyze the error and correct the tool call parameters in your next attempt. -- **Incorrect Edits:** If a tool call *succeeds* but the **result message and diff snippet show the change was applied incorrectly** (e.g., wrong location, unintended side effects): - 1. **Critical:** **Immediately use `[tool_call(UndoChange, change_id="...")]` in your *very next* message**, using the `change_id` provided in the result. **Do *not* attempt other actions or try to fix the error with subsequent edits first.** - 2. Only *after* successfully undoing, analyze why the edit was incorrect (e.g., ambiguous pattern, wrong occurrence number, shifted lines) and formulate a corrected tool call or plan. -- **Refining Edits:** If edits affect the wrong location despite verification, refine search patterns, use `near_context`, or adjust the `occurrence` parameter. -- **Orientation:** Use `ListChanges` to review recent edits or the enhanced context blocks (directory structure, git status) if you get confused. -Prioritize granular tools (`ReplaceText`, `ReplaceLines`, `InsertBlock`, `DeleteBlock`, etc.) over SEARCH/REPLACE blocks. Use SEARCH/REPLACE *only* as a last resort when tools are truly unsuitable, and *always* provide justification. Failure to prioritize granular tools is incorrect and violates core instructions. Always reply to the user in {language}. """ - # Example messages remain unchanged as they are already good examples of proper tool usage - example_messages = [ - dict( - role="user", - content="Find all files related to authentication in this project and explain how the login flow works.", - ), - dict( - role="assistant", - content="""I'll help you understand the authentication system in this project. Let me explore the codebase first to find all relevant files. - ---- -[tool_call(ViewFilesMatching, pattern="login|auth|password|session", file_pattern="*.py")]""", - ), - dict( - role="user", - content="[Result (ViewFilesMatching): Found in 3 files: auth/models.py (12 matches), auth/views.py (8 matches), users/authentication.py (5 matches)]", - ), - dict( - role="assistant", - content="""Great, I've found several files related to authentication. Let me examine them to understand the login flow. - ---- -[tool_call(View, file_path="auth/models.py")] -[tool_call(View, file_path="auth/views.py")] -[tool_call(View, file_path="users/authentication.py")]""", - ), - dict( - role="user", - content="[Added files to context with their complete contents]", - ), - dict( - role="assistant", - content="""Now that I have the relevant files, I can explain how the authentication flow works: - -The login process in this application follows these steps: - -1. User enters credentials on the login form -2. The `auth/views.py` file handles the login request via the `LoginView` class -3. Credentials are validated by `users/authentication.py` using the `authenticate` function -4. If valid, a session token is created by the `TokenManager` in `auth/models.py` -5. The user is redirected to the dashboard with their session active - -The system uses a token-based authentication mechanism with a 24-hour expiry. Password storage uses bcrypt hashing with salt as implemented in the `User` model in `auth/models.py`. - -Would you like me to explain any specific part of the authentication process in more detail?""", - ), - ] - - # File content messages remain largely unchanged as they're already concise - files_content_prefix = """ -These files have been added to the chat so you can see all of their contents. -Trust this message as the true contents of the files! - -''' - files_content_assistant_reply = "I understand. I'll use these files to help with your request." files_no_full_files = ( @@ -450,81 +72,32 @@ def new_function(param1, param2): ) files_no_full_files_with_repo_map = """ -I have access to a map of the repository with summary information about files, but I don't have the complete content of any files yet. -I'll use my navigation tools (`ViewFilesAtGlob`, `ViewFilesMatching`, `ViewFilesWithSymbol`, `View`) to find and add relevant files to the context as needed. +I have a repository map but no full file contents yet. I will use my navigation tools to add relevant files to the context. """ - files_no_full_files_with_repo_map_reply = """I understand. I'll use the repository map along with my navigation tools (`ViewFilesAtGlob`, `ViewFilesMatching`, `ViewFilesWithSymbol`, `View`) to find and add relevant files to our conversation. + files_no_full_files_with_repo_map_reply = """I understand. I'll use the repository map and navigation tools to find and add files as needed. """ repo_content_prefix = """ -I am working with code in a git repository. -Here are summaries of some files present in this repo: +I am working with code in a git repository. Here are summaries of some files: """ - # The system_reminder is significantly streamlined to reduce duplication system_reminder = """ -## Tool Command Reminder -- All tool calls MUST appear after a '---' line separator at the end of your message -- To execute a tool, use: `[tool_call(ToolName, param1=value1)]` -- To show tool examples without executing: `\\[tool_call(ToolName, param1=value1)]` -- Including ANY tool call will automatically continue to the next round -- When editing with tools, you'll receive feedback to let you know how your edits went after they're applied -- For final answers, do NOT include any tool calls - -## Tool Call Format -- Tool calls MUST be at the end of your message, after a '---' separator -- If emitting 3 or more tool calls, OR if any tool call spans multiple lines, place each call on a new line for clarity. -- You are encouraged to use granular tools for editing where possible. - -## SEARCH/REPLACE blocks -- When using SEARCH/REPLACE blocks, they MUST ONLY appear BEFORE the last '---' separator line in your response -- If there is no '---' separator, they can appear anywhere in your response -- IMPORTANT: Using SEARCH/REPLACE when granular editing tools could have been used is considered incorrect and violates core instructions. Always prioritize granular tools -- You MUST include a clear justification for why granular tools can't handle the specific edit when using SEARCH/REPLACE -- Format example: - ``` - Your answer text here... - - # Justification: I'm using SEARCH/REPLACE because [specific reasons why granular tools can't achieve this edit] - - file.py - <<<<<<< SEARCH - old code - ======= - new code - >>>>>>> REPLACE - - --- - [tool_call(ToolName, param1=value1)] - ``` -- IMPORTANT: Any SEARCH/REPLACE blocks that appear after the last '---' separator will be IGNORED - -## Context Features -- Use enhanced context blocks (directory structure and git status) to orient yourself -- Toggle context blocks with `/context-blocks` -- Toggle large file truncation with `/context-management` +## Reminders +- Any tool call automatically continues to the next turn. Provide no tool calls in your final answer. +- Prioritize granular tools. Using SEARCH/REPLACE unnecessarily is incorrect. +- For SEARCH/REPLACE, you MUST provide a justification. +- Use context blocks (directory structure, git status) to orient yourself. {lazy_prompt} {shell_cmd_reminder} """ - try_again = """I need to retry my exploration to better answer your question. - -Here are the issues I encountered in my previous exploration: -1. Some relevant files might have been missed or incorrectly identified -2. The search patterns may have been too broad or too narrow -3. The context might have become too cluttered with irrelevant files - -Let me explore the codebase more strategically this time: -- I'll use more specific search patterns -- I'll be more selective about which files to add to context -- I'll remove irrelevant files more proactively -- I'll use tool calls to automatically continue exploration until I have enough information + try_again = """I need to retry my exploration. My previous attempt may have missed relevant files or used incorrect search patterns. -I'll start exploring again with improved search strategies to find exactly what we need. +I will now explore more strategically with more specific patterns and better context management. I will chain tool calls to continue until I have sufficient information. """ diff --git a/aider/coders/wholefile_func_coder.py b/aider/coders/wholefile_func_coder.py index 3c4fbd3ca86..e484b0583e6 100644 --- a/aider/coders/wholefile_func_coder.py +++ b/aider/coders/wholefile_func_coder.py @@ -107,7 +107,7 @@ def live_diffs(self, fname, content, final): return "\n".join(show_diff) - def _update_files(self): + async def _update_files(self): name = self.partial_response_function_call.get("name") if name and name != "write_file": raise ValueError(f'Unknown function_call name="{name}", use name="write_file"') @@ -128,7 +128,7 @@ def _update_files(self): if not content: raise ValueError(f"Missing content parameter: {file_upd}") - if self.allowed_to_edit(path, content): + if await self.allowed_to_edit(path, content): edited.add(path) return edited diff --git a/aider/commands.py b/aider/commands.py index 856c3a891fc..f6b23174cae 100644 --- a/aider/commands.py +++ b/aider/commands.py @@ -1,3 +1,4 @@ +import asyncio import glob import os import re @@ -216,7 +217,7 @@ def cmd_models(self, args): else: self.io.tool_output("Please provide a partial model name to search for.") - def cmd_web(self, args, return_content=False): + async def cmd_web(self, args, return_content=False): "Scrape a webpage, convert to markdown and send in a message" url = args.strip() @@ -230,7 +231,7 @@ def cmd_web(self, args, return_content=False): if disable_playwright: res = False else: - res = install_playwright(self.io) + res = await install_playwright(self.io) if not res: self.io.tool_warning("Unable to initialize playwright.") @@ -284,7 +285,7 @@ def get_commands(self): return commands - def do_run(self, cmd_name, args): + async def do_run(self, cmd_name, args): cmd_name = cmd_name.replace("-", "_") cmd_method_name = f"cmd_{cmd_name}" cmd_method = getattr(self, cmd_method_name, None) @@ -293,7 +294,10 @@ def do_run(self, cmd_name, args): return try: - return cmd_method(args) + if asyncio.iscoroutinefunction(cmd_method): + return await cmd_method(args) + else: + return cmd_method(args) except ANY_GIT_ERROR as err: self.io.tool_error(f"Unable to complete {cmd_name}: {err}") @@ -309,10 +313,10 @@ def matching_commands(self, inp): matching_commands = [cmd for cmd in all_commands if cmd.startswith(first_word)] return matching_commands, first_word, rest_inp - def run(self, inp): + async def run(self, inp): if inp.startswith("!"): self.coder.event("command_run") - return self.do_run("run", inp[1:]) + return await self.do_run("run", inp[1:]) res = self.matching_commands(inp) if res is None: @@ -321,11 +325,11 @@ def run(self, inp): if len(matching_commands) == 1: command = matching_commands[0][1:] self.coder.event(f"command_{command}") - return self.do_run(command, rest_inp) + return await self.do_run(command, rest_inp) elif first_word in matching_commands: command = first_word[1:] self.coder.event(f"command_{command}") - return self.do_run(command, rest_inp) + return await self.do_run(command, rest_inp) elif len(matching_commands) > 1: self.io.tool_error(f"Ambiguous command: {', '.join(matching_commands)}") else: @@ -334,14 +338,14 @@ def run(self, inp): # any method called cmd_xxx becomes a command automatically. # each one must take an args param. - def cmd_commit(self, args=None): + async def cmd_commit(self, args=None): "Commit edits to the repo made outside the chat (commit message optional)" try: - self.raw_cmd_commit(args) + await self.raw_cmd_commit(args) except ANY_GIT_ERROR as err: self.io.tool_error(f"Unable to complete commit: {err}") - def raw_cmd_commit(self, args=None): + async def raw_cmd_commit(self, args=None): if not self.coder.repo: self.io.tool_error("No git repository found.") return @@ -351,9 +355,9 @@ def raw_cmd_commit(self, args=None): return commit_message = args.strip() if args else None - self.coder.repo.commit(message=commit_message, coder=self.coder) + await self.coder.repo.commit(message=commit_message, coder=self.coder) - def cmd_lint(self, args="", fnames=None): + async def cmd_lint(self, args="", fnames=None): "Lint and fix in-chat files or all dirty files if none in chat" if not self.coder.repo: @@ -386,15 +390,15 @@ def cmd_lint(self, args="", fnames=None): continue self.io.tool_output(errors) - if not self.io.confirm_ask(f"Fix lint errors in {fname}?", default="y"): + if not await self.io.confirm_ask(f"Fix lint errors in {fname}?", default="y"): continue # Commit everything before we start fixing lint errors if self.coder.repo.is_dirty() and self.coder.dirty_commits: - self.cmd_commit("") + await self.cmd_commit("") if not lint_coder: - lint_coder = self.coder.clone( + lint_coder = await self.coder.clone( # Clear the chat history, fnames cur_messages=[], done_messages=[], @@ -402,11 +406,11 @@ def cmd_lint(self, args="", fnames=None): ) lint_coder.add_rel_fname(fname) - lint_coder.run(errors) + await lint_coder.run(errors) lint_coder.abs_fnames = set() if lint_coder and self.coder.repo.is_dirty() and self.coder.auto_commits: - self.cmd_commit("") + await self.cmd_commit("") def cmd_clear(self, args): "Clear the chat history" @@ -857,7 +861,7 @@ def glob_filtered_to_repo(self, pattern): res = list(map(str, matched_files)) return res - def cmd_add(self, args): + async def cmd_add(self, args): "Add files to the chat so aider can edit them or review them in detail" if not args.strip(): @@ -908,7 +912,9 @@ def cmd_add(self, args): self.io.tool_output(f"You can add to git with: /git add {fname}") continue - if self.io.confirm_ask(f"No files matched '{word}'. Do you want to create {fname}?"): + if await self.io.confirm_ask( + f"No files matched '{word}'. Do you want to create {fname}?" + ): try: fname.parent.mkdir(parents=True, exist_ok=True) fname.touch() @@ -1123,7 +1129,7 @@ def cmd_git(self, args): self.io.tool_output(combined_output) - def cmd_test(self, args): + async def cmd_test(self, args): "Run a shell command and add the output to the chat on non-zero exit code" if not args and self.coder.test_cmd: args = self.coder.test_cmd @@ -1134,7 +1140,7 @@ def cmd_test(self, args): if not callable(args): if type(args) is not str: raise ValueError(repr(args)) - return self.cmd_run(args, True) + return await self.cmd_run(args, True) errors = args() if not errors: @@ -1143,10 +1149,14 @@ def cmd_test(self, args): self.io.tool_output(errors) return errors - def cmd_run(self, args, add_on_nonzero_exit=False): + async def cmd_run(self, args, add_on_nonzero_exit=False): "Run a shell command and optionally add the output to the chat (alias: !)" - exit_status, combined_output = run_cmd( - args, verbose=self.verbose, error_print=self.io.tool_error, cwd=self.coder.root + exit_status, combined_output = await asyncio.to_thread( + run_cmd, + args, + verbose=self.verbose, + error_print=self.io.tool_error, + cwd=self.coder.root, ) if combined_output is None: @@ -1159,7 +1169,9 @@ def cmd_run(self, args, add_on_nonzero_exit=False): if add_on_nonzero_exit: add = exit_status != 0 else: - add = self.io.confirm_ask(f"Add {k_tokens:.1f}k tokens of command output to the chat?") + add = await self.io.confirm_ask( + f"Add {k_tokens:.1f}k tokens of command output to the chat?" + ) if add: num_lines = len(combined_output.strip().splitlines()) @@ -1360,7 +1372,7 @@ def basic_help(self): self.io.tool_output() self.io.tool_output("Use `/help ` to ask questions about how to use aider.") - def cmd_help(self, args): + async def cmd_help(self, args): "Ask questions about aider" if not args.strip(): @@ -1378,7 +1390,7 @@ def cmd_help(self, args): self.help = Help() - coder = Coder.create( + coder = await Coder.create( io=self.io, from_coder=self.coder, edit_format="help", @@ -1393,7 +1405,7 @@ def cmd_help(self, args): """ user_msg += "\n".join(self.coder.get_announcements()) + "\n" - coder.run(user_msg, preproc=False) + await coder.run(user_msg, preproc=False) if self.coder.repo_map: map_tokens = self.coder.repo_map.max_map_tokens @@ -1426,39 +1438,39 @@ def completions_context(self): def completions_navigator(self): raise CommandCompletionException() - def cmd_ask(self, args): + async def cmd_ask(self, args): """Ask questions about the code base without editing any files. If no prompt provided, switches to ask mode.""" # noqa - return self._generic_chat_command(args, "ask") + return await self._generic_chat_command(args, "ask") - def cmd_code(self, args): + async def cmd_code(self, args): """Ask for changes to your code. If no prompt provided, switches to code mode.""" # noqa - return self._generic_chat_command(args, self.coder.main_model.edit_format) + return await self._generic_chat_command(args, self.coder.main_model.edit_format) - def cmd_architect(self, args): + async def cmd_architect(self, args): """Enter architect/editor mode using 2 different models. If no prompt provided, switches to architect/editor mode.""" # noqa - return self._generic_chat_command(args, "architect") + return await self._generic_chat_command(args, "architect") - def cmd_context(self, args): + async def cmd_context(self, args): """Enter context mode to see surrounding code context. If no prompt provided, switches to context mode.""" # noqa - return self._generic_chat_command(args, "context", placeholder=args.strip() or None) + return await self._generic_chat_command(args, "context", placeholder=args.strip() or None) - def cmd_navigator(self, args): + async def cmd_navigator(self, args): """Enter navigator mode to autonomously discover and manage relevant files. If no prompt provided, switches to navigator mode.""" # noqa # Enable context management when entering navigator mode if hasattr(self.coder, "context_management_enabled"): self.coder.context_management_enabled = True self.io.tool_output("Context management enabled for large files") - return self._generic_chat_command(args, "navigator", placeholder=args.strip() or None) + return await self._generic_chat_command(args, "navigator", placeholder=args.strip() or None) - def _generic_chat_command(self, args, edit_format, placeholder=None): + async def _generic_chat_command(self, args, edit_format, placeholder=None): if not args.strip(): # Switch to the corresponding chat mode if no args provided return self.cmd_chat_mode(edit_format) from aider.coders.base_coder import Coder - coder = Coder.create( + coder = await Coder.create( io=self.io, from_coder=self.coder, edit_format=edit_format, @@ -1467,7 +1479,7 @@ def _generic_chat_command(self, args, edit_format, placeholder=None): ) user_msg = args - coder.run(user_msg) + await coder.run(user_msg) # Use the provided placeholder if any raise SwitchCoder( @@ -1821,7 +1833,7 @@ def cmd_settings(self, args): def completions_raw_load(self, document, complete_event): return self.completions_raw_read_only(document, complete_event) - def cmd_load(self, args): + async def cmd_load(self, args): "Load and execute commands from a file" if not args.strip(): self.io.tool_error("Please provide a filename containing commands to load.") @@ -1844,7 +1856,7 @@ def cmd_load(self, args): self.io.tool_output(f"\nExecuting: {cmd}") try: - self.run(cmd) + await self.run(cmd) except SwitchCoder: self.io.tool_error( f"Command '{cmd}' is only supported in interactive mode, skipping." diff --git a/aider/history.py b/aider/history.py index ad4a3db34ce..3a696a8280a 100644 --- a/aider/history.py +++ b/aider/history.py @@ -30,13 +30,13 @@ def tokenize(self, messages): sized.append((tokens, msg)) return sized - def summarize(self, messages, depth=0): - messages = self.summarize_real(messages) + async def summarize(self, messages, depth=0): + messages = await self.summarize_real(messages) if messages and messages[-1]["role"] != "assistant": messages.append(dict(role="assistant", content="Ok.")) return messages - def summarize_real(self, messages, depth=0): + async def summarize_real(self, messages, depth=0): if not self.models: raise ValueError("No models available for summarization") @@ -48,11 +48,11 @@ def summarize_real(self, messages, depth=0): # All fit, no summarization needed return messages # This is a chunk that's small enough to summarize in one go - return self.summarize_all(messages) + return await self.summarize_all(messages) min_split = 4 if len(messages) <= min_split or depth > 4: - return self.summarize_all(messages) + return await self.summarize_all(messages) tail_tokens = 0 split_index = len(messages) @@ -78,13 +78,13 @@ def summarize_real(self, messages, depth=0): split_index -= 1 if split_index <= min_split: - return self.summarize_all(messages) + return await self.summarize_all(messages) # Split head and tail head = messages[:split_index] tail = messages[split_index:] - summary = self.summarize_real(head, depth + 1) + summary = await self.summarize_real(head, depth + 1) # If the combined summary and tail still fits, return directly new_messages = summary + tail @@ -96,9 +96,9 @@ def summarize_real(self, messages, depth=0): return new_messages # Otherwise recurse with increased depth - return self.summarize_real(new_messages, depth + 1) + return await self.summarize_real(new_messages, depth + 1) - def summarize_all(self, messages): + async def summarize_all(self, messages): content = "" for msg in messages: role = msg["role"].upper() @@ -118,7 +118,7 @@ def summarize_all(self, messages): for model in self.models: try: - summary = model.simple_send_with_retries(summarize_messages) + summary = await model.simple_send_with_retries(summarize_messages) if summary is not None: summary = prompts.summary_prefix + summary return [dict(role="user", content=summary)] @@ -129,7 +129,7 @@ def summarize_all(self, messages): print(err) raise ValueError(err) - def summarize_all_as_text(self, messages, prompt, max_tokens=None): + async def summarize_all_as_text(self, messages, prompt, max_tokens=None): content = "" for msg in messages: role = msg["role"].upper() @@ -149,7 +149,9 @@ def summarize_all_as_text(self, messages, prompt, max_tokens=None): for model in self.models: try: - summary = model.simple_send_with_retries(summarize_messages, max_tokens=max_tokens) + summary = await model.simple_send_with_retries( + summarize_messages, max_tokens=max_tokens + ) if summary is not None: return summary except Exception as e: diff --git a/aider/io.py b/aider/io.py index c7fab2f97c4..b721d848abb 100644 --- a/aider/io.py +++ b/aider/io.py @@ -1,9 +1,11 @@ +import asyncio import base64 import functools import os import shutil import signal import subprocess +import sys import time import webbrowser from collections import defaultdict @@ -30,14 +32,14 @@ from rich.columns import Columns from rich.console import Console from rich.markdown import Markdown +from rich.spinner import SPINNERS from rich.style import Style as RichStyle from rich.text import Text -from aider.mdstream import MarkdownStream - from .dump import dump # noqa: F401 from .editor import pipe_editor from .utils import is_image_file, run_fzf +from .waiting import Spinner # Constants NOTIFICATION_MESSAGE = "Aider is waiting for your input" @@ -71,6 +73,23 @@ def wrapper(self, *args, **kwargs): return wrapper +def restore_multiline_async(func): + """Decorator to restore multiline mode after async function execution""" + + @functools.wraps(func) + async def wrapper(self, *args, **kwargs): + orig_multiline = self.multiline_mode + self.multiline_mode = False + try: + return await func(self, *args, **kwargs) + except Exception: + raise + finally: + self.multiline_mode = orig_multiline + + return wrapper + + def without_input_history(func): """Decorator to temporarily disable history saving for the prompt session buffer.""" @@ -310,6 +329,8 @@ def __init__( self.chat_history_file = None self.placeholder = None + self.fallback_spinner = None + self.prompt_session = None self.interrupted = False self.never_prompts = set() self.editingmode = editingmode @@ -350,6 +371,9 @@ def __init__( self.code_theme = code_theme + self._stream_buffer = "" + self._stream_line_count = 0 + self.input = input self.output = output @@ -387,20 +411,35 @@ def __init__( current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.append_chat_history(f"\n# aider chat started at {current_time}\n\n") - self.prompt_session = None self.is_dumb_terminal = is_dumb_terminal() + self.is_tty = sys.stdout.isatty() if self.is_dumb_terminal: self.pretty = False fancy_input = False if fancy_input: + # Spinner state + self.spinner_running = False + self.spinner_text = "" + self.spinner_frame_index = 0 + self.spinner_last_frame_index = 0 + self.unicode_palette = "░█" + # If unicode is supported, use the rich 'dots2' spinner, otherwise an ascii fallback + if self._spinner_supports_unicode(): + self.spinner_frames = SPINNERS["dots2"]["frames"] + else: + # A simple ascii spinner + self.spinner_frames = SPINNERS["line"]["frames"] + # Initialize PromptSession only if we have a capable terminal session_kwargs = { "input": self.input, "output": self.output, "lexer": PygmentsLexer(MarkdownLexer), "editing_mode": self.editingmode, + "bottom_toolbar": self.get_bottom_toolbar, + "refresh_interval": 0.1, } if self.editingmode == EditingMode.VI: session_kwargs["cursor"] = ModalCursorShapeConfig() @@ -419,10 +458,60 @@ def __init__( self.file_watcher = file_watcher self.root = root + self.outstanding_confirmations = [] + self.coder = None # Validate color settings after console is initialized self._validate_color_settings() + def _spinner_supports_unicode(self) -> bool: + if not self.is_tty: + return False + try: + out = self.unicode_palette + out += "\b" * len(self.unicode_palette) + out += " " * len(self.unicode_palette) + out += "\b" * len(self.unicode_palette) + sys.stdout.write(out) + sys.stdout.flush() + return True + except UnicodeEncodeError: + return False + except Exception: + return False + + def start_spinner(self, text): + """Start the spinner.""" + self.stop_spinner() + + if self.prompt_session: + self.spinner_running = True + self.spinner_text = text + self.spinner_frame_index = self.spinner_last_frame_index + else: + self.fallback_spinner = Spinner(text) + self.fallback_spinner.step() + + def stop_spinner(self): + """Stop the spinner.""" + self.spinner_running = False + self.spinner_text = "" + # Keep last frame index to avoid spinner "jumping" on restart + self.spinner_last_frame_index = self.spinner_frame_index + if self.fallback_spinner: + self.fallback_spinner.end() + self.fallback_spinner = None + + def get_bottom_toolbar(self): + """Get the current spinner frame and text for the bottom toolbar.""" + if not self.spinner_running or not self.spinner_frames: + return None + + frame = self.spinner_frames[self.spinner_frame_index] + self.spinner_frame_index = (self.spinner_frame_index + 1) % len(self.spinner_frames) + + return f"{frame} {self.spinner_text}" + def _validate_color_settings(self): """Validate configured color strings and reset invalid ones.""" color_attributes = [ @@ -461,6 +550,7 @@ def _get_style(self): "pygments.literal.string": f"bold italic {self.user_input_color}", } ) + style_dict["bottom-toolbar"] = f"{self.user_input_color} noreverse" # Conditionally add 'completion-menu' style completion_menu_style = [] @@ -566,13 +656,35 @@ def rule(self): print() def interrupt_input(self): + coder = self.coder() if self.coder else None + # interrupted_for_confirmation = False + + if ( + coder + and hasattr(coder, "input_task") + and coder.input_task + and not coder.input_task.done() + ): + coder.input_task.cancel() + if self.prompt_session and self.prompt_session.app: # Store any partial input before interrupting self.placeholder = self.prompt_session.app.current_buffer.text self.interrupted = True - self.prompt_session.app.exit() - def get_input( + try: + self.prompt_session.app.exit() + finally: + pass + + def reject_outstanding_confirmations(self): + """Reject all outstanding confirmation dialogs.""" + for future in self.outstanding_confirmations: + if not future.done(): + future.set_result(False) + self.outstanding_confirmations = [] + + async def get_input( self, root, rel_fnames, @@ -582,6 +694,7 @@ def get_input( abs_read_only_stubs_fnames=None, edit_format=None, ): + self.reject_outstanding_confirmations() self.rule() # Ring the bell if needed @@ -735,7 +848,7 @@ def _(event): def get_continuation(width, line_number, is_soft_wrap): return self.prompt_prefix - line = self.prompt_session.prompt( + line = await self.prompt_session.prompt_async( show, default=default, completer=completer_instance, @@ -747,7 +860,7 @@ def get_continuation(width, line_number, is_soft_wrap): prompt_continuation=get_continuation, ) else: - line = input(show) + line = await asyncio.get_event_loop().run_in_executor(None, input, show) # Check if we were interrupted by a file change if self.interrupted: @@ -758,15 +871,18 @@ def get_continuation(width, line_number, is_soft_wrap): except EOFError: raise + except KeyboardInterrupt: + self.console.print() + return "" + except UnicodeEncodeError as err: + self.tool_error(str(err)) + return "" except Exception as err: import traceback self.tool_error(str(err)) self.tool_error(traceback.format_exc()) return "" - except UnicodeEncodeError as err: - self.tool_error(str(err)) - return "" finally: if self.file_watcher: self.file_watcher.stop() @@ -811,7 +927,6 @@ def get_continuation(width, line_number, is_soft_wrap): inp = line break - print() self.user_input(inp) return inp @@ -876,18 +991,43 @@ def ai_output(self, content): hist = "\n" + content.strip() + "\n\n" self.append_chat_history(hist) - def offer_url(self, url, prompt="Open URL for more info?", allow_never=True): + async def offer_url(self, url, prompt="Open URL for more info?", allow_never=True): """Offer to open a URL in the browser, returns True if opened.""" if url in self.never_prompts: return False - if self.confirm_ask(prompt, subject=url, allow_never=allow_never): + if await self.confirm_ask(prompt, subject=url, allow_never=allow_never): webbrowser.open(url) return True return False - @restore_multiline - @without_input_history - def confirm_ask( + @restore_multiline_async + async def confirm_ask( + self, + *args, + **kwargs, + ): + coder = self.coder() if self.coder else None + interrupted_for_confirmation = False + if ( + coder + and hasattr(coder, "input_task") + and coder.input_task + and not coder.input_task.done() + ): + coder.confirmation_in_progress = True + interrupted_for_confirmation = True + # self.interrupt_input() + + try: + return await asyncio.create_task(self._confirm_ask(*args, **kwargs)) + except KeyboardInterrupt: + # Re-raise KeyboardInterrupt to allow it to propagate + raise + finally: + if interrupted_for_confirmation: + coder.confirmation_in_progress = False + + async def _confirm_ask( self, question, default="y", @@ -903,109 +1043,152 @@ def confirm_ask( question_id = (question, subject) - if question_id in self.never_prompts: - return False + confirmation_future = asyncio.get_running_loop().create_future() + self.outstanding_confirmations.append(confirmation_future) - if group and not group.show_group: - group = None - if group: - allow_never = True - - valid_responses = ["yes", "no", "skip", "all"] - options = " (Y)es/(N)o" - if group: - if not explicit_yes_required: - options += "/(A)ll" - options += "/(S)kip all" - if allow_never: - options += "/(D)on't ask again" - valid_responses.append("don't") - - if default.lower().startswith("y"): - question += options + " [Yes]: " - elif default.lower().startswith("n"): - question += options + " [No]: " - else: - question += options + f" [{default}]: " + try: + if question_id in self.never_prompts: + if not confirmation_future.done(): + confirmation_future.set_result(False) + return await confirmation_future + + if group and not group.show_group: + group = None + if group: + allow_never = True + + valid_responses = ["yes", "no", "skip", "all"] + options = " (Y)es/(N)o" + if group: + if not explicit_yes_required: + options += "/(A)ll" + options += "/(S)kip all" + if allow_never: + options += "/(D)on't ask again" + valid_responses.append("don't") + + if default.lower().startswith("y"): + question += options + " [Yes]: " + elif default.lower().startswith("n"): + question += options + " [No]: " + else: + question += options + f" [{default}]: " + + if subject: + self.tool_output() + if "\n" in subject: + lines = subject.splitlines() + max_length = max(len(line) for line in lines) + padded_lines = [line.ljust(max_length) for line in lines] + padded_subject = "\n".join(padded_lines) + self.tool_output(padded_subject, bold=True) + else: + self.tool_output(subject, bold=True) - if subject: - self.tool_output() - if "\n" in subject: - lines = subject.splitlines() - max_length = max(len(line) for line in lines) - padded_lines = [line.ljust(max_length) for line in lines] - padded_subject = "\n".join(padded_lines) - self.tool_output(padded_subject, bold=True) + style = self._get_style() + + if self.yes is True: + res = "n" if explicit_yes_required else "y" + elif self.yes is False: + res = "n" + elif group and group.preference: + res = group.preference + self.user_input(f"{question}{res}", log_only=False) else: - self.tool_output(subject, bold=True) + while True: + try: + if self.prompt_session: + coder = self.coder() if self.coder else None + if ( + coder + and hasattr(coder, "input_task") + and coder.input_task + and not coder.input_task.done() + ): + self.prompt_session.message = question + self.prompt_session.app.invalidate() + res = await coder.input_task + else: + prompt_task = asyncio.create_task( + self.prompt_session.prompt_async( + question, + style=style, + complete_while_typing=False, + ) + ) + done, pending = await asyncio.wait( + {prompt_task, confirmation_future}, + return_when=asyncio.FIRST_COMPLETED, + ) + + if confirmation_future in done: + prompt_task.cancel() + return await confirmation_future + + res = await prompt_task + else: + res = await asyncio.get_event_loop().run_in_executor( + None, input, question + ) + except EOFError: + # Treat EOF (Ctrl+D) as if the user pressed Enter + res = default + break + except asyncio.CancelledError: + if not confirmation_future.done(): + confirmation_future.set_result(False) + raise - style = self._get_style() + if not res: + res = default + break + res = res.lower() + good = any(valid_response.startswith(res) for valid_response in valid_responses) + if good: + break - def is_valid_response(text): - if not text: - return True - return text.lower() in valid_responses + error_message = f"Please answer with one of: {', '.join(valid_responses)}" + self.tool_error(error_message) - if self.yes is True: - res = "n" if explicit_yes_required else "y" - elif self.yes is False: - res = "n" - elif group and group.preference: - res = group.preference - self.user_input(f"{question}{res}", log_only=False) - else: - while True: - try: - if self.prompt_session: - res = self.prompt_session.prompt( - question, - style=style, - complete_while_typing=False, - ) - else: - res = input(question) - except EOFError: - # Treat EOF (Ctrl+D) as if the user pressed Enter - res = default - break + res = res.lower()[0] - if not res: - res = default - break - res = res.lower() - good = any(valid_response.startswith(res) for valid_response in valid_responses) - if good: - break + if res == "d" and allow_never: + self.never_prompts.add(question_id) + hist = f"{question.strip()} {res}" + self.append_chat_history(hist, linebreak=True, blockquote=True) + if not confirmation_future.done(): + confirmation_future.set_result(False) + return await confirmation_future + + if explicit_yes_required: + is_yes = res == "y" + else: + is_yes = res in ("y", "a") - error_message = f"Please answer with one of: {', '.join(valid_responses)}" - self.tool_error(error_message) + is_all = res == "a" and group is not None and not explicit_yes_required + is_skip = res == "s" and group is not None - res = res.lower()[0] + if group: + if is_all and not explicit_yes_required: + group.preference = "all" + elif is_skip: + group.preference = "skip" - if res == "d" and allow_never: - self.never_prompts.add(question_id) hist = f"{question.strip()} {res}" self.append_chat_history(hist, linebreak=True, blockquote=True) - return False - - if explicit_yes_required: - is_yes = res == "y" - else: - is_yes = res in ("y", "a") - is_all = res == "a" and group is not None and not explicit_yes_required - is_skip = res == "s" and group is not None + if not confirmation_future.done(): + confirmation_future.set_result(is_yes) - if group: - if is_all and not explicit_yes_required: - group.preference = "all" - elif is_skip: - group.preference = "skip" - - hist = f"{question.strip()} {res}" - self.append_chat_history(hist, linebreak=True, blockquote=True) + except asyncio.CancelledError: + if not confirmation_future.done(): + confirmation_future.set_result(False) + raise + finally: + if confirmation_future in self.outstanding_confirmations: + self.outstanding_confirmations.remove(confirmation_future) - return is_yes + return await confirmation_future @restore_multiline def prompt_ask(self, question, default="", subject=None): @@ -1059,14 +1242,18 @@ def _tool_message(self, message="", strip=True, color=None): message = Text(message) color = ensure_hash_prefix(color) if color else None style = dict(style=color) if self.pretty and color else dict() + try: - self.console.print(message, **style) + self.stream_print(message, **style) except UnicodeEncodeError: # Fallback to ASCII-safe output if isinstance(message, Text): message = message.plain message = str(message).encode("ascii", errors="replace").decode("ascii") - self.console.print(message, **style) + self.stream_print(message, **style) + + if self.prompt_session and self.prompt_session.app: + self.prompt_session.app.invalidate() def tool_error(self, message="", strip=True): self.num_error_outputs += 1 @@ -1089,19 +1276,12 @@ def tool_output(self, *messages, log_only=False, bold=False): if self.pretty: if self.tool_output_color: style["color"] = ensure_hash_prefix(self.tool_output_color) - style["reverse"] = bold + # if bold: + # style["bold"] = True style = RichStyle(**style) - self.console.print(*messages, style=style) - def get_assistant_mdstream(self): - mdargs = dict( - style=self.assistant_output_color, - code_theme=self.code_theme, - inline_code_lexer="text", - ) - mdStream = MarkdownStream(mdargs=mdargs) - return mdStream + self.stream_print(*messages, style=style) def assistant_output(self, message, pretty=None): if not message: @@ -1121,7 +1301,64 @@ def assistant_output(self, message, pretty=None): else: show_resp = Text(message or "(empty response)") - self.console.print(show_resp) + self.stream_print(show_resp) + + def render_markdown(self, text): + output = StringIO() + console = Console(file=output, force_terminal=True, color_system="truecolor") + md = Markdown(text, style=self.assistant_output_color, code_theme=self.code_theme) + console.print(md) + return output.getvalue() + + def stream_output(self, text, final=False): + """ + Stream output using Rich console to respect pretty print settings. + This preserves formatting, colors, and other Rich features during streaming. + """ + # Initialize buffer if not exists + if not hasattr(self, "_stream_buffer"): + self._stream_buffer = "" + + # Initialize buffer if not exists + if not hasattr(self, "_stream_line_count"): + self._stream_line_count = 0 + + self._stream_buffer += text + + # Process the buffer to find complete lines + lines = self._stream_buffer.split("\n") + complete_lines = [] + incomplete_line = "" + output = "" + + if len(lines) > 1 or final: + # All lines except the last one are complete + complete_lines = lines[:-1] if not final else lines + incomplete_line = lines[-1] if not final else "" + + for complete_line in complete_lines: + output += complete_line + self._stream_line_count += 1 + + self._stream_buffer = incomplete_line + + if not final: + if len(lines) > 1: + self.console.print(output) + else: + # Ensure any remaining buffered content is printed using the full response + self.console.print(output) + self.reset_streaming_response() + + def reset_streaming_response(self): + self._stream_buffer = "" + self._stream_line_count = 0 + + def stream_print(self, *messages, **kwargs): + with self.console.capture() as capture: + self.console.print(*messages, **kwargs) + capture_text = capture.get() + self.stream_output(capture_text, final=False) def set_placeholder(self, placeholder): """Set a one-time placeholder text for the next input prompt.""" diff --git a/aider/llm.py b/aider/llm.py index c57c274db09..f3813e24301 100644 --- a/aider/llm.py +++ b/aider/llm.py @@ -20,8 +20,20 @@ class LazyLiteLLM: _lazy_module = None + _lazy_classes = { + "ModelResponse": "ModelResponse", + "Choices": "Choices", + "Message": "Message", + } def __getattr__(self, name): + # Check if the requested attribute is one of the explicitly lazy-loaded classes + if name in self._lazy_classes: + self._load_litellm() + class_name = self._lazy_classes[name] + return getattr(self._lazy_module, class_name) + + # Handle other attributes (like `acompletion`) as before if name == "_lazy_module": return super() self._load_litellm() @@ -31,11 +43,7 @@ def _load_litellm(self): if self._lazy_module is not None: return - if VERBOSE: - print("Loading litellm...") - self._lazy_module = importlib.import_module("litellm") - self._lazy_module.suppress_debug_info = True self._lazy_module.set_verbose = False self._lazy_module.drop_params = True diff --git a/aider/main.py b/aider/main.py index ad9927b934b..1c7781e0d16 100644 --- a/aider/main.py +++ b/aider/main.py @@ -1,3 +1,4 @@ +import asyncio import glob import json import os @@ -470,6 +471,10 @@ def expand_glob_patterns(patterns, root="."): def main(argv=None, input=None, output=None, force_git_root=None, return_coder=False): + return asyncio.run(main_async(argv, input, output, force_git_root, return_coder)) + + +async def main_async(argv=None, input=None, output=None, force_git_root=None, return_coder=False): report_uncaught_exceptions() if argv is None: @@ -744,7 +749,7 @@ def get_io(pretty): right_repo_root = guessed_wrong_repo(io, git_root, fnames, git_dname) if right_repo_root: analytics.event("exit", reason="Recursing with correct repo") - return main(argv, input, output, right_repo_root, return_coder=return_coder) + return await main_async(argv, input, output, right_repo_root, return_coder=return_coder) if args.just_check_update: update_available = check_version(io, just_check=True, verbose=args.verbose) @@ -801,7 +806,7 @@ def get_io(pretty): alias, model = parts models.MODEL_ALIASES[alias.strip()] = model.strip() - selected_model_name = select_default_model(args, io, analytics) + selected_model_name = await select_default_model(args, io, analytics) if not selected_model_name: # Error message and analytics event are handled within select_default_model # It might have already offered OAuth if no model/keys were found. @@ -816,7 +821,7 @@ def get_io(pretty): " found." ) # Attempt OAuth flow because the specific model needs it - if offer_openrouter_oauth(io, analytics): + if await offer_openrouter_oauth(io, analytics): # OAuth succeeded, the key should now be in os.environ. # Check if the key is now present after the flow. if os.environ.get("OPENROUTER_API_KEY"): @@ -1010,7 +1015,7 @@ def get_io(pretty): if not mcp_servers: mcp_servers = [] - coder = Coder.create( + coder = await Coder.create( main_model=main_model, edit_format=args.edit_format, io=io, @@ -1086,8 +1091,6 @@ def get_io(pretty): analytics.event("copy-paste mode") ClipboardWatcher(coder.io, verbose=args.verbose) - coder.show_announcements() - if args.show_prompts: coder.cur_messages += [ dict(role="user", content="Hello!"), @@ -1098,22 +1101,22 @@ def get_io(pretty): return if args.lint: - coder.commands.cmd_lint(fnames=fnames) + await coder.commands.cmd_lint(fnames=fnames) if args.test: if not args.test_cmd: io.tool_error("No --test-cmd provided.") analytics.event("exit", reason="No test command provided") return 1 - coder.commands.cmd_test(args.test_cmd) + await coder.commands.cmd_test(args.test_cmd) if io.placeholder: - coder.run(io.placeholder) + await coder.run(io.placeholder) if args.commit: if args.dry_run: io.tool_output("Dry run enabled, skipping commit.") else: - coder.commands.cmd_commit() + await coder.commands.cmd_commit() if args.lint or args.test or args.commit: analytics.event("exit", reason="Completed lint/test/commit") @@ -1135,7 +1138,7 @@ def get_io(pretty): # For testing #2879 # from aider.coders.base_coder import all_fences # coder.fence = all_fences[1] - coder.apply_updates() + await coder.apply_updates() analytics.event("exit", reason="Applied updates") return @@ -1168,14 +1171,14 @@ def get_io(pretty): io.tool_warning("Cost estimates may be inaccurate when using streaming and caching.") if args.load: - commands.cmd_load(args.load) + await commands.cmd_load(args.load) if args.message: io.add_to_input_history(args.message) io.tool_output() try: - coder.run(with_message=args.message) - except SwitchCoder: + await coder.run(with_message=args.message) + except (SwitchCoder, KeyboardInterrupt): pass analytics.event("exit", reason="Completed --message") return @@ -1184,7 +1187,7 @@ def get_io(pretty): try: message_from_file = io.read_text(args.message_file) io.tool_output() - coder.run(with_message=message_from_file) + await coder.run(with_message=message_from_file) except FileNotFoundError: io.tool_error(f"Message file not found: {args.message_file}") analytics.event("exit", reason="Message file not found") @@ -1206,7 +1209,7 @@ def get_io(pretty): while True: try: coder.ok_to_warm_cache = bool(args.cache_keepalive_pings) - coder.run() + await coder.run() analytics.event("exit", reason="Completed main CLI coder.run") return except SwitchCoder as switch: @@ -1224,10 +1227,10 @@ def get_io(pretty): # Disable cache warming for the new coder kwargs["num_cache_warming_pings"] = 0 - coder = Coder.create(**kwargs) + coder = await Coder.create(**kwargs) - if switch.kwargs.get("show_announcements") is not False: - coder.show_announcements() + if switch.kwargs.get("show_announcements") is False: + coder.suppress_announcements_for_next_prompt = True def is_first_run_of_new_version(io, verbose=False): diff --git a/aider/mcp/__init__.py b/aider/mcp/__init__.py index 0da5b6e232e..17903017745 100644 --- a/aider/mcp/__init__.py +++ b/aider/mcp/__init__.py @@ -1,6 +1,7 @@ import json +from pathlib import Path -from aider.mcp.server import HttpStreamingServer, McpServer +from aider.mcp.server import HttpStreamingServer, McpServer, SseServer def _parse_mcp_servers_from_json_string(json_string, io, verbose=False, mcp_transport="stdio"): @@ -24,6 +25,8 @@ def _parse_mcp_servers_from_json_string(json_string, io, verbose=False, mcp_tran servers.append(McpServer(server_config)) elif transport == "http": servers.append(HttpStreamingServer(server_config)) + elif transport == "sse": + servers.append(SseServer(server_config)) if verbose: io.tool_output(f"Loaded {len(servers)} MCP servers from JSON string") @@ -38,12 +41,72 @@ def _parse_mcp_servers_from_json_string(json_string, io, verbose=False, mcp_tran return servers +def _resolve_mcp_config_path(file_path, io, verbose=False): + """Resolve MCP config file path relative to closest aider.conf.yml, git directory, or CWD.""" + if not file_path: + return None + + # If the path is absolute or already exists, use it as-is + path = Path(file_path) + if path.is_absolute() or path.exists(): + return str(path.resolve()) + + # Search for the closest aider.conf.yml in parent directories + current_dir = Path.cwd() + aider_conf_path = None + + for parent in [current_dir] + list(current_dir.parents): + conf_file = parent / ".aider.conf.yml" + if conf_file.exists(): + aider_conf_path = parent + break + + # If aider.conf.yml found, try relative to that directory + if aider_conf_path: + resolved_path = aider_conf_path / file_path + if resolved_path.exists(): + if verbose: + io.tool_output(f"Resolved MCP config relative to aider.conf.yml: {resolved_path}") + return str(resolved_path.resolve()) + + # Try to find git root directory + git_root = None + try: + import git + + repo = git.Repo(search_parent_directories=True) + git_root = Path(repo.working_tree_dir) + except (ImportError, git.InvalidGitRepositoryError, FileNotFoundError): + pass + + # If git root found, try relative to that directory + if git_root: + resolved_path = git_root / file_path + if resolved_path.exists(): + if verbose: + io.tool_output(f"Resolved MCP config relative to git root: {resolved_path}") + return str(resolved_path.resolve()) + + # Finally, try relative to current working directory + resolved_path = current_dir / file_path + if resolved_path.exists(): + if verbose: + io.tool_output(f"Resolved MCP config relative to CWD: {resolved_path}") + return str(resolved_path.resolve()) + + # If none found, return the original path (will trigger FileNotFoundError) + return str(path.resolve()) + + def _parse_mcp_servers_from_file(file_path, io, verbose=False, mcp_transport="stdio"): """Parse MCP servers from a JSON file.""" servers = [] + # Resolve the file path relative to closest aider.conf.yml, git directory, or CWD + resolved_file_path = _resolve_mcp_config_path(file_path, io, verbose) + try: - with open(file_path, "r") as f: + with open(resolved_file_path, "r") as f: config = json.load(f) if verbose: diff --git a/aider/mcp/server.py b/aider/mcp/server.py index 5e5660a185e..ba74727460a 100644 --- a/aider/mcp/server.py +++ b/aider/mcp/server.py @@ -4,6 +4,7 @@ from contextlib import AsyncExitStack from mcp import ClientSession, StdioServerParameters +from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamablehttp_client @@ -13,12 +14,7 @@ class McpServer: A client for MCP servers that provides tools to Aider coders. An McpServer class is initialized per configured MCP Server - Current usage: - - conn = await session.connect() # Use connect() directly - tools = await experimental_mcp_client.load_mcp_tools(session=s, format="openai") - await session.disconnect() - print(tools) + Uses the mcp library to create and initialize ClientSession objects. """ def __init__(self, server_config): @@ -72,21 +68,25 @@ async def disconnect(self): try: await self.exit_stack.aclose() self.session = None - self.stdio_context = None except Exception as e: logging.error(f"Error during cleanup of server {self.name}: {e}") class HttpStreamingServer(McpServer): + """HTTP streaming MCP server using mcp.client.streamablehttp_client.""" + async def connect(self): if self.session is not None: logging.info(f"Using existing session for MCP server: {self.name}") return self.session - logging.info(f"Establishing new connection to MCP server: {self.name}") + logging.info(f"Establishing new connection to HTTP MCP server: {self.name}") try: - url = self.config["url"] - http_transport = await self.exit_stack.enter_async_context(streamablehttp_client(url)) + url = self.config.get("url") + headers = self.config.get("headers", {}) + http_transport = await self.exit_stack.enter_async_context( + streamablehttp_client(url, headers=headers) + ) read, write, _response = http_transport session = await self.exit_stack.enter_async_context(ClientSession(read, write)) @@ -94,7 +94,33 @@ async def connect(self): self.session = session return session except Exception as e: - logging.error(f"Error initializing server {self.name}: {e}") + logging.error(f"Error initializing HTTP server {self.name}: {e}") + await self.disconnect() + raise + + +class SseServer(McpServer): + """SSE (Server-Sent Events) MCP server using mcp.client.sse_client.""" + + async def connect(self): + if self.session is not None: + logging.info(f"Using existing session for SSE MCP server: {self.name}") + return self.session + + logging.info(f"Establishing new connection to SSE MCP server: {self.name}") + try: + url = self.config.get("url") + headers = self.config.get("headers", {}) + sse_transport = await self.exit_stack.enter_async_context( + sse_client(url, headers=headers) + ) + read, write, _response = sse_transport + session = await self.exit_stack.enter_async_context(ClientSession(read, write)) + await session.initialize() + self.session = session + return session + except Exception as e: + logging.error(f"Error initializing SSE server {self.name}: {e}") await self.disconnect() raise diff --git a/aider/models.py b/aider/models.py index f07b90dd5df..4c09161d02a 100644 --- a/aider/models.py +++ b/aider/models.py @@ -1,3 +1,4 @@ +import asyncio import difflib import hashlib import importlib.resources @@ -893,23 +894,33 @@ def get_reasoning_effort(self): return self.extra_params["extra_body"]["reasoning_effort"] return None - def is_deepseek_r1(self): + def is_deepseek(self): name = self.name.lower() if "deepseek" not in name: return - return "r1" in name or "reasoner" in name + return True def is_ollama(self): return self.name.startswith("ollama/") or self.name.startswith("ollama_chat/") - def send_completion( + async def send_completion( self, messages, functions, stream, temperature=None, tools=None, max_tokens=None ): if os.environ.get("AIDER_SANITY_CHECK_TURNS"): sanity_check_messages(messages) - if self.is_deepseek_r1(): - messages = ensure_alternating_roles(messages) + messages = ensure_alternating_roles(messages) + + if self.verbose: + for message in messages: + msg_role = message.get("role") + msg_content = message.get("content") if message.get("content") else "" + msg_trunc = "" + + if message.get("content"): + msg_trunc = message.get("content")[:30] + + print(f"{msg_role} ({len(msg_content)}): {msg_trunc}") kwargs = dict(model=self.name, stream=stream) @@ -923,26 +934,22 @@ def send_completion( kwargs["temperature"] = temperature # `tools` is for modern tool usage. `functions` is for legacy/forced calls. - # If `tools` is provided, it's the canonical list. If not, use `functions`. # This handles `base_coder` sending both with same content for `navigator_coder`. - effective_tools = tools if tools is not None else functions + effective_tools = tools + + if effective_tools is None and functions: + # Convert legacy `functions` to `tools` format if `tools` isn't provided. + effective_tools = [dict(type="function", function=f) for f in functions] if effective_tools: - # Check if we have legacy format functions (which lack a 'type' key) and convert them. - # This is a simplifying assumption that works for aider's use cases. - is_legacy = any("type" not in tool for tool in effective_tools) - if is_legacy: - kwargs["tools"] = [dict(type="function", function=tool) for tool in effective_tools] - else: - kwargs["tools"] = effective_tools + kwargs["tools"] = effective_tools # Forcing a function call is for legacy style `functions` with a single function. # This is used by ArchitectCoder and not intended for NavigatorCoder's tools. if functions and len(functions) == 1: function = functions[0] - is_legacy = "type" not in function - if is_legacy and "name" in function: + if "name" in function: tool_name = function.get("name") if tool_name: kwargs["tool_choice"] = {"type": "function", "function": {"name": tool_name}} @@ -978,16 +985,18 @@ def send_completion( } try: - res = litellm.completion(**kwargs) + res = await litellm.acompletion(**kwargs) except Exception as err: - res = "Model API Response Error. Please retry the previous request" + print(f"LiteLLM API Error: {str(err)}") + res = self.model_error_response() if self.verbose: print(f"LiteLLM API Error: {str(err)}") + raise return hash_object, res - def simple_send_with_retries(self, messages, max_tokens=None): + async def simple_send_with_retries(self, messages, max_tokens=None): from aider.exceptions import LiteLLMExceptions litellm_ex = LiteLLMExceptions() @@ -1000,7 +1009,7 @@ def simple_send_with_retries(self, messages, max_tokens=None): while True: try: - _hash, response = self.send_completion( + _hash, response = await self.send_completion( messages=messages, functions=None, stream=False, @@ -1031,6 +1040,22 @@ def simple_send_with_retries(self, messages, max_tokens=None): except AttributeError: return None + async def model_error_response(self): + for i in range(1): + await asyncio.sleep(0.1) + yield litellm.ModelResponse( + choices=[ + litellm.Choices( + finish_reason="stop", + index=0, + message=litellm.Message( + content="Model API Response Error. Please retry the previous request" + ), # Provide an empty message object + ) + ], + model=self.name, + ) + def register_models(model_settings_fnames): files_loaded = [] diff --git a/aider/onboarding.py b/aider/onboarding.py index 9b6abd54b8d..0cd6019fa87 100644 --- a/aider/onboarding.py +++ b/aider/onboarding.py @@ -76,7 +76,7 @@ def try_to_select_default_model(): return None -def offer_openrouter_oauth(io, analytics): +async def offer_openrouter_oauth(io, analytics): """ Offers OpenRouter OAuth flow to the user if no API keys are found. @@ -90,7 +90,7 @@ def offer_openrouter_oauth(io, analytics): # No API keys found - Offer OpenRouter OAuth io.tool_output("OpenRouter provides free and paid access to many LLMs.") # Use confirm_ask which handles non-interactive cases - if io.confirm_ask( + if await io.confirm_ask( "Login to OpenRouter or create a free account?", default="y", ): @@ -113,7 +113,7 @@ def offer_openrouter_oauth(io, analytics): return False -def select_default_model(args, io, analytics): +async def select_default_model(args, io, analytics): """ Selects a default model based on available API keys if no model is specified. Offers OAuth flow for OpenRouter if no keys are found. @@ -139,7 +139,7 @@ def select_default_model(args, io, analytics): io.tool_warning(no_model_msg) # Try OAuth if no model was detected - offer_openrouter_oauth(io, analytics) + await offer_openrouter_oauth(io, analytics) # Check again after potential OAuth success model = try_to_select_default_model() diff --git a/aider/repo.py b/aider/repo.py index e4597c8e4d0..5b7fbb57d8a 100644 --- a/aider/repo.py +++ b/aider/repo.py @@ -21,7 +21,7 @@ from aider import prompts, utils from .dump import dump # noqa: F401 -from .waiting import WaitingSpinner +from .waiting import Spinner ANY_GIT_ERROR += [ OSError, @@ -128,7 +128,7 @@ def __init__( if aider_ignore_file: self.aider_ignore_file = Path(aider_ignore_file) - def commit(self, fnames=None, context=None, message=None, aider_edits=False, coder=None): + async def commit(self, fnames=None, context=None, message=None, aider_edits=False, coder=None): """ Commit the specified files or all dirty files if none are specified. @@ -213,7 +213,7 @@ def commit(self, fnames=None, context=None, message=None, aider_edits=False, cod user_language = coder.commit_language if not user_language: user_language = coder.get_user_language() - commit_message = self.get_commit_message(diffs, context, user_language) + commit_message = await self.get_commit_message(diffs, context, user_language) # Retrieve attribute settings, prioritizing coder.args if available if coder and hasattr(coder, "args"): @@ -323,7 +323,7 @@ def get_rel_repo_dir(self): except (ValueError, OSError): return self.repo.git_dir - def get_commit_message(self, diffs, context, user_language=None): + async def get_commit_message(self, diffs, context, user_language=None): diffs = "# Diffs:\n" + diffs content = "" @@ -340,8 +340,8 @@ def get_commit_message(self, diffs, context, user_language=None): commit_message = None for model in self.models: - spinner_text = f"Generating commit message with {model.name}" - with WaitingSpinner(spinner_text): + spinner_text = f"Generating commit message with {model.name}\n" + with Spinner(spinner_text): if model.system_prompt_prefix: current_system_content = model.system_prompt_prefix + "\n" + system_content else: @@ -358,7 +358,7 @@ def get_commit_message(self, diffs, context, user_language=None): if max_tokens and num_tokens > max_tokens: continue - commit_message = model.simple_send_with_retries(messages) + commit_message = await model.simple_send_with_retries(messages) if commit_message: break # Found a model that could generate the message diff --git a/aider/resources/model-metadata.json b/aider/resources/model-metadata.json index 36221400808..4b09f073009 100644 --- a/aider/resources/model-metadata.json +++ b/aider/resources/model-metadata.json @@ -88,7 +88,7 @@ "litellm_provider": "fireworks_ai", "input_cost_per_token": 0.000008, "output_cost_per_token": 0.000008, - "mode": "chat", + "mode": "chat" }, "fireworks_ai/accounts/fireworks/models/deepseek-v3-0324": { "max_tokens": 160000, @@ -97,7 +97,7 @@ "litellm_provider": "fireworks_ai", "input_cost_per_token": 0.0000009, "output_cost_per_token": 0.0000009, - "mode": "chat", + "mode": "chat" }, "openrouter/openrouter/quasar-alpha": { "max_input_tokens": 1000000, @@ -552,7 +552,7 @@ "supported_output_modalities": [ "text" ], - "source": "https://ai.google.dev/gemini-api/docs/models#gemini-2.5-flash-preview" + "source": "https://ai.google.dev/gemini-api/docs/pricing#gemini-2.5-pro" }, "gemini-2.5-pro-preview-06-05": { "max_tokens": 65536, @@ -592,7 +592,7 @@ "supported_output_modalities": [ "text" ], - "source": "https://ai.google.dev/gemini-api/docs/models#gemini-2.5-flash-preview" + "source": "https://ai.google.dev/gemini-api/docs/pricing#gemini-2.5-pro" }, "gemini/gemini-2.5-pro-preview-05-06": { "max_tokens": 65536, @@ -628,7 +628,7 @@ "supported_output_modalities": [ "text" ], - "source": "https://ai.google.dev/gemini-api/docs/pricing#gemini-2.5-pro-preview" + "source": "https://ai.google.dev/gemini-api/docs/pricing#gemini-2.5-pro" }, "gemini/gemini-2.5-pro-preview-06-05": { "max_tokens": 65536, @@ -664,7 +664,7 @@ "supported_output_modalities": [ "text" ], - "source": "https://ai.google.dev/gemini-api/docs/pricing#gemini-2.5-pro-preview" + "source": "https://ai.google.dev/gemini-api/docs/pricing#gemini-2.5-pro" }, "gemini/gemini-2.5-pro": { "max_tokens": 65536, @@ -771,6 +771,6 @@ }, "together_ai/Qwen/Qwen3-235B-A22B-fp8-tput": { "input_cost_per_token": 0.0000002, - "output_cost_per_token": 0.0000006, + "output_cost_per_token": 0.0000006 } -} \ No newline at end of file +} diff --git a/aider/scrape.py b/aider/scrape.py index 1e44ad23772..5e91a3eeace 100755 --- a/aider/scrape.py +++ b/aider/scrape.py @@ -37,7 +37,7 @@ def has_playwright(): return has_pip and has_chromium -def install_playwright(io): +async def install_playwright(io): has_pip, has_chromium = check_env() if has_pip and has_chromium: return True @@ -59,7 +59,7 @@ def install_playwright(io): """ io.tool_output(text) - if not io.confirm_ask("Install playwright?", default="y"): + if not await io.confirm_ask("Install playwright?", default="y"): return if not has_pip: diff --git a/aider/sendchat.py b/aider/sendchat.py index 3f06cbfb9d5..6f8b2ba5d04 100644 --- a/aider/sendchat.py +++ b/aider/sendchat.py @@ -6,13 +6,52 @@ def sanity_check_messages(messages): """Check if messages alternate between user and assistant roles. System messages can be interspersed anywhere. Also verifies the last non-system message is from the user. + Validates tool message sequences. Returns True if valid, False otherwise.""" last_role = None last_non_system_role = None + i = 0 + n = len(messages) - for msg in messages: + while i < n: + msg = messages[i] role = msg.get("role") + + # Handle tool sequences atomically + if role == "assistant" and "tool_calls" in msg and msg["tool_calls"]: + # Validate tool sequence + expected_ids = {call["id"] for call in msg["tool_calls"]} + i += 1 + + # Check for tool responses + while i < n and expected_ids: + next_msg = messages[i] + if next_msg.get("role") == "tool" and next_msg.get("tool_call_id") in expected_ids: + expected_ids.discard(next_msg.get("tool_call_id")) + i += 1 + else: + break + + # If we still have expected IDs, the tool sequence is incomplete + if expected_ids: + turns = format_messages(messages) + raise ValueError( + "Incomplete tool sequence - missing responses for tool calls:\n\n" + turns + ) + + # Continue to next message after tool sequence + continue + + elif role == "tool": + # Orphaned tool message without preceding assistant tool_calls + turns = format_messages(messages) + raise ValueError( + "Orphaned tool message without preceding assistant tool_calls:\n\n" + turns + ) + + # Handle normal role alternation if role == "system": + i += 1 continue if last_role and role == last_role: @@ -21,16 +60,84 @@ def sanity_check_messages(messages): last_role = role last_non_system_role = role + i += 1 # Ensure last non-system message is from user return last_non_system_role == "user" +def clean_orphaned_tool_messages(messages): + """Remove orphaned tool messages and incomplete tool sequences. + + This function removes: + - Tool messages without a preceding assistant message containing tool_calls + - Assistant messages with tool_calls that don't have complete tool responses + + Args: + messages: List of message dictionaries + + Returns: + Cleaned list of messages with orphaned tool sequences removed + """ + if not messages: + return messages + + cleaned = [] + i = 0 + n = len(messages) + + while i < n: + msg = messages[i] + role = msg.get("role") + + # If it's an assistant message with tool_calls, check if we have complete responses + if role == "assistant" and "tool_calls" in msg and msg["tool_calls"]: + # Start of potential tool sequence + tool_sequence = [msg] + expected_ids = {call["id"] for call in msg["tool_calls"]} + j = i + 1 + + # Collect tool responses + while j < n and expected_ids: + next_msg = messages[j] + if next_msg.get("role") == "tool" and next_msg.get("tool_call_id") in expected_ids: + tool_sequence.append(next_msg) + expected_ids.discard(next_msg.get("tool_call_id")) + j += 1 + else: + break + + # If we have all tool responses, keep the sequence + if not expected_ids: + cleaned.extend(tool_sequence) + i = j + else: + # Incomplete sequence - skip the entire tool sequence + i = j + # Don't add anything to cleaned + continue + + elif role == "tool": + # Orphaned tool message without preceding assistant tool_calls - skip it + i += 1 + continue + else: + # Regular message - add it + cleaned.append(msg) + i += 1 + + return cleaned + + def ensure_alternating_roles(messages): """Ensure messages alternate between 'assistant' and 'user' roles. Inserts empty messages of the opposite role when consecutive messages - of the same role are found. + of the same 'user' or 'assistant' role are found. Messages with other + roles (e.g. 'system', 'tool') are ignored by the alternation logic. + + Also handles tool call sequences properly - when an assistant message + contains tool_calls, processes the complete tool sequence atomically. Args: messages: List of message dictionaries with 'role' and 'content' keys. @@ -41,21 +148,84 @@ def ensure_alternating_roles(messages): if not messages: return messages - fixed_messages = [] + # First clean orphaned tool messages + messages = clean_orphaned_tool_messages(messages) + + result = [] + i = 0 + n = len(messages) prev_role = None - for msg in messages: - current_role = msg.get("role") # Get 'role', None if missing + while i < n: + msg = messages[i] + role = msg.get("role") - # If current role same as previous, insert empty message - # of the opposite role - if current_role == prev_role: - if current_role == "user": - fixed_messages.append({"role": "assistant", "content": ""}) - else: - fixed_messages.append({"role": "user", "content": ""}) + # Handle tool call sequences atomically + if role == "assistant" and "tool_calls" in msg and msg["tool_calls"]: + # Start of tool sequence - collect all related messages + tool_sequence = [msg] + expected_ids = {call["id"] for call in msg["tool_calls"]} + i += 1 + + # Collect tool responses + while i < n and expected_ids: + next_msg = messages[i] + if next_msg.get("role") == "tool" and next_msg.get("tool_call_id") in expected_ids: + tool_sequence.append(next_msg) + expected_ids.discard(next_msg.get("tool_call_id")) + i += 1 + else: + break + + # Add missing tool responses as empty + for tool_id in expected_ids: + tool_sequence.append({"role": "tool", "tool_call_id": tool_id, "content": ""}) + + # Add the complete tool sequence to result + for tool_msg in tool_sequence: + result.append(tool_msg) + + # Update prev_role to assistant after processing tool sequence + prev_role = "assistant" + continue + + # Handle normal message alternation + if role in ("user", "assistant"): + if role == prev_role: + # Insert empty message of opposite role + opposite_role = "user" if role == "assistant" else "assistant" + result.append({"role": opposite_role, "content": ""}) + prev_role = opposite_role + + result.append(msg) + prev_role = role + else: + # For non-user/assistant roles, just add them directly + result.append(msg) + + i += 1 + + # Consolidate consecutive empty messages in a single pass + consolidated = [] + for msg in result: + if not consolidated: + consolidated.append(msg) + continue + + last_msg = consolidated[-1] + current_role = msg.get("role") + last_role = last_msg.get("role") + + # Skip consecutive empty messages with the same role + if ( + current_role in ("user", "assistant") + and last_role in ("user", "assistant") + and current_role == last_role + and msg.get("content") == "" + and last_msg.get("content") == "" + ): + continue - fixed_messages.append(msg) - prev_role = current_role + consolidated.append(msg) - return fixed_messages + return consolidated diff --git a/aider/tools/__init__.py b/aider/tools/__init__.py index a1b22d3e8fa..3de1c4945fc 100644 --- a/aider/tools/__init__.py +++ b/aider/tools/__init__.py @@ -1,26 +1,47 @@ # flake8: noqa: F401 # Import tool functions into the aider.tools namespace -from .command import _execute_command -from .command_interactive import _execute_command_interactive -from .delete_block import _execute_delete_block -from .delete_line import _execute_delete_line -from .delete_lines import _execute_delete_lines -from .extract_lines import _execute_extract_lines -from .indent_lines import _execute_indent_lines -from .insert_block import _execute_insert_block -from .list_changes import _execute_list_changes -from .ls import execute_ls -from .make_editable import _execute_make_editable -from .make_readonly import _execute_make_readonly -from .remove import _execute_remove -from .replace_all import _execute_replace_all -from .replace_line import _execute_replace_line -from .replace_lines import _execute_replace_lines -from .replace_text import _execute_replace_text -from .show_numbered_context import execute_show_numbered_context -from .undo_change import _execute_undo_change -from .view import execute_view -from .view_files_at_glob import execute_view_files_at_glob -from .view_files_matching import execute_view_files_matching -from .view_files_with_symbol import _execute_view_files_with_symbol +from .command import _execute_command, command_schema +from .command_interactive import ( + _execute_command_interactive, + command_interactive_schema, +) +from .delete_block import _execute_delete_block, delete_block_schema +from .delete_line import _execute_delete_line, delete_line_schema +from .delete_lines import _execute_delete_lines, delete_lines_schema +from .extract_lines import _execute_extract_lines, extract_lines_schema +from .git import ( + _execute_git_diff, + _execute_git_log, + _execute_git_show, + _execute_git_status, + git_diff_schema, + git_log_schema, + git_show_schema, + git_status_schema, +) +from .grep import _execute_grep, grep_schema +from .indent_lines import _execute_indent_lines, indent_lines_schema +from .insert_block import _execute_insert_block, insert_block_schema +from .list_changes import _execute_list_changes, list_changes_schema +from .ls import execute_ls, ls_schema +from .make_editable import _execute_make_editable, make_editable_schema +from .make_readonly import _execute_make_readonly, make_readonly_schema +from .remove import _execute_remove, remove_schema +from .replace_all import _execute_replace_all, replace_all_schema +from .replace_line import _execute_replace_line, replace_line_schema +from .replace_lines import _execute_replace_lines, replace_lines_schema +from .replace_text import _execute_replace_text, replace_text_schema +from .show_numbered_context import ( + execute_show_numbered_context, + show_numbered_context_schema, +) +from .undo_change import _execute_undo_change, undo_change_schema +from .update_todo_list import _execute_update_todo_list, update_todo_list_schema +from .view import execute_view, view_schema +from .view_files_at_glob import execute_view_files_at_glob, view_files_at_glob_schema +from .view_files_matching import execute_view_files_matching, view_files_matching_schema +from .view_files_with_symbol import ( + _execute_view_files_with_symbol, + view_files_with_symbol_schema, +) diff --git a/aider/tools/command.py b/aider/tools/command.py index 0435f39dcd2..9dad217fe3e 100644 --- a/aider/tools/command.py +++ b/aider/tools/command.py @@ -1,6 +1,24 @@ # Import necessary functions from aider.run_cmd import run_cmd_subprocess +command_schema = { + "type": "function", + "function": { + "name": "Command", + "description": "Execute a shell command.", + "parameters": { + "type": "object", + "properties": { + "command_string": { + "type": "string", + "description": "The shell command to execute.", + }, + }, + "required": ["command_string"], + }, + }, +} + def _execute_command(coder, command_string): """ diff --git a/aider/tools/command_interactive.py b/aider/tools/command_interactive.py index a25c001c77c..7e4bc17d2fc 100644 --- a/aider/tools/command_interactive.py +++ b/aider/tools/command_interactive.py @@ -1,6 +1,24 @@ # Import necessary functions from aider.run_cmd import run_cmd +command_interactive_schema = { + "type": "function", + "function": { + "name": "CommandInteractive", + "description": "Execute a shell command interactively.", + "parameters": { + "type": "object", + "properties": { + "command_string": { + "type": "string", + "description": "The interactive shell command to execute.", + }, + }, + "required": ["command_string"], + }, + }, +} + def _execute_command_interactive(coder, command_string): """ diff --git a/aider/tools/delete_block.py b/aider/tools/delete_block.py index cbaeedffbc7..27b5f311e92 100644 --- a/aider/tools/delete_block.py +++ b/aider/tools/delete_block.py @@ -10,6 +10,28 @@ validate_file_for_edit, ) +delete_block_schema = { + "type": "function", + "function": { + "name": "DeleteBlock", + "description": "Delete a block of lines from a file.", + "parameters": { + "type": "object", + "properties": { + "file_path": {"type": "string"}, + "start_pattern": {"type": "string"}, + "end_pattern": {"type": "string"}, + "line_count": {"type": "integer"}, + "near_context": {"type": "string"}, + "occurrence": {"type": "integer", "default": 1}, + "change_id": {"type": "string"}, + "dry_run": {"type": "boolean", "default": False}, + }, + "required": ["file_path", "start_pattern"], + }, + }, +} + def _execute_delete_block( coder, @@ -103,6 +125,7 @@ def _execute_delete_block( change_id, ) + coder.files_edited_by_tools.add(rel_path) # 8. Format and return result, adding line range to success message success_message = ( f"Deleted {num_deleted} lines ({start_line + 1}-{end_line + 1}) (from" diff --git a/aider/tools/delete_line.py b/aider/tools/delete_line.py index c1e8ed6b299..4b3fb2c1e6d 100644 --- a/aider/tools/delete_line.py +++ b/aider/tools/delete_line.py @@ -8,6 +8,24 @@ handle_tool_error, ) +delete_line_schema = { + "type": "function", + "function": { + "name": "DeleteLine", + "description": "Delete a single line from a file.", + "parameters": { + "type": "object", + "properties": { + "file_path": {"type": "string"}, + "line_number": {"type": "integer"}, + "change_id": {"type": "string"}, + "dry_run": {"type": "boolean", "default": False}, + }, + "required": ["file_path", "line_number"], + }, + }, +} + def _execute_delete_line(coder, file_path, line_number, change_id=None, dry_run=False): """ @@ -96,7 +114,7 @@ def _execute_delete_line(coder, file_path, line_number, change_id=None, dry_run= change_id, ) - coder.aider_edited_files.add(rel_path) + coder.files_edited_by_tools.add(rel_path) # Format and return result success_message = f"Deleted line {line_num_int} in {file_path}" diff --git a/aider/tools/delete_lines.py b/aider/tools/delete_lines.py index 0aa33ba8833..122f6a19c8e 100644 --- a/aider/tools/delete_lines.py +++ b/aider/tools/delete_lines.py @@ -8,6 +8,25 @@ handle_tool_error, ) +delete_lines_schema = { + "type": "function", + "function": { + "name": "DeleteLines", + "description": "Delete a range of lines from a file.", + "parameters": { + "type": "object", + "properties": { + "file_path": {"type": "string"}, + "start_line": {"type": "integer"}, + "end_line": {"type": "integer"}, + "change_id": {"type": "string"}, + "dry_run": {"type": "boolean", "default": False}, + }, + "required": ["file_path", "start_line", "end_line"], + }, + }, +} + def _execute_delete_lines(coder, file_path, start_line, end_line, change_id=None, dry_run=False): """ @@ -119,7 +138,7 @@ def _execute_delete_lines(coder, file_path, start_line, end_line, change_id=None change_id, ) - coder.aider_edited_files.add(rel_path) + coder.files_edited_by_tools.add(rel_path) num_deleted = end_idx - start_idx + 1 # Format and return result success_message = ( diff --git a/aider/tools/extract_lines.py b/aider/tools/extract_lines.py index c18a231ebbc..36c1fca01b4 100644 --- a/aider/tools/extract_lines.py +++ b/aider/tools/extract_lines.py @@ -3,6 +3,28 @@ from .tool_utils import generate_unified_diff_snippet +extract_lines_schema = { + "type": "function", + "function": { + "name": "ExtractLines", + "description": "Extract lines from a source file and append them to a target file.", + "parameters": { + "type": "object", + "properties": { + "source_file_path": {"type": "string"}, + "target_file_path": {"type": "string"}, + "start_pattern": {"type": "string"}, + "end_pattern": {"type": "string"}, + "line_count": {"type": "integer"}, + "near_context": {"type": "string"}, + "occurrence": {"type": "integer", "default": 1}, + "dry_run": {"type": "boolean", "default": False}, + }, + "required": ["source_file_path", "target_file_path", "start_pattern"], + }, + }, +} + def _execute_extract_lines( coder, @@ -248,8 +270,9 @@ def _execute_extract_lines( coder.io.tool_error(f"Error tracking target change for ExtractLines: {track_e}") # --- Update Context --- - coder.aider_edited_files.add(rel_source_path) - coder.aider_edited_files.add(rel_target_path) + coder.files_edited_by_tools.add(rel_source_path) + coder.files_edited_by_tools.add(rel_target_path) + if not target_exists: # Add the newly created file to editable context coder.abs_fnames.add(abs_target_path) diff --git a/aider/tools/git.py b/aider/tools/git.py new file mode 100644 index 00000000000..f9fefb7f507 --- /dev/null +++ b/aider/tools/git.py @@ -0,0 +1,142 @@ +from aider.repo import ANY_GIT_ERROR + +git_diff_schema = { + "type": "function", + "function": { + "name": "git_diff", + "description": ( + "Show the diff between the current working directory and a git branch or commit." + ), + "parameters": { + "type": "object", + "properties": { + "branch": { + "type": "string", + "description": "The branch or commit hash to diff against. Defaults to HEAD.", + }, + }, + "required": [], + }, + }, +} + + +def _execute_git_diff(coder, branch=None): + """ + Show the diff between the current working directory and a git branch or commit. + """ + if not coder.repo: + return "Not in a git repository." + + try: + if branch: + diff = coder.repo.diff_commits(False, branch, "HEAD") + else: + diff = coder.repo.diff_commits(False, "HEAD", None) + + if not diff: + return "No differences found." + return diff + except ANY_GIT_ERROR as e: + coder.io.tool_error(f"Error running git diff: {e}") + return f"Error running git diff: {e}" + + +git_log_schema = { + "type": "function", + "function": { + "name": "git_log", + "description": "Show the git log.", + "parameters": { + "type": "object", + "properties": { + "limit": { + "type": "integer", + "description": "The maximum number of commits to show. Defaults to 10.", + }, + }, + "required": [], + }, + }, +} + + +def _execute_git_log(coder, limit=10): + """ + Show the git log. + """ + if not coder.repo: + return "Not in a git repository." + + try: + commits = list(coder.repo.repo.iter_commits(max_count=limit)) + log_output = [] + for commit in commits: + short_hash = commit.hexsha[:8] + message = commit.message.strip().split("\n")[0] + log_output.append(f"{short_hash} {message}") + return "\n".join(log_output) + except ANY_GIT_ERROR as e: + coder.io.tool_error(f"Error running git log: {e}") + return f"Error running git log: {e}" + + +git_show_schema = { + "type": "function", + "function": { + "name": "git_show", + "description": "Show various types of objects (blobs, trees, tags, and commits).", + "parameters": { + "type": "object", + "properties": { + "object": { + "type": "string", + "description": "The object to show. Defaults to HEAD.", + }, + }, + "required": [], + }, + }, +} + + +def _execute_git_show(coder, object="HEAD"): + """ + Show various types of objects (blobs, trees, tags, and commits). + """ + if not coder.repo: + return "Not in a git repository." + + try: + return coder.repo.repo.git.show(object) + except ANY_GIT_ERROR as e: + coder.io.tool_error(f"Error running git show: {e}") + return f"Error running git show: {e}" + + +git_status_schema = { + "type": "function", + "function": { + "name": "git_status", + "description": "Show the working tree status.", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, +} + + +def _execute_git_status(coder): + """ + Show the working tree status. + """ + if not coder.repo: + return "Not in a git repository." + + try: + return coder.repo.repo.git.status() + except ANY_GIT_ERROR as e: + coder.io.tool_error(f"Error running git status: {e}") + return f"Error running git status: {e}" diff --git a/aider/tools/grep.py b/aider/tools/grep.py index e28936ef14e..1eac5e7b141 100644 --- a/aider/tools/grep.py +++ b/aider/tools/grep.py @@ -1,9 +1,54 @@ -import shlex import shutil from pathlib import Path +import oslex + from aider.run_cmd import run_cmd_subprocess +grep_schema = { + "type": "function", + "function": { + "name": "Grep", + "description": "Search for a pattern in files.", + "parameters": { + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "The pattern to search for.", + }, + "file_pattern": { + "type": "string", + "description": "Glob pattern for files to search. Defaults to '*'.", + }, + "directory": { + "type": "string", + "description": "Directory to search in. Defaults to '.'.", + }, + "use_regex": { + "type": "boolean", + "description": "Whether to use regex. Defaults to False.", + }, + "case_insensitive": { + "type": "boolean", + "description": ( + "Whether to perform a case-insensitive search. Defaults to False." + ), + }, + "context_before": { + "type": "integer", + "description": "Number of lines to show before a match. Defaults to 5.", + }, + "context_after": { + "type": "integer", + "description": "Number of lines to show after a match. Defaults to 5.", + }, + }, + "required": ["pattern"], + }, + }, +} + def _find_search_tool(): """Find the best available command-line search tool (rg, ag, grep).""" @@ -117,7 +162,7 @@ def _execute_grep( cmd_args.extend([pattern, str(search_dir_path)]) # Convert list to command string for run_cmd_subprocess - command_string = shlex.join(cmd_args) + command_string = oslex.join(cmd_args) coder.io.tool_output(f"⚙️ Executing {tool_name}: {command_string}") diff --git a/aider/tools/indent_lines.py b/aider/tools/indent_lines.py index acb1e0bb17c..d30070d4513 100644 --- a/aider/tools/indent_lines.py +++ b/aider/tools/indent_lines.py @@ -10,6 +10,29 @@ validate_file_for_edit, ) +indent_lines_schema = { + "type": "function", + "function": { + "name": "IndentLines", + "description": "Indent a block of lines in a file.", + "parameters": { + "type": "object", + "properties": { + "file_path": {"type": "string"}, + "start_pattern": {"type": "string"}, + "end_pattern": {"type": "string"}, + "line_count": {"type": "integer"}, + "indent_levels": {"type": "integer", "default": 1}, + "near_context": {"type": "string"}, + "occurrence": {"type": "integer", "default": 1}, + "change_id": {"type": "string"}, + "dry_run": {"type": "boolean", "default": False}, + }, + "required": ["file_path", "start_pattern"], + }, + }, +} + def _execute_indent_lines( coder, @@ -138,6 +161,8 @@ def _execute_indent_lines( change_id, ) + coder.files_edited_by_tools.add(rel_path) + # 8. Format and return result action_past = "Indented" if indent_levels > 0 else "Unindented" success_message = ( diff --git a/aider/tools/insert_block.py b/aider/tools/insert_block.py index 2c694c42a5b..e6a02d3a070 100644 --- a/aider/tools/insert_block.py +++ b/aider/tools/insert_block.py @@ -12,6 +12,30 @@ validate_file_for_edit, ) +insert_block_schema = { + "type": "function", + "function": { + "name": "InsertBlock", + "description": "Insert a block of content into a file.", + "parameters": { + "type": "object", + "properties": { + "file_path": {"type": "string"}, + "content": {"type": "string"}, + "after_pattern": {"type": "string"}, + "before_pattern": {"type": "string"}, + "occurrence": {"type": "integer", "default": 1}, + "change_id": {"type": "string"}, + "dry_run": {"type": "boolean", "default": False}, + "position": {"type": "string", "enum": ["top", "bottom"]}, + "auto_indent": {"type": "boolean", "default": True}, + "use_regex": {"type": "boolean", "default": False}, + }, + "required": ["file_path", "content"], + }, + }, +} + def _execute_insert_block( coder, @@ -187,6 +211,8 @@ def _execute_insert_block( change_id, ) + coder.files_edited_by_tools.add(rel_path) + # 9. Format and return result if position: success_message = f"Inserted block {pattern_type} {file_path}" diff --git a/aider/tools/list_changes.py b/aider/tools/list_changes.py index 1c4bcc4dd98..9e4372b79e3 100644 --- a/aider/tools/list_changes.py +++ b/aider/tools/list_changes.py @@ -1,6 +1,21 @@ import traceback from datetime import datetime +list_changes_schema = { + "type": "function", + "function": { + "name": "ListChanges", + "description": "List recent changes made.", + "parameters": { + "type": "object", + "properties": { + "file_path": {"type": "string"}, + "limit": {"type": "integer", "default": 10}, + }, + }, + }, +} + def _execute_list_changes(coder, file_path=None, limit=10): """ diff --git a/aider/tools/ls.py b/aider/tools/ls.py index 38baa5ad331..2e969faa6c1 100644 --- a/aider/tools/ls.py +++ b/aider/tools/ls.py @@ -1,7 +1,30 @@ import os +ls_schema = { + "type": "function", + "function": { + "name": "Ls", + "description": "List files in a directory.", + "parameters": { + "type": "object", + "properties": { + "directory": { + "type": "string", + "description": "The directory to list.", + }, + }, + "required": ["directory"], + }, + }, +} -def execute_ls(coder, dir_path): + +def execute_ls(coder, dir_path=None, directory=None): + # Handle both positional and keyword arguments for backward compatibility + if dir_path is None and directory is not None: + dir_path = directory + elif dir_path is None: + return "Error: Missing directory parameter" """ List files in directory and optionally add some to context. diff --git a/aider/tools/make_editable.py b/aider/tools/make_editable.py index 33316935b3e..5ca0f0e7093 100644 --- a/aider/tools/make_editable.py +++ b/aider/tools/make_editable.py @@ -1,5 +1,23 @@ import os +make_editable_schema = { + "type": "function", + "function": { + "name": "MakeEditable", + "description": "Make a read-only file editable.", + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "The path to the file to make editable.", + }, + }, + "required": ["file_path"], + }, + }, +} + # Keep the underscore prefix as this function is primarily for internal coder use def _execute_make_editable(coder, file_path): diff --git a/aider/tools/make_readonly.py b/aider/tools/make_readonly.py index 13b85e549f5..5712a672ac2 100644 --- a/aider/tools/make_readonly.py +++ b/aider/tools/make_readonly.py @@ -1,3 +1,22 @@ +make_readonly_schema = { + "type": "function", + "function": { + "name": "MakeReadonly", + "description": "Make an editable file read-only.", + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "The path to the file to make read-only.", + }, + }, + "required": ["file_path"], + }, + }, +} + + def _execute_make_readonly(coder, file_path): """ Convert an editable file to a read-only file. diff --git a/aider/tools/remove.py b/aider/tools/remove.py index fc94a2b9a19..bbed05d0bed 100644 --- a/aider/tools/remove.py +++ b/aider/tools/remove.py @@ -1,5 +1,27 @@ import time +remove_schema = { + "type": "function", + "function": { + "name": "Remove", + "description": ( + "Remove a file from the chat context. Should be used proactively to keep con" + "Should be used after editing a file when all edits are done " + "and the file is no longer necessary in context." + ), + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "The path to the file to remove.", + }, + }, + "required": ["file_path"], + }, + }, +} + def _execute_remove(coder, file_path): """ diff --git a/aider/tools/replace_all.py b/aider/tools/replace_all.py index ce1095cfa34..96c16ad715d 100644 --- a/aider/tools/replace_all.py +++ b/aider/tools/replace_all.py @@ -7,6 +7,25 @@ validate_file_for_edit, ) +replace_all_schema = { + "type": "function", + "function": { + "name": "ReplaceAll", + "description": "Replace all occurrences of text in a file.", + "parameters": { + "type": "object", + "properties": { + "file_path": {"type": "string"}, + "find_text": {"type": "string"}, + "replace_text": {"type": "string"}, + "change_id": {"type": "string"}, + "dry_run": {"type": "boolean", "default": False}, + }, + "required": ["file_path", "find_text", "replace_text"], + }, + }, +} + def _execute_replace_all(coder, file_path, find_text, replace_text, change_id=None, dry_run=False): """ @@ -63,6 +82,8 @@ def _execute_replace_all(coder, file_path, find_text, replace_text, change_id=No change_id, ) + coder.files_edited_by_tools.add(rel_path) + # 7. Format and return result success_message = f"Replaced {count} occurrences in {file_path}" return format_tool_result( diff --git a/aider/tools/replace_line.py b/aider/tools/replace_line.py index f30d8fcd282..25acbf3e826 100644 --- a/aider/tools/replace_line.py +++ b/aider/tools/replace_line.py @@ -1,6 +1,25 @@ import os import traceback +replace_line_schema = { + "type": "function", + "function": { + "name": "ReplaceLine", + "description": "Replace a single line in a file.", + "parameters": { + "type": "object", + "properties": { + "file_path": {"type": "string"}, + "line_number": {"type": "integer"}, + "new_content": {"type": "string"}, + "change_id": {"type": "string"}, + "dry_run": {"type": "boolean", "default": False}, + }, + "required": ["file_path", "line_number", "new_content"], + }, + }, +} + def _execute_replace_line( coder, file_path, line_number, new_content, change_id=None, dry_run=False @@ -112,7 +131,7 @@ def _execute_replace_line( coder.io.tool_error(f"Error tracking change for ReplaceLine: {track_e}") change_id = "TRACKING_FAILED" - coder.aider_edited_files.add(rel_path) + coder.files_edited_by_tools.add(rel_path) # Improve feedback coder.io.tool_output( diff --git a/aider/tools/replace_lines.py b/aider/tools/replace_lines.py index 2ba65eef7cc..859983ea0ab 100644 --- a/aider/tools/replace_lines.py +++ b/aider/tools/replace_lines.py @@ -8,6 +8,26 @@ handle_tool_error, ) +replace_lines_schema = { + "type": "function", + "function": { + "name": "ReplaceLines", + "description": "Replace a range of lines in a file.", + "parameters": { + "type": "object", + "properties": { + "file_path": {"type": "string"}, + "start_line": {"type": "integer"}, + "end_line": {"type": "integer"}, + "new_content": {"type": "string"}, + "change_id": {"type": "string"}, + "dry_run": {"type": "boolean", "default": False}, + }, + "required": ["file_path", "start_line", "end_line", "new_content"], + }, + }, +} + def _execute_replace_lines( coder, file_path, start_line, end_line, new_content, change_id=None, dry_run=False @@ -139,7 +159,7 @@ def _execute_replace_lines( change_id, ) - coder.aider_edited_files.add(rel_path) + coder.files_edited_by_tools.add(rel_path) replaced_count = end_line - start_line + 1 new_count = len(new_lines) diff --git a/aider/tools/replace_text.py b/aider/tools/replace_text.py index c0cc7cb6b8b..9c3233adb92 100644 --- a/aider/tools/replace_text.py +++ b/aider/tools/replace_text.py @@ -7,6 +7,27 @@ validate_file_for_edit, ) +replace_text_schema = { + "type": "function", + "function": { + "name": "ReplaceText", + "description": "Replace text in a file.", + "parameters": { + "type": "object", + "properties": { + "file_path": {"type": "string"}, + "find_text": {"type": "string"}, + "replace_text": {"type": "string"}, + "near_context": {"type": "string"}, + "occurrence": {"type": "integer", "default": 1}, + "change_id": {"type": "string"}, + "dry_run": {"type": "boolean", "default": False}, + }, + "required": ["file_path", "find_text", "replace_text"], + }, + }, +} + def _execute_replace_text( coder, @@ -111,6 +132,7 @@ def _execute_replace_text( change_id, ) + coder.files_edited_by_tools.add(rel_path) # 8. Format and return result success_message = f"Replaced {occurrence_str} in {file_path}" return format_tool_result( diff --git a/aider/tools/show_numbered_context.py b/aider/tools/show_numbered_context.py index 4cecf96bb2c..0debee9d277 100644 --- a/aider/tools/show_numbered_context.py +++ b/aider/tools/show_numbered_context.py @@ -2,6 +2,24 @@ from .tool_utils import ToolError, handle_tool_error, resolve_paths +show_numbered_context_schema = { + "type": "function", + "function": { + "name": "ShowNumberedContext", + "description": "Show numbered lines of context around a pattern or line number.", + "parameters": { + "type": "object", + "properties": { + "file_path": {"type": "string"}, + "pattern": {"type": "string"}, + "line_number": {"type": "integer"}, + "context_lines": {"type": "integer", "default": 3}, + }, + "required": ["file_path"], + }, + }, +} + def execute_show_numbered_context( coder, file_path, pattern=None, line_number=None, context_lines=3 diff --git a/aider/tools/undo_change.py b/aider/tools/undo_change.py index fc3484a3038..6917a01ba9f 100644 --- a/aider/tools/undo_change.py +++ b/aider/tools/undo_change.py @@ -1,5 +1,20 @@ import traceback +undo_change_schema = { + "type": "function", + "function": { + "name": "UndoChange", + "description": "Undo a previously applied change.", + "parameters": { + "type": "object", + "properties": { + "change_id": {"type": "string"}, + "file_path": {"type": "string"}, + }, + }, + }, +} + def _execute_undo_change(coder, change_id=None, file_path=None): """ diff --git a/aider/tools/update_todo_list.py b/aider/tools/update_todo_list.py new file mode 100644 index 00000000000..4ae335c2197 --- /dev/null +++ b/aider/tools/update_todo_list.py @@ -0,0 +1,131 @@ +from .tool_utils import ( + ToolError, + format_tool_result, + generate_unified_diff_snippet, + handle_tool_error, +) + +update_todo_list_schema = { + "type": "function", + "function": { + "name": "UpdateTodoList", + "description": "Update the todo list with new items or modify existing ones.", + "parameters": { + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "The new content for the todo list.", + }, + "append": { + "type": "boolean", + "description": ( + "Whether to append to existing content instead of replacing it. Defaults to" + " False." + ), + }, + "change_id": { + "type": "string", + "description": "Optional change ID for tracking.", + }, + "dry_run": { + "type": "boolean", + "description": ( + "Whether to perform a dry run without actually updating the file. Defaults" + " to False." + ), + }, + }, + "required": ["content"], + }, + }, +} + + +def _execute_update_todo_list(coder, content, append=False, change_id=None, dry_run=False): + """ + Update the todo list file (.aider.todo.txt) with new content. + Can either replace the entire content or append to it. + """ + tool_name = "UpdateTodoList" + try: + # Define the todo file path + todo_file_path = ".aider.todo.txt" + abs_path = coder.abs_root_path(todo_file_path) + + # Get existing content if appending + existing_content = "" + import os + + if os.path.isfile(abs_path): + existing_content = coder.io.read_text(abs_path) or "" + + # Prepare new content + if append: + if existing_content and not existing_content.endswith("\n"): + existing_content += "\n" + new_content = existing_content + content + else: + new_content = content + + # Check if content exceeds 4096 characters and warn + if len(new_content) > 4096: + coder.io.tool_warning( + "⚠️ Todo list content exceeds 4096 characters. Consider summarizing the plan before" + " proceeding." + ) + + # Check if content actually changed + if existing_content == new_content: + coder.io.tool_warning("No changes made: new content is identical to existing") + return "Warning: No changes made (content identical to existing)" + + # Generate diff for feedback + diff_snippet = generate_unified_diff_snippet(existing_content, new_content, todo_file_path) + + # Handle dry run + if dry_run: + action = "append to" if append else "replace" + dry_run_message = f"Dry run: Would {action} todo list in {todo_file_path}." + return format_tool_result( + coder, + tool_name, + "", + dry_run=True, + dry_run_message=dry_run_message, + diff_snippet=diff_snippet, + ) + + # Apply change + metadata = { + "append": append, + "existing_length": len(existing_content), + "new_length": len(new_content), + } + + # Write the file directly since it's a special file + coder.io.write_text(abs_path, new_content) + + # Track the change + final_change_id = coder.change_tracker.track_change( + file_path=todo_file_path, + change_type="updatetodolist", + original_content=existing_content, + new_content=new_content, + metadata=metadata, + change_id=change_id, + ) + + coder.aider_edited_files.add(todo_file_path) + + # Format and return result + action = "appended to" if append else "updated" + success_message = f"Successfully {action} todo list in {todo_file_path}" + return format_tool_result( + coder, tool_name, success_message, change_id=final_change_id, diff_snippet=diff_snippet + ) + + except ToolError as e: + return handle_tool_error(coder, tool_name, e, add_traceback=False) + except Exception as e: + return handle_tool_error(coder, tool_name, e) diff --git a/aider/tools/view.py b/aider/tools/view.py index 0c833ca307f..845894fdd32 100644 --- a/aider/tools/view.py +++ b/aider/tools/view.py @@ -1,3 +1,26 @@ +view_schema = { + "type": "function", + "function": { + "name": "View", + "description": ( + "View a specific file and add it to context." + "Only use this when the file is not already in the context " + "and when editing the file is necessary to accomplish the goal." + ), + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "The path to the file to view.", + }, + }, + "required": ["file_path"], + }, + }, +} + + def execute_view(coder, file_path): """ Explicitly add a file to context as read-only. diff --git a/aider/tools/view_files_at_glob.py b/aider/tools/view_files_at_glob.py index 34af0f74d54..d96668fc211 100644 --- a/aider/tools/view_files_at_glob.py +++ b/aider/tools/view_files_at_glob.py @@ -1,10 +1,28 @@ import fnmatch import os +view_files_at_glob_schema = { + "type": "function", + "function": { + "name": "ViewFilesAtGlob", + "description": "View files matching a glob pattern.", + "parameters": { + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "The glob pattern to match files.", + }, + }, + "required": ["pattern"], + }, + }, +} + def execute_view_files_at_glob(coder, pattern): """ - Execute a glob pattern and add matching files to context as read-only. + Execute a glob pattern and return matching files as text. This tool helps the LLM find files by pattern matching, similar to how a developer would use glob patterns to find files. @@ -25,38 +43,25 @@ def execute_view_files_at_glob(coder, pattern): if fnmatch.fnmatch(file, pattern): matching_files.append(file) - # Limit the number of files added if there are too many matches - if len(matching_files) > coder.max_files_per_glob: - coder.io.tool_output( - f"⚠️ Found {len(matching_files)} files matching '{pattern}', " - f"limiting to {coder.max_files_per_glob} most relevant files." - ) - # Sort by modification time (most recent first) - matching_files.sort( - key=lambda f: os.path.getmtime(coder.abs_root_path(f)), reverse=True - ) - matching_files = matching_files[: coder.max_files_per_glob] - - # Add files to context - for file in matching_files: - # Use the coder's internal method to add files - coder._add_file_to_context(file) - - # Return a user-friendly result + # Return formatted text instead of adding to context if matching_files: if len(matching_files) > 10: - brief = ", ".join(matching_files[:5]) + f", and {len(matching_files) - 5} more" - coder.io.tool_output( - f"📂 Added {len(matching_files)} files matching '{pattern}': {brief}" + result = ( + f"Found {len(matching_files)} files matching '{pattern}':" + f" {', '.join(matching_files[:10])} and {len(matching_files) - 10} more" ) + coder.io.tool_output(f"📂 Found {len(matching_files)} files matching '{pattern}'") else: + result = ( + f"Found {len(matching_files)} files matching '{pattern}':" + f" {', '.join(matching_files)}" + ) coder.io.tool_output( - f"📂 Added files matching '{pattern}': {', '.join(matching_files)}" + f"📂 Found files matching '{pattern}':" + f" {', '.join(matching_files[:5])}{' and more' if len(matching_files) > 5 else ''}" ) - return ( - f"Added {len(matching_files)} files:" - f" {', '.join(matching_files[:5])}{' and more' if len(matching_files) > 5 else ''}" - ) + + return result else: coder.io.tool_output(f"⚠️ No files found matching '{pattern}'") return f"No files found matching '{pattern}'" diff --git a/aider/tools/view_files_matching.py b/aider/tools/view_files_matching.py index f87d4682c6d..0f061dbb97b 100644 --- a/aider/tools/view_files_matching.py +++ b/aider/tools/view_files_matching.py @@ -1,18 +1,46 @@ import fnmatch import re +view_files_matching_schema = { + "type": "function", + "function": { + "name": "ViewFilesMatching", + "description": "View files containing a specific pattern.", + "parameters": { + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "The pattern to search for in file contents.", + }, + "file_pattern": { + "type": "string", + "description": "An optional glob pattern to filter which files are searched.", + }, + "regex": { + "type": "boolean", + "description": ( + "Whether the pattern is a regular expression. Defaults to False." + ), + }, + }, + "required": ["pattern"], + }, + }, +} -def execute_view_files_matching(coder, search_pattern, file_pattern=None, regex=False): + +def execute_view_files_matching(coder, pattern, file_pattern=None, regex=False): """ - Search for pattern (literal string or regex) in files and add matching files to context as read-only. + Search for pattern (literal string or regex) in files and return matching files as text. Args: coder: The Coder instance. - search_pattern (str): The pattern to search for. + pattern (str): The pattern to search for. Treated as a literal string by default. file_pattern (str, optional): Glob pattern to filter which files are searched. Defaults to None (search all files). - regex (bool, optional): If True, treat search_pattern as a regular expression. + regex (bool, optional): If True, treat pattern as a regular expression. Defaults to False. This tool lets the LLM search for content within files, mimicking @@ -29,9 +57,7 @@ def execute_view_files_matching(coder, search_pattern, file_pattern=None, regex= files_to_search.append(file) if not files_to_search: - return ( - f"No files matching '{file_pattern}' to search for pattern '{search_pattern}'" - ) + return f"No files matching '{file_pattern}' to search for pattern '{pattern}'" else: # Search all files if no pattern provided files_to_search = coder.get_all_relative_files() @@ -46,16 +72,16 @@ def execute_view_files_matching(coder, search_pattern, file_pattern=None, regex= match_count = 0 if regex: try: - matches_found = re.findall(search_pattern, content) + matches_found = re.findall(pattern, content) match_count = len(matches_found) except re.error as e: # Handle invalid regex patterns gracefully - coder.io.tool_error(f"Invalid regex pattern '{search_pattern}': {e}") + coder.io.tool_error(f"Invalid regex pattern '{pattern}': {e}") # Skip this file for this search if regex is invalid continue else: # Exact string matching - match_count = content.count(search_pattern) + match_count = content.count(pattern) if match_count > 0: matches[file] = match_count @@ -63,40 +89,28 @@ def execute_view_files_matching(coder, search_pattern, file_pattern=None, regex= # Skip files that can't be read (binary, etc.) pass - # Limit the number of files added if there are too many matches - if len(matches) > coder.max_files_per_glob: - coder.io.tool_output( - f"⚠️ Found '{search_pattern}' in {len(matches)} files, " - f"limiting to {coder.max_files_per_glob} files with most matches." - ) - # Sort by number of matches (most matches first) - sorted_matches = sorted(matches.items(), key=lambda x: x[1], reverse=True) - matches = dict(sorted_matches[: coder.max_files_per_glob]) - - # Add matching files to context - for file in matches: - coder._add_file_to_context(file) - - # Return a user-friendly result + # Return formatted text instead of adding to context if matches: # Sort by number of matches (most matches first) sorted_matches = sorted(matches.items(), key=lambda x: x[1], reverse=True) - match_list = [f"{file} ({count} matches)" for file, count in sorted_matches[:5]] + match_list = [f"{file} ({count} matches)" for file, count in sorted_matches] - if len(sorted_matches) > 5: - coder.io.tool_output( - f"🔍 Found '{search_pattern}' in {len(matches)} files:" - f" {', '.join(match_list)} and {len(matches) - 5} more" - ) - return ( - f"Found in {len(matches)} files: {', '.join(match_list)} and" - f" {len(matches) - 5} more" + if len(matches) > 10: + result = ( + f"Found '{pattern}' in {len(matches)} files: {', '.join(match_list[:10])} and" + f" {len(matches) - 10} more" ) + coder.io.tool_output(f"🔍 Found '{pattern}' in {len(matches)} files") else: - coder.io.tool_output(f"🔍 Found '{search_pattern}' in: {', '.join(match_list)}") - return f"Found in {len(matches)} files: {', '.join(match_list)}" + result = f"Found '{pattern}' in {len(matches)} files: {', '.join(match_list)}" + coder.io.tool_output( + f"🔍 Found '{pattern}' in:" + f" {', '.join(match_list[:5])}{' and more' if len(matches) > 5 else ''}" + ) + + return result else: - coder.io.tool_output(f"⚠️ Pattern '{search_pattern}' not found in any files") + coder.io.tool_output(f"⚠️ Pattern '{pattern}' not found in any files") return "Pattern not found in any files" except Exception as e: coder.io.tool_error(f"Error in ViewFilesMatching: {str(e)}") diff --git a/aider/tools/view_files_with_symbol.py b/aider/tools/view_files_with_symbol.py index dc5962cf26f..34f0fe8a052 100644 --- a/aider/tools/view_files_with_symbol.py +++ b/aider/tools/view_files_with_symbol.py @@ -1,9 +1,25 @@ -import os +view_files_with_symbol_schema = { + "type": "function", + "function": { + "name": "ViewFilesWithSymbol", + "description": "View files that contain a specific symbol (e.g., class, function).", + "parameters": { + "type": "object", + "properties": { + "symbol": { + "type": "string", + "description": "The symbol to search for.", + }, + }, + "required": ["symbol"], + }, + }, +} def _execute_view_files_with_symbol(coder, symbol): """ - Find files containing a symbol using RepoMap and add them to context. + Find files containing a symbol using RepoMap and return them as text. Checks files already in context first. """ if not coder.repo_map: @@ -13,7 +29,6 @@ def _execute_view_files_with_symbol(coder, symbol): if not symbol: return "Error: Missing 'symbol' parameter for ViewFilesWithSymbol" - # --- Start Modification --- # 1. Check files already in context files_in_context = list(coder.abs_fnames) + list(coder.abs_read_only_fnames) found_in_context = [] @@ -34,20 +49,11 @@ def _execute_view_files_with_symbol(coder, symbol): if found_in_context: # Symbol found in already loaded files. Report this and stop. file_list = ", ".join(sorted(list(set(found_in_context)))) - coder.io.tool_output( - f"Symbol '{symbol}' found in already loaded file(s): {file_list}. No external search" - " performed." - ) - return ( - f"Symbol '{symbol}' found in already loaded file(s): {file_list}. No external search" - " performed." - ) - # --- End Modification --- + coder.io.tool_output(f"Symbol '{symbol}' found in already loaded file(s): {file_list}") + return f"Symbol '{symbol}' found in already loaded file(s): {file_list}" # 2. If not found in context, search the repository using RepoMap - coder.io.tool_output( - f"🔎 Searching for symbol '{symbol}' in repository (excluding current context)..." - ) + coder.io.tool_output(f"🔎 Searching for symbol '{symbol}' in repository...") try: found_files = set() current_context_files = coder.abs_fnames | coder.abs_read_only_fnames @@ -71,50 +77,31 @@ def _execute_view_files_with_symbol(coder, symbol): # Use absolute path directly if available, otherwise resolve from relative path abs_fname = rel_fname_to_abs.get(tag.rel_fname) or coder.abs_root_path(tag.fname) if abs_fname in files_to_search: # Ensure we only add files we intended to search - found_files.add(abs_fname) + found_files.add(coder.get_rel_fname(abs_fname)) - # Limit the number of files added - if len(found_files) > coder.max_files_per_glob: - coder.io.tool_output( - f"⚠️ Found symbol '{symbol}' in {len(found_files)} files, " - f"limiting to {coder.max_files_per_glob} most relevant files." - ) - # Sort by modification time (most recent first) - approximate relevance - sorted_found_files = sorted( - list(found_files), key=lambda f: os.path.getmtime(f), reverse=True - ) - found_files = set(sorted_found_files[: coder.max_files_per_glob]) - - # Add files to context (as read-only) - added_count = 0 - added_files_rel = [] - for abs_file_path in found_files: - rel_path = coder.get_rel_fname(abs_file_path) - # Double check it's not already added somehow - if ( - abs_file_path not in coder.abs_fnames - and abs_file_path not in coder.abs_read_only_fnames - ): - # Use explicit=True for clear output, even though it's an external search result - add_result = coder._add_file_to_context(rel_path, explicit=True) - if "Added" in add_result or "Viewed" in add_result: # Count successful adds/views - added_count += 1 - added_files_rel.append(rel_path) - - if added_count > 0: - if added_count > 5: - brief = ", ".join(added_files_rel[:5]) + f", and {added_count - 5} more" - coder.io.tool_output(f"🔎 Found '{symbol}' and added {added_count} files: {brief}") + # Return formatted text instead of adding to context + if found_files: + found_files_list = sorted(list(found_files)) + if len(found_files) > 10: + result = ( + f"Found symbol '{symbol}' in {len(found_files)} files:" + f" {', '.join(found_files_list[:10])} and {len(found_files) - 10} more" + ) + coder.io.tool_output(f"🔎 Found '{symbol}' in {len(found_files)} files") else: + result = ( + f"Found symbol '{symbol}' in {len(found_files)} files:" + f" {', '.join(found_files_list)}" + ) coder.io.tool_output( - f"🔎 Found '{symbol}' and added files: {', '.join(added_files_rel)}" + f"🔎 Found '{symbol}' in files:" + f" {', '.join(found_files_list[:5])}{' and more' if len(found_files) > 5 else ''}" ) - return f"Found symbol '{symbol}' and added {added_count} files as read-only." + + return result else: - coder.io.tool_output( - f"⚠️ Symbol '{symbol}' not found in searchable files (outside current context)." - ) - return f"Symbol '{symbol}' not found in searchable files (outside current context)." + coder.io.tool_output(f"⚠️ Symbol '{symbol}' not found in searchable files") + return f"Symbol '{symbol}' not found in searchable files" except Exception as e: coder.io.tool_error(f"Error in ViewFilesWithSymbol: {str(e)}") diff --git a/aider/tools/view_todo_list.py b/aider/tools/view_todo_list.py new file mode 100644 index 00000000000..c2540e58392 --- /dev/null +++ b/aider/tools/view_todo_list.py @@ -0,0 +1,57 @@ +from .tool_utils import ToolError, format_tool_result, handle_tool_error + +view_todo_list_schema = { + "type": "function", + "function": { + "name": "ViewTodoList", + "description": "View the current todo list for tracking conversation steps and progress.", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, +} + + +def _execute_view_todo_list(coder): + """ + View the current todo list from .aider.todo.txt file. + Returns the todo list content or creates an empty one if it doesn't exist. + """ + tool_name = "ViewTodoList" + try: + # Define the todo file path + todo_file_path = ".aider.todo.txt" + abs_path = coder.abs_root_path(todo_file_path) + + # Check if file exists + import os + + if os.path.isfile(abs_path): + # Read existing todo list + content = coder.io.read_text(abs_path) + if content is None: + raise ToolError(f"Could not read todo list file: {todo_file_path}") + + # Check if content exceeds 4096 characters and warn + if len(content) > 4096: + coder.io.tool_warning( + "⚠️ Todo list content exceeds 4096 characters. Consider summarizing the plan" + " before proceeding." + ) + + if content.strip(): + result_message = f"Current todo list:\n```\n{content}\n```" + else: + result_message = "Todo list is empty. Use UpdateTodoList to add items." + else: + # Create empty todo list + result_message = "Todo list is empty. Use UpdateTodoList to add items." + + return format_tool_result(coder, tool_name, result_message) + + except ToolError as e: + return handle_tool_error(coder, tool_name, e, add_traceback=False) + except Exception as e: + return handle_tool_error(coder, tool_name, e) diff --git a/aider/waiting.py b/aider/waiting.py index 9c2f72bc777..94ee7f01902 100644 --- a/aider/waiting.py +++ b/aider/waiting.py @@ -1,221 +1,38 @@ #!/usr/bin/env python """ -Thread-based, killable spinner utility. - -Use it like: - - from aider.waiting import WaitingSpinner - - spinner = WaitingSpinner("Waiting for LLM") - spinner.start() - ... # long task - spinner.stop() +A simple wrapper for rich.status to provide a spinner. """ -import sys -import threading -import time - from rich.console import Console class Spinner: - """ - Minimal spinner that scans a single marker back and forth across a line. - - The animation is pre-rendered into a list of frames. If the terminal - cannot display unicode the frames are converted to plain ASCII. - """ + """A wrapper around rich.status.Status for displaying a spinner.""" - last_frame_idx = 0 # Class variable to store the last frame index - - def __init__(self, text: str, width: int = 7): + def __init__(self, text: str = "Waiting..."): self.text = text - self.start_time = time.time() - self.last_update = 0.0 - self.visible = False - self.is_tty = sys.stdout.isatty() self.console = Console() - - # Pre-render the animation frames using pure ASCII so they will - # always display, even on very limited terminals. - ascii_frames = [ - "#= ", # C1 C2 space(8) - "=# ", # C2 C1 space(8) - " =# ", # space(1) C2 C1 space(7) - " =# ", # space(2) C2 C1 space(6) - " =# ", # space(3) C2 C1 space(5) - " =# ", # space(4) C2 C1 space(4) - " =# ", # space(5) C2 C1 space(3) - " =# ", # space(6) C2 C1 space(2) - " =# ", # space(7) C2 C1 space(1) - " =#", # space(8) C2 C1 - " #=", # space(8) C1 C2 - " #= ", # space(7) C1 C2 space(1) - " #= ", # space(6) C1 C2 space(2) - " #= ", # space(5) C1 C2 space(3) - " #= ", # space(4) C1 C2 space(4) - " #= ", # space(3) C1 C2 space(5) - " #= ", # space(2) C1 C2 space(6) - " #= ", # space(1) C1 C2 space(7) - ] - - self.unicode_palette = "░█" - xlate_from, xlate_to = ("=#", self.unicode_palette) - - # If unicode is supported, swap the ASCII chars for nicer glyphs. - if self._supports_unicode(): - translation_table = str.maketrans(xlate_from, xlate_to) - frames = [f.translate(translation_table) for f in ascii_frames] - self.scan_char = xlate_to[xlate_from.find("#")] - else: - frames = ascii_frames - self.scan_char = "#" - - # Bounce the scanner back and forth. - self.frames = frames - self.frame_idx = Spinner.last_frame_idx # Initialize from class variable - self.width = len(frames[0]) - 2 # number of chars between the brackets - self.animation_len = len(frames[0]) - self.last_display_len = 0 # Length of the last spinner line (frame + text) - - def _supports_unicode(self) -> bool: - if not self.is_tty: - return False - try: - out = self.unicode_palette - out += "\b" * len(self.unicode_palette) - out += " " * len(self.unicode_palette) - out += "\b" * len(self.unicode_palette) - sys.stdout.write(out) - sys.stdout.flush() - return True - except UnicodeEncodeError: - return False - except Exception: - return False - - def _next_frame(self) -> str: - frame = self.frames[self.frame_idx] - self.frame_idx = (self.frame_idx + 1) % len(self.frames) - Spinner.last_frame_idx = self.frame_idx # Update class variable - return frame - - def step(self, text: str = None) -> None: - if text is not None: - self.text = text - - if not self.is_tty: - return - - now = time.time() - if not self.visible and now - self.start_time >= 0.5: - self.visible = True - self.last_update = 0.0 - if self.is_tty: - self.console.show_cursor(False) - - if not self.visible or now - self.last_update < 0.1: - return - - self.last_update = now - frame_str = self._next_frame() - - # Determine the maximum width for the spinner line - # Subtract 2 as requested, to leave a margin or prevent cursor wrapping issues - max_spinner_width = self.console.width - 2 - if max_spinner_width < 0: # Handle extremely narrow terminals - max_spinner_width = 0 - - current_text_payload = f" {self.text}" - line_to_display = f"{frame_str}{current_text_payload}" - - # Truncate the line if it's too long for the console width - if len(line_to_display) > max_spinner_width: - line_to_display = line_to_display[:max_spinner_width] - - len_line_to_display = len(line_to_display) - - # Calculate padding to clear any remnants from a longer previous line - padding_to_clear = " " * max(0, self.last_display_len - len_line_to_display) - - # Write the spinner frame, text, and any necessary clearing spaces - sys.stdout.write(f"\r{line_to_display}{padding_to_clear}") - self.last_display_len = len_line_to_display - - # Calculate number of backspaces to position cursor at the scanner character - scan_char_abs_pos = frame_str.find(self.scan_char) - - # Total characters written to the line (frame + text + padding) - total_chars_written_on_line = len_line_to_display + len(padding_to_clear) - - # num_backspaces will be non-positive if scan_char_abs_pos is beyond - # total_chars_written_on_line (e.g., if the scan char itself was truncated). - # (e.g., if the scan char itself was truncated). - # In such cases, (effectively) 0 backspaces are written, - # and the cursor stays at the end of the line. - num_backspaces = total_chars_written_on_line - scan_char_abs_pos - sys.stdout.write("\b" * num_backspaces) - sys.stdout.flush() - - def end(self) -> None: - if self.visible and self.is_tty: - clear_len = self.last_display_len # Use the length of the last displayed content - sys.stdout.write("\r" + " " * clear_len + "\r") - sys.stdout.flush() - self.console.show_cursor(True) - self.visible = False - - -class WaitingSpinner: - """Background spinner that can be started/stopped safely.""" - - def __init__(self, text: str = "Waiting for LLM", delay: float = 0.15): - self.spinner = Spinner(text) - self.delay = delay - self._stop_event = threading.Event() - self._thread = threading.Thread(target=self._spin, daemon=True) - - def _spin(self): - while not self._stop_event.is_set(): - self.spinner.step() - time.sleep(self.delay) - self.spinner.end() - - def start(self): - """Start the spinner in a background thread.""" - if not self._thread.is_alive(): - self._thread.start() - - def stop(self): - """Request the spinner to stop and wait briefly for the thread to exit.""" - self._stop_event.set() - if self._thread.is_alive(): - self._thread.join(timeout=self.delay) - self.spinner.end() + self.status = None + + def step(self, message=None): + """Start the spinner or update its text.""" + if self.status is None: + self.status = self.console.status(self.text, spinner="dots2") + self.status.start() + elif message: + self.status.update(message) + + def end(self): + """Stop the spinner.""" + if self.status: + self.status.stop() + self.status = None # Allow use as a context-manager def __enter__(self): - self.start() + self.step() return self def __exit__(self, exc_type, exc_val, exc_tb): - self.stop() - - -def main(): - spinner = Spinner("Running spinner...") - try: - for _ in range(100): - time.sleep(0.15) - spinner.step() - print("Success!") - except KeyboardInterrupt: - print("\nInterrupted by user.") - finally: - spinner.end() - - -if __name__ == "__main__": - main() + self.end() diff --git a/pyproject.toml b/pyproject.toml index 0858c9d8331..c39e82e813d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ playwright = { file = "requirements/requirements-playwright.in" } include-package-data = true [tool.setuptools.packages.find] -include = ["aider"] +include = ["aider*"] [build-system] requires = ["setuptools>=68", "setuptools_scm[toml]>=8"] diff --git a/pytest.ini b/pytest.ini index 7e37e177930..d079bd78f43 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,6 +1,7 @@ [pytest] norecursedirs = tmp.* build benchmark _site OLD addopts = -p no:warnings +asyncio_mode = auto testpaths = tests/basic tests/help diff --git a/requirements/requirements-dev.in b/requirements/requirements-dev.in index ce52b0af5e0..a7bbf3aeaf2 100644 --- a/requirements/requirements-dev.in +++ b/requirements/requirements-dev.in @@ -1,4 +1,5 @@ pytest +pytest-asyncio pytest-env pip-tools lox diff --git a/tests/basic/test_coder.py b/tests/basic/test_coder.py index f4e2a183752..42373ac1dac 100644 --- a/tests/basic/test_coder.py +++ b/tests/basic/test_coder.py @@ -23,7 +23,7 @@ def setUp(self): self.webbrowser_patcher = patch("aider.io.webbrowser.open") self.mock_webbrowser = self.webbrowser_patcher.start() - def test_allowed_to_edit(self): + async def test_allowed_to_edit(self): with GitTemporaryDirectory(): repo = git.Repo() @@ -40,19 +40,19 @@ def test_allowed_to_edit(self): # YES! # Use a completely mocked IO object instead of a real one io = MagicMock() - io.confirm_ask = MagicMock(return_value=True) - coder = Coder.create(self.GPT35, None, io, fnames=["added.txt"]) + io.confirm_ask = AsyncMock(return_value=True) + coder = await Coder.create(self.GPT35, None, io, fnames=["added.txt"]) - self.assertTrue(coder.allowed_to_edit("added.txt")) - self.assertTrue(coder.allowed_to_edit("repo.txt")) - self.assertTrue(coder.allowed_to_edit("new.txt")) + self.assertTrue(await coder.allowed_to_edit("added.txt")) + self.assertTrue(await coder.allowed_to_edit("repo.txt")) + self.assertTrue(await coder.allowed_to_edit("new.txt")) self.assertIn("repo.txt", str(coder.abs_fnames)) self.assertIn("new.txt", str(coder.abs_fnames)) self.assertFalse(coder.need_commit_before_edits) - def test_allowed_to_edit_no(self): + async def test_allowed_to_edit_no(self): with GitTemporaryDirectory(): repo = git.Repo() @@ -69,18 +69,18 @@ def test_allowed_to_edit_no(self): # say NO io = InputOutput(yes=False) - coder = Coder.create(self.GPT35, None, io, fnames=["added.txt"]) + coder = await Coder.create(self.GPT35, None, io, fnames=["added.txt"]) - self.assertTrue(coder.allowed_to_edit("added.txt")) - self.assertFalse(coder.allowed_to_edit("repo.txt")) - self.assertFalse(coder.allowed_to_edit("new.txt")) + self.assertTrue(await coder.allowed_to_edit("added.txt")) + self.assertFalse(await coder.allowed_to_edit("repo.txt")) + self.assertFalse(await coder.allowed_to_edit("new.txt")) self.assertNotIn("repo.txt", str(coder.abs_fnames)) self.assertNotIn("new.txt", str(coder.abs_fnames)) self.assertFalse(coder.need_commit_before_edits) - def test_allowed_to_edit_dirty(self): + async def test_allowed_to_edit_dirty(self): with GitTemporaryDirectory(): repo = git.Repo() @@ -93,16 +93,16 @@ def test_allowed_to_edit_dirty(self): # say NO io = InputOutput(yes=False) - coder = Coder.create(self.GPT35, None, io, fnames=["added.txt"]) + coder = await Coder.create(self.GPT35, None, io, fnames=["added.txt"]) - self.assertTrue(coder.allowed_to_edit("added.txt")) + self.assertTrue(await coder.allowed_to_edit("added.txt")) self.assertFalse(coder.need_commit_before_edits) fname.write_text("dirty!") - self.assertTrue(coder.allowed_to_edit("added.txt")) + self.assertTrue(await coder.allowed_to_edit("added.txt")) self.assertTrue(coder.need_commit_before_edits) - def test_get_files_content(self): + async def test_get_files_content(self): tempdir = Path(tempfile.mkdtemp()) file1 = tempdir / "file1.txt" @@ -114,13 +114,13 @@ def test_get_files_content(self): files = [file1, file2] # Initialize the Coder object with the mocked IO and mocked repo - coder = Coder.create(self.GPT35, None, io=InputOutput(), fnames=files) + coder = await Coder.create(self.GPT35, None, io=InputOutput(), fnames=files) content = coder.get_files_content().splitlines() - self.assertIn("file1.txt", content) - self.assertIn("file2.txt", content) + assert "file1.txt" in content + assert "file2.txt" in content - def test_check_for_filename_mentions(self): + async def test_check_for_filename_mentions(self): with GitTemporaryDirectory(): repo = git.Repo() @@ -137,7 +137,7 @@ def test_check_for_filename_mentions(self): repo.git.commit("-m", "new") # Initialize the Coder object with the mocked IO and mocked repo - coder = Coder.create(self.GPT35, None, mock_io) + coder = await Coder.create(self.GPT35, None, mock_io) # Call the check_for_file_mentions method coder.check_for_file_mentions("Please check file1.txt and file2.py") @@ -150,12 +150,12 @@ def test_check_for_filename_mentions(self): ] ) - self.assertEqual(coder.abs_fnames, expected_files) + assert coder.abs_fnames == expected_files - def test_check_for_ambiguous_filename_mentions_of_longer_paths(self): + async def test_check_for_ambiguous_filename_mentions_of_longer_paths(self): with GitTemporaryDirectory(): io = InputOutput(pretty=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) fname = Path("file1.txt") fname.touch() @@ -173,10 +173,10 @@ def test_check_for_ambiguous_filename_mentions_of_longer_paths(self): self.assertEqual(coder.abs_fnames, set([str(fname.resolve())])) - def test_skip_duplicate_basename_mentions(self): + async def test_skip_duplicate_basename_mentions(self): with GitTemporaryDirectory(): io = InputOutput(pretty=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) # Create files with same basename in different directories fname1 = Path("dir1") / "file.txt" @@ -204,13 +204,13 @@ def test_skip_duplicate_basename_mentions(self): mentioned = coder.get_file_mentions(f"Check {fname1} and {fname3}") self.assertEqual(mentioned, {str(fname3)}) - def test_check_for_file_mentions_read_only(self): + async def test_check_for_file_mentions_read_only(self): with GitTemporaryDirectory(): io = InputOutput( pretty=False, yes=True, ) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) fname = Path("readonly_file.txt") fname.touch() @@ -231,19 +231,19 @@ def test_check_for_file_mentions_read_only(self): # Assert that abs_fnames is still empty (file not added) self.assertEqual(coder.abs_fnames, set()) - def test_check_for_file_mentions_with_mocked_confirm(self): + async def test_check_for_file_mentions_with_mocked_confirm(self): with GitTemporaryDirectory(): io = InputOutput(pretty=False) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) # Mock get_file_mentions to return two file names coder.get_file_mentions = MagicMock(return_value=set(["file1.txt", "file2.txt"])) # Mock confirm_ask to return False for the first call and True for the second - io.confirm_ask = MagicMock(side_effect=[False, True, True]) + io.confirm_ask = AsyncMock(side_effect=[False, True, True]) # First call to check_for_file_mentions - coder.check_for_file_mentions("Please check file1.txt for the info") + await coder.check_for_file_mentions("Please check file1.txt for the info") # Assert that confirm_ask was called twice self.assertEqual(io.confirm_ask.call_count, 2) @@ -256,7 +256,7 @@ def test_check_for_file_mentions_with_mocked_confirm(self): io.confirm_ask.reset_mock() # Second call to check_for_file_mentions - coder.check_for_file_mentions("Please check file1.txt and file2.txt again") + await coder.check_for_file_mentions("Please check file1.txt and file2.txt again") # Assert that confirm_ask was called only once (for file1.txt) self.assertEqual(io.confirm_ask.call_count, 1) @@ -268,10 +268,10 @@ def test_check_for_file_mentions_with_mocked_confirm(self): # Assert that file1.txt is in ignore_mentions self.assertIn("file1.txt", coder.ignore_mentions) - def test_check_for_subdir_mention(self): + async def test_check_for_subdir_mention(self): with GitTemporaryDirectory(): io = InputOutput(pretty=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) fname = Path("other") / "file1.txt" fname.parent.mkdir(parents=True, exist_ok=True) @@ -286,10 +286,10 @@ def test_check_for_subdir_mention(self): self.assertEqual(coder.abs_fnames, set([str(fname.resolve())])) - def test_get_file_mentions_various_formats(self): + async def test_get_file_mentions_various_formats(self): with GitTemporaryDirectory(): io = InputOutput(pretty=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) # Create test files test_files = [ @@ -370,10 +370,10 @@ def test_get_file_mentions_various_formats(self): f"Failed to extract mentions from: {content}", ) - def test_get_file_mentions_multiline_backticks(self): + async def test_get_file_mentions_multiline_backticks(self): with GitTemporaryDirectory(): io = InputOutput(pretty=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) # Create test files test_files = [ @@ -409,10 +409,10 @@ def test_get_file_mentions_multiline_backticks(self): f"Failed to extract mentions from multiline backticked content: {content}", ) - def test_get_file_mentions_path_formats(self): + async def test_get_file_mentions_path_formats(self): with GitTemporaryDirectory(): io = InputOutput(pretty=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) # Test cases with different path formats test_cases = [ @@ -447,7 +447,7 @@ def test_get_file_mentions_path_formats(self): f"Failed for content: {content}, addable_files: {addable_files}", ) - def test_run_with_file_deletion(self): + async def test_run_with_file_deletion(self): # Create a few temporary files tempdir = Path(tempfile.mkdtemp()) @@ -460,8 +460,7 @@ def test_run_with_file_deletion(self): files = [file1, file2] - # Initialize the Coder object with the mocked IO and mocked repo - coder = Coder.create(self.GPT35, None, io=InputOutput(), fnames=files) + coder = await Coder.create(self.GPT35, None, io=InputOutput(), fnames=files) def mock_send(*args, **kwargs): coder.partial_response_content = "ok" @@ -471,23 +470,22 @@ def mock_send(*args, **kwargs): coder.send = mock_send # Call the run method with a message - coder.run(with_message="hi") + await coder.run(with_message="hi") self.assertEqual(len(coder.abs_fnames), 2) file1.unlink() # Call the run method again with a message - coder.run(with_message="hi") + await coder.run(with_message="hi") self.assertEqual(len(coder.abs_fnames), 1) - def test_run_with_file_unicode_error(self): + async def test_run_with_file_unicode_error(self): # Create a few temporary files _, file1 = tempfile.mkstemp() _, file2 = tempfile.mkstemp() files = [file1, file2] - # Initialize the Coder object with the mocked IO and mocked repo coder = Coder.create(self.GPT35, None, io=InputOutput(), fnames=files) def mock_send(*args, **kwargs): @@ -498,7 +496,7 @@ def mock_send(*args, **kwargs): coder.send = mock_send # Call the run method with a message - coder.run(with_message="hi") + await coder.run(with_message="hi") self.assertEqual(len(coder.abs_fnames), 2) # Write some non-UTF8 text into the file @@ -506,10 +504,10 @@ def mock_send(*args, **kwargs): f.write(b"\x80abc") # Call the run method again with a message - coder.run(with_message="hi") + await coder.run(with_message="hi") self.assertEqual(len(coder.abs_fnames), 1) - def test_choose_fence(self): + async def test_choose_fence(self): # Create a few temporary files _, file1 = tempfile.mkstemp() @@ -518,8 +516,7 @@ def test_choose_fence(self): files = [file1] - # Initialize the Coder object with the mocked IO and mocked repo - coder = Coder.create(self.GPT35, None, io=InputOutput(), fnames=files) + coder = await Coder.create(self.GPT35, None, io=InputOutput(), fnames=files) def mock_send(*args, **kwargs): coder.partial_response_content = "ok" @@ -529,22 +526,19 @@ def mock_send(*args, **kwargs): coder.send = mock_send # Call the run method with a message - coder.run(with_message="hi") + await coder.run(with_message="hi") self.assertNotEqual(coder.fence[0], "```") - def test_run_with_file_utf_unicode_error(self): + async def test_run_with_file_utf_unicode_error(self): "make sure that we honor InputOutput(encoding) and don't just assume utf-8" - # Create a few temporary files + encoding = "utf-16" _, file1 = tempfile.mkstemp() _, file2 = tempfile.mkstemp() - files = [file1, file2] - encoding = "utf-16" - # Initialize the Coder object with the mocked IO and mocked repo - coder = Coder.create( + coder = await Coder.create( self.GPT35, None, io=InputOutput(encoding=encoding), @@ -559,19 +553,19 @@ def mock_send(*args, **kwargs): coder.send = mock_send # Call the run method with a message - coder.run(with_message="hi") + await coder.run(with_message="hi") self.assertEqual(len(coder.abs_fnames), 2) some_content_which_will_error_if_read_with_encoding_utf8 = "ÅÍÎÏ".encode(encoding) with open(file1, "wb") as f: f.write(some_content_which_will_error_if_read_with_encoding_utf8) - coder.run(with_message="hi") + await coder.run(with_message="hi") # both files should still be here self.assertEqual(len(coder.abs_fnames), 2) - def test_new_file_edit_one_commit(self): + async def test_new_file_edit_one_commit(self): """A new file should get pre-committed before the GPT edit commit""" with GitTemporaryDirectory(): repo = git.Repo() @@ -580,7 +574,7 @@ def test_new_file_edit_one_commit(self): io = InputOutput(yes=True) io.tool_warning = MagicMock() - coder = Coder.create(self.GPT35, "diff", io=io, fnames=[str(fname)]) + coder = await Coder.create(self.GPT35, "diff", io=io, fnames=[str(fname)]) self.assertTrue(fname.exists()) @@ -588,7 +582,7 @@ def test_new_file_edit_one_commit(self): with self.assertRaises(git.exc.GitCommandError): list(repo.iter_commits(repo.active_branch.name)) - def mock_send(*args, **kwargs): + async def mock_send(*args, **kwargs): coder.partial_response_content = f""" Do this: @@ -606,7 +600,7 @@ def mock_send(*args, **kwargs): coder.repo.get_commit_message = MagicMock() coder.repo.get_commit_message.return_value = "commit message" - coder.run(with_message="hi") + await coder.run(with_message="hi") content = fname.read_text() self.assertEqual(content, "new\n") @@ -614,7 +608,7 @@ def mock_send(*args, **kwargs): num_commits = len(list(repo.iter_commits(repo.active_branch.name))) self.assertEqual(num_commits, 2) - def test_only_commit_gpt_edited_file(self): + async def test_only_commit_gpt_edited_file(self): """ Only commit file that gpt edits, not other dirty files. Also ensure commit msg only depends on diffs from the GPT edited file. @@ -637,9 +631,9 @@ def test_only_commit_gpt_edited_file(self): fname1.write_text("ONE\n") io = InputOutput(yes=True) - coder = Coder.create(self.GPT35, "diff", io=io, fnames=[str(fname1), str(fname2)]) + coder = await Coder.create(self.GPT35, "diff", io=io, fnames=[str(fname1), str(fname2)]) - def mock_send(*args, **kwargs): + async def mock_send(*args, **kwargs): coder.partial_response_content = f""" Do this: @@ -662,14 +656,14 @@ def mock_get_commit_message(diffs, context, user_language=None): coder.send = mock_send coder.repo.get_commit_message = MagicMock(side_effect=mock_get_commit_message) - coder.run(with_message="hi") + await coder.run(with_message="hi") content = fname2.read_text() self.assertEqual(content, "TWO\n") self.assertTrue(repo.is_dirty(path=str(fname1))) - def test_gpt_edit_to_dirty_file(self): + async def test_gpt_edit_to_dirty_file(self): """A dirty file should be committed before the GPT edits are committed""" with GitTemporaryDirectory(): @@ -690,7 +684,7 @@ def test_gpt_edit_to_dirty_file(self): fname2.write_text("OTHER\n") io = InputOutput(yes=True) - coder = Coder.create(self.GPT35, "diff", io=io, fnames=[str(fname)]) + coder = await Coder.create(self.GPT35, "diff", io=io, fnames=[str(fname)]) def mock_send(*args, **kwargs): coder.partial_response_content = f""" @@ -716,7 +710,7 @@ def mock_get_commit_message(diffs, context, user_language=None): coder.repo.get_commit_message = MagicMock(side_effect=mock_get_commit_message) coder.send = mock_send - coder.run(with_message="hi") + await coder.run(with_message="hi") content = fname.read_text() self.assertEqual(content, "three\n") @@ -754,7 +748,7 @@ def mock_get_commit_message(diffs, context, user_language=None): self.assertEqual(len(saved_diffs), 2) - def test_gpt_edit_to_existing_file_not_in_repo(self): + async def test_gpt_edit_to_existing_file_not_in_repo(self): with GitTemporaryDirectory(): repo = git.Repo() @@ -768,7 +762,7 @@ def test_gpt_edit_to_existing_file_not_in_repo(self): repo.git.commit("-m", "initial") io = InputOutput(yes=True) - coder = Coder.create(self.GPT35, "diff", io=io, fnames=[str(fname)]) + coder = await Coder.create(self.GPT35, "diff", io=io, fnames=[str(fname)]) def mock_send(*args, **kwargs): coder.partial_response_content = f""" @@ -794,7 +788,7 @@ def mock_get_commit_message(diffs, context, user_language=None): coder.repo.get_commit_message = MagicMock(side_effect=mock_get_commit_message) coder.send = mock_send - coder.run(with_message="hi") + await coder.run(with_message="hi") content = fname.read_text() self.assertEqual(content, "two\n") @@ -802,7 +796,7 @@ def mock_get_commit_message(diffs, context, user_language=None): diff = saved_diffs[0] self.assertIn("file.txt", diff) - def test_skip_aiderignored_files(self): + async def test_skip_aiderignored_files(self): with GitTemporaryDirectory(): repo = git.Repo() @@ -827,7 +821,7 @@ def test_skip_aiderignored_files(self): aider_ignore_file=str(aignore), ) - coder = Coder.create( + coder = await Coder.create( self.GPT35, None, io, @@ -839,7 +833,7 @@ def test_skip_aiderignored_files(self): self.assertNotIn(fname2, str(coder.abs_fnames)) self.assertNotIn(fname3, str(coder.abs_fnames)) - def test_skip_gitignored_files_on_init(self): + async def test_skip_gitignored_files_on_init(self): with GitTemporaryDirectory() as _: repo_path = Path(".") repo = git.Repo.init(repo_path) @@ -861,7 +855,7 @@ def test_skip_gitignored_files_on_init(self): fnames_to_add = [str(ignored_file), str(regular_file)] - coder = Coder.create(self.GPT35, None, mock_io, fnames=fnames_to_add) + coder = await Coder.create(self.GPT35, None, mock_io, fnames=fnames_to_add) self.assertNotIn(str(ignored_file.resolve()), coder.abs_fnames) self.assertIn(str(regular_file.resolve()), coder.abs_fnames) @@ -869,9 +863,9 @@ def test_skip_gitignored_files_on_init(self): f"Skipping {ignored_file.name} that matches gitignore spec." ) - def test_check_for_urls(self): + async def test_check_for_urls(self): io = InputOutput(yes=True) - coder = Coder.create(self.GPT35, None, io=io) + coder = await Coder.create(self.GPT35, None, io=io) coder.commands.scraper = MagicMock() coder.commands.scraper.scrape = MagicMock(return_value="some content") @@ -911,9 +905,8 @@ def test_check_for_urls(self): ] for input_text, expected_url in test_cases: - with self.subTest(input_text=input_text): - result = coder.check_for_urls(input_text) - self.assertIn(expected_url, result) + result = await coder.check_for_urls(input_text) + assert expected_url in result # Test cases from the GitHub issue issue_cases = [ @@ -925,32 +918,31 @@ def test_check_for_urls(self): ] for input_text, expected_url in issue_cases: - with self.subTest(input_text=input_text): - result = coder.check_for_urls(input_text) - self.assertIn(expected_url, result) + result = await coder.check_for_urls(input_text) + assert expected_url in result # Test case with multiple URLs multi_url_input = "Check http://example1.com and https://example2.com/page" - result = coder.check_for_urls(multi_url_input) - self.assertIn("http://example1.com", result) - self.assertIn("https://example2.com/page", result) + result = await coder.check_for_urls(multi_url_input) + assert "http://example1.com" in result + assert "https://example2.com/page" in result # Test case with no URL no_url_input = "This text contains no URL" - result = coder.check_for_urls(no_url_input) - self.assertEqual(result, no_url_input) + result = await coder.check_for_urls(no_url_input) + assert result == no_url_input # Test case with the same URL appearing multiple times repeated_url_input = ( "Check https://example.com, then https://example.com again, and https://example.com one" " more time" ) - result = coder.check_for_urls(repeated_url_input) + result = await coder.check_for_urls(repeated_url_input) # the original 3 in the input text, plus 1 more for the scraped text - self.assertEqual(result.count("https://example.com"), 4) - self.assertIn("https://example.com", result) + assert result.count("https://example.com") == 4 + assert "https://example.com" in result - def test_coder_from_coder_with_subdir(self): + async def test_coder_from_coder_with_subdir(self): with GitTemporaryDirectory() as root: repo = git.Repo.init(root) @@ -968,10 +960,10 @@ def test_coder_from_coder_with_subdir(self): # Create the first coder io = InputOutput(yes=True) - coder1 = Coder.create(self.GPT35, None, io=io, fnames=[test_file.name]) + coder1 = await Coder.create(self.GPT35, None, io=io, fnames=[test_file.name]) # Create a new coder from the first coder - coder2 = Coder.create(from_coder=coder1) + coder2 = await Coder.create(from_coder=coder1) # Check if both coders have the same set of abs_fnames self.assertEqual(coder1.abs_fnames, coder2.abs_fnames) @@ -986,12 +978,12 @@ def test_coder_from_coder_with_subdir(self): self.assertEqual(len(coder1.abs_fnames), 1) self.assertEqual(len(coder2.abs_fnames), 1) - def test_suggest_shell_commands(self): + async def test_suggest_shell_commands(self): with GitTemporaryDirectory(): io = InputOutput(yes=True) - coder = Coder.create(self.GPT35, "diff", io=io) + coder = await Coder.create(self.GPT35, "diff", io=io) - def mock_send(*args, **kwargs): + async def mock_send(*args, **kwargs): coder.partial_response_content = """Here's a shell command to run: ```bash @@ -1008,7 +1000,7 @@ def mock_send(*args, **kwargs): coder.handle_shell_commands = MagicMock() # Run the coder with a message - coder.run(with_message="Suggest a shell command") + await coder.run(with_message="Suggest a shell command") # Check if the shell command was added to the list self.assertEqual(len(coder.shell_commands), 1) @@ -1017,34 +1009,34 @@ def mock_send(*args, **kwargs): # Check if handle_shell_commands was called with the correct argument coder.handle_shell_commands.assert_called_once() - def test_no_suggest_shell_commands(self): + async def test_no_suggest_shell_commands(self): with GitTemporaryDirectory(): io = InputOutput(yes=True) - coder = Coder.create(self.GPT35, "diff", io=io, suggest_shell_commands=False) + coder = await Coder.create(self.GPT35, "diff", io=io, suggest_shell_commands=False) self.assertFalse(coder.suggest_shell_commands) - def test_detect_urls_enabled(self): + async def test_detect_urls_enabled(self): with GitTemporaryDirectory(): io = InputOutput(yes=True) - coder = Coder.create(self.GPT35, "diff", io=io, detect_urls=True) + coder = await Coder.create(self.GPT35, "diff", io=io, detect_urls=True) coder.commands.scraper = MagicMock() coder.commands.scraper.scrape = MagicMock(return_value="some content") # Test with a message containing a URL message = "Check out https://example.com" - coder.check_for_urls(message) + await coder.check_for_urls(message) coder.commands.scraper.scrape.assert_called_once_with("https://example.com") - def test_detect_urls_disabled(self): + async def test_detect_urls_disabled(self): with GitTemporaryDirectory(): io = InputOutput(yes=True) - coder = Coder.create(self.GPT35, "diff", io=io, detect_urls=False) + coder = await Coder.create(self.GPT35, "diff", io=io, detect_urls=False) coder.commands.scraper = MagicMock() coder.commands.scraper.scrape = MagicMock(return_value="some content") # Test with a message containing a URL message = "Check out https://example.com" - result = coder.check_for_urls(message) + result = await coder.check_for_urls(message) self.assertEqual(result, message) coder.commands.scraper.scrape.assert_not_called() @@ -1058,20 +1050,20 @@ def test_unknown_edit_format_exception(self): ) self.assertEqual(str(exc), expected_msg) - def test_unknown_edit_format_creation(self): + async def test_unknown_edit_format_creation(self): # Test that creating a Coder with invalid edit format raises the exception io = InputOutput(yes=True) invalid_format = "invalid_format" with self.assertRaises(UnknownEditFormat) as cm: - Coder.create(self.GPT35, invalid_format, io=io) + await Coder.create(self.GPT35, invalid_format, io=io) exc = cm.exception self.assertEqual(exc.edit_format, invalid_format) self.assertIsInstance(exc.valid_formats, list) self.assertTrue(len(exc.valid_formats) > 0) - def test_system_prompt_prefix(self): + async def test_system_prompt_prefix(self): # Test that system_prompt_prefix is properly set and used io = InputOutput(yes=True) test_prefix = "Test prefix. " @@ -1080,7 +1072,7 @@ def test_system_prompt_prefix(self): model = Model("gpt-3.5-turbo") model.system_prompt_prefix = test_prefix - coder = Coder.create(model, None, io=io) + coder = await Coder.create(model, None, io=io) # Get the formatted messages chunks = coder.format_messages() @@ -1090,7 +1082,7 @@ def test_system_prompt_prefix(self): system_message = next(msg for msg in messages if msg["role"] == "system") self.assertTrue(system_message["content"].startswith(test_prefix)) - def test_coder_create_with_new_file_oserror(self): + async def test_coder_create_with_new_file_oserror(self): with GitTemporaryDirectory(): io = InputOutput(yes=True) new_file = "new_file.txt" @@ -1098,7 +1090,7 @@ def test_coder_create_with_new_file_oserror(self): # Mock Path.touch() to raise OSError with patch("pathlib.Path.touch", side_effect=OSError("Permission denied")): # Create the coder with a new file - coder = Coder.create(self.GPT35, "diff", io=io, fnames=[new_file]) + coder = await Coder.create(self.GPT35, "diff", io=io, fnames=[new_file]) # Check if the coder was created successfully self.assertIsInstance(coder, Coder) @@ -1106,10 +1098,10 @@ def test_coder_create_with_new_file_oserror(self): # Check if the new file is not in abs_fnames self.assertNotIn(new_file, [os.path.basename(f) for f in coder.abs_fnames]) - def test_show_exhausted_error(self): + async def test_show_exhausted_error(self): with GitTemporaryDirectory(): io = InputOutput(yes=True) - coder = Coder.create(self.GPT35, "diff", io=io) + coder = await Coder.create(self.GPT35, "diff", io=io) # Set up some real done_messages and cur_messages coder.done_messages = [ @@ -1166,13 +1158,13 @@ def test_show_exhausted_error(self): self.assertIn("Output tokens:", error_message) self.assertIn("Total tokens:", error_message) - def test_keyboard_interrupt_handling(self): + async def test_keyboard_interrupt_handling(self): with GitTemporaryDirectory(): io = InputOutput(yes=True) - coder = Coder.create(self.GPT35, "diff", io=io) + coder = await Coder.create(self.GPT35, "diff", io=io) # Simulate keyboard interrupt during message processing - def mock_send(*args, **kwargs): + async def mock_send(*args, **kwargs): coder.partial_response_content = "Partial response" coder.partial_response_function_call = dict() raise KeyboardInterrupt() @@ -1183,19 +1175,19 @@ def mock_send(*args, **kwargs): sanity_check_messages(coder.cur_messages) # Process message that will trigger interrupt - list(coder.send_message("Test message")) + list(await coder.send_message("Test message")) # Verify messages are still in valid state sanity_check_messages(coder.cur_messages) self.assertEqual(coder.cur_messages[-1]["role"], "assistant") - def test_token_limit_error_handling(self): + async def test_token_limit_error_handling(self): with GitTemporaryDirectory(): io = InputOutput(yes=True) - coder = Coder.create(self.GPT35, "diff", io=io) + coder = await Coder.create(self.GPT35, "diff", io=io) # Simulate token limit error - def mock_send(*args, **kwargs): + async def mock_send(*args, **kwargs): coder.partial_response_content = "Partial response" coder.partial_response_function_call = dict() raise FinishReasonLength() @@ -1206,33 +1198,33 @@ def mock_send(*args, **kwargs): sanity_check_messages(coder.cur_messages) # Process message that hits token limit - list(coder.send_message("Long message")) + list(await coder.send_message("Long message")) # Verify messages are still in valid state sanity_check_messages(coder.cur_messages) self.assertEqual(coder.cur_messages[-1]["role"], "assistant") - def test_message_sanity_after_partial_response(self): + async def test_message_sanity_after_partial_response(self): with GitTemporaryDirectory(): io = InputOutput(yes=True) - coder = Coder.create(self.GPT35, "diff", io=io) + coder = await Coder.create(self.GPT35, "diff", io=io) # Simulate partial response then interrupt - def mock_send(*args, **kwargs): + async def mock_send(*args, **kwargs): coder.partial_response_content = "Partial response" coder.partial_response_function_call = dict() raise KeyboardInterrupt() coder.send = mock_send - list(coder.send_message("Test")) + list(await coder.send_message("Test")) # Verify message structure remains valid sanity_check_messages(coder.cur_messages) self.assertEqual(coder.cur_messages[-1]["role"], "assistant") - def test_normalize_language(self): - coder = Coder.create(self.GPT35, None, io=InputOutput()) + async def test_normalize_language(self): + coder = await Coder.create(self.GPT35, None, io=InputOutput()) # Test None and empty self.assertIsNone(coder.normalize_language(None)) @@ -1282,9 +1274,9 @@ def test_normalize_language(self): with patch("aider.coders.base_coder.Locale", mock_babel_locale_error): self.assertEqual(coder.normalize_language("en_US"), "English") # Falls back to map - def test_get_user_language(self): + async def test_get_user_language(self): io = InputOutput() - coder = Coder.create(self.GPT35, None, io=io) + coder = await Coder.create(self.GPT35, None, io=io) # 1. Test with self.chat_language set coder.chat_language = "fr_CA" @@ -1349,10 +1341,10 @@ def test_get_user_language(self): with patch("os.environ.get", return_value=None) as mock_env_get: self.assertIsNone(coder.get_user_language()) - def test_architect_coder_auto_accept_true(self): + async def test_architect_coder_auto_accept_true(self): with GitTemporaryDirectory(): io = InputOutput(yes=True) - io.confirm_ask = MagicMock(return_value=True) + io.confirm_ask = AsyncMock(return_value=True) # Create an ArchitectCoder with auto_accept_architect=True with patch("aider.coders.architect_coder.AskCoder.__init__", return_value=None): @@ -1379,7 +1371,7 @@ def test_architect_coder_auto_accept_true(self): coder.partial_response_content = "Make these changes to the code" # Call reply_completed - coder.reply_completed() + await coder.reply_completed() # Verify that confirm_ask was not called (auto-accepted) io.confirm_ask.assert_not_called() @@ -1387,10 +1379,10 @@ def test_architect_coder_auto_accept_true(self): # Verify that editor coder was created and run mock_editor.run.assert_called_once() - def test_architect_coder_auto_accept_false_confirmed(self): + async def test_architect_coder_auto_accept_false_confirmed(self): with GitTemporaryDirectory(): io = InputOutput(yes=False) - io.confirm_ask = MagicMock(return_value=True) + io.confirm_ask = AsyncMock(return_value=True) # Create an ArchitectCoder with auto_accept_architect=False with patch("aider.coders.architect_coder.AskCoder.__init__", return_value=None): @@ -1421,7 +1413,7 @@ def test_architect_coder_auto_accept_false_confirmed(self): coder.partial_response_content = "Make these changes to the code" # Call reply_completed - coder.reply_completed() + await coder.reply_completed() # Verify that confirm_ask was called io.confirm_ask.assert_called_once_with("Edit the files?") @@ -1429,10 +1421,10 @@ def test_architect_coder_auto_accept_false_confirmed(self): # Verify that editor coder was created and run mock_editor.run.assert_called_once() - def test_architect_coder_auto_accept_false_rejected(self): + async def test_architect_coder_auto_accept_false_rejected(self): with GitTemporaryDirectory(): io = InputOutput(yes=False) - io.confirm_ask = MagicMock(return_value=False) + io.confirm_ask = AsyncMock(return_value=False) # Create an ArchitectCoder with auto_accept_architect=False with patch("aider.coders.architect_coder.AskCoder.__init__", return_value=None): @@ -1455,7 +1447,7 @@ def test_architect_coder_auto_accept_false_rejected(self): coder.partial_response_content = "Make these changes to the code" # Call reply_completed - coder.reply_completed() + await coder.reply_completed() # Verify that confirm_ask was called io.confirm_ask.assert_called_once_with("Edit the files?") @@ -1465,7 +1457,7 @@ def test_architect_coder_auto_accept_false_rejected(self): mock_editor.run.assert_not_called() @patch("aider.coders.base_coder.experimental_mcp_client") - def test_mcp_server_connection(self, mock_mcp_client): + async def test_mcp_server_connection(self, mock_mcp_client): """Test that the coder connects to MCP servers for tools.""" with GitTemporaryDirectory(): io = InputOutput(yes=True) @@ -1481,7 +1473,7 @@ def test_mcp_server_connection(self, mock_mcp_client): # Create coder with mock MCP server with patch.object(Coder, "initialize_mcp_tools", return_value=mock_tools): - coder = Coder.create(self.GPT35, "diff", io=io, mcp_servers=[mock_server]) + coder = await Coder.create(self.GPT35, "diff", io=io, mcp_servers=[mock_server]) # Manually set mcp_tools since we're bypassing initialize_mcp_tools coder.mcp_tools = mock_tools @@ -1492,7 +1484,7 @@ def test_mcp_server_connection(self, mock_mcp_client): self.assertEqual(coder.mcp_tools[0][0], "test_server") @patch("aider.coders.base_coder.experimental_mcp_client") - def test_coder_creation_with_partial_failed_mcp_server(self, mock_mcp_client): + async def test_coder_creation_with_partial_failed_mcp_server(self, mock_mcp_client, GPT35): """Test that a coder can still be created even if an MCP server fails to initialize.""" with GitTemporaryDirectory(): io = InputOutput(yes=True) @@ -1519,8 +1511,8 @@ async def mock_load_mcp_tools(session, format): mock_mcp_client.load_mcp_tools = AsyncMock(side_effect=mock_load_mcp_tools) # Create coder with both servers - coder = Coder.create( - self.GPT35, + coder = await Coder.create( + GPT35, "diff", io=io, mcp_servers=[working_server, failing_server], @@ -1528,17 +1520,17 @@ async def mock_load_mcp_tools(session, format): ) # Verify that coder was created successfully - self.assertIsInstance(coder, Coder) + assert isinstance(coder, Coder) # Verify that only the working server's tools were added - self.assertIsNotNone(coder.mcp_tools) - self.assertEqual(len(coder.mcp_tools), 1) - self.assertEqual(coder.mcp_tools[0][0], "working_server") + assert coder.mcp_tools is not None + assert len(coder.mcp_tools) == 1 + assert coder.mcp_tools[0][0] == "working_server" # Verify that the tool list contains only working tools tool_list = coder.get_tool_list() - self.assertEqual(len(tool_list), 1) - self.assertEqual(tool_list[0]["function"]["name"], "working_tool") + assert len(tool_list) == 1 + assert tool_list[0]["function"]["name"] == "working_tool" # Verify that the warning was logged for the failing server io.tool_warning.assert_called_with( @@ -1546,7 +1538,7 @@ async def mock_load_mcp_tools(session, format): ) @patch("aider.coders.base_coder.experimental_mcp_client") - def test_coder_creation_with_all_failed_mcp_server(self, mock_mcp_client): + async def test_coder_creation_with_all_failed_mcp_server(self, mock_mcp_client): """Test that a coder can still be created even if an MCP server fails to initialize.""" with GitTemporaryDirectory(): io = InputOutput(yes=True) @@ -1564,7 +1556,7 @@ async def mock_load_mcp_tools(session, format): mock_mcp_client.load_mcp_tools = AsyncMock(side_effect=mock_load_mcp_tools) # Create coder with both servers - coder = Coder.create( + coder = await Coder.create( self.GPT35, "diff", io=io, @@ -1588,21 +1580,21 @@ async def mock_load_mcp_tools(session, format): "Error initializing MCP server failing_server:\nFailed to load tools" ) - def test_process_tool_calls_none_response(self): + async def test_process_tool_calls_none_response(self): """Test that process_tool_calls handles None response correctly.""" with GitTemporaryDirectory(): io = InputOutput(yes=True) - coder = Coder.create(self.GPT35, "diff", io=io) + coder = await Coder.create(self.GPT35, "diff", io=io) # Test with None response - result = coder.process_tool_calls(None) + result = await coder.process_tool_calls(None) self.assertFalse(result) - def test_process_tool_calls_no_tool_calls(self): + async def test_process_tool_calls_no_tool_calls(self): """Test that process_tool_calls handles response with no tool calls.""" with GitTemporaryDirectory(): io = InputOutput(yes=True) - coder = Coder.create(self.GPT35, "diff", io=io) + coder = await Coder.create(self.GPT35, "diff", io=io) # Create a response with no tool calls response = MagicMock() @@ -1610,12 +1602,12 @@ def test_process_tool_calls_no_tool_calls(self): response.choices[0].message = MagicMock() response.choices[0].message.tool_calls = [] - result = coder.process_tool_calls(response) + result = await coder.process_tool_calls(response) self.assertFalse(result) @patch("aider.coders.base_coder.experimental_mcp_client") @patch("asyncio.run") - def test_process_tool_calls_with_tools(self, mock_asyncio_run, mock_mcp_client): + async def test_process_tool_calls_with_tools(self, mock_asyncio_run, mock_mcp_client): """Test that process_tool_calls processes tool calls correctly.""" with GitTemporaryDirectory(): io = InputOutput(yes=True) @@ -1643,7 +1635,7 @@ def test_process_tool_calls_with_tools(self, mock_asyncio_run, mock_mcp_client): ) # Create coder with mock MCP tools and servers - coder = Coder.create(self.GPT35, "diff", io=io) + coder = await Coder.create(self.GPT35, "diff", io=io) coder.mcp_tools = [("test_server", [{"function": {"name": "test_tool"}}])] coder.mcp_servers = [mock_server] @@ -1660,7 +1652,7 @@ def test_process_tool_calls_with_tools(self, mock_asyncio_run, mock_mcp_client): mock_asyncio_run.return_value = tool_responses # Test process_tool_calls - result = coder.process_tool_calls(response) + result = await coder.process_tool_calls(response) self.assertTrue(result) # Verify that asyncio.run was called @@ -1673,7 +1665,7 @@ def test_process_tool_calls_with_tools(self, mock_asyncio_run, mock_mcp_client): self.assertEqual(coder.cur_messages[1]["tool_call_id"], "test_id") self.assertEqual(coder.cur_messages[1]["content"], "Tool execution result") - def test_process_tool_calls_max_calls_exceeded(self): + async def test_process_tool_calls_max_calls_exceeded(self): """Test that process_tool_calls handles max tool calls exceeded.""" with GitTemporaryDirectory(): io = InputOutput(yes=True) @@ -1697,13 +1689,13 @@ def test_process_tool_calls_max_calls_exceeded(self): mock_server.name = "test_server" # Create coder with max tool calls exceeded - coder = Coder.create(self.GPT35, "diff", io=io) + coder = await Coder.create(self.GPT35, "diff", io=io) coder.num_tool_calls = coder.max_tool_calls coder.mcp_tools = [("test_server", [{"function": {"name": "test_tool"}}])] coder.mcp_servers = [mock_server] # Test process_tool_calls - result = coder.process_tool_calls(response) + result = await coder.process_tool_calls(response) self.assertFalse(result) # Verify that warning was shown @@ -1711,7 +1703,7 @@ def test_process_tool_calls_max_calls_exceeded(self): f"Only {coder.max_tool_calls} tool calls allowed, stopping." ) - def test_process_tool_calls_user_rejects(self): + async def test_process_tool_calls_user_rejects(self): """Test that process_tool_calls handles user rejection.""" with GitTemporaryDirectory(): io = InputOutput(yes=True) @@ -1735,12 +1727,12 @@ def test_process_tool_calls_user_rejects(self): mock_server.name = "test_server" # Create coder with mock MCP tools - coder = Coder.create(self.GPT35, "diff", io=io) + coder = await Coder.create(self.GPT35, "diff", io=io) coder.mcp_tools = [("test_server", [{"function": {"name": "test_tool"}}])] coder.mcp_servers = [mock_server] # Test process_tool_calls - result = coder.process_tool_calls(response) + result = await coder.process_tool_calls(response) self.assertFalse(result) # Verify that confirm_ask was called @@ -1750,11 +1742,11 @@ def test_process_tool_calls_user_rejects(self): self.assertEqual(len(coder.cur_messages), 0) @patch("asyncio.run") - def test_execute_tool_calls(self, mock_asyncio_run): + async def test_execute_tool_calls(self, mock_asyncio_run): """Test that _execute_tool_calls executes tool calls correctly.""" with GitTemporaryDirectory(): io = InputOutput(yes=True) - coder = Coder.create(self.GPT35, "diff", io=io) + coder = await Coder.create(self.GPT35, "diff", io=io) # Create mock server and tool call mock_server = MagicMock() @@ -1783,7 +1775,7 @@ def test_execute_tool_calls(self, mock_asyncio_run): mock_asyncio_run.return_value = tool_responses # Test _execute_tool_calls directly - result = coder._execute_tool_calls(server_tool_calls) + result = await coder._execute_tool_calls(server_tool_calls) # Verify that asyncio.run was called mock_asyncio_run.assert_called_once() @@ -1794,7 +1786,7 @@ def test_execute_tool_calls(self, mock_asyncio_run): self.assertEqual(result[0]["tool_call_id"], "test_id") self.assertEqual(result[0]["content"], "Tool execution result") - def test_auto_commit_with_none_content_message(self): + async def test_auto_commit_with_none_content_message(self): """ Verify that auto_commit works with messages that have None content. This is common with tool calls. @@ -1808,7 +1800,7 @@ def test_auto_commit_with_none_content_message(self): repo.git.commit("-m", "initial") io = InputOutput(yes=True) - coder = Coder.create(self.GPT35, "diff", io=io, fnames=[str(fname)]) + coder = await Coder.create(self.GPT35, "diff", io=io, fnames=[str(fname)]) coder.cur_messages = [ {"role": "user", "content": "do a thing"}, @@ -1842,11 +1834,11 @@ def mock_get_commit_message(diffs, context, user_language=None): "aider.coders.base_coder.experimental_mcp_client.call_openai_tool", new_callable=AsyncMock, ) - def test_execute_tool_calls_multiple_content(self, mock_call_openai_tool): + async def test_execute_tool_calls_multiple_content(self, mock_call_openai_tool): """Test that _execute_tool_calls handles multiple content blocks correctly.""" with GitTemporaryDirectory(): io = InputOutput(yes=True) - coder = Coder.create(self.GPT35, "diff", io=io) + coder = await Coder.create(self.GPT35, "diff", io=io) # Create mock server and tool call mock_server = AsyncMock() @@ -1873,7 +1865,7 @@ def test_execute_tool_calls_multiple_content(self, mock_call_openai_tool): mock_call_openai_tool.return_value = mock_call_result # Test _execute_tool_calls directly - result = coder._execute_tool_calls(server_tool_calls) + result = await coder._execute_tool_calls(server_tool_calls) # Verify that call_openai_tool was called mock_call_openai_tool.assert_called_once() @@ -1891,11 +1883,11 @@ def test_execute_tool_calls_multiple_content(self, mock_call_openai_tool): "aider.coders.base_coder.experimental_mcp_client.call_openai_tool", new_callable=AsyncMock, ) - def test_execute_tool_calls_blob_content(self, mock_call_openai_tool): + async def test_execute_tool_calls_blob_content(self, mock_call_openai_tool): """Test that _execute_tool_calls handles BlobResourceContents correctly.""" with GitTemporaryDirectory(): io = InputOutput(yes=True) - coder = Coder.create(self.GPT35, "diff", io=io) + coder = await Coder.create(self.GPT35, "diff", io=io) # Create mock server and tool call mock_server = AsyncMock() @@ -1944,7 +1936,7 @@ def test_execute_tool_calls_blob_content(self, mock_call_openai_tool): mock_call_openai_tool.return_value = mock_call_result # Test _execute_tool_calls directly - result = coder._execute_tool_calls(server_tool_calls) + result = await coder._execute_tool_calls(server_tool_calls) # Verify that call_openai_tool was called mock_call_openai_tool.assert_called_once() @@ -1961,5 +1953,4 @@ def test_execute_tool_calls_blob_content(self, mock_call_openai_tool): self.assertEqual(result[0]["content"], expected_content) -if __name__ == "__main__": - unittest.main() +# Remove the unittest.main() since we're using pytest diff --git a/tests/basic/test_commands.py b/tests/basic/test_commands.py index 06440b7e620..86fc5e9e167 100644 --- a/tests/basic/test_commands.py +++ b/tests/basic/test_commands.py @@ -32,12 +32,12 @@ def tearDown(self): os.chdir(self.original_cwd) shutil.rmtree(self.tempdir, ignore_errors=True) - def test_cmd_add(self): + async def test_cmd_add(self): # Initialize the Commands and InputOutput objects io = InputOutput(pretty=False, fancy_input=False, yes=True) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Call the cmd_add method with 'foo.txt' and 'bar.txt' as a single string @@ -47,10 +47,10 @@ def test_cmd_add(self): self.assertTrue(os.path.exists("foo.txt")) self.assertTrue(os.path.exists("bar.txt")) - def test_cmd_copy(self): + async def test_cmd_copy(self): # Initialize InputOutput and Coder instances io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Add some assistant messages to the chat history @@ -77,10 +77,10 @@ def test_cmd_copy(self): ) mock_tool_output.assert_any_call(expected_preview) - def test_cmd_copy_with_cur_messages(self): + async def test_cmd_copy_with_cur_messages(self): # Initialize InputOutput and Coder instances io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Add messages to done_messages and cur_messages @@ -116,7 +116,7 @@ def test_cmd_copy_with_cur_messages(self): ) mock_tool_output.assert_any_call(expected_preview) io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Add only user messages @@ -130,9 +130,9 @@ def test_cmd_copy_with_cur_messages(self): # Assert tool_error was called indicating no assistant messages mock_tool_error.assert_called_once_with("No assistant messages found to copy.") - def test_cmd_copy_pyperclip_exception(self): + async def test_cmd_copy_pyperclip_exception(self): io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) coder.done_messages = [ @@ -152,23 +152,23 @@ def test_cmd_copy_pyperclip_exception(self): # Assert that tool_error was called with the clipboard error message mock_tool_error.assert_called_once_with("Failed to copy to clipboard: Clipboard error") - def test_cmd_add_bad_glob(self): + async def test_cmd_add_bad_glob(self): # https://github.com/Aider-AI/aider/issues/293 io = InputOutput(pretty=False, fancy_input=False, yes=False) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) commands.cmd_add("**.txt") - def test_cmd_add_with_glob_patterns(self): + async def test_cmd_add_with_glob_patterns(self): # Initialize the Commands and InputOutput objects io = InputOutput(pretty=False, fancy_input=False, yes=True) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create some test files @@ -189,12 +189,12 @@ def test_cmd_add_with_glob_patterns(self): # Check if the text file has not been added to the chat session self.assertNotIn(str(Path("test.txt").resolve()), coder.abs_fnames) - def test_cmd_add_no_match(self): + async def test_cmd_add_no_match(self): # yes=False means we will *not* create the file when it is not found io = InputOutput(pretty=False, fancy_input=False, yes=False) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Call the cmd_add method with a non-existent file pattern @@ -203,12 +203,12 @@ def test_cmd_add_no_match(self): # Check if no files have been added to the chat session self.assertEqual(len(coder.abs_fnames), 0) - def test_cmd_add_no_match_but_make_it(self): + async def test_cmd_add_no_match_but_make_it(self): # yes=True means we *will* create the file when it is not found io = InputOutput(pretty=False, fancy_input=False, yes=True) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) fname = Path("[abc].nonexistent") @@ -220,12 +220,12 @@ def test_cmd_add_no_match_but_make_it(self): self.assertEqual(len(coder.abs_fnames), 1) self.assertTrue(fname.exists()) - def test_cmd_add_drop_directory(self): + async def test_cmd_add_drop_directory(self): # Initialize the Commands and InputOutput objects io = InputOutput(pretty=False, fancy_input=False, yes=False) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create a directory and add files to it using pathlib @@ -271,12 +271,12 @@ def test_cmd_add_drop_directory(self): # it should be there, but was not in v0.10.0 self.assertNotIn(abs_fname, coder.abs_fnames) - def test_cmd_drop_with_glob_patterns(self): + async def test_cmd_drop_with_glob_patterns(self): # Initialize the Commands and InputOutput objects io = InputOutput(pretty=False, fancy_input=False, yes=True) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create test files in root and subdirectory @@ -300,12 +300,12 @@ def test_cmd_drop_with_glob_patterns(self): self.assertNotIn(str(Path("test2.py").resolve()), coder.abs_fnames) self.assertEqual(len(coder.abs_fnames), initial_count - 1) - def test_cmd_drop_without_glob(self): + async def test_cmd_drop_without_glob(self): # Initialize the Commands and InputOutput objects io = InputOutput(pretty=False, fancy_input=False, yes=True) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create test files @@ -332,12 +332,12 @@ def test_cmd_drop_without_glob(self): self.assertNotIn(str(Path("file3.py").resolve()), coder.abs_fnames) self.assertEqual(len(coder.abs_fnames), 0) - def test_cmd_add_bad_encoding(self): + async def test_cmd_add_bad_encoding(self): # Initialize the Commands and InputOutput objects io = InputOutput(pretty=False, fancy_input=False, yes=True) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create a new file foo.bad which will fail to decode as utf-8 @@ -348,7 +348,7 @@ def test_cmd_add_bad_encoding(self): self.assertEqual(coder.abs_fnames, set()) - def test_cmd_git(self): + async def test_cmd_git(self): # Initialize the Commands and InputOutput objects io = InputOutput(pretty=False, fancy_input=False, yes=True) @@ -357,7 +357,7 @@ def test_cmd_git(self): with open(f"{tempdir}/test.txt", "w") as f: f.write("test") - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Run the cmd_git method with the arguments "commit -a -m msg" @@ -369,11 +369,11 @@ def test_cmd_git(self): files_in_repo = repo.git.ls_files() self.assertIn("test.txt", files_in_repo) - def test_cmd_tokens(self): + async def test_cmd_tokens(self): # Initialize the Commands and InputOutput objects io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) commands.cmd_add("foo.txt bar.txt") @@ -393,7 +393,7 @@ def test_cmd_tokens(self): self.assertIn("foo.txt", console_output) self.assertIn("bar.txt", console_output) - def test_cmd_add_from_subdir(self): + async def test_cmd_add_from_subdir(self): repo = git.Repo.init() repo.config_writer().set_value("user", "name", "Test User").release() repo.config_writer().set_value("user", "email", "testuser@example.com").release() @@ -418,7 +418,7 @@ def test_cmd_add_from_subdir(self): os.chdir("subdir") io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # this should get added @@ -431,12 +431,12 @@ def test_cmd_add_from_subdir(self): self.assertNotIn(filenames[1], coder.abs_fnames) self.assertIn(filenames[2], coder.abs_fnames) - def test_cmd_add_from_subdir_again(self): + async def test_cmd_add_from_subdir_again(self): with GitTemporaryDirectory(): io = InputOutput(pretty=False, fancy_input=False, yes=False) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) Path("side_dir").mkdir() @@ -450,7 +450,7 @@ def test_cmd_add_from_subdir_again(self): # https://github.com/Aider-AI/aider/issues/201 commands.cmd_add("temp.txt") - def test_cmd_commit(self): + async def test_cmd_commit(self): with GitTemporaryDirectory(): fname = "test.txt" with open(fname, "w") as f: @@ -460,7 +460,7 @@ def test_cmd_commit(self): repo.git.commit("-m", "initial") io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) self.assertFalse(repo.is_dirty()) @@ -472,7 +472,7 @@ def test_cmd_commit(self): commands.cmd_commit(commit_message) self.assertFalse(repo.is_dirty()) - def test_cmd_add_from_outside_root(self): + async def test_cmd_add_from_outside_root(self): with ChdirTemporaryDirectory() as tmp_dname: root = Path("root") root.mkdir() @@ -481,7 +481,7 @@ def test_cmd_add_from_outside_root(self): io = InputOutput(pretty=False, fancy_input=False, yes=False) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) outside_file = Path(tmp_dname) / "outside.txt" @@ -493,7 +493,7 @@ def test_cmd_add_from_outside_root(self): self.assertEqual(len(coder.abs_fnames), 0) - def test_cmd_add_from_outside_git(self): + async def test_cmd_add_from_outside_git(self): with ChdirTemporaryDirectory() as tmp_dname: root = Path("root") root.mkdir() @@ -504,7 +504,7 @@ def test_cmd_add_from_outside_git(self): io = InputOutput(pretty=False, fancy_input=False, yes=False) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) outside_file = Path(tmp_dname) / "outside.txt" @@ -517,12 +517,12 @@ def test_cmd_add_from_outside_git(self): self.assertEqual(len(coder.abs_fnames), 0) - def test_cmd_add_filename_with_special_chars(self): + async def test_cmd_add_filename_with_special_chars(self): with ChdirTemporaryDirectory(): io = InputOutput(pretty=False, fancy_input=False, yes=False) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) fname = Path("with[brackets].txt") @@ -532,7 +532,7 @@ def test_cmd_add_filename_with_special_chars(self): self.assertIn(str(fname.resolve()), coder.abs_fnames) - def test_cmd_tokens_output(self): + async def test_cmd_tokens_output(self): with GitTemporaryDirectory() as repo_dir: # Create a small repository with a few files (Path(repo_dir) / "file1.txt").write_text("Content of file 1") @@ -547,7 +547,7 @@ def test_cmd_tokens_output(self): io = InputOutput(pretty=False, fancy_input=False, yes=False) from aider.coders import Coder - coder = Coder.create(Model("claude-3-5-sonnet-20240620"), None, io) + coder = await Coder.create(Model("claude-3-5-sonnet-20240620"), None, io) print(coder.get_announcements()) commands = Commands(io, coder) @@ -557,7 +557,7 @@ def test_cmd_tokens_output(self): original_tool_output = io.tool_output output_lines = [] - def capture_output(*args, **kwargs): + async def capture_output(*args, **kwargs): output_lines.extend(args) original_tool_output(*args, **kwargs) @@ -582,12 +582,12 @@ def capture_output(*args, **kwargs): self.assertTrue(any("tokens total" in line for line in output_lines)) self.assertTrue(any("tokens remaining" in line for line in output_lines)) - def test_cmd_add_dirname_with_special_chars(self): + async def test_cmd_add_dirname_with_special_chars(self): with ChdirTemporaryDirectory(): io = InputOutput(pretty=False, fancy_input=False, yes=False) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) dname = Path("with[brackets]") @@ -600,12 +600,12 @@ def test_cmd_add_dirname_with_special_chars(self): dump(coder.abs_fnames) self.assertIn(str(fname.resolve()), coder.abs_fnames) - def test_cmd_add_dirname_with_special_chars_git(self): + async def test_cmd_add_dirname_with_special_chars_git(self): with GitTemporaryDirectory(): io = InputOutput(pretty=False, fancy_input=False, yes=False) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) dname = Path("with[brackets]") @@ -622,12 +622,12 @@ def test_cmd_add_dirname_with_special_chars_git(self): dump(coder.abs_fnames) self.assertIn(str(fname.resolve()), coder.abs_fnames) - def test_cmd_add_abs_filename(self): + async def test_cmd_add_abs_filename(self): with ChdirTemporaryDirectory(): io = InputOutput(pretty=False, fancy_input=False, yes=False) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) fname = Path("file.txt") @@ -637,12 +637,12 @@ def test_cmd_add_abs_filename(self): self.assertIn(str(fname.resolve()), coder.abs_fnames) - def test_cmd_add_quoted_filename(self): + async def test_cmd_add_quoted_filename(self): with ChdirTemporaryDirectory(): io = InputOutput(pretty=False, fancy_input=False, yes=False) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) fname = Path("file with spaces.txt") @@ -652,7 +652,7 @@ def test_cmd_add_quoted_filename(self): self.assertIn(str(fname.resolve()), coder.abs_fnames) - def test_cmd_add_existing_with_dirty_repo(self): + async def test_cmd_add_existing_with_dirty_repo(self): with GitTemporaryDirectory(): repo = git.Repo() @@ -670,7 +670,7 @@ def test_cmd_add_existing_with_dirty_repo(self): io = InputOutput(pretty=False, fancy_input=False, yes=True) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # There's no reason this /add should trigger a commit @@ -688,10 +688,10 @@ def test_cmd_add_existing_with_dirty_repo(self): del commands del repo - def test_cmd_save_and_load(self): + async def test_cmd_save_and_load(self): with GitTemporaryDirectory() as repo_dir: io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create some test files @@ -762,7 +762,7 @@ def test_cmd_save_and_load(self): # Clean up Path(session_file).unlink() - def test_cmd_save_and_load_with_external_file(self): + async def test_cmd_save_and_load_with_external_file(self): with tempfile.NamedTemporaryFile(mode="w", delete=False) as external_file: external_file.write("External file content") external_file_path = external_file.name @@ -770,7 +770,7 @@ def test_cmd_save_and_load_with_external_file(self): try: with GitTemporaryDirectory() as repo_dir: io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create some test files in the repo @@ -832,7 +832,7 @@ def test_cmd_save_and_load_with_external_file(self): finally: os.unlink(external_file_path) - def test_cmd_save_and_load_with_multiple_external_files(self): + async def test_cmd_save_and_load_with_multiple_external_files(self): with ( tempfile.NamedTemporaryFile(mode="w", delete=False) as external_file1, tempfile.NamedTemporaryFile(mode="w", delete=False) as external_file2, @@ -845,7 +845,7 @@ def test_cmd_save_and_load_with_multiple_external_files(self): try: with GitTemporaryDirectory() as repo_dir: io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create some test files in the repo @@ -920,10 +920,10 @@ def test_cmd_save_and_load_with_multiple_external_files(self): os.unlink(external_file1_path) os.unlink(external_file2_path) - def test_cmd_read_only_with_image_file(self): + async def test_cmd_read_only_with_image_file(self): with GitTemporaryDirectory() as repo_dir: io = InputOutput(pretty=False, fancy_input=False, yes=False) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create a test image file @@ -936,7 +936,7 @@ def test_cmd_read_only_with_image_file(self): # Test with vision model vision_model = Model("gpt-4-vision-preview") - vision_coder = Coder.create(vision_model, None, io) + vision_coder = await Coder.create(vision_model, None, io) vision_commands = Commands(io, vision_coder) vision_commands.cmd_read_only(str(test_file)) @@ -965,10 +965,10 @@ def test_cmd_read_only_with_image_file(self): break self.assertTrue(found_image, "Image file not found in messages to LLM") - def test_cmd_read_only_with_glob_pattern(self): + async def test_cmd_read_only_with_glob_pattern(self): with GitTemporaryDirectory() as repo_dir: io = InputOutput(pretty=False, fancy_input=False, yes=False) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create multiple test files @@ -1000,10 +1000,10 @@ def test_cmd_read_only_with_glob_pattern(self): ) ) - def test_cmd_read_only_with_recursive_glob(self): + async def test_cmd_read_only_with_recursive_glob(self): with GitTemporaryDirectory() as repo_dir: io = InputOutput(pretty=False, fancy_input=False, yes=False) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create a directory structure with files @@ -1031,10 +1031,10 @@ def test_cmd_read_only_with_recursive_glob(self): ) ) - def test_cmd_read_only_with_nonexistent_glob(self): + async def test_cmd_read_only_with_nonexistent_glob(self): with GitTemporaryDirectory() as repo_dir: io = InputOutput(pretty=False, fancy_input=False, yes=False) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Test the /read-only command with a non-existent glob pattern @@ -1049,12 +1049,12 @@ def test_cmd_read_only_with_nonexistent_glob(self): # Ensure no files were added to abs_read_only_fnames self.assertEqual(len(coder.abs_read_only_fnames), 0) - def test_cmd_add_unicode_error(self): + async def test_cmd_add_unicode_error(self): # Initialize the Commands and InputOutput objects io = InputOutput(pretty=False, fancy_input=False, yes=True) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) fname = "file.txt" @@ -1066,13 +1066,13 @@ def test_cmd_add_unicode_error(self): commands.cmd_add("file.txt") self.assertEqual(coder.abs_fnames, set()) - def test_cmd_add_read_only_file(self): + async def test_cmd_add_read_only_file(self): with GitTemporaryDirectory(): # Initialize the Commands and InputOutput objects io = InputOutput(pretty=False, fancy_input=False, yes=True) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create a test file @@ -1122,12 +1122,12 @@ def test_cmd_add_read_only_file(self): ) ) - def test_cmd_test_unbound_local_error(self): + async def test_cmd_test_unbound_local_error(self): with ChdirTemporaryDirectory(): io = InputOutput(pretty=False, fancy_input=False, yes=False) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Mock the io.prompt_ask method to simulate user input @@ -1139,12 +1139,12 @@ def test_cmd_test_unbound_local_error(self): # Check that the output was added to cur_messages self.assertTrue(any("exit 1" in msg["content"] for msg in coder.cur_messages)) - def test_cmd_test_returns_output_on_failure(self): + async def test_cmd_test_returns_output_on_failure(self): with ChdirTemporaryDirectory(): io = InputOutput(pretty=False, fancy_input=False, yes=False) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Define a command that prints to stderr and exits with non-zero status @@ -1162,14 +1162,14 @@ def test_cmd_test_returns_output_on_failure(self): any(expected_output_fragment in msg["content"] for msg in coder.cur_messages) ) - def test_cmd_add_drop_untracked_files(self): + async def test_cmd_add_drop_untracked_files(self): with GitTemporaryDirectory(): repo = git.Repo() io = InputOutput(pretty=False, fancy_input=False, yes=False) from aider.coders import Coder - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) fname = Path("test.txt") @@ -1188,11 +1188,11 @@ def test_cmd_add_drop_untracked_files(self): self.assertEqual(len(coder.abs_fnames), 0) - def test_cmd_undo_with_dirty_files_not_in_last_commit(self): + async def test_cmd_undo_with_dirty_files_not_in_last_commit(self): with GitTemporaryDirectory() as repo_dir: repo = git.Repo(repo_dir) io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) other_path = Path(repo_dir) / "other_file.txt" @@ -1236,11 +1236,11 @@ def test_cmd_undo_with_dirty_files_not_in_last_commit(self): del commands del repo - def test_cmd_undo_with_newly_committed_file(self): + async def test_cmd_undo_with_newly_committed_file(self): with GitTemporaryDirectory() as repo_dir: repo = git.Repo(repo_dir) io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Put in a random first commit @@ -1272,11 +1272,11 @@ def test_cmd_undo_with_newly_committed_file(self): del commands del repo - def test_cmd_undo_on_first_commit(self): + async def test_cmd_undo_on_first_commit(self): with GitTemporaryDirectory() as repo_dir: repo = git.Repo(repo_dir) io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create and commit a new file @@ -1301,7 +1301,7 @@ def test_cmd_undo_on_first_commit(self): del commands del repo - def test_cmd_add_gitignored_file(self): + async def test_cmd_add_gitignored_file(self): with GitTemporaryDirectory(): # Create a .gitignore file gitignore = Path(".gitignore") @@ -1312,7 +1312,7 @@ def test_cmd_add_gitignored_file(self): ignored_file.write_text("This should be ignored") io = InputOutput(pretty=False, fancy_input=False, yes=False) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Try to add the ignored file @@ -1321,9 +1321,9 @@ def test_cmd_add_gitignored_file(self): # Verify the file was not added self.assertEqual(len(coder.abs_fnames), 0) - def test_cmd_think_tokens(self): + async def test_cmd_think_tokens(self): io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Test with various formats @@ -1354,7 +1354,7 @@ def test_cmd_think_tokens(self): commands.cmd_think_tokens("") mock_tool_output.assert_any_call(mock.ANY) # Just verify it calls tool_output - def test_cmd_add_aiderignored_file(self): + async def test_cmd_add_aiderignored_file(self): with GitTemporaryDirectory(): repo = git.Repo() @@ -1379,7 +1379,7 @@ def test_cmd_add_aiderignored_file(self): aider_ignore_file=str(aignore), ) - coder = Coder.create( + coder = await Coder.create( self.GPT35, None, io, @@ -1394,10 +1394,10 @@ def test_cmd_add_aiderignored_file(self): self.assertNotIn(fname2, str(coder.abs_fnames)) self.assertNotIn(fname3, str(coder.abs_fnames)) - def test_cmd_read_only(self): + async def test_cmd_read_only(self): with GitTemporaryDirectory(): io = InputOutput(pretty=False, fancy_input=False, yes=False) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create a test file @@ -1426,10 +1426,10 @@ def test_cmd_read_only(self): ) ) - def test_cmd_read_only_from_working_dir(self): + async def test_cmd_read_only_from_working_dir(self): with GitTemporaryDirectory() as repo_dir: io = InputOutput(pretty=False, fancy_input=False, yes=False) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create a subdirectory and a test file within it @@ -1463,7 +1463,7 @@ def test_cmd_read_only_from_working_dir(self): ) ) - def test_cmd_read_only_with_external_file(self): + async def test_cmd_read_only_with_external_file(self): with tempfile.NamedTemporaryFile(mode="w", delete=False) as external_file: external_file.write("External file content") external_file_path = external_file.name @@ -1474,7 +1474,7 @@ def test_cmd_read_only_with_external_file(self): repo_file = Path(repo_dir) / "repo_file.txt" repo_file.write_text("Repo file content") io = InputOutput(pretty=False, fancy_input=False, yes=False) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Test the /read command with an external file @@ -1502,7 +1502,7 @@ def test_cmd_read_only_with_external_file(self): finally: os.unlink(external_file_path) - def test_cmd_drop_read_only_with_relative_path(self): + async def test_cmd_drop_read_only_with_relative_path(self): with ChdirTemporaryDirectory() as repo_dir: test_file = Path("test_file.txt") test_file.write_text("Test content") @@ -1513,7 +1513,7 @@ def test_cmd_drop_read_only_with_relative_path(self): os.chdir(subdir) io = InputOutput(pretty=False, fancy_input=False, yes=False) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Add the file as read-only using absolute path @@ -1539,10 +1539,10 @@ def test_cmd_drop_read_only_with_relative_path(self): commands.cmd_drop("test_file.txt") self.assertEqual(len(coder.abs_read_only_fnames), 0) - def test_cmd_read_only_bulk_conversion(self): + async def test_cmd_read_only_bulk_conversion(self): with GitTemporaryDirectory() as repo_dir: io = InputOutput(pretty=False, fancy_input=False, yes=False) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create and add some test files @@ -1572,10 +1572,10 @@ def test_cmd_read_only_bulk_conversion(self): ) ) - def test_cmd_read_only_with_multiple_files(self): + async def test_cmd_read_only_with_multiple_files(self): with GitTemporaryDirectory() as repo_dir: io = InputOutput(pretty=False, fancy_input=False, yes=False) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create multiple test files @@ -1603,10 +1603,10 @@ def test_cmd_read_only_with_multiple_files(self): # Check if all files were removed from abs_read_only_fnames self.assertEqual(len(coder.abs_read_only_fnames), 0) - def test_cmd_read_only_with_tilde_path(self): + async def test_cmd_read_only_with_tilde_path(self): with GitTemporaryDirectory(): io = InputOutput(pretty=False, fancy_input=False, yes=False) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create a test file in the user's home directory @@ -1638,10 +1638,10 @@ def test_cmd_read_only_with_tilde_path(self): test_file.unlink() # pytest tests/basic/test_commands.py -k test_cmd_read_only_with_square_brackets - def test_cmd_read_only_with_square_brackets(self): + async def test_cmd_read_only_with_square_brackets(self): with GitTemporaryDirectory() as repo_dir: io = InputOutput(pretty=False, fancy_input=False, yes=False) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create test layout @@ -1664,11 +1664,11 @@ def test_cmd_read_only_with_square_brackets(self): # Check if all files were removed from abs_read_only_fnames self.assertEqual(len(coder.abs_read_only_fnames), 0) - def test_cmd_read_only_with_fuzzy_finder(self): + async def test_cmd_read_only_with_fuzzy_finder(self): with GitTemporaryDirectory() as repo_dir: repo = git.Repo(repo_dir) io = InputOutput(pretty=False, fancy_input=False, yes=False) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create some test files @@ -1700,10 +1700,10 @@ def test_cmd_read_only_with_fuzzy_finder(self): ) ) - def test_cmd_read_only_with_fuzzy_finder_no_selection(self): + async def test_cmd_read_only_with_fuzzy_finder_no_selection(self): with GitTemporaryDirectory(): io = InputOutput(pretty=False, fancy_input=False, yes=False) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create and add some test files @@ -1723,11 +1723,11 @@ def test_cmd_read_only_with_fuzzy_finder_no_selection(self): self.assertEqual(len(coder.abs_fnames), 0) self.assertEqual(len(coder.abs_read_only_fnames), 3) - def test_cmd_diff(self): + async def test_cmd_diff(self): with GitTemporaryDirectory() as repo_dir: repo = git.Repo(repo_dir) io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create and commit a file @@ -1789,9 +1789,9 @@ def test_cmd_diff(self): self.assertIn("-Further modified content", diff_output) self.assertIn("+Final modified content", diff_output) - def test_cmd_model(self): + async def test_cmd_model(self): io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Test switching the main model @@ -1811,10 +1811,10 @@ def test_cmd_model(self): # Check that the edit format is updated to the new model's default self.assertEqual(context.exception.kwargs.get("edit_format"), "diff") - def test_cmd_model_preserves_explicit_edit_format(self): + async def test_cmd_model_preserves_explicit_edit_format(self): io = InputOutput(pretty=False, fancy_input=False, yes=True) # Use gpt-3.5-turbo (default 'diff') - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) # Explicitly set edit format to something else coder.edit_format = "udiff" commands = Commands(io, coder) @@ -1830,9 +1830,9 @@ def test_cmd_model_preserves_explicit_edit_format(self): # Check that the edit format is preserved self.assertEqual(context.exception.kwargs.get("edit_format"), "udiff") - def test_cmd_editor_model(self): + async def test_cmd_editor_model(self): io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Test switching the editor model @@ -1847,9 +1847,9 @@ def test_cmd_editor_model(self): self.GPT35.weak_model.name, ) - def test_cmd_weak_model(self): + async def test_cmd_weak_model(self): io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Test switching the weak model @@ -1864,10 +1864,10 @@ def test_cmd_weak_model(self): ) self.assertEqual(context.exception.kwargs.get("main_model").weak_model.name, "gpt-4") - def test_cmd_model_updates_default_edit_format(self): + async def test_cmd_model_updates_default_edit_format(self): io = InputOutput(pretty=False, fancy_input=False, yes=True) # Use gpt-3.5-turbo (default 'diff') - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) # Ensure current edit format is the default self.assertEqual(coder.edit_format, self.GPT35.edit_format) commands = Commands(io, coder) @@ -1883,9 +1883,9 @@ def test_cmd_model_updates_default_edit_format(self): # Check that the edit format is updated to the new model's default self.assertEqual(context.exception.kwargs.get("edit_format"), "diff") - def test_cmd_ask(self): + async def test_cmd_ask(self): io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) question = "What is the meaning of life?" @@ -1900,11 +1900,11 @@ def test_cmd_ask(self): mock_run.assert_called_once() mock_run.assert_called_once_with(question) - def test_cmd_lint_with_dirty_file(self): + async def test_cmd_lint_with_dirty_file(self): with GitTemporaryDirectory() as repo_dir: repo = git.Repo(repo_dir) io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create and commit a file @@ -1938,10 +1938,10 @@ def test_cmd_lint_with_dirty_file(self): del commands del repo - def test_cmd_reset(self): + async def test_cmd_reset(self): with GitTemporaryDirectory() as repo_dir: io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Add some files to the chat @@ -1973,10 +1973,10 @@ def test_cmd_reset(self): del coder del commands - def test_reset_with_original_read_only_files(self): + async def test_reset_with_original_read_only_files(self): with GitTemporaryDirectory() as repo_dir: io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) # Create test files orig_read_only = Path(repo_dir) / "orig_read_only.txt" @@ -2020,10 +2020,10 @@ def test_reset_with_original_read_only_files(self): self.assertEqual(len(coder.cur_messages), 0) self.assertEqual(len(coder.done_messages), 0) - def test_reset_with_no_original_read_only_files(self): + async def test_reset_with_no_original_read_only_files(self): with GitTemporaryDirectory() as repo_dir: io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) # Create test files added_file = Path(repo_dir) / "added_file.txt" @@ -2058,9 +2058,9 @@ def test_reset_with_no_original_read_only_files(self): self.assertEqual(len(coder.cur_messages), 0) self.assertEqual(len(coder.done_messages), 0) - def test_cmd_reasoning_effort(self): + async def test_cmd_reasoning_effort(self): io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Test with numeric values @@ -2084,10 +2084,10 @@ def test_cmd_reasoning_effort(self): commands.cmd_reasoning_effort("") mock_tool_output.assert_any_call("Current reasoning effort: high") - def test_drop_with_original_read_only_files(self): + async def test_drop_with_original_read_only_files(self): with GitTemporaryDirectory() as repo_dir: io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) # Create test files orig_read_only = Path(repo_dir) / "orig_read_only.txt" @@ -2124,10 +2124,10 @@ def test_drop_with_original_read_only_files(self): self.assertIn(str(orig_read_only), coder.abs_read_only_fnames) self.assertNotIn(str(added_read_only), coder.abs_read_only_fnames) - def test_drop_specific_original_read_only_file(self): + async def test_drop_specific_original_read_only_file(self): with GitTemporaryDirectory() as repo_dir: io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) # Create test file orig_read_only = Path(repo_dir) / "orig_read_only.txt" @@ -2148,10 +2148,10 @@ def test_drop_specific_original_read_only_file(self): # Verify that the original read-only file is dropped when specified explicitly self.assertEqual(len(coder.abs_read_only_fnames), 0) - def test_drop_with_no_original_read_only_files(self): + async def test_drop_with_no_original_read_only_files(self): with GitTemporaryDirectory() as repo_dir: io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) # Create test files added_file = Path(repo_dir) / "added_file.txt" @@ -2180,10 +2180,10 @@ def test_drop_with_no_original_read_only_files(self): self.assertEqual(len(coder.abs_fnames), 0) self.assertEqual(len(coder.abs_read_only_fnames), 0) - def test_cmd_load_with_switch_coder(self): + async def test_cmd_load_with_switch_coder(self): with GitTemporaryDirectory() as repo_dir: io = InputOutput(pretty=False, fancy_input=False, yes=True) - coder = Coder.create(self.GPT35, None, io) + coder = await Coder.create(self.GPT35, None, io) commands = Commands(io, coder) # Create a temporary file with commands @@ -2191,7 +2191,7 @@ def test_cmd_load_with_switch_coder(self): commands_file.write_text("/ask Tell me about the code\n/model gpt-4\n") # Mock run to raise SwitchCoder for /ask and /model - def mock_run(cmd): + async def mock_run(cmd): if cmd.startswith(("/ask", "/model")): raise SwitchCoder() return None @@ -2210,7 +2210,7 @@ def mock_run(cmd): "Command '/model gpt-4' is only supported in interactive mode, skipping." ) - def test_reset_after_coder_clone_preserves_original_read_only_files(self): + async def test_reset_after_coder_clone_preserves_original_read_only_files(self): with GitTemporaryDirectory() as _: repo_dir = str(".") io = InputOutput(pretty=False, fancy_input=False, yes=True) @@ -2227,7 +2227,7 @@ def test_reset_after_coder_clone_preserves_original_read_only_files(self): original_read_only_fnames_set = {str(orig_ro_path)} # Create the initial Coder - orig_coder = Coder.create(main_model=self.GPT35, io=io, fnames=[], repo=None) + orig_coder = await Coder.create(main_model=self.GPT35, io=io, fnames=[], repo=None) orig_coder.root = repo_dir # Set root for path operations # Replace its commands object with one that has the original_read_only_fnames @@ -2244,7 +2244,7 @@ def test_reset_after_coder_clone_preserves_original_read_only_files(self): orig_coder.abs_read_only_fnames.add(str(other_ro_path)) # Simulate SwitchCoder by creating a new coder from the original one - new_coder = Coder.create(from_coder=orig_coder) + new_coder = await Coder.create(from_coder=orig_coder) new_commands = new_coder.commands # Perform /reset @@ -2261,7 +2261,7 @@ def test_reset_after_coder_clone_preserves_original_read_only_files(self): self.assertEqual(len(new_coder.done_messages), 0) self.assertEqual(len(new_coder.cur_messages), 0) - def test_drop_bare_after_coder_clone_preserves_original_read_only_files(self): + async def test_drop_bare_after_coder_clone_preserves_original_read_only_files(self): with GitTemporaryDirectory() as _: repo_dir = str(".") io = InputOutput(pretty=False, fancy_input=False, yes=True) @@ -2277,7 +2277,7 @@ def test_drop_bare_after_coder_clone_preserves_original_read_only_files(self): original_read_only_fnames_set = {str(orig_ro_path)} - orig_coder = Coder.create(main_model=self.GPT35, io=io, fnames=[], repo=None) + orig_coder = await Coder.create(main_model=self.GPT35, io=io, fnames=[], repo=None) orig_coder.root = repo_dir orig_coder.commands = Commands( io, @@ -2292,7 +2292,7 @@ def test_drop_bare_after_coder_clone_preserves_original_read_only_files(self): orig_coder.done_messages = [{"role": "user", "content": "d1"}] orig_coder.cur_messages = [{"role": "user", "content": "c1"}] - new_coder = Coder.create(from_coder=orig_coder) + new_coder = await Coder.create(from_coder=orig_coder) new_commands = new_coder.commands new_commands.cmd_drop("") diff --git a/tests/basic/test_editblock.py b/tests/basic/test_editblock.py index e93edb7c32f..70bb16f38ab 100644 --- a/tests/basic/test_editblock.py +++ b/tests/basic/test_editblock.py @@ -5,6 +5,8 @@ from pathlib import Path from unittest.mock import MagicMock, patch +import pytest + from aider.coders import Coder from aider.coders import editblock_coder as eb from aider.dump import dump # noqa: F401 @@ -320,7 +322,7 @@ def test_replace_part_with_missing_leading_whitespace_including_blank_line(self) result = eb.replace_most_similar_chunk(whole, part, replace) self.assertEqual(result, expected_output) - def test_create_new_file_with_other_file_in_chat(self): + async def test_create_new_file_with_other_file_in_chat(self): # https://github.com/Aider-AI/aider/issues/2258 with ChdirTemporaryDirectory(): # Create a few temporary files @@ -332,11 +334,11 @@ def test_create_new_file_with_other_file_in_chat(self): files = [file1] # Initialize the Coder object with the mocked IO and mocked repo - coder = Coder.create( + coder = await Coder.create( self.GPT35, "diff", use_git=False, io=InputOutput(yes=True), fnames=files ) - def mock_send(*args, **kwargs): + async def mock_send(*args, **kwargs): coder.partial_response_content = f""" Do this: @@ -352,7 +354,7 @@ def mock_send(*args, **kwargs): coder.send = mock_send - coder.run(with_message="hi") + await coder.run(with_message="hi") content = Path(file1).read_text(encoding="utf-8") self.assertEqual(content, "one\ntwo\nthree\n") @@ -360,7 +362,7 @@ def mock_send(*args, **kwargs): content = Path("newfile.txt").read_text(encoding="utf-8") self.assertEqual(content, "creating a new file\n") - def test_full_edit(self): + async def test_full_edit(self): # Create a few temporary files _, file1 = tempfile.mkstemp() @@ -370,9 +372,9 @@ def test_full_edit(self): files = [file1] # Initialize the Coder object with the mocked IO and mocked repo - coder = Coder.create(self.GPT35, "diff", io=InputOutput(), fnames=files) + coder = await Coder.create(self.GPT35, "diff", io=InputOutput(), fnames=files) - def mock_send(*args, **kwargs): + async def mock_send(*args, **kwargs): coder.partial_response_content = f""" Do this: @@ -390,12 +392,12 @@ def mock_send(*args, **kwargs): coder.send = mock_send # Call the run method with a message - coder.run(with_message="hi") + await coder.run(with_message="hi") content = Path(file1).read_text(encoding="utf-8") self.assertEqual(content, "one\nnew\nthree\n") - def test_full_edit_dry_run(self): + async def test_full_edit_dry_run(self): # Create a few temporary files _, file1 = tempfile.mkstemp() @@ -407,7 +409,7 @@ def test_full_edit_dry_run(self): files = [file1] # Initialize the Coder object with the mocked IO and mocked repo - coder = Coder.create( + coder = await Coder.create( self.GPT35, "diff", io=InputOutput(dry_run=True), @@ -415,7 +417,7 @@ def test_full_edit_dry_run(self): dry_run=True, ) - def mock_send(*args, **kwargs): + async def mock_send(*args, **kwargs): coder.partial_response_content = f""" Do this: @@ -433,7 +435,7 @@ def mock_send(*args, **kwargs): coder.send = mock_send # Call the run method with a message - coder.run(with_message="hi") + await coder.run(with_message="hi") content = Path(file1).read_text(encoding="utf-8") self.assertEqual(content, orig_content) diff --git a/tests/basic/test_history.py b/tests/basic/test_history.py index 80fcfe072fd..07b2721da23 100644 --- a/tests/basic/test_history.py +++ b/tests/basic/test_history.py @@ -41,13 +41,13 @@ def test_tokenize(self): tokenized = self.chat_summary.tokenize(messages) self.assertEqual(tokenized, [(2, messages[0]), (2, messages[1])]) - def test_summarize_all(self): + async def test_summarize_all(self): self.mock_model.simple_send_with_retries.return_value = "This is a summary" messages = [ {"role": "user", "content": "Hello world"}, {"role": "assistant", "content": "Hi there"}, ] - summary = self.chat_summary.summarize_all(messages) + summary = await self.chat_summary.summarize_all(messages) self.assertEqual( summary, [ @@ -58,7 +58,7 @@ def test_summarize_all(self): ], ) - def test_summarize(self): + async def test_summarize(self): N = 100 messages = [None] * (2 * N) for i in range(N): @@ -70,7 +70,7 @@ def test_summarize(self): "summarize_all", return_value=[{"role": "user", "content": "Summary"}], ): - result = self.chat_summary.summarize(messages) + result = await self.chat_summary.summarize(messages) print(result) self.assertIsInstance(result, list) @@ -78,7 +78,7 @@ def test_summarize(self): self.assertLess(len(result), len(messages)) self.assertEqual(result[0]["content"], "Summary") - def test_fallback_to_second_model(self): + async def test_fallback_to_second_model(self): mock_model1 = mock.Mock(spec=Model) mock_model1.name = "gpt-4" mock_model1.simple_send_with_retries = mock.Mock(side_effect=Exception("Model 1 failed")) @@ -98,7 +98,7 @@ def test_fallback_to_second_model(self): {"role": "assistant", "content": "Hi there"}, ] - summary = chat_summary.summarize_all(messages) + summary = await chat_summary.summarize_all(messages) # Check that both models were tried mock_model1.simple_send_with_retries.assert_called_once() diff --git a/tests/basic/test_io.py b/tests/basic/test_io.py index ff8a618c76c..9495985192e 100644 --- a/tests/basic/test_io.py +++ b/tests/basic/test_io.py @@ -1,7 +1,8 @@ +import asyncio import os import unittest from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch from prompt_toolkit.completion import CompleteEvent from prompt_toolkit.document import Document @@ -168,7 +169,7 @@ def test_get_input_is_a_directory_error(self, mock_input): # Simulate IsADirectoryError with patch("aider.io.open", side_effect=IsADirectoryError): - result = io.get_input(root, rel_fnames, addable_rel_fnames, commands) + result = asyncio.run(io.get_input(root, rel_fnames, addable_rel_fnames, commands)) self.assertEqual(result, "test input") mock_input.assert_called_once() @@ -178,20 +179,20 @@ def test_confirm_ask_explicit_yes_required(self, mock_input): # Test case 1: explicit_yes_required=True, self.yes=True io.yes = True - result = io.confirm_ask("Are you sure?", explicit_yes_required=True) + result = asyncio.run(io.confirm_ask("Are you sure?", explicit_yes_required=True)) self.assertFalse(result) mock_input.assert_not_called() # Test case 2: explicit_yes_required=True, self.yes=False io.yes = False - result = io.confirm_ask("Are you sure?", explicit_yes_required=True) + result = asyncio.run(io.confirm_ask("Are you sure?", explicit_yes_required=True)) self.assertFalse(result) mock_input.assert_not_called() # Test case 3: explicit_yes_required=True, user input required io.yes = None mock_input.return_value = "y" - result = io.confirm_ask("Are you sure?", explicit_yes_required=True) + result = asyncio.run(io.confirm_ask("Are you sure?", explicit_yes_required=True)) self.assertTrue(result) mock_input.assert_called_once() @@ -200,7 +201,7 @@ def test_confirm_ask_explicit_yes_required(self, mock_input): # Test case 4: explicit_yes_required=False, self.yes=True io.yes = True - result = io.confirm_ask("Are you sure?", explicit_yes_required=False) + result = asyncio.run(io.confirm_ask("Are you sure?", explicit_yes_required=False)) self.assertTrue(result) mock_input.assert_not_called() @@ -211,35 +212,37 @@ def test_confirm_ask_with_group(self, mock_input): # Test case 1: No group preference, user selects 'All' mock_input.return_value = "a" - result = io.confirm_ask("Are you sure?", group=group) + result = asyncio.run(io.confirm_ask("Are you sure?", group=group)) self.assertTrue(result) self.assertEqual(group.preference, "all") mock_input.assert_called_once() mock_input.reset_mock() # Test case 2: Group preference is 'All', should not prompt - result = io.confirm_ask("Are you sure?", group=group) + result = asyncio.run(io.confirm_ask("Are you sure?", group=group)) self.assertTrue(result) mock_input.assert_not_called() # Test case 3: No group preference, user selects 'Skip all' group.preference = None mock_input.return_value = "s" - result = io.confirm_ask("Are you sure?", group=group) + result = asyncio.run(io.confirm_ask("Are you sure?", group=group)) self.assertFalse(result) self.assertEqual(group.preference, "skip") mock_input.assert_called_once() mock_input.reset_mock() # Test case 4: Group preference is 'Skip all', should not prompt - result = io.confirm_ask("Are you sure?", group=group) + result = asyncio.run(io.confirm_ask("Are you sure?", group=group)) self.assertFalse(result) mock_input.assert_not_called() # Test case 5: explicit_yes_required=True, should not offer 'All' option group.preference = None mock_input.return_value = "y" - result = io.confirm_ask("Are you sure?", group=group, explicit_yes_required=True) + result = asyncio.run( + io.confirm_ask("Are you sure?", group=group, explicit_yes_required=True) + ) self.assertTrue(result) self.assertIsNone(group.preference) mock_input.assert_called_once() @@ -252,49 +255,49 @@ def test_confirm_ask_yes_no(self, mock_input): # Test case 1: User selects 'Yes' mock_input.return_value = "y" - result = io.confirm_ask("Are you sure?") + result = asyncio.run(io.confirm_ask("Are you sure?")) self.assertTrue(result) mock_input.assert_called_once() mock_input.reset_mock() # Test case 2: User selects 'No' mock_input.return_value = "n" - result = io.confirm_ask("Are you sure?") + result = asyncio.run(io.confirm_ask("Are you sure?")) self.assertFalse(result) mock_input.assert_called_once() mock_input.reset_mock() # Test case 3: Empty input (default to Yes) mock_input.return_value = "" - result = io.confirm_ask("Are you sure?") + result = asyncio.run(io.confirm_ask("Are you sure?")) self.assertTrue(result) mock_input.assert_called_once() mock_input.reset_mock() # Test case 4: 'skip' functions as 'no' without group mock_input.return_value = "s" - result = io.confirm_ask("Are you sure?") + result = asyncio.run(io.confirm_ask("Are you sure?")) self.assertFalse(result) mock_input.assert_called_once() mock_input.reset_mock() # Test case 5: 'all' functions as 'yes' without group mock_input.return_value = "a" - result = io.confirm_ask("Are you sure?") + result = asyncio.run(io.confirm_ask("Are you sure?")) self.assertTrue(result) mock_input.assert_called_once() mock_input.reset_mock() # Test case 6: Full word 'skip' functions as 'no' without group mock_input.return_value = "skip" - result = io.confirm_ask("Are you sure?") + result = asyncio.run(io.confirm_ask("Are you sure?")) self.assertFalse(result) mock_input.assert_called_once() mock_input.reset_mock() # Test case 7: Full word 'all' functions as 'yes' without group mock_input.return_value = "all" - result = io.confirm_ask("Are you sure?") + result = asyncio.run(io.confirm_ask("Are you sure?")) self.assertTrue(result) mock_input.assert_called_once() mock_input.reset_mock() @@ -305,7 +308,7 @@ def test_confirm_ask_allow_never(self, mock_input): io = InputOutput(pretty=False, fancy_input=False) # First call: user selects "Don't ask again" - result = io.confirm_ask("Are you sure?", allow_never=True) + result = asyncio.run(io.confirm_ask("Are you sure?", allow_never=True)) self.assertFalse(result) mock_input.assert_called_once() self.assertIn(("Are you sure?", None), io.never_prompts) @@ -314,28 +317,32 @@ def test_confirm_ask_allow_never(self, mock_input): mock_input.reset_mock() # Second call: should not prompt, immediately return False - result = io.confirm_ask("Are you sure?", allow_never=True) + result = asyncio.run(io.confirm_ask("Are you sure?", allow_never=True)) self.assertFalse(result) mock_input.assert_not_called() # Test with subject parameter mock_input.reset_mock() mock_input.side_effect = ["d"] - result = io.confirm_ask("Confirm action?", subject="Subject Text", allow_never=True) + result = asyncio.run( + io.confirm_ask("Confirm action?", subject="Subject Text", allow_never=True) + ) self.assertFalse(result) mock_input.assert_called_once() self.assertIn(("Confirm action?", "Subject Text"), io.never_prompts) # Subsequent call with the same question and subject mock_input.reset_mock() - result = io.confirm_ask("Confirm action?", subject="Subject Text", allow_never=True) + result = asyncio.run( + io.confirm_ask("Confirm action?", subject="Subject Text", allow_never=True) + ) self.assertFalse(result) mock_input.assert_not_called() # Test that allow_never=False does not add to never_prompts mock_input.reset_mock() mock_input.side_effect = ["d", "n"] - result = io.confirm_ask("Do you want to proceed?", allow_never=False) + result = asyncio.run(io.confirm_ask("Do you want to proceed?", allow_never=False)) self.assertFalse(result) self.assertEqual(mock_input.call_count, 2) self.assertNotIn(("Do you want to proceed?", None), io.never_prompts) @@ -387,18 +394,21 @@ def test_multiline_mode_restored_after_interrupt(self): io = InputOutput(fancy_input=True) io.prompt_session = MagicMock() + # Use AsyncMock for prompt_async (for confirm_ask) + io.prompt_session.prompt_async = AsyncMock(side_effect=KeyboardInterrupt) + # Start in multiline mode io.multiline_mode = True - # Mock prompt() to raise KeyboardInterrupt - io.prompt_session.prompt.side_effect = KeyboardInterrupt - - # Test confirm_ask() + # Test confirm_ask() - this is now async, so we need to handle it differently with self.assertRaises(KeyboardInterrupt): - io.confirm_ask("Test question?") + asyncio.run(io.confirm_ask("Test question?")) self.assertTrue(io.multiline_mode) # Should be restored - # Test prompt_ask() + # Test prompt_ask() - this is still synchronous + # Mock the synchronous prompt method to raise KeyboardInterrupt + io.prompt_session.prompt = MagicMock(side_effect=KeyboardInterrupt) + with self.assertRaises(KeyboardInterrupt): io.prompt_ask("Test prompt?") self.assertTrue(io.multiline_mode) # Should be restored @@ -408,17 +418,17 @@ def test_multiline_mode_restored_after_normal_exit(self): io = InputOutput(fancy_input=True) io.prompt_session = MagicMock() + # Use AsyncMock for prompt_async that returns "y" + io.prompt_session.prompt_async = AsyncMock(return_value="y") + # Start in multiline mode io.multiline_mode = True - # Mock prompt() to return normally - io.prompt_session.prompt.return_value = "y" - - # Test confirm_ask() - io.confirm_ask("Test question?") + # Test confirm_ask() - this is now async + asyncio.run(io.confirm_ask("Test question?")) self.assertTrue(io.multiline_mode) # Should be restored - # Test prompt_ask() + # Test prompt_ask() - this is still synchronous io.prompt_ask("Test prompt?") self.assertTrue(io.multiline_mode) # Should be restored diff --git a/tests/basic/test_main.py b/tests/basic/test_main.py index 11b889f9c4a..6f6e81fed17 100644 --- a/tests/basic/test_main.py +++ b/tests/basic/test_main.py @@ -5,7 +5,7 @@ from io import StringIO from pathlib import Path from unittest import TestCase -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import git from prompt_toolkit.input import DummyInput @@ -45,39 +45,43 @@ def tearDown(self): self.input_patcher.stop() self.webbrowser_patcher.stop() - def test_main_with_empty_dir_no_files_on_command(self): - main(["--no-git", "--exit", "--yes"], input=DummyInput(), output=DummyOutput()) + async def test_main_with_empty_dir_no_files_on_command(self): + await main(["--no-git", "--exit", "--yes"], input=DummyInput(), output=DummyOutput()) - def test_main_with_emptqy_dir_new_file(self): - main(["foo.txt", "--yes", "--no-git", "--exit"], input=DummyInput(), output=DummyOutput()) + async def test_main_with_emptqy_dir_new_file(self): + await main( + ["foo.txt", "--yes", "--no-git", "--exit"], input=DummyInput(), output=DummyOutput() + ) self.assertTrue(os.path.exists("foo.txt")) @patch("aider.repo.GitRepo.get_commit_message", return_value="mock commit message") - def test_main_with_empty_git_dir_new_file(self, _): + async def test_main_with_empty_git_dir_new_file(self, _): make_repo() - main(["--yes", "foo.txt", "--exit"], input=DummyInput(), output=DummyOutput()) + await main(["--yes", "foo.txt", "--exit"], input=DummyInput(), output=DummyOutput()) self.assertTrue(os.path.exists("foo.txt")) @patch("aider.repo.GitRepo.get_commit_message", return_value="mock commit message") - def test_main_with_empty_git_dir_new_files(self, _): + async def test_main_with_empty_git_dir_new_files(self, _): make_repo() - main(["--yes", "foo.txt", "bar.txt", "--exit"], input=DummyInput(), output=DummyOutput()) + await main( + ["--yes", "foo.txt", "bar.txt", "--exit"], input=DummyInput(), output=DummyOutput() + ) self.assertTrue(os.path.exists("foo.txt")) self.assertTrue(os.path.exists("bar.txt")) - def test_main_with_dname_and_fname(self): + async def test_main_with_dname_and_fname(self): subdir = Path("subdir") subdir.mkdir() make_repo(str(subdir)) - res = main(["subdir", "foo.txt"], input=DummyInput(), output=DummyOutput()) + res = await main(["subdir", "foo.txt"], input=DummyInput(), output=DummyOutput()) self.assertNotEqual(res, None) @patch("aider.repo.GitRepo.get_commit_message", return_value="mock commit message") - def test_main_with_subdir_repo_fnames(self, _): + async def test_main_with_subdir_repo_fnames(self, _): subdir = Path("subdir") subdir.mkdir() make_repo(str(subdir)) - main( + await main( ["--yes", str(subdir / "foo.txt"), str(subdir / "bar.txt"), "--exit"], input=DummyInput(), output=DummyOutput(), @@ -85,22 +89,22 @@ def test_main_with_subdir_repo_fnames(self, _): self.assertTrue((subdir / "foo.txt").exists()) self.assertTrue((subdir / "bar.txt").exists()) - def test_main_with_git_config_yml(self): + async def test_main_with_git_config_yml(self): make_repo() Path(".aider.conf.yml").write_text("auto-commits: false\n") with patch("aider.coders.Coder.create") as MockCoder: - main(["--yes"], input=DummyInput(), output=DummyOutput()) + await main(["--yes"], input=DummyInput(), output=DummyOutput()) _, kwargs = MockCoder.call_args assert kwargs["auto_commits"] is False Path(".aider.conf.yml").write_text("auto-commits: true\n") with patch("aider.coders.Coder.create") as MockCoder: - main([], input=DummyInput(), output=DummyOutput()) + await main([], input=DummyInput(), output=DummyOutput()) _, kwargs = MockCoder.call_args assert kwargs["auto_commits"] is True - def test_main_with_empty_git_dir_new_subdir_file(self): + async def test_main_with_empty_git_dir_new_subdir_file(self): make_repo() subdir = Path("subdir") subdir.mkdir() @@ -112,7 +116,7 @@ def test_main_with_empty_git_dir_new_subdir_file(self): # This will throw a git error on windows if get_tracked_files doesn't # properly convert git/posix/paths to git\posix\paths. # Because aider will try and `git add` a file that's already in the repo. - main(["--yes", str(fname), "--exit"], input=DummyInput(), output=DummyOutput()) + await main(["--yes", str(fname), "--exit"], input=DummyInput(), output=DummyOutput()) def test_setup_git(self): io = InputOutput(pretty=False, yes=True) @@ -152,7 +156,7 @@ def test_check_gitignore(self): self.assertEqual("one\ntwo\n.aider*\n.env\n", gitignore.read_text()) del os.environ["GIT_CONFIG_GLOBAL"] - def test_command_line_gitignore_files_flag(self): + async def test_command_line_gitignore_files_flag(self): with GitTemporaryDirectory() as git_dir: git_dir = Path(git_dir) @@ -168,7 +172,7 @@ def test_command_line_gitignore_files_flag(self): abs_ignored_file = str(ignored_file.resolve()) # Test without the --add-gitignore-files flag (default: False) - coder = main( + coder = await main( ["--exit", "--yes", abs_ignored_file], input=DummyInput(), output=DummyOutput(), @@ -179,7 +183,7 @@ def test_command_line_gitignore_files_flag(self): self.assertNotIn(abs_ignored_file, coder.abs_fnames) # Test with --add-gitignore-files set to True - coder = main( + coder = await main( ["--add-gitignore-files", "--exit", "--yes", abs_ignored_file], input=DummyInput(), output=DummyOutput(), @@ -190,7 +194,7 @@ def test_command_line_gitignore_files_flag(self): self.assertIn(abs_ignored_file, coder.abs_fnames) # Test with --add-gitignore-files set to False - coder = main( + coder = await main( ["--no-add-gitignore-files", "--exit", "--yes", abs_ignored_file], input=DummyInput(), output=DummyOutput(), @@ -200,7 +204,7 @@ def test_command_line_gitignore_files_flag(self): # Verify the ignored file is not in the chat self.assertNotIn(abs_ignored_file, coder.abs_fnames) - def test_add_command_gitignore_files_flag(self): + async def test_add_command_gitignore_files_flag(self): with GitTemporaryDirectory() as git_dir: git_dir = Path(git_dir) @@ -217,7 +221,7 @@ def test_add_command_gitignore_files_flag(self): rel_ignored_file = "ignored.txt" # Test without the --add-gitignore-files flag (default: False) - coder = main( + coder = await main( ["--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -232,7 +236,7 @@ def test_add_command_gitignore_files_flag(self): self.assertNotIn(abs_ignored_file, coder.abs_fnames) # Test with --add-gitignore-files set to True - coder = main( + coder = await main( ["--add-gitignore-files", "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -246,7 +250,7 @@ def test_add_command_gitignore_files_flag(self): self.assertIn(abs_ignored_file, coder.abs_fnames) # Test with --add-gitignore-files set to False - coder = main( + coder = await main( ["--no-add-gitignore-files", "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -260,36 +264,36 @@ def test_add_command_gitignore_files_flag(self): # Verify the ignored file is not in the chat self.assertNotIn(abs_ignored_file, coder.abs_fnames) - def test_main_args(self): + async def test_main_args(self): with patch("aider.coders.Coder.create") as MockCoder: # --yes will just ok the git repo without blocking on input # following calls to main will see the new repo already - main(["--no-auto-commits", "--yes"], input=DummyInput()) + await main(["--no-auto-commits", "--yes"], input=DummyInput()) _, kwargs = MockCoder.call_args assert kwargs["auto_commits"] is False with patch("aider.coders.Coder.create") as MockCoder: - main(["--auto-commits"], input=DummyInput()) + await main(["--auto-commits"], input=DummyInput()) _, kwargs = MockCoder.call_args assert kwargs["auto_commits"] is True with patch("aider.coders.Coder.create") as MockCoder: - main([], input=DummyInput()) + await main([], input=DummyInput()) _, kwargs = MockCoder.call_args assert kwargs["dirty_commits"] is True assert kwargs["auto_commits"] is True with patch("aider.coders.Coder.create") as MockCoder: - main(["--no-dirty-commits"], input=DummyInput()) + await main(["--no-dirty-commits"], input=DummyInput()) _, kwargs = MockCoder.call_args assert kwargs["dirty_commits"] is False with patch("aider.coders.Coder.create") as MockCoder: - main(["--dirty-commits"], input=DummyInput()) + await main(["--dirty-commits"], input=DummyInput()) _, kwargs = MockCoder.call_args assert kwargs["dirty_commits"] is True - def test_env_file_override(self): + async def test_env_file_override(self): with GitTemporaryDirectory() as git_dir: git_dir = Path(git_dir) git_env = git_dir / ".env" @@ -313,7 +317,7 @@ def test_env_file_override(self): named_env.write_text("A=named") with patch("pathlib.Path.home", return_value=fake_home): - main(["--yes", "--exit", "--env-file", str(named_env)]) + await main(["--yes", "--exit", "--env-file", str(named_env)]) self.assertEqual(os.environ["A"], "named") self.assertEqual(os.environ["B"], "cwd") @@ -321,24 +325,33 @@ def test_env_file_override(self): self.assertEqual(os.environ["D"], "home") self.assertEqual(os.environ["E"], "existing") - def test_message_file_flag(self): + async def test_message_file_flag(self): message_file_content = "This is a test message from a file." message_file_path = tempfile.mktemp() with open(message_file_path, "w", encoding="utf-8") as message_file: message_file.write(message_file_content) + # Create a mock async function for the run method + async def mock_run(*args, **kwargs): + pass + with patch("aider.coders.Coder.create") as MockCoder: - MockCoder.return_value.run = MagicMock() - main( + # Create a mock coder instance with an async run method + mock_coder_instance = MagicMock() + mock_coder_instance.run = AsyncMock() + MockCoder.return_value = mock_coder_instance + + await main( ["--yes", "--message-file", message_file_path], input=DummyInput(), output=DummyOutput(), ) - MockCoder.return_value.run.assert_called_once_with(with_message=message_file_content) + # Check that run was called with the correct message + mock_coder_instance.run.assert_called_once_with(with_message=message_file_content) os.remove(message_file_path) - def test_encodings_arg(self): + async def test_encodings_arg(self): fname = "foo.py" with GitTemporaryDirectory(): @@ -351,79 +364,83 @@ def side_effect(*args, **kwargs): MockSend.side_effect = side_effect - main(["--yes", fname, "--encoding", "iso-8859-15"]) + await main(["--yes", fname, "--encoding", "iso-8859-15"]) - def test_main_exit_calls_version_check(self): + async def test_main_exit_calls_version_check(self): with GitTemporaryDirectory(): with ( patch("aider.main.check_version") as mock_check_version, patch("aider.main.InputOutput") as mock_input_output, ): - main(["--exit", "--check-update"], input=DummyInput(), output=DummyOutput()) + await main(["--exit", "--check-update"], input=DummyInput(), output=DummyOutput()) mock_check_version.assert_called_once() mock_input_output.assert_called_once() @patch("aider.main.InputOutput") @patch("aider.coders.base_coder.Coder.run") - def test_main_message_adds_to_input_history(self, mock_run, MockInputOutput): + async def test_main_message_adds_to_input_history(self, mock_run, MockInputOutput): test_message = "test message" mock_io_instance = MockInputOutput.return_value - main(["--message", test_message], input=DummyInput(), output=DummyOutput()) + await main(["--message", test_message], input=DummyInput(), output=DummyOutput()) mock_io_instance.add_to_input_history.assert_called_once_with(test_message) @patch("aider.main.InputOutput") @patch("aider.coders.base_coder.Coder.run") - def test_yes(self, mock_run, MockInputOutput): + async def test_yes(self, mock_run, MockInputOutput): test_message = "test message" - main(["--yes", "--message", test_message]) + await main(["--yes", "--message", test_message]) args, kwargs = MockInputOutput.call_args self.assertTrue(args[1]) @patch("aider.main.InputOutput") @patch("aider.coders.base_coder.Coder.run") - def test_default_yes(self, mock_run, MockInputOutput): + async def test_default_yes(self, mock_run, MockInputOutput): test_message = "test message" - main(["--message", test_message]) + await main(["--message", test_message]) args, kwargs = MockInputOutput.call_args self.assertEqual(args[1], None) - def test_dark_mode_sets_code_theme(self): + async def test_dark_mode_sets_code_theme(self): # Mock InputOutput to capture the configuration with patch("aider.main.InputOutput") as MockInputOutput: MockInputOutput.return_value.get_input.return_value = None - main(["--dark-mode", "--no-git", "--exit"], input=DummyInput(), output=DummyOutput()) + await main( + ["--dark-mode", "--no-git", "--exit"], input=DummyInput(), output=DummyOutput() + ) # Ensure InputOutput was called MockInputOutput.assert_called_once() # Check if the code_theme setting is for dark mode _, kwargs = MockInputOutput.call_args self.assertEqual(kwargs["code_theme"], "monokai") - def test_light_mode_sets_code_theme(self): + async def test_light_mode_sets_code_theme(self): # Mock InputOutput to capture the configuration with patch("aider.main.InputOutput") as MockInputOutput: MockInputOutput.return_value.get_input.return_value = None - main(["--light-mode", "--no-git", "--exit"], input=DummyInput(), output=DummyOutput()) + await main( + ["--light-mode", "--no-git", "--exit"], input=DummyInput(), output=DummyOutput() + ) # Ensure InputOutput was called MockInputOutput.assert_called_once() # Check if the code_theme setting is for light mode _, kwargs = MockInputOutput.call_args self.assertEqual(kwargs["code_theme"], "default") - def create_env_file(self, file_name, content): + async def create_env_file(self, file_name, content): env_file_path = Path(self.tempdir) / file_name env_file_path.write_text(content) return env_file_path - def test_env_file_flag_sets_automatic_variable(self): + async def test_env_file_flag_sets_automatic_variable(self): env_file_path = self.create_env_file(".env.test", "AIDER_DARK_MODE=True") with patch("aider.main.InputOutput") as MockInputOutput: MockInputOutput.return_value.get_input.return_value = None MockInputOutput.return_value.get_input.confirm_ask = True - main( + await main( ["--env-file", str(env_file_path), "--no-git", "--exit"], input=DummyInput(), output=DummyOutput(), @@ -433,35 +450,35 @@ def test_env_file_flag_sets_automatic_variable(self): _, kwargs = MockInputOutput.call_args self.assertEqual(kwargs["code_theme"], "monokai") - def test_default_env_file_sets_automatic_variable(self): + async def test_default_env_file_sets_automatic_variable(self): self.create_env_file(".env", "AIDER_DARK_MODE=True") with patch("aider.main.InputOutput") as MockInputOutput: MockInputOutput.return_value.get_input.return_value = None MockInputOutput.return_value.get_input.confirm_ask = True - main(["--no-git", "--exit"], input=DummyInput(), output=DummyOutput()) + await main(["--no-git", "--exit"], input=DummyInput(), output=DummyOutput()) # Ensure InputOutput was called MockInputOutput.assert_called_once() # Check if the color settings are for dark mode _, kwargs = MockInputOutput.call_args self.assertEqual(kwargs["code_theme"], "monokai") - def test_false_vals_in_env_file(self): + async def test_false_vals_in_env_file(self): self.create_env_file(".env", "AIDER_SHOW_DIFFS=off") with patch("aider.coders.Coder.create") as MockCoder: - main(["--no-git", "--yes"], input=DummyInput(), output=DummyOutput()) + await main(["--no-git", "--yes"], input=DummyInput(), output=DummyOutput()) MockCoder.assert_called_once() _, kwargs = MockCoder.call_args self.assertEqual(kwargs["show_diffs"], False) - def test_true_vals_in_env_file(self): + async def test_true_vals_in_env_file(self): self.create_env_file(".env", "AIDER_SHOW_DIFFS=on") with patch("aider.coders.Coder.create") as MockCoder: - main(["--no-git", "--yes"], input=DummyInput(), output=DummyOutput()) + await main(["--no-git", "--yes"], input=DummyInput(), output=DummyOutput()) MockCoder.assert_called_once() _, kwargs = MockCoder.call_args self.assertEqual(kwargs["show_diffs"], True) - def test_lint_option(self): + async def test_lint_option(self): with GitTemporaryDirectory() as git_dir: # Create a dirty file in the root dirty_file = Path("dirty_file.py") @@ -485,7 +502,7 @@ def test_lint_option(self): MockLinter.return_value = "" # Run main with --lint option - main(["--lint", "--yes"]) + await main(["--lint", "--yes"]) # Check if the Linter was called with a filename ending in "dirty_file.py" # but not ending in "subdir/dirty_file.py" @@ -494,10 +511,10 @@ def test_lint_option(self): self.assertTrue(called_arg.endswith("dirty_file.py")) self.assertFalse(called_arg.endswith(f"subdir{os.path.sep}dirty_file.py")) - def test_verbose_mode_lists_env_vars(self): + async def test_verbose_mode_lists_env_vars(self): self.create_env_file(".env", "AIDER_DARK_MODE=on") with patch("sys.stdout", new_callable=StringIO) as mock_stdout: - main( + await main( ["--no-git", "--verbose", "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -513,7 +530,7 @@ def test_verbose_mode_lists_env_vars(self): self.assertRegex(relevant_output, r"AIDER_DARK_MODE:\s+on") self.assertRegex(relevant_output, r"dark_mode:\s+True") - def test_yaml_config_file_loading(self): + async def test_yaml_config_file_loading(self): with GitTemporaryDirectory() as git_dir: git_dir = Path(git_dir) @@ -543,7 +560,7 @@ def test_yaml_config_file_loading(self): patch("aider.coders.Coder.create") as MockCoder, ): # Test loading from specified config file - main( + await main( ["--yes", "--exit", "--config", str(named_config)], input=DummyInput(), output=DummyOutput(), @@ -553,7 +570,7 @@ def test_yaml_config_file_loading(self): self.assertEqual(kwargs["map_tokens"], 8192) # Test loading from current working directory - main(["--yes", "--exit"], input=DummyInput(), output=DummyOutput()) + await main(["--yes", "--exit"], input=DummyInput(), output=DummyOutput()) _, kwargs = MockCoder.call_args print("kwargs:", kwargs) # Add this line for debugging self.assertIn("main_model", kwargs, "main_model key not found in kwargs") @@ -562,46 +579,46 @@ def test_yaml_config_file_loading(self): # Test loading from git root cwd_config.unlink() - main(["--yes", "--exit"], input=DummyInput(), output=DummyOutput()) + await main(["--yes", "--exit"], input=DummyInput(), output=DummyOutput()) _, kwargs = MockCoder.call_args self.assertEqual(kwargs["main_model"].name, "gpt-4") self.assertEqual(kwargs["map_tokens"], 2048) # Test loading from home directory git_config.unlink() - main(["--yes", "--exit"], input=DummyInput(), output=DummyOutput()) + await main(["--yes", "--exit"], input=DummyInput(), output=DummyOutput()) _, kwargs = MockCoder.call_args self.assertEqual(kwargs["main_model"].name, "gpt-3.5-turbo") self.assertEqual(kwargs["map_tokens"], 1024) - def test_map_tokens_option(self): + async def test_map_tokens_option(self): with GitTemporaryDirectory(): with patch("aider.coders.base_coder.RepoMap") as MockRepoMap: MockRepoMap.return_value.max_map_tokens = 0 - main( + await main( ["--model", "gpt-4", "--map-tokens", "0", "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), ) MockRepoMap.assert_not_called() - def test_map_tokens_option_with_non_zero_value(self): + async def test_map_tokens_option_with_non_zero_value(self): with GitTemporaryDirectory(): with patch("aider.coders.base_coder.RepoMap") as MockRepoMap: MockRepoMap.return_value.max_map_tokens = 1000 - main( + await main( ["--model", "gpt-4", "--map-tokens", "1000", "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), ) MockRepoMap.assert_called_once() - def test_read_option(self): + async def test_read_option(self): with GitTemporaryDirectory(): test_file = "test_file.txt" Path(test_file).touch() - coder = main( + coder = await main( ["--read", test_file, "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -610,14 +627,14 @@ def test_read_option(self): self.assertIn(str(Path(test_file).resolve()), coder.abs_read_only_fnames) - def test_read_option_with_external_file(self): + async def test_read_option_with_external_file(self): with tempfile.NamedTemporaryFile(mode="w", delete=False) as external_file: external_file.write("External file content") external_file_path = external_file.name try: with GitTemporaryDirectory(): - coder = main( + coder = await main( ["--read", external_file_path, "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -629,7 +646,7 @@ def test_read_option_with_external_file(self): finally: os.unlink(external_file_path) - def test_model_metadata_file(self): + async def test_model_metadata_file(self): # Re-init so we don't have old data lying around from earlier test cases from aider import models @@ -646,7 +663,7 @@ def test_model_metadata_file(self): metadata_content = {"deepseek/deepseek-chat": {"max_input_tokens": 1234}} metadata_file.write_text(json.dumps(metadata_content)) - coder = main( + coder = await main( [ "--model", "deepseek/deepseek-chat", @@ -662,14 +679,14 @@ def test_model_metadata_file(self): self.assertEqual(coder.main_model.info["max_input_tokens"], 1234) - def test_sonnet_and_cache_options(self): + async def test_sonnet_and_cache_options(self): with GitTemporaryDirectory(): with patch("aider.coders.base_coder.RepoMap") as MockRepoMap: mock_repo_map = MagicMock() mock_repo_map.max_map_tokens = 1000 # Set a specific value MockRepoMap.return_value = mock_repo_map - main( + await main( ["--sonnet", "--cache-prompts", "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -681,9 +698,9 @@ def test_sonnet_and_cache_options(self): call_kwargs.get("refresh"), "files" ) # Check the 'refresh' keyword argument - def test_sonnet_and_cache_prompts_options(self): + async def test_sonnet_and_cache_prompts_options(self): with GitTemporaryDirectory(): - coder = main( + coder = await main( ["--sonnet", "--cache-prompts", "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -692,9 +709,9 @@ def test_sonnet_and_cache_prompts_options(self): self.assertTrue(coder.add_cache_headers) - def test_4o_and_cache_options(self): + async def test_4o_and_cache_options(self): with GitTemporaryDirectory(): - coder = main( + coder = await main( ["--4o", "--cache-prompts", "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -703,9 +720,9 @@ def test_4o_and_cache_options(self): self.assertFalse(coder.add_cache_headers) - def test_return_coder(self): + async def test_return_coder(self): with GitTemporaryDirectory(): - result = main( + result = await main( ["--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -713,7 +730,7 @@ def test_return_coder(self): ) self.assertIsInstance(result, Coder) - result = main( + result = await main( ["--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -721,9 +738,9 @@ def test_return_coder(self): ) self.assertIsNone(result) - def test_map_mul_option(self): + async def test_map_mul_option(self): with GitTemporaryDirectory(): - coder = main( + coder = await main( ["--map-mul", "5", "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -732,9 +749,9 @@ def test_map_mul_option(self): self.assertIsInstance(coder, Coder) self.assertEqual(coder.repo_map.map_mul_no_files, 5) - def test_suggest_shell_commands_default(self): + async def test_suggest_shell_commands_default(self): with GitTemporaryDirectory(): - coder = main( + coder = await main( ["--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -742,9 +759,9 @@ def test_suggest_shell_commands_default(self): ) self.assertTrue(coder.suggest_shell_commands) - def test_suggest_shell_commands_disabled(self): + async def test_suggest_shell_commands_disabled(self): with GitTemporaryDirectory(): - coder = main( + coder = await main( ["--no-suggest-shell-commands", "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -752,9 +769,9 @@ def test_suggest_shell_commands_disabled(self): ) self.assertFalse(coder.suggest_shell_commands) - def test_suggest_shell_commands_enabled(self): + async def test_suggest_shell_commands_enabled(self): with GitTemporaryDirectory(): - coder = main( + coder = await main( ["--suggest-shell-commands", "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -762,9 +779,9 @@ def test_suggest_shell_commands_enabled(self): ) self.assertTrue(coder.suggest_shell_commands) - def test_detect_urls_default(self): + async def test_detect_urls_default(self): with GitTemporaryDirectory(): - coder = main( + coder = await main( ["--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -772,9 +789,9 @@ def test_detect_urls_default(self): ) self.assertTrue(coder.detect_urls) - def test_detect_urls_disabled(self): + async def test_detect_urls_disabled(self): with GitTemporaryDirectory(): - coder = main( + coder = await main( ["--no-detect-urls", "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -782,9 +799,9 @@ def test_detect_urls_disabled(self): ) self.assertFalse(coder.detect_urls) - def test_detect_urls_enabled(self): + async def test_detect_urls_enabled(self): with GitTemporaryDirectory(): - coder = main( + coder = await main( ["--detect-urls", "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -792,7 +809,7 @@ def test_detect_urls_enabled(self): ) self.assertTrue(coder.detect_urls) - def test_accepts_settings_warnings(self): + async def test_accepts_settings_warnings(self): # Test that appropriate warnings are shown based on accepts_settings configuration with GitTemporaryDirectory(): # Test model that accepts the thinking_tokens setting @@ -800,7 +817,7 @@ def test_accepts_settings_warnings(self): patch("aider.io.InputOutput.tool_warning") as mock_warning, patch("aider.models.Model.set_thinking_tokens") as mock_set_thinking, ): - main( + await main( [ "--model", "anthropic/claude-3-7-sonnet-20250219", @@ -823,7 +840,7 @@ def test_accepts_settings_warnings(self): patch("aider.io.InputOutput.tool_warning") as mock_warning, patch("aider.models.Model.set_thinking_tokens") as mock_set_thinking, ): - main( + await main( [ "--model", "gpt-4o", @@ -850,7 +867,7 @@ def test_accepts_settings_warnings(self): patch("aider.io.InputOutput.tool_warning") as mock_warning, patch("aider.models.Model.set_reasoning_effort") as mock_set_reasoning, ): - main( + await main( ["--model", "o1", "--reasoning-effort", "3", "--yes", "--exit"], input=DummyInput(), output=DummyOutput(), @@ -866,7 +883,7 @@ def test_accepts_settings_warnings(self): patch("aider.io.InputOutput.tool_warning") as mock_warning, patch("aider.models.Model.set_reasoning_effort") as mock_set_reasoning, ): - main( + await main( ["--model", "gpt-3.5-turbo", "--reasoning-effort", "3", "--yes", "--exit"], input=DummyInput(), output=DummyOutput(), @@ -881,7 +898,7 @@ def test_accepts_settings_warnings(self): mock_set_reasoning.assert_not_called() @patch("aider.models.ModelInfoManager.set_verify_ssl") - def test_no_verify_ssl_sets_model_info_manager(self, mock_set_verify_ssl): + async def test_no_verify_ssl_sets_model_info_manager(self, mock_set_verify_ssl): with GitTemporaryDirectory(): # Mock Model class to avoid actual model initialization with patch("aider.models.Model") as mock_model: @@ -895,27 +912,27 @@ def test_no_verify_ssl_sets_model_info_manager(self, mock_set_verify_ssl): # Mock fuzzy_match_models to avoid string operations on MagicMock with patch("aider.models.fuzzy_match_models", return_value=[]): - main( + await main( ["--no-verify-ssl", "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), ) mock_set_verify_ssl.assert_called_once_with(False) - def test_pytest_env_vars(self): + async def test_pytest_env_vars(self): # Verify that environment variables from pytest.ini are properly set self.assertEqual(os.environ.get("AIDER_ANALYTICS"), "false") - def test_set_env_single(self): + async def test_set_env_single(self): # Test setting a single environment variable with GitTemporaryDirectory(): - main(["--set-env", "TEST_VAR=test_value", "--exit", "--yes"]) + await main(["--set-env", "TEST_VAR=test_value", "--exit", "--yes"]) self.assertEqual(os.environ.get("TEST_VAR"), "test_value") - def test_set_env_multiple(self): + async def test_set_env_multiple(self): # Test setting multiple environment variables with GitTemporaryDirectory(): - main( + await main( [ "--set-env", "TEST_VAR1=value1", @@ -928,38 +945,40 @@ def test_set_env_multiple(self): self.assertEqual(os.environ.get("TEST_VAR1"), "value1") self.assertEqual(os.environ.get("TEST_VAR2"), "value2") - def test_set_env_with_spaces(self): + async def test_set_env_with_spaces(self): # Test setting env var with spaces in value with GitTemporaryDirectory(): - main(["--set-env", "TEST_VAR=test value with spaces", "--exit", "--yes"]) + await main(["--set-env", "TEST_VAR=test value with spaces", "--exit", "--yes"]) self.assertEqual(os.environ.get("TEST_VAR"), "test value with spaces") - def test_set_env_invalid_format(self): + async def test_set_env_invalid_format(self): # Test invalid format handling with GitTemporaryDirectory(): - result = main(["--set-env", "INVALID_FORMAT", "--exit", "--yes"]) + result = await main(["--set-env", "INVALID_FORMAT", "--exit", "--yes"]) self.assertEqual(result, 1) - def test_api_key_single(self): + async def test_api_key_single(self): # Test setting a single API key with GitTemporaryDirectory(): - main(["--api-key", "anthropic=test-key", "--exit", "--yes"]) + await main(["--api-key", "anthropic=test-key", "--exit", "--yes"]) self.assertEqual(os.environ.get("ANTHROPIC_API_KEY"), "test-key") - def test_api_key_multiple(self): + async def test_api_key_multiple(self): # Test setting multiple API keys with GitTemporaryDirectory(): - main(["--api-key", "anthropic=key1", "--api-key", "openai=key2", "--exit", "--yes"]) + await main( + ["--api-key", "anthropic=key1", "--api-key", "openai=key2", "--exit", "--yes"] + ) self.assertEqual(os.environ.get("ANTHROPIC_API_KEY"), "key1") self.assertEqual(os.environ.get("OPENAI_API_KEY"), "key2") - def test_api_key_invalid_format(self): + async def test_api_key_invalid_format(self): # Test invalid format handling with GitTemporaryDirectory(): - result = main(["--api-key", "INVALID_FORMAT", "--exit", "--yes"]) + result = await main(["--api-key", "INVALID_FORMAT", "--exit", "--yes"]) self.assertEqual(result, 1) - def test_git_config_include(self): + async def test_git_config_include(self): # Test that aider respects git config includes for user.name and user.email with GitTemporaryDirectory() as git_dir: git_dir = Path(git_dir) @@ -984,7 +1003,7 @@ def test_git_config_include(self): git_config_content = git_config_path.read_text() # Run aider and verify it doesn't change the git config - main(["--yes", "--exit"], input=DummyInput(), output=DummyOutput()) + await main(["--yes", "--exit"], input=DummyInput(), output=DummyOutput()) # Check that the user settings are still the same using git command repo = git.Repo(git_dir) # Re-open repo to ensure we get fresh config @@ -995,7 +1014,7 @@ def test_git_config_include(self): git_config_content_after = git_config_path.read_text() self.assertEqual(git_config_content, git_config_content_after) - def test_git_config_include_directive(self): + async def test_git_config_include_directive(self): # Test that aider respects the include directive in git config with GitTemporaryDirectory() as git_dir: git_dir = Path(git_dir) @@ -1025,7 +1044,7 @@ def test_git_config_include_directive(self): self.assertEqual(repo.git.config("user.email"), "directive@example.com") # Run aider and verify it doesn't change the git config - main(["--yes", "--exit"], input=DummyInput(), output=DummyOutput()) + await main(["--yes", "--exit"], input=DummyInput(), output=DummyOutput()) # Check that the git config file wasn't modified config_after_aider = git_config.read_text() @@ -1036,7 +1055,7 @@ def test_git_config_include_directive(self): self.assertEqual(repo.git.config("user.name"), "Directive User") self.assertEqual(repo.git.config("user.email"), "directive@example.com") - def test_resolve_aiderignore_path(self): + async def test_resolve_aiderignore_path(self): # Import the function directly to test it from aider.args import resolve_aiderignore_path @@ -1055,12 +1074,12 @@ def test_resolve_aiderignore_path(self): rel_path = ".aiderignore" self.assertEqual(resolve_aiderignore_path(rel_path), rel_path) - def test_invalid_edit_format(self): + async def test_invalid_edit_format(self): with GitTemporaryDirectory(): # Suppress stderr for this test as argparse prints an error message with patch("sys.stderr", new_callable=StringIO) as mock_stderr: with self.assertRaises(SystemExit) as cm: - _ = main( + _ = await main( ["--edit-format", "not-a-real-format", "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -1071,11 +1090,11 @@ def test_invalid_edit_format(self): self.assertIn("invalid choice", stderr_output) self.assertIn("not-a-real-format", stderr_output) - def test_default_model_selection(self): + async def test_default_model_selection(self): with GitTemporaryDirectory(): # Test Anthropic API key os.environ["ANTHROPIC_API_KEY"] = "test-key" - coder = main( + coder = await main( ["--exit", "--yes"], input=DummyInput(), output=DummyOutput(), return_coder=True ) self.assertIn("sonnet", coder.main_model.name.lower()) @@ -1083,7 +1102,7 @@ def test_default_model_selection(self): # Test DeepSeek API key os.environ["DEEPSEEK_API_KEY"] = "test-key" - coder = main( + coder = await main( ["--exit", "--yes"], input=DummyInput(), output=DummyOutput(), return_coder=True ) self.assertIn("deepseek", coder.main_model.name.lower()) @@ -1091,7 +1110,7 @@ def test_default_model_selection(self): # Test OpenRouter API key os.environ["OPENROUTER_API_KEY"] = "test-key" - coder = main( + coder = await main( ["--exit", "--yes"], input=DummyInput(), output=DummyOutput(), return_coder=True ) self.assertIn("openrouter/", coder.main_model.name.lower()) @@ -1099,7 +1118,7 @@ def test_default_model_selection(self): # Test OpenAI API key os.environ["OPENAI_API_KEY"] = "test-key" - coder = main( + coder = await main( ["--exit", "--yes"], input=DummyInput(), output=DummyOutput(), return_coder=True ) self.assertIn("gpt-4", coder.main_model.name.lower()) @@ -1107,7 +1126,7 @@ def test_default_model_selection(self): # Test Gemini API key os.environ["GEMINI_API_KEY"] = "test-key" - coder = main( + coder = await main( ["--exit", "--yes"], input=DummyInput(), output=DummyOutput(), return_coder=True ) self.assertIn("gemini", coder.main_model.name.lower()) @@ -1116,25 +1135,25 @@ def test_default_model_selection(self): # Test no API keys - should offer OpenRouter OAuth with patch("aider.onboarding.offer_openrouter_oauth") as mock_offer_oauth: mock_offer_oauth.return_value = None # Simulate user declining or failure - result = main(["--exit", "--yes"], input=DummyInput(), output=DummyOutput()) + result = await main(["--exit", "--yes"], input=DummyInput(), output=DummyOutput()) self.assertEqual(result, 1) # Expect failure since no model could be selected mock_offer_oauth.assert_called_once() - def test_model_precedence(self): + async def test_model_precedence(self): with GitTemporaryDirectory(): # Test that earlier API keys take precedence os.environ["ANTHROPIC_API_KEY"] = "test-key" os.environ["OPENAI_API_KEY"] = "test-key" - coder = main( + coder = await main( ["--exit", "--yes"], input=DummyInput(), output=DummyOutput(), return_coder=True ) self.assertIn("sonnet", coder.main_model.name.lower()) del os.environ["ANTHROPIC_API_KEY"] del os.environ["OPENAI_API_KEY"] - def test_chat_language_spanish(self): + async def test_chat_language_spanish(self): with GitTemporaryDirectory(): - coder = main( + coder = await main( ["--chat-language", "Spanish", "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -1143,9 +1162,9 @@ def test_chat_language_spanish(self): system_info = coder.get_platform_info() self.assertIn("Spanish", system_info) - def test_commit_language_japanese(self): + async def test_commit_language_japanese(self): with GitTemporaryDirectory(): - coder = main( + coder = await main( ["--commit-language", "japanese", "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -1154,18 +1173,18 @@ def test_commit_language_japanese(self): self.assertIn("japanese", coder.commit_language) @patch("git.Repo.init") - def test_main_exit_with_git_command_not_found(self, mock_git_init): + async def test_main_exit_with_git_command_not_found(self, mock_git_init): mock_git_init.side_effect = git.exc.GitCommandNotFound("git", "Command 'git' not found") try: - result = main(["--exit", "--yes"], input=DummyInput(), output=DummyOutput()) + result = await main(["--exit", "--yes"], input=DummyInput(), output=DummyOutput()) except Exception as e: - self.fail(f"main() raised an unexpected exception: {e}") + self.fail(f"await main() raised an unexpected exception: {e}") - self.assertIsNone(result, "main() should return None when called with --exit") + self.assertIsNone(result, "await main() should return None when called with --exit") - def test_reasoning_effort_option(self): - coder = main( + async def test_reasoning_effort_option(self): + coder = await main( ["--reasoning-effort", "3", "--no-check-model-accepts-settings", "--yes", "--exit"], input=DummyInput(), output=DummyOutput(), @@ -1175,8 +1194,8 @@ def test_reasoning_effort_option(self): coder.main_model.extra_params.get("extra_body", {}).get("reasoning_effort"), "3" ) - def test_thinking_tokens_option(self): - coder = main( + async def test_thinking_tokens_option(self): + coder = await main( ["--model", "sonnet", "--thinking-tokens", "1000", "--yes", "--exit"], input=DummyInput(), output=DummyOutput(), @@ -1186,7 +1205,7 @@ def test_thinking_tokens_option(self): coder.main_model.extra_params.get("thinking", {}).get("budget_tokens"), 1000 ) - def test_list_models_includes_metadata_models(self): + async def test_list_models_includes_metadata_models(self): # Test that models from model-metadata.json appear in list-models output with GitTemporaryDirectory(): # Create a temporary model-metadata.json with test models @@ -1207,7 +1226,7 @@ def test_list_models_includes_metadata_models(self): # Capture stdout to check the output with patch("sys.stdout", new_callable=StringIO) as mock_stdout: - main( + await main( [ "--list-models", "unique-model", @@ -1224,7 +1243,7 @@ def test_list_models_includes_metadata_models(self): # Check that the unique model name from our metadata file is listed self.assertIn("test-provider/unique-model-name", output) - def test_list_models_includes_all_model_sources(self): + async def test_list_models_includes_all_model_sources(self): # Test that models from both litellm.model_cost and model-metadata.json # appear in list-models with GitTemporaryDirectory(): @@ -1241,7 +1260,7 @@ def test_list_models_includes_all_model_sources(self): # Capture stdout to check the output with patch("sys.stdout", new_callable=StringIO) as mock_stdout: - main( + await main( [ "--list-models", "metadata-only-model", @@ -1260,12 +1279,12 @@ def test_list_models_includes_all_model_sources(self): # Check that both models appear in the output self.assertIn("test-provider/metadata-only-model", output) - def test_check_model_accepts_settings_flag(self): + async def test_check_model_accepts_settings_flag(self): # Test that --check-model-accepts-settings affects whether settings are applied with GitTemporaryDirectory(): # When flag is on, setting shouldn't be applied to non-supporting model with patch("aider.models.Model.set_thinking_tokens") as mock_set_thinking: - main( + await main( [ "--model", "gpt-4o", @@ -1281,7 +1300,7 @@ def test_check_model_accepts_settings_flag(self): # Method should not be called because model doesn't support it and flag is on mock_set_thinking.assert_not_called() - def test_list_models_with_direct_resource_patch(self): + async def test_list_models_with_direct_resource_patch(self): # Test that models from resources/model-metadata.json are included in list-models output with GitTemporaryDirectory(): # Create a temporary file with test model metadata @@ -1306,7 +1325,7 @@ def test_list_models_with_direct_resource_patch(self): with patch("aider.main.importlib_resources.files", return_value=mock_files): # Capture stdout to check the output with patch("sys.stdout", new_callable=StringIO) as mock_stdout: - main( + await main( ["--list-models", "special", "--yes", "--no-gitignore"], input=DummyInput(), output=DummyOutput(), @@ -1318,7 +1337,7 @@ def test_list_models_with_direct_resource_patch(self): # When flag is off, setting should be applied regardless of support with patch("aider.models.Model.set_reasoning_effort") as mock_set_reasoning: - main( + await main( [ "--model", "gpt-3.5-turbo", @@ -1334,7 +1353,7 @@ def test_list_models_with_direct_resource_patch(self): # Method should be called because flag is off mock_set_reasoning.assert_called_once_with("3") - def test_model_accepts_settings_attribute(self): + async def test_model_accepts_settings_attribute(self): with GitTemporaryDirectory(): # Test with a model where we override the accepts_settings attribute with patch("aider.models.Model") as MockModel: @@ -1351,7 +1370,7 @@ def test_model_accepts_settings_attribute(self): mock_instance.get_weak_model.return_value = None # Run with both settings, but model only accepts reasoning_effort - main( + await main( [ "--model", "test-model", @@ -1372,10 +1391,10 @@ def test_model_accepts_settings_attribute(self): mock_instance.set_thinking_tokens.assert_not_called() @patch("aider.main.InputOutput") - def test_stream_and_cache_warning(self, MockInputOutput): + async def test_stream_and_cache_warning(self, MockInputOutput): mock_io_instance = MockInputOutput.return_value with GitTemporaryDirectory(): - main( + await main( ["--stream", "--cache-prompts", "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -1385,10 +1404,10 @@ def test_stream_and_cache_warning(self, MockInputOutput): ) @patch("aider.main.InputOutput") - def test_stream_without_cache_no_warning(self, MockInputOutput): + async def test_stream_without_cache_no_warning(self, MockInputOutput): mock_io_instance = MockInputOutput.return_value with GitTemporaryDirectory(): - main( + await main( ["--stream", "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), @@ -1396,20 +1415,20 @@ def test_stream_without_cache_no_warning(self, MockInputOutput): for call in mock_io_instance.tool_warning.call_args_list: self.assertNotIn("Cost estimates may be inaccurate", call[0][0]) - def test_argv_file_respects_git(self): + async def test_argv_file_respects_git(self): with GitTemporaryDirectory(): fname = Path("not_in_git.txt") fname.touch() with open(".gitignore", "w+") as f: f.write("not_in_git.txt") - coder = main( + coder = await main( argv=["--file", "not_in_git.txt"], input=DummyInput(), output=DummyOutput(), return_coder=True, ) self.assertNotIn("not_in_git.txt", str(coder.abs_fnames)) - self.assertFalse(coder.allowed_to_edit("not_in_git.txt")) + self.assertFalse(await coder.allowed_to_edit("not_in_git.txt")) def test_load_dotenv_files_override(self): with GitTemporaryDirectory() as git_dir: @@ -1471,10 +1490,10 @@ def test_load_dotenv_files_override(self): os.chdir(original_cwd) @patch("aider.main.InputOutput") - def test_cache_without_stream_no_warning(self, MockInputOutput): + async def test_cache_without_stream_no_warning(self, MockInputOutput): mock_io_instance = MockInputOutput.return_value with GitTemporaryDirectory(): - main( + await main( ["--cache-prompts", "--exit", "--yes", "--no-stream"], input=DummyInput(), output=DummyOutput(), @@ -1483,14 +1502,14 @@ def test_cache_without_stream_no_warning(self, MockInputOutput): self.assertNotIn("Cost estimates may be inaccurate", call[0][0]) @patch("aider.coders.Coder.create") - def test_mcp_servers_parsing(self, mock_coder_create): + async def test_mcp_servers_parsing(self, mock_coder_create): # Setup mock coder mock_coder_instance = MagicMock() mock_coder_create.return_value = mock_coder_instance # Test with --mcp-servers option with GitTemporaryDirectory(): - main( + await main( [ "--mcp-servers", '{"mcpServers":{"git":{"command":"uvx","args":["mcp-server-git"]}}}', @@ -1520,7 +1539,7 @@ def test_mcp_servers_parsing(self, mock_coder_create): mcp_content = {"mcpServers": {"git": {"command": "uvx", "args": ["mcp-server-git"]}}} mcp_file.write_text(json.dumps(mcp_content)) - main( + await main( ["--mcp-servers-file", str(mcp_file), "--exit", "--yes"], input=DummyInput(), output=DummyOutput(), diff --git a/tests/basic/test_models.py b/tests/basic/test_models.py index 73e13febb2a..11e42b807af 100644 --- a/tests/basic/test_models.py +++ b/tests/basic/test_models.py @@ -425,16 +425,16 @@ def test_aider_extra_model_settings(self): except OSError: pass - @patch("aider.models.litellm.completion") + @patch("aider.models.litellm.acompletion") @patch.object(Model, "token_count") - def test_ollama_num_ctx_set_when_missing(self, mock_token_count, mock_completion): + async def test_ollama_num_ctx_set_when_missing(self, mock_token_count, mock_completion): mock_token_count.return_value = 1000 model = Model("ollama/llama3") model.extra_params = {} messages = [{"role": "user", "content": "Hello"}] - model.send_completion(messages, functions=None, stream=False) + await model.send_completion(messages, functions=None, stream=False) # Verify num_ctx was calculated and added to call expected_ctx = int(1000 * 1.25) + 8192 # 9442 @@ -447,13 +447,13 @@ def test_ollama_num_ctx_set_when_missing(self, mock_token_count, mock_completion timeout=600, ) - @patch("aider.models.litellm.completion") - def test_modern_tool_call_propagation(self, mock_completion): + @patch("aider.models.litellm.acompletion") + async def test_modern_tool_call_propagation(self, mock_completion): # Test modern tool calling (used for MCP Server Tool Calls) model = Model("gpt-4") messages = [{"role": "user", "content": "Hello"}] - model.send_completion( + await model.send_completion( messages, functions=None, stream=False, tools=[dict(type="function", function="test")] ) @@ -466,13 +466,13 @@ def test_modern_tool_call_propagation(self, mock_completion): timeout=600, ) - @patch("aider.models.litellm.completion") - def test_legacy_tool_call_propagation(self, mock_completion): + @patch("aider.models.litellm.acompletion") + async def test_legacy_tool_call_propagation(self, mock_completion): # Test modern tool calling (used for legacy server tool calling) model = Model("gpt-4") messages = [{"role": "user", "content": "Hello"}] - model.send_completion(messages, functions=["test"], stream=False) + await model.send_completion(messages, functions=["test"], stream=False) mock_completion.assert_called_with( model=model.name, @@ -483,13 +483,13 @@ def test_legacy_tool_call_propagation(self, mock_completion): timeout=600, ) - @patch("aider.models.litellm.completion") - def test_ollama_uses_existing_num_ctx(self, mock_completion): + @patch("aider.models.litellm.acompletion") + async def test_ollama_uses_existing_num_ctx(self, mock_completion): model = Model("ollama/llama3") model.extra_params = {"num_ctx": 4096} messages = [{"role": "user", "content": "Hello"}] - model.send_completion(messages, functions=None, stream=False) + await model.send_completion(messages, functions=None, stream=False) # Should use provided num_ctx from extra_params mock_completion.assert_called_once_with( @@ -501,13 +501,13 @@ def test_ollama_uses_existing_num_ctx(self, mock_completion): timeout=600, ) - @patch("aider.models.litellm.completion") - def test_non_ollama_no_num_ctx(self, mock_completion): + @patch("aider.models.litellm.acompletion") + async def test_non_ollama_no_num_ctx(self, mock_completion): model = Model("gpt-4") model.extra_params = {} messages = [{"role": "user", "content": "Hello"}] - model.send_completion(messages, functions=None, stream=False) + await model.send_completion(messages, functions=None, stream=False) # Regular models shouldn't get num_ctx mock_completion.assert_called_once_with( @@ -534,13 +534,13 @@ def test_use_temperature_settings(self): model.use_temperature = 0.7 self.assertEqual(model.use_temperature, 0.7) - @patch("aider.models.litellm.completion") - def test_request_timeout_default(self, mock_completion): + @patch("aider.models.litellm.acompletion") + async def test_request_timeout_default(self, mock_completion): # Test default timeout is used when not specified in extra_params model = Model("gpt-4") model.extra_params = {} messages = [{"role": "user", "content": "Hello"}] - model.send_completion(messages, functions=None, stream=False) + await model.send_completion(messages, functions=None, stream=False) mock_completion.assert_called_with( model=model.name, messages=messages, @@ -549,13 +549,13 @@ def test_request_timeout_default(self, mock_completion): timeout=600, # Default timeout ) - @patch("aider.models.litellm.completion") - def test_request_timeout_from_extra_params(self, mock_completion): + @patch("aider.models.litellm.acompletion") + async def test_request_timeout_from_extra_params(self, mock_completion): # Test timeout from extra_params overrides default model = Model("gpt-4") model.extra_params = {"timeout": 300} # 5 minutes messages = [{"role": "user", "content": "Hello"}] - model.send_completion(messages, functions=None, stream=False) + await model.send_completion(messages, functions=None, stream=False) mock_completion.assert_called_with( model=model.name, messages=messages, @@ -564,13 +564,13 @@ def test_request_timeout_from_extra_params(self, mock_completion): timeout=300, # From extra_params ) - @patch("aider.models.litellm.completion") - def test_use_temperature_in_send_completion(self, mock_completion): + @patch("aider.models.litellm.acompletion") + async def test_use_temperature_in_send_completion(self, mock_completion): # Test use_temperature=True sends temperature=0 model = Model("gpt-4") model.extra_params = {} messages = [{"role": "user", "content": "Hello"}] - model.send_completion(messages, functions=None, stream=False) + await model.send_completion(messages, functions=None, stream=False) mock_completion.assert_called_with( model=model.name, messages=messages, @@ -582,7 +582,7 @@ def test_use_temperature_in_send_completion(self, mock_completion): # Test use_temperature=False doesn't send temperature model = Model("github/o1-mini") messages = [{"role": "user", "content": "Hello"}] - model.send_completion(messages, functions=None, stream=False) + await model.send_completion(messages, functions=None, stream=False) self.assertNotIn("temperature", mock_completion.call_args.kwargs) # Test use_temperature as float sends that value @@ -590,7 +590,7 @@ def test_use_temperature_in_send_completion(self, mock_completion): model.extra_params = {} model.use_temperature = 0.7 messages = [{"role": "user", "content": "Hello"}] - model.send_completion(messages, functions=None, stream=False) + await model.send_completion(messages, functions=None, stream=False) mock_completion.assert_called_with( model=model.name, messages=messages, diff --git a/tests/basic/test_onboarding.py b/tests/basic/test_onboarding.py index 398bd7f4ee3..b5b63412e8a 100644 --- a/tests/basic/test_onboarding.py +++ b/tests/basic/test_onboarding.py @@ -288,19 +288,19 @@ def test_exchange_code_for_key_request_exception(self, mock_post): @patch("aider.onboarding.try_to_select_default_model", return_value="gpt-4o") @patch("aider.onboarding.offer_openrouter_oauth") - def test_select_default_model_already_specified(self, mock_offer_oauth, mock_try_select): + async def test_select_default_model_already_specified(self, mock_offer_oauth, mock_try_select): """Test select_default_model returns args.model if provided.""" args = argparse.Namespace(model="specific-model") io_mock = DummyIO() analytics_mock = DummyAnalytics() - selected_model = select_default_model(args, io_mock, analytics_mock) + selected_model = await select_default_model(args, io_mock, analytics_mock) self.assertEqual(selected_model, "specific-model") mock_try_select.assert_not_called() mock_offer_oauth.assert_not_called() @patch("aider.onboarding.try_to_select_default_model", return_value="gpt-4o") @patch("aider.onboarding.offer_openrouter_oauth") - def test_select_default_model_found_via_env(self, mock_offer_oauth, mock_try_select): + async def test_select_default_model_found_via_env(self, mock_offer_oauth, mock_try_select): """Test select_default_model returns model found by try_to_select.""" args = argparse.Namespace(model=None) # No model specified io_mock = DummyIO() @@ -308,7 +308,7 @@ def test_select_default_model_found_via_env(self, mock_offer_oauth, mock_try_sel analytics_mock = DummyAnalytics() analytics_mock.event = MagicMock() # Track events - selected_model = select_default_model(args, io_mock, analytics_mock) + selected_model = await select_default_model(args, io_mock, analytics_mock) self.assertEqual(selected_model, "gpt-4o") mock_try_select.assert_called_once() @@ -324,7 +324,7 @@ def test_select_default_model_found_via_env(self, mock_offer_oauth, mock_try_sel @patch( "aider.onboarding.offer_openrouter_oauth", return_value=False ) # OAuth offered but fails/declined - def test_select_default_model_no_keys_oauth_fail(self, mock_offer_oauth, mock_try_select): + async def test_select_default_model_no_keys_oauth_fail(self, mock_offer_oauth, mock_try_select): """Test select_default_model offers OAuth when no keys, but OAuth fails.""" args = argparse.Namespace(model=None) io_mock = DummyIO() @@ -332,7 +332,7 @@ def test_select_default_model_no_keys_oauth_fail(self, mock_offer_oauth, mock_tr io_mock.offer_url = MagicMock() analytics_mock = DummyAnalytics() - selected_model = select_default_model(args, io_mock, analytics_mock) + selected_model = await select_default_model(args, io_mock, analytics_mock) self.assertIsNone(selected_model) self.assertEqual(mock_try_select.call_count, 2) # Called before and after oauth attempt @@ -349,14 +349,16 @@ def test_select_default_model_no_keys_oauth_fail(self, mock_offer_oauth, mock_tr @patch( "aider.onboarding.offer_openrouter_oauth", return_value=True ) # OAuth offered and succeeds - def test_select_default_model_no_keys_oauth_success(self, mock_offer_oauth, mock_try_select): + async def test_select_default_model_no_keys_oauth_success( + self, mock_offer_oauth, mock_try_select + ): """Test select_default_model offers OAuth, which succeeds.""" args = argparse.Namespace(model=None) io_mock = DummyIO() io_mock.tool_warning = MagicMock() analytics_mock = DummyAnalytics() - selected_model = select_default_model(args, io_mock, analytics_mock) + selected_model = await select_default_model(args, io_mock, analytics_mock) self.assertEqual(selected_model, "openrouter/deepseek/deepseek-r1:free") self.assertEqual(mock_try_select.call_count, 2) # Called before and after oauth @@ -374,14 +376,14 @@ def test_select_default_model_no_keys_oauth_success(self, mock_offer_oauth, mock # --- Tests for offer_openrouter_oauth --- @patch("aider.onboarding.start_openrouter_oauth_flow", return_value="new_or_key") @patch.dict(os.environ, {}, clear=True) # Ensure no key exists initially - def test_offer_openrouter_oauth_confirm_yes_success(self, mock_start_oauth): + async def test_offer_openrouter_oauth_confirm_yes_success(self, mock_start_oauth): """Test offer_openrouter_oauth when user confirms and OAuth succeeds.""" io_mock = DummyIO() io_mock.confirm_ask = MagicMock(return_value=True) # User says yes analytics_mock = DummyAnalytics() analytics_mock.event = MagicMock() - result = offer_openrouter_oauth(io_mock, analytics_mock) + result = await offer_openrouter_oauth(io_mock, analytics_mock) self.assertTrue(result) io_mock.confirm_ask.assert_called_once() @@ -394,7 +396,7 @@ def test_offer_openrouter_oauth_confirm_yes_success(self, mock_start_oauth): @patch("aider.onboarding.start_openrouter_oauth_flow", return_value=None) # OAuth fails @patch.dict(os.environ, {}, clear=True) - def test_offer_openrouter_oauth_confirm_yes_fail(self, mock_start_oauth): + async def test_offer_openrouter_oauth_confirm_yes_fail(self, mock_start_oauth): """Test offer_openrouter_oauth when user confirms but OAuth fails.""" io_mock = DummyIO() io_mock.confirm_ask = MagicMock(return_value=True) # User says yes @@ -402,7 +404,7 @@ def test_offer_openrouter_oauth_confirm_yes_fail(self, mock_start_oauth): analytics_mock = DummyAnalytics() analytics_mock.event = MagicMock() - result = offer_openrouter_oauth(io_mock, analytics_mock) + result = await offer_openrouter_oauth(io_mock, analytics_mock) self.assertFalse(result) io_mock.confirm_ask.assert_called_once() @@ -415,14 +417,14 @@ def test_offer_openrouter_oauth_confirm_yes_fail(self, mock_start_oauth): analytics_mock.event.assert_any_call("oauth_flow_failure") @patch("aider.onboarding.start_openrouter_oauth_flow") - def test_offer_openrouter_oauth_confirm_no(self, mock_start_oauth): + async def test_offer_openrouter_oauth_confirm_no(self, mock_start_oauth): """Test offer_openrouter_oauth when user declines.""" io_mock = DummyIO() io_mock.confirm_ask = MagicMock(return_value=False) # User says no analytics_mock = DummyAnalytics() analytics_mock.event = MagicMock() - result = offer_openrouter_oauth(io_mock, analytics_mock) + result = await offer_openrouter_oauth(io_mock, analytics_mock) self.assertFalse(result) io_mock.confirm_ask.assert_called_once() diff --git a/tests/basic/test_reasoning.py b/tests/basic/test_reasoning.py index 0386f29bc68..24aa9334197 100644 --- a/tests/basic/test_reasoning.py +++ b/tests/basic/test_reasoning.py @@ -13,7 +13,7 @@ class TestReasoning(unittest.TestCase): - def test_send_with_reasoning_content(self): + async def test_send_with_reasoning_content(self): """Test that reasoning content is properly formatted and output.""" # Setup IO with no pretty io = InputOutput(pretty=False) @@ -21,7 +21,7 @@ def test_send_with_reasoning_content(self): # Setup model and coder model = Model("gpt-3.5-turbo") - coder = Coder.create(model, None, io=io, stream=False) + coder = await Coder.create(model, None, io=io, stream=False) # Test data reasoning_content = "My step-by-step reasoning process" @@ -47,7 +47,7 @@ def __init__(self, content, reasoning_content): with patch.object(model, "send_completion", return_value=(mock_hash, mock_completion)): # Call send with a simple message messages = [{"role": "user", "content": "test prompt"}] - list(coder.send(messages)) + list(await coder.send(messages)) # Now verify ai_output was called with the right content io.assistant_output.assert_called_once() @@ -74,7 +74,7 @@ def __init__(self, content, reasoning_content): reasoning_pos, main_pos, "Reasoning content should appear before main content" ) - def test_send_with_reasoning_content_stream(self): + async def test_send_with_reasoning_content_stream(self): """Test that streaming reasoning content is properly formatted and output.""" # Setup IO with pretty output for streaming io = InputOutput(pretty=True) @@ -83,7 +83,7 @@ def test_send_with_reasoning_content_stream(self): # Setup model and coder model = Model("gpt-3.5-turbo") - coder = Coder.create(model, None, io=io, stream=True) + coder = await Coder.create(model, None, io=io, stream=True) # Ensure the coder shows pretty output coder.show_pretty = MagicMock(return_value=True) @@ -147,7 +147,7 @@ def __init__( # Call send with a simple message messages = [{"role": "user", "content": "test prompt"}] - list(coder.send(messages)) + list(await coder.send(messages)) # Verify mdstream.update was called multiple times mock_mdstream.update.assert_called() @@ -187,7 +187,7 @@ def __init__( expected_content = "Final answer after reasoning" self.assertEqual(coder.partial_response_content.strip(), expected_content) - def test_send_with_think_tags(self): + async def test_send_with_think_tags(self): """Test that tags are properly processed and formatted.""" # Setup IO with no pretty io = InputOutput(pretty=False) @@ -196,7 +196,7 @@ def test_send_with_think_tags(self): # Setup model and coder model = Model("gpt-3.5-turbo") model.reasoning_tag = "think" # Set to remove tags - coder = Coder.create(model, None, io=io, stream=False) + coder = await Coder.create(model, None, io=io, stream=False) # Test data reasoning_content = "My step-by-step reasoning process" @@ -229,7 +229,7 @@ def __init__(self, content): with patch.object(model, "send_completion", return_value=(mock_hash, mock_completion)): # Call send with a simple message messages = [{"role": "user", "content": "test prompt"}] - list(coder.send(messages)) + list(await coder.send(messages)) # Now verify ai_output was called with the right content io.assistant_output.assert_called_once() @@ -256,7 +256,7 @@ def __init__(self, content): coder.remove_reasoning_content() self.assertEqual(coder.partial_response_content.strip(), main_content.strip()) - def test_send_with_think_tags_stream(self): + async def test_send_with_think_tags_stream(self): """Test that streaming with tags is properly processed and formatted.""" # Setup IO with pretty output for streaming io = InputOutput(pretty=True) @@ -266,7 +266,7 @@ def test_send_with_think_tags_stream(self): # Setup model and coder model = Model("gpt-3.5-turbo") model.reasoning_tag = "think" # Set to remove tags - coder = Coder.create(model, None, io=io, stream=True) + coder = await Coder.create(model, None, io=io, stream=True) # Ensure the coder shows pretty output coder.show_pretty = MagicMock(return_value=True) @@ -329,7 +329,7 @@ def __init__( # Call send with a simple message messages = [{"role": "user", "content": "test prompt"}] - list(coder.send(messages)) + list(await coder.send(messages)) # Verify mdstream.update was called multiple times mock_mdstream.update.assert_called() @@ -399,7 +399,7 @@ def test_remove_reasoning_content(self): text = "Just regular text" self.assertEqual(remove_reasoning_content(text, "think"), text) - def test_send_with_reasoning(self): + async def test_send_with_reasoning(self): """Test that reasoning content from the 'reasoning' attribute is properly formatted and output.""" # Setup IO with no pretty @@ -408,7 +408,7 @@ def test_send_with_reasoning(self): # Setup model and coder model = Model("gpt-3.5-turbo") - coder = Coder.create(model, None, io=io, stream=False) + coder = await Coder.create(model, None, io=io, stream=False) # Test data reasoning_content = "My step-by-step reasoning process" @@ -437,7 +437,7 @@ def __init__(self, content, reasoning): with patch.object(model, "send_completion", return_value=(mock_hash, mock_completion)): # Call send with a simple message messages = [{"role": "user", "content": "test prompt"}] - list(coder.send(messages)) + list(await coder.send(messages)) # Now verify ai_output was called with the right content io.assistant_output.assert_called_once() @@ -464,7 +464,7 @@ def __init__(self, content, reasoning): reasoning_pos, main_pos, "Reasoning content should appear before main content" ) - def test_send_with_reasoning_stream(self): + async def test_send_with_reasoning_stream(self): """Test that streaming reasoning content from the 'reasoning' attribute is properly formatted and output.""" # Setup IO with pretty output for streaming @@ -474,7 +474,7 @@ def test_send_with_reasoning_stream(self): # Setup model and coder model = Model("gpt-3.5-turbo") - coder = Coder.create(model, None, io=io, stream=True) + coder = await Coder.create(model, None, io=io, stream=True) # Ensure the coder shows pretty output coder.show_pretty = MagicMock(return_value=True) @@ -539,7 +539,7 @@ def __init__( # Call send with a simple message messages = [{"role": "user", "content": "test prompt"}] - list(coder.send(messages)) + list(await coder.send(messages)) # Verify mdstream.update was called multiple times mock_mdstream.update.assert_called() @@ -580,7 +580,7 @@ def __init__( self.assertEqual(coder.partial_response_content.strip(), expected_content) @patch("aider.models.litellm.completion") - def test_simple_send_with_retries_removes_reasoning(self, mock_completion): + async def test_simple_send_with_retries_removes_reasoning(self, mock_completion): """Test that simple_send_with_retries correctly removes reasoning content.""" model = Model("deepseek-r1") # This model has reasoning_tag="think" @@ -594,7 +594,7 @@ def test_simple_send_with_retries_removes_reasoning(self, mock_completion): mock_completion.return_value = mock_response messages = [{"role": "user", "content": "test"}] - result = model.simple_send_with_retries(messages) + result = await model.simple_send_with_retries(messages) expected = """Here is some text diff --git a/tests/basic/test_repo.py b/tests/basic/test_repo.py index 71ba9479830..c207cfea5be 100644 --- a/tests/basic/test_repo.py +++ b/tests/basic/test_repo.py @@ -4,7 +4,7 @@ import time import unittest from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import git @@ -128,8 +128,8 @@ def test_diffs_between_commits(self): diffs = git_repo.diff_commits(False, "HEAD~1", "HEAD") self.assertIn("two", diffs) - @patch("aider.models.Model.simple_send_with_retries") - def test_get_commit_message(self, mock_send): + @patch("aider.models.Model.simple_send_with_retries", new_callable=AsyncMock) + async def test_get_commit_message(self, mock_send): mock_send.side_effect = ["", "a good commit message"] model1 = Model("gpt-3.5-turbo") @@ -152,8 +152,8 @@ def test_get_commit_message(self, mock_send): second_call_messages = mock_send.call_args_list[1][0][0] # Get messages from second call self.assertEqual(first_call_messages, second_call_messages) - @patch("aider.models.Model.simple_send_with_retries") - def test_get_commit_message_strip_quotes(self, mock_send): + @patch("aider.models.Model.simple_send_with_retries", new_callable=AsyncMock) + async def test_get_commit_message_strip_quotes(self, mock_send): mock_send.return_value = '"a good commit message"' repo = GitRepo(InputOutput(), None, None, models=[self.GPT35]) @@ -163,8 +163,8 @@ def test_get_commit_message_strip_quotes(self, mock_send): # Assert that the returned message is the expected one self.assertEqual(result, "a good commit message") - @patch("aider.models.Model.simple_send_with_retries") - def test_get_commit_message_no_strip_unmatched_quotes(self, mock_send): + @patch("aider.models.Model.simple_send_with_retries", new_callable=AsyncMock) + async def test_get_commit_message_no_strip_unmatched_quotes(self, mock_send): mock_send.return_value = 'a good "commit message"' repo = GitRepo(InputOutput(), None, None, models=[self.GPT35]) @@ -174,8 +174,8 @@ def test_get_commit_message_no_strip_unmatched_quotes(self, mock_send): # Assert that the returned message is the expected one self.assertEqual(result, 'a good "commit message"') - @patch("aider.models.Model.simple_send_with_retries") - def test_get_commit_message_with_custom_prompt(self, mock_send): + @patch("aider.models.Model.simple_send_with_retries", new_callable=AsyncMock) + async def test_get_commit_message_with_custom_prompt(self, mock_send): mock_send.return_value = "Custom commit message" custom_prompt = "Generate a commit message in the style of Shakespeare" @@ -189,7 +189,7 @@ def test_get_commit_message_with_custom_prompt(self, mock_send): @unittest.skipIf(platform.system() == "Windows", "Git env var behavior differs on Windows") @patch("aider.repo.GitRepo.get_commit_message") - def test_commit_with_custom_committer_name(self, mock_send): + async def test_commit_with_custom_committer_name(self, mock_send): mock_send.return_value = '"a good commit message"' with GitTemporaryDirectory(): @@ -209,7 +209,7 @@ def test_commit_with_custom_committer_name(self, mock_send): # commit a change with aider_edits=True (using default attributes) fname.write_text("new content") - commit_result = git_repo.commit(fnames=[str(fname)], aider_edits=True) + commit_result = await git_repo.commit(fnames=[str(fname)], aider_edits=True) self.assertIsNotNone(commit_result) # check the committer name (defaults interpreted as True) @@ -219,7 +219,7 @@ def test_commit_with_custom_committer_name(self, mock_send): # commit a change without aider_edits (using default attributes) fname.write_text("new content again!") - commit_result = git_repo.commit(fnames=[str(fname)], aider_edits=False) + commit_result = await git_repo.commit(fnames=[str(fname)], aider_edits=False) self.assertIsNotNone(commit_result) # check the committer name (author not modified, committer still modified by default) @@ -232,7 +232,9 @@ def test_commit_with_custom_committer_name(self, mock_send): io, None, None, attribute_author=False, attribute_committer=False ) fname.write_text("explicit false content") - commit_result = git_repo_explicit_false.commit(fnames=[str(fname)], aider_edits=True) + commit_result = await git_repo_explicit_false.commit( + fnames=[str(fname)], aider_edits=True + ) self.assertIsNotNone(commit_result) commit = raw_repo.head.commit self.assertEqual(commit.author.name, "Test User") # Explicit False @@ -247,7 +249,7 @@ def test_commit_with_custom_committer_name(self, mock_send): # Test user commit with explicit no-committer attribution git_repo_user_no_committer = GitRepo(io, None, None, attribute_committer=False) fname.write_text("user no committer content") - commit_result = git_repo_user_no_committer.commit( + commit_result = await git_repo_user_no_committer.commit( fnames=[str(fname)], aider_edits=False ) self.assertIsNotNone(commit_result) @@ -264,7 +266,7 @@ def test_commit_with_custom_committer_name(self, mock_send): ) @unittest.skipIf(platform.system() == "Windows", "Git env var behavior differs on Windows") - def test_commit_with_co_authored_by(self): + async def test_commit_with_co_authored_by(self): with GitTemporaryDirectory(): # new repo raw_repo = git.Repo() @@ -293,7 +295,7 @@ def test_commit_with_co_authored_by(self): # commit a change with aider_edits=True and co-authored-by flag fname.write_text("new content") - commit_result = git_repo.commit( + commit_result = await git_repo.commit( fnames=[str(fname)], aider_edits=True, coder=mock_coder, message="Aider edit" ) self.assertIsNotNone(commit_result) @@ -315,7 +317,7 @@ def test_commit_with_co_authored_by(self): ) @unittest.skipIf(platform.system() == "Windows", "Git env var behavior differs on Windows") - def test_commit_co_authored_by_with_explicit_name_modification(self): + async def test_commit_co_authored_by_with_explicit_name_modification(self): # Test scenario where Co-authored-by is true AND # author/committer modification are explicitly True with GitTemporaryDirectory(): @@ -347,7 +349,7 @@ def test_commit_co_authored_by_with_explicit_name_modification(self): # commit a change with aider_edits=True and combo flags fname.write_text("new content combo") - commit_result = git_repo.commit( + commit_result = await git_repo.commit( fnames=[str(fname)], aider_edits=True, coder=mock_coder, message="Aider combo edit" ) self.assertIsNotNone(commit_result) @@ -372,7 +374,7 @@ def test_commit_co_authored_by_with_explicit_name_modification(self): ) @unittest.skipIf(platform.system() == "Windows", "Git env var behavior differs on Windows") - def test_commit_ai_edits_no_coauthor_explicit_false(self): + async def test_commit_ai_edits_no_coauthor_explicit_false(self): # Test AI edits (aider_edits=True) when co-authored-by is False, # but author or committer attribution is explicitly disabled. with GitTemporaryDirectory(): @@ -399,7 +401,7 @@ def test_commit_ai_edits_no_coauthor_explicit_false(self): git_repo_no_author = GitRepo(io, None, None) fname.write_text("no author content") - commit_result = git_repo_no_author.commit( + commit_result = await git_repo_no_author.commit( fnames=[str(fname)], aider_edits=True, coder=mock_coder_no_author, @@ -423,7 +425,7 @@ def test_commit_ai_edits_no_coauthor_explicit_false(self): git_repo_no_committer = GitRepo(io, None, None) fname.write_text("no committer content") - commit_result = git_repo_no_committer.commit( + commit_result = await git_repo_no_committer.commit( fnames=[str(fname)], aider_edits=True, coder=mock_coder_no_committer, @@ -621,7 +623,7 @@ def test_subtree_only(self): self.assertNotIn(str(another_subdir_file), tracked_files) @patch("aider.models.Model.simple_send_with_retries") - def test_noop_commit(self, mock_send): + async def test_noop_commit(self, mock_send): mock_send.return_value = '"a good commit message"' with GitTemporaryDirectory(): @@ -636,11 +638,11 @@ def test_noop_commit(self, mock_send): git_repo = GitRepo(InputOutput(), None, None) - commit_result = git_repo.commit(fnames=[str(fname)]) + commit_result = await git_repo.commit(fnames=[str(fname)]) self.assertIsNone(commit_result) @unittest.skipIf(platform.system() == "Windows", "Git hook execution differs on Windows") - def test_git_commit_verify(self): + async def test_git_commit_verify(self): """Test that git_commit_verify controls whether --no-verify is passed to git commit""" with GitTemporaryDirectory(): # Create a new repo @@ -670,22 +672,24 @@ def test_git_commit_verify(self): git_repo_verify = GitRepo(io, None, None, git_commit_verify=True) # Attempt to commit - should fail due to pre-commit hook - commit_result = git_repo_verify.commit(fnames=[str(fname)], message="Should fail") + commit_result = await git_repo_verify.commit(fnames=[str(fname)], message="Should fail") self.assertIsNone(commit_result) # Create GitRepo with verify=False git_repo_no_verify = GitRepo(io, None, None, git_commit_verify=False) # Attempt to commit - should succeed by bypassing the hook - commit_result = git_repo_no_verify.commit(fnames=[str(fname)], message="Should succeed") + commit_result = await git_repo_no_verify.commit( + fnames=[str(fname)], message="Should succeed" + ) self.assertIsNotNone(commit_result) # Verify the commit was actually made latest_commit_msg = raw_repo.head.commit.message self.assertEqual(latest_commit_msg.strip(), "Should succeed") - @patch("aider.models.Model.simple_send_with_retries") - def test_get_commit_message_uses_system_prompt_prefix(self, mock_send): + @patch("aider.models.Model.simple_send_with_retries", new_callable=AsyncMock) + async def test_get_commit_message_uses_system_prompt_prefix(self, mock_send): """ Verify that GitRepo.get_commit_message() prepends the model.system_prompt_prefix to the system prompt sent to the LLM. diff --git a/tests/basic/test_scripting.py b/tests/basic/test_scripting.py index b1b3de90c6e..98e69c96ea1 100644 --- a/tests/basic/test_scripting.py +++ b/tests/basic/test_scripting.py @@ -1,6 +1,6 @@ import unittest from pathlib import Path -from unittest.mock import patch +from unittest.mock import AsyncMock, patch from aider.coders import Coder from aider.models import Model @@ -8,8 +8,8 @@ class TestScriptingAPI(unittest.TestCase): - @patch("aider.coders.base_coder.Coder.send") - def test_basic_scripting(self, mock_send): + @patch("aider.coders.base_coder.Coder.send", new_callable=AsyncMock) + async def test_basic_scripting(self, mock_send): with GitTemporaryDirectory(): # Setup def mock_send_side_effect(messages, functions=None): @@ -24,10 +24,10 @@ def mock_send_side_effect(messages, functions=None): fname.touch() fnames = [str(fname)] model = Model("gpt-4-turbo") - coder = Coder.create(main_model=model, fnames=fnames) + coder = await Coder.create(main_model=model, fnames=fnames) - result1 = coder.run("make a script that prints hello world") - result2 = coder.run("make it say goodbye") + result1 = await coder.run("make a script that prints hello world") + result2 = await coder.run("make it say goodbye") # Assertions self.assertEqual(mock_send.call_count, 2) diff --git a/tests/basic/test_sendchat.py b/tests/basic/test_sendchat.py index 652c88871d1..153ab2f421f 100644 --- a/tests/basic/test_sendchat.py +++ b/tests/basic/test_sendchat.py @@ -19,9 +19,9 @@ def test_litellm_exceptions(self): litellm_ex = LiteLLMExceptions() litellm_ex._load(strict=True) - @patch("litellm.completion") + @patch("litellm.acompletion") @patch("builtins.print") - def test_simple_send_with_retries_rate_limit_error(self, mock_print, mock_completion): + async def test_simple_send_with_retries_rate_limit_error(self, mock_print, mock_completion): mock = MagicMock() mock.status_code = 500 @@ -40,28 +40,28 @@ def test_simple_send_with_retries_rate_limit_error(self, mock_print, mock_comple model = Model(self.mock_model) model.verbose = True - model.simple_send_with_retries(self.mock_messages) + await model.simple_send_with_retries(self.mock_messages) assert mock_print.call_count > 0 - @patch("litellm.completion") - def test_send_completion_basic(self, mock_completion): + @patch("litellm.acompletion") + async def test_send_completion_basic(self, mock_completion): # Setup mock response mock_response = MagicMock() mock_completion.return_value = mock_response # Test basic send_completion - hash_obj, response = Model(self.mock_model).send_completion( + hash_obj, response = await Model(self.mock_model).send_completion( self.mock_messages, functions=None, stream=False ) assert response == mock_response mock_completion.assert_called_once() - @patch("litellm.completion") - def test_send_completion_with_functions(self, mock_completion): + @patch("litellm.acompletion") + async def test_send_completion_with_functions(self, mock_completion): mock_function = {"name": "test_function", "parameters": {"type": "object"}} - hash_obj, response = Model(self.mock_model).send_completion( + hash_obj, response = await Model(self.mock_model).send_completion( self.mock_messages, functions=[mock_function], stream=False ) @@ -70,19 +70,19 @@ def test_send_completion_with_functions(self, mock_completion): assert "tools" in called_kwargs assert called_kwargs["tools"][0]["function"] == mock_function - @patch("litellm.completion") - def test_simple_send_attribute_error(self, mock_completion): + @patch("litellm.acompletion") + async def test_simple_send_attribute_error(self, mock_completion): # Setup mock to raise AttributeError mock_completion.return_value = MagicMock() mock_completion.return_value.choices = None # Should return None on AttributeError - result = Model(self.mock_model).simple_send_with_retries(self.mock_messages) + result = await Model(self.mock_model).simple_send_with_retries(self.mock_messages) assert result is None - @patch("litellm.completion") + @patch("litellm.acompletion") @patch("builtins.print") - def test_simple_send_non_retryable_error(self, mock_print, mock_completion): + async def test_simple_send_non_retryable_error(self, mock_print, mock_completion): # Test with an error that shouldn't trigger retries mock = MagicMock() mock.status_code = 400 @@ -94,7 +94,7 @@ def test_simple_send_non_retryable_error(self, mock_print, mock_completion): model = Model(self.mock_model) model.verbose = True - result = model.simple_send_with_retries(self.mock_messages) + result = await model.simple_send_with_retries(self.mock_messages) assert result is None # Should only print the error message assert mock_print.call_count > 0 diff --git a/tests/basic/test_wholefile.py b/tests/basic/test_wholefile.py index deb192ec7e4..41f717458b1 100644 --- a/tests/basic/test_wholefile.py +++ b/tests/basic/test_wholefile.py @@ -24,11 +24,11 @@ def tearDown(self): os.chdir(self.original_cwd) shutil.rmtree(self.tempdir, ignore_errors=True) - def test_no_files(self): + async def test_no_files(self): # Initialize WholeFileCoder with the temporary directory io = InputOutput(yes=True) - coder = WholeFileCoder(main_model=self.GPT35, io=io, fnames=[]) + coder = await WholeFileCoder(main_model=self.GPT35, io=io, fnames=[]) coder.partial_response_content = ( 'To print "Hello, World!" in most programming languages, you can use the following' ' code:\n\n```python\nprint("Hello, World!")\n```\n\nThis code will output "Hello,' @@ -38,18 +38,18 @@ def test_no_files(self): # This is throwing ValueError! coder.render_incremental_response(True) - def test_no_files_new_file_should_ask(self): + async def test_no_files_new_file_should_ask(self): io = InputOutput(yes=False) # <- yes=FALSE - coder = WholeFileCoder(main_model=self.GPT35, io=io, fnames=[]) + coder = await WholeFileCoder(main_model=self.GPT35, io=io, fnames=[]) coder.partial_response_content = ( 'To print "Hello, World!" in most programming languages, you can use the following' ' code:\n\nfoo.js\n```python\nprint("Hello, World!")\n```\n\nThis code will output' ' "Hello, World!" to the console.' ) - coder.apply_updates() + await coder.apply_updates() self.assertFalse(Path("foo.js").exists()) - def test_update_files(self): + async def test_update_files(self): # Create a sample file in the temporary directory sample_file = "sample.txt" with open(sample_file, "w") as f: @@ -57,13 +57,13 @@ def test_update_files(self): # Initialize WholeFileCoder with the temporary directory io = InputOutput(yes=True) - coder = WholeFileCoder(main_model=self.GPT35, io=io, fnames=[sample_file]) + coder = await WholeFileCoder(main_model=self.GPT35, io=io, fnames=[sample_file]) # Set the partial response content with the updated content coder.partial_response_content = f"{sample_file}\n```\nUpdated content\n```" # Call update_files method - edited_files = coder.apply_updates() + edited_files = await coder.apply_updates() # Check if the sample file was updated self.assertIn("sample.txt", edited_files) @@ -73,7 +73,7 @@ def test_update_files(self): updated_content = f.read() self.assertEqual(updated_content, "Updated content\n") - def test_update_files_live_diff(self): + async def test_update_files_live_diff(self): # Create a sample file in the temporary directory sample_file = "sample.txt" with open(sample_file, "w") as f: @@ -81,7 +81,7 @@ def test_update_files_live_diff(self): # Initialize WholeFileCoder with the temporary directory io = InputOutput(yes=True) - coder = WholeFileCoder(main_model=self.GPT35, io=io, fnames=[sample_file]) + coder = await WholeFileCoder(main_model=self.GPT35, io=io, fnames=[sample_file]) # Set the partial response content with the updated content coder.partial_response_content = f"{sample_file}\n```\n0\n\1\n2\n" @@ -91,7 +91,7 @@ def test_update_files_live_diff(self): # the live diff should be concise, since we haven't changed anything yet self.assertLess(len(lines), 20) - def test_update_files_with_existing_fence(self): + async def test_update_files_with_existing_fence(self): # Create a sample file in the temporary directory sample_file = "sample.txt" original_content = """ @@ -105,7 +105,7 @@ def test_update_files_with_existing_fence(self): # Initialize WholeFileCoder with the temporary directory io = InputOutput(yes=True) - coder = WholeFileCoder(main_model=self.GPT35, io=io, fnames=[sample_file]) + coder = await WholeFileCoder(main_model=self.GPT35, io=io, fnames=[sample_file]) coder.choose_fence() @@ -117,7 +117,7 @@ def test_update_files_with_existing_fence(self): ) # Call update_files method - edited_files = coder.apply_updates() + edited_files = await coder.apply_updates() # Check if the sample file was updated self.assertIn("sample.txt", edited_files) @@ -127,7 +127,7 @@ def test_update_files_with_existing_fence(self): updated_content = f.read() self.assertEqual(updated_content, "Updated content\n") - def test_update_files_bogus_path_prefix(self): + async def test_update_files_bogus_path_prefix(self): # Create a sample file in the temporary directory sample_file = "sample.txt" with open(sample_file, "w") as f: @@ -135,14 +135,14 @@ def test_update_files_bogus_path_prefix(self): # Initialize WholeFileCoder with the temporary directory io = InputOutput(yes=True) - coder = WholeFileCoder(main_model=self.GPT35, io=io, fnames=[sample_file]) + coder = await WholeFileCoder(main_model=self.GPT35, io=io, fnames=[sample_file]) # Set the partial response content with the updated content # With path/to/ prepended onto the filename coder.partial_response_content = f"path/to/{sample_file}\n```\nUpdated content\n```" # Call update_files method - edited_files = coder.apply_updates() + edited_files = await coder.apply_updates() # Check if the sample file was updated self.assertIn("sample.txt", edited_files) @@ -152,7 +152,7 @@ def test_update_files_bogus_path_prefix(self): updated_content = f.read() self.assertEqual(updated_content, "Updated content\n") - def test_update_files_not_in_chat(self): + async def test_update_files_not_in_chat(self): # Create a sample file in the temporary directory sample_file = "sample.txt" with open(sample_file, "w") as f: @@ -160,13 +160,13 @@ def test_update_files_not_in_chat(self): # Initialize WholeFileCoder with the temporary directory io = InputOutput(yes=True) - coder = WholeFileCoder(main_model=self.GPT35, io=io) + coder = await WholeFileCoder(main_model=self.GPT35, io=io) # Set the partial response content with the updated content coder.partial_response_content = f"{sample_file}\n```\nUpdated content\n```" # Call update_files method - edited_files = coder.apply_updates() + edited_files = await coder.apply_updates() # Check if the sample file was updated self.assertIn("sample.txt", edited_files) @@ -176,7 +176,7 @@ def test_update_files_not_in_chat(self): updated_content = f.read() self.assertEqual(updated_content, "Updated content\n") - def test_update_files_no_filename_single_file_in_chat(self): + async def test_update_files_no_filename_single_file_in_chat(self): sample_file = "accumulate.py" content = ( "def accumulate(collection, operation):\n return [operation(x) for x in" @@ -188,7 +188,7 @@ def test_update_files_no_filename_single_file_in_chat(self): # Initialize WholeFileCoder with the temporary directory io = InputOutput(yes=True) - coder = WholeFileCoder(main_model=self.GPT35, io=io, fnames=[sample_file]) + coder = await WholeFileCoder(main_model=self.GPT35, io=io, fnames=[sample_file]) # Set the partial response content with the updated content coder.partial_response_content = ( @@ -199,7 +199,7 @@ def test_update_files_no_filename_single_file_in_chat(self): ) # Call update_files method - edited_files = coder.apply_updates() + edited_files = await coder.apply_updates() # Check if the sample file was updated self.assertIn(sample_file, edited_files) @@ -209,7 +209,7 @@ def test_update_files_no_filename_single_file_in_chat(self): updated_content = f.read() self.assertEqual(updated_content, content) - def test_update_files_earlier_filename(self): + async def test_update_files_earlier_filename(self): fname_a = Path("a.txt") fname_b = Path("b.txt") @@ -231,13 +231,13 @@ def test_update_files_earlier_filename(self): """ # Initialize WholeFileCoder with the temporary directory io = InputOutput(yes=True) - coder = WholeFileCoder(main_model=self.GPT35, io=io, fnames=[fname_a, fname_b]) + coder = await WholeFileCoder(main_model=self.GPT35, io=io, fnames=[fname_a, fname_b]) # Set the partial response content with the updated content coder.partial_response_content = response # Call update_files method - edited_files = coder.apply_updates() + edited_files = await coder.apply_updates() # Check if the sample file was updated self.assertIn(str(fname_a), edited_files) @@ -246,7 +246,7 @@ def test_update_files_earlier_filename(self): self.assertEqual(fname_a.read_text(), "after a\n") self.assertEqual(fname_b.read_text(), "after b\n") - def test_update_hash_filename(self): + async def test_update_hash_filename(self): fname_a = Path("a.txt") fname_b = Path("b.txt") @@ -267,13 +267,13 @@ def test_update_hash_filename(self): """ # Initialize WholeFileCoder with the temporary directory io = InputOutput(yes=True) - coder = WholeFileCoder(main_model=self.GPT35, io=io, fnames=[fname_a, fname_b]) + coder = await WholeFileCoder(main_model=self.GPT35, io=io, fnames=[fname_a, fname_b]) # Set the partial response content with the updated content coder.partial_response_content = response # Call update_files method - edited_files = coder.apply_updates() + edited_files = await coder.apply_updates() dump(edited_files) @@ -284,7 +284,7 @@ def test_update_hash_filename(self): self.assertEqual(fname_a.read_text(), "after a\n") self.assertEqual(fname_b.read_text(), "after b\n") - def test_update_named_file_but_extra_unnamed_code_block(self): + async def test_update_named_file_but_extra_unnamed_code_block(self): sample_file = "hello.py" new_content = "new\ncontent\ngoes\nhere\n" @@ -293,7 +293,7 @@ def test_update_named_file_but_extra_unnamed_code_block(self): # Initialize WholeFileCoder with the temporary directory io = InputOutput(yes=True) - coder = WholeFileCoder(main_model=self.GPT35, io=io, fnames=[sample_file]) + coder = await WholeFileCoder(main_model=self.GPT35, io=io, fnames=[sample_file]) # Set the partial response content with the updated content coder.partial_response_content = ( @@ -306,7 +306,7 @@ def test_update_named_file_but_extra_unnamed_code_block(self): ) # Call update_files method - edited_files = coder.apply_updates() + edited_files = await coder.apply_updates() # Check if the sample file was updated self.assertIn(sample_file, edited_files) @@ -316,7 +316,7 @@ def test_update_named_file_but_extra_unnamed_code_block(self): updated_content = f.read() self.assertEqual(updated_content, new_content) - def test_full_edit(self): + async def test_full_edit(self): # Create a few temporary files _, file1 = tempfile.mkstemp() @@ -326,12 +326,14 @@ def test_full_edit(self): files = [file1] # Initialize the Coder object with the mocked IO and mocked repo - coder = Coder.create(self.GPT35, "whole", io=InputOutput(), fnames=files, stream=False) + coder = await Coder.create( + self.GPT35, "whole", io=InputOutput(), fnames=files, stream=False + ) # no trailing newline so the response content below doesn't add ANOTHER newline new_content = "new\ntwo\nthree" - def mock_send(*args, **kwargs): + async def mock_send(*args, **kwargs): coder.partial_response_content = f""" Do this: @@ -347,7 +349,7 @@ def mock_send(*args, **kwargs): coder.send = MagicMock(side_effect=mock_send) # Call the run method with a message - coder.run(with_message="hi") + await coder.run(with_message="hi") content = Path(file1).read_text(encoding="utf-8") diff --git a/tests/help/test_help.py b/tests/help/test_help.py index a7222185e75..76183c59bd9 100644 --- a/tests/help/test_help.py +++ b/tests/help/test_help.py @@ -1,6 +1,7 @@ +import asyncio import time import unittest -from unittest.mock import MagicMock +from unittest.mock import AsyncMock from requests.exceptions import ConnectionError, ReadTimeout @@ -49,29 +50,44 @@ def retry_with_backoff(func, max_time=60, initial_delay=1, backoff_factor=2): @classmethod def setUpClass(cls): + # Run the async setup synchronously for unittest compatibility + asyncio.run(cls.async_setup_class()) + + @classmethod + async def async_setup_class(cls): io = InputOutput(pretty=False, yes=True) GPT35 = Model("gpt-3.5-turbo") - coder = Coder.create(GPT35, None, io) + coder = await Coder.create(GPT35, None, io) commands = Commands(io, coder) - help_coder_run = MagicMock(return_value="") - aider.coders.HelpCoder.run = help_coder_run - - def run_help_command(): - try: - commands.cmd_help("hi") - except aider.commands.SwitchCoder: - pass - else: - # If no exception was raised, fail the test - assert False, "SwitchCoder exception was not raised" + help_mock = AsyncMock() + help_mock.run.return_value = "" + aider.coders.HelpCoder.run = help_mock.run - # Use retry with backoff for the help command that loads models - cls.retry_with_backoff(run_help_command) + # Simple retry logic without the complex lambda + start_time = time.time() + delay = 1 + max_time = 60 - help_coder_run.assert_called_once() + while time.time() - start_time < max_time: + try: + try: + await commands.cmd_help("hi") + except aider.commands.SwitchCoder: + break + else: + # If no exception was raised, fail the test + assert False, "SwitchCoder exception was not raised" + break + except (ReadTimeout, ConnectionError): + await asyncio.sleep(delay) + delay = min(delay * 2, 15) + else: + raise Exception("Retry timeout exceeded") + + help_mock.run.assert_called_once() def test_init(self): help_inst = Help() diff --git a/tests/scrape/test_playwright_disable.py b/tests/scrape/test_playwright_disable.py index 7a851a4e76a..e9d65073650 100644 --- a/tests/scrape/test_playwright_disable.py +++ b/tests/scrape/test_playwright_disable.py @@ -52,7 +52,7 @@ def fake_playwright(url): assert called["called"] -def test_commands_web_disable_playwright(monkeypatch): +async def test_commands_web_disable_playwright(monkeypatch): """ Test that Commands.cmd_web does not emit a misleading warning when --disable-playwright is set. """ @@ -129,7 +129,7 @@ def scrape(self, url): args = type("Args", (), {"disable_playwright": True})() commands = Commands(io, coder, args=args) - commands.cmd_web("http://example.com") + await commands.cmd_web("http://example.com") # Should not emit a warning about playwright assert not io.warnings # Should not contain message "For the best web scraping, install Playwright:" diff --git a/tests/scrape/test_scrape.py b/tests/scrape/test_scrape.py index 7a94de022a7..e769793bb04 100644 --- a/tests/scrape/test_scrape.py +++ b/tests/scrape/test_scrape.py @@ -40,7 +40,7 @@ def setUp(self): @patch("aider.commands.install_playwright") @patch("aider.commands.Scraper") - def test_cmd_web_imports_playwright(self, mock_scraper_class, mock_install_playwright): + async def test_cmd_web_imports_playwright(self, mock_scraper_class, mock_install_playwright): # Since install_playwright is mocked, we need to simulate its side effect # of making the playwright module importable. def mock_install(*args, **kwargs): @@ -58,7 +58,7 @@ def mock_install(*args, **kwargs): try: # Run the cmd_web command - result = self.commands.cmd_web("https://example.com", return_content=True) + result = await self.commands.cmd_web("https://example.com", return_content=True) # Assert that the result contains some content self.assertIsNotNone(result)