Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions aider/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,18 @@ def get_parser(default_config_files, git_root):
help="Preserve the existing .aider.todo.txt file on startup (default: False)",
default=False,
)
group.add_argument(
"--auto-save",
action=argparse.BooleanOptionalAction,
default=False,
help="Enable/disable automatic saving of sessions as 'auto-save' (default: False)",
)
group.add_argument(
"--auto-load",
action=argparse.BooleanOptionalAction,
default=False,
help="Enable/disable automatic loading of 'auto-save' session on startup (default: False)",
)
group.add_argument(
"--disable-playwright",
action="store_true",
Expand Down
38 changes: 19 additions & 19 deletions aider/coders/agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(self, *args, **kwargs):
def _build_tool_registry(self):
"""
Build a registry of available tools with their normalized names and process_response functions.
Handles agent configuration with whitelist/blacklist functionality.
Handles agent configuration with includelist/excludelist functionality.

Returns:
dict: Mapping of normalized tool names to tool modules
Expand Down Expand Up @@ -182,10 +182,14 @@ def _build_tool_registry(self):

# Process agent configuration if provided
agent_config = self._get_agent_config()
tools_whitelist = agent_config.get("tools_whitelist", [])
tools_blacklist = agent_config.get("tools_blacklist", [])
tools_includelist = agent_config.get(
"tools_includelist", agent_config.get("tools_whitelist", [])
)
tools_excludelist = agent_config.get(
"tools_excludelist", agent_config.get("tools_blacklist", [])
)

# Always include essential tools regardless of whitelist/blacklist
# Always include essential tools regardless of includelist/excludelist
essential_tools = {"makeeditable", "replacetext", "view", "finished"}
for module in tool_modules:
if hasattr(module, "NORM_NAME") and hasattr(module, "process_response"):
Expand All @@ -194,16 +198,16 @@ def _build_tool_registry(self):
# Check if tool should be included based on configuration
should_include = True

# If whitelist is specified, only include tools in whitelist
if tools_whitelist:
should_include = tool_name in tools_whitelist
# If includelist is specified, only include tools in includelist
if tools_includelist:
should_include = tool_name in tools_includelist

# Always include essential tools
if tool_name in essential_tools:
should_include = True

# Exclude tools in blacklist (unless they're essential)
if tool_name in tools_blacklist and tool_name not in essential_tools:
# Exclude tools in excludelist (unless they're essential)
if tool_name in tools_excludelist and tool_name not in essential_tools:
should_include = False

if should_include:
Expand Down Expand Up @@ -236,10 +240,10 @@ def _get_agent_config(self):
# Set defaults for missing values
if "large_file_token_threshold" not in config:
config["large_file_token_threshold"] = 25000
if "tools_whitelist" not in config:
config["tools_whitelist"] = []
if "tools_blacklist" not in config:
config["tools_blacklist"] = []
if "tools_includelist" not in config:
config["tools_includelist"] = []
if "tools_excludelist" not in config:
config["tools_excludelist"] = []

# Apply configuration to instance
self.large_file_token_threshold = config["large_file_token_threshold"]
Expand All @@ -255,12 +259,6 @@ def get_local_tool_schemas(self):
if hasattr(tool_module, "schema"):
schemas.append(tool_module.schema)

# Add git schemas from the tool registry
git_tools = [git_diff, git_log, git_show, git_status]
for git_tool in git_tools:
if hasattr(git_tool, "schema"):
schemas.append(git_tool.schema)

return schemas

async def initialize_mcp_tools(self):
Expand Down Expand Up @@ -935,6 +933,7 @@ async def process_tool_calls(self, tool_call_response):
"""
Track tool usage before calling the base implementation.
"""
self.auto_save_session()

if self.partial_response_tool_calls:
for tool_call in self.partial_response_tool_calls:
Expand Down Expand Up @@ -976,6 +975,7 @@ async def reply_completed(self):
) = await self._process_tool_commands(content)

if self.agent_finished:
self.tool_usage_history = []
return True

# Since we are no longer suppressing, the partial_response_content IS the final content.
Expand Down
52 changes: 45 additions & 7 deletions aider/coders/base_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from aider.repo import ANY_GIT_ERROR, GitRepo
from aider.repomap import RepoMap
from aider.run_cmd import run_cmd
from aider.sessions import SessionManager
from aider.utils import format_tokens, is_image_file

from ..dump import dump # noqa: F401
Expand Down Expand Up @@ -1121,6 +1122,8 @@ async def _run_linear(self, with_message=None, preproc=True):
self.keyboard_interrupt()
except (asyncio.CancelledError, IndexError):
pass

self.auto_save_session()
except EOFError:
return
finally:
Expand Down Expand Up @@ -1272,6 +1275,8 @@ async def _run_patched(self, with_message=None, preproc=True):
self.io.stop_spinner()

self.keyboard_interrupt()

self.auto_save_session()
except EOFError:
return
finally:
Expand Down Expand Up @@ -2240,12 +2245,20 @@ def _print_tool_call_info(self, server_tool_calls):

for server, tool_calls in server_tool_calls.items():
for tool_call in tool_calls:
self.io.tool_output(f"Tool Call: {tool_call.function.name}")
color_start = "[blue]" if self.pretty else ""
color_end = "[/blue]" if self.pretty else ""

self.io.tool_output(
f"{color_start}Tool Call:{color_end} {server.name} • {tool_call.function.name}"
)
# Parse and format arguments as headers with values
if tool_call.function.arguments:
# Only do JSON unwrapping for tools containing "replace" in their name
if "replace" in tool_call.function.name.lower():
if (
"replace" in tool_call.function.name.lower()
or "insert" in tool_call.function.name.lower()
or "update" in tool_call.function.name.lower()
):
try:
args_dict = json.loads(tool_call.function.arguments)
first_key = True
Expand All @@ -2258,7 +2271,7 @@ def _print_tool_call_info(self, server_tool_calls):
if first_key:
self.io.tool_output("\n")
first_key = False
self.io.tool_output(f"{key}:")
self.io.tool_output(f"{color_start}{key}:{color_end}")
# Split the value by newlines and output each line separately
if isinstance(value, str):
for line in value.split("\n"):
Expand All @@ -2269,13 +2282,11 @@ def _print_tool_call_info(self, server_tool_calls):
except json.JSONDecodeError:
# If JSON parsing fails, show raw arguments
raw_args = tool_call.function.arguments
self.io.tool_output(f"Arguments: {raw_args}")
self.io.tool_output(f"{color_start}Arguments:{color_end} {raw_args}")
else:
# For non-replace tools, show raw arguments
raw_args = tool_call.function.arguments
self.io.tool_output(f"Arguments: {raw_args}")

self.io.tool_output(f"MCP Server: {server.name}")
self.io.tool_output(f"{color_start}Arguments:{color_end} {raw_args}")

if self.verbose:
self.io.tool_output(f"Tool ID: {tool_call.id}")
Expand All @@ -2295,7 +2306,15 @@ def _gather_server_tool_calls(self, tool_calls):
return None

server_tool_calls = {}
tool_id_set = set()

for tool_call in tool_calls:
# LLM APIs sometimes return duplicates and that's annoying part 3
if tool_call.get("id") in tool_id_set:
continue

tool_id_set.add(tool_call.get("id"))

# Check if this tool_call matches any MCP tool
for server_name, server_tools in self.mcp_tools:
for tool in server_tools:
Expand Down Expand Up @@ -2343,8 +2362,16 @@ async def _exec_server_tools(server, tool_calls_list):
try:
# Connect to the server once
session = await server.connect()
tool_id_set = set()

# Execute all tool calls for this server
for tool_call in tool_calls_list:
# LLM APIs sometimes return duplicates and that's annoying part 4
if tool_call.id in tool_id_set:
continue

tool_id_set.add(tool_call.id)

try:
# Arguments can be a stream of JSON objects.
# We need to parse them and run a tool call for each.
Expand Down Expand Up @@ -3491,6 +3518,17 @@ def apply_edits(self, edits):
def apply_edits_dry_run(self, edits):
return edits

def auto_save_session(self):
"""Automatically save the current session as 'auto-save'."""
if not getattr(self.args, "auto_save", False):
return
try:
session_manager = SessionManager(self, self.io)
session_manager.save_session("auto-save", False)
except Exception:
# Don't show errors for auto-save to avoid interrupting the user experience
pass

async def run_shell_commands(self):
if not self.suggest_shell_commands:
return ""
Expand Down
Loading
Loading