diff --git a/aider/args.py b/aider/args.py index 4c8865f0040..8c0f64a7fb2 100644 --- a/aider/args.py +++ b/aider/args.py @@ -794,6 +794,18 @@ def get_parser(default_config_files, git_root): help="Preserve the existing .aider.todo.txt file on startup (default: False)", default=False, ) + group.add_argument( + "--auto-save", + action=argparse.BooleanOptionalAction, + default=False, + help="Enable/disable automatic saving of sessions as 'auto-save' (default: False)", + ) + group.add_argument( + "--auto-load", + action=argparse.BooleanOptionalAction, + default=False, + help="Enable/disable automatic loading of 'auto-save' session on startup (default: False)", + ) group.add_argument( "--disable-playwright", action="store_true", diff --git a/aider/coders/agent_coder.py b/aider/coders/agent_coder.py index 4d70b19d954..80f1d97484c 100644 --- a/aider/coders/agent_coder.py +++ b/aider/coders/agent_coder.py @@ -140,7 +140,7 @@ def __init__(self, *args, **kwargs): def _build_tool_registry(self): """ Build a registry of available tools with their normalized names and process_response functions. - Handles agent configuration with whitelist/blacklist functionality. + Handles agent configuration with includelist/excludelist functionality. Returns: dict: Mapping of normalized tool names to tool modules @@ -182,10 +182,14 @@ def _build_tool_registry(self): # Process agent configuration if provided agent_config = self._get_agent_config() - tools_whitelist = agent_config.get("tools_whitelist", []) - tools_blacklist = agent_config.get("tools_blacklist", []) + tools_includelist = agent_config.get( + "tools_includelist", agent_config.get("tools_whitelist", []) + ) + tools_excludelist = agent_config.get( + "tools_excludelist", agent_config.get("tools_blacklist", []) + ) - # Always include essential tools regardless of whitelist/blacklist + # Always include essential tools regardless of includelist/excludelist essential_tools = {"makeeditable", "replacetext", "view", "finished"} for module in tool_modules: if hasattr(module, "NORM_NAME") and hasattr(module, "process_response"): @@ -194,16 +198,16 @@ def _build_tool_registry(self): # Check if tool should be included based on configuration should_include = True - # If whitelist is specified, only include tools in whitelist - if tools_whitelist: - should_include = tool_name in tools_whitelist + # If includelist is specified, only include tools in includelist + if tools_includelist: + should_include = tool_name in tools_includelist # Always include essential tools if tool_name in essential_tools: should_include = True - # Exclude tools in blacklist (unless they're essential) - if tool_name in tools_blacklist and tool_name not in essential_tools: + # Exclude tools in excludelist (unless they're essential) + if tool_name in tools_excludelist and tool_name not in essential_tools: should_include = False if should_include: @@ -236,10 +240,10 @@ def _get_agent_config(self): # Set defaults for missing values if "large_file_token_threshold" not in config: config["large_file_token_threshold"] = 25000 - if "tools_whitelist" not in config: - config["tools_whitelist"] = [] - if "tools_blacklist" not in config: - config["tools_blacklist"] = [] + if "tools_includelist" not in config: + config["tools_includelist"] = [] + if "tools_excludelist" not in config: + config["tools_excludelist"] = [] # Apply configuration to instance self.large_file_token_threshold = config["large_file_token_threshold"] @@ -255,12 +259,6 @@ def get_local_tool_schemas(self): if hasattr(tool_module, "schema"): schemas.append(tool_module.schema) - # Add git schemas from the tool registry - git_tools = [git_diff, git_log, git_show, git_status] - for git_tool in git_tools: - if hasattr(git_tool, "schema"): - schemas.append(git_tool.schema) - return schemas async def initialize_mcp_tools(self): @@ -935,6 +933,7 @@ async def process_tool_calls(self, tool_call_response): """ Track tool usage before calling the base implementation. """ + self.auto_save_session() if self.partial_response_tool_calls: for tool_call in self.partial_response_tool_calls: @@ -976,6 +975,7 @@ async def reply_completed(self): ) = await self._process_tool_commands(content) if self.agent_finished: + self.tool_usage_history = [] return True # Since we are no longer suppressing, the partial_response_content IS the final content. diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index 589826f0681..13e67e8cb3e 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -59,6 +59,7 @@ from aider.repo import ANY_GIT_ERROR, GitRepo from aider.repomap import RepoMap from aider.run_cmd import run_cmd +from aider.sessions import SessionManager from aider.utils import format_tokens, is_image_file from ..dump import dump # noqa: F401 @@ -1121,6 +1122,8 @@ async def _run_linear(self, with_message=None, preproc=True): self.keyboard_interrupt() except (asyncio.CancelledError, IndexError): pass + + self.auto_save_session() except EOFError: return finally: @@ -1272,6 +1275,8 @@ async def _run_patched(self, with_message=None, preproc=True): self.io.stop_spinner() self.keyboard_interrupt() + + self.auto_save_session() except EOFError: return finally: @@ -2240,12 +2245,20 @@ def _print_tool_call_info(self, server_tool_calls): for server, tool_calls in server_tool_calls.items(): for tool_call in tool_calls: - self.io.tool_output(f"Tool Call: {tool_call.function.name}") + color_start = "[blue]" if self.pretty else "" + color_end = "[/blue]" if self.pretty else "" + self.io.tool_output( + f"{color_start}Tool Call:{color_end} {server.name} • {tool_call.function.name}" + ) # Parse and format arguments as headers with values if tool_call.function.arguments: # Only do JSON unwrapping for tools containing "replace" in their name - if "replace" in tool_call.function.name.lower(): + if ( + "replace" in tool_call.function.name.lower() + or "insert" in tool_call.function.name.lower() + or "update" in tool_call.function.name.lower() + ): try: args_dict = json.loads(tool_call.function.arguments) first_key = True @@ -2258,7 +2271,7 @@ def _print_tool_call_info(self, server_tool_calls): if first_key: self.io.tool_output("\n") first_key = False - self.io.tool_output(f"{key}:") + self.io.tool_output(f"{color_start}{key}:{color_end}") # Split the value by newlines and output each line separately if isinstance(value, str): for line in value.split("\n"): @@ -2269,13 +2282,11 @@ def _print_tool_call_info(self, server_tool_calls): except json.JSONDecodeError: # If JSON parsing fails, show raw arguments raw_args = tool_call.function.arguments - self.io.tool_output(f"Arguments: {raw_args}") + self.io.tool_output(f"{color_start}Arguments:{color_end} {raw_args}") else: # For non-replace tools, show raw arguments raw_args = tool_call.function.arguments - self.io.tool_output(f"Arguments: {raw_args}") - - self.io.tool_output(f"MCP Server: {server.name}") + self.io.tool_output(f"{color_start}Arguments:{color_end} {raw_args}") if self.verbose: self.io.tool_output(f"Tool ID: {tool_call.id}") @@ -2295,7 +2306,15 @@ def _gather_server_tool_calls(self, tool_calls): return None server_tool_calls = {} + tool_id_set = set() + for tool_call in tool_calls: + # LLM APIs sometimes return duplicates and that's annoying part 3 + if tool_call.get("id") in tool_id_set: + continue + + tool_id_set.add(tool_call.get("id")) + # Check if this tool_call matches any MCP tool for server_name, server_tools in self.mcp_tools: for tool in server_tools: @@ -2343,8 +2362,16 @@ async def _exec_server_tools(server, tool_calls_list): try: # Connect to the server once session = await server.connect() + tool_id_set = set() + # Execute all tool calls for this server for tool_call in tool_calls_list: + # LLM APIs sometimes return duplicates and that's annoying part 4 + if tool_call.id in tool_id_set: + continue + + tool_id_set.add(tool_call.id) + try: # Arguments can be a stream of JSON objects. # We need to parse them and run a tool call for each. @@ -3491,6 +3518,17 @@ def apply_edits(self, edits): def apply_edits_dry_run(self, edits): return edits + def auto_save_session(self): + """Automatically save the current session as 'auto-save'.""" + if not getattr(self.args, "auto_save", False): + return + try: + session_manager = SessionManager(self, self.io) + session_manager.save_session("auto-save", False) + except Exception: + # Don't show errors for auto-save to avoid interrupting the user experience + pass + async def run_shell_commands(self): if not self.suggest_shell_commands: return "" diff --git a/aider/commands.py b/aider/commands.py index 3523e98be3b..1c3c84f0882 100644 --- a/aider/commands.py +++ b/aider/commands.py @@ -1,14 +1,11 @@ import asyncio import glob -import json import os import re import subprocess import sys import tempfile -import time from collections import OrderedDict -from datetime import datetime from os.path import expanduser from pathlib import Path @@ -17,7 +14,7 @@ from prompt_toolkit.completion import Completion, PathCompleter from prompt_toolkit.document import Document -from aider import models, prompts, voice +from aider import models, prompts, sessions, voice from aider.editor import pipe_editor from aider.format_settings import format_settings from aider.help import Help, install_help_extra @@ -89,7 +86,7 @@ def __init__( self.original_read_only_fnames = set(original_read_only_fnames or []) self.cmd_running = False - def cmd_model(self, args): + async def cmd_model(self, args): "Switch the Main Model to a new LLM" model_name = args.strip() @@ -103,7 +100,7 @@ def cmd_model(self, args): editor_model=self.coder.main_model.editor_model.name, weak_model=self.coder.main_model.weak_model.name, ) - models.sanity_check_models(self.io, model) + await models.sanity_check_models(self.io, model) # Check if the current edit format is the default for the old model old_model_edit_format = self.coder.main_model.edit_format @@ -116,7 +113,7 @@ def cmd_model(self, args): raise SwitchCoder(main_model=model, edit_format=new_edit_format) - def cmd_editor_model(self, args): + async def cmd_editor_model(self, args): "Switch the Editor Model to a new LLM" model_name = args.strip() @@ -125,10 +122,10 @@ def cmd_editor_model(self, args): editor_model=model_name, weak_model=self.coder.main_model.weak_model.name, ) - models.sanity_check_models(self.io, model) + await models.sanity_check_models(self.io, model) raise SwitchCoder(main_model=model) - def cmd_weak_model(self, args): + async def cmd_weak_model(self, args): "Switch the Weak Model to a new LLM" model_name = args.strip() @@ -137,7 +134,7 @@ def cmd_weak_model(self, args): editor_model=self.coder.main_model.editor_model.name, weak_model=model_name, ) - models.sanity_check_models(self.io, model) + await models.sanity_check_models(self.io, model) raise SwitchCoder(main_model=model) def cmd_chat_mode(self, args): @@ -235,9 +232,13 @@ async def cmd_web(self, args, return_content=False): if disable_playwright: res = False else: - res = await install_playwright(self.io) - if not res: + try: + res = await install_playwright(self.io) + if not res: + self.io.tool_warning("Unable to initialize playwright.") + except Exception: self.io.tool_warning("Unable to initialize playwright.") + res = False self.scraper = Scraper( print_error=self.io.tool_error, @@ -245,7 +246,7 @@ async def cmd_web(self, args, return_content=False): verify_ssl=self.verify_ssl, ) - content = self.scraper.scrape(url) or "" + content = await self.scraper.scrape(url) or "" content = f"Here is the content of {url}:\n\n" + content if return_content: return content @@ -1373,7 +1374,7 @@ async def cmd_help(self, args): from aider.coders.base_coder import Coder if not self.help: - res = install_help_extra(self.io) + res = await install_help_extra(self.io) if not res: self.io.tool_error("Unable to initialize interactive help.") return @@ -2034,168 +2035,29 @@ def _find_session_file(self, session_name): def cmd_save_session(self, args): """Save the current chat session to a named file in .aider/sessions/""" - if not args.strip(): - self.io.tool_error("Please provide a session name.") - return - - session_name = args.strip() - session_file = self._get_session_file_path(session_name) - - # Collect session data - session_data = { - "version": "1.0", - "timestamp": time.time(), - "session_name": session_name, - "model": self.coder.main_model.name, - "edit_format": self.coder.edit_format, - "chat_history": { - "done_messages": self.coder.done_messages, - "cur_messages": self.coder.cur_messages, - }, - "files": { - "editable": [self.coder.get_rel_fname(f) for f in self.coder.abs_fnames], - "read_only": [self.coder.get_rel_fname(f) for f in self.coder.abs_read_only_fnames], - "read_only_stubs": [ - self.coder.get_rel_fname(f) for f in self.coder.abs_read_only_stubs_fnames - ], - }, - "settings": { - "root": self.coder.root, - "auto_commits": self.coder.auto_commits, - "auto_lint": self.coder.auto_lint, - "auto_test": self.coder.auto_test, - }, - } - - try: - with open(session_file, "w", encoding="utf-8") as f: - json.dump(session_data, f, indent=2, ensure_ascii=False) - self.io.tool_output(f"Session saved to: {session_file}") - except Exception as e: - self.io.tool_error(f"Error saving session: {e}") + session_manager = sessions.SessionManager(self.coder, self.io) + session_manager.save_session(args.strip()) def cmd_list_sessions(self, args): """List all saved sessions in .aider/sessions/""" - session_dir = self._get_session_directory() - session_files = list(session_dir.glob("*.json")) + session_manager = sessions.SessionManager(self.coder, self.io) + sessions_list = session_manager.list_sessions() - if not session_files: - self.io.tool_output("No saved sessions found.") + if not sessions_list: return self.io.tool_output("Saved sessions:") - for session_file in sorted(session_files): - try: - with open(session_file, "r", encoding="utf-8") as f: - session_data = json.load(f) - session_name = session_data.get("session_name", session_file.stem) - timestamp = session_data.get("timestamp", 0) - model = session_data.get("model", "unknown") - edit_format = session_data.get("edit_format", "unknown") - - # Format timestamp - if timestamp: - date_str = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M") - else: - date_str = "unknown date" - - self.io.tool_output( - f" {session_name} (model: {model}, format: {edit_format}, {date_str})" - ) - except Exception as e: - self.io.tool_output(f" {session_file.stem} [error reading: {e}]") - - def cmd_load_session(self, args): - """Load a saved session by name or file path""" - if not args.strip(): - self.io.tool_error("Please provide a session name or file path.") - return - - session_name = args.strip() - session_file = self._find_session_file(session_name) - - if not session_file: - self.io.tool_error(f"Session not found: {session_name}") - self.io.tool_output("Use /list-sessions to see available sessions.") - return - - try: - with open(session_file, "r", encoding="utf-8") as f: - session_data = json.load(f) - except Exception as e: - self.io.tool_error(f"Error loading session: {e}") - return - - # Verify session format - if not isinstance(session_data, dict) or "version" not in session_data: - self.io.tool_error("Invalid session format.") - return - - # Load session data - try: - # Clear current state - self.coder.abs_fnames = set() - self.coder.abs_read_only_fnames = set() - self.coder.abs_read_only_stubs_fnames = set() - self.coder.done_messages = [] - self.coder.cur_messages = [] - - # Load chat history - chat_history = session_data.get("chat_history", {}) - self.coder.done_messages = chat_history.get("done_messages", []) - self.coder.cur_messages = chat_history.get("cur_messages", []) - - # Load files - files = session_data.get("files", {}) - for rel_fname in files.get("editable", []): - abs_fname = self.coder.abs_root_path(rel_fname) - if os.path.exists(abs_fname): - self.coder.abs_fnames.add(abs_fname) - else: - self.io.tool_warning(f"File not found, skipping: {rel_fname}") - - for rel_fname in files.get("read_only", []): - abs_fname = self.coder.abs_root_path(rel_fname) - if os.path.exists(abs_fname): - self.coder.abs_read_only_fnames.add(abs_fname) - else: - self.io.tool_warning(f"File not found, skipping: {rel_fname}") - - for rel_fname in files.get("read_only_stubs", []): - abs_fname = self.coder.abs_root_path(rel_fname) - if os.path.exists(abs_fname): - self.coder.abs_read_only_stubs_fnames.add(abs_fname) - else: - self.io.tool_warning(f"File not found, skipping: {rel_fname}") - - # Load settings - settings = session_data.get("settings", {}) - if "auto_commits" in settings: - self.coder.auto_commits = settings["auto_commits"] - if "auto_lint" in settings: - self.coder.auto_lint = settings["auto_lint"] - if "auto_test" in settings: - self.coder.auto_test = settings["auto_test"] - - self.io.tool_output( - f"Session loaded: {session_data.get('session_name', session_file.stem)}" - ) + for session_info in sessions_list: self.io.tool_output( - f"Model: {session_data.get('model', 'unknown')}, Edit format:" - f" {session_data.get('edit_format', 'unknown')}" + f" {session_info['name']} (model: {session_info['model']}, " + f"format: {session_info['edit_format']}, " + f"{session_info['num_messages']} messages, {session_info['num_files']} files)" ) - # Show summary - num_messages = len(self.coder.done_messages) + len(self.coder.cur_messages) - num_files = ( - len(self.coder.abs_fnames) - + len(self.coder.abs_read_only_fnames) - + len(self.coder.abs_read_only_stubs_fnames) - ) - self.io.tool_output(f"Loaded {num_messages} messages and {num_files} files") - - except Exception as e: - self.io.tool_error(f"Error applying session data: {e}") + def cmd_load_session(self, args): + """Load a saved session by name or file path""" + session_manager = sessions.SessionManager(self.coder, self.io) + session_manager.load_session(args.strip()) def cmd_copy_context(self, args=None): """Copy the current chat context as markdown, suitable to paste into a web UI""" diff --git a/aider/help.py b/aider/help.py index f6587e638e1..3da5deec300 100755 --- a/aider/help.py +++ b/aider/help.py @@ -15,13 +15,13 @@ warnings.simplefilter("ignore", category=FutureWarning) -def install_help_extra(io): +async def install_help_extra(io): pip_install_cmd = [ "aider-ce[help]", "--extra-index-url", "https://download.pytorch.org/whl/cpu", ] - res = utils.check_pip_install_extra( + res = await utils.check_pip_install_extra( io, "llama_index.embeddings.huggingface", "To use interactive /help you need to install the help extras", diff --git a/aider/main.py b/aider/main.py index c315338ab11..0c58da839eb 100644 --- a/aider/main.py +++ b/aider/main.py @@ -5,6 +5,7 @@ import re import sys import threading +import time import traceback import webbrowser from dataclasses import fields @@ -88,10 +89,10 @@ def guessed_wrong_repo(io, git_root, fnames, git_dname): return str(check_repo) -def make_new_repo(git_root, io): +async def make_new_repo(git_root, io): try: repo = git.Repo.init(git_root) - check_gitignore(git_root, io, False) + await check_gitignore(git_root, io, False) except ANY_GIT_ERROR as err: # issue #1233 io.tool_error(f"Unable to create git repo in {git_root}") io.tool_output(str(err)) @@ -101,7 +102,7 @@ def make_new_repo(git_root, io): return repo -def setup_git(git_root, io): +async def setup_git(git_root, io): if git is None: return @@ -122,11 +123,11 @@ def setup_git(git_root, io): "You should probably run aider in your project's directory, not your home dir." ) return - elif cwd and io.confirm_ask( + elif cwd and await io.confirm_ask( "No git repo found, create one to track aider's changes (recommended)?" ): git_root = str(cwd.resolve()) - repo = make_new_repo(git_root, io) + repo = await make_new_repo(git_root, io) if not repo: return @@ -155,7 +156,7 @@ def setup_git(git_root, io): return repo.working_tree_dir -def check_gitignore(git_root, io, ask=True): +async def check_gitignore(git_root, io, ask=True): if not git_root: return @@ -191,7 +192,9 @@ def check_gitignore(git_root, io, ask=True): if ask: io.tool_output("You can skip this check with --no-gitignore") - if not io.confirm_ask(f"Add {', '.join(patterns_to_add)} to .gitignore (recommended)?"): + if not await io.confirm_ask( + f"Add {', '.join(patterns_to_add)} to .gitignore (recommended)?" + ): return content += "\n".join(patterns_to_add) + "\n" @@ -208,8 +211,8 @@ def check_gitignore(git_root, io, ask=True): io.tool_output(f" {pattern}") -def check_streamlit_install(io): - return utils.check_pip_install_extra( +async def check_streamlit_install(io): + return await utils.check_pip_install_extra( io, "streamlit", "You need to install the aider browser feature", @@ -217,7 +220,7 @@ def check_streamlit_install(io): ) -def write_streamlit_credentials(): +async def write_streamlit_credentials(): from streamlit.file_util import get_streamlit_file_path # See https://github.com/Aider-AI/aider/issues/772 @@ -472,7 +475,7 @@ def expand_glob_patterns(patterns, root="."): PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__)) log_file = None -file_blacklist = ["get_bottom_toolbar", ""] +file_excludelist = ["get_bottom_toolbar", ""] def custom_tracer(frame, event, arg): @@ -492,15 +495,21 @@ def custom_tracer(frame, event, arg): func_name = frame.f_code.co_name line_no = frame.f_lineno - if func_name not in file_blacklist: - log_file.write(f"-> CALL: {func_name}() in {os.path.basename(filename)}:{line_no}\n") + if func_name not in file_excludelist: + log_file.write( + f"-> CALL: {func_name}() in {os.path.basename(filename)}:{line_no} -" + f" {time.time()}\n" + ) if event == "return": func_name = frame.f_code.co_name line_no = frame.f_lineno - if func_name not in file_blacklist: - log_file.write(f"<- RETURN: {func_name}() in {os.path.basename(filename)}:{line_no}\n") + if func_name not in file_excludelist: + log_file.write( + f"<- RETURN: {func_name}() in {os.path.basename(filename)}:{line_no} -" + f" {time.time()}\n" + ) # Must return the trace function (or a local one) for subsequent events return custom_tracer @@ -715,7 +724,7 @@ def get_io(pretty): " personal info." ) io.tool_output(f"For more info: {urls.analytics}") - disable = not io.confirm_ask( + disable = not await io.confirm_ask( "Allow collection of anonymous analytics to help improve aider?" ) @@ -733,7 +742,7 @@ def get_io(pretty): analytics.event("launched") if args.gui and not return_coder: - if not check_streamlit_install(io): + if not await check_streamlit_install(io): analytics.event("exit", reason="Streamlit not installed") return analytics.event("gui session") @@ -800,12 +809,12 @@ def get_io(pretty): return 0 if not update_available else 1 if args.install_main_branch: - success = install_from_main_branch(io) + success = await install_from_main_branch(io) analytics.event("exit", reason="Installed main branch") return 0 if success else 1 if args.upgrade: - success = install_upgrade(io) + success = await install_upgrade(io) analytics.event("exit", reason="Upgrade completed") return 0 if success else 1 @@ -813,9 +822,9 @@ def get_io(pretty): check_version(io, verbose=args.verbose) if args.git: - git_root = setup_git(git_root, io) + git_root = await setup_git(git_root, io) if args.gitignore: - check_gitignore(git_root, io) + await check_gitignore(git_root, io) if args.verbose: show = format_settings(parser, args) @@ -963,7 +972,7 @@ def get_io(pretty): return 1 if args.show_model_warnings: - problem = models.sanity_check_models(io, main_model) + problem = await models.sanity_check_models(io, main_model) if problem: analytics.event("model warning", main_model=main_model) io.tool_output("You can skip this check with --no-show-model-warnings") @@ -1255,6 +1264,17 @@ def get_io(pretty): analytics.event("cli session", main_model=main_model, edit_format=main_model.edit_format) + # Auto-load session if enabled + if args.auto_load: + try: + from aider.sessions import SessionManager + + session_manager = SessionManager(coder, io) + session_manager.load_session("auto-save") + except Exception: + # Don't show errors for auto-load to avoid interrupting the user experience + pass + while True: try: coder.ok_to_warm_cache = bool(args.cache_keepalive_pings) diff --git a/aider/models.py b/aider/models.py index 8bcd3e07bb0..3b159789ce7 100644 --- a/aider/models.py +++ b/aider/models.py @@ -1118,12 +1118,12 @@ def validate_variables(vars): return dict(keys_in_environment=True, missing_keys=missing) -def sanity_check_models(io, main_model): - problem_main = sanity_check_model(io, main_model) +async def sanity_check_models(io, main_model): + problem_main = await sanity_check_model(io, main_model) problem_weak = None if main_model.weak_model and main_model.weak_model is not main_model: - problem_weak = sanity_check_model(io, main_model.weak_model) + problem_weak = await sanity_check_model(io, main_model.weak_model) problem_editor = None if ( @@ -1131,12 +1131,12 @@ def sanity_check_models(io, main_model): and main_model.editor_model is not main_model and main_model.editor_model is not main_model.weak_model ): - problem_editor = sanity_check_model(io, main_model.editor_model) + problem_editor = await sanity_check_model(io, main_model.editor_model) return problem_main or problem_weak or problem_editor -def sanity_check_model(io, model): +async def sanity_check_model(io, model): show = False if model.missing_keys: @@ -1158,7 +1158,7 @@ def sanity_check_model(io, model): io.tool_warning(f"Warning for {model}: Unknown which environment variables are required.") # Check for model-specific dependencies - check_for_dependencies(io, model.name) + await check_for_dependencies(io, model.name) if not model.info: show = True @@ -1175,7 +1175,7 @@ def sanity_check_model(io, model): return show -def check_for_dependencies(io, model_name): +async def check_for_dependencies(io, model_name): """ Check for model-specific dependencies and install them if needed. @@ -1185,13 +1185,13 @@ def check_for_dependencies(io, model_name): """ # Check if this is a Bedrock model and ensure boto3 is installed if model_name.startswith("bedrock/"): - check_pip_install_extra( + await check_pip_install_extra( io, "boto3", "AWS Bedrock models require the boto3 package.", ["boto3"] ) # Check if this is a Vertex AI model and ensure google-cloud-aiplatform is installed elif model_name.startswith("vertex_ai/"): - check_pip_install_extra( + await check_pip_install_extra( io, "google.cloud.aiplatform", "Google Vertex AI models require the google-cloud-aiplatform package.", diff --git a/aider/repo.py b/aider/repo.py index e26fa7ea5b9..0e7664af120 100644 --- a/aider/repo.py +++ b/aider/repo.py @@ -216,7 +216,7 @@ async def commit(self, fnames=None, context=None, message=None, aider_edits=Fals commit_message = await self.get_commit_message(diffs, context, user_language) # Retrieve attribute settings, prioritizing coder.args if available - if coder and hasattr(coder, "args"): + if coder and hasattr(coder, "args") and coder.args: attribute_author = coder.args.attribute_author attribute_committer = coder.args.attribute_committer attribute_commit_message_author = coder.args.attribute_commit_message_author diff --git a/aider/repomap.py b/aider/repomap.py index 40712213d45..b1af7d176d1 100644 --- a/aider/repomap.py +++ b/aider/repomap.py @@ -697,7 +697,7 @@ def get_ranked_tags( # Ideally, this will help downweight boiler plate in frameworks, interfaces, and abstract classes if len(defines[ident]) > 4: exp = min(len(defines[ident]), 32) - mul *= math.log((4 / (2**exp)) + 1) + mul *= math.log2((4 / (2**exp)) + 1) # Calculate multiplier: log(number of unique file references * total references ^ 2) # Used to balance the number of times an identifier appears with its number of refs per file @@ -709,7 +709,7 @@ def get_ranked_tags( # With absolute number of references throughout a codebase unique_file_refs = len(set(references[ident])) total_refs = len(references[ident]) - ext_mul = round(math.log(unique_file_refs * total_refs**2 + 1)) + ext_mul = round(math.log2(unique_file_refs * total_refs**2 + 1)) for referencer, num_refs in Counter(references[ident]).items(): for definer in definers: diff --git a/aider/scrape.py b/aider/scrape.py index 5e91a3eeace..3bdef29214b 100755 --- a/aider/scrape.py +++ b/aider/scrape.py @@ -14,32 +14,41 @@ # platforms. -def check_env(): +def check_playwright(): try: - from playwright.sync_api import sync_playwright + from playwright.async_api import async_playwright # noqa: F401 + from playwright.sync_api import sync_playwright # noqa: F401 - has_pip = True + has_playwright = True except ImportError: - has_pip = False + has_playwright = False + + return has_playwright - try: - with sync_playwright() as p: - p.chromium.launch() - has_chromium = True - except Exception: - has_chromium = False - return has_pip, has_chromium +async def check_chromium(): + has_chromium = False + if check_playwright(): + from playwright.async_api import async_playwright -def has_playwright(): - has_pip, has_chromium = check_env() - return has_pip and has_chromium + try: + async with async_playwright() as p: + browser = await p.chromium.launch() + await browser.close() + has_chromium = True + except Exception as e: + has_chromium = False + print(f"chromium errors {e}") + + return has_chromium async def install_playwright(io): - has_pip, has_chromium = check_env() - if has_pip and has_chromium: + has_playwright = check_playwright() + has_chromium = await check_chromium() + + if has_playwright and has_chromium: return True pip_cmd = utils.get_pip_install(["aider-ce[playwright]"]) @@ -47,7 +56,7 @@ async def install_playwright(io): chromium_cmd = [sys.executable] + chromium_cmd.split() cmds = "" - if not has_pip: + if not has_playwright: cmds += " ".join(pip_cmd) + "\n" if not has_chromium: cmds += " ".join(chromium_cmd) + "\n" @@ -62,7 +71,7 @@ async def install_playwright(io): if not await io.confirm_ask("Install playwright?", default="y"): return - if not has_pip: + if not has_playwright: success, output = utils.run_install(pip_cmd) if not success: io.tool_error(output) @@ -95,7 +104,7 @@ def __init__(self, print_error=None, playwright_available=None, verify_ssl=True) self.playwright_available = playwright_available self.verify_ssl = verify_ssl - def scrape(self, url): + async def scrape(self, url): """ Scrape a url and turn it into readable markdown if it's HTML. If it's plain text or non-HTML, return it as-is. @@ -104,7 +113,7 @@ def scrape(self, url): """ if self.playwright_available: - content, mime_type = self.scrape_with_playwright(url) + content, mime_type = await self.scrape_with_playwright(url) else: content, mime_type = self.scrape_with_httpx(url) @@ -140,34 +149,34 @@ def looks_like_html(self, content): return False # Internals... - def scrape_with_playwright(self, url): + async def scrape_with_playwright(self, url): import playwright # noqa: F401 - from playwright.sync_api import Error as PlaywrightError - from playwright.sync_api import TimeoutError as PlaywrightTimeoutError - from playwright.sync_api import sync_playwright + from playwright.async_api import Error as PlaywrightError + from playwright.async_api import TimeoutError as PlaywrightTimeoutError + from playwright.async_api import async_playwright - with sync_playwright() as p: + async with async_playwright() as p: try: - browser = p.chromium.launch() + browser = await p.chromium.launch() except Exception as e: self.playwright_available = False self.print_error(str(e)) return None, None try: - context = browser.new_context(ignore_https_errors=not self.verify_ssl) - page = context.new_page() + context = await browser.new_context(ignore_https_errors=not self.verify_ssl) + page = await context.new_page() - user_agent = page.evaluate("navigator.userAgent") + user_agent = await page.evaluate("navigator.userAgent") user_agent = user_agent.replace("Headless", "") user_agent = user_agent.replace("headless", "") user_agent += " " + aider_user_agent - page.set_extra_http_headers({"User-Agent": user_agent}) + await page.set_extra_http_headers({"User-Agent": user_agent}) response = None try: - response = page.goto(url, wait_until="networkidle", timeout=5000) + response = await page.goto(url, wait_until="networkidle", timeout=5000) except PlaywrightTimeoutError: self.print_error(f"Page didn't quiesce, scraping content anyway: {url}") response = None @@ -176,10 +185,10 @@ def scrape_with_playwright(self, url): return None, None try: - content = page.content() + content = await page.content() mime_type = None if response: - content_type = response.header_value("content-type") + content_type = await response.header_value("content-type") if content_type: mime_type = content_type.split(";")[0] except PlaywrightError as e: @@ -187,7 +196,7 @@ def scrape_with_playwright(self, url): content = None mime_type = None finally: - browser.close() + await browser.close() return content, mime_type @@ -271,9 +280,9 @@ def slimdown_html(soup): return soup -def main(url): - scraper = Scraper(playwright_available=has_playwright()) - content = scraper.scrape(url) +async def main(url): + scraper = Scraper(playwright_available=check_playwright()) + content = await scraper.scrape(url) print(content) diff --git a/aider/sessions.py b/aider/sessions.py new file mode 100644 index 00000000000..cfff99fe98a --- /dev/null +++ b/aider/sessions.py @@ -0,0 +1,240 @@ +"""Session management utilities for Aider.""" + +import json +import os +from pathlib import Path +from typing import Dict, List, Optional + + +class SessionManager: + """Manages chat session saving, listing, and loading.""" + + def __init__(self, coder, io): + self.coder = coder + self.io = io + + def _get_session_directory(self) -> Path: + """Get the session directory, creating it if necessary.""" + session_dir = Path(self.coder.abs_root_path(".aider/sessions")) + os.makedirs(session_dir, exist_ok=True) + return session_dir + + def save_session(self, session_name: str, output=True) -> bool: + """Save the current chat session to a named file.""" + if not session_name: + if output: + self.io.tool_error("Please provide a session name.") + return False + + session_name = session_name.replace(".json", "") + session_dir = self._get_session_directory() + session_file = session_dir / f"{session_name}.json" + + if session_file.exists(): + if output: + self.io.tool_warning(f"Session '{session_name}' already exists. Overwriting.") + + try: + session_data = self._build_session_data(session_name) + with open(session_file, "w", encoding="utf-8") as f: + json.dump(session_data, f, indent=2) + + if output: + self.io.tool_output(f"Session saved: {session_file}") + + return True + + except Exception as e: + self.io.tool_error(f"Error saving session: {e}") + return False + + def list_sessions(self) -> List[Dict]: + """List all saved sessions with metadata.""" + session_dir = self._get_session_directory() + session_files = list(session_dir.glob("*.json")) + + if not session_files: + self.io.tool_output("No saved sessions found.") + return [] + + sessions = [] + for session_file in sorted(session_files, key=lambda x: x.stat().st_mtime, reverse=True): + try: + with open(session_file, "r", encoding="utf-8") as f: + session_data = json.load(f) + + session_info = { + "name": session_file.stem, + "file": session_file, + "model": session_data.get("model", "unknown"), + "edit_format": session_data.get("edit_format", "unknown"), + "num_messages": len( + session_data.get("chat_history", {}).get("done_messages", []) + ) + len(session_data.get("chat_history", {}).get("cur_messages", [])), + "num_files": ( + len(session_data.get("files", {}).get("editable", [])) + + len(session_data.get("files", {}).get("read_only", [])) + + len(session_data.get("files", {}).get("read_only_stubs", [])) + ), + } + sessions.append(session_info) + + except Exception as e: + self.io.tool_output(f" {session_file.stem} [error reading: {e}]") + + return sessions + + def load_session(self, session_identifier: str) -> bool: + """Load a saved session by name or file path.""" + if not session_identifier: + self.io.tool_error("Please provide a session name or file path.") + return False + + # Try to find the session file + session_file = self._find_session_file(session_identifier) + if not session_file: + return False + + try: + with open(session_file, "r", encoding="utf-8") as f: + session_data = json.load(f) + except Exception as e: + self.io.tool_error(f"Error loading session: {e}") + return False + + # Verify session format + if not isinstance(session_data, dict) or "version" not in session_data: + self.io.tool_error("Invalid session format.") + return False + + # Apply session data + return self._apply_session_data(session_data, session_file) + + def _build_session_data(self, session_name) -> Dict: + """Build session data dictionary from current coder state.""" + # Get relative paths for all files + editable_files = [ + self.coder.get_rel_fname(abs_fname) for abs_fname in self.coder.abs_fnames + ] + read_only_files = [ + self.coder.get_rel_fname(abs_fname) for abs_fname in self.coder.abs_read_only_fnames + ] + read_only_stubs_files = [ + self.coder.get_rel_fname(abs_fname) + for abs_fname in self.coder.abs_read_only_stubs_fnames + ] + + return { + "version": 1, + "session_name": session_name, + "model": self.coder.main_model.name, + "edit_format": self.coder.edit_format, + "chat_history": { + "done_messages": self.coder.done_messages, + "cur_messages": self.coder.cur_messages, + }, + "files": { + "editable": editable_files, + "read_only": read_only_files, + "read_only_stubs": read_only_stubs_files, + }, + "settings": { + "auto_commits": self.coder.auto_commits, + "auto_lint": self.coder.auto_lint, + "auto_test": self.coder.auto_test, + }, + } + + def _find_session_file(self, session_identifier: str) -> Optional[Path]: + """Find session file by name or path.""" + # Check if it's a direct file path + session_file = Path(session_identifier) + if session_file.exists(): + return session_file + + # Check if it's a session name in the sessions directory + session_dir = self._get_session_directory() + + # Try with .json extension + if not session_identifier.endswith(".json"): + session_file = session_dir / f"{session_identifier}.json" + if session_file.exists(): + return session_file + + session_file = session_dir / f"{session_identifier}" + if session_file.exists(): + return session_file + + self.io.tool_error(f"Session not found: {session_identifier}") + self.io.tool_output("Use /list-sessions to see available sessions.") + return None + + def _apply_session_data(self, session_data: Dict, session_file: Path) -> bool: + """Apply session data to current coder state.""" + try: + # Clear current state + self.coder.abs_fnames = set() + self.coder.abs_read_only_fnames = set() + self.coder.abs_read_only_stubs_fnames = set() + self.coder.done_messages = [] + self.coder.cur_messages = [] + + # Load chat history + chat_history = session_data.get("chat_history", {}) + self.coder.done_messages = chat_history.get("done_messages", []) + self.coder.cur_messages = chat_history.get("cur_messages", []) + + # Load files + files = session_data.get("files", {}) + for rel_fname in files.get("editable", []): + abs_fname = self.coder.abs_root_path(rel_fname) + if os.path.exists(abs_fname): + self.coder.abs_fnames.add(abs_fname) + else: + self.io.tool_warning(f"File not found, skipping: {rel_fname}") + + for rel_fname in files.get("read_only", []): + abs_fname = self.coder.abs_root_path(rel_fname) + if os.path.exists(abs_fname): + self.coder.abs_read_only_fnames.add(abs_fname) + else: + self.io.tool_warning(f"File not found, skipping: {rel_fname}") + + for rel_fname in files.get("read_only_stubs", []): + abs_fname = self.coder.abs_root_path(rel_fname) + if os.path.exists(abs_fname): + self.coder.abs_read_only_stubs_fnames.add(abs_fname) + else: + self.io.tool_warning(f"File not found, skipping: {rel_fname}") + + # Load settings + settings = session_data.get("settings", {}) + if "auto_commits" in settings: + self.coder.auto_commits = settings["auto_commits"] + if "auto_lint" in settings: + self.coder.auto_lint = settings["auto_lint"] + if "auto_test" in settings: + self.coder.auto_test = settings["auto_test"] + + self.io.tool_output( + f"Session loaded: {session_data.get('session_name', session_file.stem)}" + ) + self.io.tool_output( + f"Model: {session_data.get('model', 'unknown')}, Edit format:" + f" {session_data.get('edit_format', 'unknown')}" + ) + + # Show summary + num_messages = len(self.coder.done_messages) + len(self.coder.cur_messages) + num_files = ( + len(self.coder.abs_fnames) + + len(self.coder.abs_read_only_fnames) + + len(self.coder.abs_read_only_stubs_fnames) + ) + self.io.tool_output(f"Loaded {num_messages} messages and {num_files} files") + + return True + + except Exception as e: + self.io.tool_error(f"Error applying session data: {e}") + return False diff --git a/aider/utils.py b/aider/utils.py index 0a7a06ded11..50b0e023fdd 100644 --- a/aider/utils.py +++ b/aider/utils.py @@ -338,7 +338,7 @@ def touch_file(fname): return False -def check_pip_install_extra(io, module, prompt, pip_install_cmd, self_update=False): +async def check_pip_install_extra(io, module, prompt, pip_install_cmd, self_update=False): if module: try: __import__(module) @@ -357,7 +357,9 @@ def check_pip_install_extra(io, module, prompt, pip_install_cmd, self_update=Fal print(printable_shell_command(cmd)) # plain print so it doesn't line-wrap return - if not io.confirm_ask("Run pip install?", default="y", subject=printable_shell_command(cmd)): + if not await io.confirm_ask( + "Run pip install?", default="y", subject=printable_shell_command(cmd) + ): return success, output = run_install(cmd) diff --git a/aider/versioncheck.py b/aider/versioncheck.py index 1de5f3da1f1..c994e1c76ff 100644 --- a/aider/versioncheck.py +++ b/aider/versioncheck.py @@ -12,12 +12,12 @@ VERSION_CHECK_FNAME = Path.home() / ".aider" / "caches" / "versioncheck" -def install_from_main_branch(io): +async def install_from_main_branch(io): """ Install the latest version of aider from the main branch of the GitHub repository. """ - return utils.check_pip_install_extra( + return await utils.check_pip_install_extra( io, None, "Install the development version of aider from the main branch?", @@ -26,7 +26,7 @@ def install_from_main_branch(io): ) -def install_upgrade(io, latest_version=None): +async def install_upgrade(io, latest_version=None): """ Install the latest version of aider from PyPI. """ @@ -46,7 +46,7 @@ def install_upgrade(io, latest_version=None): io.tool_warning(text) return True - success = utils.check_pip_install_extra( + success = await utils.check_pip_install_extra( io, None, new_ver_text, @@ -75,7 +75,7 @@ def check_version(io, just_check=False, verbose=False): import requests try: - response = requests.get("https://pypi.org/pypi/aider-chat/json") + response = requests.get("https://pypi.org/pypi/aider-ce/json") data = response.json() latest_version = data["info"]["version"] current_version = aider.__version__ diff --git a/aider/website/docs/config/agent-mode.md b/aider/website/docs/config/agent-mode.md index 4140db63305..038a04352bd 100644 --- a/aider/website/docs/config/agent-mode.md +++ b/aider/website/docs/config/agent-mode.md @@ -80,7 +80,6 @@ Agent Mode prioritizes granular tools over SEARCH/REPLACE: - **Line number verification**: Two-step process for line-based edits to prevents errors - **Tool usage monitoring**: Prevents infinite loops by tracking repetitive patterns - ### Workflow Process #### 1. Exploration Phase @@ -151,13 +150,14 @@ Agent Mode can be configured using the `--agent-config` command line argument, w #### Configuration Options -- **`tools_whitelist`**: Array of tool names to allow (only these tools will be available) -- **`tools_blacklist`**: Array of tool names to exclude (these tools will be disabled) +- **`tools_includelist`**: Array of tool names to allow (only these tools will be available) +- **`tools_excludelist`**: Array of tool names to exclude (these tools will be disabled) - **`large_file_token_threshold`**: Maximum token threshold for large file warnings (default: 25000) #### Essential Tools -Certain tools are always available regardless of whitelist/blacklist settings: +Certain tools are always available regardless of includelist/excludelist settings: + - `makeeditable` - Make files editable - `replacetext` - Basic text replacement - `view` - View files @@ -167,16 +167,16 @@ Certain tools are always available regardless of whitelist/blacklist settings: ```bash # Only allow specific tools -aider --agent --agent-config '{"tools_whitelist": ["view", "makeeditable", "replacetext", "finished"]}' +aider-ce --agent --agent-config '{"tools_includelist": ["view", "makeeditable", "replacetext", "finished"]}' -# Exclude specific tools -aider --agent --agent-config '{"tools_blacklist": ["command", "commandinteractive"]}' +# Exclude specific tools +aider-ce --agent --agent-config '{"tools_excludelist": ["command", "commandinteractive"]}' # Custom large file threshold -aider --agent --agent-config '{"large_file_token_threshold": 10000}' +aider-ce --agent --agent-config '{"large_file_token_threshold": 10000}' # Combined configuration -aider --agent --agent-config '{"large_file_token_threshold": 10000, "tools_whitelist": ["view", "makeeditable", "replacetext", "finished", "gitdiff"]}' +aider-ce --agent --agent-config '{"large_file_token_threshold": 10000, "tools_includelist": ["view", "makeeditable", "replacetext", "finished", "gitdiff"]}' ``` This configuration system allows for fine-grained control over which tools are available in Agent Mode, enabling security-conscious deployments and specialized workflows while maintaining essential functionality. @@ -189,4 +189,5 @@ This configuration system allows for fine-grained control over which tools are a - **Scalable exploration**: Can handle large codebases through strategic context management - **Recovery mechanisms**: Built-in undo and safety features -Agent Mode represents a significant evolution in aider's capabilities, enabling more sophisticated and autonomous codebase manipulation while maintaining safety and control through the tool-based architecture. \ No newline at end of file +Agent Mode represents a significant evolution in aider's capabilities, enabling more sophisticated and autonomous codebase manipulation while maintaining safety and control through the tool-based architecture. + diff --git a/aider/website/docs/sessions.md b/aider/website/docs/sessions.md index d38f449d583..8ca925fba2b 100644 --- a/aider/website/docs/sessions.md +++ b/aider/website/docs/sessions.md @@ -11,6 +11,27 @@ Aider provides session management commands that allow you to save, load, and man ### `/save-session ` Save the current chat session to a named file in `.aider/sessions/`. +### Auto-Save and Auto-Load +Aider can automatically save and load sessions using command line options: + +**Auto-save:** +```bash +aider --auto-save +``` + +**Auto-load:** +```bash +aider --auto-load +``` + +**In configuration files:** +```yaml +auto-save: true +auto-load: true +``` + +When `--auto-save` is enabled, aider will automatically save your session as 'auto-save' when you exit. When `--auto-load` is enabled, aider will automatically load the 'auto-save' session on startup if it exists. + **Usage:** ``` /save-session my-project-session @@ -173,10 +194,10 @@ If a session fails to load: ### Multiple Model Sessions ``` # Save session with specific model -/model gpt-4 +/model gpt-5 /save-session gpt4-session # Try different model -/model claude-3 +/model claude-sonnet-4.5 /save-session claude-session ``` diff --git a/tests/basic/test_main.py b/tests/basic/test_main.py index 6f6e81fed17..5cd128aba8a 100644 --- a/tests/basic/test_main.py +++ b/tests/basic/test_main.py @@ -118,9 +118,9 @@ async def test_main_with_empty_git_dir_new_subdir_file(self): # Because aider will try and `git add` a file that's already in the repo. await main(["--yes", str(fname), "--exit"], input=DummyInput(), output=DummyOutput()) - def test_setup_git(self): + async def test_setup_git(self): io = InputOutput(pretty=False, yes=True) - git_root = setup_git(None, io) + git_root = await setup_git(None, io) git_root = Path(git_root).resolve() self.assertEqual(git_root, Path(self.tempdir).resolve()) @@ -130,7 +130,7 @@ def test_setup_git(self): self.assertTrue(gitignore.exists()) self.assertEqual(".aider*", gitignore.read_text().splitlines()[0]) - def test_check_gitignore(self): + async def test_check_gitignore(self): with GitTemporaryDirectory(): os.environ["GIT_CONFIG_GLOBAL"] = "globalgitconfig" @@ -139,20 +139,20 @@ def test_check_gitignore(self): gitignore = cwd / ".gitignore" self.assertFalse(gitignore.exists()) - check_gitignore(cwd, io) + await check_gitignore(cwd, io) self.assertTrue(gitignore.exists()) self.assertEqual(".aider*", gitignore.read_text().splitlines()[0]) # Test without .env file present gitignore.write_text("one\ntwo\n") - check_gitignore(cwd, io) + await check_gitignore(cwd, io) self.assertEqual("one\ntwo\n.aider*\n", gitignore.read_text()) # Test with .env file present env_file = cwd / ".env" env_file.touch() - check_gitignore(cwd, io) + await check_gitignore(cwd, io) self.assertEqual("one\ntwo\n.aider*\n.env\n", gitignore.read_text()) del os.environ["GIT_CONFIG_GLOBAL"] diff --git a/tests/basic/test_models.py b/tests/basic/test_models.py index 11e42b807af..144cd1f3227 100644 --- a/tests/basic/test_models.py +++ b/tests/basic/test_models.py @@ -50,7 +50,7 @@ def test_max_context_tokens(self): self.assertEqual(model.info["max_input_tokens"], 8 * 1024) @patch("os.environ") - def test_sanity_check_model_all_set(self, mock_environ): + async def test_sanity_check_model_all_set(self, mock_environ): mock_environ.get.return_value = "dummy_value" mock_io = MagicMock() model = MagicMock() @@ -59,7 +59,7 @@ def test_sanity_check_model_all_set(self, mock_environ): model.keys_in_environment = True model.info = {"some": "info"} - sanity_check_model(mock_io, model) + await sanity_check_model(mock_io, model) mock_io.tool_output.assert_called() calls = mock_io.tool_output.call_args_list @@ -67,7 +67,7 @@ def test_sanity_check_model_all_set(self, mock_environ): self.assertIn("- API_KEY2: Set", str(calls)) @patch("os.environ") - def test_sanity_check_model_not_set(self, mock_environ): + async def test_sanity_check_model_not_set(self, mock_environ): mock_environ.get.return_value = "" mock_io = MagicMock() model = MagicMock() @@ -76,19 +76,19 @@ def test_sanity_check_model_not_set(self, mock_environ): model.keys_in_environment = True model.info = {"some": "info"} - sanity_check_model(mock_io, model) + await sanity_check_model(mock_io, model) mock_io.tool_output.assert_called() calls = mock_io.tool_output.call_args_list self.assertIn("- API_KEY1: Not set", str(calls)) self.assertIn("- API_KEY2: Not set", str(calls)) - def test_sanity_check_models_bogus_editor(self): + async def test_sanity_check_models_bogus_editor(self): mock_io = MagicMock() main_model = Model("gpt-4") main_model.editor_model = Model("bogus-model") - result = sanity_check_models(mock_io, main_model) + result = await sanity_check_models(mock_io, main_model) self.assertTrue( result @@ -106,7 +106,7 @@ def test_sanity_check_models_bogus_editor(self): ) # Check that one of the warnings mentions the bogus model @patch("aider.models.check_for_dependencies") - def test_sanity_check_model_calls_check_dependencies(self, mock_check_deps): + async def test_sanity_check_model_calls_check_dependencies(self, mock_check_deps): """Test that sanity_check_model calls check_for_dependencies""" mock_io = MagicMock() model = MagicMock() @@ -115,7 +115,7 @@ def test_sanity_check_model_calls_check_dependencies(self, mock_check_deps): model.keys_in_environment = True model.info = {"some": "info"} - sanity_check_model(mock_io, model) + await sanity_check_model(mock_io, model) # Verify check_for_dependencies was called with the model name mock_check_deps.assert_called_once_with(mock_io, "test-model") @@ -206,7 +206,7 @@ def test_set_thinking_tokens(self): self.assertEqual(model.extra_params["thinking"]["budget_tokens"], 0.5 * 1024 * 1024) @patch("aider.models.check_pip_install_extra") - def test_check_for_dependencies_bedrock(self, mock_check_pip): + async def test_check_for_dependencies_bedrock(self, mock_check_pip): """Test that check_for_dependencies calls check_pip_install_extra for Bedrock models""" from aider.io import InputOutput @@ -215,7 +215,7 @@ def test_check_for_dependencies_bedrock(self, mock_check_pip): # Test with a Bedrock model from aider.models import check_for_dependencies - check_for_dependencies(io, "bedrock/anthropic.claude-3-sonnet-20240229-v1:0") + await check_for_dependencies(io, "bedrock/anthropic.claude-3-sonnet-20240229-v1:0") # Verify check_pip_install_extra was called with correct arguments mock_check_pip.assert_called_once_with( @@ -223,7 +223,7 @@ def test_check_for_dependencies_bedrock(self, mock_check_pip): ) @patch("aider.models.check_pip_install_extra") - def test_check_for_dependencies_vertex_ai(self, mock_check_pip): + async def test_check_for_dependencies_vertex_ai(self, mock_check_pip): """Test that check_for_dependencies calls check_pip_install_extra for Vertex AI models""" from aider.io import InputOutput @@ -232,7 +232,7 @@ def test_check_for_dependencies_vertex_ai(self, mock_check_pip): # Test with a Vertex AI model from aider.models import check_for_dependencies - check_for_dependencies(io, "vertex_ai/gemini-1.5-pro") + await check_for_dependencies(io, "vertex_ai/gemini-1.5-pro") # Verify check_pip_install_extra was called with correct arguments mock_check_pip.assert_called_once_with( @@ -243,7 +243,7 @@ def test_check_for_dependencies_vertex_ai(self, mock_check_pip): ) @patch("aider.models.check_pip_install_extra") - def test_check_for_dependencies_other_model(self, mock_check_pip): + async def test_check_for_dependencies_other_model(self, mock_check_pip): """Test that check_for_dependencies doesn't call check_pip_install_extra for other models""" from aider.io import InputOutput @@ -252,7 +252,7 @@ def test_check_for_dependencies_other_model(self, mock_check_pip): # Test with a non-Bedrock, non-Vertex AI model from aider.models import check_for_dependencies - check_for_dependencies(io, "gpt-4") + await check_for_dependencies(io, "gpt-4") # Verify check_pip_install_extra was not called mock_check_pip.assert_not_called() diff --git a/tests/scrape/test_playwright_disable.py b/tests/scrape/test_playwright_disable.py index e9d65073650..39f864ed5ff 100644 --- a/tests/scrape/test_playwright_disable.py +++ b/tests/scrape/test_playwright_disable.py @@ -17,7 +17,7 @@ def tool_error(self, msg): self.outputs.append(f"error: {msg}") -def test_scraper_disable_playwright_flag(monkeypatch): +async def test_scraper_disable_playwright_flag(monkeypatch): io = DummyIO() # Simulate that playwright is not available # (disable_playwright just means playwright_available=False) @@ -30,24 +30,24 @@ def fake_httpx(url): return "plain text", "text/plain" scraper.scrape_with_httpx = fake_httpx - content = scraper.scrape("http://example.com") + content = await scraper.scrape("http://example.com") assert content == "plain text" assert called["called"] -def test_scraper_enable_playwright(monkeypatch): +async def test_scraper_enable_playwright(monkeypatch): io = DummyIO() # Simulate that playwright is available and should be used scraper = Scraper(print_error=io.tool_error, playwright_available=True) # Patch scrape_with_playwright to check it is called called = {} - def fake_playwright(url): + async def fake_playwright(url): called["called"] = True return "hi", "text/html" scraper.scrape_with_playwright = fake_playwright - content = scraper.scrape("http://example.com") + content = await scraper.scrape("http://example.com") assert content.startswith("hi") or "" in content assert called["called"] @@ -111,15 +111,13 @@ def event(self, *a, **k): pass # Patch install_playwright to always return False (simulate not available) - monkeypatch.setattr("aider.scrape.install_playwright", lambda io: False) # Patch Scraper to always use scrape_with_httpx and never warn class DummyScraper: def __init__(self, **kwargs): self.called = False - def scrape(self, url): - self.called = True + async def scrape(self, url): return "dummy content" monkeypatch.setattr("aider.commands.Scraper", DummyScraper) diff --git a/tests/scrape/test_scrape.py b/tests/scrape/test_scrape.py index e769793bb04..8eadf92d60f 100644 --- a/tests/scrape/test_scrape.py +++ b/tests/scrape/test_scrape.py @@ -10,14 +10,14 @@ class TestScrape(unittest.TestCase): @patch("aider.scrape.Scraper.scrape_with_httpx") @patch("aider.scrape.Scraper.scrape_with_playwright") - def test_scrape_self_signed_ssl(self, mock_scrape_playwright, mock_scrape_httpx): + async def test_scrape_self_signed_ssl(self, mock_scrape_playwright, mock_scrape_httpx): # Test with SSL verification - playwright fails mock_scrape_playwright.return_value = (None, None) scraper_verify = Scraper( print_error=MagicMock(), playwright_available=True, verify_ssl=True ) - result_verify = scraper_verify.scrape("https://self-signed.badssl.com") + result_verify = await scraper_verify.scrape("https://self-signed.badssl.com") self.assertIsNone(result_verify) scraper_verify.print_error.assert_called() @@ -29,7 +29,7 @@ def test_scrape_self_signed_ssl(self, mock_scrape_playwright, mock_scrape_httpx) scraper_no_verify = Scraper( print_error=MagicMock(), playwright_available=True, verify_ssl=False ) - result_no_verify = scraper_no_verify.scrape("https://self-signed.badssl.com") + result_no_verify = await scraper_no_verify.scrape("https://self-signed.badssl.com") self.assertIsNotNone(result_no_verify) self.assertIn("self-signed", result_no_verify) scraper_no_verify.print_error.assert_not_called() @@ -85,7 +85,7 @@ def mock_install(*args, **kwargs): del sys.modules["playwright"] @patch("aider.scrape.Scraper.scrape_with_playwright") - def test_scrape_actual_url_with_playwright(self, mock_scrape_playwright): + async def test_scrape_actual_url_with_playwright(self, mock_scrape_playwright): # Create a Scraper instance with a mock print_error function mock_print_error = MagicMock() scraper = Scraper(print_error=mock_print_error, playwright_available=True) @@ -97,7 +97,7 @@ def test_scrape_actual_url_with_playwright(self, mock_scrape_playwright): ) # Scrape a mocked URL - result = scraper.scrape("https://example.com") + result = await scraper.scrape("https://example.com") # Assert that the result contains expected content self.assertIsNotNone(result) @@ -119,7 +119,7 @@ def test_scraper_print_error_not_called(self): # Assert that print_error was never called mock_print_error.assert_not_called() - def test_scrape_with_playwright_error_handling(self): + async def test_scrape_with_playwright_error_handling(self): # Create a Scraper instance with a mock print_error function mock_print_error = MagicMock() scraper = Scraper(print_error=mock_print_error, playwright_available=True) @@ -129,7 +129,7 @@ def test_scrape_with_playwright_error_handling(self): scraper.scrape_with_playwright.return_value = (None, None) # Call the scrape method - result = scraper.scrape("https://example.com") + result = await scraper.scrape("https://example.com") # Assert that the result is None self.assertIsNone(result) @@ -144,7 +144,7 @@ def test_scrape_with_playwright_error_handling(self): # Test with a different return value scraper.scrape_with_playwright.return_value = ("Some content", "text/html") - result = scraper.scrape("https://example.com") + result = await scraper.scrape("https://example.com") # Assert that the result is not None self.assertIsNotNone(result) @@ -152,7 +152,7 @@ def test_scrape_with_playwright_error_handling(self): # Assert that print_error was not called mock_print_error.assert_not_called() - def test_scrape_text_plain(self): + async def test_scrape_text_plain(self): # Create a Scraper instance scraper = Scraper(print_error=MagicMock(), playwright_available=True) @@ -161,12 +161,12 @@ def test_scrape_text_plain(self): scraper.scrape_with_playwright = MagicMock(return_value=(plain_text, "text/plain")) # Call the scrape method - result = scraper.scrape("https://example.com") + result = await scraper.scrape("https://example.com") # Assert that the result is the same as the input plain text self.assertEqual(result, plain_text) - def test_scrape_text_html(self): + async def test_scrape_text_html(self): # Create a Scraper instance scraper = Scraper(print_error=MagicMock(), playwright_available=True) @@ -179,7 +179,7 @@ def test_scrape_text_html(self): scraper.html_to_markdown = MagicMock(return_value=expected_markdown) # Call the scrape method - result = scraper.scrape("https://example.com") + result = await scraper.scrape("https://example.com") # Assert that the result is the expected markdown self.assertEqual(result, expected_markdown)