diff --git a/.github/workflows/ubuntu-tests.yml b/.github/workflows/ubuntu-tests.yml index e26127489a3..f84933f6b9d 100644 --- a/.github/workflows/ubuntu-tests.yml +++ b/.github/workflows/ubuntu-tests.yml @@ -50,6 +50,7 @@ jobs: uv pip install --system \ pytest \ pytest-asyncio \ + pytest-mock \ -r requirements/requirements.in \ -r requirements/requirements-help.in \ -r requirements/requirements-playwright.in \ diff --git a/.github/workflows/windows-tests.yml b/.github/workflows/windows-tests.yml index c5e3f12d988..8b809a17405 100644 --- a/.github/workflows/windows-tests.yml +++ b/.github/workflows/windows-tests.yml @@ -42,11 +42,10 @@ jobs: run: | python -m pip install --upgrade pip pip install uv - uv pip install --system pytest pytest-asyncio -r requirements/requirements.in -r requirements/requirements-help.in -r requirements/requirements-playwright.in '.[help,playwright]' + uv pip install --system pytest pytest-asyncio pytest-mock -r requirements/requirements.in -r requirements/requirements-help.in -r requirements/requirements-playwright.in '.[help,playwright]' - name: Run tests env: AIDER_ANALYTICS: false run: | pytest - diff --git a/aider/__init__.py b/aider/__init__.py index 6189cdf4fae..18e3a333286 100644 --- a/aider/__init__.py +++ b/aider/__init__.py @@ -1,6 +1,6 @@ from packaging import version -__version__ = "0.91.2.dev" +__version__ = "0.91.3.dev" safe_version = __version__ try: diff --git a/aider/args.py b/aider/args.py index 41d883249ad..1178e2cbadf 100644 --- a/aider/args.py +++ b/aider/args.py @@ -838,6 +838,7 @@ def get_parser(default_config_files, git_root): ) group.add_argument( "--yes-always", + "--yes", action="store_true", help="Always say yes to every confirmation (not including cli commands)", default=None, diff --git a/aider/coders/agent_coder.py b/aider/coders/agent_coder.py index 6577c0cd881..03da769386b 100644 --- a/aider/coders/agent_coder.py +++ b/aider/coders/agent_coder.py @@ -369,23 +369,26 @@ def get_local_tool_schemas(self): async def initialize_mcp_tools(self): await super().initialize_mcp_tools() - local_tools = self.get_local_tool_schemas() - if not local_tools: - return + server_name = "Local" + + if server_name not in [name for name, _ in self.mcp_tools]: + local_tools = self.get_local_tool_schemas() + if not local_tools: + return - local_server_config = {"name": "Local"} - local_server = LocalServer(local_server_config) + local_server_config = {"name": server_name} + local_server = LocalServer(local_server_config) - if not self.mcp_servers: - self.mcp_servers = [] - if not any(isinstance(s, LocalServer) for s in self.mcp_servers): - self.mcp_servers.append(local_server) + if not self.mcp_servers: + self.mcp_servers = [] + if not any(isinstance(s, LocalServer) for s in self.mcp_servers): + self.mcp_servers.append(local_server) - if not self.mcp_tools: - self.mcp_tools = [] + if not self.mcp_tools: + self.mcp_tools = [] - if "local_tools" not in [name for name, _ in self.mcp_tools]: - self.mcp_tools.append((local_server.name, local_tools)) + if server_name not in [name for name, _ in self.mcp_tools]: + self.mcp_tools.append((local_server.name, local_tools)) async def _execute_local_tool_calls(self, tool_calls_list): tool_responses = [] diff --git a/aider/coders/architect_coder.py b/aider/coders/architect_coder.py index 4174528e748..cf76fb6eb0a 100644 --- a/aider/coders/architect_coder.py +++ b/aider/coders/architect_coder.py @@ -41,14 +41,6 @@ async def reply_completed(self): kwargs["cache_prompts"] = False kwargs["num_cache_warming_pings"] = 0 kwargs["summarize_from_coder"] = False - kwargs["mcp_servers"] = [] # Empty to skip initialization - - coder = await Coder.create(**kwargs) - # Transfer MCP state to avoid re-initialization - coder.mcp_servers = self.mcp_servers - coder.mcp_tools = self.mcp_tools - # Transfer TUI app weak reference - coder.tui = self.tui new_kwargs = dict(io=self.io, from_coder=self) new_kwargs.update(kwargs) diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index fa0a3f302bc..51b04807273 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -233,22 +233,32 @@ async def create( kwargs = use_kwargs from_coder.ok_to_warm_cache = False + res = None if ( getattr(main_model, "copy_paste_mode", False) and getattr(main_model, "copy_paste_transport", "api") == "clipboard" ): res = coders.CopyPasteCoder(main_model, io, args=args, **kwargs) + + if not res: + for coder in coders.__all__: + if hasattr(coder, "edit_format") and coder.edit_format == edit_format: + res = coder(main_model, io, args=args, **kwargs) + + if res is not None: + if from_coder: + if from_coder.mcp_servers and kwargs.get("mcp_servers", False): + res.mcp_servers = from_coder.mcp_servers + res.mcp_tools = from_coder.mcp_tools + + # Transfer TUI app weak reference + res.tui = from_coder.tui + await res.initialize_mcp_tools() + res.original_kwargs = dict(kwargs) return res - for coder in coders.__all__: - if hasattr(coder, "edit_format") and coder.edit_format == edit_format: - res = coder(main_model, io, args=args, **kwargs) - await res.initialize_mcp_tools() - res.original_kwargs = dict(kwargs) - return res - valid_formats = [ str(c.edit_format) for c in coders.__all__ @@ -2703,6 +2713,12 @@ async def initialize_mcp_tools(self): tools = [] async def get_server_tools(server): + # Check if we already have tools for this server in mcp_tools + if self.mcp_tools: + for server_name, server_tools in self.mcp_tools: + if server_name == server.name: + return (server.name, server_tools) + try: session = await server.connect() server_tools = await experimental_mcp_client.load_mcp_tools( @@ -3266,8 +3282,20 @@ def consolidate_chunks(self): self.partial_response_reasoning_content = reasoning_content or "" try: - if not self.partial_response_reasoning_content: - self.partial_response_content = response.choices[0].message.content or "" + content = response.choices[0].message.content + if isinstance(content, list): + # OpenAI-compatible APIs sometimes return content as a list + # of blocks; join the textual pieces for display. + content = "".join( + block.get("text", "") + for block in content + if isinstance(block, dict) and block.get("type") == "output_text" + ) or "".join( + block.get("text", "") + for block in content + if isinstance(block, dict) and block.get("type") == "text" + ) + self.partial_response_content = content or "" except AttributeError as e: content_err = e diff --git a/aider/commands/help.py b/aider/commands/help.py index 545fd5816fa..866e7f28e91 100644 --- a/aider/commands/help.py +++ b/aider/commands/help.py @@ -52,14 +52,8 @@ async def execute(cls, io, coder, args, **kwargs): kwargs["suggest_shell_commands"] = False kwargs["cache_prompts"] = False kwargs["num_cache_warming_pings"] = 0 - kwargs["mcp_servers"] = [] # Empty to skip initialization help_coder = await Coder.create(**kwargs) - # Transfer MCP state to avoid re-initialization - help_coder.mcp_servers = coder.mcp_servers - help_coder.mcp_tools = coder.mcp_tools - # Transfer TUI app weak reference - help_coder.tui = coder.tui user_msg = help_instance.ask(args) user_msg += """ # Announcement lines from when this session of aider was launched: diff --git a/aider/commands/lint.py b/aider/commands/lint.py index fc6d45ead57..939bd6b5372 100644 --- a/aider/commands/lint.py +++ b/aider/commands/lint.py @@ -2,6 +2,7 @@ from aider.commands.utils.base_command import BaseCommand from aider.commands.utils.helpers import format_command_result +from aider.utils import expand_glob_patterns class LintCommand(BaseCommand): @@ -11,7 +12,16 @@ class LintCommand(BaseCommand): @classmethod async def execute(cls, io, coder, args, **kwargs): """Execute the lint command with given parameters.""" - fnames = kwargs.get("fnames", None) + fnames = None + + # Get files from CLI arguments if available + system_args = kwargs.get("system_args") + if system_args: + cli_files = getattr(system_args, "files", []) or [] + cli_file_arg = getattr(system_args, "file", []) or [] + all_cli_files = cli_files + cli_file_arg + if all_cli_files: + fnames = expand_glob_patterns(all_cli_files) if not coder.repo: io.tool_error("No git repository found.") @@ -21,7 +31,7 @@ async def execute(cls, io, coder, args, **kwargs): fnames = coder.get_inchat_relative_files() # If still no files, get all dirty files in the repo - if not fnames and coder.repo: + if not fnames: fnames = coder.repo.get_dirty_files() if not fnames: diff --git a/aider/commands/model.py b/aider/commands/model.py index f058a2f5615..fd2a2d2b068 100644 --- a/aider/commands/model.py +++ b/aider/commands/model.py @@ -94,10 +94,7 @@ async def execute(cls, io, coder, args, **kwargs): @classmethod def get_completions(cls, io, coder, args) -> List[str]: """Get completion options for model command.""" - from aider.llm import litellm - - model_names = litellm.model_cost.keys() - return list(model_names) + return models.get_chat_model_names() @classmethod def get_help(cls) -> str: diff --git a/aider/commands/models.py b/aider/commands/models.py index 9d9624d1f84..2af2a56771d 100644 --- a/aider/commands/models.py +++ b/aider/commands/models.py @@ -24,10 +24,7 @@ async def execute(cls, io, coder, args, **kwargs): @classmethod def get_completions(cls, io, coder, args) -> List[str]: """Get completion options for models command.""" - from aider.llm import litellm - - model_names = litellm.model_cost.keys() - return list(model_names) + return models.get_chat_model_names() @classmethod def get_help(cls) -> str: diff --git a/aider/commands/read_only.py b/aider/commands/read_only.py index 2fc43bcb647..848f368da4b 100644 --- a/aider/commands/read_only.py +++ b/aider/commands/read_only.py @@ -207,9 +207,43 @@ def _add_read_only_directory( @classmethod def get_completions(cls, io, coder, args) -> List[str]: """Get completion options for read-only command.""" - # For read-only command, we could return file paths for completion - # For now, return empty list - the completion system will handle path completion - return [] + from pathlib import Path + + root = Path(coder.root) if hasattr(coder, "root") else Path.cwd() + + # Handle the prefix - could be partial path like "src/ma" or just "ma" + if "/" in args: + # Has directory component + dir_part, file_part = args.rsplit("/", 1) + search_dir = root / dir_part + search_prefix = file_part.lower() + path_prefix = dir_part + "/" + else: + search_dir = root + search_prefix = args.lower() + path_prefix = "" + + completions = [] + try: + if search_dir.exists() and search_dir.is_dir(): + for entry in search_dir.iterdir(): + name = entry.name + if search_prefix and search_prefix not in name.lower(): + continue + # Add trailing slash for directories + if entry.is_dir(): + completions.append(path_prefix + name + "/") + else: + completions.append(path_prefix + name) + except (PermissionError, OSError): + pass + + add_completions = coder.commands.get_completions("/add") + for c in add_completions: + if args.lower() in str(c).lower() and str(c) not in completions: + completions.append(str(c)) + + return sorted(completions) @classmethod def get_help(cls) -> str: diff --git a/aider/commands/read_only_stub.py b/aider/commands/read_only_stub.py index 5d626e877da..cb98b592123 100644 --- a/aider/commands/read_only_stub.py +++ b/aider/commands/read_only_stub.py @@ -206,10 +206,44 @@ def _add_read_only_directory( @classmethod def get_completions(cls, io, coder, args) -> List[str]: - """Get completion options for read-only-stub command.""" - # For read-only-stub command, we could return file paths for completion - # For now, return empty list - the completion system will handle path completion - return [] + """Get completion options for read-only command.""" + from pathlib import Path + + root = Path(coder.root) if hasattr(coder, "root") else Path.cwd() + + # Handle the prefix - could be partial path like "src/ma" or just "ma" + if "/" in args: + # Has directory component + dir_part, file_part = args.rsplit("/", 1) + search_dir = root / dir_part + search_prefix = file_part.lower() + path_prefix = dir_part + "/" + else: + search_dir = root + search_prefix = args.lower() + path_prefix = "" + + completions = [] + try: + if search_dir.exists() and search_dir.is_dir(): + for entry in search_dir.iterdir(): + name = entry.name + if search_prefix and search_prefix not in name.lower(): + continue + # Add trailing slash for directories + if entry.is_dir(): + completions.append(path_prefix + name + "/") + else: + completions.append(path_prefix + name) + except (PermissionError, OSError): + pass + + add_completions = coder.commands.get_completions("/add") + for c in add_completions: + if args.lower() in str(c).lower() and str(c) not in completions: + completions.append(str(c)) + + return sorted(completions) @classmethod def get_help(cls) -> str: diff --git a/aider/commands/utils/base_command.py b/aider/commands/utils/base_command.py index 6ae2faa26c1..6603ea11cfc 100644 --- a/aider/commands/utils/base_command.py +++ b/aider/commands/utils/base_command.py @@ -116,14 +116,7 @@ async def _generic_chat_command(cls, io, coder, args, edit_format, placeholder=N "args": coder.args, } - kwargs["mcp_servers"] = [] # Empty to skip initialization - new_coder = await Coder.create(**kwargs) - # Transfer MCP state to avoid re-initialization - new_coder.mcp_servers = coder.mcp_servers - new_coder.mcp_tools = coder.mcp_tools - # Transfer TUI app weak reference - new_coder.tui = coder.tui await new_coder.generate(user_message=user_msg, preproc=False) coder.aider_commit_hashes = new_coder.aider_commit_hashes diff --git a/aider/helpers/model_providers.py b/aider/helpers/model_providers.py new file mode 100644 index 00000000000..fe9c58b4054 --- /dev/null +++ b/aider/helpers/model_providers.py @@ -0,0 +1,716 @@ +"""Unified model provider metadata caching and lookup. + +Historically aider kept separate modules per provider (OpenRouter vs OpenAI-like). +Those grew unwieldy and duplicated caching, request, and normalization logic. +This helper centralizes that behavior so every OpenAI-compatible endpoint defines +a small config blob and inherits the same cache + LiteLLM registration plumbing. +Provider configs remain curated via ``scripts/generate_providers.py`` and the +static per-model fallback metadata is still cleaned up with ``clean_metadata.py``. +""" + +from __future__ import annotations + +import importlib.resources as importlib_resources +import json +import os +import re +import time +from copy import deepcopy +from pathlib import Path +from typing import Dict, Optional + +import requests + +try: # Optional imports; litellm might not be installed during docs builds + from litellm.llms.custom_httpx.http_handler import HTTPHandler + from litellm.llms.custom_llm import CustomLLM, CustomLLMError + from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler +except Exception: # pragma: no cover - only during partial installs + CustomLLM = None # type: ignore + CustomLLMError = Exception # type: ignore + OpenAILikeChatHandler = None # type: ignore + HTTPHandler = None # type: ignore + +RESOURCE_FILE = "providers.json" +_PROVIDERS_REGISTERED = False +_CUSTOM_HANDLERS: Dict[str, "_JSONOpenAIProvider"] = {} + + +def _coerce_str(value): + """Return the first string representation that litellm expects.""" + if isinstance(value, str): + return value + if isinstance(value, list) and value: + return value[0] + return None + + +def _first_env_value(names): + """Return the first non-empty environment variable for the provided names.""" + if not names: + return None + if isinstance(names, str): + names = [names] + for env_name in names or []: + if not env_name: + continue + val = os.environ.get(env_name) + if val: + return val + return None + + +class _JSONOpenAIProvider(CustomLLM if CustomLLM is not None else object): # type: ignore[misc] + """CustomLLM wrapper that routes OpenAI-compatible providers through LiteLLM.""" + + def __init__(self, slug: str, config: Dict): + if CustomLLM is None or OpenAILikeChatHandler is None: # pragma: no cover + raise RuntimeError("litellm custom handler support unavailable") + super().__init__() # type: ignore[misc] + self.slug = slug + self.config = config + self._chat_handler = OpenAILikeChatHandler() + + def _resolve_api_base(self, api_base: Optional[str]) -> str: + base = ( + api_base + or _first_env_value(self.config.get("base_url_env")) + or self.config.get("api_base") + ) + if not base: + raise CustomLLMError(500, f"{self.slug} missing base URL") # type: ignore[misc] + return base.rstrip("/") + + def _resolve_api_key(self, api_key: Optional[str]) -> Optional[str]: + if api_key: + return api_key + env_val = _first_env_value(self.config.get("api_key_env")) + return env_val + + def _apply_special_handling(self, messages): + special = self.config.get("special_handling") or {} + if special.get("convert_content_list_to_string"): + from litellm.litellm_core_utils.prompt_templates.common_utils import ( + handle_messages_with_content_list_to_str_conversion, + ) + + return handle_messages_with_content_list_to_str_conversion(messages) + return messages + + def _inject_headers(self, headers): + defaults = self.config.get("default_headers") or {} + combined = dict(defaults) + combined.update(headers or {}) + return combined + + def _normalize_model_name(self, model: str) -> str: + if not isinstance(model, str): + return model + trimmed = model + if trimmed.startswith(f"{self.slug}/"): + trimmed = trimmed.split("/", 1)[1] + hf_namespace = self.config.get("hf_namespace") + if hf_namespace and not trimmed.startswith("hf:"): + trimmed = f"hf:{trimmed}" + return trimmed + + def _build_request_params(self, optional_params, stream: bool): + params = dict(optional_params or {}) + default_headers = dict(self.config.get("default_headers") or {}) + headers = params.setdefault("extra_headers", default_headers) + if headers is default_headers and default_headers: + params["extra_headers"] = dict(default_headers) + if stream: + params["stream"] = True + return params + + def _invoke_handler( + self, + *, + model, + messages, + api_base, + custom_prompt_dict, + model_response, + print_verbose, + encoding, + api_key, + logging_obj, + optional_params, + litellm_params, + logger_fn, + headers, + timeout, + client, + stream: bool, + ): + api_base = self._resolve_api_base(api_base) + api_key = self._resolve_api_key(api_key) + headers = self._inject_headers(headers) + params = self._build_request_params(optional_params, stream) + cleaned_messages = self._apply_special_handling(messages) + api_model = self._normalize_model_name(model) + http_client = None + if HTTPHandler is not None and isinstance(client, HTTPHandler): + http_client = client + return self._chat_handler.completion( + model=api_model, + messages=cleaned_messages, + api_base=api_base, + custom_llm_provider="openai", + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=params, + litellm_params=litellm_params or {}, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=http_client, + ) + + def completion( + self, + model, + messages, + api_base, + custom_prompt_dict, + model_response, + print_verbose, + encoding, + api_key, + logging_obj, + optional_params, + litellm_params=None, + acompletion=None, + logger_fn=None, + headers=None, + timeout=None, + client=None, + ): + return self._invoke_handler( + model=model, + messages=messages, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, + stream=False, + ) + + def streaming( + self, + model, + messages, + api_base, + custom_prompt_dict, + model_response, + print_verbose, + encoding, + api_key, + logging_obj, + optional_params, + litellm_params=None, + acompletion=None, + logger_fn=None, + headers=None, + timeout=None, + client=None, + ): + return self._invoke_handler( + model=model, + messages=messages, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, + stream=True, + ) + + def acompletion( + self, + model, + messages, + api_base, + custom_prompt_dict, + model_response, + print_verbose, + encoding, + api_key, + logging_obj, + optional_params, + litellm_params=None, + acompletion=None, + logger_fn=None, + headers=None, + timeout=None, + client=None, + ): + return self.completion( + model=model, + messages=messages, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, + ) + + def astreaming( + self, + model, + messages, + api_base, + custom_prompt_dict, + model_response, + print_verbose, + encoding, + api_key, + logging_obj, + optional_params, + litellm_params=None, + acompletion=None, + logger_fn=None, + headers=None, + timeout=None, + client=None, + ): + return self.streaming( + model=model, + messages=messages, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, + ) + + +def _register_provider_with_litellm(slug: str, config: Dict) -> None: + """Register provider metadata and custom handlers with LiteLLM.""" + try: + from litellm.llms.openai_like.json_loader import ( + JSONProviderRegistry, + SimpleProviderConfig, + ) + except Exception: + return + + JSONProviderRegistry.load() + + base_url = config.get("api_base") + api_key_env = _coerce_str(config.get("api_key_env")) + if not base_url or not api_key_env: + return + + if not JSONProviderRegistry.exists(slug): + payload = { + "base_url": base_url, + "api_key_env": api_key_env, + } + + api_base_env = _coerce_str(config.get("base_url_env")) + if api_base_env: + payload["api_base_env"] = api_base_env + + if config.get("param_mappings"): + payload["param_mappings"] = config["param_mappings"] + if config.get("special_handling"): + payload["special_handling"] = config["special_handling"] + if config.get("base_class"): + payload["base_class"] = config["base_class"] + + JSONProviderRegistry._providers[slug] = SimpleProviderConfig(slug, payload) + + try: + import litellm # noqa: WPS433 + except Exception: + return + + provider_list = getattr(litellm, "provider_list", None) + if isinstance(provider_list, list) and slug not in provider_list: + provider_list.append(slug) + + openai_like = getattr(litellm, "_openai_like_providers", None) + if isinstance(openai_like, list) and slug not in openai_like: + openai_like.append(slug) + + handler = _CUSTOM_HANDLERS.get(slug) + if handler is None and CustomLLM is not None and OpenAILikeChatHandler is not None: + handler = _JSONOpenAIProvider(slug, config) + _CUSTOM_HANDLERS[slug] = handler + + if handler is None: + return + + already_present = any(item.get("provider") == slug for item in litellm.custom_provider_map) + if not already_present: + litellm.custom_provider_map.append({"provider": slug, "custom_handler": handler}) + try: + litellm.custom_llm_setup() + except Exception: + pass + + +def _deep_merge(base: Dict, override: Dict) -> Dict: + """Recursively merge override dict into base without mutating inputs.""" + result = deepcopy(base) + for key, value in override.items(): + if isinstance(value, dict) and isinstance(result.get(key), dict): + result[key] = _deep_merge(result[key], value) + else: + result[key] = deepcopy(value) + return result + + +def _load_provider_configs() -> Dict[str, Dict]: + """Load provider configuration overrides from the packaged JSON file.""" + configs: Dict[str, Dict] = {} + try: + resource = importlib_resources.files("aider.resources").joinpath(RESOURCE_FILE) + data = json.loads(resource.read_text()) + except (FileNotFoundError, json.JSONDecodeError): # pragma: no cover + data = {} + + for provider, override in data.items(): + base = configs.get(provider, {}) + configs[provider] = _deep_merge(base, override) + + return configs + + +PROVIDER_CONFIGS = _load_provider_configs() + + +class ModelProviderManager: + CACHE_TTL = 60 * 60 * 24 # 24 hours + DEFAULT_TOKEN_PRICE_RATIO = 1000000 + + def __init__(self, provider_configs: Optional[Dict[str, Dict]] = None) -> None: + self.cache_dir = Path.home() / ".aider" / "caches" + self.verify_ssl: bool = True + self.provider_configs = provider_configs or deepcopy(PROVIDER_CONFIGS) + self._provider_cache: Dict[str, Dict | None] = {} + self._cache_loaded: Dict[str, bool] = {} + for name in self.provider_configs: + self._provider_cache[name] = None + self._cache_loaded[name] = False + + def set_verify_ssl(self, verify_ssl: bool) -> None: + self.verify_ssl = verify_ssl + + def supports_provider(self, provider: Optional[str]) -> bool: + return bool(provider and provider in self.provider_configs) + + def get_provider_config(self, provider: Optional[str]) -> Optional[Dict]: + if not provider: + return None + config = self.provider_configs.get(provider) + if not config: + return None + config = dict(config) + config.setdefault("litellm_provider", provider) + return config + + def get_provider_base_url(self, provider: Optional[str]) -> Optional[str]: + config = self.get_provider_config(provider) + if not config: + return None + base_envs = config.get("base_url_env") or [] + for env_var in base_envs: + val = os.environ.get(env_var) + if val: + return val.rstrip("/") + return config.get("api_base") + + def get_required_api_keys(self, provider: Optional[str]) -> list[str]: + config = self.get_provider_config(provider) + if not config: + return [] + return list(config.get("api_key_env", [])) + + def get_model_info(self, model: str) -> Dict: + provider, route = self._split_model(model) + if not provider or not self._ensure_provider_state(provider): + return {} + + content = self._ensure_content(provider) + record = self._find_record(content, route) + if not record and self.refresh_provider_cache(provider): + content = self._provider_cache.get(provider) + record = self._find_record(content, route) + if not record: + return {} + return self._record_to_info(record, provider) + + def get_models_for_listing(self) -> Dict[str, Dict]: + listings: Dict[str, Dict] = {} + for provider in list(self.provider_configs.keys()): + content = self._ensure_content(provider) + if not content or "data" not in content: + continue + for record in content["data"]: + model_id = record.get("id") + if not model_id: + continue + info = self._record_to_info(record, provider) + if info: + listings[model_id] = info + return listings + + def refresh_provider_cache(self, provider: str) -> bool: + if not self._ensure_provider_state(provider): + return False + config = self.provider_configs[provider] + if not config.get("models_url") and not config.get("api_base"): + return False + self._provider_cache[provider] = None + self._cache_loaded[provider] = True + self._update_cache(provider) + return bool(self._provider_cache.get(provider)) + + def _ensure_provider_state(self, provider: str) -> bool: + if provider not in self.provider_configs: + return False + self._provider_cache.setdefault(provider, None) + self._cache_loaded.setdefault(provider, False) + return True + + def _split_model(self, model: str) -> tuple[Optional[str], str]: + if "/" not in model: + return None, model + provider, route = model.split("/", 1) + return provider, route + + def _ensure_content(self, provider: str) -> Optional[Dict]: + self._load_cache(provider) + if not self._provider_cache.get(provider): + self._update_cache(provider) + return self._provider_cache.get(provider) + + def _find_record(self, content: Optional[Dict], route: str) -> Optional[Dict]: + if not content or "data" not in content: + return None + candidates = {route} + if ":" in route: + candidates.add(route.split(":", 1)[0]) + return next((item for item in content["data"] if item.get("id") in candidates), None) + + def _record_to_info(self, record: Dict, provider: str) -> Dict: + context_len = _first_value( + record, + "max_input_tokens", + "max_tokens", + "max_output_tokens", + "context_length", + "context_window", + "top_provider_context_length", + "top_provider", + ) + + if isinstance(context_len, dict): + context_len = context_len.get("context_length") or context_len.get("max_tokens") + + pricing = record.get("pricing", {}) if isinstance(record.get("pricing"), dict) else {} + input_cost = _cost_per_token( + _first_value(pricing, "prompt", "input", "prompt_tokens") + or _first_value(record, "input_cost_per_token", "prompt_cost_per_token") + ) + output_cost = _cost_per_token( + _first_value(pricing, "completion", "output", "completion_tokens") + or _first_value(record, "output_cost_per_token", "completion_cost_per_token") + ) + + max_tokens = _first_value( + record, + "max_tokens", + "max_input_tokens", + "context_length", + "context_window", + "top_provider_context_length", + ) + max_output_tokens = _first_value( + record, + "max_output_tokens", + "max_tokens", + "context_length", + "context_window", + "top_provider_context_length", + ) + + if max_tokens is None: + max_tokens = context_len + if max_output_tokens is None: + max_output_tokens = context_len + + info = { + "max_input_tokens": context_len, + "max_tokens": max_tokens, + "max_output_tokens": max_output_tokens, + "input_cost_per_token": ( + input_cost or 0 + ) / self.DEFAULT_TOKEN_PRICE_RATIO, # Might Only Apply to Chutes and Be a thing we configure per-provider + "output_cost_per_token": (output_cost or 0) / self.DEFAULT_TOKEN_PRICE_RATIO, + "litellm_provider": provider, + "mode": record.get("mode", "chat"), + } + return {k: v for k, v in info.items() if v is not None} + + def _get_cache_file(self, provider: str) -> Path: + fname = f"{provider}_models.json" + return self.cache_dir / fname + + def _load_cache(self, provider: str) -> None: + if self._cache_loaded.get(provider): + return + cache_file = self._get_cache_file(provider) + try: + self.cache_dir.mkdir(parents=True, exist_ok=True) + if cache_file.exists(): + cache_age = time.time() - cache_file.stat().st_mtime + if cache_age < self.CACHE_TTL: + try: + self._provider_cache[provider] = json.loads(cache_file.read_text()) + except json.JSONDecodeError: + self._provider_cache[provider] = None + except OSError: + pass + self._cache_loaded[provider] = True + + def _update_cache(self, provider: str) -> None: + payload = self._fetch_provider_models(provider) + cache_file = self._get_cache_file(provider) + + if payload: + self._provider_cache[provider] = payload + try: + cache_file.write_text(json.dumps(payload, indent=2)) + except OSError: + pass + return + + static_models = self.provider_configs[provider].get("static_models") + if static_models and not self._provider_cache.get(provider): + self._provider_cache[provider] = {"data": static_models} + + def _fetch_provider_models(self, provider: str) -> Optional[Dict]: + config = self.provider_configs[provider] + models_url = config.get("models_url") + if not models_url: + api_base = config.get("api_base") + if api_base: + models_url = api_base.rstrip("/") + "/models" + if not models_url: + return None + + headers = {} + default_headers = config.get("default_headers") or {} + headers.update(default_headers) + + api_key = self._get_api_key(provider) + requires_api_key = config.get("requires_api_key", True) + + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + elif requires_api_key: + return None + + try: + response = requests.get( + models_url, + headers=headers or None, + timeout=config.get("timeout", 10), + verify=self.verify_ssl, + ) + response.raise_for_status() + return response.json() + except Exception as ex: # noqa: BLE001 + print(f"Failed to fetch {provider} model list: {ex}") + return None + + def _get_api_key(self, provider: str) -> Optional[str]: + config = self.provider_configs[provider] + for env_var in config.get("api_key_env", []): + value = os.environ.get(env_var) + if value: + return value + return None + + +def ensure_litellm_providers_registered() -> None: + """One-time registration guard for LiteLLM provider metadata.""" + global _PROVIDERS_REGISTERED + if _PROVIDERS_REGISTERED: + return + for slug, cfg in PROVIDER_CONFIGS.items(): + _register_provider_with_litellm(slug, cfg) + _PROVIDERS_REGISTERED = True + + +_NUMBER_RE = re.compile(r"-?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?") + + +def _cost_per_token(val: Optional[str | float | int]) -> Optional[float]: + """Parse token pricing strings into floats, tolerating currency prefixes.""" + if val in (None, "", "-", "N/A"): + return None + if val == "0": + return 0.0 + if isinstance(val, str): + cleaned = val.strip().replace(",", "") + if cleaned.startswith("$"): + cleaned = cleaned[1:] + match = _NUMBER_RE.search(cleaned) + if not match: + return None + val = match.group(0) + try: + return float(val) + except (TypeError, ValueError): + return None + + +def _first_value(record: Dict, *keys: str): + """Return the first non-empty value for the provided keys.""" + for key in keys: + value = record.get(key) + if value not in (None, ""): + return value + return None diff --git a/aider/llm.py b/aider/llm.py index 31a166834c2..ff320dee3d3 100644 --- a/aider/llm.py +++ b/aider/llm.py @@ -6,6 +6,7 @@ from collections.abc import Coroutine from aider.dump import dump # noqa: F401 +from aider.helpers.model_providers import ensure_litellm_providers_registered warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") @@ -53,6 +54,9 @@ def _load_litellm(self): self._lazy_module.drop_params = True self._lazy_module._logging._disable_debugging() + # Make sure JSON-based OpenAI-compatible providers are registered + ensure_litellm_providers_registered() + # Patch GLOBAL_LOGGING_WORKER to avoid event loop binding issues # See: https://github.com/BerriAI/litellm/issues/16518 # See: https://github.com/BerriAI/litellm/issues/14521 diff --git a/aider/main.py b/aider/main.py index a5b79e7137b..2e8891db99a 100644 --- a/aider/main.py +++ b/aider/main.py @@ -10,7 +10,6 @@ pass import asyncio -import glob import json import os import re @@ -515,25 +514,6 @@ async def sanity_check_repo(repo, io): return False -def expand_glob_patterns(patterns, root="."): - """Expand glob patterns in a list of file paths.""" - expanded_files = [] - for pattern in patterns: - # Check if the pattern contains glob characters - if any(c in pattern for c in "*?[]"): - # Use glob to expand the pattern - matches = glob.glob(pattern, recursive=True) - if matches: - expanded_files.extend(matches) - else: - # If no matches, keep the original pattern - expanded_files.append(pattern) - else: - # Not a glob pattern, keep as is - expanded_files.append(pattern) - return expanded_files - - PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__)) log_file = None file_excludelist = { @@ -841,12 +821,12 @@ def get_io(pretty): # Expand glob patterns in files and file arguments all_files = args.files + (args.file or []) - all_files = expand_glob_patterns(all_files) + all_files = utils.expand_glob_patterns(all_files) fnames = [str(Path(fn).resolve()) for fn in all_files] # Expand glob patterns in read arguments read_patterns = args.read or [] - read_expanded = expand_glob_patterns(read_patterns) + read_expanded = utils.expand_glob_patterns(read_patterns) read_only_fnames = [] for fn in read_expanded: path = Path(fn).expanduser().resolve() @@ -1185,6 +1165,7 @@ def apply_model_overrides(model_name): if args.stream: io.tool_warning( f"Warning: Streaming is not supported by {main_model.name}. Disabling streaming." + " Set stream: false in config file or use --no-stream to skip this warning." ) args.stream = False @@ -1323,7 +1304,7 @@ def apply_model_overrides(model_name): return await graceful_exit(coder) if args.lint: - await coder.commands.cmd_lint(fnames=fnames) + await coder.commands.do_run("lint", "") if args.test: if not args.test_cmd: diff --git a/aider/models.py b/aider/models.py index 86c46fd4178..537226e2a27 100644 --- a/aider/models.py +++ b/aider/models.py @@ -17,9 +17,9 @@ from aider import __version__ from aider.dump import dump # noqa: F401 +from aider.helpers.model_providers import ModelProviderManager from aider.helpers.requests import model_request_parser from aider.llm import litellm -from aider.openrouter import OpenRouterModelManager from aider.sendchat import sanity_check_messages from aider.utils import check_pip_install_extra @@ -157,13 +157,13 @@ def __init__(self): self.verify_ssl = True self._cache_loaded = False - # Manager for the cached OpenRouter model database - self.openrouter_manager = OpenRouterModelManager() + # Manager for provider-specific cached model databases + self.provider_manager = ModelProviderManager() + self.openai_provider_manager = self.provider_manager # Backwards compatibility alias def set_verify_ssl(self, verify_ssl): self.verify_ssl = verify_ssl - if hasattr(self, "openrouter_manager"): - self.openrouter_manager.set_verify_ssl(verify_ssl) + self.provider_manager.set_verify_ssl(verify_ssl) def _load_cache(self): if self._cache_loaded: @@ -241,21 +241,45 @@ def get_model_info(self, model): if "model_prices_and_context_window.json" not in str(ex): print(str(ex)) + provider_info = self._resolve_via_provider(model, cached_info) + if provider_info: + return provider_info + if litellm_info: return litellm_info - if not cached_info and model.startswith("openrouter/"): - # First try using the locally cached OpenRouter model database - openrouter_info = self.openrouter_manager.get_model_info(model) - if openrouter_info: - return openrouter_info + return cached_info + + def _resolve_via_provider(self, model, cached_info): + if cached_info: + return None - # Fallback to legacy web-scraping if the API cache does not contain the model + provider = model.split("/", 1)[0] if "/" in model else None + if not self.provider_manager.supports_provider(provider): + return None + + provider_info = self.provider_manager.get_model_info(model) + if provider_info: + self._record_dynamic_model(model, provider_info) + return provider_info + + if provider == "openrouter": openrouter_info = self.fetch_openrouter_model_info(model) if openrouter_info: + openrouter_info.setdefault("litellm_provider", "openrouter") + self._record_dynamic_model(model, openrouter_info) return openrouter_info - return cached_info + return None + + def _record_dynamic_model(self, model, info): + self.local_model_metadata[model] = info + self._ensure_model_settings_entry(model) + + def _ensure_model_settings_entry(self, model): + if any(ms.name == model for ms in MODEL_SETTINGS): + return + MODEL_SETTINGS.append(ModelSettings(name=model)) def fetch_openrouter_model_info(self, model): """ @@ -300,6 +324,7 @@ def fetch_openrouter_model_info(self, model): "max_output_tokens": context_size, "input_cost_per_token": input_cost, "output_cost_per_token": output_cost, + "litellm_provider": "openrouter", } return params except Exception as e: @@ -355,6 +380,7 @@ def __init__( ) self.info = self.get_model_info(model) + self.litellm_provider = (self.info.get("litellm_provider") or "").lower() # Are all needed keys/params available? res = self.validate_environment() @@ -367,6 +393,7 @@ def __init__( self.max_chat_history_tokens = min(max(max_input_tokens / 16, 1024), 8192) self.configure_model_settings(model) + self._apply_provider_defaults() self.get_weak_model(weak_model) if editor_model is False: @@ -690,6 +717,49 @@ def get_editor_model(self, provided_editor_model, editor_edit_format): return self.editor_model + def _ensure_extra_params_dict(self): + if self.extra_params is None: + self.extra_params = {} + elif not isinstance(self.extra_params, dict): + self.extra_params = dict(self.extra_params) + + def _apply_provider_defaults(self): + provider = (self.info.get("litellm_provider") or "").lower() + self.litellm_provider = provider or None + + if not provider: + return + + provider_config = model_info_manager.provider_manager.get_provider_config(provider) + if not provider_config: + return + + self._ensure_extra_params_dict() + self.extra_params.setdefault("custom_llm_provider", provider) + + if provider_config.get("supports_stream") is False: + # Some OpenAI-compatible providers (e.g., Synthetic) only expose the + # non-streaming /chat/completions endpoint, so forcing streaming would + # loop through LiteLLM's fallback and explode mid-response. Disable the + # streaming flag up front so the caller transparently falls back to + # standard completions for those providers. + self.streaming = False + + base_url = model_info_manager.provider_manager.get_provider_base_url(provider) + if base_url: + self.extra_params.setdefault("base_url", base_url) + + default_headers = provider_config.get("default_headers") or {} + if default_headers: + headers = self.extra_params.setdefault("extra_headers", {}) + for key, value in default_headers.items(): + headers.setdefault(key, value) + + provider_extra = provider_config.get("extra_params") or {} + for key, value in provider_extra.items(): + if key not in self.extra_params: + self.extra_params[key] = value + def tokenizer(self, text): return litellm.encode(model=self.name, text=text) @@ -788,6 +858,12 @@ def fast_validate_environment(self): if var and os.environ.get(var): return dict(keys_in_environment=[var], missing_keys=[]) + if not var and provider and model_info_manager.provider_manager.supports_provider(provider): + provider_keys = model_info_manager.provider_manager.get_required_api_keys(provider) + for env_var in provider_keys: + if os.environ.get(env_var): + return dict(keys_in_environment=[env_var], missing_keys=[]) + def validate_environment(self): res = self.fast_validate_environment() if res: @@ -818,6 +894,14 @@ def validate_environment(self): return res provider = self.info.get("litellm_provider", "").lower() + provider_config = model_info_manager.provider_manager.get_provider_config(provider) + if provider_config: + envs = provider_config.get("api_key_env", []) + available = [env for env in envs if os.environ.get(env)] + if available: + return dict(keys_in_environment=available, missing_keys=[]) + if envs: + return dict(keys_in_environment=False, missing_keys=envs) if provider == "cohere_chat": return validate_variables(["COHERE_API_KEY"]) if provider == "gemini": @@ -1304,31 +1388,35 @@ async def check_for_dependencies(io, model_name): ) -def fuzzy_match_models(name): - name = name.lower() - +def get_chat_model_names(): chat_models = set() model_metadata = list(litellm.model_cost.items()) model_metadata += list(model_info_manager.local_model_metadata.items()) + openai_provider_models = model_info_manager.provider_manager.get_models_for_listing() + model_metadata += list(openai_provider_models.items()) + for orig_model, attrs in model_metadata: - model = orig_model.lower() if attrs.get("mode") != "chat": continue - provider = attrs.get("litellm_provider", "").lower() - if not provider: - continue - provider += "/" - - if model.startswith(provider): - fq_model = orig_model - else: - fq_model = provider + orig_model + provider = (attrs.get("litellm_provider") or "").lower() + if provider: + prefix = provider + "/" + if orig_model.lower().startswith(prefix): + fq_model = orig_model + else: + fq_model = f"{provider}/{orig_model}" + chat_models.add(fq_model) - chat_models.add(fq_model) chat_models.add(orig_model) - chat_models = sorted(chat_models) + return sorted(chat_models) + + +def fuzzy_match_models(name): + name = name.lower() + + chat_models = get_chat_model_names() # exactly matching model # matching_models = [ # (fq,m) for fq,m in chat_models @@ -1338,7 +1426,7 @@ def fuzzy_match_models(name): # return matching_models # Check for model names containing the name - matching_models = [m for m in chat_models if name in m] + matching_models = [m for m in chat_models if name in m.lower()] if matching_models: return sorted(set(matching_models)) diff --git a/aider/openrouter.py b/aider/openrouter.py deleted file mode 100644 index ea641c17fda..00000000000 --- a/aider/openrouter.py +++ /dev/null @@ -1,129 +0,0 @@ -""" -OpenRouter model metadata caching and lookup. - -This module keeps a local cached copy of the OpenRouter model list -(downloaded from ``https://openrouter.ai/api/v1/models``) and exposes a -helper class that returns metadata for a given model in a format compatible -with litellm’s ``get_model_info``. -""" - -from __future__ import annotations - -import json -import time -from pathlib import Path -from typing import Dict - -import requests - - -def _cost_per_token(val: str | None) -> float | None: - """Convert a price string (USD per token) to a float.""" - if val in (None, "", "0"): - return 0.0 if val == "0" else None - try: - return float(val) - except Exception: # noqa: BLE001 - return None - - -class OpenRouterModelManager: - MODELS_URL = "https://openrouter.ai/api/v1/models" - CACHE_TTL = 60 * 60 * 24 # 24 h - - def __init__(self) -> None: - self.cache_dir = Path.home() / ".aider" / "caches" - self.cache_file = self.cache_dir / "openrouter_models.json" - self.content: Dict | None = None - self.verify_ssl: bool = True - self._cache_loaded = False - - # ------------------------------------------------------------------ # - # Public API # - # ------------------------------------------------------------------ # - def set_verify_ssl(self, verify_ssl: bool) -> None: - """Enable/disable SSL verification for API requests.""" - self.verify_ssl = verify_ssl - - def get_model_info(self, model: str) -> Dict: - """ - Return metadata for *model* or an empty ``dict`` when unknown. - - ``model`` should use the aider naming convention, e.g. - ``openrouter/nousresearch/deephermes-3-mistral-24b-preview:free``. - """ - self._ensure_content() - if not self.content or "data" not in self.content: - return {} - - route = self._strip_prefix(model) - - # Consider both the exact id and id without any “:suffix”. - candidates = {route} - if ":" in route: - candidates.add(route.split(":", 1)[0]) - - record = next((item for item in self.content["data"] if item.get("id") in candidates), None) - if not record: - return {} - - context_len = ( - record.get("top_provider", {}).get("context_length") - or record.get("context_length") - or None - ) - - pricing = record.get("pricing", {}) - return { - "max_input_tokens": context_len, - "max_tokens": context_len, - "max_output_tokens": context_len, - "input_cost_per_token": _cost_per_token(pricing.get("prompt")), - "output_cost_per_token": _cost_per_token(pricing.get("completion")), - "litellm_provider": "openrouter", - } - - # ------------------------------------------------------------------ # - # Internal helpers # - # ------------------------------------------------------------------ # - def _strip_prefix(self, model: str) -> str: - return model[len("openrouter/") :] if model.startswith("openrouter/") else model - - def _ensure_content(self) -> None: - self._load_cache() - if not self.content: - self._update_cache() - - def _load_cache(self) -> None: - if self._cache_loaded: - return - try: - self.cache_dir.mkdir(parents=True, exist_ok=True) - if self.cache_file.exists(): - cache_age = time.time() - self.cache_file.stat().st_mtime - if cache_age < self.CACHE_TTL: - try: - self.content = json.loads(self.cache_file.read_text()) - except json.JSONDecodeError: - self.content = None - except OSError: - # Cache directory might be unwritable; ignore. - pass - - self._cache_loaded = True - - def _update_cache(self) -> None: - try: - response = requests.get(self.MODELS_URL, timeout=10, verify=self.verify_ssl) - if response.status_code == 200: - self.content = response.json() - try: - self.cache_file.write_text(json.dumps(self.content, indent=2)) - except OSError: - pass # Non-fatal if we can’t write the cache - except Exception as ex: # noqa: BLE001 - print(f"Failed to fetch OpenRouter model list: {ex}") - try: - self.cache_file.write_text("{}") - except OSError: - pass diff --git a/aider/resources/providers.json b/aider/resources/providers.json new file mode 100644 index 00000000000..7c022a21095 --- /dev/null +++ b/aider/resources/providers.json @@ -0,0 +1,90 @@ +{ + "openrouter": { + "api_base": "https://openrouter.ai/api/v1", + "models_url": "https://openrouter.ai/api/v1/models", + "api_key_env": [ + "OPENROUTER_API_KEY" + ], + "requires_api_key": false, + "default_headers": { + "HTTP-Referer": "https://aider.chat", + "X-Title": "aider" + } + }, + "openai": { + "api_base": "https://api.openai.com/v1", + "models_url": "https://api.openai.com/v1/models", + "api_key_env": [ + "OPENAI_API_KEY" + ], + "base_url_env": [ + "OPENAI_API_BASE" + ], + "display_name": "openai" + }, + "apertis": { + "api_base": "https://api.stima.tech/v1", + "api_key_env": [ + "STIMA_API_KEY" + ], + "display_name": "apertis" + }, + "chutes": { + "api_base": "https://llm.chutes.ai/v1/", + "api_key_env": [ + "CHUTES_API_KEY" + ], + "display_name": "chutes" + }, + "helicone": { + "api_base": "https://ai-gateway.helicone.ai/", + "api_key_env": [ + "HELICONE_API_KEY" + ], + "display_name": "helicone" + }, + "nano-gpt": { + "api_base": "https://nano-gpt.com/api/v1", + "api_key_env": [ + "NANOGPT_API_KEY" + ], + "display_name": "nano-gpt" + }, + "poe": { + "api_base": "https://api.poe.com/v1", + "api_key_env": [ + "POE_API_KEY" + ], + "display_name": "poe" + }, + "publicai": { + "api_base": "https://api.publicai.co/v1", + "api_key_env": [ + "PUBLICAI_API_KEY" + ], + "display_name": "publicai" + }, + "synthetic": { + "api_base": "https://api.synthetic.new/openai/v1", + "api_key_env": [ + "SYNTHETIC_API_KEY" + ], + "display_name": "synthetic", + "hf_namespace": true, + "supports_stream": false + }, + "veniceai": { + "api_base": "https://api.venice.ai/api/v1", + "api_key_env": [ + "VENICE_AI_API_KEY" + ], + "display_name": "veniceai" + }, + "xiaomi_mimo": { + "api_base": "https://api.xiaomimimo.com/v1", + "api_key_env": [ + "XIAOMI_MIMO_API_KEY" + ], + "display_name": "xiaomi_mimo" + } +} diff --git a/aider/tui/worker.py b/aider/tui/worker.py index dcaf6f0dce0..4596200c9df 100644 --- a/aider/tui/worker.py +++ b/aider/tui/worker.py @@ -102,7 +102,6 @@ async def _async_run(self): kwargs["args"] = self.coder.args # Skip summarization to avoid blocking LLM calls during mode switch kwargs["summarize_from_coder"] = False - kwargs["mcp_servers"] = [] # Empty to skip initialization new_coder = await Coder.create(**kwargs) new_coder.args = self.coder.args @@ -110,12 +109,6 @@ async def _async_run(self): if switch.kwargs.get("show_announcements") is False: new_coder.suppress_announcements_for_next_prompt = True - # Transfer MCP state to avoid re-initialization - new_coder.mcp_servers = self.coder.mcp_servers - new_coder.mcp_tools = self.coder.mcp_tools - # Transfer TUI app weak reference - new_coder.tui = self.coder.tui - # Notify TUI of mode change self.coder = new_coder edit_format = getattr(self.coder, "edit_format", "code") or "code" diff --git a/aider/utils.py b/aider/utils.py index a171ca8466d..dbaede294a3 100644 --- a/aider/utils.py +++ b/aider/utils.py @@ -1,3 +1,4 @@ +import glob import os import platform import shutil @@ -14,6 +15,25 @@ IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".webp", ".pdf"} +def expand_glob_patterns(patterns): + """Expand glob patterns in a list of file paths.""" + expanded_files = [] + for pattern in patterns: + # Check if the pattern contains glob characters + if any(c in pattern for c in "*?[]"): + # Use glob to expand the pattern + matches = glob.glob(pattern, recursive=True) + if matches: + expanded_files.extend(matches) + else: + # If no matches, keep the original pattern + expanded_files.append(pattern) + else: + # Not a glob pattern, keep as is + expanded_files.append(pattern) + return expanded_files + + def _execute_fzf(input_data, multi=False): """ Runs fzf as a subprocess, feeding it input_data. diff --git a/aider/website/docs/config/mcp.md b/aider/website/docs/config/mcp.md index b23b23c70b7..811c6a7a658 100644 --- a/aider/website/docs/config/mcp.md +++ b/aider/website/docs/config/mcp.md @@ -188,3 +188,23 @@ mcp-servers: "http://127.0.0.1:9222" ] ``` + +### GitHub + +GitHub MCP provides access to GitHub repositories, issues, pull requests, and other GitHub resources. It enables AI models to interact with GitHub data, read repository contents, and perform various GitHub operations. The server runs in a Docker container and requires a GitHub Personal Access Token for authentication. + +```yaml +mcp-servers: + mcpServers: + github: + transport: stdio + command: "docker" + args: [ + "run", + "-i", + "--rm", + "-e", + "GITHUB_PERSONAL_ACCESS_TOKEN=", + "ghcr.io/github/github-mcp-server" + ] +``` diff --git a/docker/Dockerfile.local.nvidia.cuda.ubuntu b/docker/Dockerfile.local.nvidia.cuda.ubuntu new file mode 100644 index 00000000000..a95b46ed38f --- /dev/null +++ b/docker/Dockerfile.local.nvidia.cuda.ubuntu @@ -0,0 +1,90 @@ +# Usage steps: +# docker build -f Dockerfile.local.nvidia.cuda.ubuntu -t aider-ce:local-nvidia-cuda13.1.0-ubuntu24.04 . +# docker run -it --rm --network host --runtime=nvidia --gpus all -v "$(pwd)":/app aider-ce:local-nvidia-cuda13.1.0-ubuntu24.04 +FROM nvidia/cuda:13.1.0-runtime-ubuntu24.04 + +ARG DEFAULTUSER=ubuntu +ARG DEFAULTGROUP=ubuntu +ENV OLDHOME=/home/${DEFAULTUSER} + +# Install system dependencies +RUN apt-get update && \ + apt-get install --no-install-recommends -y python3.12 python3-pip python3-full python3-venv python3.12-venv curl jq nvidia-utils-580 && \ + apt-get install --no-install-recommends -y build-essential fzf git libportaudio2 pandoc wget fonts-unifont fontconfig && \ + wget -O /tmp/fonts-ubuntu.deb http://ftp.de.debian.org/debian/pool/non-free/f/fonts-ubuntu/fonts-ubuntu_0.83-6_all.deb && \ + wget -O /tmp/ttf-unifont.deb http://ftp.de.debian.org/debian/pool/main/u/unifont/ttf-unifont_13.0.06-1_all.deb && \ + wget -O /tmp/ttf-ubuntu-font-family.deb http://ftp.de.debian.org/debian/pool/non-free/f/fonts-ubuntu/ttf-ubuntu-font-family_0.83-4_all.deb && \ + dpkg -i /tmp/fonts-ubuntu.deb /tmp/ttf-unifont.deb /tmp/ttf-ubuntu-font-family.deb && \ + apt-get install -f -y && \ + rm /tmp/fonts-ubuntu.deb /tmp/ttf-unifont.deb /tmp/ttf-ubuntu-font-family.deb && \ + fc-cache -fv && \ + rm -rf /var/lib/apt/lists/* && \ + rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# Create virtual environment +RUN ln -s /usr/bin/python3 /usr/bin/python +RUN python -m venv /venv +ENV PATH="/venv/bin:$PATH" +RUN /venv/bin/python -m pip install --no-cache-dir uv + +# Playwright browser settings +ENV PLAYWRIGHT_BROWSERS_PATH=${OLDHOME}/pw-browsers +ENV PLAYWRIGHT_SKIP_BROWSER_GC=1 + +# Create directories with proper permissions +RUN mkdir -p ${OLDHOME}/.aider ${OLDHOME}/.cache ${OLDHOME}/pw-browsers && \ + chown -R ${DEFAULTUSER}:${DEFAULTGROUP} ${OLDHOME} /app /venv && \ + chmod -R 777 ${OLDHOME}/.aider ${OLDHOME}/.cache ${OLDHOME}/pw-browsers + +# So git doesn't complain about unusual permissions +RUN git config --system --add safe.directory /app + +# This puts the container's ~/.aider into the host's project directory (usually host's cwd). +# That way caches, version checks, etc get stored in the host filesystem not +# simply discarded every time the container exits. +ENV HOME=/app + +# Copy requirements files +COPY requirements.txt /tmp/aider/ +COPY requirements/ /tmp/aider/requirements/ + +# Install dependencies as root +RUN uv pip install --no-cache-dir -r /tmp/aider/requirements.txt && \ + rm -rf /tmp/aider + +# Install playwright browsers +RUN uv pip install --no-cache-dir playwright && \ + /venv/bin/python -m playwright install --with-deps chromium + +# Fix site-packages permissions +RUN find /venv/lib/python3.12/site-packages \( -type d -exec chmod a+rwx {} + \) -o \( -type f -exec chmod a+rw {} + \) + +# Copy the rest of the application code +COPY . /app/ + +# Install the application as a package +RUN uv pip install . && \ + find . -mindepth 1 -delete + +# give some guidance +RUN cat >> ${OLDHOME}/.bashrc <<'EOF' +export OLLAMA_API_BASE=http://localhost:11434 +echo '* Activating venv' +source /venv/bin/activate +echo '* Checking GPU availability' +nvidia-smi -L +echo '* Checking available models' +curl -sS http://localhost:11434/v1/models | jq .data[].id | tr -d '\"' +echo '* Start cecli with stream: false, see https://github.com/Aider-AI/aider/issues/4594' +echo 'cecli --no-stream --model ollama/' +echo ' later you can best add it to .aider.conf.yml, see https://aider.chat/docs/config/aider_conf.html' +EOF + +# switch to non-root and set the host-mounted /app to HOME to persist ~/.aider +USER ${DEFAULTUSER} +ENV HOME=/app + +# WARNING! entrypoint must use explicit path, perhaps needs an update to OLDHOME +ENTRYPOINT ["/bin/bash", "--rcfile", "/home/ubuntu/.bashrc", "-i"] diff --git a/requirements/common-constraints.txt b/requirements/common-constraints.txt index a4c04c8caf5..f1e5ab20765 100644 --- a/requirements/common-constraints.txt +++ b/requirements/common-constraints.txt @@ -457,10 +457,13 @@ pytest==9.0.1 # -r requirements/requirements-dev.in # pytest-asyncio # pytest-env + # pytest-mock pytest-asyncio==1.3.0 # via -r requirements/requirements-dev.in pytest-env==1.2.0 # via -r requirements/requirements-dev.in +pytest-mock==3.15.1 + # via -r requirements/requirements-dev.in python-dateutil==2.9.0.post0 # via # google-cloud-bigquery diff --git a/requirements/requirements-dev.in b/requirements/requirements-dev.in index a7bbf3aeaf2..e52d0cdc30a 100644 --- a/requirements/requirements-dev.in +++ b/requirements/requirements-dev.in @@ -1,6 +1,7 @@ pytest pytest-asyncio pytest-env +pytest-mock pip-tools lox matplotlib diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index 2c21ad40b50..d1cebb9dff6 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -215,6 +215,7 @@ pytest==9.0.1 # -r requirements/requirements-dev.in # pytest-asyncio # pytest-env + # pytest-mock pytest-asyncio==1.3.0 # via # -c requirements/common-constraints.txt @@ -223,6 +224,10 @@ pytest-env==1.2.0 # via # -c requirements/common-constraints.txt # -r requirements/requirements-dev.in +pytest-mock==3.15.1 + # via + # -c requirements/common-constraints.txt + # -r requirements/requirements-dev.in python-dateutil==2.9.0.post0 # via # -c requirements/common-constraints.txt diff --git a/scripts/generate_providers.py b/scripts/generate_providers.py new file mode 100644 index 00000000000..50d4b08907b --- /dev/null +++ b/scripts/generate_providers.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python +""" +Interactively generate aider/resources/providers.json from litellm data. + +This script reads litellm's openai_like provider definitions and walks the user +through building cecli's provider registry, mirroring the workflow used by +clean_metadata.py (prompting when decisions are needed). +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Any, Dict, Iterable + +AUTO_APPROVE = False + + +def prompt_yes_no(question: str, default: bool = True) -> bool: + """Prompt user for yes/no input, returning bool.""" + + suffix = " [Y/n] " if default else " [y/N] " + if AUTO_APPROVE: + print(f"{question}{suffix}-> {'Y' if default else 'N'} (auto)") + return default + while True: + resp = input(question + suffix).strip().lower() + if not resp: + return default + if resp in ("y", "yes"): + return True + if resp in ("n", "no"): + return False + print("Please enter 'y' or 'n'.") + + +def _format_default(value: str | None) -> str | None: + if value is None: + return None + if value.startswith("[") and value.endswith("]"): + try: + parsed = json.loads(value) + except json.JSONDecodeError: + return value + if isinstance(parsed, list): + return ", ".join(str(item) for item in parsed) + return value + + +def prompt_value(question: str, default: str | None = None) -> str | None: + """Prompt user for a string; empty input keeps default.""" + + display_default = _format_default(default) + suffix = f" [{display_default}]" if display_default is not None else "" + if AUTO_APPROVE: + print(f"{question}{suffix}: -> {display_default or ''} (auto)") + return default + resp = input(f"{question}{suffix}: ").strip() + if not resp: + return default + return resp + + +def ensure_json_object(prompt_text: str, default: Dict[str, Any] | None = None) -> Dict[str, Any]: + """Prompt for a JSON object, re-prompting on parse errors.""" + + default_str = json.dumps(default, indent=2) if default else "" + while True: + raw = prompt_value(prompt_text, default_str) + if not raw: + return default or {} + if AUTO_APPROVE: + try: + parsed = json.loads(raw) + except json.JSONDecodeError: + return default or {} + return parsed + try: + parsed = json.loads(raw) + except json.JSONDecodeError as exc: # pragma: no cover - interactive error path + print(f"Invalid JSON ({exc}). Please try again.") + continue + if not isinstance(parsed, dict): + print('Please provide a JSON object (e.g., {"Header": "value"}).') + continue + return parsed + + +def _list_to_csv(value: Iterable[str] | str | None) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + return ", ".join(str(item) for item in value) + + +def _parse_csv(value: str | None) -> list[str]: + if not value: + return [] + return [item.strip() for item in value.split(",") if item.strip()] + + +def main(): + global AUTO_APPROVE + + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "-y", + "--yes", + "--auto-approve", + dest="auto", + action="store_true", + help="Automatically include all providers and accept defaults without prompting.", + ) + args = parser.parse_args() + AUTO_APPROVE = args.auto + + script_dir = Path(__file__).parent.resolve() + repo_root = script_dir.parent + + litellm_providers_path = ( + script_dir.parent / "../litellm/litellm/llms/openai_like/providers.json" + ).resolve() + output_path = (repo_root / "aider" / "resources" / "providers.json").resolve() + + if not litellm_providers_path.exists(): + print(f"Error: Could not find litellm providers at {litellm_providers_path}") + return + + try: + litellm_data = json.loads(litellm_providers_path.read_text()) + except json.JSONDecodeError as exc: + print(f"Error: Failed to parse litellm providers ({exc}).") + return + + existing = {} + if output_path.exists(): + try: + existing = json.loads(output_path.read_text()) + except json.JSONDecodeError as exc: + print(f"Warning: Existing {output_path} is invalid JSON ({exc}); ignoring.") + + new_config: Dict[str, Dict[str, Any]] = {} + + for provider_name in sorted(litellm_data.keys()): + litellm_entry = litellm_data[provider_name] + existing_entry = existing.get(provider_name, {}) + default_keep = bool(existing_entry) + + print("\n" + "=" * 60) + print(f"Provider: {provider_name}") + print(f" Display name : {litellm_entry.get('display_name', provider_name)}") + print(f" Base URL : {litellm_entry.get('base_url', 'N/A')}") + + api_key_list = litellm_entry.get("api_key_env") + api_key_display = _list_to_csv(api_key_list) if api_key_list else "N/A" + print(f" API key env : {api_key_display}") + + include = prompt_yes_no( + f"Include provider '{provider_name}'?", default=default_keep or True + ) + if not include: + continue + + display_name = prompt_value( + "Display name", + existing_entry.get("display_name") + or litellm_entry.get("display_name") + or provider_name, + ) + api_base = prompt_value( + "API base URL", + existing_entry.get("api_base") or litellm_entry.get("base_url") or "", + ) + base_url_env = prompt_value( + "Comma-separated env vars for overriding base URL", + _list_to_csv(existing_entry.get("base_url_env")) or "", + ) + api_key_env = prompt_value( + "Comma-separated env vars for API key lookup", + _list_to_csv(existing_entry.get("api_key_env", litellm_entry.get("api_key_env", []))) + or "", + ) + models_url = prompt_value( + "Models endpoint URL (leave blank if none)", + existing_entry.get("models_url", ""), + ) + default_headers = ensure_json_object( + "Default headers JSON (empty for none)", + existing_entry.get("default_headers"), + ) + + record: Dict[str, Any] = {} + if display_name: + record["display_name"] = display_name + if api_base: + record["api_base"] = api_base + if api_key_env: + record["api_key_env"] = _parse_csv(api_key_env) + if base_url_env: + record["base_url_env"] = _parse_csv(base_url_env) + if models_url: + record["models_url"] = models_url + if default_headers: + record["default_headers"] = default_headers + + new_config[provider_name] = record + + # Preserve providers that only exist in the existing file (not litellm) if user wants. + for provider_name in sorted(existing.keys()): + if provider_name in new_config or provider_name in litellm_data: + continue + print("\n" + "=" * 60) + print(f"Provider '{provider_name}' exists only in {output_path}.") + if prompt_yes_no("Keep this provider?", default=True): + new_config[provider_name] = existing[provider_name] + + output_path.write_text(json.dumps(new_config, indent=2, sort_keys=True) + "\n") + print(f"\nWrote {len(new_config)} providers to {output_path}.\n") + + +if __name__ == "__main__": + main() diff --git a/tests/basic/test_main.py b/tests/basic/test_main.py index 7ed6564e5c3..256df16f6a8 100644 --- a/tests/basic/test_main.py +++ b/tests/basic/test_main.py @@ -1,7 +1,9 @@ +import asyncio import json import os import subprocess import tempfile +import types from io import StringIO from pathlib import Path from unittest import TestCase @@ -12,12 +14,22 @@ from prompt_toolkit.output import DummyOutput from aider.coders import Coder, CopyPasteCoder +from aider.commands import SwitchCoder from aider.dump import dump # noqa: F401 from aider.io import InputOutput from aider.main import check_gitignore, load_dotenv_files, main, setup_git from aider.utils import GitTemporaryDirectory, IgnorantTemporaryDirectory, make_repo +def mock_autosave_future(): + """Create an awaitable mock for _autosave_future. + + Returns AsyncMock()() - the first call creates an async mock function, + the second call invokes it to get an awaitable coroutine object. + """ + return AsyncMock()() + + class TestMain(TestCase): def setUp(self): self.original_env = os.environ.copy() @@ -45,57 +57,61 @@ def tearDown(self): self.input_patcher.stop() self.webbrowser_patcher.stop() - async def test_main_with_empty_dir_no_files_on_command(self): - await main(["--no-git", "--exit", "--yes"], input=DummyInput(), output=DummyOutput()) + def test_main_with_empty_dir_no_files_on_command(self): + main(["--no-git", "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput()) - async def test_main_with_emptqy_dir_new_file(self): - await main( - ["foo.txt", "--yes", "--no-git", "--exit"], input=DummyInput(), output=DummyOutput() + def test_main_with_emptqy_dir_new_file(self): + main( + ["foo.txt", "--yes-always", "--no-git", "--exit"], + input=DummyInput(), + output=DummyOutput(), ) self.assertTrue(os.path.exists("foo.txt")) @patch("aider.repo.GitRepo.get_commit_message", return_value="mock commit message") - async def test_main_with_empty_git_dir_new_file(self, _): + def test_main_with_empty_git_dir_new_file(self, _): make_repo() - await main(["--yes", "foo.txt", "--exit"], input=DummyInput(), output=DummyOutput()) + main(["--yes-always", "foo.txt", "--exit"], input=DummyInput(), output=DummyOutput()) self.assertTrue(os.path.exists("foo.txt")) @patch("aider.repo.GitRepo.get_commit_message", return_value="mock commit message") - async def test_main_with_empty_git_dir_new_files(self, _): + def test_main_with_empty_git_dir_new_files(self, _): make_repo() - await main( - ["--yes", "foo.txt", "bar.txt", "--exit"], input=DummyInput(), output=DummyOutput() + main( + ["--yes-always", "foo.txt", "bar.txt", "--exit"], + input=DummyInput(), + output=DummyOutput(), ) self.assertTrue(os.path.exists("foo.txt")) self.assertTrue(os.path.exists("bar.txt")) - async def test_main_with_dname_and_fname(self): + def test_main_with_dname_and_fname(self): subdir = Path("subdir") subdir.mkdir() make_repo(str(subdir)) - res = await main(["subdir", "foo.txt"], input=DummyInput(), output=DummyOutput()) + res = main(["subdir", "foo.txt"], input=DummyInput(), output=DummyOutput()) self.assertNotEqual(res, None) @patch("aider.repo.GitRepo.get_commit_message", return_value="mock commit message") - async def test_main_with_subdir_repo_fnames(self, _): + def test_main_with_subdir_repo_fnames(self, _): subdir = Path("subdir") subdir.mkdir() make_repo(str(subdir)) - await main( - ["--yes", str(subdir / "foo.txt"), str(subdir / "bar.txt"), "--exit"], + main( + ["--yes-always", str(subdir / "foo.txt"), str(subdir / "bar.txt"), "--exit"], input=DummyInput(), output=DummyOutput(), ) self.assertTrue((subdir / "foo.txt").exists()) self.assertTrue((subdir / "bar.txt").exists()) - async def test_main_copy_paste_model_overrides(self): + def test_main_copy_paste_model_overrides(self): overrides = json.dumps({"gpt-4o": {"fast": {"temperature": 0.42}}}) - coder = await main( + coder = main( [ "--no-git", "--exit", - "--yes", + "--yes-always", "--model", "cp:gpt-4o:fast", "--model-overrides", @@ -112,11 +128,11 @@ async def test_main_copy_paste_model_overrides(self): self.assertEqual(coder.main_model.override_kwargs, {"temperature": 0.42}) @patch("aider.main.ClipboardWatcher") - async def test_main_copy_paste_flag_sets_mode(self, mock_watcher): + def test_main_copy_paste_flag_sets_mode(self, mock_watcher): mock_watcher.return_value = MagicMock() - coder = await main( - ["--no-git", "--exit", "--yes", "--copy-paste"], + coder = main( + ["--no-git", "--exit", "--yes-always", "--copy-paste"], input=DummyInput(), output=DummyOutput(), return_coder=True, @@ -128,22 +144,26 @@ async def test_main_copy_paste_flag_sets_mode(self, mock_watcher): self.assertTrue(coder.copy_paste_mode) self.assertFalse(coder.manual_copy_paste) - async def test_main_with_git_config_yml(self): + def test_main_with_git_config_yml(self): make_repo() Path(".aider.conf.yml").write_text("auto-commits: false\n") with patch("aider.coders.Coder.create") as MockCoder: - await main(["--yes"], input=DummyInput(), output=DummyOutput()) + mock_coder_instance = MockCoder.return_value + mock_coder_instance._autosave_future = mock_autosave_future() + main(["--yes-always"], input=DummyInput(), output=DummyOutput()) _, kwargs = MockCoder.call_args assert kwargs["auto_commits"] is False Path(".aider.conf.yml").write_text("auto-commits: true\n") with patch("aider.coders.Coder.create") as MockCoder: - await main([], input=DummyInput(), output=DummyOutput()) + mock_coder_instance = MockCoder.return_value + mock_coder_instance._autosave_future = mock_autosave_future() + main([], input=DummyInput(), output=DummyOutput()) _, kwargs = MockCoder.call_args assert kwargs["auto_commits"] is True - async def test_main_with_empty_git_dir_new_subdir_file(self): + def test_main_with_empty_git_dir_new_subdir_file(self): make_repo() subdir = Path("subdir") subdir.mkdir() @@ -155,11 +175,11 @@ async def test_main_with_empty_git_dir_new_subdir_file(self): # This will throw a git error on windows if get_tracked_files doesn't # properly convert git/posix/paths to git\posix\paths. # Because aider will try and `git add` a file that's already in the repo. - await main(["--yes", str(fname), "--exit"], input=DummyInput(), output=DummyOutput()) + main(["--yes-always", str(fname), "--exit"], input=DummyInput(), output=DummyOutput()) - async def test_setup_git(self): + def test_setup_git(self): io = InputOutput(pretty=False, yes=True) - git_root = await setup_git(None, io) + git_root = asyncio.run(setup_git(None, io)) git_root = Path(git_root).resolve() self.assertEqual(git_root, Path(self.tempdir).resolve()) @@ -169,7 +189,7 @@ async def test_setup_git(self): self.assertTrue(gitignore.exists()) self.assertEqual(".aider*", gitignore.read_text().splitlines()[0]) - async def test_check_gitignore(self): + def test_check_gitignore(self): with GitTemporaryDirectory(): os.environ["GIT_CONFIG_GLOBAL"] = "globalgitconfig" @@ -178,24 +198,24 @@ async def test_check_gitignore(self): gitignore = cwd / ".gitignore" self.assertFalse(gitignore.exists()) - await check_gitignore(cwd, io) + asyncio.run(check_gitignore(cwd, io)) self.assertTrue(gitignore.exists()) self.assertEqual(".aider*", gitignore.read_text().splitlines()[0]) # Test without .env file present gitignore.write_text("one\ntwo\n") - await check_gitignore(cwd, io) + asyncio.run(check_gitignore(cwd, io)) self.assertEqual("one\ntwo\n.aider*\n", gitignore.read_text()) # Test with .env file present env_file = cwd / ".env" env_file.touch() - await check_gitignore(cwd, io) + asyncio.run(check_gitignore(cwd, io)) self.assertEqual("one\ntwo\n.aider*\n.env\n", gitignore.read_text()) del os.environ["GIT_CONFIG_GLOBAL"] - async def test_command_line_gitignore_files_flag(self): + def test_command_line_gitignore_files_flag(self): with GitTemporaryDirectory() as git_dir: git_dir = Path(git_dir) @@ -211,8 +231,8 @@ async def test_command_line_gitignore_files_flag(self): abs_ignored_file = str(ignored_file.resolve()) # Test without the --add-gitignore-files flag (default: False) - coder = await main( - ["--exit", "--yes", abs_ignored_file], + coder = main( + ["--exit", "--yes-always", abs_ignored_file], input=DummyInput(), output=DummyOutput(), return_coder=True, @@ -222,8 +242,8 @@ async def test_command_line_gitignore_files_flag(self): self.assertNotIn(abs_ignored_file, coder.abs_fnames) # Test with --add-gitignore-files set to True - coder = await main( - ["--add-gitignore-files", "--exit", "--yes", abs_ignored_file], + coder = main( + ["--add-gitignore-files", "--exit", "--yes-always", abs_ignored_file], input=DummyInput(), output=DummyOutput(), return_coder=True, @@ -233,8 +253,8 @@ async def test_command_line_gitignore_files_flag(self): self.assertIn(abs_ignored_file, coder.abs_fnames) # Test with --add-gitignore-files set to False - coder = await main( - ["--no-add-gitignore-files", "--exit", "--yes", abs_ignored_file], + coder = main( + ["--no-add-gitignore-files", "--exit", "--yes-always", abs_ignored_file], input=DummyInput(), output=DummyOutput(), return_coder=True, @@ -243,7 +263,7 @@ async def test_command_line_gitignore_files_flag(self): # Verify the ignored file is not in the chat self.assertNotIn(abs_ignored_file, coder.abs_fnames) - async def test_add_command_gitignore_files_flag(self): + def test_add_command_gitignore_files_flag(self): with GitTemporaryDirectory() as git_dir: git_dir = Path(git_dir) @@ -260,79 +280,95 @@ async def test_add_command_gitignore_files_flag(self): rel_ignored_file = "ignored.txt" # Test without the --add-gitignore-files flag (default: False) - coder = await main( - ["--exit", "--yes"], + coder = main( + ["--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), return_coder=True, force_git_root=git_dir, ) - with patch.object(coder.io, "confirm_ask", return_value=True): - coder.commands.cmd_add(rel_ignored_file) + try: + asyncio.run(coder.commands.do_run("add", rel_ignored_file)) + except SwitchCoder: + pass # Verify the ignored file is not in the chat self.assertNotIn(abs_ignored_file, coder.abs_fnames) # Test with --add-gitignore-files set to True - coder = await main( - ["--add-gitignore-files", "--exit", "--yes"], + coder = main( + ["--add-gitignore-files", "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), return_coder=True, force_git_root=git_dir, ) - with patch.object(coder.io, "confirm_ask", return_value=True): - coder.commands.cmd_add(rel_ignored_file) + try: + asyncio.run(coder.commands.do_run("add", rel_ignored_file)) + except SwitchCoder: + pass # Verify the ignored file is in the chat self.assertIn(abs_ignored_file, coder.abs_fnames) # Test with --add-gitignore-files set to False - coder = await main( - ["--no-add-gitignore-files", "--exit", "--yes"], + coder = main( + ["--no-add-gitignore-files", "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), return_coder=True, force_git_root=git_dir, ) - with patch.object(coder.io, "confirm_ask", return_value=True): - coder.commands.cmd_add(rel_ignored_file) + try: + asyncio.run(coder.commands.do_run("add", rel_ignored_file)) + except SwitchCoder: + pass # Verify the ignored file is not in the chat self.assertNotIn(abs_ignored_file, coder.abs_fnames) - async def test_main_args(self): + def test_main_args(self): with patch("aider.coders.Coder.create") as MockCoder: + mock_coder_instance = MockCoder.return_value + mock_coder_instance._autosave_future = mock_autosave_future() # --yes will just ok the git repo without blocking on input # following calls to main will see the new repo already - await main(["--no-auto-commits", "--yes"], input=DummyInput()) + main(["--no-auto-commits", "--yes-always"], input=DummyInput()) _, kwargs = MockCoder.call_args assert kwargs["auto_commits"] is False with patch("aider.coders.Coder.create") as MockCoder: - await main(["--auto-commits"], input=DummyInput()) + mock_coder_instance = MockCoder.return_value + mock_coder_instance._autosave_future = mock_autosave_future() + main(["--auto-commits"], input=DummyInput()) _, kwargs = MockCoder.call_args assert kwargs["auto_commits"] is True with patch("aider.coders.Coder.create") as MockCoder: - await main([], input=DummyInput()) + mock_coder_instance = MockCoder.return_value + mock_coder_instance._autosave_future = mock_autosave_future() + main([], input=DummyInput()) _, kwargs = MockCoder.call_args assert kwargs["dirty_commits"] is True assert kwargs["auto_commits"] is True with patch("aider.coders.Coder.create") as MockCoder: - await main(["--no-dirty-commits"], input=DummyInput()) + mock_coder_instance = MockCoder.return_value + mock_coder_instance._autosave_future = mock_autosave_future() + main(["--no-dirty-commits"], input=DummyInput()) _, kwargs = MockCoder.call_args assert kwargs["dirty_commits"] is False with patch("aider.coders.Coder.create") as MockCoder: - await main(["--dirty-commits"], input=DummyInput()) + mock_coder_instance = MockCoder.return_value + mock_coder_instance._autosave_future = mock_autosave_future() + main(["--dirty-commits"], input=DummyInput()) _, kwargs = MockCoder.call_args assert kwargs["dirty_commits"] is True - async def test_env_file_override(self): + def test_env_file_override(self): with GitTemporaryDirectory() as git_dir: git_dir = Path(git_dir) git_env = git_dir / ".env" @@ -356,7 +392,7 @@ async def test_env_file_override(self): named_env.write_text("A=named") with patch("pathlib.Path.home", return_value=fake_home): - await main(["--yes", "--exit", "--env-file", str(named_env)]) + main(["--yes-always", "--exit", "--env-file", str(named_env)]) self.assertEqual(os.environ["A"], "named") self.assertEqual(os.environ["B"], "cwd") @@ -364,7 +400,7 @@ async def test_env_file_override(self): self.assertEqual(os.environ["D"], "home") self.assertEqual(os.environ["E"], "existing") - async def test_message_file_flag(self): + def test_message_file_flag(self): message_file_content = "This is a test message from a file." message_file_path = tempfile.mktemp() with open(message_file_path, "w", encoding="utf-8") as message_file: @@ -378,10 +414,11 @@ async def mock_run(*args, **kwargs): # Create a mock coder instance with an async run method mock_coder_instance = MagicMock() mock_coder_instance.run = AsyncMock() + mock_coder_instance._autosave_future = mock_autosave_future() MockCoder.return_value = mock_coder_instance - await main( - ["--yes", "--message-file", message_file_path], + main( + ["--yes-always", "--message-file", message_file_path], input=DummyInput(), output=DummyOutput(), ) @@ -390,96 +427,100 @@ async def mock_run(*args, **kwargs): os.remove(message_file_path) - async def test_encodings_arg(self): + def test_encodings_arg(self): fname = "foo.py" with GitTemporaryDirectory(): - with patch("aider.coders.Coder.create") as MockCoder: # noqa: F841 + with patch("aider.coders.Coder.create") as MockCoder: + mock_coder_instance = MockCoder.return_value + mock_coder_instance._autosave_future = mock_autosave_future() with patch("aider.main.InputOutput") as MockSend: def side_effect(*args, **kwargs): self.assertEqual(kwargs["encoding"], "iso-8859-15") - return MagicMock() + mock_io = MagicMock() + mock_io.confirm_ask = AsyncMock(return_value=True) + return mock_io MockSend.side_effect = side_effect - await main(["--yes", fname, "--encoding", "iso-8859-15"]) + main(["--yes-always", fname, "--encoding", "iso-8859-15"]) - async def test_main_exit_calls_version_check(self): + def test_main_exit_calls_version_check(self): with GitTemporaryDirectory(): with ( patch("aider.main.check_version") as mock_check_version, patch("aider.main.InputOutput") as mock_input_output, ): - await main(["--exit", "--check-update"], input=DummyInput(), output=DummyOutput()) + mock_input_output.return_value.confirm_ask = AsyncMock(return_value=True) + main(["--exit", "--check-update"], input=DummyInput(), output=DummyOutput()) mock_check_version.assert_called_once() mock_input_output.assert_called_once() - @patch("aider.main.InputOutput") + @patch("aider.main.InputOutput", autospec=True) @patch("aider.coders.base_coder.Coder.run") - async def test_main_message_adds_to_input_history(self, mock_run, MockInputOutput): + def test_main_message_adds_to_input_history(self, mock_run, MockInputOutput): test_message = "test message" mock_io_instance = MockInputOutput.return_value + mock_io_instance.pretty = True - await main(["--message", test_message], input=DummyInput(), output=DummyOutput()) + main(["--message", test_message], input=DummyInput(), output=DummyOutput()) mock_io_instance.add_to_input_history.assert_called_once_with(test_message) - @patch("aider.main.InputOutput") + @patch("aider.main.InputOutput", autospec=True) @patch("aider.coders.base_coder.Coder.run") - async def test_yes(self, mock_run, MockInputOutput): + def test_yes(self, mock_run, MockInputOutput): test_message = "test message" + MockInputOutput.return_value.pretty = True - await main(["--yes", "--message", test_message]) + main(["--yes-always", "--message", test_message]) args, kwargs = MockInputOutput.call_args self.assertTrue(args[1]) - @patch("aider.main.InputOutput") + @patch("aider.main.InputOutput", autospec=True) @patch("aider.coders.base_coder.Coder.run") - async def test_default_yes(self, mock_run, MockInputOutput): + def test_default_yes(self, mock_run, MockInputOutput): test_message = "test message" + MockInputOutput.return_value.pretty = True - await main(["--message", test_message]) + main(["--message", test_message]) args, kwargs = MockInputOutput.call_args self.assertEqual(args[1], None) - async def test_dark_mode_sets_code_theme(self): + def test_dark_mode_sets_code_theme(self): # Mock InputOutput to capture the configuration with patch("aider.main.InputOutput") as MockInputOutput: MockInputOutput.return_value.get_input.return_value = None - await main( - ["--dark-mode", "--no-git", "--exit"], input=DummyInput(), output=DummyOutput() - ) + main(["--dark-mode", "--no-git", "--exit"], input=DummyInput(), output=DummyOutput()) # Ensure InputOutput was called MockInputOutput.assert_called_once() # Check if the code_theme setting is for dark mode _, kwargs = MockInputOutput.call_args self.assertEqual(kwargs["code_theme"], "monokai") - async def test_light_mode_sets_code_theme(self): + def test_light_mode_sets_code_theme(self): # Mock InputOutput to capture the configuration with patch("aider.main.InputOutput") as MockInputOutput: MockInputOutput.return_value.get_input.return_value = None - await main( - ["--light-mode", "--no-git", "--exit"], input=DummyInput(), output=DummyOutput() - ) + main(["--light-mode", "--no-git", "--exit"], input=DummyInput(), output=DummyOutput()) # Ensure InputOutput was called MockInputOutput.assert_called_once() # Check if the code_theme setting is for light mode _, kwargs = MockInputOutput.call_args self.assertEqual(kwargs["code_theme"], "default") - async def create_env_file(self, file_name, content): + def create_env_file(self, file_name, content): env_file_path = Path(self.tempdir) / file_name env_file_path.write_text(content) return env_file_path - async def test_env_file_flag_sets_automatic_variable(self): + def test_env_file_flag_sets_automatic_variable(self): env_file_path = self.create_env_file(".env.test", "AIDER_DARK_MODE=True") with patch("aider.main.InputOutput") as MockInputOutput: MockInputOutput.return_value.get_input.return_value = None MockInputOutput.return_value.get_input.confirm_ask = True - await main( + main( ["--env-file", str(env_file_path), "--no-git", "--exit"], input=DummyInput(), output=DummyOutput(), @@ -489,35 +530,39 @@ async def test_env_file_flag_sets_automatic_variable(self): _, kwargs = MockInputOutput.call_args self.assertEqual(kwargs["code_theme"], "monokai") - async def test_default_env_file_sets_automatic_variable(self): + def test_default_env_file_sets_automatic_variable(self): self.create_env_file(".env", "AIDER_DARK_MODE=True") with patch("aider.main.InputOutput") as MockInputOutput: MockInputOutput.return_value.get_input.return_value = None MockInputOutput.return_value.get_input.confirm_ask = True - await main(["--no-git", "--exit"], input=DummyInput(), output=DummyOutput()) + main(["--no-git", "--exit"], input=DummyInput(), output=DummyOutput()) # Ensure InputOutput was called MockInputOutput.assert_called_once() # Check if the color settings are for dark mode _, kwargs = MockInputOutput.call_args self.assertEqual(kwargs["code_theme"], "monokai") - async def test_false_vals_in_env_file(self): + def test_false_vals_in_env_file(self): self.create_env_file(".env", "AIDER_SHOW_DIFFS=off") - with patch("aider.coders.Coder.create") as MockCoder: - await main(["--no-git", "--yes"], input=DummyInput(), output=DummyOutput()) + with patch("aider.coders.Coder.create", autospec=True) as MockCoder: + mock_coder_instance = MockCoder.return_value + mock_coder_instance._autosave_future = mock_autosave_future() + main(["--no-git", "--yes-always"], input=DummyInput(), output=DummyOutput()) MockCoder.assert_called_once() _, kwargs = MockCoder.call_args self.assertEqual(kwargs["show_diffs"], False) - async def test_true_vals_in_env_file(self): + def test_true_vals_in_env_file(self): self.create_env_file(".env", "AIDER_SHOW_DIFFS=on") with patch("aider.coders.Coder.create") as MockCoder: - await main(["--no-git", "--yes"], input=DummyInput(), output=DummyOutput()) + mock_coder_instance = MockCoder.return_value + mock_coder_instance._autosave_future = mock_autosave_future() + main(["--no-git", "--yes-always"], input=DummyInput(), output=DummyOutput()) MockCoder.assert_called_once() _, kwargs = MockCoder.call_args self.assertEqual(kwargs["show_diffs"], True) - async def test_lint_option(self): + def test_lint_option(self): with GitTemporaryDirectory() as git_dir: # Create a dirty file in the root dirty_file = Path("dirty_file.py") @@ -541,7 +586,7 @@ async def test_lint_option(self): MockLinter.return_value = "" # Run main with --lint option - await main(["--lint", "--yes"]) + main(["--lint", "--yes-always"], input=DummyInput(), output=DummyOutput()) # Check if the Linter was called with a filename ending in "dirty_file.py" # but not ending in "subdir/dirty_file.py" @@ -550,11 +595,69 @@ async def test_lint_option(self): self.assertTrue(called_arg.endswith("dirty_file.py")) self.assertFalse(called_arg.endswith(f"subdir{os.path.sep}dirty_file.py")) - async def test_verbose_mode_lists_env_vars(self): + def test_lint_option_with_explicit_files(self): + with GitTemporaryDirectory(): + # Create two files + file1 = Path("file1.py") + file1.write_text("def foo(): pass") + file2 = Path("file2.py") + file2.write_text("def bar(): pass") + + # Mock the Linter class + with patch("aider.linter.Linter.lint") as MockLinter: + MockLinter.return_value = "" + + # Run main with --lint and explicit files + main( + ["--lint", "file1.py", "file2.py", "--yes-always"], + input=DummyInput(), + output=DummyOutput(), + ) + + # Check if the Linter was called twice (once for each file) + self.assertEqual(MockLinter.call_count, 2) + + # Check that both files were linted + called_files = [call[0][0] for call in MockLinter.call_args_list] + self.assertTrue(any(f.endswith("file1.py") for f in called_files)) + self.assertTrue(any(f.endswith("file2.py") for f in called_files)) + + def test_lint_option_with_glob_pattern(self): + with GitTemporaryDirectory(): + # Create multiple Python files + file1 = Path("test1.py") + file1.write_text("def foo(): pass") + file2 = Path("test2.py") + file2.write_text("def bar(): pass") + file3 = Path("readme.txt") + file3.write_text("not a python file") + + # Mock the Linter class + with patch("aider.linter.Linter.lint") as MockLinter: + MockLinter.return_value = "" + + # Run main with --lint and glob pattern + main( + ["--lint", "test*.py", "--yes-always"], + input=DummyInput(), + output=DummyOutput(), + ) + + # Check if the Linter was called for Python files matching the glob + self.assertGreaterEqual(MockLinter.call_count, 2) + + # Check that Python files were linted + called_files = [call[0][0] for call in MockLinter.call_args_list] + self.assertTrue(any(f.endswith("test1.py") for f in called_files)) + self.assertTrue(any(f.endswith("test2.py") for f in called_files)) + # Check that non-Python file was not linted + self.assertFalse(any(f.endswith("readme.txt") for f in called_files)) + + def test_verbose_mode_lists_env_vars(self): self.create_env_file(".env", "AIDER_DARK_MODE=on") with patch("sys.stdout", new_callable=StringIO) as mock_stdout: - await main( - ["--no-git", "--verbose", "--exit", "--yes"], + main( + ["--no-git", "--verbose", "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), ) @@ -569,7 +672,7 @@ async def test_verbose_mode_lists_env_vars(self): self.assertRegex(relevant_output, r"AIDER_DARK_MODE:\s+on") self.assertRegex(relevant_output, r"dark_mode:\s+True") - async def test_yaml_config_file_loading(self): + def test_yaml_config_file_loading(self): with GitTemporaryDirectory() as git_dir: git_dir = Path(git_dir) @@ -598,9 +701,11 @@ async def test_yaml_config_file_loading(self): patch("pathlib.Path.home", return_value=fake_home), patch("aider.coders.Coder.create") as MockCoder, ): + mock_coder_instance = MockCoder.return_value + mock_coder_instance._autosave_future = mock_autosave_future() # Test loading from specified config file - await main( - ["--yes", "--exit", "--config", str(named_config)], + main( + ["--yes-always", "--exit", "--config", str(named_config)], input=DummyInput(), output=DummyOutput(), ) @@ -609,7 +714,8 @@ async def test_yaml_config_file_loading(self): self.assertEqual(kwargs["map_tokens"], 8192) # Test loading from current working directory - await main(["--yes", "--exit"], input=DummyInput(), output=DummyOutput()) + mock_coder_instance._autosave_future = mock_autosave_future() + main(["--yes-always", "--exit"], input=DummyInput(), output=DummyOutput()) _, kwargs = MockCoder.call_args print("kwargs:", kwargs) # Add this line for debugging self.assertIn("main_model", kwargs, "main_model key not found in kwargs") @@ -618,47 +724,49 @@ async def test_yaml_config_file_loading(self): # Test loading from git root cwd_config.unlink() - await main(["--yes", "--exit"], input=DummyInput(), output=DummyOutput()) + mock_coder_instance._autosave_future = mock_autosave_future() + main(["--yes-always", "--exit"], input=DummyInput(), output=DummyOutput()) _, kwargs = MockCoder.call_args self.assertEqual(kwargs["main_model"].name, "gpt-4") self.assertEqual(kwargs["map_tokens"], 2048) # Test loading from home directory git_config.unlink() - await main(["--yes", "--exit"], input=DummyInput(), output=DummyOutput()) + mock_coder_instance._autosave_future = mock_autosave_future() + main(["--yes-always", "--exit"], input=DummyInput(), output=DummyOutput()) _, kwargs = MockCoder.call_args self.assertEqual(kwargs["main_model"].name, "gpt-3.5-turbo") self.assertEqual(kwargs["map_tokens"], 1024) - async def test_map_tokens_option(self): + def test_map_tokens_option(self): with GitTemporaryDirectory(): with patch("aider.coders.base_coder.RepoMap") as MockRepoMap: MockRepoMap.return_value.max_map_tokens = 0 - await main( - ["--model", "gpt-4", "--map-tokens", "0", "--exit", "--yes"], + main( + ["--model", "gpt-4", "--map-tokens", "0", "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), ) MockRepoMap.assert_not_called() - async def test_map_tokens_option_with_non_zero_value(self): + def test_map_tokens_option_with_non_zero_value(self): with GitTemporaryDirectory(): with patch("aider.coders.base_coder.RepoMap") as MockRepoMap: MockRepoMap.return_value.max_map_tokens = 1000 - await main( - ["--model", "gpt-4", "--map-tokens", "1000", "--exit", "--yes"], + main( + ["--model", "gpt-4", "--map-tokens", "1000", "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), ) MockRepoMap.assert_called_once() - async def test_read_option(self): + def test_read_option(self): with GitTemporaryDirectory(): test_file = "test_file.txt" Path(test_file).touch() - coder = await main( - ["--read", test_file, "--exit", "--yes"], + coder = main( + ["--read", test_file, "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), return_coder=True, @@ -666,15 +774,15 @@ async def test_read_option(self): self.assertIn(str(Path(test_file).resolve()), coder.abs_read_only_fnames) - async def test_read_option_with_external_file(self): + def test_read_option_with_external_file(self): with tempfile.NamedTemporaryFile(mode="w", delete=False) as external_file: external_file.write("External file content") external_file_path = external_file.name try: with GitTemporaryDirectory(): - coder = await main( - ["--read", external_file_path, "--exit", "--yes"], + coder = main( + ["--read", external_file_path, "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), return_coder=True, @@ -685,7 +793,7 @@ async def test_read_option_with_external_file(self): finally: os.unlink(external_file_path) - async def test_model_metadata_file(self): + def test_model_metadata_file(self): # Re-init so we don't have old data lying around from earlier test cases from aider import models @@ -702,14 +810,14 @@ async def test_model_metadata_file(self): metadata_content = {"deepseek/deepseek-chat": {"max_input_tokens": 1234}} metadata_file.write_text(json.dumps(metadata_content)) - coder = await main( + coder = main( [ "--model", "deepseek/deepseek-chat", "--model-metadata-file", str(metadata_file), "--exit", - "--yes", + "--yes-always", ], input=DummyInput(), output=DummyOutput(), @@ -718,15 +826,15 @@ async def test_model_metadata_file(self): self.assertEqual(coder.main_model.info["max_input_tokens"], 1234) - async def test_sonnet_and_cache_options(self): + def test_sonnet_and_cache_options(self): with GitTemporaryDirectory(): with patch("aider.coders.base_coder.RepoMap") as MockRepoMap: mock_repo_map = MagicMock() mock_repo_map.max_map_tokens = 1000 # Set a specific value MockRepoMap.return_value = mock_repo_map - await main( - ["--sonnet", "--cache-prompts", "--exit", "--yes"], + main( + ["--sonnet", "--cache-prompts", "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), ) @@ -737,10 +845,10 @@ async def test_sonnet_and_cache_options(self): call_kwargs.get("refresh"), "files" ) # Check the 'refresh' keyword argument - async def test_sonnet_and_cache_prompts_options(self): + def test_sonnet_and_cache_prompts_options(self): with GitTemporaryDirectory(): - coder = await main( - ["--sonnet", "--cache-prompts", "--exit", "--yes"], + coder = main( + ["--sonnet", "--cache-prompts", "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), return_coder=True, @@ -748,10 +856,10 @@ async def test_sonnet_and_cache_prompts_options(self): self.assertTrue(coder.add_cache_headers) - async def test_4o_and_cache_options(self): + def test_4o_and_cache_options(self): with GitTemporaryDirectory(): - coder = await main( - ["--4o", "--cache-prompts", "--exit", "--yes"], + coder = main( + ["--4o", "--cache-prompts", "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), return_coder=True, @@ -759,28 +867,28 @@ async def test_4o_and_cache_options(self): self.assertFalse(coder.add_cache_headers) - async def test_return_coder(self): + def test_return_coder(self): with GitTemporaryDirectory(): - result = await main( - ["--exit", "--yes"], + result = main( + ["--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), return_coder=True, ) self.assertIsInstance(result, Coder) - result = await main( - ["--exit", "--yes"], + result = main( + ["--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), return_coder=False, ) - self.assertIsNone(result) + self.assertEqual(result, 0) - async def test_map_mul_option(self): + def test_map_mul_option(self): with GitTemporaryDirectory(): - coder = await main( - ["--map-mul", "5", "--exit", "--yes"], + coder = main( + ["--map-mul", "5", "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), return_coder=True, @@ -788,67 +896,67 @@ async def test_map_mul_option(self): self.assertIsInstance(coder, Coder) self.assertEqual(coder.repo_map.map_mul_no_files, 5) - async def test_suggest_shell_commands_default(self): + def test_suggest_shell_commands_default(self): with GitTemporaryDirectory(): - coder = await main( - ["--exit", "--yes"], + coder = main( + ["--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), return_coder=True, ) self.assertTrue(coder.suggest_shell_commands) - async def test_suggest_shell_commands_disabled(self): + def test_suggest_shell_commands_disabled(self): with GitTemporaryDirectory(): - coder = await main( - ["--no-suggest-shell-commands", "--exit", "--yes"], + coder = main( + ["--no-suggest-shell-commands", "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), return_coder=True, ) self.assertFalse(coder.suggest_shell_commands) - async def test_suggest_shell_commands_enabled(self): + def test_suggest_shell_commands_enabled(self): with GitTemporaryDirectory(): - coder = await main( - ["--suggest-shell-commands", "--exit", "--yes"], + coder = main( + ["--suggest-shell-commands", "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), return_coder=True, ) self.assertTrue(coder.suggest_shell_commands) - async def test_detect_urls_default(self): + def test_detect_urls_default(self): with GitTemporaryDirectory(): - coder = await main( - ["--exit", "--yes"], + coder = main( + ["--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), return_coder=True, ) self.assertTrue(coder.detect_urls) - async def test_detect_urls_disabled(self): + def test_detect_urls_disabled(self): with GitTemporaryDirectory(): - coder = await main( - ["--no-detect-urls", "--exit", "--yes"], + coder = main( + ["--no-detect-urls", "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), return_coder=True, ) self.assertFalse(coder.detect_urls) - async def test_detect_urls_enabled(self): + def test_detect_urls_enabled(self): with GitTemporaryDirectory(): - coder = await main( - ["--detect-urls", "--exit", "--yes"], + coder = main( + ["--detect-urls", "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), return_coder=True, ) self.assertTrue(coder.detect_urls) - async def test_accepts_settings_warnings(self): + def test_accepts_settings_warnings(self): # Test that appropriate warnings are shown based on accepts_settings configuration with GitTemporaryDirectory(): # Test model that accepts the thinking_tokens setting @@ -856,13 +964,13 @@ async def test_accepts_settings_warnings(self): patch("aider.io.InputOutput.tool_warning") as mock_warning, patch("aider.models.Model.set_thinking_tokens") as mock_set_thinking, ): - await main( + main( [ "--model", "anthropic/claude-3-7-sonnet-20250219", "--thinking-tokens", "1000", - "--yes", + "--yes-always", "--exit", ], input=DummyInput(), @@ -879,14 +987,14 @@ async def test_accepts_settings_warnings(self): patch("aider.io.InputOutput.tool_warning") as mock_warning, patch("aider.models.Model.set_thinking_tokens") as mock_set_thinking, ): - await main( + main( [ "--model", "gpt-4o", "--thinking-tokens", "1000", "--check-model-accepts-settings", - "--yes", + "--yes-always", "--exit", ], input=DummyInput(), @@ -906,8 +1014,8 @@ async def test_accepts_settings_warnings(self): patch("aider.io.InputOutput.tool_warning") as mock_warning, patch("aider.models.Model.set_reasoning_effort") as mock_set_reasoning, ): - await main( - ["--model", "o1", "--reasoning-effort", "3", "--yes", "--exit"], + main( + ["--model", "o1", "--reasoning-effort", "3", "--yes-always", "--exit"], input=DummyInput(), output=DummyOutput(), ) @@ -922,8 +1030,15 @@ async def test_accepts_settings_warnings(self): patch("aider.io.InputOutput.tool_warning") as mock_warning, patch("aider.models.Model.set_reasoning_effort") as mock_set_reasoning, ): - await main( - ["--model", "gpt-3.5-turbo", "--reasoning-effort", "3", "--yes", "--exit"], + main( + [ + "--model", + "gpt-3.5-turbo", + "--reasoning-effort", + "3", + "--yes-always", + "--exit", + ], input=DummyInput(), output=DummyOutput(), ) @@ -937,7 +1052,7 @@ async def test_accepts_settings_warnings(self): mock_set_reasoning.assert_not_called() @patch("aider.models.ModelInfoManager.set_verify_ssl") - async def test_no_verify_ssl_sets_model_info_manager(self, mock_set_verify_ssl): + def test_no_verify_ssl_sets_model_info_manager(self, mock_set_verify_ssl): with GitTemporaryDirectory(): # Mock Model class to avoid actual model initialization with patch("aider.models.Model") as mock_model: @@ -951,73 +1066,80 @@ async def test_no_verify_ssl_sets_model_info_manager(self, mock_set_verify_ssl): # Mock fuzzy_match_models to avoid string operations on MagicMock with patch("aider.models.fuzzy_match_models", return_value=[]): - await main( - ["--no-verify-ssl", "--exit", "--yes"], + main( + ["--no-verify-ssl", "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), ) mock_set_verify_ssl.assert_called_once_with(False) - async def test_pytest_env_vars(self): + def test_pytest_env_vars(self): # Verify that environment variables from pytest.ini are properly set self.assertEqual(os.environ.get("AIDER_ANALYTICS"), "false") - async def test_set_env_single(self): + def test_set_env_single(self): # Test setting a single environment variable with GitTemporaryDirectory(): - await main(["--set-env", "TEST_VAR=test_value", "--exit", "--yes"]) + main(["--set-env", "TEST_VAR=test_value", "--exit", "--yes-always"]) self.assertEqual(os.environ.get("TEST_VAR"), "test_value") - async def test_set_env_multiple(self): + def test_set_env_multiple(self): # Test setting multiple environment variables with GitTemporaryDirectory(): - await main( + main( [ "--set-env", "TEST_VAR1=value1", "--set-env", "TEST_VAR2=value2", "--exit", - "--yes", + "--yes-always", ] ) self.assertEqual(os.environ.get("TEST_VAR1"), "value1") self.assertEqual(os.environ.get("TEST_VAR2"), "value2") - async def test_set_env_with_spaces(self): + def test_set_env_with_spaces(self): # Test setting env var with spaces in value with GitTemporaryDirectory(): - await main(["--set-env", "TEST_VAR=test value with spaces", "--exit", "--yes"]) + main(["--set-env", "TEST_VAR=test value with spaces", "--exit", "--yes-always"]) self.assertEqual(os.environ.get("TEST_VAR"), "test value with spaces") - async def test_set_env_invalid_format(self): + def test_set_env_invalid_format(self): # Test invalid format handling with GitTemporaryDirectory(): - result = await main(["--set-env", "INVALID_FORMAT", "--exit", "--yes"]) + result = main(["--set-env", "INVALID_FORMAT", "--exit", "--yes-always"]) self.assertEqual(result, 1) - async def test_api_key_single(self): + def test_api_key_single(self): # Test setting a single API key with GitTemporaryDirectory(): - await main(["--api-key", "anthropic=test-key", "--exit", "--yes"]) + main(["--api-key", "anthropic=test-key", "--exit", "--yes-always"]) self.assertEqual(os.environ.get("ANTHROPIC_API_KEY"), "test-key") - async def test_api_key_multiple(self): + def test_api_key_multiple(self): # Test setting multiple API keys with GitTemporaryDirectory(): - await main( - ["--api-key", "anthropic=key1", "--api-key", "openai=key2", "--exit", "--yes"] + main( + [ + "--api-key", + "anthropic=key1", + "--api-key", + "openai=key2", + "--exit", + "--yes-always", + ] ) self.assertEqual(os.environ.get("ANTHROPIC_API_KEY"), "key1") self.assertEqual(os.environ.get("OPENAI_API_KEY"), "key2") - async def test_api_key_invalid_format(self): + def test_api_key_invalid_format(self): # Test invalid format handling with GitTemporaryDirectory(): - result = await main(["--api-key", "INVALID_FORMAT", "--exit", "--yes"]) + result = main(["--api-key", "INVALID_FORMAT", "--exit", "--yes-always"]) self.assertEqual(result, 1) - async def test_git_config_include(self): + def test_git_config_include(self): # Test that aider respects git config includes for user.name and user.email with GitTemporaryDirectory() as git_dir: git_dir = Path(git_dir) @@ -1042,7 +1164,7 @@ async def test_git_config_include(self): git_config_content = git_config_path.read_text() # Run aider and verify it doesn't change the git config - await main(["--yes", "--exit"], input=DummyInput(), output=DummyOutput()) + main(["--yes-always", "--exit"], input=DummyInput(), output=DummyOutput()) # Check that the user settings are still the same using git command repo = git.Repo(git_dir) # Re-open repo to ensure we get fresh config @@ -1053,7 +1175,7 @@ async def test_git_config_include(self): git_config_content_after = git_config_path.read_text() self.assertEqual(git_config_content, git_config_content_after) - async def test_git_config_include_directive(self): + def test_git_config_include_directive(self): # Test that aider respects the include directive in git config with GitTemporaryDirectory() as git_dir: git_dir = Path(git_dir) @@ -1083,7 +1205,7 @@ async def test_git_config_include_directive(self): self.assertEqual(repo.git.config("user.email"), "directive@example.com") # Run aider and verify it doesn't change the git config - await main(["--yes", "--exit"], input=DummyInput(), output=DummyOutput()) + main(["--yes-always", "--exit"], input=DummyInput(), output=DummyOutput()) # Check that the git config file wasn't modified config_after_aider = git_config.read_text() @@ -1094,7 +1216,7 @@ async def test_git_config_include_directive(self): self.assertEqual(repo.git.config("user.name"), "Directive User") self.assertEqual(repo.git.config("user.email"), "directive@example.com") - async def test_resolve_aiderignore_path(self): + def test_resolve_aiderignore_path(self): # Import the function directly to test it from aider.args import resolve_aiderignore_path @@ -1113,13 +1235,13 @@ async def test_resolve_aiderignore_path(self): rel_path = ".aiderignore" self.assertEqual(resolve_aiderignore_path(rel_path), rel_path) - async def test_invalid_edit_format(self): + def test_invalid_edit_format(self): with GitTemporaryDirectory(): # Suppress stderr for this test as argparse prints an error message with patch("sys.stderr", new_callable=StringIO) as mock_stderr: with self.assertRaises(SystemExit) as cm: - _ = await main( - ["--edit-format", "not-a-real-format", "--exit", "--yes"], + _ = main( + ["--edit-format", "not-a-real-format", "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), ) @@ -1129,44 +1251,59 @@ async def test_invalid_edit_format(self): self.assertIn("invalid choice", stderr_output) self.assertIn("not-a-real-format", stderr_output) - async def test_default_model_selection(self): + def test_default_model_selection(self): with GitTemporaryDirectory(): # Test Anthropic API key os.environ["ANTHROPIC_API_KEY"] = "test-key" - coder = await main( - ["--exit", "--yes"], input=DummyInput(), output=DummyOutput(), return_coder=True + coder = main( + ["--exit", "--yes-always"], + input=DummyInput(), + output=DummyOutput(), + return_coder=True, ) self.assertIn("sonnet", coder.main_model.name.lower()) del os.environ["ANTHROPIC_API_KEY"] # Test DeepSeek API key os.environ["DEEPSEEK_API_KEY"] = "test-key" - coder = await main( - ["--exit", "--yes"], input=DummyInput(), output=DummyOutput(), return_coder=True + coder = main( + ["--exit", "--yes-always"], + input=DummyInput(), + output=DummyOutput(), + return_coder=True, ) self.assertIn("deepseek", coder.main_model.name.lower()) del os.environ["DEEPSEEK_API_KEY"] # Test OpenRouter API key os.environ["OPENROUTER_API_KEY"] = "test-key" - coder = await main( - ["--exit", "--yes"], input=DummyInput(), output=DummyOutput(), return_coder=True + coder = main( + ["--exit", "--yes-always"], + input=DummyInput(), + output=DummyOutput(), + return_coder=True, ) self.assertIn("openrouter/", coder.main_model.name.lower()) del os.environ["OPENROUTER_API_KEY"] # Test OpenAI API key os.environ["OPENAI_API_KEY"] = "test-key" - coder = await main( - ["--exit", "--yes"], input=DummyInput(), output=DummyOutput(), return_coder=True + coder = main( + ["--exit", "--yes-always"], + input=DummyInput(), + output=DummyOutput(), + return_coder=True, ) self.assertIn("gpt-4", coder.main_model.name.lower()) del os.environ["OPENAI_API_KEY"] # Test Gemini API key os.environ["GEMINI_API_KEY"] = "test-key" - coder = await main( - ["--exit", "--yes"], input=DummyInput(), output=DummyOutput(), return_coder=True + coder = main( + ["--exit", "--yes-always"], + input=DummyInput(), + output=DummyOutput(), + return_coder=True, ) self.assertIn("gemini", coder.main_model.name.lower()) del os.environ["GEMINI_API_KEY"] @@ -1174,23 +1311,26 @@ async def test_default_model_selection(self): # Test no API keys - should offer OpenRouter OAuth with patch("aider.onboarding.offer_openrouter_oauth") as mock_offer_oauth: mock_offer_oauth.return_value = None # Simulate user declining or failure - result = await main(["--exit", "--yes"], input=DummyInput(), output=DummyOutput()) + result = main(["--exit", "--yes-always"], input=DummyInput(), output=DummyOutput()) self.assertEqual(result, 1) # Expect failure since no model could be selected mock_offer_oauth.assert_called_once() - async def test_model_precedence(self): + def test_model_precedence(self): with GitTemporaryDirectory(): # Test that earlier API keys take precedence os.environ["ANTHROPIC_API_KEY"] = "test-key" os.environ["OPENAI_API_KEY"] = "test-key" - coder = await main( - ["--exit", "--yes"], input=DummyInput(), output=DummyOutput(), return_coder=True + coder = main( + ["--exit", "--yes-always"], + input=DummyInput(), + output=DummyOutput(), + return_coder=True, ) self.assertIn("sonnet", coder.main_model.name.lower()) del os.environ["ANTHROPIC_API_KEY"] del os.environ["OPENAI_API_KEY"] - async def test_model_overrides_suffix_applied(self): + def test_model_overrides_suffix_applied(self): with GitTemporaryDirectory() as git_dir: git_dir = Path(git_dir) overrides_file = git_dir / ".aider.model.overrides.yml" @@ -1201,6 +1341,7 @@ async def test_model_overrides_suffix_applied(self): patch("aider.coders.Coder.create") as MockCoder, ): mock_coder_instance = MagicMock() + mock_coder_instance._autosave_future = mock_autosave_future() MockCoder.return_value = mock_coder_instance mock_instance = MockModel.return_value @@ -1214,8 +1355,8 @@ async def test_model_overrides_suffix_applied(self): mock_instance.weak_model_name = None mock_instance.get_weak_model.return_value = None - await main( - ["--model", "gpt-4o:fast", "--exit", "--yes", "--no-git"], + main( + ["--model", "gpt-4o:fast", "--exit", "--yes-always", "--no-git"], input=DummyInput(), output=DummyOutput(), force_git_root=git_dir, @@ -1241,7 +1382,7 @@ async def test_model_overrides_suffix_applied(self): ), ) - async def test_model_overrides_no_match_preserves_model_name(self): + def test_model_overrides_no_match_preserves_model_name(self): with GitTemporaryDirectory() as git_dir: git_dir = Path(git_dir) @@ -1250,6 +1391,7 @@ async def test_model_overrides_no_match_preserves_model_name(self): patch("aider.coders.Coder.create") as MockCoder, ): mock_coder_instance = MagicMock() + mock_coder_instance._autosave_future = mock_autosave_future() MockCoder.return_value = mock_coder_instance mock_instance = MockModel.return_value @@ -1265,8 +1407,8 @@ async def test_model_overrides_no_match_preserves_model_name(self): model_name = "hf:moonshotai/Kimi-K2-Thinking" - await main( - ["--model", model_name, "--exit", "--yes", "--no-git"], + main( + ["--model", model_name, "--exit", "--yes-always", "--no-git"], input=DummyInput(), output=DummyOutput(), force_git_root=git_dir, @@ -1287,10 +1429,10 @@ async def test_model_overrides_no_match_preserves_model_name(self): ), ) - async def test_chat_language_spanish(self): + def test_chat_language_spanish(self): with GitTemporaryDirectory(): - coder = await main( - ["--chat-language", "Spanish", "--exit", "--yes"], + coder = main( + ["--chat-language", "Spanish", "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), return_coder=True, @@ -1298,10 +1440,10 @@ async def test_chat_language_spanish(self): system_info = coder.get_platform_info() self.assertIn("Spanish", system_info) - async def test_commit_language_japanese(self): + def test_commit_language_japanese(self): with GitTemporaryDirectory(): - coder = await main( - ["--commit-language", "japanese", "--exit", "--yes"], + coder = main( + ["--commit-language", "japanese", "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), return_coder=True, @@ -1309,19 +1451,25 @@ async def test_commit_language_japanese(self): self.assertIn("japanese", coder.commit_language) @patch("git.Repo.init") - async def test_main_exit_with_git_command_not_found(self, mock_git_init): + def test_main_exit_with_git_command_not_found(self, mock_git_init): mock_git_init.side_effect = git.exc.GitCommandNotFound("git", "Command 'git' not found") try: - result = await main(["--exit", "--yes"], input=DummyInput(), output=DummyOutput()) + result = main(["--exit", "--yes-always"], input=DummyInput(), output=DummyOutput()) except Exception as e: - self.fail(f"await main() raised an unexpected exception: {e}") + self.fail(f"main() raised an unexpected exception: {e}") - self.assertIsNone(result, "await main() should return None when called with --exit") + self.assertEqual(result, 0, "main() should return 0 (success) when called with --exit") - async def test_reasoning_effort_option(self): - coder = await main( - ["--reasoning-effort", "3", "--no-check-model-accepts-settings", "--yes", "--exit"], + def test_reasoning_effort_option(self): + coder = main( + [ + "--reasoning-effort", + "3", + "--no-check-model-accepts-settings", + "--yes-always", + "--exit", + ], input=DummyInput(), output=DummyOutput(), return_coder=True, @@ -1330,9 +1478,9 @@ async def test_reasoning_effort_option(self): coder.main_model.extra_params.get("extra_body", {}).get("reasoning_effort"), "3" ) - async def test_thinking_tokens_option(self): - coder = await main( - ["--model", "sonnet", "--thinking-tokens", "1000", "--yes", "--exit"], + def test_thinking_tokens_option(self): + coder = main( + ["--model", "sonnet", "--thinking-tokens", "1000", "--yes-always", "--exit"], input=DummyInput(), output=DummyOutput(), return_coder=True, @@ -1341,7 +1489,7 @@ async def test_thinking_tokens_option(self): coder.main_model.extra_params.get("thinking", {}).get("budget_tokens"), 1000 ) - async def test_list_models_includes_metadata_models(self): + def test_list_models_includes_metadata_models(self): # Test that models from model-metadata.json appear in list-models output with GitTemporaryDirectory(): # Create a temporary model-metadata.json with test models @@ -1362,13 +1510,13 @@ async def test_list_models_includes_metadata_models(self): # Capture stdout to check the output with patch("sys.stdout", new_callable=StringIO) as mock_stdout: - await main( + main( [ "--list-models", "unique-model", "--model-metadata-file", str(metadata_file), - "--yes", + "--yes-always", "--no-gitignore", ], input=DummyInput(), @@ -1379,7 +1527,7 @@ async def test_list_models_includes_metadata_models(self): # Check that the unique model name from our metadata file is listed self.assertIn("test-provider/unique-model-name", output) - async def test_list_models_includes_all_model_sources(self): + def test_list_models_includes_all_model_sources(self): # Test that models from both litellm.model_cost and model-metadata.json # appear in list-models with GitTemporaryDirectory(): @@ -1396,13 +1544,13 @@ async def test_list_models_includes_all_model_sources(self): # Capture stdout to check the output with patch("sys.stdout", new_callable=StringIO) as mock_stdout: - await main( + main( [ "--list-models", "metadata-only-model", "--model-metadata-file", str(metadata_file), - "--yes", + "--yes-always", "--no-gitignore", ], input=DummyInput(), @@ -1415,19 +1563,84 @@ async def test_list_models_includes_all_model_sources(self): # Check that both models appear in the output self.assertIn("test-provider/metadata-only-model", output) - async def test_check_model_accepts_settings_flag(self): + def test_list_models_includes_openai_provider(self): + import aider.models as models_module + + provider_name = "openai" + manager = models_module.model_info_manager.provider_manager + provider_config = { + "api_base": "https://api.openai.com/v1", + "models_url": "https://api.openai.com/v1/models", + "api_key_env": ["OPENAI_API_KEY"], + "base_url_env": ["OPENAI_API_BASE"], + "default_headers": {}, + } + + had_config = provider_name in manager.provider_configs + previous_config = manager.provider_configs.get(provider_name) + had_cache = provider_name in manager._provider_cache + previous_cache = manager._provider_cache.get(provider_name) + had_loaded = provider_name in manager._cache_loaded + previous_loaded = manager._cache_loaded.get(provider_name) + + manager.provider_configs[provider_name] = provider_config + manager._provider_cache[provider_name] = None + manager._cache_loaded[provider_name] = False + + payload = { + "data": [ + { + "id": "demo/foo", + "max_input_tokens": 4096, + "pricing": {"prompt": "0.0001", "completion": "0.0002"}, + } + ] + } + + def _fake_get(url, *, headers=None, timeout=None, verify=None): + return types.SimpleNamespace(status_code=200, json=lambda: payload) + + try: + with GitTemporaryDirectory(): + with patch("requests.get", _fake_get): + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + main( + ["--list-models", "openai/demo/foo", "--yes", "--no-gitignore"], + input=DummyInput(), + output=DummyOutput(), + ) + + output = mock_stdout.getvalue() + self.assertIn("openai/demo/foo", output) + finally: + if had_config: + manager.provider_configs[provider_name] = previous_config + else: + manager.provider_configs.pop(provider_name, None) + + if had_cache: + manager._provider_cache[provider_name] = previous_cache + else: + manager._provider_cache.pop(provider_name, None) + + if had_loaded: + manager._cache_loaded[provider_name] = previous_loaded + else: + manager._cache_loaded.pop(provider_name, None) + + def test_check_model_accepts_settings_flag(self): # Test that --check-model-accepts-settings affects whether settings are applied with GitTemporaryDirectory(): # When flag is on, setting shouldn't be applied to non-supporting model with patch("aider.models.Model.set_thinking_tokens") as mock_set_thinking: - await main( + main( [ "--model", "gpt-4o", "--thinking-tokens", "1000", "--check-model-accepts-settings", - "--yes", + "--yes-always", "--exit", ], input=DummyInput(), @@ -1436,7 +1649,7 @@ async def test_check_model_accepts_settings_flag(self): # Method should not be called because model doesn't support it and flag is on mock_set_thinking.assert_not_called() - async def test_list_models_with_direct_resource_patch(self): + def test_list_models_with_direct_resource_patch(self): # Test that models from resources/model-metadata.json are included in list-models output with GitTemporaryDirectory(): # Create a temporary file with test model metadata @@ -1461,8 +1674,8 @@ async def test_list_models_with_direct_resource_patch(self): with patch("aider.main.importlib_resources.files", return_value=mock_files): # Capture stdout to check the output with patch("sys.stdout", new_callable=StringIO) as mock_stdout: - await main( - ["--list-models", "special", "--yes", "--no-gitignore"], + main( + ["--list-models", "special", "--yes-always", "--no-gitignore"], input=DummyInput(), output=DummyOutput(), ) @@ -1473,14 +1686,14 @@ async def test_list_models_with_direct_resource_patch(self): # When flag is off, setting should be applied regardless of support with patch("aider.models.Model.set_reasoning_effort") as mock_set_reasoning: - await main( + main( [ "--model", "gpt-3.5-turbo", "--reasoning-effort", "3", "--no-check-model-accepts-settings", - "--yes", + "--yes-always", "--exit", ], input=DummyInput(), @@ -1489,7 +1702,7 @@ async def test_list_models_with_direct_resource_patch(self): # Method should be called because flag is off mock_set_reasoning.assert_called_once_with("3") - async def test_model_accepts_settings_attribute(self): + def test_model_accepts_settings_attribute(self): with GitTemporaryDirectory(): # Test with a model where we override the accepts_settings attribute with patch("aider.models.Model") as MockModel: @@ -1506,7 +1719,7 @@ async def test_model_accepts_settings_attribute(self): mock_instance.get_weak_model.return_value = None # Run with both settings, but model only accepts reasoning_effort - await main( + main( [ "--model", "test-model", @@ -1515,7 +1728,7 @@ async def test_model_accepts_settings_attribute(self): "--thinking-tokens", "1000", "--check-model-accepts-settings", - "--yes", + "--yes-always", "--exit", ], input=DummyInput(), @@ -1526,12 +1739,13 @@ async def test_model_accepts_settings_attribute(self): mock_instance.set_reasoning_effort.assert_called_once_with("3") mock_instance.set_thinking_tokens.assert_not_called() - @patch("aider.main.InputOutput") - async def test_stream_and_cache_warning(self, MockInputOutput): + @patch("aider.main.InputOutput", autospec=True) + def test_stream_and_cache_warning(self, MockInputOutput): mock_io_instance = MockInputOutput.return_value + mock_io_instance.pretty = True with GitTemporaryDirectory(): - await main( - ["--stream", "--cache-prompts", "--exit", "--yes"], + main( + ["--stream", "--cache-prompts", "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), ) @@ -1539,32 +1753,33 @@ async def test_stream_and_cache_warning(self, MockInputOutput): "Cost estimates may be inaccurate when using streaming and caching." ) - @patch("aider.main.InputOutput") - async def test_stream_without_cache_no_warning(self, MockInputOutput): + @patch("aider.main.InputOutput", autospec=True) + def test_stream_without_cache_no_warning(self, MockInputOutput): mock_io_instance = MockInputOutput.return_value + mock_io_instance.pretty = True with GitTemporaryDirectory(): - await main( - ["--stream", "--exit", "--yes"], + main( + ["--stream", "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), ) for call in mock_io_instance.tool_warning.call_args_list: self.assertNotIn("Cost estimates may be inaccurate", call[0][0]) - async def test_argv_file_respects_git(self): + def test_argv_file_respects_git(self): with GitTemporaryDirectory(): fname = Path("not_in_git.txt") fname.touch() with open(".gitignore", "w+") as f: f.write("not_in_git.txt") - coder = await main( + coder = main( argv=["--file", "not_in_git.txt"], input=DummyInput(), output=DummyOutput(), return_coder=True, ) self.assertNotIn("not_in_git.txt", str(coder.abs_fnames)) - self.assertFalse(await coder.allowed_to_edit("not_in_git.txt")) + self.assertFalse(asyncio.run(coder.allowed_to_edit("not_in_git.txt"))) def test_load_dotenv_files_override(self): with GitTemporaryDirectory() as git_dir: @@ -1625,12 +1840,13 @@ def test_load_dotenv_files_override(self): # Restore CWD os.chdir(original_cwd) - @patch("aider.main.InputOutput") - async def test_cache_without_stream_no_warning(self, MockInputOutput): + @patch("aider.main.InputOutput", autospec=True) + def test_cache_without_stream_no_warning(self, MockInputOutput): mock_io_instance = MockInputOutput.return_value + mock_io_instance.pretty = True with GitTemporaryDirectory(): - await main( - ["--cache-prompts", "--exit", "--yes", "--no-stream"], + main( + ["--cache-prompts", "--exit", "--yes-always", "--no-stream"], input=DummyInput(), output=DummyOutput(), ) @@ -1638,19 +1854,20 @@ async def test_cache_without_stream_no_warning(self, MockInputOutput): self.assertNotIn("Cost estimates may be inaccurate", call[0][0]) @patch("aider.coders.Coder.create") - async def test_mcp_servers_parsing(self, mock_coder_create): + def test_mcp_servers_parsing(self, mock_coder_create): # Setup mock coder mock_coder_instance = MagicMock() + mock_coder_instance._autosave_future = mock_autosave_future() mock_coder_create.return_value = mock_coder_instance # Test with --mcp-servers option with GitTemporaryDirectory(): - await main( + main( [ "--mcp-servers", '{"mcpServers":{"git":{"command":"uvx","args":["mcp-server-git"]}}}', "--exit", - "--yes", + "--yes-always", ], input=DummyInput(), output=DummyOutput(), @@ -1668,6 +1885,7 @@ async def test_mcp_servers_parsing(self, mock_coder_create): # Test with --mcp-servers-file option mock_coder_create.reset_mock() + mock_coder_instance._autosave_future = mock_autosave_future() with GitTemporaryDirectory(): # Create a temporary MCP servers file @@ -1675,8 +1893,8 @@ async def test_mcp_servers_parsing(self, mock_coder_create): mcp_content = {"mcpServers": {"git": {"command": "uvx", "args": ["mcp-server-git"]}}} mcp_file.write_text(json.dumps(mcp_content)) - await main( - ["--mcp-servers-file", str(mcp_file), "--exit", "--yes"], + main( + ["--mcp-servers-file", str(mcp_file), "--exit", "--yes-always"], input=DummyInput(), output=DummyOutput(), ) diff --git a/tests/basic/test_main_smoke.py b/tests/basic/test_main_smoke.py new file mode 100644 index 00000000000..1586d88a3eb --- /dev/null +++ b/tests/basic/test_main_smoke.py @@ -0,0 +1,44 @@ +import os +import platform + +import pytest +from prompt_toolkit.input import DummyInput +from prompt_toolkit.output import DummyOutput + +from aider.main import main, main_async + + +@pytest.fixture(autouse=True) +def isolated_env(tmp_path, monkeypatch, mocker): + """Completely isolated test environment with no real API keys.""" + fake_home = tmp_path / "home" + fake_home.mkdir() + + clean_env = { + "OPENAI_API_KEY": "test-key", + "AIDER_CHECK_UPDATE": "false", + "AIDER_ANALYTICS": "false", + } + + if platform.system() == "Windows": + clean_env["USERPROFILE"] = str(fake_home) + else: + clean_env["HOME"] = str(fake_home) + + mocker.patch.dict(os.environ, clean_env, clear=True) + mocker.patch( + "aider.io.webbrowser.open", + side_effect=AssertionError("Browser should not open during tests"), + ) + mocker.patch("builtins.input", return_value=None) + monkeypatch.chdir(tmp_path) + + yield tmp_path + + +async def test_main_async_executes(): + await main_async(["--exit", "--yes-always"], input=DummyInput(), output=DummyOutput()) + + +def test_main_executes(): + main(["--exit", "--yes-always"], input=DummyInput(), output=DummyOutput()) diff --git a/tests/basic/test_model_provider_manager.py b/tests/basic/test_model_provider_manager.py new file mode 100644 index 00000000000..ed8ac769d2d --- /dev/null +++ b/tests/basic/test_model_provider_manager.py @@ -0,0 +1,364 @@ +import json +import sys +import types + + +def _install_stubs(): + if "PIL" not in sys.modules: + pil_module = types.ModuleType("PIL") + image_module = types.ModuleType("PIL.Image") + image_grab_module = types.ModuleType("PIL.ImageGrab") + + class _DummyImage: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + @property + def size(self): + return (1024, 1024) + + def _dummy_open(*args, **kwargs): + return _DummyImage() + + image_module.open = _dummy_open + image_grab_module.grab = _dummy_open + pil_module.Image = image_module + pil_module.ImageGrab = image_grab_module + sys.modules["PIL"] = pil_module + sys.modules["PIL.Image"] = image_module + sys.modules["PIL.ImageGrab"] = image_grab_module + + if "numpy" not in sys.modules: + numpy_module = types.ModuleType("numpy") + numpy_module.ndarray = object + numpy_module.array = lambda *a, **k: None + numpy_module.dot = lambda *a, **k: 0.0 + numpy_module.linalg = types.SimpleNamespace(norm=lambda *a, **k: 1.0) + sys.modules["numpy"] = numpy_module + + if "oslex" not in sys.modules: + oslex_module = types.ModuleType("oslex") + oslex_module.__all__ = [] + sys.modules["oslex"] = oslex_module + + if "rich" not in sys.modules: + rich_module = types.ModuleType("rich") + console_module = types.ModuleType("rich.console") + + class _DummyConsole: + def __init__(self, *args, **kwargs): + pass + + def status(self, *args, **kwargs): + return self + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def update(self, *args, **kwargs): + return None + + console_module.Console = _DummyConsole + rich_module.console = console_module + sys.modules["rich"] = rich_module + sys.modules["rich.console"] = console_module + + if "pyperclip" not in sys.modules: + pyperclip_module = types.ModuleType("pyperclip") + + class _DummyPyperclipException(Exception): + pass + + pyperclip_module.PyperclipException = _DummyPyperclipException + pyperclip_module.copy = lambda *args, **kwargs: None + sys.modules["pyperclip"] = pyperclip_module + + if "pexpect" not in sys.modules: + pexpect_module = types.ModuleType("pexpect") + + class _DummySpawn: + def __init__(self, *args, **kwargs): + pass + + def sendline(self, *args, **kwargs): + return 0 + + def close(self, *args, **kwargs): + return 0 + + pexpect_module.spawn = _DummySpawn + sys.modules["pexpect"] = pexpect_module + + if "psutil" not in sys.modules: + psutil_module = types.ModuleType("psutil") + + class _DummyProcess: + def __init__(self, *args, **kwargs): + pass + + def children(self, *args, **kwargs): + return [] + + def terminate(self): + return None + + psutil_module.Process = _DummyProcess + sys.modules["psutil"] = psutil_module + + if "pypandoc" not in sys.modules: + pypandoc_module = types.ModuleType("pypandoc") + pypandoc_module.convert_text = lambda *args, **kwargs: "" + sys.modules["pypandoc"] = pypandoc_module + + +_install_stubs() + +from aider.helpers.model_providers import ModelProviderManager # noqa: E402 +from aider.models import MODEL_SETTINGS, Model, ModelInfoManager # noqa: E402 + + +class DummyResponse: + def __init__(self, payload): + self._payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self._payload + + +def _make_manager(tmp_path, config): + manager = ModelProviderManager(provider_configs=config) + manager.cache_dir = tmp_path # Avoid touching real home dir + return manager + + +def test_model_provider_matches_suffix_variants(monkeypatch, tmp_path): + payload = { + "data": [ + { + "id": "demo/model", + "context_length": 2048, + "pricing": {"prompt": "1.0", "completion": "2.0"}, + } + ] + } + + config = { + "openrouter": { + "api_base": "https://openrouter.ai/api/v1", + "models_url": "https://openrouter.ai/api/v1/models", + "requires_api_key": False, + } + } + + manager = _make_manager(tmp_path, config) + cache_file = manager._get_cache_file("openrouter") + cache_file.write_text(json.dumps(payload)) + manager._cache_loaded["openrouter"] = True + manager._provider_cache["openrouter"] = payload + + info = manager.get_model_info("openrouter/demo/model:extended") + + assert info["max_input_tokens"] == 2048 + assert info["input_cost_per_token"] == 1.0 / manager.DEFAULT_TOKEN_PRICE_RATIO + assert info["litellm_provider"] == "openrouter" + + +def test_model_provider_uses_top_provider_context(tmp_path): + payload = { + "data": [ + { + "id": "demo/model", + "top_provider": {"context_length": 4096}, + "pricing": {"prompt": "3", "completion": "4"}, + } + ] + } + + config = { + "demo": { + "api_base": "https://example.com/v1", + "models_url": "https://example.com/v1/models", + "requires_api_key": False, + } + } + + manager = _make_manager(tmp_path, config) + cache_file = manager._get_cache_file("demo") + cache_file.write_text(json.dumps(payload)) + manager._cache_loaded["demo"] = True + manager._provider_cache["demo"] = payload + + info = manager.get_model_info("demo/demo/model") + + assert info["max_input_tokens"] == 4096 + assert info["max_tokens"] == 4096 + assert info["max_output_tokens"] == 4096 + + +def test_fetch_provider_models_injects_headers(monkeypatch, tmp_path): + payload = {"data": []} + captured = {} + + def _fake_get(url, *, headers=None, timeout=None, verify=None): + captured["url"] = url + captured["headers"] = headers + captured["timeout"] = timeout + captured["verify"] = verify + return DummyResponse(payload) + + monkeypatch.setattr("requests.get", _fake_get) + + config = { + "demo": { + "api_base": "https://example.com/v1", + "default_headers": {"X-Test": "demo"}, + "requires_api_key": False, + } + } + + manager = _make_manager(tmp_path, config) + manager.set_verify_ssl(False) + + result = manager._fetch_provider_models("demo") + + assert result == payload + assert captured["url"] == "https://example.com/v1/models" + assert captured["headers"] == {"X-Test": "demo"} + assert captured["timeout"] == 10 + assert captured["verify"] is False + + +def test_get_api_key_prefers_first_valid(monkeypatch, tmp_path): + config = { + "demo": { + "api_base": "https://example.com/v1", + "api_key_env": ["DEMO_FALLBACK", "DEMO_KEY"], + "requires_api_key": True, + } + } + + manager = _make_manager(tmp_path, config) + monkeypatch.delenv("DEMO_FALLBACK", raising=False) + monkeypatch.setenv("DEMO_KEY", "secret") + + assert manager._get_api_key("demo") == "secret" + + +def test_refresh_provider_cache_uses_static_models(monkeypatch, tmp_path): + config = { + "demo": { + "api_base": "https://example.com/v1", + "static_models": [ + { + "id": "demo/foo", + "max_input_tokens": 1024, + "pricing": {"prompt": "0.5", "completion": "1.0"}, + } + ], + } + } + + manager = _make_manager(tmp_path, config) + + def _failing_fetch(*args, **kwargs): + raise RuntimeError("boom") + + monkeypatch.setattr("requests.get", _failing_fetch) + + refreshed = manager.refresh_provider_cache("demo") + + assert refreshed is True + info = manager.get_model_info("demo/demo/foo") + assert info["max_input_tokens"] == 1024 + assert info["input_cost_per_token"] == 0.5 / manager.DEFAULT_TOKEN_PRICE_RATIO + + +def test_model_info_manager_delegates_to_provider(monkeypatch, tmp_path): + monkeypatch.setattr( + "aider.models.litellm", + types.SimpleNamespace( + _lazy_module=None, + get_model_info=lambda *a, **k: {}, + validate_environment=lambda model: {"keys_in_environment": True, "missing_keys": []}, + encode=lambda *a, **k: [], + token_counter=lambda *a, **k: 0, + ), + ) + + stub_info = { + "max_input_tokens": 512, + "max_tokens": 512, + "max_output_tokens": 512, + "input_cost_per_token": 1.0, + "output_cost_per_token": 2.0, + "litellm_provider": "openrouter", + } + + monkeypatch.setattr( + "aider.helpers.model_providers.ModelProviderManager.supports_provider", + lambda self, provider: provider == "openrouter", + ) + monkeypatch.setattr( + "aider.helpers.model_providers.ModelProviderManager.get_model_info", + lambda self, model: stub_info, + ) + + mim = ModelInfoManager() + info = mim.get_model_info("openrouter/demo/model") + + assert info == stub_info + + +def test_model_dynamic_settings_added(monkeypatch, tmp_path): + provider = "demo" + model_name = "demo/org/foo" + manager = ModelInfoManager() + + def _fake_supports(self, prov): + return prov == provider + + def _fake_get(self, model): + return { + "max_input_tokens": 2048, + "max_tokens": 2048, + "max_output_tokens": 2048, + "litellm_provider": provider, + } + + monkeypatch.setattr( + "aider.helpers.model_providers.ModelProviderManager.supports_provider", + _fake_supports, + ) + monkeypatch.setattr( + "aider.helpers.model_providers.ModelProviderManager.get_model_info", + _fake_get, + ) + monkeypatch.setattr( + "aider.models.litellm", + types.SimpleNamespace( + _lazy_module=None, + get_model_info=lambda *a, **k: {}, + validate_environment=lambda model: {"keys_in_environment": True, "missing_keys": []}, + encode=lambda *a, **k: [], + token_counter=lambda *a, **k: 0, + ), + ) + + assert not any(ms.name == model_name for ms in MODEL_SETTINGS) + + info = manager.get_model_info(model_name) + assert info["max_tokens"] == 2048 + + assert any(ms.name == model_name for ms in MODEL_SETTINGS) + + model = Model(model_name) + assert model.info["max_tokens"] == 2048 diff --git a/tests/basic/test_openrouter.py b/tests/basic/test_openrouter.py deleted file mode 100644 index f55c301572c..00000000000 --- a/tests/basic/test_openrouter.py +++ /dev/null @@ -1,73 +0,0 @@ -from pathlib import Path - -from aider.models import ModelInfoManager -from aider.openrouter import OpenRouterModelManager - - -class DummyResponse: - """Minimal stand-in for requests.Response used in tests.""" - - def __init__(self, json_data): - self.status_code = 200 - self._json_data = json_data - - def json(self): - return self._json_data - - -def test_openrouter_get_model_info_from_cache(monkeypatch, tmp_path): - """ - OpenRouterModelManager should return correct metadata taken from the - downloaded (and locally cached) models JSON payload. - """ - payload = { - "data": [ - { - "id": "mistralai/mistral-medium-3", - "context_length": 32768, - "pricing": {"prompt": "100", "completion": "200"}, - "top_provider": {"context_length": 32768}, - } - ] - } - - # Fake out the network call and the HOME directory used for the cache file - monkeypatch.setattr("requests.get", lambda *a, **k: DummyResponse(payload)) - monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) - - manager = OpenRouterModelManager() - info = manager.get_model_info("openrouter/mistralai/mistral-medium-3") - - assert info["max_input_tokens"] == 32768 - assert info["input_cost_per_token"] == 100.0 - assert info["output_cost_per_token"] == 200.0 - assert info["litellm_provider"] == "openrouter" - - -def test_model_info_manager_uses_openrouter_manager(monkeypatch): - """ - ModelInfoManager should delegate to OpenRouterModelManager when litellm - provides no data for an OpenRouter-prefixed model. - """ - # Ensure litellm path returns no info so that fallback logic triggers - monkeypatch.setattr("aider.models.litellm.get_model_info", lambda *a, **k: {}) - - stub_info = { - "max_input_tokens": 512, - "max_tokens": 512, - "max_output_tokens": 512, - "input_cost_per_token": 100.0, - "output_cost_per_token": 200.0, - "litellm_provider": "openrouter", - } - - # Force OpenRouterModelManager to return our stub info - monkeypatch.setattr( - "aider.models.OpenRouterModelManager.get_model_info", - lambda self, model: stub_info, - ) - - mim = ModelInfoManager() - info = mim.get_model_info("openrouter/fake/model") - - assert info == stub_info diff --git a/tests/basic/test_reasoning.py b/tests/basic/test_reasoning.py index 24aa9334197..31bfe3c05ed 100644 --- a/tests/basic/test_reasoning.py +++ b/tests/basic/test_reasoning.py @@ -1,6 +1,10 @@ +import json +import textwrap import unittest from unittest.mock import MagicMock, patch +import litellm + from aider.coders.base_coder import Coder from aider.dump import dump # noqa from aider.io import InputOutput @@ -13,6 +17,43 @@ class TestReasoning(unittest.TestCase): + SYNTHETIC_COMPLETION = textwrap.dedent("""\ + { + "id": "test-completion", + "created": 0, + "model": "synthetic/hf:MiniMaxAI/MiniMax-M2", + "object": "chat.completion", + "system_fingerprint": null, + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "Final synthetic summary of the repository.", + "role": "assistant", + "tool_calls": null, + "function_call": null, + "reasoning_content": "Internal reasoning about how to describe the repo." + }, + "token_ids": null + } + ], + "usage": { + "completion_tokens": 10, + "prompt_tokens": 5, + "total_tokens": 15, + "completion_tokens_details": null, + "prompt_tokens_details": { + "audio_tokens": null, + "cached_tokens": null, + "text_tokens": null, + "image_tokens": null + } + }, + "prompt_token_ids": null + } + """) + async def test_send_with_reasoning_content(self): """Test that reasoning content is properly formatted and output.""" # Setup IO with no pretty @@ -74,6 +115,31 @@ def __init__(self, content, reasoning_content): reasoning_pos, main_pos, "Reasoning content should appear before main content" ) + async def test_reasoning_keeps_answer_block(self): + """Ensure providers returning reasoning+answer still show both sections.""" + io = InputOutput(pretty=False) + io.assistant_output = MagicMock() + model = Model("gpt-4o") + coder = await Coder.create(model, None, io=io, stream=False) + + completion = litellm.ModelResponse(**json.loads(self.SYNTHETIC_COMPLETION)) + mock_hash = MagicMock() + mock_hash.hexdigest.return_value = "hash" + + with patch.object(model, "send_completion", return_value=(mock_hash, completion)): + list(await coder.send([{"role": "user", "content": "describe"}])) + + output = io.assistant_output.call_args[0][0] + self.assertIn(REASONING_START, output) + self.assertIn("Internal reasoning about how to describe the repo.", output) + self.assertIn("Final synthetic summary of the repository.", output) + self.assertIn(REASONING_END, output) + + coder.remove_reasoning_content() + self.assertEqual( + coder.partial_response_content.strip(), "Final synthetic summary of the repository." + ) + async def test_send_with_reasoning_content_stream(self): """Test that streaming reasoning content is properly formatted and output.""" # Setup IO with pretty output for streaming