diff --git a/agent/core/agent_loop.py b/agent/core/agent_loop.py index c3fd88bc..5f09859a 100644 --- a/agent/core/agent_loop.py +++ b/agent/core/agent_loop.py @@ -13,6 +13,7 @@ from agent.config import Config from agent.core.doom_loop import check_for_doom_loop +from agent.core.llm_errors import friendly_llm_error_message, render_llm_error_message from agent.core.llm_params import _resolve_llm_params from agent.core.prompt_caching import with_prompt_caching from agent.core.session import Event, OpType, Session @@ -191,44 +192,7 @@ async def _heal_effort_and_rebuild_params( def _friendly_error_message(error: Exception) -> str | None: """Return a user-friendly message for known error types, or None to fall back to traceback.""" - err_str = str(error).lower() - - if "authentication" in err_str or "unauthorized" in err_str or "invalid x-api-key" in err_str: - return ( - "Authentication failed — your API key is missing or invalid.\n\n" - "To fix this, set the API key for your model provider:\n" - " • Anthropic: export ANTHROPIC_API_KEY=sk-...\n" - " • OpenAI: export OPENAI_API_KEY=sk-...\n" - " • HF Router: export HF_TOKEN=hf_...\n\n" - "You can also add it to a .env file in the project root.\n" - "To switch models, use the /model command." - ) - - if "insufficient" in err_str and "credit" in err_str: - return ( - "Insufficient API credits. Please check your account balance " - "at your model provider's dashboard." - ) - - if "not supported by provider" in err_str or "no provider supports" in err_str: - return ( - "The model isn't served by the provider you pinned.\n\n" - "Drop the ':' suffix to let the HF router auto-pick a " - "provider, or use '/model' (no arg) to see which providers host " - "which models." - ) - - if "model_not_found" in err_str or ( - "model" in err_str - and ("not found" in err_str or "does not exist" in err_str) - ): - return ( - "Model not found. Use '/model' to list suggestions, or paste an " - "HF model id like 'MiniMaxAI/MiniMax-M2.7'. Availability is shown " - "when you switch." - ) - - return None + return friendly_llm_error_message(error) async def _compact_and_notify(session: Session) -> None: @@ -870,11 +834,9 @@ async def _exec_tool( continue except Exception as e: - import traceback - - error_msg = _friendly_error_message(e) - if error_msg is None: - error_msg = str(e) + "\n" + traceback.format_exc() + logger.info("Agent turn failed: %s", e) + logger.debug("Agent turn failed", exc_info=True) + error_msg = render_llm_error_message(e) await session.send_event( Event( diff --git a/agent/core/llm_errors.py b/agent/core/llm_errors.py new file mode 100644 index 00000000..a41fe4b5 --- /dev/null +++ b/agent/core/llm_errors.py @@ -0,0 +1,147 @@ +"""Shared LLM error classification and user-facing messages.""" + +from __future__ import annotations + +from typing import Literal + +LlmErrorType = Literal[ + "auth", + "credits", + "model", + "provider", + "rate_limit", + "network", + "unknown", +] + +_AUTH_MARKERS = ( + "authentication failed", + "authentication_error", + "authentication error", + "unauthorized", + "invalid x-api-key", + "invalid api key", + "incorrect api key", + "didn't provide an api key", + "did not provide an api key", + "no api key provided", + "provide your api key", + "x-api-key header is required", + "api key header is required", + "api key required", + "api key is missing or invalid", + "api_key_invalid", + "401", +) +_CREDITS_MARKERS = ( + "insufficient credit", + "insufficient credits", + "out of credits", + "insufficient_quota", + "credit balance is too low", + "balance is too low", + "purchase credits", + "plans & billing", + "quota", + "billing", + "payment required", + "402", +) +_RATE_LIMIT_MARKERS = ("429", "rate limit", "too many requests") +_NETWORK_MARKERS = ( + "timeout", + "timed out", + "connect", + "connection error", + "connection refused", + "connection reset", + "network", + "service unavailable", + "bad gateway", + "overloaded", + "capacity", +) + + +def _has_any(err_str: str, markers: tuple[str, ...]) -> bool: + return any(marker in err_str for marker in markers) + + +def classify_llm_error(error: Exception) -> LlmErrorType: + """Classify common provider/API failures from the exception text.""" + err_str = str(error).lower() + + if _has_any(err_str, _AUTH_MARKERS): + return "auth" + if _has_any(err_str, _CREDITS_MARKERS): + return "credits" + if "not supported by provider" in err_str or "no provider supports" in err_str: + return "provider" + if "model_not_found" in err_str or "unknown model" in err_str: + return "model" + if "model" in err_str and ( + "not found" in err_str + or "does not exist" in err_str + or "not available" in err_str + ): + return "model" + if _has_any(err_str, _RATE_LIMIT_MARKERS): + return "rate_limit" + if _has_any(err_str, _NETWORK_MARKERS): + return "network" + return "unknown" + + +def friendly_llm_error_message(error: Exception) -> str | None: + """Return a clean user-facing message for common LLM failures.""" + error_type = classify_llm_error(error) + + if error_type == "auth": + return ( + "Authentication failed — your API key is missing or invalid.\n\n" + "To fix this, set the API key for your model provider:\n" + " • Anthropic: export ANTHROPIC_API_KEY=sk-...\n" + " • OpenAI: export OPENAI_API_KEY=sk-...\n" + " • HF Router: export HF_TOKEN=hf_...\n\n" + "You can also add it to a .env file in the project root.\n" + "To switch models, use the /model command." + ) + if error_type == "credits": + return ( + "Insufficient API credits or quota for this model/provider.\n\n" + "Check billing for the current provider, or switch models with /model." + ) + if error_type == "provider": + return ( + "The model isn't served by the provider you pinned.\n\n" + "Drop the ':' suffix to let the HF router auto-pick a " + "provider, or use '/model' (no arg) to see which providers host " + "which models." + ) + if error_type == "model": + return ( + "Model not found. Use '/model' to list suggestions, or paste an " + "HF model id like 'MiniMaxAI/MiniMax-M2.7'. Availability is shown " + "when you switch." + ) + if error_type == "rate_limit": + return ( + "Rate limit reached. Wait a moment and retry, or switch models/providers " + "with /model." + ) + if error_type == "network": + return "The model provider is unavailable or timed out. Retry in a moment." + return None + + +def render_llm_error_message(error: Exception) -> str: + """Return the message safe to show to users.""" + return friendly_llm_error_message(error) or str(error) + + +def health_error_type(error: Exception) -> str: + """Map LLM failures to the backend health endpoint error_type values.""" + error_type = classify_llm_error(error) + if error_type in {"auth", "credits", "rate_limit", "network"}: + return error_type + return "unknown" diff --git a/agent/core/llm_params.py b/agent/core/llm_params.py index d6843df1..8bc8fc9d 100644 --- a/agent/core/llm_params.py +++ b/agent/core/llm_params.py @@ -1,37 +1,15 @@ -"""LiteLLM kwargs resolution for the model ids this agent accepts. +"""LiteLLM kwargs resolution for the model ids this agent accepts.""" -Kept separate from ``agent_loop`` so tools (research, context compaction, etc.) -can import it without pulling in the whole agent loop / tool router and -creating circular imports. -""" +from agent.core.provider_adapters import ( + UnsupportedEffortError, + resolve_adapter, +) -import os +__all__ = ["UnsupportedEffortError", "_resolve_llm_params"] def _patch_litellm_effort_validation() -> None: - """Neuter LiteLLM 1.83's hardcoded effort-level validation. - - Context: at ``litellm/llms/anthropic/chat/transformation.py:~1443`` the - Anthropic adapter validates ``output_config.effort ∈ {high, medium, - low, max}`` and gates ``max`` behind an ``_is_opus_4_6_model`` check - that only matches the substring ``opus-4-6`` / ``opus_4_6``. Result: - - * ``xhigh`` — valid on Anthropic's real API for Claude 4.7 — is - rejected pre-flight with "Invalid effort value: xhigh". - * ``max`` on Opus 4.7 is rejected with "effort='max' is only supported - by Claude Opus 4.6", even though Opus 4.7 accepts it in practice. - - We don't want to maintain a parallel model table, so we let the - Anthropic API itself be the validator: widen ``_is_opus_4_6_model`` - to also match ``opus-4-7``+ families, and drop the valid-effort-set - check entirely. If Anthropic rejects an effort level, we see a 400 - and the cascade walks down — exactly the behavior we want for any - future model family. - - Removable once litellm ships 1.83.8-stable (which merges PR #25867, - "Litellm day 0 opus 4.7 support") — see commit 0868a82 on their main - branch. Until then, this one-time patch is the escape hatch. - """ + """Patch LiteLLM's Anthropic effort validation for Claude Opus 4.7.""" try: from litellm.llms.anthropic.chat import transformation as _t except Exception: @@ -64,59 +42,15 @@ def _widened(model: str) -> bool: _patch_litellm_effort_validation() -# Effort levels accepted on the wire. -# Anthropic (4.6+): low | medium | high | xhigh | max (output_config.effort) -# OpenAI direct: minimal | low | medium | high (reasoning_effort top-level) -# HF router: low | medium | high (extra_body.reasoning_effort) -# -# We validate *shape* here and let the probe cascade walk down on rejection; -# we deliberately do NOT maintain a per-model capability table. -_ANTHROPIC_EFFORTS = {"low", "medium", "high", "xhigh", "max"} -_OPENAI_EFFORTS = {"minimal", "low", "medium", "high"} -_HF_EFFORTS = {"low", "medium", "high"} - - -class UnsupportedEffortError(ValueError): - """The requested effort isn't valid for this provider's API surface. - - Raised synchronously before any network call so the probe cascade can - skip levels the provider can't accept (e.g. ``max`` on HF router). - """ - - def _resolve_llm_params( model_name: str, session_hf_token: str | None = None, reasoning_effort: str | None = None, strict: bool = False, ) -> dict: - """ - Build LiteLLM kwargs for a given model id. - - • ``anthropic/`` — native thinking config. We bypass LiteLLM's - ``reasoning_effort`` → ``thinking`` mapping (which lags new Claude - releases like 4.7 and sends the wrong API shape). Instead we pass - both ``thinking={"type": "adaptive"}`` and ``output_config= - {"effort": }`` as top-level kwargs — LiteLLM's Anthropic - adapter forwards unknown top-level kwargs into the request body - verbatim (confirmed by live probe; ``extra_body`` does NOT work - here because Anthropic's API rejects it as "Extra inputs are not - permitted"). This is the stable API for 4.6 and 4.7. Older - extended-thinking models that only accept ``thinking.type.enabled`` - will reject this; the probe's cascade catches that and falls back - to no thinking. + """Build LiteLLM kwargs for a given model id. - • ``openai/`` — ``reasoning_effort`` forwarded as a top-level - kwarg (GPT-5 / o-series). LiteLLM uses the user's ``OPENAI_API_KEY``. - - • Anything else is treated as a HuggingFace router id. We hit the - auto-routing OpenAI-compatible endpoint at - ``https://router.huggingface.co/v1``. The id can be bare or carry an - HF routing suffix (``:fastest`` / ``:cheapest`` / ``:``). - A leading ``huggingface/`` is stripped. ``reasoning_effort`` is - forwarded via ``extra_body`` (LiteLLM's OpenAI adapter refuses it as - a top-level kwarg for non-OpenAI models). "minimal" normalizes to - "low". + Delegates to the matching provider adapter. ``strict=True`` raises ``UnsupportedEffortError`` when the requested effort isn't in the provider's accepted set, instead of silently @@ -131,70 +65,12 @@ def _resolve_llm_params( 2. session.hf_token — the user's own token (CLI / OAuth / cache file). 3. HF_TOKEN env — belt-and-suspenders fallback for CLI users. """ - if model_name.startswith("anthropic/"): - params: dict = {"model": model_name} - if reasoning_effort: - level = reasoning_effort - if level == "minimal": - level = "low" - if level not in _ANTHROPIC_EFFORTS: - if strict: - raise UnsupportedEffortError( - f"Anthropic doesn't accept effort={level!r}" - ) - else: - # Adaptive thinking + output_config.effort is the stable - # Anthropic API for Claude 4.6 / 4.7. Both kwargs are - # passed top-level: LiteLLM forwards unknown params into - # the request body for Anthropic, so ``output_config`` - # reaches the API. ``extra_body`` does NOT work here — - # Anthropic rejects it as "Extra inputs are not - # permitted". - params["thinking"] = {"type": "adaptive"} - params["output_config"] = {"effort": level} - return params - - if model_name.startswith("bedrock/"): - # LiteLLM routes ``bedrock/...`` through the Converse adapter, which - # picks up AWS credentials from the standard env vars - # (``AWS_ACCESS_KEY_ID`` / ``AWS_SECRET_ACCESS_KEY`` / ``AWS_REGION``). - # The Anthropic thinking/effort shape is not forwarded through Converse - # the same way, so we leave it off for now. - return {"model": model_name} - - if model_name.startswith("openai/"): - params = {"model": model_name} - if reasoning_effort: - if reasoning_effort not in _OPENAI_EFFORTS: - if strict: - raise UnsupportedEffortError( - f"OpenAI doesn't accept effort={reasoning_effort!r}" - ) - else: - params["reasoning_effort"] = reasoning_effort - return params - - hf_model = model_name.removeprefix("huggingface/") - api_key = ( - os.environ.get("INFERENCE_TOKEN") - or session_hf_token - or os.environ.get("HF_TOKEN") + adapter = resolve_adapter(model_name) + if adapter is None: + raise ValueError(f"No provider adapter for model: {model_name}") + return adapter.build_params( + model_name, + session_hf_token=session_hf_token, + reasoning_effort=reasoning_effort, + strict=strict, ) - params = { - "model": f"openai/{hf_model}", - "api_base": "https://router.huggingface.co/v1", - "api_key": api_key, - } - if os.environ.get("INFERENCE_TOKEN"): - bill_to = os.environ.get("HF_BILL_TO", "smolagents") - params["extra_headers"] = {"X-HF-Bill-To": bill_to} - if reasoning_effort: - hf_level = "low" if reasoning_effort == "minimal" else reasoning_effort - if hf_level not in _HF_EFFORTS: - if strict: - raise UnsupportedEffortError( - f"HF router doesn't accept effort={hf_level!r}" - ) - else: - params["extra_body"] = {"reasoning_effort": hf_level} - return params diff --git a/agent/core/model_switcher.py b/agent/core/model_switcher.py index afb8d52c..433d4236 100644 --- a/agent/core/model_switcher.py +++ b/agent/core/model_switcher.py @@ -3,7 +3,7 @@ Split out of ``agent.main`` so the REPL dispatcher stays focused on input parsing. Exposes: -* ``SUGGESTED_MODELS`` — the short list shown by ``/model`` with no arg. +* Adapter-driven model catalog rendered by ``/model`` with no arg. * ``is_valid_model_id`` — loose format check on user input. * ``probe_and_switch_model`` — async: checks routing, fires a 1-token probe to resolve the effort cascade, then commits the switch (or @@ -16,20 +16,14 @@ from __future__ import annotations from agent.core.effort_probe import ProbeInconclusive, probe_effort - - -# Suggested models shown by `/model` (not a gate). Users can paste any HF -# model id (e.g. "MiniMaxAI/MiniMax-M2.7") or an `anthropic/` / `openai/` -# prefix for direct API access. For HF ids, append ":fastest" / -# ":cheapest" / ":preferred" / ":" to override the default -# routing policy (auto = fastest with failover). -SUGGESTED_MODELS = [ - {"id": "bedrock/us.anthropic.claude-opus-4-7", "label": "Claude Opus 4.7"}, - {"id": "bedrock/us.anthropic.claude-opus-4-6-v1", "label": "Claude Opus 4.6"}, - {"id": "MiniMaxAI/MiniMax-M2.7", "label": "MiniMax M2.7"}, - {"id": "moonshotai/Kimi-K2.6", "label": "Kimi K2.6"}, - {"id": "zai-org/GLM-5.1", "label": "GLM 5.1"}, -] +from agent.core.llm_errors import render_llm_error_message +from agent.core.provider_adapters import ( + ADAPTERS, + find_model_option, + get_available_models, + is_valid_model_name, + resolve_adapter, +) _ROUTING_POLICIES = {"fastest", "cheapest", "preferred"} @@ -47,11 +41,7 @@ def is_valid_model_id(model_id: str) -> bool: Actual availability is verified against the HF router catalog on switch, and by the provider on the probe's ping call. """ - if not model_id or "/" not in model_id: - return False - head = model_id.split(":", 1)[0] - parts = head.split("/") - return len(parts) >= 2 and all(parts) + return is_valid_model_name(model_id) def _print_hf_routing_info(model_id: str, console) -> bool: @@ -63,7 +53,8 @@ def _print_hf_routing_info(model_id: str, console) -> bool: Anthropic / OpenAI ids return ``True`` without printing anything — the probe below covers "does this model exist". """ - if model_id.startswith(("anthropic/", "openai/")): + adapter = resolve_adapter(model_id) + if adapter and adapter.provider_id != "huggingface": return True from agent.core import hf_router_catalog as cat @@ -125,18 +116,45 @@ def _print_hf_routing_info(model_id: str, console) -> bool: def print_model_listing(config, console) -> None: - """Render the default ``/model`` (no-arg) view: current + suggested.""" + """Render the default ``/model`` (no-arg) view: current + available.""" current = config.model_name if config else "" + current_info = find_model_option(current) + available = get_available_models() + console.print("[bold]Current model:[/bold]") - console.print(f" {current}") - console.print("\n[bold]Suggested:[/bold]") - for m in SUGGESTED_MODELS: - marker = " [dim]<-- current[/dim]" if m["id"] == current else "" - console.print(f" {m['id']} [dim]({m['label']})[/dim]{marker}") + if current_info: + console.print(f" {current} [dim]({current_info['label']})[/dim]") + else: + console.print(f" {current}") + + console.print("\n[bold]Available:[/bold]") + for adapter in ADAPTERS: + section = [m for m in available if m.get("provider") == adapter.provider_id] + if not section: + if adapter.provider_id == "openai_compat" and adapter.should_show(): + console.print(f"\n[bold]{adapter.provider_label}[/bold]") + console.print(" [dim]Use openai-compat/[/dim]") + continue + + console.print(f"\n[bold]{adapter.provider_label}[/bold]") + section_sorted = sorted( + section, + key=lambda m: ( + not bool(m.get("recommended")), + str(m.get("label", "")).lower(), + ), + ) + for m in section_sorted: + marker = " [dim]<-- current[/dim]" if m["id"] == current else "" + source = m.get("source") + source_tag = " [dim](dynamic)[/dim]" if source == "dynamic" else "" + console.print(f" {m['id']} [dim]({m['label']})[/dim]{source_tag}{marker}") + console.print( "\n[dim]Paste any HF model id (e.g. 'MiniMaxAI/MiniMax-M2.7').\n" "Add ':fastest', ':cheapest', ':preferred', or ':' to override routing.\n" - "Use 'anthropic/' or 'openai/' for direct API access.[/dim]" + "Direct prefixes: 'anthropic/', 'openai/', 'openrouter/', 'opencode/',\n" + "'opencode-go/', 'ollama/', 'lm_studio/', 'vllm/', 'openai-compat/'.[/dim]" ) @@ -146,7 +164,16 @@ def print_invalid_id(arg: str, console) -> None: "[dim]Expected:\n" " • /[:tag] (HF router — paste from huggingface.co)\n" " • anthropic/\n" - " • openai/[/dim]" + " • openai/\n" + " • bedrock/\n" + " • gemini/\n" + " • openrouter/\n" + " • opencode/\n" + " • opencode-go/\n" + " • ollama/\n" + " • lm_studio/\n" + " • vllm/\n" + " • openai-compat/[/dim]" ) @@ -187,14 +214,15 @@ async def probe_and_switch_model( outcome = await probe_effort(model_id, preference, hf_token) except ProbeInconclusive as e: _commit_switch(model_id, config, session, effective=None, cache=False) + warning = render_llm_error_message(e) console.print( f"[yellow]Model switched to {model_id}[/yellow] " - f"[dim](couldn't validate: {e}; will verify on first message)[/dim]" + f"[dim](couldn't validate: {warning}; will verify on first message)[/dim]" ) return except Exception as e: # Hard persistent error — auth, unknown model, quota. Don't switch. - console.print(f"[bold red]Switch failed:[/bold red] {e}") + console.print(f"[bold red]Switch failed:[/bold red] {render_llm_error_message(e)}") console.print(f"[dim]Keeping current model: {config.model_name}[/dim]") return diff --git a/agent/core/provider_adapters.py b/agent/core/provider_adapters.py new file mode 100644 index 00000000..a7f8a278 --- /dev/null +++ b/agent/core/provider_adapters.py @@ -0,0 +1,718 @@ +"""Provider adapters for runtime params and model catalog metadata.""" + +import functools +import json +import os +import time +from dataclasses import dataclass +from typing import Any, ClassVar +from urllib.error import URLError +from urllib.request import urlopen + +_DISCOVERY_TIMEOUT_SECONDS = 2.0 +_DISCOVERY_CACHE_TTL_SECONDS = 30.0 +_discovery_cache: dict[str, tuple[float, tuple["SuggestedModel", ...]]] = {} + + +class UnsupportedEffortError(ValueError): + """The requested effort isn't valid for this provider's API surface. + + Raised synchronously before any network call so the probe cascade can + skip levels the provider can't accept (e.g. ``max`` on HF router). + """ + + +@dataclass(frozen=True) +class SuggestedModel: + id: str + label: str + description: str + provider: str + provider_label: str + avatar_url: str + recommended: bool = False + source: str = "static" + + +def _has_model_suffix(model_name: str, prefix: str) -> bool: + if not model_name.startswith(prefix): + return False + tail = model_name[len(prefix) :].split(":", 1)[0] + return bool(tail) and all(tail.split("/")) + + +def _normalize_openai_api_base(api_base: str) -> str: + base = api_base.rstrip("/") + if base.endswith("/v1"): + return base + return f"{base}/v1" + + +def _provider_avatar_url(provider_id: str) -> str: + avatars = { + "anthropic": "https://huggingface.co/api/avatars/Anthropic", + "gemini": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06.svg", + "openai": "https://openai.com/favicon.ico", + "ollama": "https://ollama.com/public/ollama.png", + "lm_studio": "https://avatars.githubusercontent.com/u/16906759?s=200&v=4", + "vllm": "https://avatars.githubusercontent.com/u/132129714?s=200&v=4", + "openrouter": "https://openrouter.ai/favicon.ico", + "opencode_zen": "https://huggingface.co/api/avatars/opencode-ai", + "opencode_go": "https://huggingface.co/api/avatars/opencode-ai", + "openai_compat": "https://openai.com/favicon.ico", + "huggingface": "https://huggingface.co/api/avatars/huggingface", + } + return avatars.get(provider_id, "https://huggingface.co/api/avatars/huggingface") + + +@functools.lru_cache(maxsize=1) +def _all_adapter_prefixes() -> tuple[str, ...]: + prefixes: list[str] = [] + for adapter in ADAPTERS: + prefixes.extend(adapter.prefixes) + return tuple(dict.fromkeys(prefixes)) + + +def _is_hf_model_name(model_name: str) -> bool: + if model_name.startswith(_all_adapter_prefixes()): + return False + bare = model_name.removeprefix("huggingface/").split(":", 1)[0] + parts = bare.split("/") + return len(parts) >= 2 and all(parts) + + +def _discover_models( + *, + provider_id: str, + provider_label: str, + prefix: str, + api_base: str, +) -> tuple[SuggestedModel, ...]: + now = time.monotonic() + cached = _discovery_cache.get(api_base) + if cached and cached[0] > now: + return cached[1] + + models: list[SuggestedModel] = [] + try: + with urlopen( + f"{api_base}/models", timeout=_DISCOVERY_TIMEOUT_SECONDS + ) as response: + payload = json.load(response) + except (OSError, URLError, TimeoutError, ValueError): + payload = {"data": []} + + for item in payload.get("data", []): + if not isinstance(item, dict): + continue + model_id = item.get("id") + if not isinstance(model_id, str) or not model_id: + continue + models.append( + SuggestedModel( + id=f"{prefix}{model_id}", + label=model_id, + description=provider_label, + provider=provider_id, + provider_label=provider_label, + avatar_url=_provider_avatar_url(provider_id), + source="dynamic", + ) + ) + + resolved = tuple(sorted(models, key=lambda m: m.label.lower())) + _discovery_cache[api_base] = (now + _DISCOVERY_CACHE_TTL_SECONDS, resolved) + return resolved + + +@dataclass(frozen=True) +class ProviderAdapter: + provider_id: str + provider_label: str + prefixes: tuple[str, ...] = () + supports_custom_model: bool = False + custom_model_hint: str | None = None + custom_model_mode: str | None = None + + def matches(self, model_name: str) -> bool: + return bool(self.prefixes) and model_name.startswith(self.prefixes) + + def suggested_models(self) -> tuple[SuggestedModel, ...]: + return () + + def available_models(self) -> tuple[SuggestedModel, ...]: + return self.suggested_models() + + def should_show(self) -> bool: + return True + + def build_params( + self, + model_name: str, + *, + session_hf_token: str | None = None, + reasoning_effort: str | None = None, + strict: bool = False, + ) -> dict: + raise NotImplementedError + + def allows_model_name(self, model_name: str) -> bool: + return self.matches(model_name) + + def to_summary(self) -> dict[str, Any]: + return { + "id": self.provider_id, + "label": self.provider_label, + "avatarUrl": _provider_avatar_url(self.provider_id), + "supportsCustomModel": self.supports_custom_model, + "customModelHint": self.custom_model_hint, + "customModelMode": self.custom_model_mode, + "prefix": self.prefixes[0] if self.prefixes else "", + } + + +@dataclass(frozen=True) +class AnthropicAdapter(ProviderAdapter): + """Anthropic models via native API (thinking + output_config.effort).""" + + prefixes: tuple[str, ...] = ("anthropic/",) + _EFFORTS: ClassVar[frozenset[str]] = frozenset( + {"low", "medium", "high", "xhigh", "max"} + ) + + def suggested_models(self) -> tuple[SuggestedModel, ...]: + return ( + SuggestedModel( + id="anthropic/claude-opus-4-7", + label="Claude Opus 4.7", + description="Anthropic", + provider="anthropic", + provider_label="Anthropic", + avatar_url=_provider_avatar_url("anthropic"), + recommended=True, + ), + SuggestedModel( + id="anthropic/claude-opus-4-6", + label="Claude Opus 4.6", + description="Anthropic", + provider="anthropic", + provider_label="Anthropic", + avatar_url=_provider_avatar_url("anthropic"), + ), + ) + + def allows_model_name(self, model_name: str) -> bool: + return _has_model_suffix(model_name, "anthropic/") + + def build_params( + self, + model_name: str, + *, + session_hf_token: str | None = None, + reasoning_effort: str | None = None, + strict: bool = False, + ) -> dict: + params: dict[str, Any] = {"model": model_name} + if reasoning_effort: + level = "low" if reasoning_effort == "minimal" else reasoning_effort + if level not in self._EFFORTS: + if strict: + raise UnsupportedEffortError( + f"Anthropic doesn't accept effort={level!r}" + ) + else: + params["thinking"] = {"type": "adaptive"} + params["output_config"] = {"effort": level} + return params + + +@dataclass(frozen=True) +class OpenAIAdapter(ProviderAdapter): + """OpenAI models via native API (reasoning_effort top-level kwarg).""" + + prefixes: tuple[str, ...] = ("openai/",) + _EFFORTS: ClassVar[frozenset[str]] = frozenset({"minimal", "low", "medium", "high"}) + + def suggested_models(self) -> tuple[SuggestedModel, ...]: + return ( + SuggestedModel( + id="openai/gpt-5", + label="GPT-5", + description="OpenAI", + provider="openai", + provider_label="OpenAI", + avatar_url=_provider_avatar_url("openai"), + recommended=True, + ), + ) + + def allows_model_name(self, model_name: str) -> bool: + return _has_model_suffix(model_name, "openai/") + + def build_params( + self, + model_name: str, + *, + session_hf_token: str | None = None, + reasoning_effort: str | None = None, + strict: bool = False, + ) -> dict: + params: dict[str, Any] = {"model": model_name} + if reasoning_effort: + if reasoning_effort not in self._EFFORTS: + if strict: + raise UnsupportedEffortError( + f"OpenAI doesn't accept effort={reasoning_effort!r}" + ) + else: + params["reasoning_effort"] = reasoning_effort + return params + + +@dataclass(frozen=True) +class BedrockAdapter(ProviderAdapter): + """AWS Bedrock models via LiteLLM Converse adapter. + + Picks up AWS credentials from standard env vars. + Thinking/effort not forwarded through Converse for now. + """ + + prefixes: tuple[str, ...] = ("bedrock/",) + + def allows_model_name(self, model_name: str) -> bool: + return _has_model_suffix(model_name, "bedrock/") + + def build_params( + self, + model_name: str, + *, + session_hf_token: str | None = None, + reasoning_effort: str | None = None, + strict: bool = False, + ) -> dict: + return {"model": model_name} + + +@dataclass(frozen=True) +class GeminiAdapter(ProviderAdapter): + """Google Gemini via LiteLLM's native Gemini adapter. + + GEMINI_API_KEY picked up automatically by LiteLLM. + reasoning_effort forwarded directly — LiteLLM maps to thinking_config. + """ + + prefixes: tuple[str, ...] = ("gemini/",) + _EFFORTS: ClassVar[frozenset[str]] = frozenset({"low", "medium", "high"}) + + def suggested_models(self) -> tuple[SuggestedModel, ...]: + return ( + SuggestedModel( + id="gemini/gemini-2.5-pro", + label="Gemini 2.5 Pro", + description="Google Gemini", + provider="gemini", + provider_label="Google Gemini", + avatar_url=_provider_avatar_url("gemini"), + recommended=True, + ), + SuggestedModel( + id="gemini/gemini-2.5-flash", + label="Gemini 2.5 Flash", + description="Google Gemini", + provider="gemini", + provider_label="Google Gemini", + avatar_url=_provider_avatar_url("gemini"), + ), + ) + + def allows_model_name(self, model_name: str) -> bool: + return _has_model_suffix(model_name, "gemini/") + + def build_params( + self, + model_name: str, + *, + session_hf_token: str | None = None, + reasoning_effort: str | None = None, + strict: bool = False, + ) -> dict: + params: dict[str, Any] = {"model": model_name} + if reasoning_effort: + level = "low" if reasoning_effort == "minimal" else reasoning_effort + if level not in self._EFFORTS: + if strict: + raise UnsupportedEffortError( + f"Gemini doesn't accept effort={level!r}" + ) + else: + params["reasoning_effort"] = level + return params + + +@dataclass(frozen=True) +class OpenAICompatAdapter(ProviderAdapter): + api_base_url: str = "" + api_key_env: str = "" + default_api_key: str = "" + supports_reasoning_effort: bool = True + use_raw_model_name: bool = False + + def resolved_api_base(self) -> str: + return _normalize_openai_api_base(self.api_base_url) + + def resolved_api_key(self) -> str | None: + if self.api_key_env: + return os.environ.get(self.api_key_env, self.default_api_key) + return self.default_api_key or None + + def suggested_model_defs(self) -> tuple[tuple[str, str, bool], ...]: + return () + + def suggested_models(self) -> tuple[SuggestedModel, ...]: + prefix = self.prefixes[0] + return tuple( + SuggestedModel( + id=f"{prefix}{model_id}", + label=label, + description=self.provider_label, + provider=self.provider_id, + provider_label=self.provider_label, + avatar_url=_provider_avatar_url(self.provider_id), + recommended=recommended, + ) + for model_id, label, recommended in self.suggested_model_defs() + ) + + def allows_model_name(self, model_name: str) -> bool: + return bool(self.prefixes) and _has_model_suffix(model_name, self.prefixes[0]) + + def build_params( + self, + model_name: str, + *, + session_hf_token: str | None = None, + reasoning_effort: str | None = None, + strict: bool = False, + ) -> dict: + del session_hf_token + + model_id = model_name.removeprefix(self.prefixes[0]) + params: dict[str, Any] = { + "model": model_id if self.use_raw_model_name else f"openai/{model_id}", + "api_base": self.resolved_api_base(), + "api_key": self.resolved_api_key(), + } + + if reasoning_effort: + if not self.supports_reasoning_effort: + if strict: + raise UnsupportedEffortError( + f"{self.provider_id} doesn't accept effort={reasoning_effort!r}" + ) + else: + params["reasoning_effort"] = reasoning_effort + + return params + + +@dataclass(frozen=True) +class OllamaAdapter(OpenAICompatAdapter): + prefixes: tuple[str, ...] = ("ollama/",) + api_key_env: str = "OLLAMA_API_KEY" + default_api_key: str = "ollama" + supports_reasoning_effort: bool = False + + def resolved_api_base(self) -> str: + return _normalize_openai_api_base( + os.environ.get("OLLAMA_API_BASE", "http://localhost:11434/v1") + ) + + def available_models(self) -> tuple[SuggestedModel, ...]: + return _discover_models( + provider_id=self.provider_id, + provider_label=self.provider_label, + prefix=self.prefixes[0], + api_base=self.resolved_api_base(), + ) + + def should_show(self) -> bool: + return bool(self.available_models()) + + +@dataclass(frozen=True) +class LmStudioAdapter(OpenAICompatAdapter): + prefixes: tuple[str, ...] = ("lm_studio/",) + api_key_env: str = "LMSTUDIO_API_KEY" + default_api_key: str = "lm-studio" + supports_reasoning_effort: bool = False + use_raw_model_name: bool = True + + def resolved_api_base(self) -> str: + return _normalize_openai_api_base( + os.environ.get("LMSTUDIO_BASE_URL", "http://127.0.0.1:1234/v1") + ) + + def available_models(self) -> tuple[SuggestedModel, ...]: + return _discover_models( + provider_id=self.provider_id, + provider_label=self.provider_label, + prefix=self.prefixes[0], + api_base=self.resolved_api_base(), + ) + + def should_show(self) -> bool: + return bool(self.available_models()) + + +@dataclass(frozen=True) +class VllmAdapter(OpenAICompatAdapter): + prefixes: tuple[str, ...] = ("vllm/",) + api_key_env: str = "VLLM_API_KEY" + default_api_key: str = "vllm" + supports_reasoning_effort: bool = False + + def resolved_api_base(self) -> str: + return _normalize_openai_api_base( + os.environ.get("VLLM_BASE_URL", "http://localhost:8000/v1") + ) + + def available_models(self) -> tuple[SuggestedModel, ...]: + return _discover_models( + provider_id=self.provider_id, + provider_label=self.provider_label, + prefix=self.prefixes[0], + api_base=self.resolved_api_base(), + ) + + def should_show(self) -> bool: + return bool(self.available_models()) + + +@dataclass(frozen=True) +class OpenRouterAdapter(OpenAICompatAdapter): + prefixes: tuple[str, ...] = ("openrouter/",) + api_base_url: str = "https://openrouter.ai/api/v1" + api_key_env: str = "OPENROUTER_API_KEY" + + def suggested_model_defs(self) -> tuple[tuple[str, str, bool], ...]: + return (("anthropic/claude-sonnet-4.5", "Claude Sonnet 4.5", True),) + + +@dataclass(frozen=True) +class OpenCodeZenAdapter(OpenAICompatAdapter): + prefixes: tuple[str, ...] = ("opencode/",) + api_base_url: str = "https://opencode.ai/zen/v1" + api_key_env: str = "OPENCODE_ZEN_API_KEY" + + def suggested_model_defs(self) -> tuple[tuple[str, str, bool], ...]: + return (("kimi-k2.6", "Kimi K2.6", True),) + + +@dataclass(frozen=True) +class OpenCodeGoAdapter(OpenAICompatAdapter): + prefixes: tuple[str, ...] = ("opencode-go/",) + api_base_url: str = "https://opencode.ai/zen/go/v1" + api_key_env: str = "OPENCODE_GO_API_KEY" + + def suggested_model_defs(self) -> tuple[tuple[str, str, bool], ...]: + return (("kimi-k2.6", "Kimi K2.6", True),) + + +@dataclass(frozen=True) +class GenericOpenAICompatAdapter(OpenAICompatAdapter): + prefixes: tuple[str, ...] = ("openai-compat/",) + api_key_env: str = "OPENAI_COMPAT_API_KEY" + supports_custom_model: bool = True + custom_model_hint: str | None = ( + "Use openai-compat/. Configure OPENAI_COMPAT_BASE_URL on server." + ) + custom_model_mode: str | None = "suffix" + + def resolved_api_base(self) -> str: + api_base = os.environ.get("OPENAI_COMPAT_BASE_URL", "") + if not api_base: + raise ValueError("OPENAI_COMPAT_BASE_URL is required for openai-compat/") + return _normalize_openai_api_base(api_base) + + def should_show(self) -> bool: + return bool(os.environ.get("OPENAI_COMPAT_BASE_URL")) + + +@dataclass(frozen=True) +class HfRouterAdapter(ProviderAdapter): + """HuggingFace router — OpenAI-compat endpoint with HF token chain.""" + + _EFFORTS: ClassVar[frozenset[str]] = frozenset({"low", "medium", "high"}) + supports_custom_model: bool = True + custom_model_hint: str | None = ( + "Paste any Hugging Face model id, optionally with :fastest/:cheapest/:preferred" + ) + custom_model_mode: str | None = "raw" + + def suggested_models(self) -> tuple[SuggestedModel, ...]: + return ( + SuggestedModel( + id="moonshotai/Kimi-K2.6", + label="Kimi K2.6", + description="HF Router", + provider="huggingface", + provider_label="Hugging Face Router", + avatar_url="https://huggingface.co/api/avatars/moonshotai", + recommended=True, + ), + SuggestedModel( + id="MiniMaxAI/MiniMax-M2.7", + label="MiniMax M2.7", + description="HF Router", + provider="huggingface", + provider_label="Hugging Face Router", + avatar_url="https://huggingface.co/api/avatars/MiniMaxAI", + ), + SuggestedModel( + id="zai-org/GLM-5.1", + label="GLM 5.1", + description="HF Router", + provider="huggingface", + provider_label="Hugging Face Router", + avatar_url="https://huggingface.co/api/avatars/zai-org", + ), + ) + + def matches(self, model_name: str) -> bool: + return _is_hf_model_name(model_name) + + def allows_model_name(self, model_name: str) -> bool: + return _is_hf_model_name(model_name) + + def build_params( + self, + model_name: str, + *, + session_hf_token: str | None = None, + reasoning_effort: str | None = None, + strict: bool = False, + ) -> dict: + hf_model = model_name.removeprefix("huggingface/") + inference_token = os.environ.get("INFERENCE_TOKEN") + api_key = inference_token or session_hf_token or os.environ.get("HF_TOKEN") + + params: dict[str, Any] = { + "model": f"openai/{hf_model}", + "api_base": "https://router.huggingface.co/v1", + "api_key": api_key, + } + + if inference_token: + bill_to = os.environ.get("HF_BILL_TO", "smolagents") + params["extra_headers"] = {"X-HF-Bill-To": bill_to} + + if reasoning_effort: + hf_level = "low" if reasoning_effort == "minimal" else reasoning_effort + if hf_level not in self._EFFORTS: + if strict: + raise UnsupportedEffortError( + f"HF router doesn't accept effort={hf_level!r}" + ) + else: + params["extra_body"] = {"reasoning_effort": hf_level} + + return params + + +ADAPTERS: tuple[ProviderAdapter, ...] = ( + AnthropicAdapter(provider_id="anthropic", provider_label="Anthropic"), + BedrockAdapter(provider_id="bedrock", provider_label="AWS Bedrock"), + GeminiAdapter(provider_id="gemini", provider_label="Google Gemini"), + OpenAIAdapter(provider_id="openai", provider_label="OpenAI"), + OllamaAdapter(provider_id="ollama", provider_label="Ollama"), + LmStudioAdapter(provider_id="lm_studio", provider_label="LM Studio"), + VllmAdapter(provider_id="vllm", provider_label="vLLM"), + OpenRouterAdapter(provider_id="openrouter", provider_label="OpenRouter"), + OpenCodeZenAdapter(provider_id="opencode_zen", provider_label="OpenCode Zen"), + OpenCodeGoAdapter(provider_id="opencode_go", provider_label="OpenCode Go"), + GenericOpenAICompatAdapter( + provider_id="openai_compat", + provider_label="OpenAI-Compatible", + ), + HfRouterAdapter(provider_id="huggingface", provider_label="Hugging Face Router") +) + + +def resolve_adapter(model_name: str) -> ProviderAdapter | None: + for adapter in ADAPTERS: + if adapter.matches(model_name): + return adapter + return None + + +def is_valid_model_name(model_name: str) -> bool: + adapter = resolve_adapter(model_name) + return adapter is not None and adapter.allows_model_name(model_name) + + +def _serialized_model(model: SuggestedModel) -> dict[str, Any]: + return { + "id": model.id, + "label": model.label, + "description": model.description, + "provider": model.provider, + "providerLabel": model.provider_label, + "avatarUrl": model.avatar_url, + "recommended": model.recommended, + "source": model.source, + } + + +def get_available_models() -> list[dict[str, Any]]: + available: list[dict[str, Any]] = [] + for adapter in ADAPTERS: + if not adapter.should_show(): + continue + for model in adapter.available_models(): + available.append(_serialized_model(model)) + return available + + +def get_provider_summaries() -> list[dict[str, Any]]: + providers: list[dict[str, Any]] = [] + for adapter in ADAPTERS: + if not adapter.should_show(): + continue + providers.append(adapter.to_summary()) + return providers + + +def find_model_option(model_name: str) -> dict[str, Any] | None: + for model in get_available_models(): + if model["id"] == model_name: + return model + + adapter = resolve_adapter(model_name) + if not adapter or not adapter.allows_model_name(model_name): + return None + + label = model_name + if adapter.provider_id == "huggingface": + label = model_name.removeprefix("huggingface/") + elif adapter.prefixes: + label = model_name.removeprefix(adapter.prefixes[0]) + + return { + "id": model_name, + "label": label, + "description": f"Custom {adapter.provider_label} model", + "provider": adapter.provider_id, + "providerLabel": adapter.provider_label, + "avatarUrl": _provider_avatar_url(adapter.provider_id), + "recommended": False, + "source": "custom", + } + + +def build_model_catalog(current_model: str) -> dict[str, Any]: + return { + "current": current_model, + "available": get_available_models(), + "providers": get_provider_summaries(), + "currentInfo": find_model_option(current_model), + } diff --git a/backend/routes/agent.py b/backend/routes/agent.py index 7f577995..8c5e4504 100644 --- a/backend/routes/agent.py +++ b/backend/routes/agent.py @@ -32,52 +32,29 @@ import user_quotas +from agent.core.llm_errors import health_error_type, render_llm_error_message from agent.core.llm_params import _resolve_llm_params +from agent.core.provider_adapters import ( + build_model_catalog, + is_valid_model_name, + resolve_adapter, +) logger = logging.getLogger(__name__) router = APIRouter(prefix="/api", tags=["agent"]) -AVAILABLE_MODELS = [ - { - "id": "moonshotai/Kimi-K2.6", - "label": "Kimi K2.6", - "provider": "huggingface", - "tier": "free", - "recommended": True, - }, - { - "id": "bedrock/us.anthropic.claude-opus-4-6-v1", - "label": "Claude Opus 4.6", - "provider": "anthropic", - "tier": "pro", - "recommended": True, - }, - { - "id": "MiniMaxAI/MiniMax-M2.7", - "label": "MiniMax M2.7", - "provider": "huggingface", - "tier": "free", - }, - { - "id": "zai-org/GLM-5.1", - "label": "GLM 5.1", - "provider": "huggingface", - "tier": "free", - }, -] - def _is_anthropic_model(model_id: str) -> bool: - return "anthropic" in model_id + return model_id.startswith(("anthropic/", "bedrock/")) async def _require_hf_for_anthropic(request: Request, model_id: str) -> None: """403 if a non-``huggingface``-org user tries to select an Anthropic model. - Anthropic models are billed to the Space's ``ANTHROPIC_API_KEY``; every - other model in ``AVAILABLE_MODELS`` is routed through HF Router and - billed via ``X-HF-Bill-To``. The gate only fires for Anthropic so + Anthropic models are billed to the Space's ``ANTHROPIC_API_KEY``; other + providers use their own routing/billing config. The gate only fires for + ``anthropic/*`` so non-HF users can still freely switch between the free models. Pattern: https://github.com/huggingface/ml-intern/pull/63 @@ -176,34 +153,12 @@ async def llm_health_check() -> LLMHealthResponse: ) return LLMHealthResponse(status="ok", model=model) except Exception as e: - err_str = str(e).lower() - error_type = "unknown" - - if ( - "401" in err_str - or "auth" in err_str - or "invalid" in err_str - or "api key" in err_str - ): - error_type = "auth" - elif ( - "402" in err_str - or "credit" in err_str - or "quota" in err_str - or "insufficient" in err_str - or "billing" in err_str - ): - error_type = "credits" - elif "429" in err_str or "rate" in err_str: - error_type = "rate_limit" - elif "timeout" in err_str or "connect" in err_str or "network" in err_str: - error_type = "network" - + error_type = health_error_type(e) logger.warning(f"LLM health check failed ({error_type}): {e}") return LLMHealthResponse( status="error", model=model, - error=str(e)[:500], + error=render_llm_error_message(e)[:500], error_type=error_type, ) @@ -211,10 +166,9 @@ async def llm_health_check() -> LLMHealthResponse: @router.get("/config/model") async def get_model() -> dict: """Get current model and available models. No auth required.""" - return { - "current": session_manager.config.model_name, - "available": AVAILABLE_MODELS, - } + return await asyncio.to_thread( + build_model_catalog, session_manager.config.model_name + ) _TITLE_STRIP_CHARS = str.maketrans("", "", "`*_~#[]()") @@ -309,8 +263,7 @@ async def create_session( if isinstance(body, dict): model = body.get("model") - valid_ids = {m["id"] for m in AVAILABLE_MODELS} - if model and model not in valid_ids: + if model and not is_valid_model_name(model): raise HTTPException(status_code=400, detail=f"Unknown model: {model}") # Opus is gated to HF staff (PR #63). Only fires when the resolved model @@ -343,6 +296,8 @@ async def restore_session_summary( messages = body.get("messages") if not isinstance(messages, list) or not messages: raise HTTPException(status_code=400, detail="Missing 'messages' array") + if len(messages) > 2000: + raise HTTPException(status_code=400, detail="Too many messages (max 2000)") hf_token = None auth_header = request.headers.get("Authorization", "") @@ -354,8 +309,7 @@ async def restore_session_summary( hf_token = os.environ.get("HF_TOKEN") model = body.get("model") - valid_ids = {m["id"] for m in AVAILABLE_MODELS} - if model and model not in valid_ids: + if model and not is_valid_model_name(model): raise HTTPException(status_code=400, detail=f"Unknown model: {model}") resolved_model = model or session_manager.config.model_name @@ -413,9 +367,13 @@ async def set_session_model( model_id = body.get("model") if not model_id: raise HTTPException(status_code=400, detail="Missing 'model' field") - valid_ids = {m["id"] for m in AVAILABLE_MODELS} - if model_id not in valid_ids: + if not is_valid_model_name(model_id): raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}") + adapter = resolve_adapter(model_id) + if adapter and not adapter.should_show(): + raise HTTPException( + status_code=400, detail=f"Provider not configured for: {model_id}" + ) await _require_hf_for_anthropic(request, model_id) agent_session = session_manager.sessions.get(session_id) if not agent_session: diff --git a/backend/session_manager.py b/backend/session_manager.py index 7293f9cf..e2747137 100644 --- a/backend/session_manager.py +++ b/backend/session_manager.py @@ -9,6 +9,7 @@ from typing import Any, Optional from agent.config import load_config +from agent.core.llm_errors import render_llm_error_message from agent.core.agent_loop import process_submission from agent.core.session import Event, OpType, Session from agent.core.tools import ToolRouter @@ -344,7 +345,7 @@ async def _run_session( except Exception as e: logger.error(f"Error in session {session_id}: {e}") await session.send_event( - Event(event_type="error", data={"error": str(e)}) + Event(event_type="error", data={"error": render_llm_error_message(e)}) ) finally: diff --git a/frontend/src/components/Chat/ChatInput.tsx b/frontend/src/components/Chat/ChatInput.tsx index d9fe5c4d..5f86a0f5 100644 --- a/frontend/src/components/Chat/ChatInput.tsx +++ b/frontend/src/components/Chat/ChatInput.tsx @@ -1,65 +1,27 @@ -import { useState, useCallback, useEffect, useRef, KeyboardEvent } from 'react'; -import { Box, TextField, IconButton, CircularProgress, Typography, Menu, MenuItem, ListItemIcon, ListItemText, Chip } from '@mui/material'; -import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward'; -import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown'; -import StopIcon from '@mui/icons-material/Stop'; -import { apiFetch } from '@/utils/api'; -import { useUserQuota } from '@/hooks/useUserQuota'; -import ClaudeCapDialog from '@/components/ClaudeCapDialog'; -import { useAgentStore } from '@/store/agentStore'; -import { FIRST_FREE_MODEL_PATH } from '@/utils/model'; - -// Model configuration -interface ModelOption { - id: string; - name: string; - description: string; - modelPath: string; - avatarUrl: string; - recommended?: boolean; -} - -const getHfAvatarUrl = (modelId: string) => { - const org = modelId.split('/')[0]; - return `https://huggingface.co/api/avatars/${org}`; -}; - -const MODEL_OPTIONS: ModelOption[] = [ - { - id: 'kimi-k2.6', - name: 'Kimi K2.6', - description: 'Novita', - modelPath: 'moonshotai/Kimi-K2.6', - avatarUrl: getHfAvatarUrl('moonshotai/Kimi-K2.6'), - recommended: true, - }, - { - id: 'claude-opus', - name: 'Claude Opus 4.6', - description: 'Anthropic', - modelPath: 'anthropic/claude-opus-4-6', - avatarUrl: 'https://huggingface.co/api/avatars/Anthropic', - recommended: true, - }, - { - id: 'minimax-m2.7', - name: 'MiniMax M2.7', - description: 'Novita', - modelPath: 'MiniMaxAI/MiniMax-M2.7', - avatarUrl: getHfAvatarUrl('MiniMaxAI/MiniMax-M2.7'), - }, - { - id: 'glm-5.1', - name: 'GLM 5.1', - description: 'Together', - modelPath: 'zai-org/GLM-5.1', - avatarUrl: getHfAvatarUrl('zai-org/GLM-5.1'), - }, -]; - -const findModelByPath = (path: string): ModelOption | undefined => { - return MODEL_OPTIONS.find(m => m.modelPath === path || path?.includes(m.id)); -}; +import { useState, useCallback, useEffect, useRef, KeyboardEvent } from "react"; +import { + Box, + TextField, + IconButton, + CircularProgress, + Typography, + Menu, + MenuItem, + ListItemIcon, + ListItemText, + Chip, + ListSubheader, +} from "@mui/material"; +import ArrowUpwardIcon from "@mui/icons-material/ArrowUpward"; +import ArrowDropDownIcon from "@mui/icons-material/ArrowDropDown"; +import StopIcon from "@mui/icons-material/Stop"; +import { useUserQuota } from "@/hooks/useUserQuota"; +import { useModelCatalog } from "@/hooks/useModelCatalog"; +import type { ModelOption, ProviderOption } from "@/hooks/useModelCatalog"; +import ClaudeCapDialog from "@/components/ClaudeCapDialog"; +import { useAgentStore } from "@/store/agentStore"; +import { FIRST_FREE_MODEL_PATH } from "@/utils/model"; +import CustomModelDialog from "@/components/Chat/CustomModelDialog"; interface ChatInputProps { sessionId?: string; @@ -70,45 +32,43 @@ interface ChatInputProps { placeholder?: string; } -const isClaudeModel = (m: ModelOption) => m.modelPath.startsWith('anthropic/'); -const firstFreeModel = () => MODEL_OPTIONS.find(m => !isClaudeModel(m)) ?? MODEL_OPTIONS[0]; +const OPENAI_COMPAT_PROVIDER = "openai_compat"; -export default function ChatInput({ sessionId, onSend, onStop, isProcessing = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) { - const [input, setInput] = useState(''); +const isClaudePath = (modelPath: string) => modelPath.startsWith("anthropic/"); + +const firstFreeModel = (models: ModelOption[]) => { + const byPath = models.find((m) => m.id === FIRST_FREE_MODEL_PATH); + if (byPath) return byPath; + return models.find((m) => !isClaudePath(m.id)); +}; + +export default function ChatInput({ + sessionId, + onSend, + onStop, + isProcessing = false, + disabled = false, + placeholder = "Ask anything...", +}: ChatInputProps) { + const [input, setInput] = useState(""); const inputRef = useRef(null); - const [selectedModelId, setSelectedModelId] = useState(MODEL_OPTIONS[0].id); + const { + modelOptions, + selectedModelPath, + selectedModel, + switchModel, + groups, + } = useModelCatalog(sessionId); const [modelAnchorEl, setModelAnchorEl] = useState(null); + const [customModalOpen, setCustomModalOpen] = useState(false); + const [customPrefix, setCustomPrefix] = useState("openai-compat/"); const { quota, refresh: refreshQuota } = useUserQuota(); - // The daily-cap dialog is triggered from two places: (a) a 429 returned - // from the chat transport when the user tries to send on Opus over cap — - // surfaced via the agent-store flag — and (b) nothing else right now - // (switching models is free). Keeping the open state in the store means - // the hook layer can flip it without threading props through. const claudeQuotaExhausted = useAgentStore((s) => s.claudeQuotaExhausted); - const setClaudeQuotaExhausted = useAgentStore((s) => s.setClaudeQuotaExhausted); - const lastSentRef = useRef(''); - - // Model is per-session: fetch this tab's current model every time the - // session changes. Other tabs keep their own selections independently. - useEffect(() => { - if (!sessionId) return; - let cancelled = false; - apiFetch(`/api/session/${sessionId}`) - .then((res) => (res.ok ? res.json() : null)) - .then((data) => { - if (cancelled) return; - if (data?.model) { - const model = findModelByPath(data.model); - if (model) setSelectedModelId(model.id); - } - }) - .catch(() => { /* ignore */ }); - return () => { cancelled = true; }; - }, [sessionId]); - - const selectedModel = MODEL_OPTIONS.find(m => m.id === selectedModelId) || MODEL_OPTIONS[0]; + const setClaudeQuotaExhausted = useAgentStore( + (s) => s.setClaudeQuotaExhausted, + ); + const lastSentRef = useRef(""); - // Auto-focus the textarea when the session becomes ready useEffect(() => { if (!disabled && !isProcessing && inputRef.current) { inputRef.current.focus(); @@ -119,33 +79,28 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa if (input.trim() && !disabled) { lastSentRef.current = input; onSend(input); - setInput(''); + setInput(""); } }, [input, disabled, onSend]); - // When the chat transport reports a Claude-quota 429, restore the typed - // text so the user doesn't lose their message. useEffect(() => { if (claudeQuotaExhausted && lastSentRef.current) { setInput(lastSentRef.current); } }, [claudeQuotaExhausted]); - // Refresh the quota display whenever the session changes (user might - // have started another tab that spent quota). useEffect(() => { if (sessionId) refreshQuota(); - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [sessionId]); + }, [sessionId, refreshQuota]); const handleKeyDown = useCallback( (e: KeyboardEvent) => { - if (e.key === 'Enter' && !e.shiftKey) { + if (e.key === "Enter" && !e.shiftKey) { e.preventDefault(); handleSend(); } }, - [handleSend] + [handleSend], ); const handleModelClick = (event: React.MouseEvent) => { @@ -158,51 +113,59 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa const handleSelectModel = async (model: ModelOption) => { handleModelClose(); - if (!sessionId) return; try { - const res = await apiFetch(`/api/session/${sessionId}/model`, { - method: 'POST', - body: JSON.stringify({ model: model.modelPath }), - }); - if (res.ok) setSelectedModelId(model.id); - } catch { /* ignore */ } + await switchModel(model.id, model); + } catch { + // ignore + } + }; + + const handleOpenCustomModal = (provider: ProviderOption) => { + handleModelClose(); + if (!sessionId || !provider.prefix) return; + setCustomPrefix(provider.prefix); + setCustomModalOpen(true); + }; + + const handleCustomSubmit = async (modelId: string) => { + const full = `${customPrefix}${modelId}`; + const info: ModelOption = { + id: full, + label: modelId, + description: "Custom OpenAI-compatible model", + provider: OPENAI_COMPAT_PROVIDER, + providerLabel: "OpenAI-Compatible", + }; + await switchModel(full, info); + setCustomModalOpen(false); }; - // Dialog close: just clear the flag. The typed text is already restored. const handleCapDialogClose = useCallback(() => { setClaudeQuotaExhausted(false); }, [setClaudeQuotaExhausted]); - // "Use a free model" — switch the current session to Kimi (or the first - // non-Anthropic option) and auto-retry the send that tripped the cap. const handleUseFreeModel = useCallback(async () => { setClaudeQuotaExhausted(false); if (!sessionId) return; - const free = MODEL_OPTIONS.find(m => m.modelPath === FIRST_FREE_MODEL_PATH) - ?? firstFreeModel(); + const free = firstFreeModel(modelOptions); + if (!free) return; try { - const res = await apiFetch(`/api/session/${sessionId}/model`, { - method: 'POST', - body: JSON.stringify({ model: free.modelPath }), - }); - if (res.ok) { - setSelectedModelId(free.id); - const retryText = lastSentRef.current; - if (retryText) { - onSend(retryText); - setInput(''); - lastSentRef.current = ''; - } + await switchModel(free.id, free); + const retryText = lastSentRef.current; + if (retryText) { + onSend(retryText); + setInput(""); + lastSentRef.current = ""; } - } catch { /* ignore */ } - }, [sessionId, onSend, setClaudeQuotaExhausted]); + } catch { + // ignore + } + }, [sessionId, modelOptions, onSend, setClaudeQuotaExhausted, switchModel]); - // Hide the chip until the user has actually burned quota — an unused - // Opus session shouldn't populate a counter. const claudeChip = (() => { if (!quota || quota.claudeUsedToday === 0) return null; - if (quota.plan === 'free') { - return quota.claudeRemaining > 0 ? 'Free today' : 'Pro only'; + if (quota.plan === "free") { + return quota.claudeRemaining > 0 ? "Free today" : "Pro only"; } return `${quota.claudeUsedToday}/${quota.claudeDailyCap} today`; })(); @@ -212,26 +175,33 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa sx={{ pb: { xs: 2, md: 4 }, pt: { xs: 1, md: 2 }, - position: 'relative', + position: "relative", zIndex: 10, }} > - + {isProcessing ? ( @@ -275,18 +245,29 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa sx={{ mt: 1, p: 1.5, - borderRadius: '10px', - color: 'var(--muted-text)', - transition: 'all 0.2s', - position: 'relative', - '&:hover': { - bgcolor: 'var(--hover-bg)', - color: 'var(--accent-red)', + borderRadius: "10px", + color: "var(--muted-text)", + transition: "all 0.2s", + position: "relative", + "&:hover": { + bgcolor: "var(--hover-bg)", + color: "var(--accent-red)", }, }} > - - + + @@ -297,14 +278,14 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa sx={{ mt: 1, p: 1, - borderRadius: '10px', - color: 'var(--muted-text)', - transition: 'all 0.2s', - '&:hover': { - color: 'var(--accent-yellow)', - bgcolor: 'var(--hover-bg)', + borderRadius: "10px", + color: "var(--muted-text)", + transition: "all 0.2s", + "&:hover": { + color: "var(--accent-yellow)", + bgcolor: "var(--hover-bg)", }, - '&.Mui-disabled': { + "&.Mui-disabled": { opacity: 0.3, }, }} @@ -314,124 +295,229 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa )} - {/* Powered By Badge */} - + powered by - {selectedModel.name} - - {selectedModel.name} + {selectedModel?.avatarUrl && ( + {selectedModel.label} + )} + + {selectedModel?.label || "Model"} - + - {/* Model Selection Menu */} - {MODEL_OPTIONS.map((model) => ( - handleSelectModel(model)} - selected={selectedModelId === model.id} - sx={{ - py: 1.5, - '&.Mui-selected': { - bgcolor: 'rgba(255,255,255,0.05)', - } - }} - > - - {model.name} - - - {model.name} - {model.recommended && ( - ( + + + {provider.label} + + + {models.map((model) => ( + void handleSelectModel(model)} + disabled={!sessionId} + selected={selectedModelPath === model.id} + sx={{ + py: 1.5, + "&.Mui-selected": { + bgcolor: "rgba(255,255,255,0.05)", + }, + }} + > + + {model.avatarUrl ? ( + {model.label} - )} - {isClaudeModel(model) && claudeChip && ( - + > + {provider.label.slice(0, 2)} + )} - - } - secondary={model.description} - secondaryTypographyProps={{ - sx: { fontSize: '12px', color: 'var(--muted-text)' } - }} - /> - + + + {model.label} + {model.recommended && ( + + )} + {isClaudePath(model.id) && claudeChip && ( + + )} + + } + secondary={model.description} + secondaryTypographyProps={{ + sx: { fontSize: "12px", color: "var(--muted-text)" }, + }} + /> + + ))} + + {provider.id === OPENAI_COMPAT_PROVIDER && + provider.supportsCustomModel && ( + handleOpenCustomModal(provider)} + disabled={!sessionId} + sx={{ py: 1.5 }} + > + " + } + secondaryTypographyProps={{ + sx: { fontSize: "12px", color: "var(--muted-text)" }, + }} + /> + + )} + ))} + {!sessionId && ( + + + Start a session to switch models. + + + )} + setCustomModalOpen(false)} + onSubmit={handleCustomSubmit} + /> + void; + onSubmit: (modelId: string) => Promise; +} + +export default function CustomModelDialog({ + open, + prefix, + onClose, + onSubmit, +}: CustomModelDialogProps) { + const [value, setValue] = useState(""); + const [error, setError] = useState(""); + const [submitting, setSubmitting] = useState(false); + + useEffect(() => { + if (!open) { + setValue(""); + setError(""); + setSubmitting(false); + } + }, [open]); + + const handleSubmit = async () => { + const trimmed = value.trim(); + if (!trimmed) { + setError("Model id is required"); + return; + } + setError(""); + setSubmitting(true); + try { + await onSubmit(trimmed); + } catch { + setError("Failed to switch model"); + } finally { + setSubmitting(false); + } + }; + + return ( + + + Custom OpenAI-compatible model + + + + Enter model id only. We will use server env config for base URL and + key. + + setValue(e.target.value)} + onKeyDown={(e) => { + if (e.key === "Enter" && !submitting) void handleSubmit(); + }} + placeholder="e.g. my-model" + disabled={submitting} + sx={{ + "& .MuiOutlinedInput-root": { + bgcolor: "transparent", + color: "var(--text)", + }, + }} + /> + + Final id:{" "} + + {prefix} + {value.trim() || ""} + + + {error && ( + + {error} + + )} + + + + + + + ); +} diff --git a/frontend/src/hooks/useModelCatalog.ts b/frontend/src/hooks/useModelCatalog.ts new file mode 100644 index 00000000..42d8f7bc --- /dev/null +++ b/frontend/src/hooks/useModelCatalog.ts @@ -0,0 +1,195 @@ +import { useState, useCallback, useEffect, useMemo } from "react"; +import { apiFetch } from "@/utils/api"; + +export interface ModelOption { + id: string; + label: string; + description: string; + provider: string; + providerLabel: string; + avatarUrl?: string; + recommended?: boolean; + source?: string; +} + +export interface ProviderOption { + id: string; + label: string; + avatarUrl?: string; + supportsCustomModel?: boolean; + customModelHint?: string; + customModelMode?: string; + prefix?: string; +} + +interface ModelCatalogResponse { + current?: string; + currentInfo?: ModelOption | null; + available?: ModelOption[]; + providers?: ProviderOption[]; +} + +interface SessionResponse { + model?: string; +} + +const toModelOption = (value: unknown): ModelOption | null => { + if (!value || typeof value !== "object") return null; + const v = value as Record; + if (typeof v.id !== "string" || typeof v.label !== "string") return null; + return { + id: v.id, + label: v.label, + description: typeof v.description === "string" ? v.description : "", + provider: typeof v.provider === "string" ? v.provider : "", + providerLabel: typeof v.providerLabel === "string" ? v.providerLabel : "", + avatarUrl: typeof v.avatarUrl === "string" ? v.avatarUrl : undefined, + recommended: Boolean(v.recommended), + source: typeof v.source === "string" ? v.source : undefined, + }; +}; + +const toProviderOption = (value: unknown): ProviderOption | null => { + if (!value || typeof value !== "object") return null; + const v = value as Record; + if (typeof v.id !== "string" || typeof v.label !== "string") return null; + return { + id: v.id, + label: v.label, + avatarUrl: typeof v.avatarUrl === "string" ? v.avatarUrl : undefined, + supportsCustomModel: Boolean(v.supportsCustomModel), + customModelHint: + typeof v.customModelHint === "string" ? v.customModelHint : undefined, + customModelMode: + typeof v.customModelMode === "string" ? v.customModelMode : undefined, + prefix: typeof v.prefix === "string" ? v.prefix : undefined, + }; +}; + +export function useModelCatalog(sessionId?: string) { + const [modelOptions, setModelOptions] = useState([]); + const [providerOptions, setProviderOptions] = useState([]); + const [selectedModelPath, setSelectedModelPath] = useState(""); + const [selectedModelInfo, setSelectedModelInfo] = + useState(null); + + useEffect(() => { + let cancelled = false; + + const loadCatalog = async () => { + try { + const res = await apiFetch("/api/config/model"); + if (!res.ok || cancelled) return; + const data = (await res.json()) as ModelCatalogResponse; + const available = (data.available || []) + .map(toModelOption) + .filter((v): v is ModelOption => v !== null); + const providers = (data.providers || []) + .map(toProviderOption) + .filter((v): v is ProviderOption => v !== null); + const currentInfo = toModelOption(data.currentInfo ?? null); + if (cancelled) return; + + setModelOptions(available); + setProviderOptions(providers); + setSelectedModelPath(data.current || ""); + setSelectedModelInfo(currentInfo); + } catch { + // ignore — catalog is optional, models may not be configured + } + }; + + void loadCatalog(); + return () => { + cancelled = true; + }; + }, []); + + useEffect(() => { + if (!sessionId) return; + let cancelled = false; + apiFetch(`/api/session/${sessionId}`) + .then((res) => (res.ok ? res.json() : null)) + .then((data: SessionResponse | null) => { + if (cancelled || !data?.model) return; + setSelectedModelPath(data.model); + const model = modelOptions.find((m) => m.id === data.model); + if (model) { + setSelectedModelInfo(model); + return; + } + setSelectedModelInfo((prev) => + prev && prev.id === data.model + ? prev + : { + id: data.model, + label: data.model, + description: "Custom model", + provider: "", + providerLabel: "", + }, + ); + }) + .catch(() => { + // ignore + }); + return () => { + cancelled = true; + }; + }, [sessionId, modelOptions]); + + const selectedModel = + selectedModelInfo || + modelOptions.find((m) => m.id === selectedModelPath) || + (selectedModelPath + ? { + id: selectedModelPath, + label: selectedModelPath, + description: "Custom model", + provider: "", + providerLabel: "", + } + : null); + + const switchModel = useCallback( + async (modelPath: string, info?: ModelOption) => { + if (!sessionId) return; + const res = await apiFetch(`/api/session/${sessionId}/model`, { + method: "POST", + body: JSON.stringify({ model: modelPath }), + }); + if (res.ok) { + setSelectedModelPath(modelPath); + setSelectedModelInfo( + info || modelOptions.find((m) => m.id === modelPath) || null, + ); + } + }, + [sessionId, modelOptions], + ); + + const groups = useMemo( + () => + providerOptions.map((provider) => ({ + provider, + models: modelOptions + .filter((m) => m.provider === provider.id) + .sort((a, b) => { + const ar = a.recommended ? 0 : 1; + const br = b.recommended ? 0 : 1; + if (ar !== br) return ar - br; + return a.label.localeCompare(b.label); + }), + })), + [providerOptions, modelOptions], + ); + + return { + modelOptions, + providerOptions, + selectedModelPath, + selectedModel, + switchModel, + groups, + }; +} diff --git a/frontend/src/utils/model.ts b/frontend/src/utils/model.ts index 89f23fe7..8ee4e297 100644 --- a/frontend/src/utils/model.ts +++ b/frontend/src/utils/model.ts @@ -2,9 +2,8 @@ * Shared model-id constants used by session-create call sites and the * ClaudeCapDialog "Use a free model" escape hatch. * - * Keep in sync with MODEL_OPTIONS in components/Chat/ChatInput.tsx and - * AVAILABLE_MODELS in backend/routes/agent.py. Bare HF ids (no - * `huggingface/` prefix) — matches upstream's auto-router. + * ChatInput now loads catalog from /api/config/model, but this fallback + * free-model constant is still used for the Claude-cap escape hatch. */ export const CLAUDE_MODEL_PATH = 'anthropic/claude-opus-4-6'; diff --git a/tests/test_llm_errors.py b/tests/test_llm_errors.py new file mode 100644 index 00000000..d36ebed7 --- /dev/null +++ b/tests/test_llm_errors.py @@ -0,0 +1,124 @@ +import asyncio +from types import SimpleNamespace + +from rich.console import Console + +import agent.core.model_switcher as model_switcher +from agent.core.effort_probe import ProbeInconclusive +from agent.core.llm_errors import ( + classify_llm_error, + friendly_llm_error_message, + health_error_type, + render_llm_error_message, +) + + +def test_auth_errors_get_clean_message() -> None: + error = Exception("401 unauthorized: invalid api key") + + assert classify_llm_error(error) == "auth" + assert "Authentication failed" in friendly_llm_error_message(error) + + +def test_missing_api_key_header_gets_clean_message() -> None: + error = Exception("authentication_error: x-api-key header is required") + + assert classify_llm_error(error) == "auth" + assert render_llm_error_message(error).startswith("Authentication failed") + + +def test_openai_missing_api_key_gets_clean_message() -> None: + error = Exception( + "You didn't provide an API key. You need to provide your API key in an Authorization header." + ) + + assert classify_llm_error(error) == "auth" + assert render_llm_error_message(error).startswith("Authentication failed") + + +def test_anthropic_low_credit_error_gets_clean_message() -> None: + error = Exception( + "Your credit balance is too low to access the Anthropic API. " + "Please go to Plans & Billing to upgrade or purchase credits." + ) + + assert classify_llm_error(error) == "credits" + assert render_llm_error_message(error).startswith( + "Insufficient API credits or quota" + ) + + +def test_model_not_found_error_gets_clean_message() -> None: + error = Exception("model_not_found: requested model does not exist") + + assert classify_llm_error(error) == "model" + assert render_llm_error_message(error).startswith("Model not found") + + +def test_unknown_errors_fall_back_to_plain_exception_text() -> None: + error = RuntimeError("boom") + + assert classify_llm_error(error) == "unknown" + assert render_llm_error_message(error) == "boom" + + +def test_health_error_type_keeps_public_categories_stable() -> None: + assert health_error_type(Exception("invalid api key")) == "auth" + assert health_error_type(Exception("credit balance is too low")) == "credits" + assert health_error_type(Exception("rate limit exceeded")) == "rate_limit" + assert health_error_type(Exception("model_not_found")) == "unknown" + + +def test_model_switcher_shows_clean_hard_failure(monkeypatch) -> None: + async def fake_probe_effort(*args, **kwargs): + raise Exception( + "Your credit balance is too low to access the Anthropic API. " + "Please go to Plans & Billing to upgrade or purchase credits." + ) + + monkeypatch.setattr(model_switcher, "probe_effort", fake_probe_effort) + console = Console(record=True, width=120) + config = SimpleNamespace( + reasoning_effort="high", + model_name="anthropic/claude-opus-4-6", + ) + + asyncio.run( + model_switcher.probe_and_switch_model( + "anthropic/claude-opus-4-7", + config, + None, + console, + None, + ) + ) + + output = console.export_text() + assert "Insufficient API credits or quota" in output + assert "credit balance is too low" not in output.lower() + + +def test_model_switcher_shows_clean_inconclusive_warning(monkeypatch) -> None: + async def fake_probe_effort(*args, **kwargs): + raise ProbeInconclusive("timeout talking to provider") + + monkeypatch.setattr(model_switcher, "probe_effort", fake_probe_effort) + console = Console(record=True, width=120) + config = SimpleNamespace( + reasoning_effort="high", + model_name="anthropic/claude-opus-4-6", + ) + + asyncio.run( + model_switcher.probe_and_switch_model( + "anthropic/claude-opus-4-7", + config, + None, + console, + None, + ) + ) + + output = console.export_text() + assert "The model provider is unavailable or timed out" in output + assert "timeout talking to provider" not in output.lower() diff --git a/tests/test_provider_adapters.py b/tests/test_provider_adapters.py new file mode 100644 index 00000000..2e1141e8 --- /dev/null +++ b/tests/test_provider_adapters.py @@ -0,0 +1,391 @@ +import pytest + +import agent.core.provider_adapters as providers +from agent.core.llm_params import _resolve_llm_params +from agent.core.model_switcher import is_valid_model_id +from agent.core.provider_adapters import ( + UnsupportedEffortError, + is_valid_model_name, +) + + +# -- Anthropic adapter ------------------------------------------------------- + + +def test_anthropic_adapter_builds_thinking_config(): + params = _resolve_llm_params("anthropic/claude-opus-4-6", reasoning_effort="high") + + assert params == { + "model": "anthropic/claude-opus-4-6", + "thinking": {"type": "adaptive"}, + "output_config": {"effort": "high"}, + } + + +def test_anthropic_adapter_normalizes_minimal_to_low(): + params = _resolve_llm_params( + "anthropic/claude-opus-4-7", reasoning_effort="minimal" + ) + + assert params["output_config"] == {"effort": "low"} + + +def test_anthropic_adapter_no_effort(): + params = _resolve_llm_params("anthropic/claude-opus-4-6") + + assert params == {"model": "anthropic/claude-opus-4-6"} + + +def test_anthropic_adapter_strict_rejects_invalid(): + with pytest.raises(UnsupportedEffortError): + _resolve_llm_params( + "anthropic/claude-opus-4-6", reasoning_effort="turbo", strict=True + ) + + +def test_anthropic_adapter_nonstrict_drops_invalid(): + params = _resolve_llm_params( + "anthropic/claude-opus-4-6", reasoning_effort="turbo", strict=False + ) + assert "thinking" not in params + assert "output_config" not in params + + +# -- OpenAI adapter ----------------------------------------------------------- + + +def test_openai_adapter_passes_reasoning_effort(): + params = _resolve_llm_params("openai/gpt-5", reasoning_effort="medium") + + assert params == {"model": "openai/gpt-5", "reasoning_effort": "medium"} + + +def test_openai_adapter_strict_rejects_max(): + with pytest.raises(UnsupportedEffortError): + _resolve_llm_params("openai/gpt-5", reasoning_effort="max", strict=True) + + +# -- Bedrock adapter ---------------------------------------------------------- + + +def test_bedrock_adapter_returns_model_only(): + params = _resolve_llm_params("bedrock/us.anthropic.claude-opus-4-7") + assert params == {"model": "bedrock/us.anthropic.claude-opus-4-7"} + + +def test_bedrock_adapter_ignores_effort(): + params = _resolve_llm_params( + "bedrock/us.anthropic.claude-opus-4-6-v1", reasoning_effort="high" + ) + assert params == {"model": "bedrock/us.anthropic.claude-opus-4-6-v1"} + + +def test_bedrock_validation(): + assert is_valid_model_name("bedrock/us.anthropic.claude-opus-4-7") is True + assert is_valid_model_name("bedrock/") is False + + +# -- Gemini adapter ----------------------------------------------------------- + + +def test_gemini_adapter_passes_reasoning_effort(): + params = _resolve_llm_params("gemini/gemini-2.5-pro", reasoning_effort="medium") + assert params == {"model": "gemini/gemini-2.5-pro", "reasoning_effort": "medium"} + + +def test_gemini_adapter_normalizes_minimal(): + params = _resolve_llm_params("gemini/gemini-2.5-flash", reasoning_effort="minimal") + assert params == {"model": "gemini/gemini-2.5-flash", "reasoning_effort": "low"} + + +def test_gemini_adapter_no_effort(): + params = _resolve_llm_params("gemini/gemini-2.5-pro") + assert params == {"model": "gemini/gemini-2.5-pro"} + + +def test_gemini_adapter_strict_rejects_invalid(): + with pytest.raises(UnsupportedEffortError): + _resolve_llm_params("gemini/gemini-2.5-pro", reasoning_effort="max", strict=True) + with pytest.raises(UnsupportedEffortError): + _resolve_llm_params("gemini/gemini-2.5-pro", reasoning_effort="xhigh", strict=True) + + +def test_gemini_validation(): + assert is_valid_model_name("gemini/gemini-2.5-pro") is True + assert is_valid_model_name("gemini/") is False + + +# -- HF Router adapter -------------------------------------------------------- + + +def test_hf_adapter_builds_router_params(monkeypatch): + monkeypatch.setenv("HF_TOKEN", "hf-test") + + params = _resolve_llm_params( + "moonshotai/Kimi-K2.6:novita", reasoning_effort="minimal" + ) + + assert params == { + "model": "openai/moonshotai/Kimi-K2.6:novita", + "api_base": "https://router.huggingface.co/v1", + "api_key": "hf-test", + "extra_body": {"reasoning_effort": "low"}, + } + + +def test_hf_adapter_adds_bill_to_header(monkeypatch): + monkeypatch.setenv("INFERENCE_TOKEN", "hf-space-token") + monkeypatch.delenv("HF_TOKEN", raising=False) + + params = _resolve_llm_params("MiniMaxAI/MiniMax-M2.7") + + assert params["extra_headers"] == {"X-HF-Bill-To": "smolagents"} + assert params["api_key"] == "hf-space-token" + + +def test_hf_adapter_strict_rejects_max(): + with pytest.raises(UnsupportedEffortError): + _resolve_llm_params( + "MiniMaxAI/MiniMax-M2.7", reasoning_effort="max", strict=True + ) + + +# -- OpenAI-compatible adapters ------------------------------------------------ + + +def test_ollama_adapter_builds_params(monkeypatch): + monkeypatch.delenv("OLLAMA_API_BASE", raising=False) + monkeypatch.delenv("OLLAMA_API_KEY", raising=False) + + params = _resolve_llm_params("ollama/llama3.1") + + assert params == { + "model": "openai/llama3.1", + "api_base": "http://localhost:11434/v1", + "api_key": "ollama", + } + + +def test_ollama_adapter_normalizes_base_url(monkeypatch): + monkeypatch.setenv("OLLAMA_API_BASE", "http://localhost:11434") + + params = _resolve_llm_params("ollama/llama3.1") + + assert params["api_base"] == "http://localhost:11434/v1" + + +def test_ollama_adapter_strict_rejects_effort(): + with pytest.raises(UnsupportedEffortError): + _resolve_llm_params("ollama/llama3.1", reasoning_effort="high", strict=True) + + +def test_lm_studio_adapter_uses_raw_model_name(monkeypatch): + monkeypatch.delenv("LMSTUDIO_BASE_URL", raising=False) + monkeypatch.delenv("LMSTUDIO_API_KEY", raising=False) + + params = _resolve_llm_params("lm_studio/google/gemma-3-12b") + + assert params == { + "model": "google/gemma-3-12b", + "api_base": "http://127.0.0.1:1234/v1", + "api_key": "lm-studio", + } + + +def test_vllm_adapter_uses_env_override(monkeypatch): + monkeypatch.setenv("VLLM_BASE_URL", "http://127.0.0.1:8000") + monkeypatch.setenv("VLLM_API_KEY", "secret") + + params = _resolve_llm_params("vllm/Qwen3-32B") + + assert params == { + "model": "openai/Qwen3-32B", + "api_base": "http://127.0.0.1:8000/v1", + "api_key": "secret", + } + + +def test_openrouter_adapter_uses_api_key(monkeypatch): + monkeypatch.setenv("OPENROUTER_API_KEY", "router-key") + + params = _resolve_llm_params( + "openrouter/anthropic/claude-sonnet-4.5", reasoning_effort="medium" + ) + + assert params == { + "model": "openai/anthropic/claude-sonnet-4.5", + "api_base": "https://openrouter.ai/api/v1", + "api_key": "router-key", + "reasoning_effort": "medium", + } + + +def test_opencode_zen_adapter_uses_api_key(monkeypatch): + monkeypatch.setenv("OPENCODE_ZEN_API_KEY", "zen-key") + + params = _resolve_llm_params("opencode/kimi-k2.6") + + assert params == { + "model": "openai/kimi-k2.6", + "api_base": "https://opencode.ai/zen/v1", + "api_key": "zen-key", + } + + +def test_opencode_go_adapter_uses_api_key(monkeypatch): + monkeypatch.setenv("OPENCODE_GO_API_KEY", "go-key") + + params = _resolve_llm_params("opencode-go/kimi-k2.6") + + assert params == { + "model": "openai/kimi-k2.6", + "api_base": "https://opencode.ai/zen/go/v1", + "api_key": "go-key", + } + + +def test_openai_compat_requires_base_url(monkeypatch): + monkeypatch.delenv("OPENAI_COMPAT_BASE_URL", raising=False) + + with pytest.raises(ValueError, match="OPENAI_COMPAT_BASE_URL"): + _resolve_llm_params("openai-compat/my-model") + + +def test_openai_compat_uses_required_base_url(monkeypatch): + monkeypatch.setenv("OPENAI_COMPAT_BASE_URL", "http://localhost:8080") + monkeypatch.setenv("OPENAI_COMPAT_API_KEY", "compat-key") + + params = _resolve_llm_params("openai-compat/my-model") + + assert params == { + "model": "openai/my-model", + "api_base": "http://localhost:8080/v1", + "api_key": "compat-key", + } + + +# -- Validation --------------------------------------------------------------- + + +def test_model_validation_accepts_free_form_hf_ids(): + assert is_valid_model_name("moonshotai/Kimi-K2.6:fastest") is True + assert is_valid_model_name("huggingface/moonshotai/Kimi-K2.6:novita") is True + + +def test_model_validation_accepts_direct_provider_ids(): + assert is_valid_model_name("anthropic/claude-opus-4-7") is True + assert is_valid_model_name("openai/gpt-5") is True + assert is_valid_model_name("bedrock/us.anthropic.claude-opus-4-7") is True + assert is_valid_model_name("gemini/gemini-2.5-pro") is True + assert is_valid_model_name("ollama/llama3.1") is True + assert is_valid_model_name("lm_studio/google/gemma-3-12b") is True + assert is_valid_model_name("vllm/Qwen3-32B") is True + assert is_valid_model_name("openrouter/anthropic/claude-sonnet-4.5") is True + assert is_valid_model_name("opencode/kimi-k2.6") is True + assert is_valid_model_name("opencode-go/kimi-k2.6") is True + assert is_valid_model_name("openai-compat/my-model") is True + + +def test_model_validation_rejects_garbage(): + assert is_valid_model_name("") is False + assert is_valid_model_name("no-slash") is False + assert is_valid_model_name("anthropic/") is False + assert is_valid_model_name("openai/") is False + assert is_valid_model_name("gemini/") is False + assert is_valid_model_name("ollama/") is False + assert is_valid_model_name("lm_studio/") is False + assert is_valid_model_name("vllm/") is False + assert is_valid_model_name("openrouter/") is False + assert is_valid_model_name("opencode/") is False + assert is_valid_model_name("opencode-go/") is False + assert is_valid_model_name("openai-compat/") is False + assert is_valid_model_name("huggingface/nope") is False + assert is_valid_model_name("moonshotai/") is False + + +def test_hf_validation_excludes_new_provider_prefixes(): + hf = providers.resolve_adapter("openrouter/google/gemini-2.5-pro") + + assert hf is not None + assert hf.provider_id == "openrouter" + + +def test_model_catalog_hides_local_providers_when_unreachable(monkeypatch): + monkeypatch.setattr(providers.OllamaAdapter, "available_models", lambda self: ()) + monkeypatch.setattr(providers.LmStudioAdapter, "available_models", lambda self: ()) + monkeypatch.setattr(providers.VllmAdapter, "available_models", lambda self: ()) + + catalog = providers.build_model_catalog("anthropic/claude-opus-4-6") + provider_ids = {p["id"] for p in catalog["providers"]} + + assert "ollama" not in provider_ids + assert "lm_studio" not in provider_ids + assert "vllm" not in provider_ids + + +def test_model_catalog_includes_local_providers_when_discovered(monkeypatch): + dynamic = ( + providers.SuggestedModel( + id="ollama/llama3.1", + label="llama3.1", + description="Ollama", + provider="ollama", + provider_label="Ollama", + avatar_url="avatar", + source="dynamic", + ), + ) + monkeypatch.setattr( + providers.OllamaAdapter, "available_models", lambda self: dynamic + ) + + catalog = providers.build_model_catalog("ollama/llama3.1") + + assert any(m["id"] == "ollama/llama3.1" for m in catalog["available"]) + assert any(p["id"] == "ollama" for p in catalog["providers"]) + + +def test_model_catalog_openai_compat_visibility_depends_on_env(monkeypatch): + monkeypatch.delenv("OPENAI_COMPAT_BASE_URL", raising=False) + hidden = providers.build_model_catalog("anthropic/claude-opus-4-6") + assert not any(p["id"] == "openai_compat" for p in hidden["providers"]) + + monkeypatch.setenv("OPENAI_COMPAT_BASE_URL", "http://localhost:8080/v1") + shown = providers.build_model_catalog("anthropic/claude-opus-4-6") + assert any(p["id"] == "openai_compat" for p in shown["providers"]) + + +def test_model_catalog_includes_current_info_for_custom_model(monkeypatch): + monkeypatch.setenv("OPENAI_COMPAT_BASE_URL", "http://localhost:8080/v1") + + catalog = providers.build_model_catalog("openai-compat/my-model") + + assert catalog["currentInfo"] is not None + assert catalog["currentInfo"]["id"] == "openai-compat/my-model" + assert catalog["currentInfo"]["source"] == "custom" + + +def test_cli_validation_matches_provider_validation(): + assert is_valid_model_id("openai/gpt-5") is True + assert is_valid_model_id("moonshotai/Kimi-K2.6:fastest") is True + assert is_valid_model_id("ollama/llama3.1") is True + assert is_valid_model_id("openrouter/anthropic/claude-sonnet-4.5") is True + assert is_valid_model_id("openai-compat/my-model") is True + assert is_valid_model_id("openai/") is False + assert is_valid_model_id("anthropic/") is False + + +def test_resolve_raises_on_no_adapter(monkeypatch): + from agent.core import llm_params + + monkeypatch.setattr(llm_params, "resolve_adapter", lambda _: None) + with pytest.raises(ValueError, match="No provider adapter"): + _resolve_llm_params("anything") + + +def test_unsupported_effort_reexport(): + """UnsupportedEffortError must be importable from llm_params (backward compat).""" + from agent.core.llm_params import UnsupportedEffortError as FromLlm + from agent.core.provider_adapters import UnsupportedEffortError as FromAdapters + + assert FromLlm is FromAdapters