diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index fa0a3f302bc..3ef9dfa3d2d 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -3266,8 +3266,20 @@ def consolidate_chunks(self): self.partial_response_reasoning_content = reasoning_content or "" try: - if not self.partial_response_reasoning_content: - self.partial_response_content = response.choices[0].message.content or "" + content = response.choices[0].message.content + if isinstance(content, list): + # OpenAI-compatible APIs sometimes return content as a list + # of blocks; join the textual pieces for display. + content = "".join( + block.get("text", "") + for block in content + if isinstance(block, dict) and block.get("type") == "output_text" + ) or "".join( + block.get("text", "") + for block in content + if isinstance(block, dict) and block.get("type") == "text" + ) + self.partial_response_content = content or "" except AttributeError as e: content_err = e diff --git a/aider/commands/model.py b/aider/commands/model.py index f058a2f5615..fd2a2d2b068 100644 --- a/aider/commands/model.py +++ b/aider/commands/model.py @@ -94,10 +94,7 @@ async def execute(cls, io, coder, args, **kwargs): @classmethod def get_completions(cls, io, coder, args) -> List[str]: """Get completion options for model command.""" - from aider.llm import litellm - - model_names = litellm.model_cost.keys() - return list(model_names) + return models.get_chat_model_names() @classmethod def get_help(cls) -> str: diff --git a/aider/commands/models.py b/aider/commands/models.py index 9d9624d1f84..2af2a56771d 100644 --- a/aider/commands/models.py +++ b/aider/commands/models.py @@ -24,10 +24,7 @@ async def execute(cls, io, coder, args, **kwargs): @classmethod def get_completions(cls, io, coder, args) -> List[str]: """Get completion options for models command.""" - from aider.llm import litellm - - model_names = litellm.model_cost.keys() - return list(model_names) + return models.get_chat_model_names() @classmethod def get_help(cls) -> str: diff --git a/aider/helpers/model_providers.py b/aider/helpers/model_providers.py new file mode 100644 index 00000000000..95843829d94 --- /dev/null +++ b/aider/helpers/model_providers.py @@ -0,0 +1,713 @@ +"""Unified model provider metadata caching and lookup. + +Historically aider kept separate modules per provider (OpenRouter vs OpenAI-like). +Those grew unwieldy and duplicated caching, request, and normalization logic. +This helper centralizes that behavior so every OpenAI-compatible endpoint defines +a small config blob and inherits the same cache + LiteLLM registration plumbing. +Provider configs remain curated via ``scripts/generate_providers.py`` and the +static per-model fallback metadata is still cleaned up with ``clean_metadata.py``. +""" + +from __future__ import annotations + +import importlib.resources as importlib_resources +import json +import os +import re +import time +from copy import deepcopy +from pathlib import Path +from typing import Dict, Optional + +import requests + +try: # Optional imports; litellm might not be installed during docs builds + from litellm.llms.custom_httpx.http_handler import HTTPHandler + from litellm.llms.custom_llm import CustomLLM, CustomLLMError + from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler +except Exception: # pragma: no cover - only during partial installs + CustomLLM = None # type: ignore + CustomLLMError = Exception # type: ignore + OpenAILikeChatHandler = None # type: ignore + HTTPHandler = None # type: ignore + +RESOURCE_FILE = "providers.json" +_PROVIDERS_REGISTERED = False +_CUSTOM_HANDLERS: Dict[str, "_JSONOpenAIProvider"] = {} + + +def _coerce_str(value): + """Return the first string representation that litellm expects.""" + if isinstance(value, str): + return value + if isinstance(value, list) and value: + return value[0] + return None + + +def _first_env_value(names): + """Return the first non-empty environment variable for the provided names.""" + if not names: + return None + if isinstance(names, str): + names = [names] + for env_name in names or []: + if not env_name: + continue + val = os.environ.get(env_name) + if val: + return val + return None + + +class _JSONOpenAIProvider(CustomLLM if CustomLLM is not None else object): # type: ignore[misc] + """CustomLLM wrapper that routes OpenAI-compatible providers through LiteLLM.""" + + def __init__(self, slug: str, config: Dict): + if CustomLLM is None or OpenAILikeChatHandler is None: # pragma: no cover + raise RuntimeError("litellm custom handler support unavailable") + super().__init__() # type: ignore[misc] + self.slug = slug + self.config = config + self._chat_handler = OpenAILikeChatHandler() + + def _resolve_api_base(self, api_base: Optional[str]) -> str: + base = ( + api_base + or _first_env_value(self.config.get("base_url_env")) + or self.config.get("api_base") + ) + if not base: + raise CustomLLMError(500, f"{self.slug} missing base URL") # type: ignore[misc] + return base.rstrip("/") + + def _resolve_api_key(self, api_key: Optional[str]) -> Optional[str]: + if api_key: + return api_key + env_val = _first_env_value(self.config.get("api_key_env")) + return env_val + + def _apply_special_handling(self, messages): + special = self.config.get("special_handling") or {} + if special.get("convert_content_list_to_string"): + from litellm.litellm_core_utils.prompt_templates.common_utils import ( + handle_messages_with_content_list_to_str_conversion, + ) + + return handle_messages_with_content_list_to_str_conversion(messages) + return messages + + def _inject_headers(self, headers): + defaults = self.config.get("default_headers") or {} + combined = dict(defaults) + combined.update(headers or {}) + return combined + + def _normalize_model_name(self, model: str) -> str: + if not isinstance(model, str): + return model + trimmed = model + if trimmed.startswith(f"{self.slug}/"): + trimmed = trimmed.split("/", 1)[1] + hf_namespace = self.config.get("hf_namespace") + if hf_namespace and not trimmed.startswith("hf:"): + trimmed = f"hf:{trimmed}" + return trimmed + + def _build_request_params(self, optional_params, stream: bool): + params = dict(optional_params or {}) + default_headers = dict(self.config.get("default_headers") or {}) + headers = params.setdefault("extra_headers", default_headers) + if headers is default_headers and default_headers: + params["extra_headers"] = dict(default_headers) + if stream: + params["stream"] = True + return params + + def _invoke_handler( + self, + *, + model, + messages, + api_base, + custom_prompt_dict, + model_response, + print_verbose, + encoding, + api_key, + logging_obj, + optional_params, + litellm_params, + logger_fn, + headers, + timeout, + client, + stream: bool, + ): + api_base = self._resolve_api_base(api_base) + api_key = self._resolve_api_key(api_key) + headers = self._inject_headers(headers) + params = self._build_request_params(optional_params, stream) + cleaned_messages = self._apply_special_handling(messages) + api_model = self._normalize_model_name(model) + http_client = None + if HTTPHandler is not None and isinstance(client, HTTPHandler): + http_client = client + return self._chat_handler.completion( + model=api_model, + messages=cleaned_messages, + api_base=api_base, + custom_llm_provider="openai", + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=params, + litellm_params=litellm_params or {}, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=http_client, + ) + + def completion( + self, + model, + messages, + api_base, + custom_prompt_dict, + model_response, + print_verbose, + encoding, + api_key, + logging_obj, + optional_params, + litellm_params=None, + acompletion=None, + logger_fn=None, + headers=None, + timeout=None, + client=None, + ): + return self._invoke_handler( + model=model, + messages=messages, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, + stream=False, + ) + + def streaming( + self, + model, + messages, + api_base, + custom_prompt_dict, + model_response, + print_verbose, + encoding, + api_key, + logging_obj, + optional_params, + litellm_params=None, + acompletion=None, + logger_fn=None, + headers=None, + timeout=None, + client=None, + ): + return self._invoke_handler( + model=model, + messages=messages, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, + stream=True, + ) + + def acompletion( + self, + model, + messages, + api_base, + custom_prompt_dict, + model_response, + print_verbose, + encoding, + api_key, + logging_obj, + optional_params, + litellm_params=None, + acompletion=None, + logger_fn=None, + headers=None, + timeout=None, + client=None, + ): + return self.completion( + model=model, + messages=messages, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, + ) + + def astreaming( + self, + model, + messages, + api_base, + custom_prompt_dict, + model_response, + print_verbose, + encoding, + api_key, + logging_obj, + optional_params, + litellm_params=None, + acompletion=None, + logger_fn=None, + headers=None, + timeout=None, + client=None, + ): + return self.streaming( + model=model, + messages=messages, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, + ) + + +def _register_provider_with_litellm(slug: str, config: Dict) -> None: + """Register provider metadata and custom handlers with LiteLLM.""" + try: + from litellm.llms.openai_like.json_loader import ( + JSONProviderRegistry, + SimpleProviderConfig, + ) + except Exception: + return + + JSONProviderRegistry.load() + + base_url = config.get("api_base") + api_key_env = _coerce_str(config.get("api_key_env")) + if not base_url or not api_key_env: + return + + if not JSONProviderRegistry.exists(slug): + payload = { + "base_url": base_url, + "api_key_env": api_key_env, + } + + api_base_env = _coerce_str(config.get("base_url_env")) + if api_base_env: + payload["api_base_env"] = api_base_env + + if config.get("param_mappings"): + payload["param_mappings"] = config["param_mappings"] + if config.get("special_handling"): + payload["special_handling"] = config["special_handling"] + if config.get("base_class"): + payload["base_class"] = config["base_class"] + + JSONProviderRegistry._providers[slug] = SimpleProviderConfig(slug, payload) + + try: + import litellm # noqa: WPS433 + except Exception: + return + + provider_list = getattr(litellm, "provider_list", None) + if isinstance(provider_list, list) and slug not in provider_list: + provider_list.append(slug) + + openai_like = getattr(litellm, "_openai_like_providers", None) + if isinstance(openai_like, list) and slug not in openai_like: + openai_like.append(slug) + + handler = _CUSTOM_HANDLERS.get(slug) + if handler is None and CustomLLM is not None and OpenAILikeChatHandler is not None: + handler = _JSONOpenAIProvider(slug, config) + _CUSTOM_HANDLERS[slug] = handler + + if handler is None: + return + + already_present = any(item.get("provider") == slug for item in litellm.custom_provider_map) + if not already_present: + litellm.custom_provider_map.append({"provider": slug, "custom_handler": handler}) + try: + litellm.custom_llm_setup() + except Exception: + pass + + +def _deep_merge(base: Dict, override: Dict) -> Dict: + """Recursively merge override dict into base without mutating inputs.""" + result = deepcopy(base) + for key, value in override.items(): + if isinstance(value, dict) and isinstance(result.get(key), dict): + result[key] = _deep_merge(result[key], value) + else: + result[key] = deepcopy(value) + return result + + +def _load_provider_configs() -> Dict[str, Dict]: + """Load provider configuration overrides from the packaged JSON file.""" + configs: Dict[str, Dict] = {} + try: + resource = importlib_resources.files("aider.resources").joinpath(RESOURCE_FILE) + data = json.loads(resource.read_text()) + except (FileNotFoundError, json.JSONDecodeError): # pragma: no cover + data = {} + + for provider, override in data.items(): + base = configs.get(provider, {}) + configs[provider] = _deep_merge(base, override) + + return configs + + +PROVIDER_CONFIGS = _load_provider_configs() + + +class ModelProviderManager: + CACHE_TTL = 60 * 60 * 24 # 24 hours + + def __init__(self, provider_configs: Optional[Dict[str, Dict]] = None) -> None: + self.cache_dir = Path.home() / ".aider" / "caches" + self.verify_ssl: bool = True + self.provider_configs = provider_configs or deepcopy(PROVIDER_CONFIGS) + self._provider_cache: Dict[str, Dict | None] = {} + self._cache_loaded: Dict[str, bool] = {} + for name in self.provider_configs: + self._provider_cache[name] = None + self._cache_loaded[name] = False + + def set_verify_ssl(self, verify_ssl: bool) -> None: + self.verify_ssl = verify_ssl + + def supports_provider(self, provider: Optional[str]) -> bool: + return bool(provider and provider in self.provider_configs) + + def get_provider_config(self, provider: Optional[str]) -> Optional[Dict]: + if not provider: + return None + config = self.provider_configs.get(provider) + if not config: + return None + config = dict(config) + config.setdefault("litellm_provider", provider) + return config + + def get_provider_base_url(self, provider: Optional[str]) -> Optional[str]: + config = self.get_provider_config(provider) + if not config: + return None + base_envs = config.get("base_url_env") or [] + for env_var in base_envs: + val = os.environ.get(env_var) + if val: + return val.rstrip("/") + return config.get("api_base") + + def get_required_api_keys(self, provider: Optional[str]) -> list[str]: + config = self.get_provider_config(provider) + if not config: + return [] + return list(config.get("api_key_env", [])) + + def get_model_info(self, model: str) -> Dict: + provider, route = self._split_model(model) + if not provider or not self._ensure_provider_state(provider): + return {} + + content = self._ensure_content(provider) + record = self._find_record(content, route) + if not record and self.refresh_provider_cache(provider): + content = self._provider_cache.get(provider) + record = self._find_record(content, route) + if not record: + return {} + return self._record_to_info(record, provider) + + def get_models_for_listing(self) -> Dict[str, Dict]: + listings: Dict[str, Dict] = {} + for provider in list(self.provider_configs.keys()): + content = self._ensure_content(provider) + if not content or "data" not in content: + continue + for record in content["data"]: + model_id = record.get("id") + if not model_id: + continue + info = self._record_to_info(record, provider) + if info: + listings[model_id] = info + return listings + + def refresh_provider_cache(self, provider: str) -> bool: + if not self._ensure_provider_state(provider): + return False + config = self.provider_configs[provider] + if not config.get("models_url") and not config.get("api_base"): + return False + self._provider_cache[provider] = None + self._cache_loaded[provider] = True + self._update_cache(provider) + return bool(self._provider_cache.get(provider)) + + def _ensure_provider_state(self, provider: str) -> bool: + if provider not in self.provider_configs: + return False + self._provider_cache.setdefault(provider, None) + self._cache_loaded.setdefault(provider, False) + return True + + def _split_model(self, model: str) -> tuple[Optional[str], str]: + if "/" not in model: + return None, model + provider, route = model.split("/", 1) + return provider, route + + def _ensure_content(self, provider: str) -> Optional[Dict]: + self._load_cache(provider) + if not self._provider_cache.get(provider): + self._update_cache(provider) + return self._provider_cache.get(provider) + + def _find_record(self, content: Optional[Dict], route: str) -> Optional[Dict]: + if not content or "data" not in content: + return None + candidates = {route} + if ":" in route: + candidates.add(route.split(":", 1)[0]) + return next((item for item in content["data"] if item.get("id") in candidates), None) + + def _record_to_info(self, record: Dict, provider: str) -> Dict: + context_len = _first_value( + record, + "max_input_tokens", + "max_tokens", + "max_output_tokens", + "context_length", + "context_window", + "top_provider_context_length", + "top_provider", + ) + + if isinstance(context_len, dict): + context_len = context_len.get("context_length") or context_len.get("max_tokens") + + pricing = record.get("pricing", {}) if isinstance(record.get("pricing"), dict) else {} + input_cost = _cost_per_token( + _first_value(pricing, "prompt", "input", "prompt_tokens") + or _first_value(record, "input_cost_per_token", "prompt_cost_per_token") + ) + output_cost = _cost_per_token( + _first_value(pricing, "completion", "output", "completion_tokens") + or _first_value(record, "output_cost_per_token", "completion_cost_per_token") + ) + + max_tokens = _first_value( + record, + "max_tokens", + "max_input_tokens", + "context_length", + "context_window", + "top_provider_context_length", + ) + max_output_tokens = _first_value( + record, + "max_output_tokens", + "max_tokens", + "context_length", + "context_window", + "top_provider_context_length", + ) + + if max_tokens is None: + max_tokens = context_len + if max_output_tokens is None: + max_output_tokens = context_len + + info = { + "max_input_tokens": context_len, + "max_tokens": max_tokens, + "max_output_tokens": max_output_tokens, + "input_cost_per_token": input_cost, + "output_cost_per_token": output_cost, + "litellm_provider": provider, + "mode": record.get("mode", "chat"), + } + return {k: v for k, v in info.items() if v is not None} + + def _get_cache_file(self, provider: str) -> Path: + fname = f"{provider}_models.json" + return self.cache_dir / fname + + def _load_cache(self, provider: str) -> None: + if self._cache_loaded.get(provider): + return + cache_file = self._get_cache_file(provider) + try: + self.cache_dir.mkdir(parents=True, exist_ok=True) + if cache_file.exists(): + cache_age = time.time() - cache_file.stat().st_mtime + if cache_age < self.CACHE_TTL: + try: + self._provider_cache[provider] = json.loads(cache_file.read_text()) + except json.JSONDecodeError: + self._provider_cache[provider] = None + except OSError: + pass + self._cache_loaded[provider] = True + + def _update_cache(self, provider: str) -> None: + payload = self._fetch_provider_models(provider) + cache_file = self._get_cache_file(provider) + + if payload: + self._provider_cache[provider] = payload + try: + cache_file.write_text(json.dumps(payload, indent=2)) + except OSError: + pass + return + + static_models = self.provider_configs[provider].get("static_models") + if static_models and not self._provider_cache.get(provider): + self._provider_cache[provider] = {"data": static_models} + + def _fetch_provider_models(self, provider: str) -> Optional[Dict]: + config = self.provider_configs[provider] + models_url = config.get("models_url") + if not models_url: + api_base = config.get("api_base") + if api_base: + models_url = api_base.rstrip("/") + "/models" + if not models_url: + return None + + headers = {} + default_headers = config.get("default_headers") or {} + headers.update(default_headers) + + api_key = self._get_api_key(provider) + requires_api_key = config.get("requires_api_key", True) + + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + elif requires_api_key: + return None + + try: + response = requests.get( + models_url, + headers=headers or None, + timeout=config.get("timeout", 10), + verify=self.verify_ssl, + ) + response.raise_for_status() + return response.json() + except Exception as ex: # noqa: BLE001 + print(f"Failed to fetch {provider} model list: {ex}") + return None + + def _get_api_key(self, provider: str) -> Optional[str]: + config = self.provider_configs[provider] + for env_var in config.get("api_key_env", []): + value = os.environ.get(env_var) + if value: + return value + return None + + +def ensure_litellm_providers_registered() -> None: + """One-time registration guard for LiteLLM provider metadata.""" + global _PROVIDERS_REGISTERED + if _PROVIDERS_REGISTERED: + return + for slug, cfg in PROVIDER_CONFIGS.items(): + _register_provider_with_litellm(slug, cfg) + _PROVIDERS_REGISTERED = True + + +_NUMBER_RE = re.compile(r"-?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?") + + +def _cost_per_token(val: Optional[str | float | int]) -> Optional[float]: + """Parse token pricing strings into floats, tolerating currency prefixes.""" + if val in (None, "", "-", "N/A"): + return None + if val == "0": + return 0.0 + if isinstance(val, str): + cleaned = val.strip().replace(",", "") + if cleaned.startswith("$"): + cleaned = cleaned[1:] + match = _NUMBER_RE.search(cleaned) + if not match: + return None + val = match.group(0) + try: + return float(val) + except (TypeError, ValueError): + return None + + +def _first_value(record: Dict, *keys: str): + """Return the first non-empty value for the provided keys.""" + for key in keys: + value = record.get(key) + if value not in (None, ""): + return value + return None diff --git a/aider/llm.py b/aider/llm.py index 31a166834c2..ff320dee3d3 100644 --- a/aider/llm.py +++ b/aider/llm.py @@ -6,6 +6,7 @@ from collections.abc import Coroutine from aider.dump import dump # noqa: F401 +from aider.helpers.model_providers import ensure_litellm_providers_registered warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") @@ -53,6 +54,9 @@ def _load_litellm(self): self._lazy_module.drop_params = True self._lazy_module._logging._disable_debugging() + # Make sure JSON-based OpenAI-compatible providers are registered + ensure_litellm_providers_registered() + # Patch GLOBAL_LOGGING_WORKER to avoid event loop binding issues # See: https://github.com/BerriAI/litellm/issues/16518 # See: https://github.com/BerriAI/litellm/issues/14521 diff --git a/aider/main.py b/aider/main.py index a5b79e7137b..61cae3972d1 100644 --- a/aider/main.py +++ b/aider/main.py @@ -1185,6 +1185,7 @@ def apply_model_overrides(model_name): if args.stream: io.tool_warning( f"Warning: Streaming is not supported by {main_model.name}. Disabling streaming." + " Set stream: false in config file or use --no-stream to skip this warning." ) args.stream = False diff --git a/aider/models.py b/aider/models.py index 86c46fd4178..537226e2a27 100644 --- a/aider/models.py +++ b/aider/models.py @@ -17,9 +17,9 @@ from aider import __version__ from aider.dump import dump # noqa: F401 +from aider.helpers.model_providers import ModelProviderManager from aider.helpers.requests import model_request_parser from aider.llm import litellm -from aider.openrouter import OpenRouterModelManager from aider.sendchat import sanity_check_messages from aider.utils import check_pip_install_extra @@ -157,13 +157,13 @@ def __init__(self): self.verify_ssl = True self._cache_loaded = False - # Manager for the cached OpenRouter model database - self.openrouter_manager = OpenRouterModelManager() + # Manager for provider-specific cached model databases + self.provider_manager = ModelProviderManager() + self.openai_provider_manager = self.provider_manager # Backwards compatibility alias def set_verify_ssl(self, verify_ssl): self.verify_ssl = verify_ssl - if hasattr(self, "openrouter_manager"): - self.openrouter_manager.set_verify_ssl(verify_ssl) + self.provider_manager.set_verify_ssl(verify_ssl) def _load_cache(self): if self._cache_loaded: @@ -241,21 +241,45 @@ def get_model_info(self, model): if "model_prices_and_context_window.json" not in str(ex): print(str(ex)) + provider_info = self._resolve_via_provider(model, cached_info) + if provider_info: + return provider_info + if litellm_info: return litellm_info - if not cached_info and model.startswith("openrouter/"): - # First try using the locally cached OpenRouter model database - openrouter_info = self.openrouter_manager.get_model_info(model) - if openrouter_info: - return openrouter_info + return cached_info + + def _resolve_via_provider(self, model, cached_info): + if cached_info: + return None - # Fallback to legacy web-scraping if the API cache does not contain the model + provider = model.split("/", 1)[0] if "/" in model else None + if not self.provider_manager.supports_provider(provider): + return None + + provider_info = self.provider_manager.get_model_info(model) + if provider_info: + self._record_dynamic_model(model, provider_info) + return provider_info + + if provider == "openrouter": openrouter_info = self.fetch_openrouter_model_info(model) if openrouter_info: + openrouter_info.setdefault("litellm_provider", "openrouter") + self._record_dynamic_model(model, openrouter_info) return openrouter_info - return cached_info + return None + + def _record_dynamic_model(self, model, info): + self.local_model_metadata[model] = info + self._ensure_model_settings_entry(model) + + def _ensure_model_settings_entry(self, model): + if any(ms.name == model for ms in MODEL_SETTINGS): + return + MODEL_SETTINGS.append(ModelSettings(name=model)) def fetch_openrouter_model_info(self, model): """ @@ -300,6 +324,7 @@ def fetch_openrouter_model_info(self, model): "max_output_tokens": context_size, "input_cost_per_token": input_cost, "output_cost_per_token": output_cost, + "litellm_provider": "openrouter", } return params except Exception as e: @@ -355,6 +380,7 @@ def __init__( ) self.info = self.get_model_info(model) + self.litellm_provider = (self.info.get("litellm_provider") or "").lower() # Are all needed keys/params available? res = self.validate_environment() @@ -367,6 +393,7 @@ def __init__( self.max_chat_history_tokens = min(max(max_input_tokens / 16, 1024), 8192) self.configure_model_settings(model) + self._apply_provider_defaults() self.get_weak_model(weak_model) if editor_model is False: @@ -690,6 +717,49 @@ def get_editor_model(self, provided_editor_model, editor_edit_format): return self.editor_model + def _ensure_extra_params_dict(self): + if self.extra_params is None: + self.extra_params = {} + elif not isinstance(self.extra_params, dict): + self.extra_params = dict(self.extra_params) + + def _apply_provider_defaults(self): + provider = (self.info.get("litellm_provider") or "").lower() + self.litellm_provider = provider or None + + if not provider: + return + + provider_config = model_info_manager.provider_manager.get_provider_config(provider) + if not provider_config: + return + + self._ensure_extra_params_dict() + self.extra_params.setdefault("custom_llm_provider", provider) + + if provider_config.get("supports_stream") is False: + # Some OpenAI-compatible providers (e.g., Synthetic) only expose the + # non-streaming /chat/completions endpoint, so forcing streaming would + # loop through LiteLLM's fallback and explode mid-response. Disable the + # streaming flag up front so the caller transparently falls back to + # standard completions for those providers. + self.streaming = False + + base_url = model_info_manager.provider_manager.get_provider_base_url(provider) + if base_url: + self.extra_params.setdefault("base_url", base_url) + + default_headers = provider_config.get("default_headers") or {} + if default_headers: + headers = self.extra_params.setdefault("extra_headers", {}) + for key, value in default_headers.items(): + headers.setdefault(key, value) + + provider_extra = provider_config.get("extra_params") or {} + for key, value in provider_extra.items(): + if key not in self.extra_params: + self.extra_params[key] = value + def tokenizer(self, text): return litellm.encode(model=self.name, text=text) @@ -788,6 +858,12 @@ def fast_validate_environment(self): if var and os.environ.get(var): return dict(keys_in_environment=[var], missing_keys=[]) + if not var and provider and model_info_manager.provider_manager.supports_provider(provider): + provider_keys = model_info_manager.provider_manager.get_required_api_keys(provider) + for env_var in provider_keys: + if os.environ.get(env_var): + return dict(keys_in_environment=[env_var], missing_keys=[]) + def validate_environment(self): res = self.fast_validate_environment() if res: @@ -818,6 +894,14 @@ def validate_environment(self): return res provider = self.info.get("litellm_provider", "").lower() + provider_config = model_info_manager.provider_manager.get_provider_config(provider) + if provider_config: + envs = provider_config.get("api_key_env", []) + available = [env for env in envs if os.environ.get(env)] + if available: + return dict(keys_in_environment=available, missing_keys=[]) + if envs: + return dict(keys_in_environment=False, missing_keys=envs) if provider == "cohere_chat": return validate_variables(["COHERE_API_KEY"]) if provider == "gemini": @@ -1304,31 +1388,35 @@ async def check_for_dependencies(io, model_name): ) -def fuzzy_match_models(name): - name = name.lower() - +def get_chat_model_names(): chat_models = set() model_metadata = list(litellm.model_cost.items()) model_metadata += list(model_info_manager.local_model_metadata.items()) + openai_provider_models = model_info_manager.provider_manager.get_models_for_listing() + model_metadata += list(openai_provider_models.items()) + for orig_model, attrs in model_metadata: - model = orig_model.lower() if attrs.get("mode") != "chat": continue - provider = attrs.get("litellm_provider", "").lower() - if not provider: - continue - provider += "/" - - if model.startswith(provider): - fq_model = orig_model - else: - fq_model = provider + orig_model + provider = (attrs.get("litellm_provider") or "").lower() + if provider: + prefix = provider + "/" + if orig_model.lower().startswith(prefix): + fq_model = orig_model + else: + fq_model = f"{provider}/{orig_model}" + chat_models.add(fq_model) - chat_models.add(fq_model) chat_models.add(orig_model) - chat_models = sorted(chat_models) + return sorted(chat_models) + + +def fuzzy_match_models(name): + name = name.lower() + + chat_models = get_chat_model_names() # exactly matching model # matching_models = [ # (fq,m) for fq,m in chat_models @@ -1338,7 +1426,7 @@ def fuzzy_match_models(name): # return matching_models # Check for model names containing the name - matching_models = [m for m in chat_models if name in m] + matching_models = [m for m in chat_models if name in m.lower()] if matching_models: return sorted(set(matching_models)) diff --git a/aider/openrouter.py b/aider/openrouter.py deleted file mode 100644 index ea641c17fda..00000000000 --- a/aider/openrouter.py +++ /dev/null @@ -1,129 +0,0 @@ -""" -OpenRouter model metadata caching and lookup. - -This module keeps a local cached copy of the OpenRouter model list -(downloaded from ``https://openrouter.ai/api/v1/models``) and exposes a -helper class that returns metadata for a given model in a format compatible -with litellm’s ``get_model_info``. -""" - -from __future__ import annotations - -import json -import time -from pathlib import Path -from typing import Dict - -import requests - - -def _cost_per_token(val: str | None) -> float | None: - """Convert a price string (USD per token) to a float.""" - if val in (None, "", "0"): - return 0.0 if val == "0" else None - try: - return float(val) - except Exception: # noqa: BLE001 - return None - - -class OpenRouterModelManager: - MODELS_URL = "https://openrouter.ai/api/v1/models" - CACHE_TTL = 60 * 60 * 24 # 24 h - - def __init__(self) -> None: - self.cache_dir = Path.home() / ".aider" / "caches" - self.cache_file = self.cache_dir / "openrouter_models.json" - self.content: Dict | None = None - self.verify_ssl: bool = True - self._cache_loaded = False - - # ------------------------------------------------------------------ # - # Public API # - # ------------------------------------------------------------------ # - def set_verify_ssl(self, verify_ssl: bool) -> None: - """Enable/disable SSL verification for API requests.""" - self.verify_ssl = verify_ssl - - def get_model_info(self, model: str) -> Dict: - """ - Return metadata for *model* or an empty ``dict`` when unknown. - - ``model`` should use the aider naming convention, e.g. - ``openrouter/nousresearch/deephermes-3-mistral-24b-preview:free``. - """ - self._ensure_content() - if not self.content or "data" not in self.content: - return {} - - route = self._strip_prefix(model) - - # Consider both the exact id and id without any “:suffix”. - candidates = {route} - if ":" in route: - candidates.add(route.split(":", 1)[0]) - - record = next((item for item in self.content["data"] if item.get("id") in candidates), None) - if not record: - return {} - - context_len = ( - record.get("top_provider", {}).get("context_length") - or record.get("context_length") - or None - ) - - pricing = record.get("pricing", {}) - return { - "max_input_tokens": context_len, - "max_tokens": context_len, - "max_output_tokens": context_len, - "input_cost_per_token": _cost_per_token(pricing.get("prompt")), - "output_cost_per_token": _cost_per_token(pricing.get("completion")), - "litellm_provider": "openrouter", - } - - # ------------------------------------------------------------------ # - # Internal helpers # - # ------------------------------------------------------------------ # - def _strip_prefix(self, model: str) -> str: - return model[len("openrouter/") :] if model.startswith("openrouter/") else model - - def _ensure_content(self) -> None: - self._load_cache() - if not self.content: - self._update_cache() - - def _load_cache(self) -> None: - if self._cache_loaded: - return - try: - self.cache_dir.mkdir(parents=True, exist_ok=True) - if self.cache_file.exists(): - cache_age = time.time() - self.cache_file.stat().st_mtime - if cache_age < self.CACHE_TTL: - try: - self.content = json.loads(self.cache_file.read_text()) - except json.JSONDecodeError: - self.content = None - except OSError: - # Cache directory might be unwritable; ignore. - pass - - self._cache_loaded = True - - def _update_cache(self) -> None: - try: - response = requests.get(self.MODELS_URL, timeout=10, verify=self.verify_ssl) - if response.status_code == 200: - self.content = response.json() - try: - self.cache_file.write_text(json.dumps(self.content, indent=2)) - except OSError: - pass # Non-fatal if we can’t write the cache - except Exception as ex: # noqa: BLE001 - print(f"Failed to fetch OpenRouter model list: {ex}") - try: - self.cache_file.write_text("{}") - except OSError: - pass diff --git a/aider/resources/providers.json b/aider/resources/providers.json new file mode 100644 index 00000000000..7c022a21095 --- /dev/null +++ b/aider/resources/providers.json @@ -0,0 +1,90 @@ +{ + "openrouter": { + "api_base": "https://openrouter.ai/api/v1", + "models_url": "https://openrouter.ai/api/v1/models", + "api_key_env": [ + "OPENROUTER_API_KEY" + ], + "requires_api_key": false, + "default_headers": { + "HTTP-Referer": "https://aider.chat", + "X-Title": "aider" + } + }, + "openai": { + "api_base": "https://api.openai.com/v1", + "models_url": "https://api.openai.com/v1/models", + "api_key_env": [ + "OPENAI_API_KEY" + ], + "base_url_env": [ + "OPENAI_API_BASE" + ], + "display_name": "openai" + }, + "apertis": { + "api_base": "https://api.stima.tech/v1", + "api_key_env": [ + "STIMA_API_KEY" + ], + "display_name": "apertis" + }, + "chutes": { + "api_base": "https://llm.chutes.ai/v1/", + "api_key_env": [ + "CHUTES_API_KEY" + ], + "display_name": "chutes" + }, + "helicone": { + "api_base": "https://ai-gateway.helicone.ai/", + "api_key_env": [ + "HELICONE_API_KEY" + ], + "display_name": "helicone" + }, + "nano-gpt": { + "api_base": "https://nano-gpt.com/api/v1", + "api_key_env": [ + "NANOGPT_API_KEY" + ], + "display_name": "nano-gpt" + }, + "poe": { + "api_base": "https://api.poe.com/v1", + "api_key_env": [ + "POE_API_KEY" + ], + "display_name": "poe" + }, + "publicai": { + "api_base": "https://api.publicai.co/v1", + "api_key_env": [ + "PUBLICAI_API_KEY" + ], + "display_name": "publicai" + }, + "synthetic": { + "api_base": "https://api.synthetic.new/openai/v1", + "api_key_env": [ + "SYNTHETIC_API_KEY" + ], + "display_name": "synthetic", + "hf_namespace": true, + "supports_stream": false + }, + "veniceai": { + "api_base": "https://api.venice.ai/api/v1", + "api_key_env": [ + "VENICE_AI_API_KEY" + ], + "display_name": "veniceai" + }, + "xiaomi_mimo": { + "api_base": "https://api.xiaomimimo.com/v1", + "api_key_env": [ + "XIAOMI_MIMO_API_KEY" + ], + "display_name": "xiaomi_mimo" + } +} diff --git a/scripts/generate_providers.py b/scripts/generate_providers.py new file mode 100644 index 00000000000..50d4b08907b --- /dev/null +++ b/scripts/generate_providers.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python +""" +Interactively generate aider/resources/providers.json from litellm data. + +This script reads litellm's openai_like provider definitions and walks the user +through building cecli's provider registry, mirroring the workflow used by +clean_metadata.py (prompting when decisions are needed). +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Any, Dict, Iterable + +AUTO_APPROVE = False + + +def prompt_yes_no(question: str, default: bool = True) -> bool: + """Prompt user for yes/no input, returning bool.""" + + suffix = " [Y/n] " if default else " [y/N] " + if AUTO_APPROVE: + print(f"{question}{suffix}-> {'Y' if default else 'N'} (auto)") + return default + while True: + resp = input(question + suffix).strip().lower() + if not resp: + return default + if resp in ("y", "yes"): + return True + if resp in ("n", "no"): + return False + print("Please enter 'y' or 'n'.") + + +def _format_default(value: str | None) -> str | None: + if value is None: + return None + if value.startswith("[") and value.endswith("]"): + try: + parsed = json.loads(value) + except json.JSONDecodeError: + return value + if isinstance(parsed, list): + return ", ".join(str(item) for item in parsed) + return value + + +def prompt_value(question: str, default: str | None = None) -> str | None: + """Prompt user for a string; empty input keeps default.""" + + display_default = _format_default(default) + suffix = f" [{display_default}]" if display_default is not None else "" + if AUTO_APPROVE: + print(f"{question}{suffix}: -> {display_default or ''} (auto)") + return default + resp = input(f"{question}{suffix}: ").strip() + if not resp: + return default + return resp + + +def ensure_json_object(prompt_text: str, default: Dict[str, Any] | None = None) -> Dict[str, Any]: + """Prompt for a JSON object, re-prompting on parse errors.""" + + default_str = json.dumps(default, indent=2) if default else "" + while True: + raw = prompt_value(prompt_text, default_str) + if not raw: + return default or {} + if AUTO_APPROVE: + try: + parsed = json.loads(raw) + except json.JSONDecodeError: + return default or {} + return parsed + try: + parsed = json.loads(raw) + except json.JSONDecodeError as exc: # pragma: no cover - interactive error path + print(f"Invalid JSON ({exc}). Please try again.") + continue + if not isinstance(parsed, dict): + print('Please provide a JSON object (e.g., {"Header": "value"}).') + continue + return parsed + + +def _list_to_csv(value: Iterable[str] | str | None) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + return ", ".join(str(item) for item in value) + + +def _parse_csv(value: str | None) -> list[str]: + if not value: + return [] + return [item.strip() for item in value.split(",") if item.strip()] + + +def main(): + global AUTO_APPROVE + + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "-y", + "--yes", + "--auto-approve", + dest="auto", + action="store_true", + help="Automatically include all providers and accept defaults without prompting.", + ) + args = parser.parse_args() + AUTO_APPROVE = args.auto + + script_dir = Path(__file__).parent.resolve() + repo_root = script_dir.parent + + litellm_providers_path = ( + script_dir.parent / "../litellm/litellm/llms/openai_like/providers.json" + ).resolve() + output_path = (repo_root / "aider" / "resources" / "providers.json").resolve() + + if not litellm_providers_path.exists(): + print(f"Error: Could not find litellm providers at {litellm_providers_path}") + return + + try: + litellm_data = json.loads(litellm_providers_path.read_text()) + except json.JSONDecodeError as exc: + print(f"Error: Failed to parse litellm providers ({exc}).") + return + + existing = {} + if output_path.exists(): + try: + existing = json.loads(output_path.read_text()) + except json.JSONDecodeError as exc: + print(f"Warning: Existing {output_path} is invalid JSON ({exc}); ignoring.") + + new_config: Dict[str, Dict[str, Any]] = {} + + for provider_name in sorted(litellm_data.keys()): + litellm_entry = litellm_data[provider_name] + existing_entry = existing.get(provider_name, {}) + default_keep = bool(existing_entry) + + print("\n" + "=" * 60) + print(f"Provider: {provider_name}") + print(f" Display name : {litellm_entry.get('display_name', provider_name)}") + print(f" Base URL : {litellm_entry.get('base_url', 'N/A')}") + + api_key_list = litellm_entry.get("api_key_env") + api_key_display = _list_to_csv(api_key_list) if api_key_list else "N/A" + print(f" API key env : {api_key_display}") + + include = prompt_yes_no( + f"Include provider '{provider_name}'?", default=default_keep or True + ) + if not include: + continue + + display_name = prompt_value( + "Display name", + existing_entry.get("display_name") + or litellm_entry.get("display_name") + or provider_name, + ) + api_base = prompt_value( + "API base URL", + existing_entry.get("api_base") or litellm_entry.get("base_url") or "", + ) + base_url_env = prompt_value( + "Comma-separated env vars for overriding base URL", + _list_to_csv(existing_entry.get("base_url_env")) or "", + ) + api_key_env = prompt_value( + "Comma-separated env vars for API key lookup", + _list_to_csv(existing_entry.get("api_key_env", litellm_entry.get("api_key_env", []))) + or "", + ) + models_url = prompt_value( + "Models endpoint URL (leave blank if none)", + existing_entry.get("models_url", ""), + ) + default_headers = ensure_json_object( + "Default headers JSON (empty for none)", + existing_entry.get("default_headers"), + ) + + record: Dict[str, Any] = {} + if display_name: + record["display_name"] = display_name + if api_base: + record["api_base"] = api_base + if api_key_env: + record["api_key_env"] = _parse_csv(api_key_env) + if base_url_env: + record["base_url_env"] = _parse_csv(base_url_env) + if models_url: + record["models_url"] = models_url + if default_headers: + record["default_headers"] = default_headers + + new_config[provider_name] = record + + # Preserve providers that only exist in the existing file (not litellm) if user wants. + for provider_name in sorted(existing.keys()): + if provider_name in new_config or provider_name in litellm_data: + continue + print("\n" + "=" * 60) + print(f"Provider '{provider_name}' exists only in {output_path}.") + if prompt_yes_no("Keep this provider?", default=True): + new_config[provider_name] = existing[provider_name] + + output_path.write_text(json.dumps(new_config, indent=2, sort_keys=True) + "\n") + print(f"\nWrote {len(new_config)} providers to {output_path}.\n") + + +if __name__ == "__main__": + main() diff --git a/tests/basic/test_main.py b/tests/basic/test_main.py index 7ed6564e5c3..7885b2e636f 100644 --- a/tests/basic/test_main.py +++ b/tests/basic/test_main.py @@ -2,6 +2,7 @@ import os import subprocess import tempfile +import types from io import StringIO from pathlib import Path from unittest import TestCase @@ -1415,6 +1416,71 @@ async def test_list_models_includes_all_model_sources(self): # Check that both models appear in the output self.assertIn("test-provider/metadata-only-model", output) + async def test_list_models_includes_openai_provider(self): + import aider.models as models_module + + provider_name = "openai" + manager = models_module.model_info_manager.provider_manager + provider_config = { + "api_base": "https://api.openai.com/v1", + "models_url": "https://api.openai.com/v1/models", + "api_key_env": ["OPENAI_API_KEY"], + "base_url_env": ["OPENAI_API_BASE"], + "default_headers": {}, + } + + had_config = provider_name in manager.provider_configs + previous_config = manager.provider_configs.get(provider_name) + had_cache = provider_name in manager._provider_cache + previous_cache = manager._provider_cache.get(provider_name) + had_loaded = provider_name in manager._cache_loaded + previous_loaded = manager._cache_loaded.get(provider_name) + + manager.provider_configs[provider_name] = provider_config + manager._provider_cache[provider_name] = None + manager._cache_loaded[provider_name] = False + + payload = { + "data": [ + { + "id": "demo/foo", + "max_input_tokens": 4096, + "pricing": {"prompt": "0.0001", "completion": "0.0002"}, + } + ] + } + + def _fake_get(url, *, headers=None, timeout=None, verify=None): + return types.SimpleNamespace(status_code=200, json=lambda: payload) + + try: + with GitTemporaryDirectory(): + with patch("requests.get", _fake_get): + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + await main( + ["--list-models", "openai/demo/foo", "--yes", "--no-gitignore"], + input=DummyInput(), + output=DummyOutput(), + ) + + output = mock_stdout.getvalue() + self.assertIn("openai/demo/foo", output) + finally: + if had_config: + manager.provider_configs[provider_name] = previous_config + else: + manager.provider_configs.pop(provider_name, None) + + if had_cache: + manager._provider_cache[provider_name] = previous_cache + else: + manager._provider_cache.pop(provider_name, None) + + if had_loaded: + manager._cache_loaded[provider_name] = previous_loaded + else: + manager._cache_loaded.pop(provider_name, None) + async def test_check_model_accepts_settings_flag(self): # Test that --check-model-accepts-settings affects whether settings are applied with GitTemporaryDirectory(): diff --git a/tests/basic/test_model_provider_manager.py b/tests/basic/test_model_provider_manager.py new file mode 100644 index 00000000000..a4ddf36917d --- /dev/null +++ b/tests/basic/test_model_provider_manager.py @@ -0,0 +1,364 @@ +import json +import sys +import types + + +def _install_stubs(): + if "PIL" not in sys.modules: + pil_module = types.ModuleType("PIL") + image_module = types.ModuleType("PIL.Image") + image_grab_module = types.ModuleType("PIL.ImageGrab") + + class _DummyImage: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + @property + def size(self): + return (1024, 1024) + + def _dummy_open(*args, **kwargs): + return _DummyImage() + + image_module.open = _dummy_open + image_grab_module.grab = _dummy_open + pil_module.Image = image_module + pil_module.ImageGrab = image_grab_module + sys.modules["PIL"] = pil_module + sys.modules["PIL.Image"] = image_module + sys.modules["PIL.ImageGrab"] = image_grab_module + + if "numpy" not in sys.modules: + numpy_module = types.ModuleType("numpy") + numpy_module.ndarray = object + numpy_module.array = lambda *a, **k: None + numpy_module.dot = lambda *a, **k: 0.0 + numpy_module.linalg = types.SimpleNamespace(norm=lambda *a, **k: 1.0) + sys.modules["numpy"] = numpy_module + + if "oslex" not in sys.modules: + oslex_module = types.ModuleType("oslex") + oslex_module.__all__ = [] + sys.modules["oslex"] = oslex_module + + if "rich" not in sys.modules: + rich_module = types.ModuleType("rich") + console_module = types.ModuleType("rich.console") + + class _DummyConsole: + def __init__(self, *args, **kwargs): + pass + + def status(self, *args, **kwargs): + return self + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def update(self, *args, **kwargs): + return None + + console_module.Console = _DummyConsole + rich_module.console = console_module + sys.modules["rich"] = rich_module + sys.modules["rich.console"] = console_module + + if "pyperclip" not in sys.modules: + pyperclip_module = types.ModuleType("pyperclip") + + class _DummyPyperclipException(Exception): + pass + + pyperclip_module.PyperclipException = _DummyPyperclipException + pyperclip_module.copy = lambda *args, **kwargs: None + sys.modules["pyperclip"] = pyperclip_module + + if "pexpect" not in sys.modules: + pexpect_module = types.ModuleType("pexpect") + + class _DummySpawn: + def __init__(self, *args, **kwargs): + pass + + def sendline(self, *args, **kwargs): + return 0 + + def close(self, *args, **kwargs): + return 0 + + pexpect_module.spawn = _DummySpawn + sys.modules["pexpect"] = pexpect_module + + if "psutil" not in sys.modules: + psutil_module = types.ModuleType("psutil") + + class _DummyProcess: + def __init__(self, *args, **kwargs): + pass + + def children(self, *args, **kwargs): + return [] + + def terminate(self): + return None + + psutil_module.Process = _DummyProcess + sys.modules["psutil"] = psutil_module + + if "pypandoc" not in sys.modules: + pypandoc_module = types.ModuleType("pypandoc") + pypandoc_module.convert_text = lambda *args, **kwargs: "" + sys.modules["pypandoc"] = pypandoc_module + + +_install_stubs() + +from aider.helpers.model_providers import ModelProviderManager # noqa: E402 +from aider.models import MODEL_SETTINGS, Model, ModelInfoManager # noqa: E402 + + +class DummyResponse: + def __init__(self, payload): + self._payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self._payload + + +def _make_manager(tmp_path, config): + manager = ModelProviderManager(provider_configs=config) + manager.cache_dir = tmp_path # Avoid touching real home dir + return manager + + +def test_model_provider_matches_suffix_variants(monkeypatch, tmp_path): + payload = { + "data": [ + { + "id": "demo/model", + "context_length": 2048, + "pricing": {"prompt": "1.0", "completion": "2.0"}, + } + ] + } + + config = { + "openrouter": { + "api_base": "https://openrouter.ai/api/v1", + "models_url": "https://openrouter.ai/api/v1/models", + "requires_api_key": False, + } + } + + manager = _make_manager(tmp_path, config) + cache_file = manager._get_cache_file("openrouter") + cache_file.write_text(json.dumps(payload)) + manager._cache_loaded["openrouter"] = True + manager._provider_cache["openrouter"] = payload + + info = manager.get_model_info("openrouter/demo/model:extended") + + assert info["max_input_tokens"] == 2048 + assert info["input_cost_per_token"] == 1.0 + assert info["litellm_provider"] == "openrouter" + + +def test_model_provider_uses_top_provider_context(tmp_path): + payload = { + "data": [ + { + "id": "demo/model", + "top_provider": {"context_length": 4096}, + "pricing": {"prompt": "3", "completion": "4"}, + } + ] + } + + config = { + "demo": { + "api_base": "https://example.com/v1", + "models_url": "https://example.com/v1/models", + "requires_api_key": False, + } + } + + manager = _make_manager(tmp_path, config) + cache_file = manager._get_cache_file("demo") + cache_file.write_text(json.dumps(payload)) + manager._cache_loaded["demo"] = True + manager._provider_cache["demo"] = payload + + info = manager.get_model_info("demo/demo/model") + + assert info["max_input_tokens"] == 4096 + assert info["max_tokens"] == 4096 + assert info["max_output_tokens"] == 4096 + + +def test_fetch_provider_models_injects_headers(monkeypatch, tmp_path): + payload = {"data": []} + captured = {} + + def _fake_get(url, *, headers=None, timeout=None, verify=None): + captured["url"] = url + captured["headers"] = headers + captured["timeout"] = timeout + captured["verify"] = verify + return DummyResponse(payload) + + monkeypatch.setattr("requests.get", _fake_get) + + config = { + "demo": { + "api_base": "https://example.com/v1", + "default_headers": {"X-Test": "demo"}, + "requires_api_key": False, + } + } + + manager = _make_manager(tmp_path, config) + manager.set_verify_ssl(False) + + result = manager._fetch_provider_models("demo") + + assert result == payload + assert captured["url"] == "https://example.com/v1/models" + assert captured["headers"] == {"X-Test": "demo"} + assert captured["timeout"] == 10 + assert captured["verify"] is False + + +def test_get_api_key_prefers_first_valid(monkeypatch, tmp_path): + config = { + "demo": { + "api_base": "https://example.com/v1", + "api_key_env": ["DEMO_FALLBACK", "DEMO_KEY"], + "requires_api_key": True, + } + } + + manager = _make_manager(tmp_path, config) + monkeypatch.delenv("DEMO_FALLBACK", raising=False) + monkeypatch.setenv("DEMO_KEY", "secret") + + assert manager._get_api_key("demo") == "secret" + + +def test_refresh_provider_cache_uses_static_models(monkeypatch, tmp_path): + config = { + "demo": { + "api_base": "https://example.com/v1", + "static_models": [ + { + "id": "demo/foo", + "max_input_tokens": 1024, + "pricing": {"prompt": "0.5", "completion": "1.0"}, + } + ], + } + } + + manager = _make_manager(tmp_path, config) + + def _failing_fetch(*args, **kwargs): + raise RuntimeError("boom") + + monkeypatch.setattr("requests.get", _failing_fetch) + + refreshed = manager.refresh_provider_cache("demo") + + assert refreshed is True + info = manager.get_model_info("demo/demo/foo") + assert info["max_input_tokens"] == 1024 + assert info["input_cost_per_token"] == 0.5 + + +def test_model_info_manager_delegates_to_provider(monkeypatch, tmp_path): + monkeypatch.setattr( + "aider.models.litellm", + types.SimpleNamespace( + _lazy_module=None, + get_model_info=lambda *a, **k: {}, + validate_environment=lambda model: {"keys_in_environment": True, "missing_keys": []}, + encode=lambda *a, **k: [], + token_counter=lambda *a, **k: 0, + ), + ) + + stub_info = { + "max_input_tokens": 512, + "max_tokens": 512, + "max_output_tokens": 512, + "input_cost_per_token": 1.0, + "output_cost_per_token": 2.0, + "litellm_provider": "openrouter", + } + + monkeypatch.setattr( + "aider.helpers.model_providers.ModelProviderManager.supports_provider", + lambda self, provider: provider == "openrouter", + ) + monkeypatch.setattr( + "aider.helpers.model_providers.ModelProviderManager.get_model_info", + lambda self, model: stub_info, + ) + + mim = ModelInfoManager() + info = mim.get_model_info("openrouter/demo/model") + + assert info == stub_info + + +def test_model_dynamic_settings_added(monkeypatch, tmp_path): + provider = "demo" + model_name = "demo/org/foo" + manager = ModelInfoManager() + + def _fake_supports(self, prov): + return prov == provider + + def _fake_get(self, model): + return { + "max_input_tokens": 2048, + "max_tokens": 2048, + "max_output_tokens": 2048, + "litellm_provider": provider, + } + + monkeypatch.setattr( + "aider.helpers.model_providers.ModelProviderManager.supports_provider", + _fake_supports, + ) + monkeypatch.setattr( + "aider.helpers.model_providers.ModelProviderManager.get_model_info", + _fake_get, + ) + monkeypatch.setattr( + "aider.models.litellm", + types.SimpleNamespace( + _lazy_module=None, + get_model_info=lambda *a, **k: {}, + validate_environment=lambda model: {"keys_in_environment": True, "missing_keys": []}, + encode=lambda *a, **k: [], + token_counter=lambda *a, **k: 0, + ), + ) + + assert not any(ms.name == model_name for ms in MODEL_SETTINGS) + + info = manager.get_model_info(model_name) + assert info["max_tokens"] == 2048 + + assert any(ms.name == model_name for ms in MODEL_SETTINGS) + + model = Model(model_name) + assert model.info["max_tokens"] == 2048 diff --git a/tests/basic/test_openrouter.py b/tests/basic/test_openrouter.py deleted file mode 100644 index f55c301572c..00000000000 --- a/tests/basic/test_openrouter.py +++ /dev/null @@ -1,73 +0,0 @@ -from pathlib import Path - -from aider.models import ModelInfoManager -from aider.openrouter import OpenRouterModelManager - - -class DummyResponse: - """Minimal stand-in for requests.Response used in tests.""" - - def __init__(self, json_data): - self.status_code = 200 - self._json_data = json_data - - def json(self): - return self._json_data - - -def test_openrouter_get_model_info_from_cache(monkeypatch, tmp_path): - """ - OpenRouterModelManager should return correct metadata taken from the - downloaded (and locally cached) models JSON payload. - """ - payload = { - "data": [ - { - "id": "mistralai/mistral-medium-3", - "context_length": 32768, - "pricing": {"prompt": "100", "completion": "200"}, - "top_provider": {"context_length": 32768}, - } - ] - } - - # Fake out the network call and the HOME directory used for the cache file - monkeypatch.setattr("requests.get", lambda *a, **k: DummyResponse(payload)) - monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) - - manager = OpenRouterModelManager() - info = manager.get_model_info("openrouter/mistralai/mistral-medium-3") - - assert info["max_input_tokens"] == 32768 - assert info["input_cost_per_token"] == 100.0 - assert info["output_cost_per_token"] == 200.0 - assert info["litellm_provider"] == "openrouter" - - -def test_model_info_manager_uses_openrouter_manager(monkeypatch): - """ - ModelInfoManager should delegate to OpenRouterModelManager when litellm - provides no data for an OpenRouter-prefixed model. - """ - # Ensure litellm path returns no info so that fallback logic triggers - monkeypatch.setattr("aider.models.litellm.get_model_info", lambda *a, **k: {}) - - stub_info = { - "max_input_tokens": 512, - "max_tokens": 512, - "max_output_tokens": 512, - "input_cost_per_token": 100.0, - "output_cost_per_token": 200.0, - "litellm_provider": "openrouter", - } - - # Force OpenRouterModelManager to return our stub info - monkeypatch.setattr( - "aider.models.OpenRouterModelManager.get_model_info", - lambda self, model: stub_info, - ) - - mim = ModelInfoManager() - info = mim.get_model_info("openrouter/fake/model") - - assert info == stub_info diff --git a/tests/basic/test_reasoning.py b/tests/basic/test_reasoning.py index 24aa9334197..31bfe3c05ed 100644 --- a/tests/basic/test_reasoning.py +++ b/tests/basic/test_reasoning.py @@ -1,6 +1,10 @@ +import json +import textwrap import unittest from unittest.mock import MagicMock, patch +import litellm + from aider.coders.base_coder import Coder from aider.dump import dump # noqa from aider.io import InputOutput @@ -13,6 +17,43 @@ class TestReasoning(unittest.TestCase): + SYNTHETIC_COMPLETION = textwrap.dedent("""\ + { + "id": "test-completion", + "created": 0, + "model": "synthetic/hf:MiniMaxAI/MiniMax-M2", + "object": "chat.completion", + "system_fingerprint": null, + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "Final synthetic summary of the repository.", + "role": "assistant", + "tool_calls": null, + "function_call": null, + "reasoning_content": "Internal reasoning about how to describe the repo." + }, + "token_ids": null + } + ], + "usage": { + "completion_tokens": 10, + "prompt_tokens": 5, + "total_tokens": 15, + "completion_tokens_details": null, + "prompt_tokens_details": { + "audio_tokens": null, + "cached_tokens": null, + "text_tokens": null, + "image_tokens": null + } + }, + "prompt_token_ids": null + } + """) + async def test_send_with_reasoning_content(self): """Test that reasoning content is properly formatted and output.""" # Setup IO with no pretty @@ -74,6 +115,31 @@ def __init__(self, content, reasoning_content): reasoning_pos, main_pos, "Reasoning content should appear before main content" ) + async def test_reasoning_keeps_answer_block(self): + """Ensure providers returning reasoning+answer still show both sections.""" + io = InputOutput(pretty=False) + io.assistant_output = MagicMock() + model = Model("gpt-4o") + coder = await Coder.create(model, None, io=io, stream=False) + + completion = litellm.ModelResponse(**json.loads(self.SYNTHETIC_COMPLETION)) + mock_hash = MagicMock() + mock_hash.hexdigest.return_value = "hash" + + with patch.object(model, "send_completion", return_value=(mock_hash, completion)): + list(await coder.send([{"role": "user", "content": "describe"}])) + + output = io.assistant_output.call_args[0][0] + self.assertIn(REASONING_START, output) + self.assertIn("Internal reasoning about how to describe the repo.", output) + self.assertIn("Final synthetic summary of the repository.", output) + self.assertIn(REASONING_END, output) + + coder.remove_reasoning_content() + self.assertEqual( + coder.partial_response_content.strip(), "Final synthetic summary of the repository." + ) + async def test_send_with_reasoning_content_stream(self): """Test that streaming reasoning content is properly formatted and output.""" # Setup IO with pretty output for streaming