From 33f1be8f084bed5218e10a225dc478bbf91d8979 Mon Sep 17 00:00:00 2001 From: Chris Nestrud Date: Sat, 27 Dec 2025 20:28:30 -0600 Subject: [PATCH 01/14] Add OpenAI-compatible provider registry Introduce OpenAIProviderManager plus JSON-backed metadata to hydrate /models payloads for OpenAI-like providers such as Synthetic. Hook ModelInfoManager, Model, and CLI completions/listings into that registry, expose configuration data in aider/resources/openai_providers.json, and ensure LiteLLM is initialized with the custom handler so cecli can call these endpoints reliably. --- aider/commands/model.py | 5 +- aider/commands/models.py | 5 +- aider/llm.py | 4 + aider/main.py | 3 +- aider/models.py | 116 ++++- aider/openai_providers.py | 699 ++++++++++++++++++++++++++ aider/resources/openai_providers.json | 78 +++ scripts/generate_openai_providers.py | 220 ++++++++ tests/basic/test_main.py | 66 +++ tests/basic/test_openai_providers.py | 402 +++++++++++++++ 10 files changed, 1572 insertions(+), 26 deletions(-) create mode 100644 aider/openai_providers.py create mode 100644 aider/resources/openai_providers.json create mode 100644 scripts/generate_openai_providers.py create mode 100644 tests/basic/test_openai_providers.py diff --git a/aider/commands/model.py b/aider/commands/model.py index f058a2f5615..fd2a2d2b068 100644 --- a/aider/commands/model.py +++ b/aider/commands/model.py @@ -94,10 +94,7 @@ async def execute(cls, io, coder, args, **kwargs): @classmethod def get_completions(cls, io, coder, args) -> List[str]: """Get completion options for model command.""" - from aider.llm import litellm - - model_names = litellm.model_cost.keys() - return list(model_names) + return models.get_chat_model_names() @classmethod def get_help(cls) -> str: diff --git a/aider/commands/models.py b/aider/commands/models.py index 9d9624d1f84..2af2a56771d 100644 --- a/aider/commands/models.py +++ b/aider/commands/models.py @@ -24,10 +24,7 @@ async def execute(cls, io, coder, args, **kwargs): @classmethod def get_completions(cls, io, coder, args) -> List[str]: """Get completion options for models command.""" - from aider.llm import litellm - - model_names = litellm.model_cost.keys() - return list(model_names) + return models.get_chat_model_names() @classmethod def get_help(cls) -> str: diff --git a/aider/llm.py b/aider/llm.py index 31a166834c2..eb74aab25e1 100644 --- a/aider/llm.py +++ b/aider/llm.py @@ -6,6 +6,7 @@ from collections.abc import Coroutine from aider.dump import dump # noqa: F401 +from aider.openai_providers import ensure_litellm_providers_registered warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") @@ -53,6 +54,9 @@ def _load_litellm(self): self._lazy_module.drop_params = True self._lazy_module._logging._disable_debugging() + # Make sure JSON-based OpenAI-compatible providers are registered + ensure_litellm_providers_registered() + # Patch GLOBAL_LOGGING_WORKER to avoid event loop binding issues # See: https://github.com/BerriAI/litellm/issues/16518 # See: https://github.com/BerriAI/litellm/issues/14521 diff --git a/aider/main.py b/aider/main.py index a5b79e7137b..b13512eecdb 100644 --- a/aider/main.py +++ b/aider/main.py @@ -1184,7 +1184,8 @@ def apply_model_overrides(model_name): if not main_model.streaming: if args.stream: io.tool_warning( - f"Warning: Streaming is not supported by {main_model.name}. Disabling streaming." + f"Warning: Streaming is not supported by {main_model.name}. Disabling streaming. " + "Run with --no-stream to skip this warning." ) args.stream = False diff --git a/aider/models.py b/aider/models.py index 86c46fd4178..87519de04bd 100644 --- a/aider/models.py +++ b/aider/models.py @@ -20,6 +20,7 @@ from aider.helpers.requests import model_request_parser from aider.llm import litellm from aider.openrouter import OpenRouterModelManager +from aider.openai_providers import OpenAIProviderManager from aider.sendchat import sanity_check_messages from aider.utils import check_pip_install_extra @@ -157,13 +158,16 @@ def __init__(self): self.verify_ssl = True self._cache_loaded = False - # Manager for the cached OpenRouter model database + # Manager for provider-specific cached model databases self.openrouter_manager = OpenRouterModelManager() + self.openai_provider_manager = OpenAIProviderManager() def set_verify_ssl(self, verify_ssl): self.verify_ssl = verify_ssl if hasattr(self, "openrouter_manager"): self.openrouter_manager.set_verify_ssl(verify_ssl) + if hasattr(self, "openai_provider_manager"): + self.openai_provider_manager.set_verify_ssl(verify_ssl) def _load_cache(self): if self._cache_loaded: @@ -255,6 +259,17 @@ def get_model_info(self, model): if openrouter_info: return openrouter_info + provider = model.split("/", 1)[0] if "/" in model else None + if self.openai_provider_manager.supports_provider(provider): + provider_info = self.openai_provider_manager.get_model_info(model) + if not provider_info and not cached_info: + refreshed = self.openai_provider_manager.refresh_provider_cache(provider) + if refreshed: + provider_info = self.openai_provider_manager.get_model_info(model) + if provider_info: + self.local_model_metadata[model] = provider_info + return provider_info + return cached_info def fetch_openrouter_model_info(self, model): @@ -355,6 +370,7 @@ def __init__( ) self.info = self.get_model_info(model) + self.litellm_provider = (self.info.get("litellm_provider") or "").lower() # Are all needed keys/params available? res = self.validate_environment() @@ -367,6 +383,7 @@ def __init__( self.max_chat_history_tokens = min(max(max_input_tokens / 16, 1024), 8192) self.configure_model_settings(model) + self._apply_provider_defaults() self.get_weak_model(weak_model) if editor_model is False: @@ -690,6 +707,49 @@ def get_editor_model(self, provided_editor_model, editor_edit_format): return self.editor_model + def _ensure_extra_params_dict(self): + if self.extra_params is None: + self.extra_params = {} + elif not isinstance(self.extra_params, dict): + self.extra_params = dict(self.extra_params) + + def _apply_provider_defaults(self): + provider = (self.info.get("litellm_provider") or "").lower() + self.litellm_provider = provider or None + + if not provider: + return + + provider_config = model_info_manager.openai_provider_manager.get_provider_config(provider) + if not provider_config: + return + + self._ensure_extra_params_dict() + self.extra_params.setdefault("custom_llm_provider", provider) + + if provider_config.get("supports_stream") is False: + # Some OpenAI-compatible providers (e.g., Synthetic) only expose the + # non-streaming /chat/completions endpoint, so forcing streaming would + # loop through LiteLLM's fallback and explode mid-response. Disable the + # streaming flag up front so the caller transparently falls back to + # standard completions for those providers. + self.streaming = False + + base_url = model_info_manager.openai_provider_manager.get_provider_base_url(provider) + if base_url: + self.extra_params.setdefault("base_url", base_url) + + default_headers = provider_config.get("default_headers") or {} + if default_headers: + headers = self.extra_params.setdefault("extra_headers", {}) + for key, value in default_headers.items(): + headers.setdefault(key, value) + + provider_extra = provider_config.get("extra_params") or {} + for key, value in provider_extra.items(): + if key not in self.extra_params: + self.extra_params[key] = value + def tokenizer(self, text): return litellm.encode(model=self.name, text=text) @@ -788,6 +848,16 @@ def fast_validate_environment(self): if var and os.environ.get(var): return dict(keys_in_environment=[var], missing_keys=[]) + if ( + not var + and provider + and model_info_manager.openai_provider_manager.supports_provider(provider) + ): + provider_keys = model_info_manager.openai_provider_manager.get_required_api_keys(provider) + for env_var in provider_keys: + if os.environ.get(env_var): + return dict(keys_in_environment=[env_var], missing_keys=[]) + def validate_environment(self): res = self.fast_validate_environment() if res: @@ -818,6 +888,14 @@ def validate_environment(self): return res provider = self.info.get("litellm_provider", "").lower() + provider_config = model_info_manager.openai_provider_manager.get_provider_config(provider) + if provider_config: + envs = provider_config.get("api_key_env", []) + available = [env for env in envs if os.environ.get(env)] + if available: + return dict(keys_in_environment=available, missing_keys=[]) + if envs: + return dict(keys_in_environment=False, missing_keys=envs) if provider == "cohere_chat": return validate_variables(["COHERE_API_KEY"]) if provider == "gemini": @@ -1304,31 +1382,35 @@ async def check_for_dependencies(io, model_name): ) -def fuzzy_match_models(name): - name = name.lower() - +def get_chat_model_names(): chat_models = set() model_metadata = list(litellm.model_cost.items()) model_metadata += list(model_info_manager.local_model_metadata.items()) + openai_provider_models = model_info_manager.openai_provider_manager.get_models_for_listing() + model_metadata += list(openai_provider_models.items()) + for orig_model, attrs in model_metadata: - model = orig_model.lower() if attrs.get("mode") != "chat": continue - provider = attrs.get("litellm_provider", "").lower() - if not provider: - continue - provider += "/" - - if model.startswith(provider): - fq_model = orig_model - else: - fq_model = provider + orig_model + provider = (attrs.get("litellm_provider") or "").lower() + if provider: + prefix = provider + "/" + if orig_model.lower().startswith(prefix): + fq_model = orig_model + else: + fq_model = f"{provider}/{orig_model}" + chat_models.add(fq_model) - chat_models.add(fq_model) chat_models.add(orig_model) - chat_models = sorted(chat_models) + return sorted(chat_models) + + +def fuzzy_match_models(name): + name = name.lower() + + chat_models = get_chat_model_names() # exactly matching model # matching_models = [ # (fq,m) for fq,m in chat_models @@ -1338,7 +1420,7 @@ def fuzzy_match_models(name): # return matching_models # Check for model names containing the name - matching_models = [m for m in chat_models if name in m] + matching_models = [m for m in chat_models if name in m.lower()] if matching_models: return sorted(set(matching_models)) diff --git a/aider/openai_providers.py b/aider/openai_providers.py new file mode 100644 index 00000000000..97e5bb05c8e --- /dev/null +++ b/aider/openai_providers.py @@ -0,0 +1,699 @@ +"""OpenAI-compatible provider metadata caching and lookup. + +This module keeps local cached copies of provider-specific ``/models`` payloads +for OpenAI-compatible endpoints (Synthetic and others). The primary public API +is :class:`OpenAIProviderManager`, which exposes helper methods used throughout +cecli to look up provider details and model metadata. +""" + +from __future__ import annotations +import importlib.resources as importlib_resources +import json +import os +import time +from copy import deepcopy +from pathlib import Path +from typing import Dict, Iterable, Optional + +import requests + +try: # Optional imports; litellm might not be installed during docs builds + from litellm.llms.custom_httpx.http_handler import HTTPHandler + from litellm.llms.custom_llm import CustomLLM, CustomLLMError + from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler +except Exception: # pragma: no cover - only during partial installs + CustomLLM = None # type: ignore + CustomLLMError = Exception # type: ignore + OpenAILikeChatHandler = None # type: ignore + HTTPHandler = None # type: ignore + +RESOURCE_FILE = "openai_providers.json" +_PROVIDERS_REGISTERED = False +_CUSTOM_HANDLERS: Dict[str, "_JSONOpenAIProvider"] = {} + + +def _coerce_str(value): + if isinstance(value, str): + return value + if isinstance(value, list) and value: + return value[0] + return None + + +def _first_env_value(names): + if not names: + return None + if isinstance(names, str): + names = [names] + for env_name in names or []: + if not env_name: + continue + val = os.environ.get(env_name) + if val: + return val + return None + + +class _JSONOpenAIProvider(CustomLLM if CustomLLM is not None else object): # type: ignore[misc] + """CustomLLM wrapper that routes OpenAI-compatible providers through LiteLLM.""" + + def __init__(self, slug: str, config: Dict): + if CustomLLM is None or OpenAILikeChatHandler is None: # pragma: no cover + raise RuntimeError("litellm custom handler support unavailable") + super().__init__() # type: ignore[misc] + self.slug = slug + self.config = config + self._chat_handler = OpenAILikeChatHandler() + + def _resolve_api_base(self, api_base: Optional[str]) -> str: + base = api_base or _first_env_value(self.config.get("base_url_env")) or self.config.get( + "api_base" + ) + if not base: + raise CustomLLMError(500, f"{self.slug} missing base URL") # type: ignore[misc] + return base.rstrip("/") + + def _resolve_api_key(self, api_key: Optional[str]) -> Optional[str]: + if api_key: + return api_key + env_val = _first_env_value(self.config.get("api_key_env")) + return env_val + + def _apply_special_handling(self, messages): + special = self.config.get("special_handling") or {} + if special.get("convert_content_list_to_string"): + from litellm.litellm_core_utils.prompt_templates.common_utils import ( + handle_messages_with_content_list_to_str_conversion, + ) + + return handle_messages_with_content_list_to_str_conversion(messages) + return messages + + def _inject_headers(self, headers): + defaults = self.config.get("default_headers") or {} + combined = dict(defaults) + combined.update(headers or {}) + return combined + + def _normalize_model_name(self, model: str) -> str: + if not isinstance(model, str): + return model + trimmed = model + if trimmed.startswith(f"{self.slug}/"): + trimmed = trimmed.split("/", 1)[1] + hf_namespace = self.config.get("hf_namespace") + if hf_namespace and not trimmed.startswith("hf:"): + trimmed = f"hf:{trimmed}" + return trimmed + + def _build_request_params(self, optional_params, stream: bool): + params = dict(optional_params or {}) + default_headers = dict(self.config.get("default_headers") or {}) + headers = params.setdefault("extra_headers", default_headers) + if headers is default_headers and default_headers: + params["extra_headers"] = dict(default_headers) + if stream: + params["stream"] = True + return params + + def _invoke_handler( + self, + *, + model, + messages, + api_base, + custom_prompt_dict, + model_response, + print_verbose, + encoding, + api_key, + logging_obj, + optional_params, + litellm_params, + logger_fn, + headers, + timeout, + client, + stream: bool, + ): + api_base = self._resolve_api_base(api_base) + api_key = self._resolve_api_key(api_key) + headers = self._inject_headers(headers) + params = self._build_request_params(optional_params, stream) + cleaned_messages = self._apply_special_handling(messages) + api_model = self._normalize_model_name(model) + http_client = None + if HTTPHandler is not None and isinstance(client, HTTPHandler): + http_client = client + return self._chat_handler.completion( + model=api_model, + messages=cleaned_messages, + api_base=api_base, + custom_llm_provider="openai", + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=params, + litellm_params=litellm_params or {}, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=http_client, + ) + + def completion( + self, + model, + messages, + api_base, + custom_prompt_dict, + model_response, + print_verbose, + encoding, + api_key, + logging_obj, + optional_params, + litellm_params=None, + acompletion=None, + logger_fn=None, + headers=None, + timeout=None, + client=None, + ): + return self._invoke_handler( + model=model, + messages=messages, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, + stream=False, + ) + + def streaming( + self, + model, + messages, + api_base, + custom_prompt_dict, + model_response, + print_verbose, + encoding, + api_key, + logging_obj, + optional_params, + litellm_params=None, + acompletion=None, + logger_fn=None, + headers=None, + timeout=None, + client=None, + ): + # The synchronous OpenAILikeChatHandler handles both regular and streaming + # responses; we reuse it even when LiteLLM calls into the async wrappers, + # since many OpenAI-compatible providers (Synthetic, Venice, etc.) only + # support the non-streaming /chat/completions endpoint. True streaming for + # those providers would require a dedicated SSE client layered on top of + # httpx, so for now we normalize them through the sync path. + return self._invoke_handler( + model=model, + messages=messages, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, + stream=True, + ) + + def acompletion( + self, + model, + messages, + api_base, + custom_prompt_dict, + model_response, + print_verbose, + encoding, + api_key, + logging_obj, + optional_params, + litellm_params=None, + acompletion=None, + logger_fn=None, + headers=None, + timeout=None, + client=None, + ): + return self.completion( + model=model, + messages=messages, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, + ) + + def astreaming( + self, + model, + messages, + api_base, + custom_prompt_dict, + model_response, + print_verbose, + encoding, + api_key, + logging_obj, + optional_params, + litellm_params=None, + acompletion=None, + logger_fn=None, + headers=None, + timeout=None, + client=None, + ): + return self.streaming( + model=model, + messages=messages, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, + ) + + +def _register_provider_with_litellm(slug: str, config: Dict) -> None: + """Register provider with litellm's registry and custom handler.""" + try: + from litellm.llms.openai_like.json_loader import ( + JSONProviderRegistry, + SimpleProviderConfig, + ) + except Exception: + return + + JSONProviderRegistry.load() + + base_url = config.get("api_base") + api_key_env = _coerce_str(config.get("api_key_env")) + if not base_url or not api_key_env: + return + + if not JSONProviderRegistry.exists(slug): + payload = { + "base_url": base_url, + "api_key_env": api_key_env, + } + + api_base_env = _coerce_str(config.get("base_url_env")) + if api_base_env: + payload["api_base_env"] = api_base_env + + if config.get("param_mappings"): + payload["param_mappings"] = config["param_mappings"] + if config.get("special_handling"): + payload["special_handling"] = config["special_handling"] + if config.get("base_class"): + payload["base_class"] = config["base_class"] + + JSONProviderRegistry._providers[slug] = SimpleProviderConfig(slug, payload) + + try: + import litellm # noqa: WPS433 + except Exception: + return + + provider_list = getattr(litellm, "provider_list", None) + if isinstance(provider_list, list) and slug not in provider_list: + provider_list.append(slug) + + openai_like = getattr(litellm, "_openai_like_providers", None) + if isinstance(openai_like, list) and slug not in openai_like: + openai_like.append(slug) + + handler = _CUSTOM_HANDLERS.get(slug) + if handler is None and CustomLLM is not None and OpenAILikeChatHandler is not None: + handler = _JSONOpenAIProvider(slug, config) + _CUSTOM_HANDLERS[slug] = handler + + if handler is None: + return + + already_present = any(item.get("provider") == slug for item in litellm.custom_provider_map) + if not already_present: + litellm.custom_provider_map.append({"provider": slug, "custom_handler": handler}) + try: + litellm.custom_llm_setup() + except Exception: + pass + + +def _deep_merge(base: Dict, override: Dict) -> Dict: + result = deepcopy(base) + for key, value in override.items(): + if isinstance(value, dict) and isinstance(result.get(key), dict): + result[key] = _deep_merge(result[key], value) + else: + result[key] = deepcopy(value) + return result + + +def _load_provider_configs() -> Dict[str, Dict]: + configs: Dict[str, Dict] = {} + try: + resource = importlib_resources.files("aider.resources").joinpath(RESOURCE_FILE) + data = json.loads(resource.read_text()) + except (FileNotFoundError, json.JSONDecodeError): # pragma: no cover - fallback path + data = {} + + for provider, override in data.items(): + base = configs.get(provider, {}) + configs[provider] = _deep_merge(base, override) + + return configs + + +PROVIDER_CONFIGS = _load_provider_configs() + + +def ensure_litellm_providers_registered() -> None: + global _PROVIDERS_REGISTERED + if _PROVIDERS_REGISTERED: + return + for slug, cfg in PROVIDER_CONFIGS.items(): + _register_provider_with_litellm(slug, cfg) + _PROVIDERS_REGISTERED = True + + +def _cost_per_token(val: Optional[str | float | int]) -> Optional[float]: + """Convert a price value (USD per token) to a float.""" + if val in (None, "", "-", "N/A"): + return None + if val == "0": + return 0.0 + try: + return float(val) + except (TypeError, ValueError): + return None + + +def _first_value(record: Dict, *keys: Iterable[str]) -> Optional[float]: + """Return the first non-None value from record for the provided keys.""" + for key in keys: + if key in record and record[key] not in (None, ""): + return record[key] + return None + + +class OpenAIProviderManager: + """Cached metadata manager for OpenAI-compatible providers.""" + + CACHE_TTL = 60 * 60 * 24 # 24 hours + + def __init__(self, provider_configs: Optional[Dict[str, Dict]] = None) -> None: + self.cache_dir = Path.home() / ".aider" / "caches" + self.verify_ssl: bool = True + + self.provider_configs = provider_configs or deepcopy(PROVIDER_CONFIGS) + self._provider_cache: Dict[str, Dict | None] = {name: None for name in self.provider_configs} + self._cache_loaded: Dict[str, bool] = {name: False for name in self.provider_configs} + + # ------------------------------------------------------------------ # + # Provider helpers # + # ------------------------------------------------------------------ # + def set_verify_ssl(self, verify_ssl: bool) -> None: + """Enable/disable SSL verification for API requests.""" + self.verify_ssl = verify_ssl + + def supports_provider(self, provider: Optional[str]) -> bool: + return bool(provider and provider in self.provider_configs) + + def get_provider_config(self, provider: Optional[str]) -> Optional[Dict]: + if not provider: + return None + config = self.provider_configs.get(provider) + if not config: + return None + config = dict(config) + config.setdefault("litellm_provider", provider) + return config + + def get_provider_base_url(self, provider: Optional[str]) -> Optional[str]: + config = self.get_provider_config(provider) + if not config: + return None + for env_var in config.get("base_url_env", []): + val = os.environ.get(env_var) + if val: + return val.rstrip("/") + return config.get("api_base") + + def get_required_api_keys(self, provider: Optional[str]) -> list[str]: + config = self.get_provider_config(provider) + if not config: + return [] + return list(config.get("api_key_env", [])) + + # ------------------------------------------------------------------ # + # Model metadata API # + # ------------------------------------------------------------------ # + def get_model_info(self, model: str) -> Dict: + """Return metadata for *model* or an empty ``dict`` when unknown.""" + provider, route = self._split_model(model) + if not self.supports_provider(provider): + return {} + + content = self._ensure_content(provider) + if not content or "data" not in content: + return {} + + candidates = {route} + if ":" in route: + candidates.add(route.split(":", 1)[0]) + + record = next( + (item for item in content["data"] if item.get("id") in candidates), + None, + ) + if not record: + return {} + + return self._record_to_info(record, provider) + + def get_models_for_listing(self) -> Dict[str, Dict]: + """Return all known models keyed by their bare ids across providers.""" + listings: Dict[str, Dict] = {} + for provider in self.provider_configs: + content = self._ensure_content(provider) + if not content or "data" not in content: + continue + for record in content["data"]: + model_id = record.get("id") + if not model_id: + continue + info = self._record_to_info(record, provider) + if not info: + continue + listings[model_id] = info + return listings + + # ------------------------------------------------------------------ # + # Internal helpers # + # ------------------------------------------------------------------ # + def _split_model(self, model: str) -> tuple[Optional[str], str]: + if "/" not in model: + return None, model + provider, route = model.split("/", 1) + return provider, route + + def _ensure_content(self, provider: str) -> Optional[Dict]: + self._load_cache(provider) + if not self._provider_cache.get(provider): + self._update_cache(provider) + return self._provider_cache.get(provider) + + def _record_to_info(self, record: Dict, provider: str) -> Dict: + context_len = _first_value( + record, + "max_input_tokens", + "max_tokens", + "max_output_tokens", + "context_length", + "context_window", + "top_provider_context_length", + ) + pricing = record.get("pricing", {}) if isinstance(record.get("pricing"), dict) else {} + + input_cost = _cost_per_token( + _first_value(pricing, "prompt", "input", "prompt_tokens") + or _first_value(record, "input_cost_per_token", "prompt_cost_per_token") + ) + output_cost = _cost_per_token( + _first_value(pricing, "completion", "output", "completion_tokens") + or _first_value(record, "output_cost_per_token", "completion_cost_per_token") + ) + + max_tokens = _first_value( + record, + "max_tokens", + "max_input_tokens", + "context_length", + "context_window", + "top_provider_context_length", + ) + max_output_tokens = _first_value( + record, + "max_output_tokens", + "max_tokens", + "context_length", + "context_window", + "top_provider_context_length", + ) + + if max_tokens is None: + max_tokens = context_len + if max_output_tokens is None: + max_output_tokens = context_len + + info = { + "max_input_tokens": context_len, + "max_tokens": max_tokens, + "max_output_tokens": max_output_tokens, + "input_cost_per_token": input_cost, + "output_cost_per_token": output_cost, + "litellm_provider": provider, + "mode": record.get("mode", "chat"), + } + + return {k: v for k, v in info.items() if v is not None} + + def refresh_provider_cache(self, provider: str) -> bool: + """Force-refresh the provider's /models cache if supported.""" + if not self.supports_provider(provider): + return False + config = self.provider_configs[provider] + if not config.get("models_url"): + return False + self._provider_cache[provider] = None + self._cache_loaded[provider] = True + self._update_cache(provider) + return bool(self._provider_cache.get(provider)) + + def _get_cache_file(self, provider: str) -> Path: + fname = f"{provider}_models.json" + return self.cache_dir / fname + + def _load_cache(self, provider: str) -> None: + if self._cache_loaded.get(provider): + return + cache_file = self._get_cache_file(provider) + try: + self.cache_dir.mkdir(parents=True, exist_ok=True) + if cache_file.exists(): + cache_age = time.time() - cache_file.stat().st_mtime + if cache_age < self.CACHE_TTL: + try: + self._provider_cache[provider] = json.loads(cache_file.read_text()) + except json.JSONDecodeError: + self._provider_cache[provider] = None + except OSError: + pass + self._cache_loaded[provider] = True + + def _update_cache(self, provider: str) -> None: + payload = self._fetch_provider_models(provider) + cache_file = self._get_cache_file(provider) + + if payload: + self._provider_cache[provider] = payload + try: + cache_file.write_text(json.dumps(payload, indent=2)) + except OSError: + pass + return + + static_models = self.provider_configs[provider].get("static_models") + if static_models and not self._provider_cache.get(provider): + self._provider_cache[provider] = {"data": static_models} + + def _fetch_provider_models(self, provider: str) -> Optional[Dict]: + config = self.provider_configs[provider] + models_url = config.get("models_url") + if not models_url: + api_base = config.get("api_base") + if api_base: + models_url = api_base.rstrip("/") + "/models" + if not models_url: + return None + + headers = {} + default_headers = config.get("default_headers") or {} + headers.update(default_headers) + + api_key = self._get_api_key(provider) + requires_api_key = config.get("requires_api_key", True) + + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + elif requires_api_key: + return None + + try: + response = requests.get( + models_url, + headers=headers or None, + timeout=config.get("timeout", 10), + verify=self.verify_ssl, + ) + response.raise_for_status() + return response.json() + except Exception as ex: # noqa: BLE001 + print(f"Failed to fetch {provider} model list: {ex}") + return None + + def _get_api_key(self, provider: str) -> Optional[str]: + config = self.provider_configs[provider] + for env_var in config.get("api_key_env", []): + value = os.environ.get(env_var) + if value: + return value + return None diff --git a/aider/resources/openai_providers.json b/aider/resources/openai_providers.json new file mode 100644 index 00000000000..1609d2d971b --- /dev/null +++ b/aider/resources/openai_providers.json @@ -0,0 +1,78 @@ +{ + "openai": { + "api_base": "https://api.openai.com/v1", + "models_url": "https://api.openai.com/v1/models", + "api_key_env": [ + "OPENAI_API_KEY" + ], + "base_url_env": [ + "OPENAI_API_BASE" + ], + "display_name": "openai" + }, + "apertis": { + "api_base": "https://api.stima.tech/v1", + "api_key_env": [ + "STIMA_API_KEY" + ], + "display_name": "apertis" + }, + "chutes": { + "api_base": "https://llm.chutes.ai/v1/", + "api_key_env": [ + "CHUTES_API_KEY" + ], + "display_name": "chutes" + }, + "helicone": { + "api_base": "https://ai-gateway.helicone.ai/", + "api_key_env": [ + "HELICONE_API_KEY" + ], + "display_name": "helicone" + }, + "nano-gpt": { + "api_base": "https://nano-gpt.com/api/v1", + "api_key_env": [ + "NANOGPT_API_KEY" + ], + "display_name": "nano-gpt" + }, + "poe": { + "api_base": "https://api.poe.com/v1", + "api_key_env": [ + "POE_API_KEY" + ], + "display_name": "poe" + }, + "publicai": { + "api_base": "https://api.publicai.co/v1", + "api_key_env": [ + "PUBLICAI_API_KEY" + ], + "display_name": "publicai" + }, + "synthetic": { + "api_base": "https://api.synthetic.new/openai/v1", + "api_key_env": [ + "SYNTHETIC_API_KEY" + ], + "display_name": "synthetic", + "hf_namespace": true, + "supports_stream": false + }, + "veniceai": { + "api_base": "https://api.venice.ai/api/v1", + "api_key_env": [ + "VENICE_AI_API_KEY" + ], + "display_name": "veniceai" + }, + "xiaomi_mimo": { + "api_base": "https://api.xiaomimimo.com/v1", + "api_key_env": [ + "XIAOMI_MIMO_API_KEY" + ], + "display_name": "xiaomi_mimo" + } +} diff --git a/scripts/generate_openai_providers.py b/scripts/generate_openai_providers.py new file mode 100644 index 00000000000..34774610a4d --- /dev/null +++ b/scripts/generate_openai_providers.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python +""" +Interactively generate aider/resources/openai_providers.json from litellm data. + +This script reads litellm's openai_like provider definitions and walks the user +through building cecli's provider registry, mirroring the workflow used by +clean_metadata.py (prompting when decisions are needed). +""" + +from __future__ import annotations + +import json +import argparse +from pathlib import Path +from typing import Any, Dict, Iterable + +AUTO_APPROVE = False + + +def prompt_yes_no(question: str, default: bool = True) -> bool: + """Prompt user for yes/no input, returning bool.""" + + suffix = " [Y/n] " if default else " [y/N] " + if AUTO_APPROVE: + print(f"{question}{suffix}-> {'Y' if default else 'N'} (auto)") + return default + while True: + resp = input(question + suffix).strip().lower() + if not resp: + return default + if resp in ("y", "yes"): + return True + if resp in ("n", "no"): + return False + print("Please enter 'y' or 'n'.") + + +def _format_default(value: str | None) -> str | None: + if value is None: + return None + if value.startswith("[") and value.endswith("]"): + try: + parsed = json.loads(value) + except json.JSONDecodeError: + return value + if isinstance(parsed, list): + return ", ".join(str(item) for item in parsed) + return value + + +def prompt_value(question: str, default: str | None = None) -> str | None: + """Prompt user for a string; empty input keeps default.""" + + display_default = _format_default(default) + suffix = f" [{display_default}]" if display_default is not None else "" + if AUTO_APPROVE: + print(f"{question}{suffix}: -> {display_default or ''} (auto)") + return default + resp = input(f"{question}{suffix}: ").strip() + if not resp: + return default + return resp + + +def ensure_json_object(prompt_text: str, default: Dict[str, Any] | None = None) -> Dict[str, Any]: + """Prompt for a JSON object, re-prompting on parse errors.""" + + default_str = json.dumps(default, indent=2) if default else "" + while True: + raw = prompt_value(prompt_text, default_str) + if not raw: + return default or {} + if AUTO_APPROVE: + try: + parsed = json.loads(raw) + except json.JSONDecodeError: + return default or {} + return parsed + try: + parsed = json.loads(raw) + except json.JSONDecodeError as exc: # pragma: no cover - interactive error path + print(f"Invalid JSON ({exc}). Please try again.") + continue + if not isinstance(parsed, dict): + print("Please provide a JSON object (e.g., {\"Header\": \"value\"}).") + continue + return parsed + + +def _list_to_csv(value: Iterable[str] | str | None) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + return ", ".join(str(item) for item in value) + + +def _parse_csv(value: str | None) -> list[str]: + if not value: + return [] + return [item.strip() for item in value.split(",") if item.strip()] + + +def main(): + global AUTO_APPROVE + + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "-y", + "--yes", + "--auto-approve", + dest="auto", + action="store_true", + help="Automatically include all providers and accept defaults without prompting.", + ) + args = parser.parse_args() + AUTO_APPROVE = args.auto + + script_dir = Path(__file__).parent.resolve() + repo_root = script_dir.parent + + litellm_providers_path = ( + script_dir.parent / "../litellm/litellm/llms/openai_like/providers.json" + ).resolve() + output_path = (repo_root / "aider" / "resources" / "openai_providers.json").resolve() + + if not litellm_providers_path.exists(): + print(f"Error: Could not find litellm providers at {litellm_providers_path}") + return + + try: + litellm_data = json.loads(litellm_providers_path.read_text()) + except json.JSONDecodeError as exc: + print(f"Error: Failed to parse litellm providers ({exc}).") + return + + existing = {} + if output_path.exists(): + try: + existing = json.loads(output_path.read_text()) + except json.JSONDecodeError as exc: + print(f"Warning: Existing {output_path} is invalid JSON ({exc}); ignoring.") + + new_config: Dict[str, Dict[str, Any]] = {} + + for provider_name in sorted(litellm_data.keys()): + litellm_entry = litellm_data[provider_name] + existing_entry = existing.get(provider_name, {}) + default_keep = bool(existing_entry) + + print("\n" + "=" * 60) + print(f"Provider: {provider_name}") + print(f" Display name : {litellm_entry.get('display_name', provider_name)}") + print(f" Base URL : {litellm_entry.get('base_url', 'N/A')}") + + api_key_list = litellm_entry.get("api_key_env") + api_key_display = _list_to_csv(api_key_list) if api_key_list else "N/A" + print(f" API key env : {api_key_display}") + + include = prompt_yes_no( + f"Include provider '{provider_name}'?", default=default_keep or True + ) + if not include: + continue + + display_name = prompt_value( + "Display name", existing_entry.get("display_name") or litellm_entry.get("display_name") or provider_name + ) + api_base = prompt_value( + "API base URL", + existing_entry.get("api_base") or litellm_entry.get("base_url") or "", + ) + base_url_env = prompt_value( + "Comma-separated env vars for overriding base URL", + _list_to_csv(existing_entry.get("base_url_env")) or "", + ) + api_key_env = prompt_value( + "Comma-separated env vars for API key lookup", + _list_to_csv(existing_entry.get("api_key_env", litellm_entry.get("api_key_env", []))) or "", + ) + models_url = prompt_value( + "Models endpoint URL (leave blank if none)", + existing_entry.get("models_url", ""), + ) + default_headers = ensure_json_object( + "Default headers JSON (empty for none)", + existing_entry.get("default_headers"), + ) + + record: Dict[str, Any] = {} + if display_name: + record["display_name"] = display_name + if api_base: + record["api_base"] = api_base + if api_key_env: + record["api_key_env"] = _parse_csv(api_key_env) + if base_url_env: + record["base_url_env"] = _parse_csv(base_url_env) + if models_url: + record["models_url"] = models_url + if default_headers: + record["default_headers"] = default_headers + + new_config[provider_name] = record + + # Preserve providers that only exist in the existing file (not litellm) if user wants. + for provider_name in sorted(existing.keys()): + if provider_name in new_config or provider_name in litellm_data: + continue + print("\n" + "=" * 60) + print(f"Provider '{provider_name}' exists only in {output_path}.") + if prompt_yes_no("Keep this provider?", default=True): + new_config[provider_name] = existing[provider_name] + + output_path.write_text(json.dumps(new_config, indent=2, sort_keys=True) + "\n") + print(f"\nWrote {len(new_config)} providers to {output_path}.\n") + + +if __name__ == "__main__": + main() diff --git a/tests/basic/test_main.py b/tests/basic/test_main.py index 7ed6564e5c3..c3998c50d1f 100644 --- a/tests/basic/test_main.py +++ b/tests/basic/test_main.py @@ -2,6 +2,7 @@ import os import subprocess import tempfile +import types from io import StringIO from pathlib import Path from unittest import TestCase @@ -1415,6 +1416,71 @@ async def test_list_models_includes_all_model_sources(self): # Check that both models appear in the output self.assertIn("test-provider/metadata-only-model", output) + async def test_list_models_includes_openai_provider(self): + import aider.models as models_module + + provider_name = "openai" + manager = models_module.model_info_manager.openai_provider_manager + provider_config = { + "api_base": "https://api.openai.com/v1", + "models_url": "https://api.openai.com/v1/models", + "api_key_env": ["OPENAI_API_KEY"], + "base_url_env": ["OPENAI_API_BASE"], + "default_headers": {}, + } + + had_config = provider_name in manager.provider_configs + previous_config = manager.provider_configs.get(provider_name) + had_cache = provider_name in manager._provider_cache + previous_cache = manager._provider_cache.get(provider_name) + had_loaded = provider_name in manager._cache_loaded + previous_loaded = manager._cache_loaded.get(provider_name) + + manager.provider_configs[provider_name] = provider_config + manager._provider_cache[provider_name] = None + manager._cache_loaded[provider_name] = False + + payload = { + "data": [ + { + "id": "demo/foo", + "max_input_tokens": 4096, + "pricing": {"prompt": "0.0001", "completion": "0.0002"}, + } + ] + } + + def _fake_get(url, *, timeout=None, verify=None): + return types.SimpleNamespace(status_code=200, json=lambda: payload) + + try: + with GitTemporaryDirectory(): + with patch("requests.get", _fake_get): + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + await main( + ["--list-models", "openai/demo/foo", "--yes", "--no-gitignore"], + input=DummyInput(), + output=DummyOutput(), + ) + + output = mock_stdout.getvalue() + self.assertIn("openai/demo/foo", output) + finally: + if had_config: + manager.provider_configs[provider_name] = previous_config + else: + manager.provider_configs.pop(provider_name, None) + + if had_cache: + manager._provider_cache[provider_name] = previous_cache + else: + manager._provider_cache.pop(provider_name, None) + + if had_loaded: + manager._cache_loaded[provider_name] = previous_loaded + else: + manager._cache_loaded.pop(provider_name, None) + async def test_check_model_accepts_settings_flag(self): # Test that --check-model-accepts-settings affects whether settings are applied with GitTemporaryDirectory(): diff --git a/tests/basic/test_openai_providers.py b/tests/basic/test_openai_providers.py new file mode 100644 index 00000000000..5be6974472e --- /dev/null +++ b/tests/basic/test_openai_providers.py @@ -0,0 +1,402 @@ +import json +from pathlib import Path +import sys +import types +if "PIL" not in sys.modules: + pil_module = types.ModuleType("PIL") + image_module = types.ModuleType("PIL.Image") + image_grab_module = types.ModuleType("PIL.ImageGrab") + + class _DummyImage: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + @property + def size(self): + return (1024, 1024) + + def _dummy_open(*args, **kwargs): + return _DummyImage() + + image_module.open = _dummy_open + image_grab_module.grab = _dummy_open + pil_module.Image = image_module + pil_module.ImageGrab = image_grab_module + sys.modules["PIL"] = pil_module + sys.modules["PIL.Image"] = image_module + sys.modules["PIL.ImageGrab"] = image_grab_module + +if "numpy" not in sys.modules: + numpy_module = types.ModuleType("numpy") + numpy_module.ndarray = object + numpy_module.array = lambda *a, **k: None + numpy_module.dot = lambda *a, **k: 0.0 + numpy_module.linalg = types.SimpleNamespace(norm=lambda *a, **k: 1.0) + sys.modules["numpy"] = numpy_module + +if "oslex" not in sys.modules: + oslex_module = types.ModuleType("oslex") + oslex_module.__all__ = [] + sys.modules["oslex"] = oslex_module + +if "rich" not in sys.modules: + rich_module = types.ModuleType("rich") + console_module = types.ModuleType("rich.console") + + class _DummyConsole: + def __init__(self, *args, **kwargs): + pass + + def status(self, *args, **kwargs): + return self + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def update(self, *args, **kwargs): + return None + + console_module.Console = _DummyConsole + rich_module.console = console_module + sys.modules["rich"] = rich_module + sys.modules["rich.console"] = console_module + +if "pyperclip" not in sys.modules: + pyperclip_module = types.ModuleType("pyperclip") + + class _DummyPyperclipException(Exception): + pass + + pyperclip_module.PyperclipException = _DummyPyperclipException + pyperclip_module.copy = lambda *args, **kwargs: None + sys.modules["pyperclip"] = pyperclip_module + +if "pexpect" not in sys.modules: + pexpect_module = types.ModuleType("pexpect") + + class _DummySpawn: + def __init__(self, *args, **kwargs): + pass + + def sendline(self, *args, **kwargs): + return 0 + + def close(self, *args, **kwargs): + return 0 + + pexpect_module.spawn = _DummySpawn + sys.modules["pexpect"] = pexpect_module + +if "psutil" not in sys.modules: + psutil_module = types.ModuleType("psutil") + + class _DummyProcess: + def __init__(self, *args, **kwargs): + pass + + def children(self, *args, **kwargs): + return [] + + def terminate(self): + return None + + psutil_module.Process = _DummyProcess + sys.modules["psutil"] = psutil_module + +if "pypandoc" not in sys.modules: + pypandoc_module = types.ModuleType("pypandoc") + pypandoc_module.convert_text = lambda *args, **kwargs: "" + sys.modules["pypandoc"] = pypandoc_module + +import aider.models as models_module +from aider.commands.model import ModelCommand +from aider.models import ModelInfoManager +from aider.openai_providers import OpenAIProviderManager, _JSONOpenAIProvider + + +class DummyResponse: + """Minimal stand-in for requests.Response used in tests.""" + + def __init__(self, json_data): + self.status_code = 200 + self._json_data = json_data + + def json(self): + return self._json_data + + def raise_for_status(self): + return None + + +def _load_openai_fixture(): + return { + "data": [ + { + "id": "zai-org/GLM-4.6", + "object": "model", + "created": 1723500000, + "owned_by": "openai", + "max_input_tokens": 131072, + "max_output_tokens": 131072, + "max_tokens": 131072, + "context_length": 131072, + "context_window": 131072, + "top_provider_context_length": 131072, + "pricing": { + "prompt": "0.00000055", + "completion": "0.00000219", + }, + }, + { + "id": "zai-org/GLM-4.6:extended", + "object": "model", + "created": 1723500001, + "owned_by": "openai", + "max_tokens": 65536, + "pricing": { + "prompt": "0.00000060", + "completion": "0.00000250", + }, + }, + ] + } + + +def _test_provider_config(): + return { + "openai": { + "api_base": "https://api.openai.com/v1", + "models_url": "https://api.openai.com/v1/models", + "api_key_env": ["OPENAI_API_KEY"], + "base_url_env": ["OPENAI_API_BASE"], + "default_headers": {}, + } + } + + +def test_provider_manager_get_model_info_from_cache(monkeypatch, tmp_path): + """OpenAIProviderManager should hydrate from cached payloads.""" + + payload = _load_openai_fixture() + + def _fail_request(*args, **kwargs): # pragma: no cover - should never be called + raise AssertionError("Network request should not be made when cache is valid") + + monkeypatch.setattr("requests.get", _fail_request) + monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) + + manager = OpenAIProviderManager(provider_configs=_test_provider_config()) + manager.cache_dir.mkdir(parents=True, exist_ok=True) + cache_file = manager._get_cache_file("openai") + cache_file.write_text(json.dumps(payload)) + + info = manager.get_model_info("openai/zai-org/GLM-4.6:extended") + + assert info["max_input_tokens"] == 131072 + assert info["max_output_tokens"] == 131072 + assert info["max_tokens"] == 131072 + assert info["input_cost_per_token"] == 0.00000055 + assert info["output_cost_per_token"] == 0.00000219 + assert info["litellm_provider"] == "openai" + assert manager._cache_loaded["openai"] + + + +def test_provider_manager_models_endpoint_fetch(monkeypatch, tmp_path): + """OpenAIProviderManager should fetch and cache the /models payload when missing.""" + + payload = _load_openai_fixture() + call_args = [] + + def _recording_request(url, *, headers=None, timeout=None, verify=None): + call_args.append((url, headers, timeout, verify)) + return DummyResponse(payload) + + monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) + monkeypatch.setattr("requests.get", _recording_request) + monkeypatch.setenv("OPENAI_API_KEY", "test-key") + + provider_config = _test_provider_config() + manager = OpenAIProviderManager(provider_configs=provider_config) + manager.set_verify_ssl(False) + + info = manager.get_model_info("openai/zai-org/GLM-4.6:extended") + + expected_url = provider_config["openai"]["models_url"] + assert call_args == [ + ( + expected_url, + {"Authorization": "Bearer test-key"}, + 10, + False, + ) + ] + assert info["max_input_tokens"] == 131072 + assert info["max_output_tokens"] == 131072 + assert info["max_tokens"] == 131072 + assert info["input_cost_per_token"] == 0.00000055 + assert info["output_cost_per_token"] == 0.00000219 + + info_again = manager.get_model_info("openai/zai-org/GLM-4.6") + + +def test_provider_static_models_used_without_api_key(monkeypatch, tmp_path): + payload = _load_openai_fixture() + provider_config = _test_provider_config() + provider_config["openai"]["static_models"] = payload["data"] + + def _fail_request(*args, **kwargs): # pragma: no cover - should not run + raise AssertionError("Network request should not be attempted without API key") + + monkeypatch.setattr("requests.get", _fail_request) + monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) + + manager = OpenAIProviderManager(provider_configs=provider_config) + info = manager.get_model_info("openai/zai-org/GLM-4.6") + + assert info["litellm_provider"] == "openai" + assert info["max_tokens"] == 131072 + + +def test_model_info_manager_uses_openai_provider_manager(monkeypatch): + """ModelInfoManager should delegate to OpenAIProviderManager for openai-like models.""" + + monkeypatch.setattr( + models_module, + "litellm", + types.SimpleNamespace(_lazy_module=None, get_model_info=lambda *a, **k: {}), + ) + + stub_info = { + "max_input_tokens": 1024, + "max_tokens": 1024, + "max_output_tokens": 1024, + "input_cost_per_token": 0.0001, + "output_cost_per_token": 0.0002, + "litellm_provider": "openai", + } + + monkeypatch.setattr( + "aider.models.OpenAIProviderManager.get_model_info", + lambda self, model: stub_info, + ) + + mim = ModelInfoManager() + info = mim.get_model_info("openai/demo/model") + + assert info == stub_info + + +def test_openai_provider_manager_listing(monkeypatch, tmp_path): + payload = _load_openai_fixture() + monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) + manager = OpenAIProviderManager(provider_configs=_test_provider_config()) + manager.cache_dir.mkdir(parents=True, exist_ok=True) + cache_file = manager._get_cache_file("openai") + cache_file.write_text(json.dumps(payload)) + + listings = manager.get_models_for_listing() + + assert "zai-org/GLM-4.6" in listings + assert listings["zai-org/GLM-4.6"]["litellm_provider"] == "openai" + assert listings["zai-org/GLM-4.6"]["mode"] == "chat" + + +def test_chat_model_names_include_openai_provider_models(monkeypatch, tmp_path): + payload = _load_openai_fixture() + monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) + + import aider.models as models_module + + models_module.litellm = types.SimpleNamespace(model_cost={}, _lazy_module=None) + models_module.model_info_manager = models_module.ModelInfoManager() + models_module.model_info_manager.openai_provider_manager = OpenAIProviderManager( + provider_configs=_test_provider_config() + ) + manager = models_module.model_info_manager.openai_provider_manager + manager.cache_dir.mkdir(parents=True, exist_ok=True) + cache_file = manager._get_cache_file("openai") + cache_file.write_text(json.dumps(payload)) + + names = models_module.get_chat_model_names() + + assert "openai/zai-org/GLM-4.6" in names + + +def test_model_command_completions_include_openai_provider_models(monkeypatch, tmp_path): + payload = _load_openai_fixture() + monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) + + import aider.models as models_module + + models_module.litellm = types.SimpleNamespace(model_cost={}, _lazy_module=None) + models_module.model_info_manager = models_module.ModelInfoManager() + models_module.model_info_manager.openai_provider_manager = OpenAIProviderManager( + provider_configs=_test_provider_config() + ) + manager = models_module.model_info_manager.openai_provider_manager + manager.cache_dir.mkdir(parents=True, exist_ok=True) + cache_file = manager._get_cache_file("openai") + cache_file.write_text(json.dumps(payload)) + + completions = ModelCommand.get_completions(io=None, coder=None, args="") + + assert "openai/zai-org/GLM-4.6" in completions + + +def test_model_disables_streaming_for_non_streaming_providers(monkeypatch): + provider_configs = { + "synthetic": { + "api_base": "https://api.synthetic.new/openai/v1", + "api_key_env": ["SYNTHETIC_API_KEY"], + "supports_stream": False, + } + } + + provider_manager = OpenAIProviderManager(provider_configs=provider_configs) + fake_info = { + "max_input_tokens": 4096, + "max_tokens": 4096, + "max_output_tokens": 4096, + "litellm_provider": "synthetic", + } + + fake_model_info_manager = types.SimpleNamespace( + get_model_info=lambda model: fake_info, + openai_provider_manager=provider_manager, + ) + + monkeypatch.setenv("SYNTHETIC_API_KEY", "test-key") + monkeypatch.setattr(models_module, "model_info_manager", fake_model_info_manager) + monkeypatch.setattr( + models_module, + "litellm", + types.SimpleNamespace( + encode=lambda *a, **k: [], + token_counter=lambda *a, **k: 0, + validate_environment=lambda model: {"keys_in_environment": True, "missing_keys": []}, + ), + ) + + model = models_module.Model("synthetic/deepseek-ai/DeepSeek-V3.1") + + assert model.streaming is False + assert model.extra_params["custom_llm_provider"] == "synthetic" + + +def test_json_provider_hf_namespace_normalization(): + provider = object.__new__(_JSONOpenAIProvider) + provider.slug = "synthetic" + provider.config = {"hf_namespace": True} + + rewritten = provider._normalize_model_name("synthetic/deepseek-ai/DeepSeek-V3.1") + assert rewritten == "hf:deepseek-ai/DeepSeek-V3.1" + + unchanged = provider._normalize_model_name("hf:deepseek-ai/DeepSeek-V3.1") + assert unchanged == "hf:deepseek-ai/DeepSeek-V3.1" From 5fb0c20e5024f0308fcd110ee12009aeb4f73f15 Mon Sep 17 00:00:00 2001 From: Chris Nestrud Date: Sat, 27 Dec 2025 20:56:37 -0600 Subject: [PATCH 02/14] Normalize provider pricing strings Some OpenAI-compatible providers emit costs like '/bin/bash.00000055', which our float parser treated as invalid and left the UI without per-token pricing (e.g., synthetic MiniMax-M2). Strip currency symbols/commas before parsing and add a regression test that proves static model caches with dollar-prefixed pricing still populate ModelInfoManager. --- aider/openai_providers.py | 13 +++++++++++++ tests/basic/test_openai_providers.py | 29 ++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/aider/openai_providers.py b/aider/openai_providers.py index 97e5bb05c8e..c6a0376343d 100644 --- a/aider/openai_providers.py +++ b/aider/openai_providers.py @@ -7,6 +7,7 @@ """ from __future__ import annotations + import importlib.resources as importlib_resources import json import os @@ -14,6 +15,7 @@ from copy import deepcopy from pathlib import Path from typing import Dict, Iterable, Optional +import re import requests @@ -424,12 +426,23 @@ def ensure_litellm_providers_registered() -> None: _PROVIDERS_REGISTERED = True +_NUMBER_RE = re.compile(r"-?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?") + + def _cost_per_token(val: Optional[str | float | int]) -> Optional[float]: """Convert a price value (USD per token) to a float.""" if val in (None, "", "-", "N/A"): return None if val == "0": return 0.0 + if isinstance(val, str): + cleaned = val.strip().replace(",", "") + if cleaned.startswith("$"): + cleaned = cleaned[1:] + match = _NUMBER_RE.search(cleaned) + if not match: + return None + val = match.group(0) try: return float(val) except (TypeError, ValueError): diff --git a/tests/basic/test_openai_providers.py b/tests/basic/test_openai_providers.py index 5be6974472e..61882be5049 100644 --- a/tests/basic/test_openai_providers.py +++ b/tests/basic/test_openai_providers.py @@ -1,7 +1,9 @@ import json +import math from pathlib import Path import sys import types +import pytest if "PIL" not in sys.modules: pil_module = types.ModuleType("PIL") image_module = types.ModuleType("PIL.Image") @@ -264,6 +266,33 @@ def _fail_request(*args, **kwargs): # pragma: no cover - should not run assert info["max_tokens"] == 131072 +def test_provider_models_price_strings(monkeypatch, tmp_path): + payload = { + "data": [ + { + "id": "demo/model", + "max_input_tokens": 4096, + "pricing": {"prompt": "$0.00000055", "completion": "$0.00000219"}, + } + ] + } + + provider_config = _test_provider_config() + provider_config["openai"]["static_models"] = payload["data"] + + def _fail_request(*args, **kwargs): # pragma: no cover - should not run + raise AssertionError("Network fetch should be skipped when static models exist") + + monkeypatch.setattr("requests.get", _fail_request) + monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) + + manager = OpenAIProviderManager(provider_configs=provider_config) + info = manager.get_model_info("openai/demo/model") + + assert math.isclose(info["input_cost_per_token"], 0.00000055) + assert math.isclose(info["output_cost_per_token"], 0.00000219) + + def test_model_info_manager_uses_openai_provider_manager(monkeypatch): """ModelInfoManager should delegate to OpenAIProviderManager for openai-like models.""" From b3a8ae95c31884a37d16c18e95b0996702e52219 Mon Sep 17 00:00:00 2001 From: Chris Nestrud Date: Sat, 27 Dec 2025 21:04:53 -0600 Subject: [PATCH 03/14] Ensure non-streaming answers survive reasoning blocks Synthetic/other OpenAI-like providers returned both reasoning_content and content, but our consolidation skipped storing the final content whenever reasoning existed, so cecli printed only the THINKING section. Always capture the message.content (including list-style OpenAI blocks) and add a regression test that feeds a recorded MiniMax completion via a heredoc JSON snippet to assert both THINKING and ANSWER text render. --- aider/coders/base_coder.py | 16 +++++++-- tests/basic/test_reasoning.py | 67 +++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 2 deletions(-) diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index fa0a3f302bc..3ef9dfa3d2d 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -3266,8 +3266,20 @@ def consolidate_chunks(self): self.partial_response_reasoning_content = reasoning_content or "" try: - if not self.partial_response_reasoning_content: - self.partial_response_content = response.choices[0].message.content or "" + content = response.choices[0].message.content + if isinstance(content, list): + # OpenAI-compatible APIs sometimes return content as a list + # of blocks; join the textual pieces for display. + content = "".join( + block.get("text", "") + for block in content + if isinstance(block, dict) and block.get("type") == "output_text" + ) or "".join( + block.get("text", "") + for block in content + if isinstance(block, dict) and block.get("type") == "text" + ) + self.partial_response_content = content or "" except AttributeError as e: content_err = e diff --git a/tests/basic/test_reasoning.py b/tests/basic/test_reasoning.py index 24aa9334197..7bf1eb6eea0 100644 --- a/tests/basic/test_reasoning.py +++ b/tests/basic/test_reasoning.py @@ -1,6 +1,10 @@ +import json +import textwrap import unittest from unittest.mock import MagicMock, patch +import litellm + from aider.coders.base_coder import Coder from aider.dump import dump # noqa from aider.io import InputOutput @@ -13,6 +17,44 @@ class TestReasoning(unittest.TestCase): + SYNTHETIC_COMPLETION = textwrap.dedent( + """\ + { + "id": "test-completion", + "created": 0, + "model": "synthetic/hf:MiniMaxAI/MiniMax-M2", + "object": "chat.completion", + "system_fingerprint": null, + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "Final synthetic summary of the repository.", + "role": "assistant", + "tool_calls": null, + "function_call": null, + "reasoning_content": "Internal reasoning about how to describe the repo." + }, + "token_ids": null + } + ], + "usage": { + "completion_tokens": 10, + "prompt_tokens": 5, + "total_tokens": 15, + "completion_tokens_details": null, + "prompt_tokens_details": { + "audio_tokens": null, + "cached_tokens": null, + "text_tokens": null, + "image_tokens": null + } + }, + "prompt_token_ids": null + } + """ + ) async def test_send_with_reasoning_content(self): """Test that reasoning content is properly formatted and output.""" # Setup IO with no pretty @@ -74,6 +116,31 @@ def __init__(self, content, reasoning_content): reasoning_pos, main_pos, "Reasoning content should appear before main content" ) + async def test_reasoning_keeps_answer_block(self): + """Ensure providers returning reasoning+answer still show both sections.""" + io = InputOutput(pretty=False) + io.assistant_output = MagicMock() + model = Model("gpt-4o") + coder = await Coder.create(model, None, io=io, stream=False) + + completion = litellm.ModelResponse(**json.loads(self.SYNTHETIC_COMPLETION)) + mock_hash = MagicMock() + mock_hash.hexdigest.return_value = "hash" + + with patch.object(model, "send_completion", return_value=(mock_hash, completion)): + list(await coder.send([{"role": "user", "content": "describe"}])) + + output = io.assistant_output.call_args[0][0] + self.assertIn(REASONING_START, output) + self.assertIn("Internal reasoning about how to describe the repo.", output) + self.assertIn("Final synthetic summary of the repository.", output) + self.assertIn(REASONING_END, output) + + coder.remove_reasoning_content() + self.assertEqual( + coder.partial_response_content.strip(), "Final synthetic summary of the repository." + ) + async def test_send_with_reasoning_content_stream(self): """Test that streaming reasoning content is properly formatted and output.""" # Setup IO with pretty output for streaming From 01f530b3d34ec7a2ee0cbc96bd239f16bb251cb8 Mon Sep 17 00:00:00 2001 From: Chris Nestrud Date: Sat, 27 Dec 2025 21:20:31 -0600 Subject: [PATCH 04/14] Run formatting hooks after provider/test updates MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reinstalled the dev toolchain, ran the documented pre-commit hooks, and applied the resulting isort/black fixes across provider modules plus removed an unused import/variable in tests so the branch now passes the project’s formatting gate. --- aider/models.py | 6 ++++-- aider/openai_providers.py | 12 ++++++++---- scripts/generate_openai_providers.py | 12 ++++++++---- tests/basic/test_openai_providers.py | 7 ++----- tests/basic/test_reasoning.py | 7 +++---- 5 files changed, 25 insertions(+), 19 deletions(-) diff --git a/aider/models.py b/aider/models.py index 87519de04bd..0513141289b 100644 --- a/aider/models.py +++ b/aider/models.py @@ -19,8 +19,8 @@ from aider.dump import dump # noqa: F401 from aider.helpers.requests import model_request_parser from aider.llm import litellm -from aider.openrouter import OpenRouterModelManager from aider.openai_providers import OpenAIProviderManager +from aider.openrouter import OpenRouterModelManager from aider.sendchat import sanity_check_messages from aider.utils import check_pip_install_extra @@ -853,7 +853,9 @@ def fast_validate_environment(self): and provider and model_info_manager.openai_provider_manager.supports_provider(provider) ): - provider_keys = model_info_manager.openai_provider_manager.get_required_api_keys(provider) + provider_keys = model_info_manager.openai_provider_manager.get_required_api_keys( + provider + ) for env_var in provider_keys: if os.environ.get(env_var): return dict(keys_in_environment=[env_var], missing_keys=[]) diff --git a/aider/openai_providers.py b/aider/openai_providers.py index c6a0376343d..ab9894ff4c3 100644 --- a/aider/openai_providers.py +++ b/aider/openai_providers.py @@ -11,11 +11,11 @@ import importlib.resources as importlib_resources import json import os +import re import time from copy import deepcopy from pathlib import Path from typing import Dict, Iterable, Optional -import re import requests @@ -68,8 +68,10 @@ def __init__(self, slug: str, config: Dict): self._chat_handler = OpenAILikeChatHandler() def _resolve_api_base(self, api_base: Optional[str]) -> str: - base = api_base or _first_env_value(self.config.get("base_url_env")) or self.config.get( - "api_base" + base = ( + api_base + or _first_env_value(self.config.get("base_url_env")) + or self.config.get("api_base") ) if not base: raise CustomLLMError(500, f"{self.slug} missing base URL") # type: ignore[misc] @@ -467,7 +469,9 @@ def __init__(self, provider_configs: Optional[Dict[str, Dict]] = None) -> None: self.verify_ssl: bool = True self.provider_configs = provider_configs or deepcopy(PROVIDER_CONFIGS) - self._provider_cache: Dict[str, Dict | None] = {name: None for name in self.provider_configs} + self._provider_cache: Dict[str, Dict | None] = { + name: None for name in self.provider_configs + } self._cache_loaded: Dict[str, bool] = {name: False for name in self.provider_configs} # ------------------------------------------------------------------ # diff --git a/scripts/generate_openai_providers.py b/scripts/generate_openai_providers.py index 34774610a4d..22b98ade8ad 100644 --- a/scripts/generate_openai_providers.py +++ b/scripts/generate_openai_providers.py @@ -9,8 +9,8 @@ from __future__ import annotations -import json import argparse +import json from pathlib import Path from typing import Any, Dict, Iterable @@ -82,7 +82,7 @@ def ensure_json_object(prompt_text: str, default: Dict[str, Any] | None = None) print(f"Invalid JSON ({exc}). Please try again.") continue if not isinstance(parsed, dict): - print("Please provide a JSON object (e.g., {\"Header\": \"value\"}).") + print('Please provide a JSON object (e.g., {"Header": "value"}).') continue return parsed @@ -164,7 +164,10 @@ def main(): continue display_name = prompt_value( - "Display name", existing_entry.get("display_name") or litellm_entry.get("display_name") or provider_name + "Display name", + existing_entry.get("display_name") + or litellm_entry.get("display_name") + or provider_name, ) api_base = prompt_value( "API base URL", @@ -176,7 +179,8 @@ def main(): ) api_key_env = prompt_value( "Comma-separated env vars for API key lookup", - _list_to_csv(existing_entry.get("api_key_env", litellm_entry.get("api_key_env", []))) or "", + _list_to_csv(existing_entry.get("api_key_env", litellm_entry.get("api_key_env", []))) + or "", ) models_url = prompt_value( "Models endpoint URL (leave blank if none)", diff --git a/tests/basic/test_openai_providers.py b/tests/basic/test_openai_providers.py index 61882be5049..48edb72bab8 100644 --- a/tests/basic/test_openai_providers.py +++ b/tests/basic/test_openai_providers.py @@ -1,9 +1,9 @@ import json import math -from pathlib import Path import sys import types -import pytest +from pathlib import Path + if "PIL" not in sys.modules: pil_module = types.ModuleType("PIL") image_module = types.ModuleType("PIL.Image") @@ -209,7 +209,6 @@ def _fail_request(*args, **kwargs): # pragma: no cover - should never be called assert manager._cache_loaded["openai"] - def test_provider_manager_models_endpoint_fetch(monkeypatch, tmp_path): """OpenAIProviderManager should fetch and cache the /models payload when missing.""" @@ -245,8 +244,6 @@ def _recording_request(url, *, headers=None, timeout=None, verify=None): assert info["input_cost_per_token"] == 0.00000055 assert info["output_cost_per_token"] == 0.00000219 - info_again = manager.get_model_info("openai/zai-org/GLM-4.6") - def test_provider_static_models_used_without_api_key(monkeypatch, tmp_path): payload = _load_openai_fixture() diff --git a/tests/basic/test_reasoning.py b/tests/basic/test_reasoning.py index 7bf1eb6eea0..31bfe3c05ed 100644 --- a/tests/basic/test_reasoning.py +++ b/tests/basic/test_reasoning.py @@ -17,8 +17,7 @@ class TestReasoning(unittest.TestCase): - SYNTHETIC_COMPLETION = textwrap.dedent( - """\ + SYNTHETIC_COMPLETION = textwrap.dedent("""\ { "id": "test-completion", "created": 0, @@ -53,8 +52,8 @@ class TestReasoning(unittest.TestCase): }, "prompt_token_ids": null } - """ - ) + """) + async def test_send_with_reasoning_content(self): """Test that reasoning content is properly formatted and output.""" # Setup IO with no pretty From e5978a69dc76617ff707cbef27feaf9ce4aee88c Mon Sep 17 00:00:00 2001 From: Chris Nestrud Date: Sat, 27 Dec 2025 23:05:01 -0600 Subject: [PATCH 05/14] docs: clarify streaming warning with config key guidance Co-authored-by: aider-ce (synthetic/hf:deepseek-ai/DeepSeek-V3.2) --- aider/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aider/main.py b/aider/main.py index b13512eecdb..61cae3972d1 100644 --- a/aider/main.py +++ b/aider/main.py @@ -1184,8 +1184,8 @@ def apply_model_overrides(model_name): if not main_model.streaming: if args.stream: io.tool_warning( - f"Warning: Streaming is not supported by {main_model.name}. Disabling streaming. " - "Run with --no-stream to skip this warning." + f"Warning: Streaming is not supported by {main_model.name}. Disabling streaming." + " Set stream: false in config file or use --no-stream to skip this warning." ) args.stream = False From 61ca039f853366cb5857a68d81ab076cb96e8460 Mon Sep 17 00:00:00 2001 From: Chris Nestrud Date: Sun, 28 Dec 2025 13:00:05 -0600 Subject: [PATCH 06/14] feat: unify model provider metadata management --- .../model_providers.py} | 197 ++++++++---------- aider/llm.py | 2 +- aider/models.py | 72 ++++--- aider/openrouter.py | 129 ------------ aider/resources/openai_providers.json | 12 ++ tests/basic/test_main.py | 4 +- tests/basic/test_openai_providers.py | 36 ++-- tests/basic/test_openrouter.py | 147 ++++++++++++- 8 files changed, 303 insertions(+), 296 deletions(-) rename aider/{openai_providers.py => helpers/model_providers.py} (87%) delete mode 100644 aider/openrouter.py diff --git a/aider/openai_providers.py b/aider/helpers/model_providers.py similarity index 87% rename from aider/openai_providers.py rename to aider/helpers/model_providers.py index ab9894ff4c3..d05eeccbf38 100644 --- a/aider/openai_providers.py +++ b/aider/helpers/model_providers.py @@ -1,10 +1,4 @@ -"""OpenAI-compatible provider metadata caching and lookup. - -This module keeps local cached copies of provider-specific ``/models`` payloads -for OpenAI-compatible endpoints (Synthetic and others). The primary public API -is :class:`OpenAIProviderManager`, which exposes helper methods used throughout -cecli to look up provider details and model metadata. -""" +"""Unified model provider metadata caching and lookup.""" from __future__ import annotations @@ -15,7 +9,7 @@ import time from copy import deepcopy from pathlib import Path -from typing import Dict, Iterable, Optional +from typing import Dict, Optional import requests @@ -225,12 +219,6 @@ def streaming( timeout=None, client=None, ): - # The synchronous OpenAILikeChatHandler handles both regular and streaming - # responses; we reuse it even when LiteLLM calls into the async wrappers, - # since many OpenAI-compatible providers (Synthetic, Venice, etc.) only - # support the non-streaming /chat/completions endpoint. True streaming for - # those providers would require a dedicated SSE client layered on top of - # httpx, so for now we normalize them through the sync path. return self._invoke_handler( model=model, messages=messages, @@ -326,7 +314,6 @@ def astreaming( def _register_provider_with_litellm(slug: str, config: Dict) -> None: - """Register provider with litellm's registry and custom handler.""" try: from litellm.llms.openai_like.json_loader import ( JSONProviderRegistry, @@ -406,7 +393,7 @@ def _load_provider_configs() -> Dict[str, Dict]: try: resource = importlib_resources.files("aider.resources").joinpath(RESOURCE_FILE) data = json.loads(resource.read_text()) - except (FileNotFoundError, json.JSONDecodeError): # pragma: no cover - fallback path + except (FileNotFoundError, json.JSONDecodeError): # pragma: no cover data = {} for provider, override in data.items(): @@ -419,66 +406,20 @@ def _load_provider_configs() -> Dict[str, Dict]: PROVIDER_CONFIGS = _load_provider_configs() -def ensure_litellm_providers_registered() -> None: - global _PROVIDERS_REGISTERED - if _PROVIDERS_REGISTERED: - return - for slug, cfg in PROVIDER_CONFIGS.items(): - _register_provider_with_litellm(slug, cfg) - _PROVIDERS_REGISTERED = True - - -_NUMBER_RE = re.compile(r"-?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?") - - -def _cost_per_token(val: Optional[str | float | int]) -> Optional[float]: - """Convert a price value (USD per token) to a float.""" - if val in (None, "", "-", "N/A"): - return None - if val == "0": - return 0.0 - if isinstance(val, str): - cleaned = val.strip().replace(",", "") - if cleaned.startswith("$"): - cleaned = cleaned[1:] - match = _NUMBER_RE.search(cleaned) - if not match: - return None - val = match.group(0) - try: - return float(val) - except (TypeError, ValueError): - return None - - -def _first_value(record: Dict, *keys: Iterable[str]) -> Optional[float]: - """Return the first non-None value from record for the provided keys.""" - for key in keys: - if key in record and record[key] not in (None, ""): - return record[key] - return None - - -class OpenAIProviderManager: - """Cached metadata manager for OpenAI-compatible providers.""" - +class ModelProviderManager: CACHE_TTL = 60 * 60 * 24 # 24 hours def __init__(self, provider_configs: Optional[Dict[str, Dict]] = None) -> None: self.cache_dir = Path.home() / ".aider" / "caches" self.verify_ssl: bool = True - self.provider_configs = provider_configs or deepcopy(PROVIDER_CONFIGS) - self._provider_cache: Dict[str, Dict | None] = { - name: None for name in self.provider_configs - } - self._cache_loaded: Dict[str, bool] = {name: False for name in self.provider_configs} + self._provider_cache: Dict[str, Dict | None] = {} + self._cache_loaded: Dict[str, bool] = {} + for name in self.provider_configs: + self._provider_cache[name] = None + self._cache_loaded[name] = False - # ------------------------------------------------------------------ # - # Provider helpers # - # ------------------------------------------------------------------ # def set_verify_ssl(self, verify_ssl: bool) -> None: - """Enable/disable SSL verification for API requests.""" self.verify_ssl = verify_ssl def supports_provider(self, provider: Optional[str]) -> bool: @@ -498,7 +439,8 @@ def get_provider_base_url(self, provider: Optional[str]) -> Optional[str]: config = self.get_provider_config(provider) if not config: return None - for env_var in config.get("base_url_env", []): + base_envs = config.get("base_url_env") or [] + for env_var in base_envs: val = os.environ.get(env_var) if val: return val.rstrip("/") @@ -510,36 +452,23 @@ def get_required_api_keys(self, provider: Optional[str]) -> list[str]: return [] return list(config.get("api_key_env", [])) - # ------------------------------------------------------------------ # - # Model metadata API # - # ------------------------------------------------------------------ # def get_model_info(self, model: str) -> Dict: - """Return metadata for *model* or an empty ``dict`` when unknown.""" provider, route = self._split_model(model) - if not self.supports_provider(provider): + if not provider or not self._ensure_provider_state(provider): return {} content = self._ensure_content(provider) - if not content or "data" not in content: - return {} - - candidates = {route} - if ":" in route: - candidates.add(route.split(":", 1)[0]) - - record = next( - (item for item in content["data"] if item.get("id") in candidates), - None, - ) + record = self._find_record(content, route) + if not record and self.refresh_provider_cache(provider): + content = self._provider_cache.get(provider) + record = self._find_record(content, route) if not record: return {} - return self._record_to_info(record, provider) def get_models_for_listing(self) -> Dict[str, Dict]: - """Return all known models keyed by their bare ids across providers.""" listings: Dict[str, Dict] = {} - for provider in self.provider_configs: + for provider in list(self.provider_configs.keys()): content = self._ensure_content(provider) if not content or "data" not in content: continue @@ -548,14 +477,28 @@ def get_models_for_listing(self) -> Dict[str, Dict]: if not model_id: continue info = self._record_to_info(record, provider) - if not info: - continue - listings[model_id] = info + if info: + listings[model_id] = info return listings - # ------------------------------------------------------------------ # - # Internal helpers # - # ------------------------------------------------------------------ # + def refresh_provider_cache(self, provider: str) -> bool: + if not self._ensure_provider_state(provider): + return False + config = self.provider_configs[provider] + if not config.get("models_url") and not config.get("api_base"): + return False + self._provider_cache[provider] = None + self._cache_loaded[provider] = True + self._update_cache(provider) + return bool(self._provider_cache.get(provider)) + + def _ensure_provider_state(self, provider: str) -> bool: + if provider not in self.provider_configs: + return False + self._provider_cache.setdefault(provider, None) + self._cache_loaded.setdefault(provider, False) + return True + def _split_model(self, model: str) -> tuple[Optional[str], str]: if "/" not in model: return None, model @@ -568,6 +511,14 @@ def _ensure_content(self, provider: str) -> Optional[Dict]: self._update_cache(provider) return self._provider_cache.get(provider) + def _find_record(self, content: Optional[Dict], route: str) -> Optional[Dict]: + if not content or "data" not in content: + return None + candidates = {route} + if ":" in route: + candidates.add(route.split(":", 1)[0]) + return next((item for item in content["data"] if item.get("id") in candidates), None) + def _record_to_info(self, record: Dict, provider: str) -> Dict: context_len = _first_value( record, @@ -577,9 +528,13 @@ def _record_to_info(self, record: Dict, provider: str) -> Dict: "context_length", "context_window", "top_provider_context_length", + "top_provider", ) - pricing = record.get("pricing", {}) if isinstance(record.get("pricing"), dict) else {} + if isinstance(context_len, dict): + context_len = context_len.get("context_length") or context_len.get("max_tokens") + + pricing = record.get("pricing", {}) if isinstance(record.get("pricing"), dict) else {} input_cost = _cost_per_token( _first_value(pricing, "prompt", "input", "prompt_tokens") or _first_value(record, "input_cost_per_token", "prompt_cost_per_token") @@ -620,21 +575,8 @@ def _record_to_info(self, record: Dict, provider: str) -> Dict: "litellm_provider": provider, "mode": record.get("mode", "chat"), } - return {k: v for k, v in info.items() if v is not None} - def refresh_provider_cache(self, provider: str) -> bool: - """Force-refresh the provider's /models cache if supported.""" - if not self.supports_provider(provider): - return False - config = self.provider_configs[provider] - if not config.get("models_url"): - return False - self._provider_cache[provider] = None - self._cache_loaded[provider] = True - self._update_cache(provider) - return bool(self._provider_cache.get(provider)) - def _get_cache_file(self, provider: str) -> Path: fname = f"{provider}_models.json" return self.cache_dir / fname @@ -714,3 +656,42 @@ def _get_api_key(self, provider: str) -> Optional[str]: if value: return value return None + + +def ensure_litellm_providers_registered() -> None: + global _PROVIDERS_REGISTERED + if _PROVIDERS_REGISTERED: + return + for slug, cfg in PROVIDER_CONFIGS.items(): + _register_provider_with_litellm(slug, cfg) + _PROVIDERS_REGISTERED = True + + +_NUMBER_RE = re.compile(r"-?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?") + + +def _cost_per_token(val: Optional[str | float | int]) -> Optional[float]: + if val in (None, "", "-", "N/A"): + return None + if val == "0": + return 0.0 + if isinstance(val, str): + cleaned = val.strip().replace(",", "") + if cleaned.startswith("$"): + cleaned = cleaned[1:] + match = _NUMBER_RE.search(cleaned) + if not match: + return None + val = match.group(0) + try: + return float(val) + except (TypeError, ValueError): + return None + + +def _first_value(record: Dict, *keys: str): + for key in keys: + value = record.get(key) + if value not in (None, ""): + return value + return None diff --git a/aider/llm.py b/aider/llm.py index eb74aab25e1..ff320dee3d3 100644 --- a/aider/llm.py +++ b/aider/llm.py @@ -6,7 +6,7 @@ from collections.abc import Coroutine from aider.dump import dump # noqa: F401 -from aider.openai_providers import ensure_litellm_providers_registered +from aider.helpers.model_providers import ensure_litellm_providers_registered warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") diff --git a/aider/models.py b/aider/models.py index 0513141289b..bb7e78eca25 100644 --- a/aider/models.py +++ b/aider/models.py @@ -19,8 +19,7 @@ from aider.dump import dump # noqa: F401 from aider.helpers.requests import model_request_parser from aider.llm import litellm -from aider.openai_providers import OpenAIProviderManager -from aider.openrouter import OpenRouterModelManager +from aider.helpers.model_providers import ModelProviderManager from aider.sendchat import sanity_check_messages from aider.utils import check_pip_install_extra @@ -159,15 +158,12 @@ def __init__(self): self._cache_loaded = False # Manager for provider-specific cached model databases - self.openrouter_manager = OpenRouterModelManager() - self.openai_provider_manager = OpenAIProviderManager() + self.provider_manager = ModelProviderManager() + self.openai_provider_manager = self.provider_manager # Backwards compatibility alias def set_verify_ssl(self, verify_ssl): self.verify_ssl = verify_ssl - if hasattr(self, "openrouter_manager"): - self.openrouter_manager.set_verify_ssl(verify_ssl) - if hasattr(self, "openai_provider_manager"): - self.openai_provider_manager.set_verify_ssl(verify_ssl) + self.provider_manager.set_verify_ssl(verify_ssl) def _load_cache(self): if self._cache_loaded: @@ -245,32 +241,45 @@ def get_model_info(self, model): if "model_prices_and_context_window.json" not in str(ex): print(str(ex)) + provider_info = self._resolve_via_provider(model, cached_info) + if provider_info: + return provider_info + if litellm_info: return litellm_info - if not cached_info and model.startswith("openrouter/"): - # First try using the locally cached OpenRouter model database - openrouter_info = self.openrouter_manager.get_model_info(model) - if openrouter_info: - return openrouter_info + return cached_info + + def _resolve_via_provider(self, model, cached_info): + if cached_info: + return None + + provider = model.split("/", 1)[0] if "/" in model else None + if not self.provider_manager.supports_provider(provider): + return None - # Fallback to legacy web-scraping if the API cache does not contain the model + provider_info = self.provider_manager.get_model_info(model) + if provider_info: + self._record_dynamic_model(model, provider_info) + return provider_info + + if provider == "openrouter": openrouter_info = self.fetch_openrouter_model_info(model) if openrouter_info: + openrouter_info.setdefault("litellm_provider", "openrouter") + self._record_dynamic_model(model, openrouter_info) return openrouter_info - provider = model.split("/", 1)[0] if "/" in model else None - if self.openai_provider_manager.supports_provider(provider): - provider_info = self.openai_provider_manager.get_model_info(model) - if not provider_info and not cached_info: - refreshed = self.openai_provider_manager.refresh_provider_cache(provider) - if refreshed: - provider_info = self.openai_provider_manager.get_model_info(model) - if provider_info: - self.local_model_metadata[model] = provider_info - return provider_info + return None - return cached_info + def _record_dynamic_model(self, model, info): + self.local_model_metadata[model] = info + self._ensure_model_settings_entry(model) + + def _ensure_model_settings_entry(self, model): + if any(ms.name == model for ms in MODEL_SETTINGS): + return + MODEL_SETTINGS.append(ModelSettings(name=model)) def fetch_openrouter_model_info(self, model): """ @@ -315,6 +324,7 @@ def fetch_openrouter_model_info(self, model): "max_output_tokens": context_size, "input_cost_per_token": input_cost, "output_cost_per_token": output_cost, + "litellm_provider": "openrouter", } return params except Exception as e: @@ -720,7 +730,7 @@ def _apply_provider_defaults(self): if not provider: return - provider_config = model_info_manager.openai_provider_manager.get_provider_config(provider) + provider_config = model_info_manager.provider_manager.get_provider_config(provider) if not provider_config: return @@ -735,7 +745,7 @@ def _apply_provider_defaults(self): # standard completions for those providers. self.streaming = False - base_url = model_info_manager.openai_provider_manager.get_provider_base_url(provider) + base_url = model_info_manager.provider_manager.get_provider_base_url(provider) if base_url: self.extra_params.setdefault("base_url", base_url) @@ -851,9 +861,9 @@ def fast_validate_environment(self): if ( not var and provider - and model_info_manager.openai_provider_manager.supports_provider(provider) + and model_info_manager.provider_manager.supports_provider(provider) ): - provider_keys = model_info_manager.openai_provider_manager.get_required_api_keys( + provider_keys = model_info_manager.provider_manager.get_required_api_keys( provider ) for env_var in provider_keys: @@ -890,7 +900,7 @@ def validate_environment(self): return res provider = self.info.get("litellm_provider", "").lower() - provider_config = model_info_manager.openai_provider_manager.get_provider_config(provider) + provider_config = model_info_manager.provider_manager.get_provider_config(provider) if provider_config: envs = provider_config.get("api_key_env", []) available = [env for env in envs if os.environ.get(env)] @@ -1389,7 +1399,7 @@ def get_chat_model_names(): model_metadata = list(litellm.model_cost.items()) model_metadata += list(model_info_manager.local_model_metadata.items()) - openai_provider_models = model_info_manager.openai_provider_manager.get_models_for_listing() + openai_provider_models = model_info_manager.provider_manager.get_models_for_listing() model_metadata += list(openai_provider_models.items()) for orig_model, attrs in model_metadata: diff --git a/aider/openrouter.py b/aider/openrouter.py deleted file mode 100644 index ea641c17fda..00000000000 --- a/aider/openrouter.py +++ /dev/null @@ -1,129 +0,0 @@ -""" -OpenRouter model metadata caching and lookup. - -This module keeps a local cached copy of the OpenRouter model list -(downloaded from ``https://openrouter.ai/api/v1/models``) and exposes a -helper class that returns metadata for a given model in a format compatible -with litellm’s ``get_model_info``. -""" - -from __future__ import annotations - -import json -import time -from pathlib import Path -from typing import Dict - -import requests - - -def _cost_per_token(val: str | None) -> float | None: - """Convert a price string (USD per token) to a float.""" - if val in (None, "", "0"): - return 0.0 if val == "0" else None - try: - return float(val) - except Exception: # noqa: BLE001 - return None - - -class OpenRouterModelManager: - MODELS_URL = "https://openrouter.ai/api/v1/models" - CACHE_TTL = 60 * 60 * 24 # 24 h - - def __init__(self) -> None: - self.cache_dir = Path.home() / ".aider" / "caches" - self.cache_file = self.cache_dir / "openrouter_models.json" - self.content: Dict | None = None - self.verify_ssl: bool = True - self._cache_loaded = False - - # ------------------------------------------------------------------ # - # Public API # - # ------------------------------------------------------------------ # - def set_verify_ssl(self, verify_ssl: bool) -> None: - """Enable/disable SSL verification for API requests.""" - self.verify_ssl = verify_ssl - - def get_model_info(self, model: str) -> Dict: - """ - Return metadata for *model* or an empty ``dict`` when unknown. - - ``model`` should use the aider naming convention, e.g. - ``openrouter/nousresearch/deephermes-3-mistral-24b-preview:free``. - """ - self._ensure_content() - if not self.content or "data" not in self.content: - return {} - - route = self._strip_prefix(model) - - # Consider both the exact id and id without any “:suffix”. - candidates = {route} - if ":" in route: - candidates.add(route.split(":", 1)[0]) - - record = next((item for item in self.content["data"] if item.get("id") in candidates), None) - if not record: - return {} - - context_len = ( - record.get("top_provider", {}).get("context_length") - or record.get("context_length") - or None - ) - - pricing = record.get("pricing", {}) - return { - "max_input_tokens": context_len, - "max_tokens": context_len, - "max_output_tokens": context_len, - "input_cost_per_token": _cost_per_token(pricing.get("prompt")), - "output_cost_per_token": _cost_per_token(pricing.get("completion")), - "litellm_provider": "openrouter", - } - - # ------------------------------------------------------------------ # - # Internal helpers # - # ------------------------------------------------------------------ # - def _strip_prefix(self, model: str) -> str: - return model[len("openrouter/") :] if model.startswith("openrouter/") else model - - def _ensure_content(self) -> None: - self._load_cache() - if not self.content: - self._update_cache() - - def _load_cache(self) -> None: - if self._cache_loaded: - return - try: - self.cache_dir.mkdir(parents=True, exist_ok=True) - if self.cache_file.exists(): - cache_age = time.time() - self.cache_file.stat().st_mtime - if cache_age < self.CACHE_TTL: - try: - self.content = json.loads(self.cache_file.read_text()) - except json.JSONDecodeError: - self.content = None - except OSError: - # Cache directory might be unwritable; ignore. - pass - - self._cache_loaded = True - - def _update_cache(self) -> None: - try: - response = requests.get(self.MODELS_URL, timeout=10, verify=self.verify_ssl) - if response.status_code == 200: - self.content = response.json() - try: - self.cache_file.write_text(json.dumps(self.content, indent=2)) - except OSError: - pass # Non-fatal if we can’t write the cache - except Exception as ex: # noqa: BLE001 - print(f"Failed to fetch OpenRouter model list: {ex}") - try: - self.cache_file.write_text("{}") - except OSError: - pass diff --git a/aider/resources/openai_providers.json b/aider/resources/openai_providers.json index 1609d2d971b..7c022a21095 100644 --- a/aider/resources/openai_providers.json +++ b/aider/resources/openai_providers.json @@ -1,4 +1,16 @@ { + "openrouter": { + "api_base": "https://openrouter.ai/api/v1", + "models_url": "https://openrouter.ai/api/v1/models", + "api_key_env": [ + "OPENROUTER_API_KEY" + ], + "requires_api_key": false, + "default_headers": { + "HTTP-Referer": "https://aider.chat", + "X-Title": "aider" + } + }, "openai": { "api_base": "https://api.openai.com/v1", "models_url": "https://api.openai.com/v1/models", diff --git a/tests/basic/test_main.py b/tests/basic/test_main.py index c3998c50d1f..7885b2e636f 100644 --- a/tests/basic/test_main.py +++ b/tests/basic/test_main.py @@ -1420,7 +1420,7 @@ async def test_list_models_includes_openai_provider(self): import aider.models as models_module provider_name = "openai" - manager = models_module.model_info_manager.openai_provider_manager + manager = models_module.model_info_manager.provider_manager provider_config = { "api_base": "https://api.openai.com/v1", "models_url": "https://api.openai.com/v1/models", @@ -1450,7 +1450,7 @@ async def test_list_models_includes_openai_provider(self): ] } - def _fake_get(url, *, timeout=None, verify=None): + def _fake_get(url, *, headers=None, timeout=None, verify=None): return types.SimpleNamespace(status_code=200, json=lambda: payload) try: diff --git a/tests/basic/test_openai_providers.py b/tests/basic/test_openai_providers.py index 48edb72bab8..3f0b301fc64 100644 --- a/tests/basic/test_openai_providers.py +++ b/tests/basic/test_openai_providers.py @@ -119,7 +119,7 @@ def terminate(self): import aider.models as models_module from aider.commands.model import ModelCommand from aider.models import ModelInfoManager -from aider.openai_providers import OpenAIProviderManager, _JSONOpenAIProvider +from aider.helpers.model_providers import ModelProviderManager, _JSONOpenAIProvider class DummyResponse: @@ -183,7 +183,7 @@ def _test_provider_config(): def test_provider_manager_get_model_info_from_cache(monkeypatch, tmp_path): - """OpenAIProviderManager should hydrate from cached payloads.""" + """ModelProviderManager should hydrate from cached payloads.""" payload = _load_openai_fixture() @@ -193,7 +193,7 @@ def _fail_request(*args, **kwargs): # pragma: no cover - should never be called monkeypatch.setattr("requests.get", _fail_request) monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) - manager = OpenAIProviderManager(provider_configs=_test_provider_config()) + manager = ModelProviderManager(provider_configs=_test_provider_config()) manager.cache_dir.mkdir(parents=True, exist_ok=True) cache_file = manager._get_cache_file("openai") cache_file.write_text(json.dumps(payload)) @@ -210,7 +210,7 @@ def _fail_request(*args, **kwargs): # pragma: no cover - should never be called def test_provider_manager_models_endpoint_fetch(monkeypatch, tmp_path): - """OpenAIProviderManager should fetch and cache the /models payload when missing.""" + """ModelProviderManager should fetch and cache the /models payload when missing.""" payload = _load_openai_fixture() call_args = [] @@ -224,7 +224,7 @@ def _recording_request(url, *, headers=None, timeout=None, verify=None): monkeypatch.setenv("OPENAI_API_KEY", "test-key") provider_config = _test_provider_config() - manager = OpenAIProviderManager(provider_configs=provider_config) + manager = ModelProviderManager(provider_configs=provider_config) manager.set_verify_ssl(False) info = manager.get_model_info("openai/zai-org/GLM-4.6:extended") @@ -256,7 +256,7 @@ def _fail_request(*args, **kwargs): # pragma: no cover - should not run monkeypatch.setattr("requests.get", _fail_request) monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) - manager = OpenAIProviderManager(provider_configs=provider_config) + manager = ModelProviderManager(provider_configs=provider_config) info = manager.get_model_info("openai/zai-org/GLM-4.6") assert info["litellm_provider"] == "openai" @@ -283,15 +283,15 @@ def _fail_request(*args, **kwargs): # pragma: no cover - should not run monkeypatch.setattr("requests.get", _fail_request) monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) - manager = OpenAIProviderManager(provider_configs=provider_config) + manager = ModelProviderManager(provider_configs=provider_config) info = manager.get_model_info("openai/demo/model") assert math.isclose(info["input_cost_per_token"], 0.00000055) assert math.isclose(info["output_cost_per_token"], 0.00000219) -def test_model_info_manager_uses_openai_provider_manager(monkeypatch): - """ModelInfoManager should delegate to OpenAIProviderManager for openai-like models.""" +def test_model_info_manager_uses_provider_manager(monkeypatch): + """ModelInfoManager should delegate to ModelProviderManager for openai-like models.""" monkeypatch.setattr( models_module, @@ -309,7 +309,7 @@ def test_model_info_manager_uses_openai_provider_manager(monkeypatch): } monkeypatch.setattr( - "aider.models.OpenAIProviderManager.get_model_info", + "aider.helpers.model_providers.ModelProviderManager.get_model_info", lambda self, model: stub_info, ) @@ -319,10 +319,10 @@ def test_model_info_manager_uses_openai_provider_manager(monkeypatch): assert info == stub_info -def test_openai_provider_manager_listing(monkeypatch, tmp_path): +def test_provider_manager_listing(monkeypatch, tmp_path): payload = _load_openai_fixture() monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) - manager = OpenAIProviderManager(provider_configs=_test_provider_config()) + manager = ModelProviderManager(provider_configs=_test_provider_config()) manager.cache_dir.mkdir(parents=True, exist_ok=True) cache_file = manager._get_cache_file("openai") cache_file.write_text(json.dumps(payload)) @@ -342,10 +342,10 @@ def test_chat_model_names_include_openai_provider_models(monkeypatch, tmp_path): models_module.litellm = types.SimpleNamespace(model_cost={}, _lazy_module=None) models_module.model_info_manager = models_module.ModelInfoManager() - models_module.model_info_manager.openai_provider_manager = OpenAIProviderManager( + models_module.model_info_manager.provider_manager = ModelProviderManager( provider_configs=_test_provider_config() ) - manager = models_module.model_info_manager.openai_provider_manager + manager = models_module.model_info_manager.provider_manager manager.cache_dir.mkdir(parents=True, exist_ok=True) cache_file = manager._get_cache_file("openai") cache_file.write_text(json.dumps(payload)) @@ -363,10 +363,10 @@ def test_model_command_completions_include_openai_provider_models(monkeypatch, t models_module.litellm = types.SimpleNamespace(model_cost={}, _lazy_module=None) models_module.model_info_manager = models_module.ModelInfoManager() - models_module.model_info_manager.openai_provider_manager = OpenAIProviderManager( + models_module.model_info_manager.provider_manager = ModelProviderManager( provider_configs=_test_provider_config() ) - manager = models_module.model_info_manager.openai_provider_manager + manager = models_module.model_info_manager.provider_manager manager.cache_dir.mkdir(parents=True, exist_ok=True) cache_file = manager._get_cache_file("openai") cache_file.write_text(json.dumps(payload)) @@ -385,7 +385,7 @@ def test_model_disables_streaming_for_non_streaming_providers(monkeypatch): } } - provider_manager = OpenAIProviderManager(provider_configs=provider_configs) + provider_manager = ModelProviderManager(provider_configs=provider_configs) fake_info = { "max_input_tokens": 4096, "max_tokens": 4096, @@ -395,7 +395,7 @@ def test_model_disables_streaming_for_non_streaming_providers(monkeypatch): fake_model_info_manager = types.SimpleNamespace( get_model_info=lambda model: fake_info, - openai_provider_manager=provider_manager, + provider_manager=provider_manager, ) monkeypatch.setenv("SYNTHETIC_API_KEY", "test-key") diff --git a/tests/basic/test_openrouter.py b/tests/basic/test_openrouter.py index f55c301572c..8659f7a997a 100644 --- a/tests/basic/test_openrouter.py +++ b/tests/basic/test_openrouter.py @@ -1,7 +1,122 @@ from pathlib import Path +import sys +import types +if "numpy" not in sys.modules: + numpy_module = types.ModuleType("numpy") + numpy_module.ndarray = object + numpy_module.array = lambda *a, **k: None + numpy_module.dot = lambda *a, **k: 0.0 + numpy_module.linalg = types.SimpleNamespace(norm=lambda *a, **k: 1.0) + sys.modules["numpy"] = numpy_module + +if "PIL" not in sys.modules: + pil_module = types.ModuleType("PIL") + image_module = types.ModuleType("PIL.Image") + image_grab_module = types.ModuleType("PIL.ImageGrab") + + class _DummyImage: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + @property + def size(self): + return (1024, 1024) + + def _dummy_open(*args, **kwargs): + return _DummyImage() + + image_module.open = _dummy_open + image_grab_module.grab = _dummy_open + pil_module.Image = image_module + pil_module.ImageGrab = image_grab_module + sys.modules["PIL"] = pil_module + sys.modules["PIL.Image"] = image_module + sys.modules["PIL.ImageGrab"] = image_grab_module + +if "oslex" not in sys.modules: + oslex_module = types.ModuleType("oslex") + oslex_module.__all__ = [] + sys.modules["oslex"] = oslex_module + +if "rich" not in sys.modules: + rich_module = types.ModuleType("rich") + console_module = types.ModuleType("rich.console") + + class _DummyConsole: + def __init__(self, *args, **kwargs): + pass + + def status(self, *args, **kwargs): + return self + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def update(self, *args, **kwargs): + return None + + console_module.Console = _DummyConsole + rich_module.console = console_module + sys.modules["rich"] = rich_module + sys.modules["rich.console"] = console_module + +if "pyperclip" not in sys.modules: + pyperclip_module = types.ModuleType("pyperclip") + + class _DummyPyperclipException(Exception): + pass + + pyperclip_module.PyperclipException = _DummyPyperclipException + pyperclip_module.copy = lambda *args, **kwargs: None + sys.modules["pyperclip"] = pyperclip_module + +if "pexpect" not in sys.modules: + pexpect_module = types.ModuleType("pexpect") + + class _DummySpawn: + def __init__(self, *args, **kwargs): + pass + + def sendline(self, *args, **kwargs): + return 0 + + def close(self, *args, **kwargs): + return 0 + + pexpect_module.spawn = _DummySpawn + sys.modules["pexpect"] = pexpect_module + +if "psutil" not in sys.modules: + psutil_module = types.ModuleType("psutil") + + class _DummyProcess: + def __init__(self, *args, **kwargs): + pass + + def children(self, *args, **kwargs): + return [] + + def terminate(self): + return None + + psutil_module.Process = _DummyProcess + sys.modules["psutil"] = psutil_module + +if "pypandoc" not in sys.modules: + pypandoc_module = types.ModuleType("pypandoc") + pypandoc_module.convert_text = lambda *args, **kwargs: "" + sys.modules["pypandoc"] = pypandoc_module + +from aider.helpers.model_providers import ModelProviderManager from aider.models import ModelInfoManager -from aider.openrouter import OpenRouterModelManager +import aider.models as models_module class DummyResponse: @@ -14,10 +129,13 @@ def __init__(self, json_data): def json(self): return self._json_data + def raise_for_status(self): + return None + def test_openrouter_get_model_info_from_cache(monkeypatch, tmp_path): """ - OpenRouterModelManager should return correct metadata taken from the + ModelProviderManager should return correct metadata taken from the downloaded (and locally cached) models JSON payload. """ payload = { @@ -35,7 +153,14 @@ def test_openrouter_get_model_info_from_cache(monkeypatch, tmp_path): monkeypatch.setattr("requests.get", lambda *a, **k: DummyResponse(payload)) monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) - manager = OpenRouterModelManager() + provider_config = { + "openrouter": { + "api_base": "https://openrouter.ai/api/v1", + "models_url": "https://openrouter.ai/api/v1/models", + "requires_api_key": False, + } + } + manager = ModelProviderManager(provider_configs=provider_config) info = manager.get_model_info("openrouter/mistralai/mistral-medium-3") assert info["max_input_tokens"] == 32768 @@ -46,11 +171,15 @@ def test_openrouter_get_model_info_from_cache(monkeypatch, tmp_path): def test_model_info_manager_uses_openrouter_manager(monkeypatch): """ - ModelInfoManager should delegate to OpenRouterModelManager when litellm + ModelInfoManager should delegate to ModelProviderManager when litellm provides no data for an OpenRouter-prefixed model. """ # Ensure litellm path returns no info so that fallback logic triggers - monkeypatch.setattr("aider.models.litellm.get_model_info", lambda *a, **k: {}) + monkeypatch.setattr( + models_module, + "litellm", + types.SimpleNamespace(_lazy_module=None, get_model_info=lambda *a, **k: {}), + ) stub_info = { "max_input_tokens": 512, @@ -61,11 +190,15 @@ def test_model_info_manager_uses_openrouter_manager(monkeypatch): "litellm_provider": "openrouter", } - # Force OpenRouterModelManager to return our stub info + # Force ModelProviderManager to return our stub info monkeypatch.setattr( - "aider.models.OpenRouterModelManager.get_model_info", + "aider.helpers.model_providers.ModelProviderManager.get_model_info", lambda self, model: stub_info, ) + monkeypatch.setattr( + "aider.helpers.model_providers.ModelProviderManager.supports_provider", + lambda self, provider: provider == "openrouter", + ) mim = ModelInfoManager() info = mim.get_model_info("openrouter/fake/model") From 564dcdf29d4da937f1a0d83475d1a84e8da4613b Mon Sep 17 00:00:00 2001 From: Chris Nestrud Date: Sun, 28 Dec 2025 13:06:32 -0600 Subject: [PATCH 07/14] test: cover unified model provider manager --- tests/basic/test_model_provider_manager.py | 335 +++++++++++++++++++++ 1 file changed, 335 insertions(+) create mode 100644 tests/basic/test_model_provider_manager.py diff --git a/tests/basic/test_model_provider_manager.py b/tests/basic/test_model_provider_manager.py new file mode 100644 index 00000000000..72dd103bfac --- /dev/null +++ b/tests/basic/test_model_provider_manager.py @@ -0,0 +1,335 @@ +import json +import sys +import types + + +def _install_stubs(): + if "PIL" not in sys.modules: + pil_module = types.ModuleType("PIL") + image_module = types.ModuleType("PIL.Image") + image_grab_module = types.ModuleType("PIL.ImageGrab") + + class _DummyImage: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + @property + def size(self): + return (1024, 1024) + + def _dummy_open(*args, **kwargs): + return _DummyImage() + + image_module.open = _dummy_open + image_grab_module.grab = _dummy_open + pil_module.Image = image_module + pil_module.ImageGrab = image_grab_module + sys.modules["PIL"] = pil_module + sys.modules["PIL.Image"] = image_module + sys.modules["PIL.ImageGrab"] = image_grab_module + + if "numpy" not in sys.modules: + numpy_module = types.ModuleType("numpy") + numpy_module.ndarray = object + numpy_module.array = lambda *a, **k: None + numpy_module.dot = lambda *a, **k: 0.0 + numpy_module.linalg = types.SimpleNamespace(norm=lambda *a, **k: 1.0) + sys.modules["numpy"] = numpy_module + + if "oslex" not in sys.modules: + oslex_module = types.ModuleType("oslex") + oslex_module.__all__ = [] + sys.modules["oslex"] = oslex_module + + if "rich" not in sys.modules: + rich_module = types.ModuleType("rich") + console_module = types.ModuleType("rich.console") + + class _DummyConsole: + def __init__(self, *args, **kwargs): + pass + + def status(self, *args, **kwargs): + return self + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def update(self, *args, **kwargs): + return None + + console_module.Console = _DummyConsole + rich_module.console = console_module + sys.modules["rich"] = rich_module + sys.modules["rich.console"] = console_module + + if "pyperclip" not in sys.modules: + pyperclip_module = types.ModuleType("pyperclip") + + class _DummyPyperclipException(Exception): + pass + + pyperclip_module.PyperclipException = _DummyPyperclipException + pyperclip_module.copy = lambda *args, **kwargs: None + sys.modules["pyperclip"] = pyperclip_module + + if "pexpect" not in sys.modules: + pexpect_module = types.ModuleType("pexpect") + + class _DummySpawn: + def __init__(self, *args, **kwargs): + pass + + def sendline(self, *args, **kwargs): + return 0 + + def close(self, *args, **kwargs): + return 0 + + pexpect_module.spawn = _DummySpawn + sys.modules["pexpect"] = pexpect_module + + if "psutil" not in sys.modules: + psutil_module = types.ModuleType("psutil") + + class _DummyProcess: + def __init__(self, *args, **kwargs): + pass + + def children(self, *args, **kwargs): + return [] + + def terminate(self): + return None + + psutil_module.Process = _DummyProcess + sys.modules["psutil"] = psutil_module + + if "pypandoc" not in sys.modules: + pypandoc_module = types.ModuleType("pypandoc") + pypandoc_module.convert_text = lambda *args, **kwargs: "" + sys.modules["pypandoc"] = pypandoc_module + + +_install_stubs() + +from aider.helpers.model_providers import ModelProviderManager +from aider.models import MODEL_SETTINGS, Model, ModelInfoManager + + +class DummyResponse: + def __init__(self, payload): + self._payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self._payload + + +def _make_manager(tmp_path, config): + manager = ModelProviderManager(provider_configs=config) + manager.cache_dir = tmp_path # Avoid touching real home dir + return manager + + +def test_model_provider_matches_suffix_variants(monkeypatch, tmp_path): + payload = { + "data": [ + { + "id": "demo/model", + "context_length": 2048, + "pricing": {"prompt": "1.0", "completion": "2.0"}, + } + ] + } + + config = { + "openrouter": { + "api_base": "https://openrouter.ai/api/v1", + "models_url": "https://openrouter.ai/api/v1/models", + "requires_api_key": False, + } + } + + manager = _make_manager(tmp_path, config) + cache_file = manager._get_cache_file("openrouter") + cache_file.write_text(json.dumps(payload)) + manager._cache_loaded["openrouter"] = True + manager._provider_cache["openrouter"] = payload + + info = manager.get_model_info("openrouter/demo/model:extended") + + assert info["max_input_tokens"] == 2048 + assert info["input_cost_per_token"] == 1.0 + assert info["litellm_provider"] == "openrouter" + + +def test_model_provider_uses_top_provider_context(tmp_path): + payload = { + "data": [ + { + "id": "demo/model", + "top_provider": {"context_length": 4096}, + "pricing": {"prompt": "3", "completion": "4"}, + } + ] + } + + config = { + "demo": { + "api_base": "https://example.com/v1", + "models_url": "https://example.com/v1/models", + "requires_api_key": False, + } + } + + manager = _make_manager(tmp_path, config) + cache_file = manager._get_cache_file("demo") + cache_file.write_text(json.dumps(payload)) + manager._cache_loaded["demo"] = True + manager._provider_cache["demo"] = payload + + info = manager.get_model_info("demo/demo/model") + + assert info["max_input_tokens"] == 4096 + assert info["max_tokens"] == 4096 + assert info["max_output_tokens"] == 4096 + + +def test_fetch_provider_models_injects_headers(monkeypatch, tmp_path): + payload = {"data": []} + captured = {} + + def _fake_get(url, *, headers=None, timeout=None, verify=None): + captured["url"] = url + captured["headers"] = headers + captured["timeout"] = timeout + captured["verify"] = verify + return DummyResponse(payload) + + monkeypatch.setattr("requests.get", _fake_get) + + config = { + "demo": { + "api_base": "https://example.com/v1", + "default_headers": {"X-Test": "demo"}, + "requires_api_key": False, + } + } + + manager = _make_manager(tmp_path, config) + manager.set_verify_ssl(False) + + result = manager._fetch_provider_models("demo") + + assert result == payload + assert captured["url"] == "https://example.com/v1/models" + assert captured["headers"] == {"X-Test": "demo"} + assert captured["timeout"] == 10 + assert captured["verify"] is False + + +def test_get_api_key_prefers_first_valid(monkeypatch, tmp_path): + config = { + "demo": { + "api_base": "https://example.com/v1", + "api_key_env": ["DEMO_FALLBACK", "DEMO_KEY"], + "requires_api_key": True, + } + } + + manager = _make_manager(tmp_path, config) + monkeypatch.delenv("DEMO_FALLBACK", raising=False) + monkeypatch.setenv("DEMO_KEY", "secret") + + assert manager._get_api_key("demo") == "secret" + + +def test_model_info_manager_delegates_to_provider(monkeypatch, tmp_path): + monkeypatch.setattr( + "aider.models.litellm", + types.SimpleNamespace( + _lazy_module=None, + get_model_info=lambda *a, **k: {}, + validate_environment=lambda model: {"keys_in_environment": True, "missing_keys": []}, + encode=lambda *a, **k: [], + token_counter=lambda *a, **k: 0, + ), + ) + + stub_info = { + "max_input_tokens": 512, + "max_tokens": 512, + "max_output_tokens": 512, + "input_cost_per_token": 1.0, + "output_cost_per_token": 2.0, + "litellm_provider": "openrouter", + } + + monkeypatch.setattr( + "aider.helpers.model_providers.ModelProviderManager.supports_provider", + lambda self, provider: provider == "openrouter", + ) + monkeypatch.setattr( + "aider.helpers.model_providers.ModelProviderManager.get_model_info", + lambda self, model: stub_info, + ) + + mim = ModelInfoManager() + info = mim.get_model_info("openrouter/demo/model") + + assert info == stub_info + + +def test_model_dynamic_settings_added(monkeypatch, tmp_path): + provider = "demo" + model_name = "demo/org/foo" + manager = ModelInfoManager() + + def _fake_supports(self, prov): + return prov == provider + + def _fake_get(self, model): + return { + "max_input_tokens": 2048, + "max_tokens": 2048, + "max_output_tokens": 2048, + "litellm_provider": provider, + } + + monkeypatch.setattr( + "aider.helpers.model_providers.ModelProviderManager.supports_provider", + _fake_supports, + ) + monkeypatch.setattr( + "aider.helpers.model_providers.ModelProviderManager.get_model_info", + _fake_get, + ) + monkeypatch.setattr( + "aider.models.litellm", + types.SimpleNamespace( + _lazy_module=None, + get_model_info=lambda *a, **k: {}, + validate_environment=lambda model: {"keys_in_environment": True, "missing_keys": []}, + encode=lambda *a, **k: [], + token_counter=lambda *a, **k: 0, + ), + ) + + assert not any(ms.name == model_name for ms in MODEL_SETTINGS) + + info = manager.get_model_info(model_name) + assert info["max_tokens"] == 2048 + + assert any(ms.name == model_name for ms in MODEL_SETTINGS) + + model = Model(model_name) + assert model.info["max_tokens"] == 2048 From ffbc8beabd37c2985bab2b38182cb9f6d030183d Mon Sep 17 00:00:00 2001 From: Chris Nestrud Date: Sun, 28 Dec 2025 13:08:50 -0600 Subject: [PATCH 08/14] test: add fallback coverage for unified provider cache --- tests/basic/test_model_provider_manager.py | 29 ++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/basic/test_model_provider_manager.py b/tests/basic/test_model_provider_manager.py index 72dd103bfac..42a5e77018b 100644 --- a/tests/basic/test_model_provider_manager.py +++ b/tests/basic/test_model_provider_manager.py @@ -253,6 +253,35 @@ def test_get_api_key_prefers_first_valid(monkeypatch, tmp_path): assert manager._get_api_key("demo") == "secret" +def test_refresh_provider_cache_uses_static_models(monkeypatch, tmp_path): + config = { + "demo": { + "api_base": "https://example.com/v1", + "static_models": [ + { + "id": "demo/foo", + "max_input_tokens": 1024, + "pricing": {"prompt": "0.5", "completion": "1.0"}, + } + ], + } + } + + manager = _make_manager(tmp_path, config) + + def _failing_fetch(*args, **kwargs): + raise RuntimeError("boom") + + monkeypatch.setattr("requests.get", _failing_fetch) + + refreshed = manager.refresh_provider_cache("demo") + + assert refreshed is True + info = manager.get_model_info("demo/demo/foo") + assert info["max_input_tokens"] == 1024 + assert info["input_cost_per_token"] == 0.5 + + def test_model_info_manager_delegates_to_provider(monkeypatch, tmp_path): monkeypatch.setattr( "aider.models.litellm", From ab7e70ea295ca99f4183c0b23dc075d3e20bb433 Mon Sep 17 00:00:00 2001 From: Chris Nestrud Date: Sun, 28 Dec 2025 13:09:30 -0600 Subject: [PATCH 09/14] test: retire legacy provider-specific suites --- tests/basic/test_openai_providers.py | 428 --------------------------- tests/basic/test_openrouter.py | 206 ------------- 2 files changed, 634 deletions(-) delete mode 100644 tests/basic/test_openai_providers.py delete mode 100644 tests/basic/test_openrouter.py diff --git a/tests/basic/test_openai_providers.py b/tests/basic/test_openai_providers.py deleted file mode 100644 index 3f0b301fc64..00000000000 --- a/tests/basic/test_openai_providers.py +++ /dev/null @@ -1,428 +0,0 @@ -import json -import math -import sys -import types -from pathlib import Path - -if "PIL" not in sys.modules: - pil_module = types.ModuleType("PIL") - image_module = types.ModuleType("PIL.Image") - image_grab_module = types.ModuleType("PIL.ImageGrab") - - class _DummyImage: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - @property - def size(self): - return (1024, 1024) - - def _dummy_open(*args, **kwargs): - return _DummyImage() - - image_module.open = _dummy_open - image_grab_module.grab = _dummy_open - pil_module.Image = image_module - pil_module.ImageGrab = image_grab_module - sys.modules["PIL"] = pil_module - sys.modules["PIL.Image"] = image_module - sys.modules["PIL.ImageGrab"] = image_grab_module - -if "numpy" not in sys.modules: - numpy_module = types.ModuleType("numpy") - numpy_module.ndarray = object - numpy_module.array = lambda *a, **k: None - numpy_module.dot = lambda *a, **k: 0.0 - numpy_module.linalg = types.SimpleNamespace(norm=lambda *a, **k: 1.0) - sys.modules["numpy"] = numpy_module - -if "oslex" not in sys.modules: - oslex_module = types.ModuleType("oslex") - oslex_module.__all__ = [] - sys.modules["oslex"] = oslex_module - -if "rich" not in sys.modules: - rich_module = types.ModuleType("rich") - console_module = types.ModuleType("rich.console") - - class _DummyConsole: - def __init__(self, *args, **kwargs): - pass - - def status(self, *args, **kwargs): - return self - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def update(self, *args, **kwargs): - return None - - console_module.Console = _DummyConsole - rich_module.console = console_module - sys.modules["rich"] = rich_module - sys.modules["rich.console"] = console_module - -if "pyperclip" not in sys.modules: - pyperclip_module = types.ModuleType("pyperclip") - - class _DummyPyperclipException(Exception): - pass - - pyperclip_module.PyperclipException = _DummyPyperclipException - pyperclip_module.copy = lambda *args, **kwargs: None - sys.modules["pyperclip"] = pyperclip_module - -if "pexpect" not in sys.modules: - pexpect_module = types.ModuleType("pexpect") - - class _DummySpawn: - def __init__(self, *args, **kwargs): - pass - - def sendline(self, *args, **kwargs): - return 0 - - def close(self, *args, **kwargs): - return 0 - - pexpect_module.spawn = _DummySpawn - sys.modules["pexpect"] = pexpect_module - -if "psutil" not in sys.modules: - psutil_module = types.ModuleType("psutil") - - class _DummyProcess: - def __init__(self, *args, **kwargs): - pass - - def children(self, *args, **kwargs): - return [] - - def terminate(self): - return None - - psutil_module.Process = _DummyProcess - sys.modules["psutil"] = psutil_module - -if "pypandoc" not in sys.modules: - pypandoc_module = types.ModuleType("pypandoc") - pypandoc_module.convert_text = lambda *args, **kwargs: "" - sys.modules["pypandoc"] = pypandoc_module - -import aider.models as models_module -from aider.commands.model import ModelCommand -from aider.models import ModelInfoManager -from aider.helpers.model_providers import ModelProviderManager, _JSONOpenAIProvider - - -class DummyResponse: - """Minimal stand-in for requests.Response used in tests.""" - - def __init__(self, json_data): - self.status_code = 200 - self._json_data = json_data - - def json(self): - return self._json_data - - def raise_for_status(self): - return None - - -def _load_openai_fixture(): - return { - "data": [ - { - "id": "zai-org/GLM-4.6", - "object": "model", - "created": 1723500000, - "owned_by": "openai", - "max_input_tokens": 131072, - "max_output_tokens": 131072, - "max_tokens": 131072, - "context_length": 131072, - "context_window": 131072, - "top_provider_context_length": 131072, - "pricing": { - "prompt": "0.00000055", - "completion": "0.00000219", - }, - }, - { - "id": "zai-org/GLM-4.6:extended", - "object": "model", - "created": 1723500001, - "owned_by": "openai", - "max_tokens": 65536, - "pricing": { - "prompt": "0.00000060", - "completion": "0.00000250", - }, - }, - ] - } - - -def _test_provider_config(): - return { - "openai": { - "api_base": "https://api.openai.com/v1", - "models_url": "https://api.openai.com/v1/models", - "api_key_env": ["OPENAI_API_KEY"], - "base_url_env": ["OPENAI_API_BASE"], - "default_headers": {}, - } - } - - -def test_provider_manager_get_model_info_from_cache(monkeypatch, tmp_path): - """ModelProviderManager should hydrate from cached payloads.""" - - payload = _load_openai_fixture() - - def _fail_request(*args, **kwargs): # pragma: no cover - should never be called - raise AssertionError("Network request should not be made when cache is valid") - - monkeypatch.setattr("requests.get", _fail_request) - monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) - - manager = ModelProviderManager(provider_configs=_test_provider_config()) - manager.cache_dir.mkdir(parents=True, exist_ok=True) - cache_file = manager._get_cache_file("openai") - cache_file.write_text(json.dumps(payload)) - - info = manager.get_model_info("openai/zai-org/GLM-4.6:extended") - - assert info["max_input_tokens"] == 131072 - assert info["max_output_tokens"] == 131072 - assert info["max_tokens"] == 131072 - assert info["input_cost_per_token"] == 0.00000055 - assert info["output_cost_per_token"] == 0.00000219 - assert info["litellm_provider"] == "openai" - assert manager._cache_loaded["openai"] - - -def test_provider_manager_models_endpoint_fetch(monkeypatch, tmp_path): - """ModelProviderManager should fetch and cache the /models payload when missing.""" - - payload = _load_openai_fixture() - call_args = [] - - def _recording_request(url, *, headers=None, timeout=None, verify=None): - call_args.append((url, headers, timeout, verify)) - return DummyResponse(payload) - - monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) - monkeypatch.setattr("requests.get", _recording_request) - monkeypatch.setenv("OPENAI_API_KEY", "test-key") - - provider_config = _test_provider_config() - manager = ModelProviderManager(provider_configs=provider_config) - manager.set_verify_ssl(False) - - info = manager.get_model_info("openai/zai-org/GLM-4.6:extended") - - expected_url = provider_config["openai"]["models_url"] - assert call_args == [ - ( - expected_url, - {"Authorization": "Bearer test-key"}, - 10, - False, - ) - ] - assert info["max_input_tokens"] == 131072 - assert info["max_output_tokens"] == 131072 - assert info["max_tokens"] == 131072 - assert info["input_cost_per_token"] == 0.00000055 - assert info["output_cost_per_token"] == 0.00000219 - - -def test_provider_static_models_used_without_api_key(monkeypatch, tmp_path): - payload = _load_openai_fixture() - provider_config = _test_provider_config() - provider_config["openai"]["static_models"] = payload["data"] - - def _fail_request(*args, **kwargs): # pragma: no cover - should not run - raise AssertionError("Network request should not be attempted without API key") - - monkeypatch.setattr("requests.get", _fail_request) - monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) - - manager = ModelProviderManager(provider_configs=provider_config) - info = manager.get_model_info("openai/zai-org/GLM-4.6") - - assert info["litellm_provider"] == "openai" - assert info["max_tokens"] == 131072 - - -def test_provider_models_price_strings(monkeypatch, tmp_path): - payload = { - "data": [ - { - "id": "demo/model", - "max_input_tokens": 4096, - "pricing": {"prompt": "$0.00000055", "completion": "$0.00000219"}, - } - ] - } - - provider_config = _test_provider_config() - provider_config["openai"]["static_models"] = payload["data"] - - def _fail_request(*args, **kwargs): # pragma: no cover - should not run - raise AssertionError("Network fetch should be skipped when static models exist") - - monkeypatch.setattr("requests.get", _fail_request) - monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) - - manager = ModelProviderManager(provider_configs=provider_config) - info = manager.get_model_info("openai/demo/model") - - assert math.isclose(info["input_cost_per_token"], 0.00000055) - assert math.isclose(info["output_cost_per_token"], 0.00000219) - - -def test_model_info_manager_uses_provider_manager(monkeypatch): - """ModelInfoManager should delegate to ModelProviderManager for openai-like models.""" - - monkeypatch.setattr( - models_module, - "litellm", - types.SimpleNamespace(_lazy_module=None, get_model_info=lambda *a, **k: {}), - ) - - stub_info = { - "max_input_tokens": 1024, - "max_tokens": 1024, - "max_output_tokens": 1024, - "input_cost_per_token": 0.0001, - "output_cost_per_token": 0.0002, - "litellm_provider": "openai", - } - - monkeypatch.setattr( - "aider.helpers.model_providers.ModelProviderManager.get_model_info", - lambda self, model: stub_info, - ) - - mim = ModelInfoManager() - info = mim.get_model_info("openai/demo/model") - - assert info == stub_info - - -def test_provider_manager_listing(monkeypatch, tmp_path): - payload = _load_openai_fixture() - monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) - manager = ModelProviderManager(provider_configs=_test_provider_config()) - manager.cache_dir.mkdir(parents=True, exist_ok=True) - cache_file = manager._get_cache_file("openai") - cache_file.write_text(json.dumps(payload)) - - listings = manager.get_models_for_listing() - - assert "zai-org/GLM-4.6" in listings - assert listings["zai-org/GLM-4.6"]["litellm_provider"] == "openai" - assert listings["zai-org/GLM-4.6"]["mode"] == "chat" - - -def test_chat_model_names_include_openai_provider_models(monkeypatch, tmp_path): - payload = _load_openai_fixture() - monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) - - import aider.models as models_module - - models_module.litellm = types.SimpleNamespace(model_cost={}, _lazy_module=None) - models_module.model_info_manager = models_module.ModelInfoManager() - models_module.model_info_manager.provider_manager = ModelProviderManager( - provider_configs=_test_provider_config() - ) - manager = models_module.model_info_manager.provider_manager - manager.cache_dir.mkdir(parents=True, exist_ok=True) - cache_file = manager._get_cache_file("openai") - cache_file.write_text(json.dumps(payload)) - - names = models_module.get_chat_model_names() - - assert "openai/zai-org/GLM-4.6" in names - - -def test_model_command_completions_include_openai_provider_models(monkeypatch, tmp_path): - payload = _load_openai_fixture() - monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) - - import aider.models as models_module - - models_module.litellm = types.SimpleNamespace(model_cost={}, _lazy_module=None) - models_module.model_info_manager = models_module.ModelInfoManager() - models_module.model_info_manager.provider_manager = ModelProviderManager( - provider_configs=_test_provider_config() - ) - manager = models_module.model_info_manager.provider_manager - manager.cache_dir.mkdir(parents=True, exist_ok=True) - cache_file = manager._get_cache_file("openai") - cache_file.write_text(json.dumps(payload)) - - completions = ModelCommand.get_completions(io=None, coder=None, args="") - - assert "openai/zai-org/GLM-4.6" in completions - - -def test_model_disables_streaming_for_non_streaming_providers(monkeypatch): - provider_configs = { - "synthetic": { - "api_base": "https://api.synthetic.new/openai/v1", - "api_key_env": ["SYNTHETIC_API_KEY"], - "supports_stream": False, - } - } - - provider_manager = ModelProviderManager(provider_configs=provider_configs) - fake_info = { - "max_input_tokens": 4096, - "max_tokens": 4096, - "max_output_tokens": 4096, - "litellm_provider": "synthetic", - } - - fake_model_info_manager = types.SimpleNamespace( - get_model_info=lambda model: fake_info, - provider_manager=provider_manager, - ) - - monkeypatch.setenv("SYNTHETIC_API_KEY", "test-key") - monkeypatch.setattr(models_module, "model_info_manager", fake_model_info_manager) - monkeypatch.setattr( - models_module, - "litellm", - types.SimpleNamespace( - encode=lambda *a, **k: [], - token_counter=lambda *a, **k: 0, - validate_environment=lambda model: {"keys_in_environment": True, "missing_keys": []}, - ), - ) - - model = models_module.Model("synthetic/deepseek-ai/DeepSeek-V3.1") - - assert model.streaming is False - assert model.extra_params["custom_llm_provider"] == "synthetic" - - -def test_json_provider_hf_namespace_normalization(): - provider = object.__new__(_JSONOpenAIProvider) - provider.slug = "synthetic" - provider.config = {"hf_namespace": True} - - rewritten = provider._normalize_model_name("synthetic/deepseek-ai/DeepSeek-V3.1") - assert rewritten == "hf:deepseek-ai/DeepSeek-V3.1" - - unchanged = provider._normalize_model_name("hf:deepseek-ai/DeepSeek-V3.1") - assert unchanged == "hf:deepseek-ai/DeepSeek-V3.1" diff --git a/tests/basic/test_openrouter.py b/tests/basic/test_openrouter.py deleted file mode 100644 index 8659f7a997a..00000000000 --- a/tests/basic/test_openrouter.py +++ /dev/null @@ -1,206 +0,0 @@ -from pathlib import Path -import sys -import types - -if "numpy" not in sys.modules: - numpy_module = types.ModuleType("numpy") - numpy_module.ndarray = object - numpy_module.array = lambda *a, **k: None - numpy_module.dot = lambda *a, **k: 0.0 - numpy_module.linalg = types.SimpleNamespace(norm=lambda *a, **k: 1.0) - sys.modules["numpy"] = numpy_module - -if "PIL" not in sys.modules: - pil_module = types.ModuleType("PIL") - image_module = types.ModuleType("PIL.Image") - image_grab_module = types.ModuleType("PIL.ImageGrab") - - class _DummyImage: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - @property - def size(self): - return (1024, 1024) - - def _dummy_open(*args, **kwargs): - return _DummyImage() - - image_module.open = _dummy_open - image_grab_module.grab = _dummy_open - pil_module.Image = image_module - pil_module.ImageGrab = image_grab_module - sys.modules["PIL"] = pil_module - sys.modules["PIL.Image"] = image_module - sys.modules["PIL.ImageGrab"] = image_grab_module - -if "oslex" not in sys.modules: - oslex_module = types.ModuleType("oslex") - oslex_module.__all__ = [] - sys.modules["oslex"] = oslex_module - -if "rich" not in sys.modules: - rich_module = types.ModuleType("rich") - console_module = types.ModuleType("rich.console") - - class _DummyConsole: - def __init__(self, *args, **kwargs): - pass - - def status(self, *args, **kwargs): - return self - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def update(self, *args, **kwargs): - return None - - console_module.Console = _DummyConsole - rich_module.console = console_module - sys.modules["rich"] = rich_module - sys.modules["rich.console"] = console_module - -if "pyperclip" not in sys.modules: - pyperclip_module = types.ModuleType("pyperclip") - - class _DummyPyperclipException(Exception): - pass - - pyperclip_module.PyperclipException = _DummyPyperclipException - pyperclip_module.copy = lambda *args, **kwargs: None - sys.modules["pyperclip"] = pyperclip_module - -if "pexpect" not in sys.modules: - pexpect_module = types.ModuleType("pexpect") - - class _DummySpawn: - def __init__(self, *args, **kwargs): - pass - - def sendline(self, *args, **kwargs): - return 0 - - def close(self, *args, **kwargs): - return 0 - - pexpect_module.spawn = _DummySpawn - sys.modules["pexpect"] = pexpect_module - -if "psutil" not in sys.modules: - psutil_module = types.ModuleType("psutil") - - class _DummyProcess: - def __init__(self, *args, **kwargs): - pass - - def children(self, *args, **kwargs): - return [] - - def terminate(self): - return None - - psutil_module.Process = _DummyProcess - sys.modules["psutil"] = psutil_module - -if "pypandoc" not in sys.modules: - pypandoc_module = types.ModuleType("pypandoc") - pypandoc_module.convert_text = lambda *args, **kwargs: "" - sys.modules["pypandoc"] = pypandoc_module - -from aider.helpers.model_providers import ModelProviderManager -from aider.models import ModelInfoManager -import aider.models as models_module - - -class DummyResponse: - """Minimal stand-in for requests.Response used in tests.""" - - def __init__(self, json_data): - self.status_code = 200 - self._json_data = json_data - - def json(self): - return self._json_data - - def raise_for_status(self): - return None - - -def test_openrouter_get_model_info_from_cache(monkeypatch, tmp_path): - """ - ModelProviderManager should return correct metadata taken from the - downloaded (and locally cached) models JSON payload. - """ - payload = { - "data": [ - { - "id": "mistralai/mistral-medium-3", - "context_length": 32768, - "pricing": {"prompt": "100", "completion": "200"}, - "top_provider": {"context_length": 32768}, - } - ] - } - - # Fake out the network call and the HOME directory used for the cache file - monkeypatch.setattr("requests.get", lambda *a, **k: DummyResponse(payload)) - monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) - - provider_config = { - "openrouter": { - "api_base": "https://openrouter.ai/api/v1", - "models_url": "https://openrouter.ai/api/v1/models", - "requires_api_key": False, - } - } - manager = ModelProviderManager(provider_configs=provider_config) - info = manager.get_model_info("openrouter/mistralai/mistral-medium-3") - - assert info["max_input_tokens"] == 32768 - assert info["input_cost_per_token"] == 100.0 - assert info["output_cost_per_token"] == 200.0 - assert info["litellm_provider"] == "openrouter" - - -def test_model_info_manager_uses_openrouter_manager(monkeypatch): - """ - ModelInfoManager should delegate to ModelProviderManager when litellm - provides no data for an OpenRouter-prefixed model. - """ - # Ensure litellm path returns no info so that fallback logic triggers - monkeypatch.setattr( - models_module, - "litellm", - types.SimpleNamespace(_lazy_module=None, get_model_info=lambda *a, **k: {}), - ) - - stub_info = { - "max_input_tokens": 512, - "max_tokens": 512, - "max_output_tokens": 512, - "input_cost_per_token": 100.0, - "output_cost_per_token": 200.0, - "litellm_provider": "openrouter", - } - - # Force ModelProviderManager to return our stub info - monkeypatch.setattr( - "aider.helpers.model_providers.ModelProviderManager.get_model_info", - lambda self, model: stub_info, - ) - monkeypatch.setattr( - "aider.helpers.model_providers.ModelProviderManager.supports_provider", - lambda self, provider: provider == "openrouter", - ) - - mim = ModelInfoManager() - info = mim.get_model_info("openrouter/fake/model") - - assert info == stub_info From a2f773f25a5bd062f79b039b39082eb9c2137ab8 Mon Sep 17 00:00:00 2001 From: Chris Nestrud Date: Sun, 28 Dec 2025 13:22:09 -0600 Subject: [PATCH 10/14] chore: rename provider resource file --- aider/helpers/model_providers.py | 2 +- aider/resources/{openai_providers.json => providers.json} | 0 scripts/generate_openai_providers.py | 4 ++-- 3 files changed, 3 insertions(+), 3 deletions(-) rename aider/resources/{openai_providers.json => providers.json} (100%) diff --git a/aider/helpers/model_providers.py b/aider/helpers/model_providers.py index d05eeccbf38..d5a31a7329e 100644 --- a/aider/helpers/model_providers.py +++ b/aider/helpers/model_providers.py @@ -23,7 +23,7 @@ OpenAILikeChatHandler = None # type: ignore HTTPHandler = None # type: ignore -RESOURCE_FILE = "openai_providers.json" +RESOURCE_FILE = "providers.json" _PROVIDERS_REGISTERED = False _CUSTOM_HANDLERS: Dict[str, "_JSONOpenAIProvider"] = {} diff --git a/aider/resources/openai_providers.json b/aider/resources/providers.json similarity index 100% rename from aider/resources/openai_providers.json rename to aider/resources/providers.json diff --git a/scripts/generate_openai_providers.py b/scripts/generate_openai_providers.py index 22b98ade8ad..50d4b08907b 100644 --- a/scripts/generate_openai_providers.py +++ b/scripts/generate_openai_providers.py @@ -1,6 +1,6 @@ #!/usr/bin/env python """ -Interactively generate aider/resources/openai_providers.json from litellm data. +Interactively generate aider/resources/providers.json from litellm data. This script reads litellm's openai_like provider definitions and walks the user through building cecli's provider registry, mirroring the workflow used by @@ -122,7 +122,7 @@ def main(): litellm_providers_path = ( script_dir.parent / "../litellm/litellm/llms/openai_like/providers.json" ).resolve() - output_path = (repo_root / "aider" / "resources" / "openai_providers.json").resolve() + output_path = (repo_root / "aider" / "resources" / "providers.json").resolve() if not litellm_providers_path.exists(): print(f"Error: Could not find litellm providers at {litellm_providers_path}") From eb511ff03721afff53031e72b5b6615556701e68 Mon Sep 17 00:00:00 2001 From: Chris Nestrud Date: Sun, 28 Dec 2025 13:24:59 -0600 Subject: [PATCH 11/14] chore: document helper functions in provider manager --- aider/helpers/model_providers.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/aider/helpers/model_providers.py b/aider/helpers/model_providers.py index d5a31a7329e..a3950aadf12 100644 --- a/aider/helpers/model_providers.py +++ b/aider/helpers/model_providers.py @@ -29,6 +29,7 @@ def _coerce_str(value): + """Return the first string representation that litellm expects.""" if isinstance(value, str): return value if isinstance(value, list) and value: @@ -37,6 +38,7 @@ def _coerce_str(value): def _first_env_value(names): + """Return the first non-empty environment variable for the provided names.""" if not names: return None if isinstance(names, str): @@ -314,6 +316,7 @@ def astreaming( def _register_provider_with_litellm(slug: str, config: Dict) -> None: + """Register provider metadata and custom handlers with LiteLLM.""" try: from litellm.llms.openai_like.json_loader import ( JSONProviderRegistry, @@ -379,6 +382,7 @@ def _register_provider_with_litellm(slug: str, config: Dict) -> None: def _deep_merge(base: Dict, override: Dict) -> Dict: + """Recursively merge override dict into base without mutating inputs.""" result = deepcopy(base) for key, value in override.items(): if isinstance(value, dict) and isinstance(result.get(key), dict): @@ -389,6 +393,7 @@ def _deep_merge(base: Dict, override: Dict) -> Dict: def _load_provider_configs() -> Dict[str, Dict]: + """Load provider configuration overrides from the packaged JSON file.""" configs: Dict[str, Dict] = {} try: resource = importlib_resources.files("aider.resources").joinpath(RESOURCE_FILE) @@ -659,6 +664,7 @@ def _get_api_key(self, provider: str) -> Optional[str]: def ensure_litellm_providers_registered() -> None: + """One-time registration guard for LiteLLM provider metadata.""" global _PROVIDERS_REGISTERED if _PROVIDERS_REGISTERED: return @@ -671,6 +677,7 @@ def ensure_litellm_providers_registered() -> None: def _cost_per_token(val: Optional[str | float | int]) -> Optional[float]: + """Parse token pricing strings into floats, tolerating currency prefixes.""" if val in (None, "", "-", "N/A"): return None if val == "0": @@ -690,6 +697,7 @@ def _cost_per_token(val: Optional[str | float | int]) -> Optional[float]: def _first_value(record: Dict, *keys: str): + """Return the first non-empty value for the provided keys.""" for key in keys: value = record.get(key) if value not in (None, ""): From 4164dca7c0ed1917619156d5ad34e1c632a17c0c Mon Sep 17 00:00:00 2001 From: Chris Nestrud Date: Sun, 28 Dec 2025 13:26:27 -0600 Subject: [PATCH 12/14] docs: clarify unified provider rationale --- aider/helpers/model_providers.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/aider/helpers/model_providers.py b/aider/helpers/model_providers.py index a3950aadf12..c25ec50c8e9 100644 --- a/aider/helpers/model_providers.py +++ b/aider/helpers/model_providers.py @@ -1,4 +1,10 @@ -"""Unified model provider metadata caching and lookup.""" +"""Unified model provider metadata caching and lookup. + +Historically aider kept separate modules per provider (OpenRouter vs OpenAI-like). +Those grew unwieldy and duplicated caching, request, and normalization logic. +This helper centralizes that behavior so every OpenAI-compatible endpoint defines +a small config blob and inherits the same cache + LiteLLM registration plumbing. +""" from __future__ import annotations From ce524d3196203c503dcfd0996808f9b5f0ea67ab Mon Sep 17 00:00:00 2001 From: Chris Nestrud Date: Sun, 28 Dec 2025 13:31:47 -0600 Subject: [PATCH 13/14] chore: document provider scripts and rename generator --- aider/helpers/model_providers.py | 2 ++ scripts/{generate_openai_providers.py => generate_providers.py} | 0 2 files changed, 2 insertions(+) rename scripts/{generate_openai_providers.py => generate_providers.py} (100%) diff --git a/aider/helpers/model_providers.py b/aider/helpers/model_providers.py index c25ec50c8e9..95843829d94 100644 --- a/aider/helpers/model_providers.py +++ b/aider/helpers/model_providers.py @@ -4,6 +4,8 @@ Those grew unwieldy and duplicated caching, request, and normalization logic. This helper centralizes that behavior so every OpenAI-compatible endpoint defines a small config blob and inherits the same cache + LiteLLM registration plumbing. +Provider configs remain curated via ``scripts/generate_providers.py`` and the +static per-model fallback metadata is still cleaned up with ``clean_metadata.py``. """ from __future__ import annotations diff --git a/scripts/generate_openai_providers.py b/scripts/generate_providers.py similarity index 100% rename from scripts/generate_openai_providers.py rename to scripts/generate_providers.py From e365eaab061ac1141c15e2d9dc4a6e9c8791de15 Mon Sep 17 00:00:00 2001 From: Chris Nestrud Date: Sun, 28 Dec 2025 13:40:15 -0600 Subject: [PATCH 14/14] style: apply pre-commit formatting fixes --- aider/models.py | 12 +++--------- tests/basic/test_model_provider_manager.py | 4 ++-- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/aider/models.py b/aider/models.py index bb7e78eca25..537226e2a27 100644 --- a/aider/models.py +++ b/aider/models.py @@ -17,9 +17,9 @@ from aider import __version__ from aider.dump import dump # noqa: F401 +from aider.helpers.model_providers import ModelProviderManager from aider.helpers.requests import model_request_parser from aider.llm import litellm -from aider.helpers.model_providers import ModelProviderManager from aider.sendchat import sanity_check_messages from aider.utils import check_pip_install_extra @@ -858,14 +858,8 @@ def fast_validate_environment(self): if var and os.environ.get(var): return dict(keys_in_environment=[var], missing_keys=[]) - if ( - not var - and provider - and model_info_manager.provider_manager.supports_provider(provider) - ): - provider_keys = model_info_manager.provider_manager.get_required_api_keys( - provider - ) + if not var and provider and model_info_manager.provider_manager.supports_provider(provider): + provider_keys = model_info_manager.provider_manager.get_required_api_keys(provider) for env_var in provider_keys: if os.environ.get(env_var): return dict(keys_in_environment=[env_var], missing_keys=[]) diff --git a/tests/basic/test_model_provider_manager.py b/tests/basic/test_model_provider_manager.py index 42a5e77018b..a4ddf36917d 100644 --- a/tests/basic/test_model_provider_manager.py +++ b/tests/basic/test_model_provider_manager.py @@ -119,8 +119,8 @@ def terminate(self): _install_stubs() -from aider.helpers.model_providers import ModelProviderManager -from aider.models import MODEL_SETTINGS, Model, ModelInfoManager +from aider.helpers.model_providers import ModelProviderManager # noqa: E402 +from aider.models import MODEL_SETTINGS, Model, ModelInfoManager # noqa: E402 class DummyResponse: