From 267618d9692d48d95576a981ec83c8eec0dc96db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Vy=C5=A1n=C3=BD?= Date: Wed, 22 Apr 2026 15:58:57 +0200 Subject: [PATCH 01/15] Refactor LLM param resolution into adapters --- agent/core/llm_params.py | 42 +++++------------- agent/core/provider_adapters.py | 79 +++++++++++++++++++++++++++++++++ tests/test_provider_adapters.py | 35 +++++++++++++++ 3 files changed, 124 insertions(+), 32 deletions(-) create mode 100644 agent/core/provider_adapters.py create mode 100644 tests/test_provider_adapters.py diff --git a/agent/core/llm_params.py b/agent/core/llm_params.py index 7aa68a63..980a043e 100644 --- a/agent/core/llm_params.py +++ b/agent/core/llm_params.py @@ -5,14 +5,7 @@ creating circular imports. """ -import os - - -# HF router reasoning models only accept "low" | "medium" | "high" (e.g. -# MiniMax M2 actually *requires* reasoning to be enabled). OpenAI's GPT-5 -# also accepts "minimal" for near-zero thinking. We map "minimal" to "low" -# for HF so the user doesn't get a 400. -_HF_ALLOWED_EFFORTS = {"low", "medium", "high"} +from agent.core.provider_adapters import ADAPTERS def _resolve_llm_params( @@ -50,27 +43,12 @@ def _resolve_llm_params( 2. session.hf_token — the user's own token (CLI / OAuth / cache file). 3. HF_TOKEN env — belt-and-suspenders fallback for CLI users. """ - if model_name.startswith(("anthropic/", "openai/")): - params: dict = {"model": model_name} - if reasoning_effort: - params["reasoning_effort"] = reasoning_effort - return params - - hf_model = model_name.removeprefix("huggingface/") - api_key = ( - os.environ.get("INFERENCE_TOKEN") - or session_hf_token - or os.environ.get("HF_TOKEN") - ) - params = { - "model": f"openai/{hf_model}", - "api_base": "https://router.huggingface.co/v1", - "api_key": api_key, - } - if os.environ.get("INFERENCE_TOKEN"): - params["extra_headers"] = {"X-HF-Bill-To": "huggingface"} - if reasoning_effort: - hf_level = "low" if reasoning_effort == "minimal" else reasoning_effort - if hf_level in _HF_ALLOWED_EFFORTS: - params["extra_body"] = {"reasoning_effort": hf_level} - return params + for adapter in ADAPTERS: + if adapter.matches(model_name): + return adapter.build_params( + model_name, + session_hf_token=session_hf_token, + reasoning_effort=reasoning_effort, + ) + + raise ValueError(f"Unsupported model id: {model_name}") diff --git a/agent/core/provider_adapters.py b/agent/core/provider_adapters.py new file mode 100644 index 00000000..8cf5e84a --- /dev/null +++ b/agent/core/provider_adapters.py @@ -0,0 +1,79 @@ +"""Provider-specific LiteLLM parameter builders.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass + + +class ProviderAdapter: + """Build LiteLLM kwargs for one family of model ids.""" + + def matches(self, model_name: str) -> bool: + raise NotImplementedError + + def build_params( + self, + model_name: str, + session_hf_token: str | None = None, + reasoning_effort: str | None = None, + ) -> dict: + raise NotImplementedError + + +@dataclass(frozen=True) +class NativeAdapter(ProviderAdapter): + prefixes: tuple[str, ...] = ("anthropic/", "openai/") + + def matches(self, model_name: str) -> bool: + return model_name.startswith(self.prefixes) + + def build_params( + self, + model_name: str, + session_hf_token: str | None = None, + reasoning_effort: str | None = None, + ) -> dict: + del session_hf_token + params: dict = {"model": model_name} + if reasoning_effort: + params["reasoning_effort"] = reasoning_effort + return params + + +@dataclass(frozen=True) +class HfRouterAdapter(ProviderAdapter): + allowed_efforts: tuple[str, ...] = ("low", "medium", "high") + + def matches(self, model_name: str) -> bool: + return "/" in model_name and not model_name.startswith( + ("anthropic/", "openai/") + ) + + def build_params( + self, + model_name: str, + session_hf_token: str | None = None, + reasoning_effort: str | None = None, + ) -> dict: + hf_model = model_name.removeprefix("huggingface/") + inference_token = os.environ.get("INFERENCE_TOKEN") + api_key = inference_token or session_hf_token or os.environ.get("HF_TOKEN") + params = { + "model": f"openai/{hf_model}", + "api_base": "https://router.huggingface.co/v1", + "api_key": api_key, + } + if inference_token: + params["extra_headers"] = {"X-HF-Bill-To": "huggingface"} + if reasoning_effort: + hf_level = "low" if reasoning_effort == "minimal" else reasoning_effort + if hf_level in self.allowed_efforts: + params["extra_body"] = {"reasoning_effort": hf_level} + return params + + +ADAPTERS: tuple[ProviderAdapter, ...] = ( + NativeAdapter(), + HfRouterAdapter(), +) diff --git a/tests/test_provider_adapters.py b/tests/test_provider_adapters.py new file mode 100644 index 00000000..147ea7eb --- /dev/null +++ b/tests/test_provider_adapters.py @@ -0,0 +1,35 @@ +from agent.core.llm_params import _resolve_llm_params + + +def test_native_adapter_keeps_model_name(): + params = _resolve_llm_params("anthropic/claude-opus-4-6", reasoning_effort="high") + + assert params == { + "model": "anthropic/claude-opus-4-6", + "reasoning_effort": "high", + } + + +def test_hf_adapter_builds_router_params(monkeypatch): + monkeypatch.setenv("HF_TOKEN", "hf-test") + + params = _resolve_llm_params( + "moonshotai/Kimi-K2.6:novita", reasoning_effort="minimal" + ) + + assert params == { + "model": "openai/moonshotai/Kimi-K2.6:novita", + "api_base": "https://router.huggingface.co/v1", + "api_key": "hf-test", + "extra_body": {"reasoning_effort": "low"}, + } + + +def test_hf_adapter_adds_bill_to_header(monkeypatch): + monkeypatch.setenv("INFERENCE_TOKEN", "hf-space-token") + monkeypatch.delenv("HF_TOKEN", raising=False) + + params = _resolve_llm_params("MiniMaxAI/MiniMax-M2.7") + + assert params["extra_headers"] == {"X-HF-Bill-To": "huggingface"} + assert params["api_key"] == "hf-space-token" From dcb4aff1f1c404fb1800f0c84acc401ad18f22c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Vy=C5=A1n=C3=BD?= Date: Wed, 22 Apr 2026 18:32:24 +0200 Subject: [PATCH 02/15] Remove unused session_hf_token parameter from NativeAdapter.build_params --- agent/core/provider_adapters.py | 1 - 1 file changed, 1 deletion(-) diff --git a/agent/core/provider_adapters.py b/agent/core/provider_adapters.py index 3a701a50..a876a536 100644 --- a/agent/core/provider_adapters.py +++ b/agent/core/provider_adapters.py @@ -19,7 +19,6 @@ def build_params( session_hf_token: str | None = None, reasoning_effort: str | None = None, ) -> dict: - del session_hf_token params: dict = {"model": model_name} if reasoning_effort: params["reasoning_effort"] = reasoning_effort From 2f8565026080c6986487a4059c1f3c1e4f3497c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Vy=C5=A1n=C3=BD?= Date: Wed, 22 Apr 2026 20:08:17 +0200 Subject: [PATCH 03/15] Unify model catalog around provider adapters --- agent/core/llm_params.py | 16 +- agent/core/provider_adapters.py | 198 +++++++++++++++++++-- agent/main.py | 150 ++++++++++------ backend/routes/agent.py | 56 ++---- frontend/src/components/Chat/ChatInput.tsx | 195 +++++++++++--------- tests/test_provider_adapters.py | 15 ++ 6 files changed, 437 insertions(+), 193 deletions(-) diff --git a/agent/core/llm_params.py b/agent/core/llm_params.py index 980a043e..7f1f638d 100644 --- a/agent/core/llm_params.py +++ b/agent/core/llm_params.py @@ -5,7 +5,7 @@ creating circular imports. """ -from agent.core.provider_adapters import ADAPTERS +from agent.core.provider_adapters import resolve_adapter def _resolve_llm_params( @@ -43,12 +43,12 @@ def _resolve_llm_params( 2. session.hf_token — the user's own token (CLI / OAuth / cache file). 3. HF_TOKEN env — belt-and-suspenders fallback for CLI users. """ - for adapter in ADAPTERS: - if adapter.matches(model_name): - return adapter.build_params( - model_name, - session_hf_token=session_hf_token, - reasoning_effort=reasoning_effort, - ) + adapter = resolve_adapter(model_name) + if adapter: + return adapter.build_params( + model_name, + session_hf_token=session_hf_token, + reasoning_effort=reasoning_effort, + ) raise ValueError(f"Unsupported model id: {model_name}") diff --git a/agent/core/provider_adapters.py b/agent/core/provider_adapters.py index a876a536..adb255f6 100644 --- a/agent/core/provider_adapters.py +++ b/agent/core/provider_adapters.py @@ -1,36 +1,135 @@ -"""Provider-specific LiteLLM parameter builders.""" +"""Provider adapters for runtime params and model catalog metadata.""" import os from dataclasses import dataclass +from typing import Any -_NATIVE_PREFIXES = ("anthropic/", "openai/") + +@dataclass(frozen=True) +class SuggestedModel: + id: str + label: str + description: str + provider: str + provider_label: str + avatar_url: str + recommended: bool = False @dataclass(frozen=True) -class NativeAdapter: - prefixes: tuple[str, ...] = _NATIVE_PREFIXES +class ProviderAdapter: + provider_id: str + provider_label: str + prefixes: tuple[str, ...] = () + supports_custom_model: bool = False + custom_model_hint: str | None = None + + def matches(self, model_name: str) -> bool: + return bool(self.prefixes) and model_name.startswith(self.prefixes) + + def suggested_models(self) -> tuple[SuggestedModel, ...]: + return () + + def build_params( + self, + model_name: str, + session_hf_token: str | None = None, + reasoning_effort: str | None = None, + ) -> dict: + raise NotImplementedError + + def allows_model_name(self, model_name: str) -> bool: + if any(model.id == model_name for model in self.suggested_models()): + return True + return self.supports_custom_model and self.matches(model_name) + + def to_summary(self) -> dict[str, Any]: + return { + "id": self.provider_id, + "label": self.provider_label, + "supportsCustomModel": self.supports_custom_model, + "customModelHint": self.custom_model_hint, + } + + +@dataclass(frozen=True) +class NativeAdapter(ProviderAdapter): + prefixes: tuple[str, ...] = ("anthropic/", "openai/") def matches(self, model_name: str) -> bool: return model_name.startswith(self.prefixes) + def suggested_models(self) -> tuple[SuggestedModel, ...]: + return ( + SuggestedModel( + id="anthropic/claude-opus-4-6", + label="Claude Opus 4.6", + description="Anthropic", + provider="anthropic", + provider_label="Anthropic", + avatar_url="https://huggingface.co/api/avatars/Anthropic", + recommended=True, + ), + ) + def build_params( self, model_name: str, session_hf_token: str | None = None, reasoning_effort: str | None = None, ) -> dict: - params: dict = {"model": model_name} + params: dict[str, Any] = {"model": model_name} if reasoning_effort: params["reasoning_effort"] = reasoning_effort return params @dataclass(frozen=True) -class HfRouterAdapter: +class HfRouterAdapter(ProviderAdapter): allowed_efforts: tuple[str, ...] = ("low", "medium", "high") + def _is_hf_model_name(self, model_name: str) -> bool: + if model_name.startswith(("anthropic/", "openai/")): + return False + + bare = model_name.removeprefix("huggingface/").split(":", 1)[0] + parts = bare.split("/") + return len(parts) >= 2 and all(parts) + def matches(self, model_name: str) -> bool: - return "/" in model_name and not model_name.startswith(_NATIVE_PREFIXES) + return self._is_hf_model_name(model_name) + + def suggested_models(self) -> tuple[SuggestedModel, ...]: + return ( + SuggestedModel( + id="MiniMaxAI/MiniMax-M2.7", + label="MiniMax M2.7", + description="HF Router", + provider="huggingface", + provider_label="Hugging Face Router", + avatar_url="https://huggingface.co/api/avatars/MiniMaxAI", + recommended=True, + ), + SuggestedModel( + id="moonshotai/Kimi-K2.6", + label="Kimi K2.6", + description="HF Router", + provider="huggingface", + provider_label="Hugging Face Router", + avatar_url="https://huggingface.co/api/avatars/moonshotai", + ), + SuggestedModel( + id="zai-org/GLM-5.1", + label="GLM 5.1", + description="HF Router", + provider="huggingface", + provider_label="Hugging Face Router", + avatar_url="https://huggingface.co/api/avatars/zai-org", + ), + ) + + def allows_model_name(self, model_name: str) -> bool: + return self._is_hf_model_name(model_name) def build_params( self, @@ -41,22 +140,99 @@ def build_params( hf_model = model_name.removeprefix("huggingface/") inference_token = os.environ.get("INFERENCE_TOKEN") api_key = inference_token or session_hf_token or os.environ.get("HF_TOKEN") - params = { + + params: dict[str, Any] = { "model": f"openai/{hf_model}", "api_base": "https://router.huggingface.co/v1", "api_key": api_key, } + if inference_token: bill_to = os.environ.get("HF_BILL_TO", "smolagents") params["extra_headers"] = {"X-HF-Bill-To": bill_to} + if reasoning_effort: hf_level = "low" if reasoning_effort == "minimal" else reasoning_effort if hf_level in self.allowed_efforts: params["extra_body"] = {"reasoning_effort": hf_level} + return params -ADAPTERS = ( - NativeAdapter(), - HfRouterAdapter(), +ADAPTERS: tuple[ProviderAdapter, ...] = ( + NativeAdapter(provider_id="native", provider_label="Native"), + HfRouterAdapter( + provider_id="huggingface", + provider_label="Hugging Face Router", + supports_custom_model=True, + custom_model_hint=( + "Paste any Hugging Face model id, optionally with " + ":fastest, :cheapest, :preferred, or :" + ), + ), ) + + +def resolve_adapter(model_name: str) -> ProviderAdapter | None: + for adapter in ADAPTERS: + if adapter.matches(model_name): + return adapter + return None + + +def is_valid_model_name(model_name: str) -> bool: + adapter = resolve_adapter(model_name) + if not adapter: + return False + return adapter.allows_model_name(model_name) + + +def get_available_models() -> list[dict[str, Any]]: + available: list[dict[str, Any]] = [] + for adapter in ADAPTERS: + for model in adapter.suggested_models(): + available.append( + { + "id": model.id, + "label": model.label, + "description": model.description, + "provider": model.provider, + "providerLabel": model.provider_label, + "avatarUrl": model.avatar_url, + "recommended": model.recommended, + } + ) + return available + + +def get_provider_summaries() -> list[dict[str, Any]]: + return [adapter.to_summary() for adapter in ADAPTERS] + + +def find_model_option(model_name: str) -> dict[str, Any] | None: + for model in get_available_models(): + if model["id"] == model_name: + return model + + adapter = resolve_adapter(model_name) + if not adapter or not adapter.supports_custom_model: + return None + + return { + "id": model_name, + "label": model_name.removeprefix("huggingface/"), + "description": f"Custom {adapter.provider_label} model", + "provider": adapter.provider_id, + "providerLabel": adapter.provider_label, + "avatarUrl": "https://huggingface.co/api/avatars/huggingface", + "recommended": False, + } + + +def build_model_catalog(current_model: str) -> dict[str, Any]: + return { + "current": current_model, + "available": get_available_models(), + "providers": get_provider_summaries(), + "currentInfo": find_model_option(current_model), + } diff --git a/agent/main.py b/agent/main.py index 581979fe..e81c3004 100644 --- a/agent/main.py +++ b/agent/main.py @@ -22,6 +22,7 @@ from agent.config import load_config from agent.core.agent_loop import submission_loop +from agent.core.provider_adapters import get_available_models, is_valid_model_name from agent.core.session import OpType from agent.core.tools import ToolRouter from agent.utils.reliability_checks import check_training_script_save_pattern @@ -49,39 +50,28 @@ # on every error — users don't need it, and our friendly errors cover the case. litellm.suppress_debug_info = True + # ── Suggested models shown by `/model` (not a gate) ────────────────────── # Users can paste any HF model id (e.g. "MiniMaxAI/MiniMax-M2.7") or use one # of the `anthropic/` / `openai/` prefixes for direct API access. For HF ids, # append ":fastest" / ":cheapest" / ":preferred" / ":" to override # the default routing policy (auto = fastest with failover). -SUGGESTED_MODELS = [ - {"id": "anthropic/claude-opus-4-6", "label": "Claude Opus 4.6"}, - {"id": "MiniMaxAI/MiniMax-M2.7", "label": "MiniMax M2.7"}, - {"id": "moonshotai/Kimi-K2.6", "label": "Kimi K2.6"}, - {"id": "zai-org/GLM-5.1", "label": "GLM 5.1"}, -] - +def _suggested_models() -> list[dict[str, Any]]: + return get_available_models() -def _is_valid_model_id(model_id: str) -> bool: - """Loose format check — lets users pick any model id. - Accepts: - • anthropic/ - • openai/ - • /[:] (HF router; tag = provider or policy) - • huggingface//[:] (same, accepts legacy prefix) - - Actual availability is verified against the HF router catalog on switch, - or by the provider on first call. - """ - if not model_id or "/" not in model_id: +def _looks_like_hf_model_id(model_id: str) -> bool: + if model_id.startswith(("anthropic/", "openai/")): return False - # Strip :tag suffix before structural check - head = model_id.split(":", 1)[0] - parts = head.split("/") + bare = model_id.removeprefix("huggingface/").split(":", 1)[0] + parts = bare.split("/") return len(parts) >= 2 and all(parts) +def _is_valid_model_id(model_id: str) -> bool: + return is_valid_model_name(model_id) or _looks_like_hf_model_id(model_id) + + def _safe_get_args(arguments: dict) -> dict: """Safely extract args dict from arguments, handling cases where LLM passes string.""" args = arguments.get("args", {}) @@ -160,9 +150,7 @@ def _print_model_preflight(model_id: str, console) -> None: ) ctx = f"{p.context_length:,} ctx" if p.context_length else "ctx n/a" tools = "tools" if p.supports_tools else "no tools" - console.print( - f" [dim]{p.provider}: {price}, {ctx}, {tools}[/dim]" - ) + console.print(f" [dim]{p.provider}: {price}, {ctx}, {tools}[/dim]") def _get_hf_token() -> str | None: @@ -172,6 +160,7 @@ def _get_hf_token() -> str | None: return token try: from huggingface_hub import HfApi + api = HfApi() token = api.token if token: @@ -224,10 +213,13 @@ async def _prompt_and_save_hf_token(prompt_session: PromptSession) -> str: login(token=token, add_to_git_credential=False) print("Token saved to ~/.cache/huggingface/token") except Exception as e: - print(f"Warning: could not persist token ({e}), using for this session only.") + print( + f"Warning: could not persist token ({e}), using for this session only." + ) return token + @dataclass class Operation: """Operation to be executed by the agent""" @@ -252,9 +244,9 @@ def _create_rich_console(): class _ThinkingShimmer: """Animated shiny/shimmer thinking indicator — a bright gradient sweeps across the text.""" - _BASE = (90, 90, 110) # dim base color - _HIGHLIGHT = (255, 200, 80) # bright shimmer highlight (warm gold) - _WIDTH = 5 # shimmer width in characters + _BASE = (90, 90, 110) # dim base color + _HIGHLIGHT = (255, 200, 80) # bright shimmer highlight (warm gold) + _WIDTH = 5 # shimmer width in characters _FPS = 24 def __init__(self, console): @@ -335,7 +327,7 @@ def _pop_block(self) -> str | None: if idx == -1: return None block = self._buffer[:idx] - self._buffer = self._buffer[idx + 2:] + self._buffer = self._buffer[idx + 2 :] return block async def flush_ready( @@ -361,7 +353,9 @@ async def finish( """Flush complete blocks, then render whatever incomplete tail remains.""" await self.flush_ready(cancel_event=cancel_event, instant=instant) if self._buffer.strip(): - await print_markdown(self._buffer, cancel_event=cancel_event, instant=instant) + await print_markdown( + self._buffer, cancel_event=cancel_event, instant=instant + ) self._buffer = "" def discard(self): @@ -459,7 +453,11 @@ def _cancel_event(): elif event.event_type == "error": shimmer.stop() stream_buf.discard() - error = event.data.get("error", "Unknown error") if event.data else "Unknown error" + error = ( + event.data.get("error", "Unknown error") + if event.data + else "Unknown error" + ) print_error(error) turn_complete_event.set() elif event.event_type == "shutdown": @@ -721,7 +719,9 @@ def _cancel_event(): f"Approve item {i}? (y=yes, yolo=approve all, n=no, or provide feedback): " ) except (KeyboardInterrupt, EOFError): - get_console().print("[dim]Approval cancelled — rejecting remaining items[/dim]") + get_console().print( + "[dim]Approval cancelled — rejecting remaining items[/dim]" + ) approvals.append( { "tool_call_id": tool_call_id, @@ -847,9 +847,12 @@ def _handle_slash_command( console.print("[bold]Current model:[/bold]") console.print(f" {current}") console.print("\n[bold]Suggested:[/bold]") - for m in SUGGESTED_MODELS: + for m in _suggested_models(): marker = " [dim]<-- current[/dim]" if m["id"] == current else "" - console.print(f" {m['id']} [dim]({m['label']})[/dim]{marker}") + provider = m.get("providerLabel") or m.get("provider") or "provider" + console.print( + f" {m['id']} [dim]({m['label']} · {provider})[/dim]{marker}" + ) console.print( "\n[dim]Paste any HF model id (e.g. 'MiniMaxAI/MiniMax-M2.7').\n" "Add ':fastest', ':cheapest', ':preferred', or ':' to override routing.\n" @@ -865,7 +868,9 @@ def _handle_slash_command( " • openai/[/dim]" ) return None - normalized = arg.removeprefix("huggingface/") + normalized = ( + arg.removeprefix("huggingface/") if _looks_like_hf_model_id(arg) else arg + ) _print_model_preflight(normalized, console) session = session_holder[0] if session_holder else None if session: @@ -932,6 +937,7 @@ async def main(): hf_user = None try: from huggingface_hub import HfApi + hf_user = HfApi(token=hf_token).whoami().get("name") except Exception: pass @@ -941,6 +947,7 @@ async def main(): # Pre-warm the HF router catalog in the background so /model switches # don't block on a network fetch. from agent.core import hf_router_catalog + asyncio.create_task(asyncio.to_thread(hf_router_catalog.prewarm)) # Create queues for communication @@ -1084,7 +1091,11 @@ def _install_sigint() -> bool: # Handle slash commands if user_input.strip().startswith("/"): sub = _handle_slash_command( - user_input.strip(), config, session_holder, submission_queue, submission_id + user_input.strip(), + config, + session_holder, + submission_queue, + submission_id, ) if sub is None: # Command handled locally, loop back for input @@ -1147,7 +1158,10 @@ async def headless_main( hf_token = _get_hf_token() if not hf_token: - print("ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.", file=sys.stderr) + print( + "ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.", + file=sys.stderr, + ) sys.exit(1) print(f"HF token loaded", file=sys.stderr) @@ -1288,26 +1302,35 @@ async def headless_main( for t in tools_data ] _hl_sub_id[0] += 1 - await submission_queue.put(Submission( - id=f"hl_approval_{_hl_sub_id[0]}", - operation=Operation( - op_type=OpType.EXEC_APPROVAL, - data={"approvals": approvals}, - ), - )) + await submission_queue.put( + Submission( + id=f"hl_approval_{_hl_sub_id[0]}", + operation=Operation( + op_type=OpType.EXEC_APPROVAL, + data={"approvals": approvals}, + ), + ) + ) elif event.event_type == "compacted": old_tokens = event.data.get("old_tokens", 0) if event.data else 0 new_tokens = event.data.get("new_tokens", 0) if event.data else 0 print_compacted(old_tokens, new_tokens) elif event.event_type == "error": stream_buf.discard() - error = event.data.get("error", "Unknown error") if event.data else "Unknown error" + error = ( + event.data.get("error", "Unknown error") + if event.data + else "Unknown error" + ) print_error(error) break elif event.event_type in ("turn_complete", "interrupted"): stream_buf.discard() history_size = event.data.get("history_size", "?") if event.data else "?" - print(f"\n--- Agent {event.event_type} (history_size={history_size}) ---", file=sys.stderr) + print( + f"\n--- Agent {event.event_type} (history_size={history_size}) ---", + file=sys.stderr, + ) break # Shutdown @@ -1327,6 +1350,7 @@ def cli(): """Entry point for the ml-intern CLI command.""" import logging as _logging import warnings + # Suppress aiohttp "Unclosed client session" noise during event loop teardown _logging.getLogger("asyncio").setLevel(_logging.CRITICAL) # Suppress litellm pydantic deprecation warnings @@ -1335,12 +1359,23 @@ def cli(): warnings.filterwarnings("ignore", category=SyntaxWarning, module="whoosh") parser = argparse.ArgumentParser(description="Hugging Face Agent CLI") - parser.add_argument("prompt", nargs="?", default=None, help="Run headlessly with this prompt") - parser.add_argument("--model", "-m", default=None, help=f"Model to use (default: from config)") - parser.add_argument("--max-iterations", type=int, default=None, - help="Max LLM requests per turn (default: 50, use -1 for unlimited)") - parser.add_argument("--no-stream", action="store_true", - help="Disable token streaming (use non-streaming LLM calls)") + parser.add_argument( + "prompt", nargs="?", default=None, help="Run headlessly with this prompt" + ) + parser.add_argument( + "--model", "-m", default=None, help=f"Model to use (default: from config)" + ) + parser.add_argument( + "--max-iterations", + type=int, + default=None, + help="Max LLM requests per turn (default: 50, use -1 for unlimited)", + ) + parser.add_argument( + "--no-stream", + action="store_true", + help="Disable token streaming (use non-streaming LLM calls)", + ) args = parser.parse_args() try: @@ -1348,7 +1383,14 @@ def cli(): max_iter = args.max_iterations if max_iter is not None and max_iter < 0: max_iter = 10_000 # effectively unlimited - asyncio.run(headless_main(args.prompt, model=args.model, max_iterations=max_iter, stream=not args.no_stream)) + asyncio.run( + headless_main( + args.prompt, + model=args.model, + max_iterations=max_iter, + stream=not args.no_stream, + ) + ) else: asyncio.run(main()) except KeyboardInterrupt: diff --git a/backend/routes/agent.py b/backend/routes/agent.py index e76f3727..ba003b21 100644 --- a/backend/routes/agent.py +++ b/backend/routes/agent.py @@ -31,36 +31,12 @@ from session_manager import MAX_SESSIONS, SessionCapacityError, session_manager from agent.core.llm_params import _resolve_llm_params +from agent.core.provider_adapters import build_model_catalog, is_valid_model_name logger = logging.getLogger(__name__) router = APIRouter(prefix="/api", tags=["agent"]) -AVAILABLE_MODELS = [ - { - "id": "anthropic/claude-opus-4-6", - "label": "Claude Opus 4.6", - "provider": "anthropic", - "recommended": True, - }, - { - "id": "MiniMaxAI/MiniMax-M2.7", - "label": "MiniMax M2.7", - "provider": "huggingface", - "recommended": True, - }, - { - "id": "moonshotai/Kimi-K2.6", - "label": "Kimi K2.6", - "provider": "huggingface", - }, - { - "id": "zai-org/GLM-5.1", - "label": "GLM 5.1", - "provider": "huggingface", - }, -] - def _check_session_access(session_id: str, user: dict[str, Any]) -> None: """Verify the user has access to the given session. Raises 403 or 404.""" @@ -137,10 +113,7 @@ async def llm_health_check() -> LLMHealthResponse: @router.get("/config/model") async def get_model() -> dict: """Get current model and available models. No auth required.""" - return { - "current": session_manager.config.model_name, - "available": AVAILABLE_MODELS, - } + return build_model_catalog(session_manager.config.model_name) @router.post("/config/model") @@ -149,8 +122,7 @@ async def set_model(body: dict, user: dict = Depends(get_current_user)) -> dict: model_id = body.get("model") if not model_id: raise HTTPException(status_code=400, detail="Missing 'model' field") - valid_ids = {m["id"] for m in AVAILABLE_MODELS} - if model_id not in valid_ids: + if not is_valid_model_name(model_id): raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}") session_manager.config.model_name = model_id logger.info(f"Model changed to {model_id} by {user.get('username', 'unknown')}") @@ -313,8 +285,7 @@ async def set_session_model( model_id = body.get("model") if not model_id: raise HTTPException(status_code=400, detail="Missing 'model' field") - valid_ids = {m["id"] for m in AVAILABLE_MODELS} - if model_id not in valid_ids: + if not is_valid_model_name(model_id): raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}") agent_session = session_manager.sessions.get(session_id) if not agent_session: @@ -420,7 +391,9 @@ async def chat_sse( success = await session_manager.submit_user_input(session_id, text) else: broadcaster.unsubscribe(sub_id) - raise HTTPException(status_code=400, detail="Must provide 'text' or 'approvals'") + raise HTTPException( + status_code=400, detail="Must provide 'text' or 'approvals'" + ) if not success: broadcaster.unsubscribe(sub_id) @@ -437,7 +410,13 @@ async def chat_sse( # --------------------------------------------------------------------------- # Shared SSE helpers # --------------------------------------------------------------------------- -_TERMINAL_EVENTS = {"turn_complete", "approval_required", "error", "interrupted", "shutdown"} +_TERMINAL_EVENTS = { + "turn_complete", + "approval_required", + "error", + "interrupted", + "shutdown", +} _SSE_KEEPALIVE_SECONDS = 15 @@ -537,7 +516,10 @@ async def truncate_session( _check_session_access(session_id, user) success = await session_manager.truncate(session_id, body.user_message_index) if not success: - raise HTTPException(status_code=404, detail="Session not found, inactive, or message index out of range") + raise HTTPException( + status_code=404, + detail="Session not found, inactive, or message index out of range", + ) return {"status": "truncated", "session_id": session_id} @@ -563,5 +545,3 @@ async def shutdown_session( if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") return {"status": "shutdown_requested", "session_id": session_id} - - diff --git a/frontend/src/components/Chat/ChatInput.tsx b/frontend/src/components/Chat/ChatInput.tsx index f61c3bfe..11abf18b 100644 --- a/frontend/src/components/Chat/ChatInput.tsx +++ b/frontend/src/components/Chat/ChatInput.tsx @@ -5,56 +5,58 @@ import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown'; import StopIcon from '@mui/icons-material/Stop'; import { apiFetch } from '@/utils/api'; -// Model configuration interface ModelOption { id: string; - name: string; + label: string; description: string; - modelPath: string; avatarUrl: string; + providerLabel?: string; recommended?: boolean; } -const getHfAvatarUrl = (modelId: string) => { - const org = modelId.split('/')[0]; - return `https://huggingface.co/api/avatars/${org}`; -}; - -const MODEL_OPTIONS: ModelOption[] = [ +const FALLBACK_MODELS: ModelOption[] = [ { - id: 'claude-opus', - name: 'Claude Opus 4.6', + id: 'anthropic/claude-opus-4-6', + label: 'Claude Opus 4.6', description: 'Anthropic', - modelPath: 'anthropic/claude-opus-4-6', avatarUrl: 'https://huggingface.co/api/avatars/Anthropic', + providerLabel: 'Anthropic', recommended: true, }, { - id: 'minimax-m2.7', - name: 'MiniMax M2.7', - description: 'Novita', - modelPath: 'MiniMaxAI/MiniMax-M2.7', - avatarUrl: getHfAvatarUrl('MiniMaxAI/MiniMax-M2.7'), + id: 'MiniMaxAI/MiniMax-M2.7', + label: 'MiniMax M2.7', + description: 'HF Router', + avatarUrl: 'https://huggingface.co/api/avatars/MiniMaxAI', + providerLabel: 'Hugging Face Router', recommended: true, }, { - id: 'kimi-k2.6', - name: 'Kimi K2.6', - description: 'Novita', - modelPath: 'moonshotai/Kimi-K2.6', - avatarUrl: getHfAvatarUrl('moonshotai/Kimi-K2.6'), + id: 'moonshotai/Kimi-K2.6', + label: 'Kimi K2.6', + description: 'HF Router', + avatarUrl: 'https://huggingface.co/api/avatars/moonshotai', + providerLabel: 'Hugging Face Router', }, { - id: 'glm-5.1', - name: 'GLM 5.1', - description: 'Together', - modelPath: 'zai-org/GLM-5.1', - avatarUrl: getHfAvatarUrl('zai-org/GLM-5.1'), + id: 'zai-org/GLM-5.1', + label: 'GLM 5.1', + description: 'HF Router', + avatarUrl: 'https://huggingface.co/api/avatars/zai-org', + providerLabel: 'Hugging Face Router', }, ]; -const findModelByPath = (path: string): ModelOption | undefined => { - return MODEL_OPTIONS.find(m => m.modelPath === path || path?.includes(m.id)); +const toModelOption = (value: any): ModelOption | null => { + if (!value || !value.id || !value.label) return null; + return { + id: String(value.id), + label: String(value.label), + description: String(value.description || value.providerLabel || ''), + avatarUrl: String(value.avatarUrl || 'https://huggingface.co/api/avatars/huggingface'), + providerLabel: value.providerLabel ? String(value.providerLabel) : undefined, + recommended: Boolean(value.recommended), + }; }; interface ChatInputProps { @@ -69,30 +71,56 @@ interface ChatInputProps { export default function ChatInput({ sessionId, onSend, onStop, isProcessing = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) { const [input, setInput] = useState(''); const inputRef = useRef(null); - const [selectedModelId, setSelectedModelId] = useState(MODEL_OPTIONS[0].id); + const [modelOptions, setModelOptions] = useState(FALLBACK_MODELS); + const [selectedModelPath, setSelectedModelPath] = useState(FALLBACK_MODELS[0].id); const [modelAnchorEl, setModelAnchorEl] = useState(null); - // Model is per-session: fetch this tab's current model every time the - // session changes. Other tabs keep their own selections independently. + useEffect(() => { + let cancelled = false; + + apiFetch('/api/config/model') + .then((res) => (res.ok ? res.json() : null)) + .then((data) => { + if (cancelled || !data) return; + + const rawAvailable = Array.isArray(data.available) ? data.available : []; + const available = rawAvailable + .map(toModelOption) + .filter((value: ModelOption | null): value is ModelOption => value !== null); + + if (available.length > 0) { + setModelOptions(available); + } + if (typeof data.current === 'string' && data.current) { + setSelectedModelPath(data.current); + } + }) + .catch(() => { /* ignore */ }); + + return () => { cancelled = true; }; + }, []); + useEffect(() => { if (!sessionId) return; + let cancelled = false; apiFetch(`/api/session/${sessionId}`) .then((res) => (res.ok ? res.json() : null)) .then((data) => { if (cancelled) return; - if (data?.model) { - const model = findModelByPath(data.model); - if (model) setSelectedModelId(model.id); + if (typeof data?.model === 'string' && data.model) { + setSelectedModelPath(data.model); } }) .catch(() => { /* ignore */ }); + return () => { cancelled = true; }; }, [sessionId]); - const selectedModel = MODEL_OPTIONS.find(m => m.id === selectedModelId) || MODEL_OPTIONS[0]; + const selectedModel = modelOptions.find((model) => model.id === selectedModelPath) + || toModelOption({ id: selectedModelPath, label: selectedModelPath, description: '', avatarUrl: 'https://huggingface.co/api/avatars/huggingface' }) + || modelOptions[0]; - // Auto-focus the textarea when the session becomes ready useEffect(() => { if (!disabled && !isProcessing && inputRef.current) { inputRef.current.focus(); @@ -113,7 +141,7 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa handleSend(); } }, - [handleSend] + [handleSend], ); const handleModelClick = (event: React.MouseEvent) => { @@ -124,16 +152,21 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa setModelAnchorEl(null); }; - const handleSelectModel = async (model: ModelOption) => { + const handleSelectModel = async (modelPath: string) => { handleModelClose(); if (!sessionId) return; + try { const res = await apiFetch(`/api/session/${sessionId}/model`, { method: 'POST', - body: JSON.stringify({ model: model.modelPath }), + body: JSON.stringify({ model: modelPath }), }); - if (res.ok) setSelectedModelId(model.id); - } catch { /* ignore */ } + if (res.ok) { + setSelectedModelPath(modelPath); + } + } catch { + // ignore + } }; return ( @@ -158,9 +191,9 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa border: '1px solid var(--border)', transition: 'box-shadow 0.2s ease, border-color 0.2s ease', '&:focus-within': { - borderColor: 'var(--accent-yellow)', - boxShadow: 'var(--focus)', - } + borderColor: 'var(--accent-yellow)', + boxShadow: 'var(--focus)', + }, }} > {isProcessing ? ( @@ -243,7 +276,6 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa )} - {/* Powered By Badge */} @@ -265,16 +297,15 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa {selectedModel.name} - {selectedModel.name} + {selectedModel.label} - {/* Model Selection Menu */} - {MODEL_OPTIONS.map((model) => ( + {modelOptions.map((model) => ( handleSelectModel(model)} - selected={selectedModelId === model.id} + onClick={() => handleSelectModel(model.id)} + selected={selectedModelPath === model.id} sx={{ py: 1.5, '&.Mui-selected': { bgcolor: 'rgba(255,255,255,0.05)', - } + }, }} > {model.name} - {model.name} + {model.label} {model.recommended && ( )} - } - secondary={model.description} + )} + secondary={model.description || model.providerLabel} secondaryTypographyProps={{ - sx: { fontSize: '12px', color: 'var(--muted-text)' } + sx: { fontSize: '12px', color: 'var(--muted-text)' }, }} /> diff --git a/tests/test_provider_adapters.py b/tests/test_provider_adapters.py index 78282815..650094a3 100644 --- a/tests/test_provider_adapters.py +++ b/tests/test_provider_adapters.py @@ -1,4 +1,5 @@ from agent.core.llm_params import _resolve_llm_params +from agent.core.provider_adapters import build_model_catalog, is_valid_model_name def test_native_adapter_keeps_model_name(): @@ -33,3 +34,17 @@ def test_hf_adapter_adds_bill_to_header(monkeypatch): assert params["extra_headers"] == {"X-HF-Bill-To": "smolagents"} assert params["api_key"] == "hf-space-token" + + +def test_model_catalog_comes_from_adapters(): + catalog = build_model_catalog("anthropic/claude-opus-4-6") + + assert catalog["current"] == "anthropic/claude-opus-4-6" + assert any(model["provider"] == "anthropic" for model in catalog["available"]) + assert any(model["provider"] == "huggingface" for model in catalog["available"]) + assert any(provider["id"] == "huggingface" for provider in catalog["providers"]) + + +def test_model_validation_accepts_free_form_hf_ids(): + assert is_valid_model_name("moonshotai/Kimi-K2.6:fastest") is True + assert is_valid_model_name("huggingface/moonshotai/Kimi-K2.6:novita") is True From f1f9116e8168f0d9fd6ce76243e469825289546b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Vy=C5=A1n=C3=BD?= Date: Wed, 22 Apr 2026 20:19:04 +0200 Subject: [PATCH 04/15] Add OpenCode Go provider adapter --- agent/core/provider_adapters.py | 53 ++++++++++++++++++++++++++++++++- agent/main.py | 17 +++++++++-- tests/test_provider_adapters.py | 18 +++++++++++ 3 files changed, 84 insertions(+), 4 deletions(-) diff --git a/agent/core/provider_adapters.py b/agent/core/provider_adapters.py index adb255f6..6ccb2bad 100644 --- a/agent/core/provider_adapters.py +++ b/agent/core/provider_adapters.py @@ -159,8 +159,53 @@ def build_params( return params +@dataclass(frozen=True) +class OpenCodeGoAdapter(ProviderAdapter): + prefixes: tuple[str, ...] = ("opencode-go/",) + + def suggested_models(self) -> tuple[SuggestedModel, ...]: + return ( + SuggestedModel( + id="opencode-go/kimi-k2.6", + label="Kimi K2.6", + description="OpenCode Go", + provider="opencode_go", + provider_label="OpenCode Go", + avatar_url="https://huggingface.co/api/avatars/opencode-ai", + recommended=True, + ), + ) + + def allows_model_name(self, model_name: str) -> bool: + if not self.matches(model_name): + return False + return bool(model_name.removeprefix("opencode-go/")) + + def build_params( + self, + model_name: str, + session_hf_token: str | None = None, + reasoning_effort: str | None = None, + ) -> dict: + model_id = model_name.removeprefix("opencode-go/") + api_key = os.environ.get("OPENCODE_GO_API_KEY") or os.environ.get( + "OPENCODE_API_KEY" + ) + return { + "model": f"openai/{model_id}", + "api_base": "https://opencode.ai/zen/go/v1", + "api_key": api_key, + } + + ADAPTERS: tuple[ProviderAdapter, ...] = ( NativeAdapter(provider_id="native", provider_label="Native"), + OpenCodeGoAdapter( + provider_id="opencode_go", + provider_label="OpenCode Go", + supports_custom_model=True, + custom_model_hint="Use opencode-go/, for example opencode-go/kimi-k2.6", + ), HfRouterAdapter( provider_id="huggingface", provider_label="Hugging Face Router", @@ -218,9 +263,15 @@ def find_model_option(model_name: str) -> dict[str, Any] | None: if not adapter or not adapter.supports_custom_model: return None + label = model_name + if adapter.provider_id == "huggingface": + label = model_name.removeprefix("huggingface/") + elif adapter.prefixes: + label = model_name.removeprefix(adapter.prefixes[0]) + return { "id": model_name, - "label": model_name.removeprefix("huggingface/"), + "label": label, "description": f"Custom {adapter.provider_label} model", "provider": adapter.provider_id, "providerLabel": adapter.provider_label, diff --git a/agent/main.py b/agent/main.py index e81c3004..e2cef1ef 100644 --- a/agent/main.py +++ b/agent/main.py @@ -22,7 +22,11 @@ from agent.config import load_config from agent.core.agent_loop import submission_loop -from agent.core.provider_adapters import get_available_models, is_valid_model_name +from agent.core.provider_adapters import ( + get_available_models, + is_valid_model_name, + resolve_adapter, +) from agent.core.session import OpType from agent.core.tools import ToolRouter from agent.utils.reliability_checks import check_training_script_save_pattern @@ -61,8 +65,10 @@ def _suggested_models() -> list[dict[str, Any]]: def _looks_like_hf_model_id(model_id: str) -> bool: - if model_id.startswith(("anthropic/", "openai/")): - return False + adapter = resolve_adapter(model_id) + if adapter: + return adapter.provider_id == "huggingface" + bare = model_id.removeprefix("huggingface/").split(":", 1)[0] parts = bare.split("/") return len(parts) >= 2 and all(parts) @@ -96,6 +102,11 @@ def _print_model_preflight(model_id: str, console) -> None: console.print(f"[green]Model switched to {model_id}[/green]") return + adapter = resolve_adapter(model_id) + if adapter and adapter.provider_id != "huggingface": + console.print(f"[green]Model switched to {model_id}[/green]") + return + from agent.core import hf_router_catalog as cat bare, _, tag = model_id.partition(":") diff --git a/tests/test_provider_adapters.py b/tests/test_provider_adapters.py index 650094a3..64db78ab 100644 --- a/tests/test_provider_adapters.py +++ b/tests/test_provider_adapters.py @@ -36,15 +36,33 @@ def test_hf_adapter_adds_bill_to_header(monkeypatch): assert params["api_key"] == "hf-space-token" +def test_opencode_go_adapter_uses_api_key(monkeypatch): + monkeypatch.setenv("OPENCODE_GO_API_KEY", "go-test-key") + + params = _resolve_llm_params("opencode-go/kimi-k2.6") + + assert params == { + "model": "openai/kimi-k2.6", + "api_base": "https://opencode.ai/zen/go/v1", + "api_key": "go-test-key", + } + + def test_model_catalog_comes_from_adapters(): catalog = build_model_catalog("anthropic/claude-opus-4-6") assert catalog["current"] == "anthropic/claude-opus-4-6" assert any(model["provider"] == "anthropic" for model in catalog["available"]) assert any(model["provider"] == "huggingface" for model in catalog["available"]) + assert any(model["provider"] == "opencode_go" for model in catalog["available"]) + assert any(provider["id"] == "opencode_go" for provider in catalog["providers"]) assert any(provider["id"] == "huggingface" for provider in catalog["providers"]) def test_model_validation_accepts_free_form_hf_ids(): assert is_valid_model_name("moonshotai/Kimi-K2.6:fastest") is True assert is_valid_model_name("huggingface/moonshotai/Kimi-K2.6:novita") is True + + +def test_model_validation_accepts_free_form_opencode_go_ids(): + assert is_valid_model_name("opencode-go/glm-5.1") is True From 4550dd54390b52c17c9e6b90c970b2af4094ad17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Vy=C5=A1n=C3=BD?= Date: Thu, 23 Apr 2026 10:34:51 +0200 Subject: [PATCH 05/15] Refactor LLM param resolution into provider adapters Split NativeAdapter into AnthropicAdapter (thinking config + output_config.effort) and OpenAIAdapter (reasoning_effort top-level). Each adapter owns its accepted effort set and raises UnsupportedEffortError in strict mode, preserving the effort_probe cascade with zero changes to effort_probe.py or agent_loop.py. llm_params.py becomes a thin dispatcher delegating to resolve_adapter().build_params() while keeping the litellm effort-validation patch and re-exporting UnsupportedEffortError. model_switcher.py reads suggested models from the adapter registry instead of maintaining a separate SUGGESTED_MODELS list. backend/routes/agent.py replaces AVAILABLE_MODELS with build_model_catalog(). OpenCodeGoAdapter deferred to PR #60. Co-Authored-By: Claude Opus 4.6 (1M context) --- agent/core/llm_params.py | 131 +++++--------------------- agent/core/model_switcher.py | 40 +++----- agent/core/provider_adapters.py | 162 +++++++++++++++++++++----------- tests/test_provider_adapters.py | 89 ++++++++++++++---- 4 files changed, 221 insertions(+), 201 deletions(-) diff --git a/agent/core/llm_params.py b/agent/core/llm_params.py index 830f334c..adbd40d7 100644 --- a/agent/core/llm_params.py +++ b/agent/core/llm_params.py @@ -3,9 +3,20 @@ Kept separate from ``agent_loop`` so tools (research, context compaction, etc.) can import it without pulling in the whole agent loop / tool router and creating circular imports. + +Provider-specific logic (Anthropic thinking config, OpenAI reasoning_effort, +HF router extra_body) lives in ``provider_adapters.py``. This module is the +stable import surface for ``effort_probe`` and ``agent_loop``. """ -import os +from agent.core.provider_adapters import ( + UnsupportedEffortError, + resolve_adapter, +) + +# Re-export so existing ``from agent.core.llm_params import +# UnsupportedEffortError`` in effort_probe.py keeps working. +__all__ = ["UnsupportedEffortError", "_resolve_llm_params"] def _patch_litellm_effort_validation() -> None: @@ -64,59 +75,17 @@ def _widened(model: str) -> bool: _patch_litellm_effort_validation() -# Effort levels accepted on the wire. -# Anthropic (4.6+): low | medium | high | xhigh | max (output_config.effort) -# OpenAI direct: minimal | low | medium | high (reasoning_effort top-level) -# HF router: low | medium | high (extra_body.reasoning_effort) -# -# We validate *shape* here and let the probe cascade walk down on rejection; -# we deliberately do NOT maintain a per-model capability table. -_ANTHROPIC_EFFORTS = {"low", "medium", "high", "xhigh", "max"} -_OPENAI_EFFORTS = {"minimal", "low", "medium", "high"} -_HF_EFFORTS = {"low", "medium", "high"} - - -class UnsupportedEffortError(ValueError): - """The requested effort isn't valid for this provider's API surface. - - Raised synchronously before any network call so the probe cascade can - skip levels the provider can't accept (e.g. ``max`` on HF router). - """ - - def _resolve_llm_params( model_name: str, session_hf_token: str | None = None, reasoning_effort: str | None = None, strict: bool = False, ) -> dict: - """ - Build LiteLLM kwargs for a given model id. - - • ``anthropic/`` — native thinking config. We bypass LiteLLM's - ``reasoning_effort`` → ``thinking`` mapping (which lags new Claude - releases like 4.7 and sends the wrong API shape). Instead we pass - both ``thinking={"type": "adaptive"}`` and ``output_config= - {"effort": }`` as top-level kwargs — LiteLLM's Anthropic - adapter forwards unknown top-level kwargs into the request body - verbatim (confirmed by live probe; ``extra_body`` does NOT work - here because Anthropic's API rejects it as "Extra inputs are not - permitted"). This is the stable API for 4.6 and 4.7. Older - extended-thinking models that only accept ``thinking.type.enabled`` - will reject this; the probe's cascade catches that and falls back - to no thinking. - - • ``openai/`` — ``reasoning_effort`` forwarded as a top-level - kwarg (GPT-5 / o-series). LiteLLM uses the user's ``OPENAI_API_KEY``. - - • Anything else is treated as a HuggingFace router id. We hit the - auto-routing OpenAI-compatible endpoint at - ``https://router.huggingface.co/v1``. The id can be bare or carry an - HF routing suffix (``:fastest`` / ``:cheapest`` / ``:``). - A leading ``huggingface/`` is stripped. ``reasoning_effort`` is - forwarded via ``extra_body`` (LiteLLM's OpenAI adapter refuses it as - a top-level kwarg for non-OpenAI models). "minimal" normalizes to - "low". + """Build LiteLLM kwargs for a given model id. + + Delegates to the matching provider adapter. See ``provider_adapters.py`` + for the per-provider logic (Anthropic thinking config, OpenAI + reasoning_effort, HF router extra_body, etc.). ``strict=True`` raises ``UnsupportedEffortError`` when the requested effort isn't in the provider's accepted set, instead of silently @@ -131,62 +100,12 @@ def _resolve_llm_params( 2. session.hf_token — the user's own token (CLI / OAuth / cache file). 3. HF_TOKEN env — belt-and-suspenders fallback for CLI users. """ - if model_name.startswith("anthropic/"): - params: dict = {"model": model_name} - if reasoning_effort: - level = reasoning_effort - if level == "minimal": - level = "low" - if level not in _ANTHROPIC_EFFORTS: - if strict: - raise UnsupportedEffortError( - f"Anthropic doesn't accept effort={level!r}" - ) - else: - # Adaptive thinking + output_config.effort is the stable - # Anthropic API for Claude 4.6 / 4.7. Both kwargs are - # passed top-level: LiteLLM forwards unknown params into - # the request body for Anthropic, so ``output_config`` - # reaches the API. ``extra_body`` does NOT work here — - # Anthropic rejects it as "Extra inputs are not - # permitted". - params["thinking"] = {"type": "adaptive"} - params["output_config"] = {"effort": level} - return params - - if model_name.startswith("openai/"): - params = {"model": model_name} - if reasoning_effort: - if reasoning_effort not in _OPENAI_EFFORTS: - if strict: - raise UnsupportedEffortError( - f"OpenAI doesn't accept effort={reasoning_effort!r}" - ) - else: - params["reasoning_effort"] = reasoning_effort - return params - - hf_model = model_name.removeprefix("huggingface/") - api_key = ( - os.environ.get("INFERENCE_TOKEN") - or session_hf_token - or os.environ.get("HF_TOKEN") + adapter = resolve_adapter(model_name) + if adapter is None: + raise ValueError(f"Unsupported model id: {model_name}") + return adapter.build_params( + model_name, + session_hf_token=session_hf_token, + reasoning_effort=reasoning_effort, + strict=strict, ) - params = { - "model": f"openai/{hf_model}", - "api_base": "https://router.huggingface.co/v1", - "api_key": api_key, - } - if os.environ.get("INFERENCE_TOKEN"): - bill_to = os.environ.get("HF_BILL_TO", "smolagents") - params["extra_headers"] = {"X-HF-Bill-To": bill_to} - if reasoning_effort: - hf_level = "low" if reasoning_effort == "minimal" else reasoning_effort - if hf_level not in _HF_EFFORTS: - if strict: - raise UnsupportedEffortError( - f"HF router doesn't accept effort={hf_level!r}" - ) - else: - params["extra_body"] = {"reasoning_effort": hf_level} - return params diff --git a/agent/core/model_switcher.py b/agent/core/model_switcher.py index b30c7238..35729bcc 100644 --- a/agent/core/model_switcher.py +++ b/agent/core/model_switcher.py @@ -16,20 +16,11 @@ from __future__ import annotations from agent.core.effort_probe import ProbeInconclusive, probe_effort - - -# Suggested models shown by `/model` (not a gate). Users can paste any HF -# model id (e.g. "MiniMaxAI/MiniMax-M2.7") or an `anthropic/` / `openai/` -# prefix for direct API access. For HF ids, append ":fastest" / -# ":cheapest" / ":preferred" / ":" to override the default -# routing policy (auto = fastest with failover). -SUGGESTED_MODELS = [ - {"id": "anthropic/claude-opus-4-7", "label": "Claude Opus 4.7"}, - {"id": "anthropic/claude-opus-4-6", "label": "Claude Opus 4.6"}, - {"id": "MiniMaxAI/MiniMax-M2.7", "label": "MiniMax M2.7"}, - {"id": "moonshotai/Kimi-K2.6", "label": "Kimi K2.6"}, - {"id": "zai-org/GLM-5.1", "label": "GLM 5.1"}, -] +from agent.core.provider_adapters import ( + get_available_models, + is_valid_model_name, + resolve_adapter, +) _ROUTING_POLICIES = {"fastest", "cheapest", "preferred"} @@ -38,15 +29,12 @@ def is_valid_model_id(model_id: str) -> bool: """Loose format check — lets users pick any model id. - Accepts: - • anthropic/ - • openai/ - • /[:] (HF router; tag = provider or policy) - • huggingface//[:] (same, accepts legacy prefix) - - Actual availability is verified against the HF router catalog on - switch, and by the provider on the probe's ping call. + Checks the adapter registry first (covers all registered providers), + then falls back to the structural ``/[:]`` pattern + so unknown HF models are still accepted. """ + if is_valid_model_name(model_id): + return True if not model_id or "/" not in model_id: return False head = model_id.split(":", 1)[0] @@ -63,7 +51,8 @@ def _print_hf_routing_info(model_id: str, console) -> bool: Anthropic / OpenAI ids return ``True`` without printing anything — the probe below covers "does this model exist". """ - if model_id.startswith(("anthropic/", "openai/")): + adapter = resolve_adapter(model_id) + if adapter and adapter.provider_id != "huggingface": return True from agent.core import hf_router_catalog as cat @@ -130,9 +119,10 @@ def print_model_listing(config, console) -> None: console.print("[bold]Current model:[/bold]") console.print(f" {current}") console.print("\n[bold]Suggested:[/bold]") - for m in SUGGESTED_MODELS: + for m in get_available_models(): marker = " [dim]<-- current[/dim]" if m["id"] == current else "" - console.print(f" {m['id']} [dim]({m['label']})[/dim]{marker}") + provider = m.get("providerLabel") or m.get("provider") or "" + console.print(f" {m['id']} [dim]({m['label']} · {provider})[/dim]{marker}") console.print( "\n[dim]Paste any HF model id (e.g. 'MiniMaxAI/MiniMax-M2.7').\n" "Add ':fastest', ':cheapest', ':preferred', or ':' to override routing.\n" diff --git a/agent/core/provider_adapters.py b/agent/core/provider_adapters.py index 6ccb2bad..07e4e5a7 100644 --- a/agent/core/provider_adapters.py +++ b/agent/core/provider_adapters.py @@ -1,9 +1,31 @@ -"""Provider adapters for runtime params and model catalog metadata.""" +"""Provider adapters for runtime params and model catalog metadata. + +Each adapter owns its LiteLLM kwargs construction (``build_params``) and +the list of suggested models shown in ``/model`` and the web picker. +Adding a new provider means subclassing ``ProviderAdapter``, implementing +``build_params``, and appending to ``ADAPTERS``. +""" import os from dataclasses import dataclass -from typing import Any +from typing import Any, ClassVar + + +# --------------------------------------------------------------------------- +# Errors +# --------------------------------------------------------------------------- + +class UnsupportedEffortError(ValueError): + """The requested effort isn't valid for this provider's API surface. + Raised synchronously before any network call so the probe cascade can + skip levels the provider can't accept (e.g. ``max`` on HF router). + """ + + +# --------------------------------------------------------------------------- +# Data types +# --------------------------------------------------------------------------- @dataclass(frozen=True) class SuggestedModel: @@ -33,8 +55,10 @@ def suggested_models(self) -> tuple[SuggestedModel, ...]: def build_params( self, model_name: str, + *, session_hf_token: str | None = None, reasoning_effort: str | None = None, + strict: bool = False, ) -> dict: raise NotImplementedError @@ -52,15 +76,30 @@ def to_summary(self) -> dict[str, Any]: } +# --------------------------------------------------------------------------- +# Concrete adapters +# --------------------------------------------------------------------------- + @dataclass(frozen=True) -class NativeAdapter(ProviderAdapter): - prefixes: tuple[str, ...] = ("anthropic/", "openai/") +class AnthropicAdapter(ProviderAdapter): + """Anthropic models via native API (thinking + output_config.effort).""" - def matches(self, model_name: str) -> bool: - return model_name.startswith(self.prefixes) + prefixes: tuple[str, ...] = ("anthropic/",) + _EFFORTS: ClassVar[frozenset[str]] = frozenset( + {"low", "medium", "high", "xhigh", "max"} + ) def suggested_models(self) -> tuple[SuggestedModel, ...]: return ( + SuggestedModel( + id="anthropic/claude-opus-4-7", + label="Claude Opus 4.7", + description="Anthropic", + provider="anthropic", + provider_label="Anthropic", + avatar_url="https://huggingface.co/api/avatars/Anthropic", + recommended=True, + ), SuggestedModel( id="anthropic/claude-opus-4-6", label="Claude Opus 4.6", @@ -68,30 +107,69 @@ def suggested_models(self) -> tuple[SuggestedModel, ...]: provider="anthropic", provider_label="Anthropic", avatar_url="https://huggingface.co/api/avatars/Anthropic", - recommended=True, ), ) def build_params( self, model_name: str, + *, session_hf_token: str | None = None, reasoning_effort: str | None = None, + strict: bool = False, ) -> dict: params: dict[str, Any] = {"model": model_name} if reasoning_effort: - params["reasoning_effort"] = reasoning_effort + level = "low" if reasoning_effort == "minimal" else reasoning_effort + if level not in self._EFFORTS: + if strict: + raise UnsupportedEffortError( + f"Anthropic doesn't accept effort={level!r}" + ) + else: + params["thinking"] = {"type": "adaptive"} + params["output_config"] = {"effort": level} + return params + + +@dataclass(frozen=True) +class OpenAIAdapter(ProviderAdapter): + """OpenAI models via native API (reasoning_effort top-level kwarg).""" + + prefixes: tuple[str, ...] = ("openai/",) + _EFFORTS: ClassVar[frozenset[str]] = frozenset( + {"minimal", "low", "medium", "high"} + ) + + def build_params( + self, + model_name: str, + *, + session_hf_token: str | None = None, + reasoning_effort: str | None = None, + strict: bool = False, + ) -> dict: + params: dict[str, Any] = {"model": model_name} + if reasoning_effort: + if reasoning_effort not in self._EFFORTS: + if strict: + raise UnsupportedEffortError( + f"OpenAI doesn't accept effort={reasoning_effort!r}" + ) + else: + params["reasoning_effort"] = reasoning_effort return params @dataclass(frozen=True) class HfRouterAdapter(ProviderAdapter): - allowed_efforts: tuple[str, ...] = ("low", "medium", "high") + """HuggingFace router — OpenAI-compat endpoint with HF token chain.""" + + _EFFORTS: ClassVar[frozenset[str]] = frozenset({"low", "medium", "high"}) def _is_hf_model_name(self, model_name: str) -> bool: if model_name.startswith(("anthropic/", "openai/")): return False - bare = model_name.removeprefix("huggingface/").split(":", 1)[0] parts = bare.split("/") return len(parts) >= 2 and all(parts) @@ -134,8 +212,10 @@ def allows_model_name(self, model_name: str) -> bool: def build_params( self, model_name: str, + *, session_hf_token: str | None = None, reasoning_effort: str | None = None, + strict: bool = False, ) -> dict: hf_model = model_name.removeprefix("huggingface/") inference_token = os.environ.get("INFERENCE_TOKEN") @@ -153,58 +233,28 @@ def build_params( if reasoning_effort: hf_level = "low" if reasoning_effort == "minimal" else reasoning_effort - if hf_level in self.allowed_efforts: + if hf_level not in self._EFFORTS: + if strict: + raise UnsupportedEffortError( + f"HF router doesn't accept effort={hf_level!r}" + ) + else: params["extra_body"] = {"reasoning_effort": hf_level} return params -@dataclass(frozen=True) -class OpenCodeGoAdapter(ProviderAdapter): - prefixes: tuple[str, ...] = ("opencode-go/",) - - def suggested_models(self) -> tuple[SuggestedModel, ...]: - return ( - SuggestedModel( - id="opencode-go/kimi-k2.6", - label="Kimi K2.6", - description="OpenCode Go", - provider="opencode_go", - provider_label="OpenCode Go", - avatar_url="https://huggingface.co/api/avatars/opencode-ai", - recommended=True, - ), - ) - - def allows_model_name(self, model_name: str) -> bool: - if not self.matches(model_name): - return False - return bool(model_name.removeprefix("opencode-go/")) - - def build_params( - self, - model_name: str, - session_hf_token: str | None = None, - reasoning_effort: str | None = None, - ) -> dict: - model_id = model_name.removeprefix("opencode-go/") - api_key = os.environ.get("OPENCODE_GO_API_KEY") or os.environ.get( - "OPENCODE_API_KEY" - ) - return { - "model": f"openai/{model_id}", - "api_base": "https://opencode.ai/zen/go/v1", - "api_key": api_key, - } - +# --------------------------------------------------------------------------- +# Registry +# --------------------------------------------------------------------------- ADAPTERS: tuple[ProviderAdapter, ...] = ( - NativeAdapter(provider_id="native", provider_label="Native"), - OpenCodeGoAdapter( - provider_id="opencode_go", - provider_label="OpenCode Go", + AnthropicAdapter(provider_id="anthropic", provider_label="Anthropic"), + OpenAIAdapter( + provider_id="openai", + provider_label="OpenAI", supports_custom_model=True, - custom_model_hint="Use opencode-go/, for example opencode-go/kimi-k2.6", + custom_model_hint="Use openai/, for example openai/gpt-5", ), HfRouterAdapter( provider_id="huggingface", @@ -225,6 +275,10 @@ def resolve_adapter(model_name: str) -> ProviderAdapter | None: return None +# --------------------------------------------------------------------------- +# Catalog helpers (used by model_switcher, backend, frontend) +# --------------------------------------------------------------------------- + def is_valid_model_name(model_name: str) -> bool: adapter = resolve_adapter(model_name) if not adapter: diff --git a/tests/test_provider_adapters.py b/tests/test_provider_adapters.py index 64db78ab..101b3ada 100644 --- a/tests/test_provider_adapters.py +++ b/tests/test_provider_adapters.py @@ -1,16 +1,67 @@ +import pytest + from agent.core.llm_params import _resolve_llm_params -from agent.core.provider_adapters import build_model_catalog, is_valid_model_name +from agent.core.provider_adapters import ( + UnsupportedEffortError, + build_model_catalog, + is_valid_model_name, +) + +# -- Anthropic adapter ------------------------------------------------------- -def test_native_adapter_keeps_model_name(): +def test_anthropic_adapter_builds_thinking_config(): params = _resolve_llm_params("anthropic/claude-opus-4-6", reasoning_effort="high") assert params == { "model": "anthropic/claude-opus-4-6", - "reasoning_effort": "high", + "thinking": {"type": "adaptive"}, + "output_config": {"effort": "high"}, } +def test_anthropic_adapter_normalizes_minimal_to_low(): + params = _resolve_llm_params("anthropic/claude-opus-4-7", reasoning_effort="minimal") + + assert params["output_config"] == {"effort": "low"} + + +def test_anthropic_adapter_no_effort(): + params = _resolve_llm_params("anthropic/claude-opus-4-6") + + assert params == {"model": "anthropic/claude-opus-4-6"} + + +def test_anthropic_adapter_strict_rejects_invalid(): + with pytest.raises(UnsupportedEffortError): + _resolve_llm_params( + "anthropic/claude-opus-4-6", reasoning_effort="turbo", strict=True + ) + + +def test_anthropic_adapter_nonstrict_drops_invalid(): + params = _resolve_llm_params( + "anthropic/claude-opus-4-6", reasoning_effort="turbo", strict=False + ) + assert "thinking" not in params + assert "output_config" not in params + + +# -- OpenAI adapter ----------------------------------------------------------- + +def test_openai_adapter_passes_reasoning_effort(): + params = _resolve_llm_params("openai/gpt-5", reasoning_effort="medium") + + assert params == {"model": "openai/gpt-5", "reasoning_effort": "medium"} + + +def test_openai_adapter_strict_rejects_max(): + with pytest.raises(UnsupportedEffortError): + _resolve_llm_params("openai/gpt-5", reasoning_effort="max", strict=True) + + +# -- HF Router adapter -------------------------------------------------------- + def test_hf_adapter_builds_router_params(monkeypatch): monkeypatch.setenv("HF_TOKEN", "hf-test") @@ -36,17 +87,14 @@ def test_hf_adapter_adds_bill_to_header(monkeypatch): assert params["api_key"] == "hf-space-token" -def test_opencode_go_adapter_uses_api_key(monkeypatch): - monkeypatch.setenv("OPENCODE_GO_API_KEY", "go-test-key") +def test_hf_adapter_strict_rejects_max(): + with pytest.raises(UnsupportedEffortError): + _resolve_llm_params( + "MiniMaxAI/MiniMax-M2.7", reasoning_effort="max", strict=True + ) - params = _resolve_llm_params("opencode-go/kimi-k2.6") - - assert params == { - "model": "openai/kimi-k2.6", - "api_base": "https://opencode.ai/zen/go/v1", - "api_key": "go-test-key", - } +# -- Catalog & validation ----------------------------------------------------- def test_model_catalog_comes_from_adapters(): catalog = build_model_catalog("anthropic/claude-opus-4-6") @@ -54,9 +102,9 @@ def test_model_catalog_comes_from_adapters(): assert catalog["current"] == "anthropic/claude-opus-4-6" assert any(model["provider"] == "anthropic" for model in catalog["available"]) assert any(model["provider"] == "huggingface" for model in catalog["available"]) - assert any(model["provider"] == "opencode_go" for model in catalog["available"]) - assert any(provider["id"] == "opencode_go" for provider in catalog["providers"]) assert any(provider["id"] == "huggingface" for provider in catalog["providers"]) + assert any(provider["id"] == "anthropic" for provider in catalog["providers"]) + assert catalog["currentInfo"] is not None def test_model_validation_accepts_free_form_hf_ids(): @@ -64,5 +112,14 @@ def test_model_validation_accepts_free_form_hf_ids(): assert is_valid_model_name("huggingface/moonshotai/Kimi-K2.6:novita") is True -def test_model_validation_accepts_free_form_opencode_go_ids(): - assert is_valid_model_name("opencode-go/glm-5.1") is True +def test_model_validation_rejects_garbage(): + assert is_valid_model_name("") is False + assert is_valid_model_name("no-slash") is False + + +def test_unsupported_effort_reexport(): + """UnsupportedEffortError must be importable from llm_params (backward compat).""" + from agent.core.llm_params import UnsupportedEffortError as FromLlm + from agent.core.provider_adapters import UnsupportedEffortError as FromAdapters + + assert FromLlm is FromAdapters From a4a75325a46bfba04e1cff13b35165916d828800 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Vy=C5=A1n=C3=BD?= Date: Thu, 23 Apr 2026 13:05:09 +0200 Subject: [PATCH 06/15] Trim provider adapters to params and validation Restore the existing web model behavior so PR #55 stays a behavior-preserving refactor while keeping shared runtime and CLI validation logic. --- agent/core/llm_params.py | 6 +- agent/core/model_switcher.py | 62 +++-- agent/core/provider_adapters.py | 151 ++---------- backend/routes/agent.py | 93 ++++--- frontend/src/components/Chat/ChatInput.tsx | 267 ++++++++------------- frontend/src/utils/model.ts | 9 +- tests/test_backend_routes.py | 75 ------ tests/test_provider_adapters.py | 36 +-- 8 files changed, 234 insertions(+), 465 deletions(-) delete mode 100644 tests/test_backend_routes.py diff --git a/agent/core/llm_params.py b/agent/core/llm_params.py index c60c41d4..a52032e6 100644 --- a/agent/core/llm_params.py +++ b/agent/core/llm_params.py @@ -1,4 +1,4 @@ -"""LiteLLM kwargs resolution for supported model ids.""" +"""LiteLLM kwargs resolution for the model ids this agent accepts.""" from agent.core.provider_adapters import ( UnsupportedEffortError, @@ -12,7 +12,7 @@ def _patch_litellm_effort_validation() -> None: """Patch LiteLLM's Anthropic effort validation for Claude Opus 4.7.""" try: from litellm.llms.anthropic.chat import transformation as _t - except ImportError: + except Exception: return cfg = getattr(_t, "AnthropicConfig", None) @@ -70,8 +70,6 @@ def _resolve_llm_params( 3. HF_TOKEN env — belt-and-suspenders fallback for CLI users. """ adapter = resolve_adapter(model_name) - if adapter is None: - raise ValueError(f"Unsupported model id: {model_name}") return adapter.build_params( model_name, session_hf_token=session_hf_token, diff --git a/agent/core/model_switcher.py b/agent/core/model_switcher.py index 35729bcc..5bf5a44c 100644 --- a/agent/core/model_switcher.py +++ b/agent/core/model_switcher.py @@ -16,11 +16,21 @@ from __future__ import annotations from agent.core.effort_probe import ProbeInconclusive, probe_effort -from agent.core.provider_adapters import ( - get_available_models, - is_valid_model_name, - resolve_adapter, -) +from agent.core.provider_adapters import is_valid_model_name + + +# Suggested models shown by `/model` (not a gate). Users can paste any HF +# model id (e.g. "MiniMaxAI/MiniMax-M2.7") or an `anthropic/` / `openai/` +# prefix for direct API access. For HF ids, append ":fastest" / +# ":cheapest" / ":preferred" / ":" to override the default +# routing policy (auto = fastest with failover). +SUGGESTED_MODELS = [ + {"id": "anthropic/claude-opus-4-7", "label": "Claude Opus 4.7"}, + {"id": "anthropic/claude-opus-4-6", "label": "Claude Opus 4.6"}, + {"id": "MiniMaxAI/MiniMax-M2.7", "label": "MiniMax M2.7"}, + {"id": "moonshotai/Kimi-K2.6", "label": "Kimi K2.6"}, + {"id": "zai-org/GLM-5.1", "label": "GLM 5.1"}, +] _ROUTING_POLICIES = {"fastest", "cheapest", "preferred"} @@ -29,17 +39,16 @@ def is_valid_model_id(model_id: str) -> bool: """Loose format check — lets users pick any model id. - Checks the adapter registry first (covers all registered providers), - then falls back to the structural ``/[:]`` pattern - so unknown HF models are still accepted. + Accepts: + • anthropic/ + • openai/ + • /[:] (HF router; tag = provider or policy) + • huggingface//[:] (same, accepts legacy prefix) + + Actual availability is verified against the HF router catalog on + switch, and by the provider on the probe's ping call. """ - if is_valid_model_name(model_id): - return True - if not model_id or "/" not in model_id: - return False - head = model_id.split(":", 1)[0] - parts = head.split("/") - return len(parts) >= 2 and all(parts) + return is_valid_model_name(model_id) def _print_hf_routing_info(model_id: str, console) -> bool: @@ -51,8 +60,7 @@ def _print_hf_routing_info(model_id: str, console) -> bool: Anthropic / OpenAI ids return ``True`` without printing anything — the probe below covers "does this model exist". """ - adapter = resolve_adapter(model_id) - if adapter and adapter.provider_id != "huggingface": + if model_id.startswith(("anthropic/", "openai/")): return True from agent.core import hf_router_catalog as cat @@ -107,9 +115,7 @@ def _print_hf_routing_info(model_id: str, console) -> bool: ) ctx = f"{p.context_length:,} ctx" if p.context_length else "ctx n/a" tools = "tools" if p.supports_tools else "no tools" - console.print( - f" [dim]{p.provider}: {price}, {ctx}, {tools}[/dim]" - ) + console.print(f" [dim]{p.provider}: {price}, {ctx}, {tools}[/dim]") return True @@ -119,10 +125,9 @@ def print_model_listing(config, console) -> None: console.print("[bold]Current model:[/bold]") console.print(f" {current}") console.print("\n[bold]Suggested:[/bold]") - for m in get_available_models(): + for m in SUGGESTED_MODELS: marker = " [dim]<-- current[/dim]" if m["id"] == current else "" - provider = m.get("providerLabel") or m.get("provider") or "" - console.print(f" {m['id']} [dim]({m['label']} · {provider})[/dim]{marker}") + console.print(f" {m['id']} [dim]({m['label']})[/dim]{marker}") console.print( "\n[dim]Paste any HF model id (e.g. 'MiniMaxAI/MiniMax-M2.7').\n" "Add ':fastest', ':cheapest', ':preferred', or ':' to override routing.\n" @@ -169,7 +174,9 @@ async def probe_and_switch_model( # Nothing to validate with a ping that we couldn't validate on the # first real call just as cheaply. Skip the probe entirely. _commit_switch(model_id, config, session, effective=None, cache=False) - console.print(f"[green]Model switched to {model_id}[/green] [dim](effort: off)[/dim]") + console.print( + f"[green]Model switched to {model_id}[/green] [dim](effort: off)[/dim]" + ) return console.print(f"[dim]checking {model_id} (effort: {preference})...[/dim]") @@ -189,8 +196,11 @@ async def probe_and_switch_model( return _commit_switch( - model_id, config, session, - effective=outcome.effective_effort, cache=True, + model_id, + config, + session, + effective=outcome.effective_effort, + cache=True, ) effort_label = outcome.effective_effort or "off" suffix = f" — {outcome.note}" if outcome.note else "" diff --git a/agent/core/provider_adapters.py b/agent/core/provider_adapters.py index b86f50c5..26381e2a 100644 --- a/agent/core/provider_adapters.py +++ b/agent/core/provider_adapters.py @@ -1,4 +1,4 @@ -"""Provider adapters for runtime params and model metadata.""" +"""Provider adapters for runtime params and model-name validation.""" import os from dataclasses import dataclass @@ -13,31 +13,29 @@ class UnsupportedEffortError(ValueError): """ -@dataclass(frozen=True) -class SuggestedModel: - id: str - label: str - description: str - provider: str - provider_label: str - avatar_url: str - recommended: bool = False +def _has_model_suffix(model_name: str, prefix: str) -> bool: + if not model_name.startswith(prefix): + return False + tail = model_name[len(prefix) :].split(":", 1)[0] + return bool(tail) and all(tail.split("/")) + + +def _is_hf_model_name(model_name: str) -> bool: + if model_name.startswith(("anthropic/", "openai/")): + return False + bare = model_name.removeprefix("huggingface/").split(":", 1)[0] + parts = bare.split("/") + return len(parts) >= 2 and all(parts) @dataclass(frozen=True) class ProviderAdapter: provider_id: str - provider_label: str prefixes: tuple[str, ...] = () - supports_custom_model: bool = False - custom_model_hint: str | None = None def matches(self, model_name: str) -> bool: return bool(self.prefixes) and model_name.startswith(self.prefixes) - def suggested_models(self) -> tuple[SuggestedModel, ...]: - return () - def build_params( self, model_name: str, @@ -49,9 +47,7 @@ def build_params( raise NotImplementedError def allows_model_name(self, model_name: str) -> bool: - if any(model.id == model_name for model in self.suggested_models()): - return True - return self.supports_custom_model and self.matches(model_name) + return self.matches(model_name) @dataclass(frozen=True) @@ -63,26 +59,8 @@ class AnthropicAdapter(ProviderAdapter): {"low", "medium", "high", "xhigh", "max"} ) - def suggested_models(self) -> tuple[SuggestedModel, ...]: - return ( - SuggestedModel( - id="anthropic/claude-opus-4-7", - label="Claude Opus 4.7", - description="Anthropic", - provider="anthropic", - provider_label="Anthropic", - avatar_url="https://huggingface.co/api/avatars/Anthropic", - recommended=True, - ), - SuggestedModel( - id="anthropic/claude-opus-4-6", - label="Claude Opus 4.6", - description="Anthropic", - provider="anthropic", - provider_label="Anthropic", - avatar_url="https://huggingface.co/api/avatars/Anthropic", - ), - ) + def allows_model_name(self, model_name: str) -> bool: + return _has_model_suffix(model_name, "anthropic/") def build_params( self, @@ -113,6 +91,9 @@ class OpenAIAdapter(ProviderAdapter): prefixes: tuple[str, ...] = ("openai/",) _EFFORTS: ClassVar[frozenset[str]] = frozenset({"minimal", "low", "medium", "high"}) + def allows_model_name(self, model_name: str) -> bool: + return _has_model_suffix(model_name, "openai/") + def build_params( self, model_name: str, @@ -139,47 +120,11 @@ class HfRouterAdapter(ProviderAdapter): _EFFORTS: ClassVar[frozenset[str]] = frozenset({"low", "medium", "high"}) - def _is_hf_model_name(self, model_name: str) -> bool: - if model_name.startswith(("anthropic/", "openai/")): - return False - bare = model_name.removeprefix("huggingface/").split(":", 1)[0] - parts = bare.split("/") - return len(parts) >= 2 and all(parts) - def matches(self, model_name: str) -> bool: - return self._is_hf_model_name(model_name) - - def suggested_models(self) -> tuple[SuggestedModel, ...]: - return ( - SuggestedModel( - id="MiniMaxAI/MiniMax-M2.7", - label="MiniMax M2.7", - description="HF Router", - provider="huggingface", - provider_label="Hugging Face Router", - avatar_url="https://huggingface.co/api/avatars/MiniMaxAI", - recommended=True, - ), - SuggestedModel( - id="moonshotai/Kimi-K2.6", - label="Kimi K2.6", - description="HF Router", - provider="huggingface", - provider_label="Hugging Face Router", - avatar_url="https://huggingface.co/api/avatars/moonshotai", - ), - SuggestedModel( - id="zai-org/GLM-5.1", - label="GLM 5.1", - description="HF Router", - provider="huggingface", - provider_label="Hugging Face Router", - avatar_url="https://huggingface.co/api/avatars/zai-org", - ), - ) + return not model_name.startswith(("anthropic/", "openai/")) def allows_model_name(self, model_name: str) -> bool: - return self._is_hf_model_name(model_name) + return _is_hf_model_name(model_name) def build_params( self, @@ -217,22 +162,9 @@ def build_params( ADAPTERS: tuple[ProviderAdapter, ...] = ( - AnthropicAdapter(provider_id="anthropic", provider_label="Anthropic"), - OpenAIAdapter( - provider_id="openai", - provider_label="OpenAI", - supports_custom_model=True, - custom_model_hint="Use openai/, for example openai/gpt-5", - ), - HfRouterAdapter( - provider_id="huggingface", - provider_label="Hugging Face Router", - supports_custom_model=True, - custom_model_hint=( - "Paste any Hugging Face model id, optionally with " - ":fastest, :cheapest, :preferred, or :" - ), - ), + AnthropicAdapter(provider_id="anthropic"), + OpenAIAdapter(provider_id="openai"), + HfRouterAdapter(provider_id="huggingface"), ) @@ -245,35 +177,4 @@ def resolve_adapter(model_name: str) -> ProviderAdapter | None: def is_valid_model_name(model_name: str) -> bool: adapter = resolve_adapter(model_name) - if not adapter: - return False - return adapter.allows_model_name(model_name) - - -def get_available_models() -> list[dict[str, Any]]: - available: list[dict[str, Any]] = [] - for adapter in ADAPTERS: - for model in adapter.suggested_models(): - available.append( - { - "id": model.id, - "label": model.label, - "description": model.description, - "provider": model.provider, - "providerLabel": model.provider_label, - "avatarUrl": model.avatar_url, - "recommended": model.recommended, - } - ) - return available - - -def is_suggested_model_name(model_name: str) -> bool: - return any(model["id"] == model_name for model in get_available_models()) - - -def build_model_catalog(current_model: str) -> dict[str, Any]: - return { - "current": current_model, - "available": get_available_models(), - } + return adapter is not None and adapter.allows_model_name(model_name) diff --git a/backend/routes/agent.py b/backend/routes/agent.py index 8a49b7fe..d8b3d775 100644 --- a/backend/routes/agent.py +++ b/backend/routes/agent.py @@ -28,31 +28,51 @@ SubmitRequest, TruncateRequest, ) -from session_manager import ( - MAX_SESSIONS, - AgentSession, - SessionCapacityError, - session_manager, -) +from session_manager import MAX_SESSIONS, AgentSession, SessionCapacityError, session_manager import user_quotas from agent.core.llm_params import _resolve_llm_params -from agent.core.provider_adapters import ( - build_model_catalog, - is_suggested_model_name, -) logger = logging.getLogger(__name__) router = APIRouter(prefix="/api", tags=["agent"]) +AVAILABLE_MODELS = [ + { + "id": "moonshotai/Kimi-K2.6", + "label": "Kimi K2.6", + "provider": "huggingface", + "tier": "free", + "recommended": True, + }, + { + "id": "anthropic/claude-opus-4-6", + "label": "Claude Opus 4.6", + "provider": "anthropic", + "tier": "pro", + "recommended": True, + }, + { + "id": "MiniMaxAI/MiniMax-M2.7", + "label": "MiniMax M2.7", + "provider": "huggingface", + "tier": "free", + }, + { + "id": "zai-org/GLM-5.1", + "label": "GLM 5.1", + "provider": "huggingface", + "tier": "free", + }, +] + async def _require_hf_for_anthropic(request: Request, model_id: str) -> None: """403 if a non-``huggingface``-org user tries to select an Anthropic model. Anthropic models are billed to the Space's ``ANTHROPIC_API_KEY``; every - other suggested web model is routed through HF Router and + other model in ``AVAILABLE_MODELS`` is routed through HF Router and billed via ``X-HF-Bill-To``. The gate only fires for ``anthropic/*`` so non-HF users can still freely switch between the free models. @@ -187,25 +207,10 @@ async def llm_health_check() -> LLMHealthResponse: @router.get("/config/model") async def get_model() -> dict: """Get current model and available models. No auth required.""" - return build_model_catalog(session_manager.config.model_name) - - -@router.post("/config/model") -async def set_model( - body: dict, - request: Request, - user: dict = Depends(get_current_user), -) -> dict: - """Set the LLM model. Applies to new conversations.""" - model_id = body.get("model") - if not model_id: - raise HTTPException(status_code=400, detail="Missing 'model' field") - if not is_suggested_model_name(model_id): - raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}") - await _require_hf_for_anthropic(request, model_id) - session_manager.config.model_name = model_id - logger.info(f"Model changed to {model_id} by {user.get('username', 'unknown')}") - return {"model": model_id} + return { + "current": session_manager.config.model_name, + "available": AVAILABLE_MODELS, + } _TITLE_STRIP_CHARS = str.maketrans("", "", "`*_~#[]()") @@ -300,7 +305,8 @@ async def create_session( if isinstance(body, dict): model = body.get("model") - if model and not is_suggested_model_name(model): + valid_ids = {m["id"] for m in AVAILABLE_MODELS} + if model and model not in valid_ids: raise HTTPException(status_code=400, detail=f"Unknown model: {model}") # Opus is gated to HF staff (PR #63). Only fires when the resolved model @@ -344,7 +350,8 @@ async def restore_session_summary( hf_token = os.environ.get("HF_TOKEN") model = body.get("model") - if model and not is_suggested_model_name(model): + valid_ids = {m["id"] for m in AVAILABLE_MODELS} + if model and model not in valid_ids: raise HTTPException(status_code=400, detail=f"Unknown model: {model}") resolved_model = model or session_manager.config.model_name @@ -402,7 +409,8 @@ async def set_session_model( model_id = body.get("model") if not model_id: raise HTTPException(status_code=400, detail="Missing 'model' field") - if not is_suggested_model_name(model_id): + valid_ids = {m["id"] for m in AVAILABLE_MODELS} + if model_id not in valid_ids: raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}") await _require_hf_for_anthropic(request, model_id) agent_session = session_manager.sessions.get(session_id) @@ -536,9 +544,7 @@ async def chat_sse( success = await session_manager.submit_user_input(session_id, text) else: broadcaster.unsubscribe(sub_id) - raise HTTPException( - status_code=400, detail="Must provide 'text' or 'approvals'" - ) + raise HTTPException(status_code=400, detail="Must provide 'text' or 'approvals'") if not success: broadcaster.unsubscribe(sub_id) @@ -555,13 +561,7 @@ async def chat_sse( # --------------------------------------------------------------------------- # Shared SSE helpers # --------------------------------------------------------------------------- -_TERMINAL_EVENTS = { - "turn_complete", - "approval_required", - "error", - "interrupted", - "shutdown", -} +_TERMINAL_EVENTS = {"turn_complete", "approval_required", "error", "interrupted", "shutdown"} _SSE_KEEPALIVE_SECONDS = 15 @@ -661,10 +661,7 @@ async def truncate_session( _check_session_access(session_id, user) success = await session_manager.truncate(session_id, body.user_message_index) if not success: - raise HTTPException( - status_code=404, - detail="Session not found, inactive, or message index out of range", - ) + raise HTTPException(status_code=404, detail="Session not found, inactive, or message index out of range") return {"status": "truncated", "session_id": session_id} @@ -690,3 +687,5 @@ async def shutdown_session( if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") return {"status": "shutdown_requested", "session_id": session_id} + + diff --git a/frontend/src/components/Chat/ChatInput.tsx b/frontend/src/components/Chat/ChatInput.tsx index 3f28fb52..d9fe5c4d 100644 --- a/frontend/src/components/Chat/ChatInput.tsx +++ b/frontend/src/components/Chat/ChatInput.tsx @@ -1,4 +1,4 @@ -import { useState, useCallback, useEffect, useMemo, useRef, type KeyboardEvent, type MouseEvent } from 'react'; +import { useState, useCallback, useEffect, useRef, KeyboardEvent } from 'react'; import { Box, TextField, IconButton, CircularProgress, Typography, Menu, MenuItem, ListItemIcon, ListItemText, Chip } from '@mui/material'; import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward'; import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown'; @@ -9,90 +9,58 @@ import ClaudeCapDialog from '@/components/ClaudeCapDialog'; import { useAgentStore } from '@/store/agentStore'; import { FIRST_FREE_MODEL_PATH } from '@/utils/model'; +// Model configuration interface ModelOption { id: string; - label: string; + name: string; description: string; + modelPath: string; avatarUrl: string; - providerLabel?: string; recommended?: boolean; } -interface ModelCatalogResponse { - current?: string; - available?: unknown; -} - -interface SessionResponse { - model?: string; -} +const getHfAvatarUrl = (modelId: string) => { + const org = modelId.split('/')[0]; + return `https://huggingface.co/api/avatars/${org}`; +}; -const FALLBACK_MODELS: ModelOption[] = [ +const MODEL_OPTIONS: ModelOption[] = [ { - id: 'anthropic/claude-opus-4-6', - label: 'Claude Opus 4.6', - description: 'Anthropic', - avatarUrl: 'https://huggingface.co/api/avatars/Anthropic', - providerLabel: 'Anthropic', + id: 'kimi-k2.6', + name: 'Kimi K2.6', + description: 'Novita', + modelPath: 'moonshotai/Kimi-K2.6', + avatarUrl: getHfAvatarUrl('moonshotai/Kimi-K2.6'), recommended: true, }, { - id: 'MiniMaxAI/MiniMax-M2.7', - label: 'MiniMax M2.7', - description: 'HF Router', - avatarUrl: 'https://huggingface.co/api/avatars/MiniMaxAI', - providerLabel: 'Hugging Face Router', + id: 'claude-opus', + name: 'Claude Opus 4.6', + description: 'Anthropic', + modelPath: 'anthropic/claude-opus-4-6', + avatarUrl: 'https://huggingface.co/api/avatars/Anthropic', recommended: true, }, { - id: 'moonshotai/Kimi-K2.6', - label: 'Kimi K2.6', - description: 'HF Router', - avatarUrl: 'https://huggingface.co/api/avatars/moonshotai', - providerLabel: 'Hugging Face Router', + id: 'minimax-m2.7', + name: 'MiniMax M2.7', + description: 'Novita', + modelPath: 'MiniMaxAI/MiniMax-M2.7', + avatarUrl: getHfAvatarUrl('MiniMaxAI/MiniMax-M2.7'), }, { - id: 'zai-org/GLM-5.1', - label: 'GLM 5.1', - description: 'HF Router', - avatarUrl: 'https://huggingface.co/api/avatars/zai-org', - providerLabel: 'Hugging Face Router', + id: 'glm-5.1', + name: 'GLM 5.1', + description: 'Together', + modelPath: 'zai-org/GLM-5.1', + avatarUrl: getHfAvatarUrl('zai-org/GLM-5.1'), }, ]; -const isRecord = (value: unknown): value is Record => ( - typeof value === 'object' && value !== null -); - -const toModelOption = (value: unknown): ModelOption | null => { - if (!isRecord(value)) return null; - if (typeof value.id !== 'string' || typeof value.label !== 'string') return null; - - const description = typeof value.description === 'string' - ? value.description - : typeof value.providerLabel === 'string' - ? value.providerLabel - : ''; - - return { - id: value.id, - label: value.label, - description, - avatarUrl: typeof value.avatarUrl === 'string' - ? value.avatarUrl - : 'https://huggingface.co/api/avatars/huggingface', - providerLabel: typeof value.providerLabel === 'string' ? value.providerLabel : undefined, - recommended: Boolean(value.recommended), - }; +const findModelByPath = (path: string): ModelOption | undefined => { + return MODEL_OPTIONS.find(m => m.modelPath === path || path?.includes(m.id)); }; -const makeUnknownModelOption = (modelId: string): ModelOption => ({ - id: modelId, - label: modelId, - description: 'Custom model', - avatarUrl: 'https://huggingface.co/api/avatars/huggingface', -}); - interface ChatInputProps { sessionId?: string; onSend: (text: string) => void; @@ -102,14 +70,13 @@ interface ChatInputProps { placeholder?: string; } -const isClaudeModel = (model: ModelOption) => model.id.startsWith('anthropic/'); +const isClaudeModel = (m: ModelOption) => m.modelPath.startsWith('anthropic/'); +const firstFreeModel = () => MODEL_OPTIONS.find(m => !isClaudeModel(m)) ?? MODEL_OPTIONS[0]; export default function ChatInput({ sessionId, onSend, onStop, isProcessing = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) { const [input, setInput] = useState(''); const inputRef = useRef(null); - const [modelOptions, setModelOptions] = useState(FALLBACK_MODELS); - const [catalogCurrent, setCatalogCurrent] = useState(FALLBACK_MODELS[0].id); - const [sessionModel, setSessionModel] = useState(null); + const [selectedModelId, setSelectedModelId] = useState(MODEL_OPTIONS[0].id); const [modelAnchorEl, setModelAnchorEl] = useState(null); const { quota, refresh: refreshQuota } = useUserQuota(); // The daily-cap dialog is triggered from two places: (a) a 429 returned @@ -121,58 +88,27 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa const setClaudeQuotaExhausted = useAgentStore((s) => s.setClaudeQuotaExhausted); const lastSentRef = useRef(''); + // Model is per-session: fetch this tab's current model every time the + // session changes. Other tabs keep their own selections independently. useEffect(() => { - let cancelled = false; - - apiFetch('/api/config/model') - .then((res) => (res.ok ? res.json() as Promise : null)) - .then((data) => { - if (cancelled || !data) return; - - const rawAvailable = Array.isArray(data.available) ? data.available : []; - const available = rawAvailable - .map(toModelOption) - .filter((value: ModelOption | null): value is ModelOption => value !== null); - - if (available.length > 0) { - setModelOptions(available); - } - if (typeof data.current === 'string' && data.current) { - setCatalogCurrent(data.current); - } - }) - .catch(() => { /* ignore */ }); - - return () => { cancelled = true; }; - }, []); - - useEffect(() => { - if (!sessionId) { - setSessionModel(null); - return; - } - + if (!sessionId) return; let cancelled = false; apiFetch(`/api/session/${sessionId}`) - .then((res) => (res.ok ? res.json() as Promise : null)) + .then((res) => (res.ok ? res.json() : null)) .then((data) => { if (cancelled) return; - setSessionModel(typeof data?.model === 'string' && data.model ? data.model : null); + if (data?.model) { + const model = findModelByPath(data.model); + if (model) setSelectedModelId(model.id); + } }) .catch(() => { /* ignore */ }); - return () => { cancelled = true; }; }, [sessionId]); - const selectedModelPath = sessionModel ?? catalogCurrent; - const selectedModel = useMemo( - () => modelOptions.find((model) => model.id === selectedModelPath) - ?? (selectedModelPath ? makeUnknownModelOption(selectedModelPath) : null) - ?? modelOptions[0] - ?? FALLBACK_MODELS[0], - [modelOptions, selectedModelPath], - ); + const selectedModel = MODEL_OPTIONS.find(m => m.id === selectedModelId) || MODEL_OPTIONS[0]; + // Auto-focus the textarea when the session becomes ready useEffect(() => { if (!disabled && !isProcessing && inputRef.current) { inputRef.current.focus(); @@ -199,7 +135,8 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa // have started another tab that spent quota). useEffect(() => { if (sessionId) refreshQuota(); - }, [refreshQuota, sessionId]); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [sessionId]); const handleKeyDown = useCallback( (e: KeyboardEvent) => { @@ -208,10 +145,10 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa handleSend(); } }, - [handleSend], + [handleSend] ); - const handleModelClick = (event: MouseEvent) => { + const handleModelClick = (event: React.MouseEvent) => { setModelAnchorEl(event.currentTarget); }; @@ -219,21 +156,16 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa setModelAnchorEl(null); }; - const handleSelectModel = async (modelPath: string) => { + const handleSelectModel = async (model: ModelOption) => { handleModelClose(); if (!sessionId) return; - try { const res = await apiFetch(`/api/session/${sessionId}/model`, { method: 'POST', - body: JSON.stringify({ model: modelPath }), + body: JSON.stringify({ model: model.modelPath }), }); - if (res.ok) { - setSessionModel(modelPath); - } - } catch { - // ignore - } + if (res.ok) setSelectedModelId(model.id); + } catch { /* ignore */ } }; // Dialog close: just clear the flag. The typed text is already restored. @@ -246,18 +178,15 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa const handleUseFreeModel = useCallback(async () => { setClaudeQuotaExhausted(false); if (!sessionId) return; - const free = modelOptions.find((model) => model.id === FIRST_FREE_MODEL_PATH) - ?? modelOptions.find((model) => !isClaudeModel(model)) - ?? modelOptions[0]; - if (!free) return; - + const free = MODEL_OPTIONS.find(m => m.modelPath === FIRST_FREE_MODEL_PATH) + ?? firstFreeModel(); try { const res = await apiFetch(`/api/session/${sessionId}/model`, { method: 'POST', - body: JSON.stringify({ model: free.id }), + body: JSON.stringify({ model: free.modelPath }), }); if (res.ok) { - setSessionModel(free.id); + setSelectedModelId(free.id); const retryText = lastSentRef.current; if (retryText) { onSend(retryText); @@ -265,10 +194,8 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa lastSentRef.current = ''; } } - } catch { - // ignore - } - }, [modelOptions, onSend, sessionId, setClaudeQuotaExhausted]); + } catch { /* ignore */ } + }, [sessionId, onSend, setClaudeQuotaExhausted]); // Hide the chip until the user has actually burned quota — an unused // Opus session shouldn't populate a counter. @@ -302,9 +229,9 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa border: '1px solid var(--border)', transition: 'box-shadow 0.2s ease, border-color 0.2s ease', '&:focus-within': { - borderColor: 'var(--accent-yellow)', - boxShadow: 'var(--focus)', - }, + borderColor: 'var(--accent-yellow)', + boxShadow: 'var(--focus)', + } }} > {isProcessing ? ( @@ -387,6 +314,7 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa )} + {/* Powered By Badge */} @@ -408,15 +336,16 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa {selectedModel.label} - {selectedModel.label} + {selectedModel.name} + {/* Model Selection Menu */} - {modelOptions.map((model) => ( + {MODEL_OPTIONS.map((model) => ( handleSelectModel(model.id)} - selected={selectedModelPath === model.id} + onClick={() => handleSelectModel(model)} + selected={selectedModelId === model.id} sx={{ py: 1.5, '&.Mui-selected': { bgcolor: 'rgba(255,255,255,0.05)', - }, + } }} > {model.label} - {model.label} + {model.name} {model.recommended && ( )} - )} - secondary={model.description || model.providerLabel} + } + secondary={model.description} secondaryTypographyProps={{ - sx: { fontSize: '12px', color: 'var(--muted-text)' }, + sx: { fontSize: '12px', color: 'var(--muted-text)' } }} /> diff --git a/frontend/src/utils/model.ts b/frontend/src/utils/model.ts index 2358fe68..89f23fe7 100644 --- a/frontend/src/utils/model.ts +++ b/frontend/src/utils/model.ts @@ -1,4 +1,11 @@ -/** Shared model-id constants used by the web UI. */ +/** + * Shared model-id constants used by session-create call sites and the + * ClaudeCapDialog "Use a free model" escape hatch. + * + * Keep in sync with MODEL_OPTIONS in components/Chat/ChatInput.tsx and + * AVAILABLE_MODELS in backend/routes/agent.py. Bare HF ids (no + * `huggingface/` prefix) — matches upstream's auto-router. + */ export const CLAUDE_MODEL_PATH = 'anthropic/claude-opus-4-6'; export const FIRST_FREE_MODEL_PATH = 'moonshotai/Kimi-K2.6'; diff --git a/tests/test_backend_routes.py b/tests/test_backend_routes.py deleted file mode 100644 index 12f5a30f..00000000 --- a/tests/test_backend_routes.py +++ /dev/null @@ -1,75 +0,0 @@ -from __future__ import annotations - -import sys -from pathlib import Path -from types import SimpleNamespace - -from fastapi.testclient import TestClient - - -ROOT = Path(__file__).resolve().parent.parent -BACKEND_ROOT = ROOT / "backend" -if str(BACKEND_ROOT) not in sys.path: - sys.path.insert(0, str(BACKEND_ROOT)) - -from main import app # noqa: E402 -from routes import agent as agent_routes # noqa: E402 - - -client = TestClient(app) - - -class DummySession: - def __init__(self, model_name: str) -> None: - self.config = SimpleNamespace(model_name=model_name) - self.context_manager = SimpleNamespace(items=[]) - self.pending_approval = None - - def update_model(self, model_name: str) -> None: - self.config.model_name = model_name - - -def test_get_model_returns_current_and_available_only() -> None: - original = agent_routes.session_manager.config.model_name - agent_routes.session_manager.config.model_name = "anthropic/claude-opus-4-6" - try: - response = client.get("/api/config/model") - finally: - agent_routes.session_manager.config.model_name = original - - assert response.status_code == 200 - data = response.json() - assert set(data) == {"current", "available"} - assert data["current"] == "anthropic/claude-opus-4-6" - assert any(model["id"] == "moonshotai/Kimi-K2.6" for model in data["available"]) - - -def test_set_model_rejects_custom_hf_id() -> None: - response = client.post( - "/api/config/model", json={"model": "moonshotai/Kimi-K2.6:fastest"} - ) - - assert response.status_code == 400 - assert response.json()["detail"] == "Unknown model: moonshotai/Kimi-K2.6:fastest" - - -def test_set_session_model_rejects_custom_hf_id() -> None: - session_id = "test-session" - sessions = agent_routes.session_manager.sessions - sessions[session_id] = SimpleNamespace( - user_id="dev", - is_active=True, - is_processing=False, - created_at=SimpleNamespace(isoformat=lambda: "2026-01-01T00:00:00"), - session=DummySession("moonshotai/Kimi-K2.6"), - ) - try: - response = client.post( - f"/api/session/{session_id}/model", - json={"model": "moonshotai/Kimi-K2.6:fastest"}, - ) - finally: - sessions.pop(session_id, None) - - assert response.status_code == 400 - assert response.json()["detail"] == "Unknown model: moonshotai/Kimi-K2.6:fastest" diff --git a/tests/test_provider_adapters.py b/tests/test_provider_adapters.py index 0b49dad4..2b6de9dd 100644 --- a/tests/test_provider_adapters.py +++ b/tests/test_provider_adapters.py @@ -1,10 +1,9 @@ import pytest from agent.core.llm_params import _resolve_llm_params +from agent.core.model_switcher import is_valid_model_id from agent.core.provider_adapters import ( UnsupportedEffortError, - build_model_catalog, - is_suggested_model_name, is_valid_model_name, ) @@ -100,22 +99,7 @@ def test_hf_adapter_strict_rejects_max(): ) -# -- Catalog & validation ----------------------------------------------------- - - -def test_model_catalog_comes_from_adapters(): - catalog = build_model_catalog("anthropic/claude-opus-4-6") - - assert catalog["current"] == "anthropic/claude-opus-4-6" - assert any(model["provider"] == "anthropic" for model in catalog["available"]) - assert any(model["provider"] == "huggingface" for model in catalog["available"]) - assert set(catalog) == {"current", "available"} - - -def test_suggested_model_validation_is_strict(): - assert is_suggested_model_name("anthropic/claude-opus-4-6") is True - assert is_suggested_model_name("moonshotai/Kimi-K2.6") is True - assert is_suggested_model_name("moonshotai/Kimi-K2.6:fastest") is False +# -- Validation --------------------------------------------------------------- def test_model_validation_accepts_free_form_hf_ids(): @@ -123,9 +107,25 @@ def test_model_validation_accepts_free_form_hf_ids(): assert is_valid_model_name("huggingface/moonshotai/Kimi-K2.6:novita") is True +def test_model_validation_accepts_direct_provider_ids(): + assert is_valid_model_name("anthropic/claude-opus-4-7") is True + assert is_valid_model_name("openai/gpt-5") is True + + def test_model_validation_rejects_garbage(): assert is_valid_model_name("") is False assert is_valid_model_name("no-slash") is False + assert is_valid_model_name("anthropic/") is False + assert is_valid_model_name("openai/") is False + assert is_valid_model_name("huggingface/nope") is False + assert is_valid_model_name("moonshotai/") is False + + +def test_cli_validation_matches_provider_validation(): + assert is_valid_model_id("openai/gpt-5") is True + assert is_valid_model_id("moonshotai/Kimi-K2.6:fastest") is True + assert is_valid_model_id("openai/") is False + assert is_valid_model_id("anthropic/") is False def test_unsupported_effort_reexport(): From d559a37cb368f82e552582ac39c1fe9f8c0fbd18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Vy=C5=A1n=C3=BD?= Date: Thu, 23 Apr 2026 13:30:13 +0200 Subject: [PATCH 07/15] Clean provider error handling Classify auth, credits, and missing-model failures once so the CLI, model switcher, and health checks show clean user-facing errors instead of raw provider traces. --- agent/core/agent_loop.py | 376 +++++++++++++++++++++-------------- agent/core/llm_errors.py | 147 ++++++++++++++ agent/core/model_switcher.py | 8 +- backend/routes/agent.py | 53 +++-- backend/session_manager.py | 35 ++-- tests/test_llm_errors.py | 124 ++++++++++++ 6 files changed, 555 insertions(+), 188 deletions(-) create mode 100644 agent/core/llm_errors.py create mode 100644 tests/test_llm_errors.py diff --git a/agent/core/agent_loop.py b/agent/core/agent_loop.py index c3fd88bc..3129716e 100644 --- a/agent/core/agent_loop.py +++ b/agent/core/agent_loop.py @@ -13,6 +13,7 @@ from agent.config import Config from agent.core.doom_loop import check_for_doom_loop +from agent.core.llm_errors import friendly_llm_error_message, render_llm_error_message from agent.core.llm_params import _resolve_llm_params from agent.core.prompt_caching import with_prompt_caching from agent.core.session import Event, OpType, Session @@ -125,14 +126,24 @@ def _is_transient_error(error: Exception) -> bool: """Return True for errors that are likely transient and worth retrying.""" err_str = str(error).lower() transient_patterns = [ - "timeout", "timed out", - "429", "rate limit", "rate_limit", - "503", "service unavailable", - "502", "bad gateway", - "500", "internal server error", - "overloaded", "capacity", - "connection reset", "connection refused", "connection error", - "eof", "broken pipe", + "timeout", + "timed out", + "429", + "rate limit", + "rate_limit", + "503", + "service unavailable", + "502", + "bad gateway", + "500", + "internal server error", + "overloaded", + "capacity", + "connection reset", + "connection refused", + "connection error", + "eof", + "broken pipe", ] return any(pattern in err_str for pattern in transient_patterns) @@ -146,11 +157,14 @@ def _is_effort_config_error(error: Exception) -> bool: doesn't work for the current model. We heal the cache and retry once. """ from agent.core.effort_probe import _is_invalid_effort, _is_thinking_unsupported + return _is_thinking_unsupported(error) or _is_invalid_effort(error) async def _heal_effort_and_rebuild_params( - session: Session, error: Exception, llm_params: dict, + session: Session, + error: Exception, + llm_params: dict, ) -> dict: """Update the session's effort cache based on ``error`` and return new llm_params. Called only when ``_is_effort_config_error(error)`` is True. @@ -161,7 +175,11 @@ async def _heal_effort_and_rebuild_params( • invalid-effort → re-run the full cascade probe; the result lands in the cache """ - from agent.core.effort_probe import ProbeInconclusive, _is_thinking_unsupported, probe_effort + from agent.core.effort_probe import ( + ProbeInconclusive, + _is_thinking_unsupported, + probe_effort, + ) model = session.config.model_name if _is_thinking_unsupported(error): @@ -170,11 +188,15 @@ async def _heal_effort_and_rebuild_params( else: try: outcome = await probe_effort( - model, session.config.reasoning_effort, session.hf_token, + model, + session.config.reasoning_effort, + session.hf_token, ) session.model_effective_effort[model] = outcome.effective_effort logger.info( - "healed: %s effort cascade → %s", model, outcome.effective_effort, + "healed: %s effort cascade → %s", + model, + outcome.effective_effort, ) except ProbeInconclusive: # Transient during healing — strip thinking for safety, next @@ -191,44 +213,7 @@ async def _heal_effort_and_rebuild_params( def _friendly_error_message(error: Exception) -> str | None: """Return a user-friendly message for known error types, or None to fall back to traceback.""" - err_str = str(error).lower() - - if "authentication" in err_str or "unauthorized" in err_str or "invalid x-api-key" in err_str: - return ( - "Authentication failed — your API key is missing or invalid.\n\n" - "To fix this, set the API key for your model provider:\n" - " • Anthropic: export ANTHROPIC_API_KEY=sk-...\n" - " • OpenAI: export OPENAI_API_KEY=sk-...\n" - " • HF Router: export HF_TOKEN=hf_...\n\n" - "You can also add it to a .env file in the project root.\n" - "To switch models, use the /model command." - ) - - if "insufficient" in err_str and "credit" in err_str: - return ( - "Insufficient API credits. Please check your account balance " - "at your model provider's dashboard." - ) - - if "not supported by provider" in err_str or "no provider supports" in err_str: - return ( - "The model isn't served by the provider you pinned.\n\n" - "Drop the ':' suffix to let the HF router auto-pick a " - "provider, or use '/model' (no arg) to see which providers host " - "which models." - ) - - if "model_not_found" in err_str or ( - "model" in err_str - and ("not found" in err_str or "does not exist" in err_str) - ): - return ( - "Model not found. Use '/model' to list suggestions, or paste an " - "HF model id like 'MiniMaxAI/MiniMax-M2.7'. Availability is shown " - "when you switch." - ) - - return None + return friendly_llm_error_message(error) async def _compact_and_notify(session: Session) -> None: @@ -237,7 +222,10 @@ async def _compact_and_notify(session: Session) -> None: old_usage = cm.running_context_usage logger.debug( "Compaction check: usage=%d, max=%d, threshold=%d, needs_compact=%s", - old_usage, cm.model_max_tokens, cm.compaction_threshold, cm.needs_compaction, + old_usage, + cm.model_max_tokens, + cm.compaction_threshold, + cm.needs_compaction, ) await cm.compact( model_name=session.config.model_name, @@ -248,7 +236,10 @@ async def _compact_and_notify(session: Session) -> None: if new_usage != old_usage: logger.warning( "Context compacted: %d -> %d tokens (max=%d, %d messages)", - old_usage, new_usage, cm.model_max_tokens, len(cm.items), + old_usage, + new_usage, + cm.model_max_tokens, + len(cm.items), ) await session.send_event( Event( @@ -287,13 +278,16 @@ async def _cleanup_on_cancel(session: Session) -> None: @dataclass class LLMResult: """Result from an LLM call (streaming or non-streaming).""" + content: str | None tool_calls_acc: dict[int, dict] token_count: int finish_reason: str | None -async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> LLMResult: +async def _call_llm_streaming( + session: Session, messages, tools, llm_params +) -> LLMResult: """Call the LLM with streaming, emitting assistant_chunk events.""" response = None _healed_effort = False # one-shot safety net per call @@ -315,22 +309,37 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> except Exception as e: if not _healed_effort and _is_effort_config_error(e): _healed_effort = True - llm_params = await _heal_effort_and_rebuild_params(session, e, llm_params) - await session.send_event(Event( - event_type="tool_log", - data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."}, - )) + llm_params = await _heal_effort_and_rebuild_params( + session, e, llm_params + ) + await session.send_event( + Event( + event_type="tool_log", + data={ + "tool": "system", + "log": "Reasoning effort not supported for this model — adjusting and retrying.", + }, + ) + ) continue if _llm_attempt < _MAX_LLM_RETRIES - 1 and _is_transient_error(e): _delay = _LLM_RETRY_DELAYS[_llm_attempt] logger.warning( "Transient LLM error (attempt %d/%d): %s — retrying in %ds", - _llm_attempt + 1, _MAX_LLM_RETRIES, e, _delay, + _llm_attempt + 1, + _MAX_LLM_RETRIES, + e, + _delay, + ) + await session.send_event( + Event( + event_type="tool_log", + data={ + "tool": "system", + "log": f"LLM connection error, retrying in {_delay}s...", + }, + ) ) - await session.send_event(Event( - event_type="tool_log", - data={"tool": "system", "log": f"LLM connection error, retrying in {_delay}s..."}, - )) await asyncio.sleep(_delay) continue raise @@ -366,16 +375,21 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> idx = tc_delta.index if idx not in tool_calls_acc: tool_calls_acc[idx] = { - "id": "", "type": "function", + "id": "", + "type": "function", "function": {"name": "", "arguments": ""}, } if tc_delta.id: tool_calls_acc[idx]["id"] = tc_delta.id if tc_delta.function: if tc_delta.function.name: - tool_calls_acc[idx]["function"]["name"] += tc_delta.function.name + tool_calls_acc[idx]["function"]["name"] += ( + tc_delta.function.name + ) if tc_delta.function.arguments: - tool_calls_acc[idx]["function"]["arguments"] += tc_delta.function.arguments + tool_calls_acc[idx]["function"]["arguments"] += ( + tc_delta.function.arguments + ) if hasattr(chunk, "usage") and chunk.usage: token_count = chunk.usage.total_tokens @@ -388,7 +402,9 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> ) -async def _call_llm_non_streaming(session: Session, messages, tools, llm_params) -> LLMResult: +async def _call_llm_non_streaming( + session: Session, messages, tools, llm_params +) -> LLMResult: """Call the LLM without streaming, emit assistant_message at the end.""" response = None _healed_effort = False @@ -409,22 +425,37 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params) except Exception as e: if not _healed_effort and _is_effort_config_error(e): _healed_effort = True - llm_params = await _heal_effort_and_rebuild_params(session, e, llm_params) - await session.send_event(Event( - event_type="tool_log", - data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."}, - )) + llm_params = await _heal_effort_and_rebuild_params( + session, e, llm_params + ) + await session.send_event( + Event( + event_type="tool_log", + data={ + "tool": "system", + "log": "Reasoning effort not supported for this model — adjusting and retrying.", + }, + ) + ) continue if _llm_attempt < _MAX_LLM_RETRIES - 1 and _is_transient_error(e): _delay = _LLM_RETRY_DELAYS[_llm_attempt] logger.warning( "Transient LLM error (attempt %d/%d): %s — retrying in %ds", - _llm_attempt + 1, _MAX_LLM_RETRIES, e, _delay, + _llm_attempt + 1, + _MAX_LLM_RETRIES, + e, + _delay, + ) + await session.send_event( + Event( + event_type="tool_log", + data={ + "tool": "system", + "log": f"LLM connection error, retrying in {_delay}s...", + }, + ) ) - await session.send_event(Event( - event_type="tool_log", - data={"tool": "system", "log": f"LLM connection error, retrying in {_delay}s..."}, - )) await asyncio.sleep(_delay) continue raise @@ -505,7 +536,8 @@ async def _abandon_pending_approval(session: Session) -> None: @staticmethod async def run_agent( - session: Session, text: str, + session: Session, + text: str, ) -> str | None: """ Handle user input (like user_input_or_turn in codex.rs:1291) @@ -569,12 +601,18 @@ async def run_agent( llm_params = _resolve_llm_params( session.config.model_name, session.hf_token, - reasoning_effort=session.effective_effort_for(session.config.model_name), + reasoning_effort=session.effective_effort_for( + session.config.model_name + ), ) if session.stream: - llm_result = await _call_llm_streaming(session, messages, tools, llm_params) + llm_result = await _call_llm_streaming( + session, messages, tools, llm_params + ) else: - llm_result = await _call_llm_non_streaming(session, messages, tools, llm_params) + llm_result = await _call_llm_non_streaming( + session, messages, tools, llm_params + ) content = llm_result.content tool_calls_acc = llm_result.tool_calls_acc @@ -618,7 +656,10 @@ async def run_agent( await session.send_event( Event( event_type="tool_log", - data={"tool": "system", "log": f"Output truncated — retrying with smaller content ({dropped_names})"}, + data={ + "tool": "system", + "log": f"Output truncated — retrying with smaller content ({dropped_names})", + }, ) ) iteration += 1 @@ -678,7 +719,8 @@ async def run_agent( except (json.JSONDecodeError, TypeError, ValueError): logger.warning( "Malformed arguments for tool_call %s (%s) — skipping", - tc.id, tc.function.name, + tc.id, + tc.function.name, ) tc.function.arguments = "{}" bad_tools.append(tc) @@ -699,20 +741,35 @@ async def run_agent( f"arguments and was NOT executed. Retry with smaller content — " f"for 'write', split into multiple smaller writes using 'edit'." ) - session.context_manager.add_message(Message( - role="tool", - content=error_msg, - tool_call_id=tc.id, - name=tc.function.name, - )) - await session.send_event(Event( - event_type="tool_call", - data={"tool": tc.function.name, "arguments": {}, "tool_call_id": tc.id}, - )) - await session.send_event(Event( - event_type="tool_output", - data={"tool": tc.function.name, "tool_call_id": tc.id, "output": error_msg, "success": False}, - )) + session.context_manager.add_message( + Message( + role="tool", + content=error_msg, + tool_call_id=tc.id, + name=tc.function.name, + ) + ) + await session.send_event( + Event( + event_type="tool_call", + data={ + "tool": tc.function.name, + "arguments": {}, + "tool_call_id": tc.id, + }, + ) + ) + await session.send_event( + Event( + event_type="tool_output", + data={ + "tool": tc.function.name, + "tool_call_id": tc.id, + "output": error_msg, + "success": False, + }, + ) + ) # ── Cancellation check: before tool execution ── if session.is_cancelled: @@ -730,9 +787,7 @@ async def run_agent( # Execute non-approval tools (in parallel when possible) if non_approval_tools: # 1. Validate args upfront - parsed_tools: list[ - tuple[ToolCall, str, dict, bool, str] - ] = [] + parsed_tools: list[tuple[ToolCall, str, dict, bool, str]] = [] for tc, tool_name, tool_args in non_approval_tools: args_valid, error_msg = _validate_tool_args(tool_args) parsed_tools.append( @@ -768,12 +823,14 @@ async def _exec_tool( ) return (tc, name, args, out, ok) - gather_task = asyncio.ensure_future(asyncio.gather( - *[ - _exec_tool(tc, name, args, valid, err) - for tc, name, args, valid, err in parsed_tools - ] - )) + gather_task = asyncio.ensure_future( + asyncio.gather( + *[ + _exec_tool(tc, name, args, valid, err) + for tc, name, args, valid, err in parsed_tools + ] + ) + ) cancel_task = asyncio.ensure_future(session._cancelled.wait()) done, _ = await asyncio.wait( @@ -790,10 +847,16 @@ async def _exec_tool( # Notify frontend that in-flight tools were cancelled for tc, name, _args, valid, _ in parsed_tools: if valid: - await session.send_event(Event( - event_type="tool_state_change", - data={"tool_call_id": tc.id, "tool": name, "state": "cancelled"}, - )) + await session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": tc.id, + "tool": name, + "state": "cancelled", + }, + ) + ) await _cleanup_on_cancel(session) break @@ -829,23 +892,32 @@ async def _exec_tool( for tc, tool_name, tool_args in approval_required_tools: # Resolve sandbox file paths for hf_jobs scripts so the # frontend can display & edit the actual file content. - if tool_name == "hf_jobs" and isinstance(tool_args.get("script"), str): + if tool_name == "hf_jobs" and isinstance( + tool_args.get("script"), str + ): from agent.tools.sandbox_tool import resolve_sandbox_script + sandbox = getattr(session, "sandbox", None) - resolved, _ = await resolve_sandbox_script(sandbox, tool_args["script"]) + resolved, _ = await resolve_sandbox_script( + sandbox, tool_args["script"] + ) if resolved: tool_args = {**tool_args, "script": resolved} - tools_data.append({ - "tool": tool_name, - "arguments": tool_args, - "tool_call_id": tc.id, - }) + tools_data.append( + { + "tool": tool_name, + "arguments": tool_args, + "tool_call_id": tc.id, + } + ) - await session.send_event(Event( - event_type="approval_required", - data={"tools": tools_data, "count": len(tools_data)}, - )) + await session.send_event( + Event( + event_type="approval_required", + data={"tools": tools_data, "count": len(tools_data)}, + ) + ) # Store all approval-requiring tools (ToolCall objects for execution) session.pending_approval = { @@ -863,18 +935,19 @@ async def _exec_tool( logger.warning( "ContextWindowExceededError at iteration %d — forcing compaction " "(usage=%d, model_max_tokens=%d, messages=%d)", - iteration, cm.running_context_usage, cm.model_max_tokens, len(cm.items), + iteration, + cm.running_context_usage, + cm.model_max_tokens, + len(cm.items), ) cm.running_context_usage = cm.model_max_tokens + 1 await _compact_and_notify(session) continue except Exception as e: - import traceback - - error_msg = _friendly_error_message(e) - if error_msg is None: - error_msg = str(e) + "\n" + traceback.format_exc() + logger.info("Agent turn failed: %s", e) + logger.debug("Agent turn failed", exc_info=True) + error_msg = render_llm_error_message(e) await session.send_event( Event( @@ -1039,13 +1112,15 @@ async def execute_tool(tc, tool_name, tool_args, was_edited): # Execute all approved tools concurrently (cancellable) if approved_tasks: - gather_task = asyncio.ensure_future(asyncio.gather( - *[ - execute_tool(tc, tool_name, tool_args, was_edited) - for tc, tool_name, tool_args, was_edited in approved_tasks - ], - return_exceptions=True, - )) + gather_task = asyncio.ensure_future( + asyncio.gather( + *[ + execute_tool(tc, tool_name, tool_args, was_edited) + for tc, tool_name, tool_args, was_edited in approved_tasks + ], + return_exceptions=True, + ) + ) cancel_task = asyncio.ensure_future(session._cancelled.wait()) done, _ = await asyncio.wait( @@ -1061,10 +1136,16 @@ async def execute_tool(tc, tool_name, tool_args, was_edited): pass # Notify frontend that approved tools were cancelled for tc, tool_name, _args, _was_edited in approved_tasks: - await session.send_event(Event( - event_type="tool_state_change", - data={"tool_call_id": tc.id, "tool": tool_name, "state": "cancelled"}, - )) + await session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": tc.id, + "tool": tool_name, + "state": "cancelled", + }, + ) + ) await _cleanup_on_cancel(session) await session.send_event(Event(event_type="interrupted")) session.increment_turn() @@ -1212,8 +1293,12 @@ async def submission_loop( # Create session with tool router session = Session( - event_queue, config=config, tool_router=tool_router, hf_token=hf_token, - local_mode=local_mode, stream=stream, + event_queue, + config=config, + tool_router=tool_router, + hf_token=hf_token, + local_mode=local_mode, + stream=stream, ) if session_holder is not None: session_holder[0] = session @@ -1230,10 +1315,13 @@ async def submission_loop( async with tool_router: # Emit ready event after initialization await session.send_event( - Event(event_type="ready", data={ - "message": "Agent initialized", - "tool_count": len(tool_router.tools), - }) + Event( + event_type="ready", + data={ + "message": "Agent initialized", + "tool_count": len(tool_router.tools), + }, + ) ) while session.is_running: diff --git a/agent/core/llm_errors.py b/agent/core/llm_errors.py new file mode 100644 index 00000000..a41fe4b5 --- /dev/null +++ b/agent/core/llm_errors.py @@ -0,0 +1,147 @@ +"""Shared LLM error classification and user-facing messages.""" + +from __future__ import annotations + +from typing import Literal + +LlmErrorType = Literal[ + "auth", + "credits", + "model", + "provider", + "rate_limit", + "network", + "unknown", +] + +_AUTH_MARKERS = ( + "authentication failed", + "authentication_error", + "authentication error", + "unauthorized", + "invalid x-api-key", + "invalid api key", + "incorrect api key", + "didn't provide an api key", + "did not provide an api key", + "no api key provided", + "provide your api key", + "x-api-key header is required", + "api key header is required", + "api key required", + "api key is missing or invalid", + "api_key_invalid", + "401", +) +_CREDITS_MARKERS = ( + "insufficient credit", + "insufficient credits", + "out of credits", + "insufficient_quota", + "credit balance is too low", + "balance is too low", + "purchase credits", + "plans & billing", + "quota", + "billing", + "payment required", + "402", +) +_RATE_LIMIT_MARKERS = ("429", "rate limit", "too many requests") +_NETWORK_MARKERS = ( + "timeout", + "timed out", + "connect", + "connection error", + "connection refused", + "connection reset", + "network", + "service unavailable", + "bad gateway", + "overloaded", + "capacity", +) + + +def _has_any(err_str: str, markers: tuple[str, ...]) -> bool: + return any(marker in err_str for marker in markers) + + +def classify_llm_error(error: Exception) -> LlmErrorType: + """Classify common provider/API failures from the exception text.""" + err_str = str(error).lower() + + if _has_any(err_str, _AUTH_MARKERS): + return "auth" + if _has_any(err_str, _CREDITS_MARKERS): + return "credits" + if "not supported by provider" in err_str or "no provider supports" in err_str: + return "provider" + if "model_not_found" in err_str or "unknown model" in err_str: + return "model" + if "model" in err_str and ( + "not found" in err_str + or "does not exist" in err_str + or "not available" in err_str + ): + return "model" + if _has_any(err_str, _RATE_LIMIT_MARKERS): + return "rate_limit" + if _has_any(err_str, _NETWORK_MARKERS): + return "network" + return "unknown" + + +def friendly_llm_error_message(error: Exception) -> str | None: + """Return a clean user-facing message for common LLM failures.""" + error_type = classify_llm_error(error) + + if error_type == "auth": + return ( + "Authentication failed — your API key is missing or invalid.\n\n" + "To fix this, set the API key for your model provider:\n" + " • Anthropic: export ANTHROPIC_API_KEY=sk-...\n" + " • OpenAI: export OPENAI_API_KEY=sk-...\n" + " • HF Router: export HF_TOKEN=hf_...\n\n" + "You can also add it to a .env file in the project root.\n" + "To switch models, use the /model command." + ) + if error_type == "credits": + return ( + "Insufficient API credits or quota for this model/provider.\n\n" + "Check billing for the current provider, or switch models with /model." + ) + if error_type == "provider": + return ( + "The model isn't served by the provider you pinned.\n\n" + "Drop the ':' suffix to let the HF router auto-pick a " + "provider, or use '/model' (no arg) to see which providers host " + "which models." + ) + if error_type == "model": + return ( + "Model not found. Use '/model' to list suggestions, or paste an " + "HF model id like 'MiniMaxAI/MiniMax-M2.7'. Availability is shown " + "when you switch." + ) + if error_type == "rate_limit": + return ( + "Rate limit reached. Wait a moment and retry, or switch models/providers " + "with /model." + ) + if error_type == "network": + return "The model provider is unavailable or timed out. Retry in a moment." + return None + + +def render_llm_error_message(error: Exception) -> str: + """Return the message safe to show to users.""" + return friendly_llm_error_message(error) or str(error) + + +def health_error_type(error: Exception) -> str: + """Map LLM failures to the backend health endpoint error_type values.""" + error_type = classify_llm_error(error) + if error_type in {"auth", "credits", "rate_limit", "network"}: + return error_type + return "unknown" diff --git a/agent/core/model_switcher.py b/agent/core/model_switcher.py index 5bf5a44c..9437d87c 100644 --- a/agent/core/model_switcher.py +++ b/agent/core/model_switcher.py @@ -16,6 +16,7 @@ from __future__ import annotations from agent.core.effort_probe import ProbeInconclusive, probe_effort +from agent.core.llm_errors import render_llm_error_message from agent.core.provider_adapters import is_valid_model_name @@ -184,14 +185,17 @@ async def probe_and_switch_model( outcome = await probe_effort(model_id, preference, hf_token) except ProbeInconclusive as e: _commit_switch(model_id, config, session, effective=None, cache=False) + warning = render_llm_error_message(e) console.print( f"[yellow]Model switched to {model_id}[/yellow] " - f"[dim](couldn't validate: {e}; will verify on first message)[/dim]" + f"[dim](couldn't validate: {warning}; will verify on first message)[/dim]" ) return except Exception as e: # Hard persistent error — auth, unknown model, quota. Don't switch. - console.print(f"[bold red]Switch failed:[/bold red] {e}") + console.print( + f"[bold red]Switch failed:[/bold red] {render_llm_error_message(e)}" + ) console.print(f"[dim]Keeping current model: {config.model_name}[/dim]") return diff --git a/backend/routes/agent.py b/backend/routes/agent.py index d8b3d775..0b164d49 100644 --- a/backend/routes/agent.py +++ b/backend/routes/agent.py @@ -28,10 +28,16 @@ SubmitRequest, TruncateRequest, ) -from session_manager import MAX_SESSIONS, AgentSession, SessionCapacityError, session_manager +from session_manager import ( + MAX_SESSIONS, + AgentSession, + SessionCapacityError, + session_manager, +) import user_quotas +from agent.core.llm_errors import health_error_type, render_llm_error_message from agent.core.llm_params import _resolve_llm_params logger = logging.getLogger(__name__) @@ -172,34 +178,12 @@ async def llm_health_check() -> LLMHealthResponse: ) return LLMHealthResponse(status="ok", model=model) except Exception as e: - err_str = str(e).lower() - error_type = "unknown" - - if ( - "401" in err_str - or "auth" in err_str - or "invalid" in err_str - or "api key" in err_str - ): - error_type = "auth" - elif ( - "402" in err_str - or "credit" in err_str - or "quota" in err_str - or "insufficient" in err_str - or "billing" in err_str - ): - error_type = "credits" - elif "429" in err_str or "rate" in err_str: - error_type = "rate_limit" - elif "timeout" in err_str or "connect" in err_str or "network" in err_str: - error_type = "network" - + error_type = health_error_type(e) logger.warning(f"LLM health check failed ({error_type}): {e}") return LLMHealthResponse( status="error", model=model, - error=str(e)[:500], + error=render_llm_error_message(e)[:500], error_type=error_type, ) @@ -544,7 +528,9 @@ async def chat_sse( success = await session_manager.submit_user_input(session_id, text) else: broadcaster.unsubscribe(sub_id) - raise HTTPException(status_code=400, detail="Must provide 'text' or 'approvals'") + raise HTTPException( + status_code=400, detail="Must provide 'text' or 'approvals'" + ) if not success: broadcaster.unsubscribe(sub_id) @@ -561,7 +547,13 @@ async def chat_sse( # --------------------------------------------------------------------------- # Shared SSE helpers # --------------------------------------------------------------------------- -_TERMINAL_EVENTS = {"turn_complete", "approval_required", "error", "interrupted", "shutdown"} +_TERMINAL_EVENTS = { + "turn_complete", + "approval_required", + "error", + "interrupted", + "shutdown", +} _SSE_KEEPALIVE_SECONDS = 15 @@ -661,7 +653,10 @@ async def truncate_session( _check_session_access(session_id, user) success = await session_manager.truncate(session_id, body.user_message_index) if not success: - raise HTTPException(status_code=404, detail="Session not found, inactive, or message index out of range") + raise HTTPException( + status_code=404, + detail="Session not found, inactive, or message index out of range", + ) return {"status": "truncated", "session_id": session_id} @@ -687,5 +682,3 @@ async def shutdown_session( if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") return {"status": "shutdown_requested", "session_id": session_id} - - diff --git a/backend/session_manager.py b/backend/session_manager.py index 7293f9cf..dbb79ed5 100644 --- a/backend/session_manager.py +++ b/backend/session_manager.py @@ -9,6 +9,7 @@ from typing import Any, Optional from agent.config import load_config +from agent.core.llm_errors import render_llm_error_message from agent.core.agent_loop import process_submission from agent.core.session import Event, OpType, Session from agent.core.tools import ToolRouter @@ -125,9 +126,7 @@ def __init__(self, config_path: str | None = None) -> None: def _count_user_sessions(self, user_id: str) -> int: """Count active sessions owned by a specific user.""" return sum( - 1 - for s in self.sessions.values() - if s.user_id == user_id and s.is_active + 1 for s in self.sessions.values() if s.user_id == user_id and s.is_active ) async def create_session( @@ -191,7 +190,9 @@ def _create_session_sync(): if model: session_config.model_name = model session = Session( - event_queue, config=session_config, tool_router=tool_router, + event_queue, + config=session_config, + tool_router=tool_router, hf_token=hf_token, ) t1 = _time.monotonic() @@ -331,7 +332,9 @@ async def _run_session( ) agent_session.is_processing = True try: - should_continue = await process_submission(session, submission) + should_continue = await process_submission( + session, submission + ) finally: agent_session.is_processing = False if not should_continue: @@ -344,7 +347,10 @@ async def _run_session( except Exception as e: logger.error(f"Error in session {session_id}: {e}") await session.send_event( - Event(event_type="error", data={"error": str(e)}) + Event( + event_type="error", + data={"error": render_llm_error_message(e)}, + ) ) finally: @@ -408,7 +414,9 @@ async def truncate(self, session_id: str, user_message_index: int) -> bool: agent_session = self.sessions.get(session_id) if not agent_session or not agent_session.is_active: return False - return agent_session.session.context_manager.truncate_to_user_message(user_message_index) + return agent_session.session.context_manager.truncate_to_user_message( + user_message_index + ) async def compact(self, session_id: str) -> bool: """Compact context in a session.""" @@ -487,15 +495,18 @@ def get_session_info(self, session_id: str) -> dict[str, Any] | None: pending_approval = [] for tc in pa["tool_calls"]: import json + try: args = json.loads(tc.function.arguments) except (json.JSONDecodeError, AttributeError): args = {} - pending_approval.append({ - "tool": tc.function.name, - "tool_call_id": tc.id, - "arguments": args, - }) + pending_approval.append( + { + "tool": tc.function.name, + "tool_call_id": tc.id, + "arguments": args, + } + ) return { "session_id": session_id, diff --git a/tests/test_llm_errors.py b/tests/test_llm_errors.py new file mode 100644 index 00000000..d36ebed7 --- /dev/null +++ b/tests/test_llm_errors.py @@ -0,0 +1,124 @@ +import asyncio +from types import SimpleNamespace + +from rich.console import Console + +import agent.core.model_switcher as model_switcher +from agent.core.effort_probe import ProbeInconclusive +from agent.core.llm_errors import ( + classify_llm_error, + friendly_llm_error_message, + health_error_type, + render_llm_error_message, +) + + +def test_auth_errors_get_clean_message() -> None: + error = Exception("401 unauthorized: invalid api key") + + assert classify_llm_error(error) == "auth" + assert "Authentication failed" in friendly_llm_error_message(error) + + +def test_missing_api_key_header_gets_clean_message() -> None: + error = Exception("authentication_error: x-api-key header is required") + + assert classify_llm_error(error) == "auth" + assert render_llm_error_message(error).startswith("Authentication failed") + + +def test_openai_missing_api_key_gets_clean_message() -> None: + error = Exception( + "You didn't provide an API key. You need to provide your API key in an Authorization header." + ) + + assert classify_llm_error(error) == "auth" + assert render_llm_error_message(error).startswith("Authentication failed") + + +def test_anthropic_low_credit_error_gets_clean_message() -> None: + error = Exception( + "Your credit balance is too low to access the Anthropic API. " + "Please go to Plans & Billing to upgrade or purchase credits." + ) + + assert classify_llm_error(error) == "credits" + assert render_llm_error_message(error).startswith( + "Insufficient API credits or quota" + ) + + +def test_model_not_found_error_gets_clean_message() -> None: + error = Exception("model_not_found: requested model does not exist") + + assert classify_llm_error(error) == "model" + assert render_llm_error_message(error).startswith("Model not found") + + +def test_unknown_errors_fall_back_to_plain_exception_text() -> None: + error = RuntimeError("boom") + + assert classify_llm_error(error) == "unknown" + assert render_llm_error_message(error) == "boom" + + +def test_health_error_type_keeps_public_categories_stable() -> None: + assert health_error_type(Exception("invalid api key")) == "auth" + assert health_error_type(Exception("credit balance is too low")) == "credits" + assert health_error_type(Exception("rate limit exceeded")) == "rate_limit" + assert health_error_type(Exception("model_not_found")) == "unknown" + + +def test_model_switcher_shows_clean_hard_failure(monkeypatch) -> None: + async def fake_probe_effort(*args, **kwargs): + raise Exception( + "Your credit balance is too low to access the Anthropic API. " + "Please go to Plans & Billing to upgrade or purchase credits." + ) + + monkeypatch.setattr(model_switcher, "probe_effort", fake_probe_effort) + console = Console(record=True, width=120) + config = SimpleNamespace( + reasoning_effort="high", + model_name="anthropic/claude-opus-4-6", + ) + + asyncio.run( + model_switcher.probe_and_switch_model( + "anthropic/claude-opus-4-7", + config, + None, + console, + None, + ) + ) + + output = console.export_text() + assert "Insufficient API credits or quota" in output + assert "credit balance is too low" not in output.lower() + + +def test_model_switcher_shows_clean_inconclusive_warning(monkeypatch) -> None: + async def fake_probe_effort(*args, **kwargs): + raise ProbeInconclusive("timeout talking to provider") + + monkeypatch.setattr(model_switcher, "probe_effort", fake_probe_effort) + console = Console(record=True, width=120) + config = SimpleNamespace( + reasoning_effort="high", + model_name="anthropic/claude-opus-4-6", + ) + + asyncio.run( + model_switcher.probe_and_switch_model( + "anthropic/claude-opus-4-7", + config, + None, + console, + None, + ) + ) + + output = console.export_text() + assert "The model provider is unavailable or timed out" in output + assert "timeout talking to provider" not in output.lower() From a3a1b8f8f1bdfa7aef6e2d483bfa82f5c9c17a67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Vy=C5=A1n=C3=BD?= Date: Thu, 23 Apr 2026 14:27:22 +0200 Subject: [PATCH 08/15] Add Phase 1 provider adapters --- TODO.md | 12 +++ agent/core/model_switcher.py | 17 +++- agent/core/provider_adapters.py | 146 +++++++++++++++++++++++++++++++- tests/test_provider_adapters.py | 139 ++++++++++++++++++++++++++++++ 4 files changed, 308 insertions(+), 6 deletions(-) create mode 100644 TODO.md diff --git a/TODO.md b/TODO.md new file mode 100644 index 00000000..2bdee354 --- /dev/null +++ b/TODO.md @@ -0,0 +1,12 @@ +# TODO + +## Phase 1 +- [x] Add CLI/runtime-only provider adapters +- [x] Add tests for new provider adapters +- [x] Update CLI model help/validation text +- [x] Run compile + provider adapter tests + +## Phase 2 TODO +- [ ] Add adapter-driven backend model catalog +- [ ] Add local model discovery + cache +- [ ] Expand web UI model picker diff --git a/agent/core/model_switcher.py b/agent/core/model_switcher.py index 9437d87c..b816e008 100644 --- a/agent/core/model_switcher.py +++ b/agent/core/model_switcher.py @@ -17,7 +17,7 @@ from agent.core.effort_probe import ProbeInconclusive, probe_effort from agent.core.llm_errors import render_llm_error_message -from agent.core.provider_adapters import is_valid_model_name +from agent.core.provider_adapters import is_valid_model_name, resolve_adapter # Suggested models shown by `/model` (not a gate). Users can paste any HF @@ -61,7 +61,8 @@ def _print_hf_routing_info(model_id: str, console) -> bool: Anthropic / OpenAI ids return ``True`` without printing anything — the probe below covers "does this model exist". """ - if model_id.startswith(("anthropic/", "openai/")): + adapter = resolve_adapter(model_id) + if adapter and adapter.provider_id != "huggingface": return True from agent.core import hf_router_catalog as cat @@ -132,7 +133,8 @@ def print_model_listing(config, console) -> None: console.print( "\n[dim]Paste any HF model id (e.g. 'MiniMaxAI/MiniMax-M2.7').\n" "Add ':fastest', ':cheapest', ':preferred', or ':' to override routing.\n" - "Use 'anthropic/' or 'openai/' for direct API access.[/dim]" + "Direct prefixes: 'anthropic/', 'openai/', 'openrouter/', 'opencode/',\n" + "'opencode-go/', 'ollama/', 'lm_studio/', 'vllm/', 'openai-compat/'.[/dim]" ) @@ -142,7 +144,14 @@ def print_invalid_id(arg: str, console) -> None: "[dim]Expected:\n" " • /[:tag] (HF router — paste from huggingface.co)\n" " • anthropic/\n" - " • openai/[/dim]" + " • openai/\n" + " • openrouter/\n" + " • opencode/\n" + " • opencode-go/\n" + " • ollama/\n" + " • lm_studio/\n" + " • vllm/\n" + " • openai-compat/[/dim]" ) diff --git a/agent/core/provider_adapters.py b/agent/core/provider_adapters.py index 26381e2a..d14d9f66 100644 --- a/agent/core/provider_adapters.py +++ b/agent/core/provider_adapters.py @@ -20,8 +20,22 @@ def _has_model_suffix(model_name: str, prefix: str) -> bool: return bool(tail) and all(tail.split("/")) +def _normalize_openai_api_base(api_base: str) -> str: + base = api_base.rstrip("/") + if base.endswith("/v1"): + return base + return f"{base}/v1" + + +def _all_adapter_prefixes() -> tuple[str, ...]: + prefixes: list[str] = [] + for adapter in ADAPTERS: + prefixes.extend(adapter.prefixes) + return tuple(dict.fromkeys(prefixes)) + + def _is_hf_model_name(model_name: str) -> bool: - if model_name.startswith(("anthropic/", "openai/")): + if model_name.startswith(_all_adapter_prefixes()): return False bare = model_name.removeprefix("huggingface/").split(":", 1)[0] parts = bare.split("/") @@ -114,6 +128,127 @@ def build_params( return params +@dataclass(frozen=True) +class OpenAICompatAdapter(ProviderAdapter): + api_base_url: str = "" + api_key_env: str = "" + default_api_key: str = "" + supports_reasoning_effort: bool = True + use_raw_model_name: bool = False + + def resolved_api_base(self) -> str: + return _normalize_openai_api_base(self.api_base_url) + + def resolved_api_key(self) -> str | None: + if self.api_key_env: + return os.environ.get(self.api_key_env, self.default_api_key) + return self.default_api_key or None + + def allows_model_name(self, model_name: str) -> bool: + return bool(self.prefixes) and _has_model_suffix(model_name, self.prefixes[0]) + + def build_params( + self, + model_name: str, + *, + session_hf_token: str | None = None, + reasoning_effort: str | None = None, + strict: bool = False, + ) -> dict: + del session_hf_token + + model_id = model_name.removeprefix(self.prefixes[0]) + params: dict[str, Any] = { + "model": model_name if self.use_raw_model_name else f"openai/{model_id}", + "api_base": self.resolved_api_base(), + "api_key": self.resolved_api_key(), + } + + if reasoning_effort: + if not self.supports_reasoning_effort: + if strict: + raise UnsupportedEffortError( + f"{self.provider_id} doesn't accept effort={reasoning_effort!r}" + ) + else: + params["reasoning_effort"] = reasoning_effort + + return params + + +@dataclass(frozen=True) +class OllamaAdapter(OpenAICompatAdapter): + prefixes: tuple[str, ...] = ("ollama/",) + api_key_env: str = "OLLAMA_API_KEY" + default_api_key: str = "ollama" + supports_reasoning_effort: bool = False + + def resolved_api_base(self) -> str: + return _normalize_openai_api_base( + os.environ.get("OLLAMA_API_BASE", "http://localhost:11434/v1") + ) + + +@dataclass(frozen=True) +class LmStudioAdapter(OpenAICompatAdapter): + prefixes: tuple[str, ...] = ("lm_studio/",) + api_key_env: str = "LMSTUDIO_API_KEY" + default_api_key: str = "lm-studio" + supports_reasoning_effort: bool = False + use_raw_model_name: bool = True + + def resolved_api_base(self) -> str: + return _normalize_openai_api_base( + os.environ.get("LMSTUDIO_BASE_URL", "http://127.0.0.1:1234/v1") + ) + + +@dataclass(frozen=True) +class VllmAdapter(OpenAICompatAdapter): + prefixes: tuple[str, ...] = ("vllm/",) + api_key_env: str = "VLLM_API_KEY" + default_api_key: str = "vllm" + supports_reasoning_effort: bool = False + + def resolved_api_base(self) -> str: + return _normalize_openai_api_base( + os.environ.get("VLLM_BASE_URL", "http://localhost:8000/v1") + ) + + +@dataclass(frozen=True) +class OpenRouterAdapter(OpenAICompatAdapter): + prefixes: tuple[str, ...] = ("openrouter/",) + api_base_url: str = "https://openrouter.ai/api/v1" + api_key_env: str = "OPENROUTER_API_KEY" + + +@dataclass(frozen=True) +class OpenCodeZenAdapter(OpenAICompatAdapter): + prefixes: tuple[str, ...] = ("opencode/",) + api_base_url: str = "https://opencode.ai/zen/v1" + api_key_env: str = "OPENCODE_ZEN_API_KEY" + + +@dataclass(frozen=True) +class OpenCodeGoAdapter(OpenAICompatAdapter): + prefixes: tuple[str, ...] = ("opencode-go/",) + api_base_url: str = "https://opencode.ai/zen/go/v1" + api_key_env: str = "OPENCODE_GO_API_KEY" + + +@dataclass(frozen=True) +class GenericOpenAICompatAdapter(OpenAICompatAdapter): + prefixes: tuple[str, ...] = ("openai-compat/",) + api_key_env: str = "OPENAI_COMPAT_API_KEY" + + def resolved_api_base(self) -> str: + api_base = os.environ.get("OPENAI_COMPAT_BASE_URL", "") + if not api_base: + raise ValueError("OPENAI_COMPAT_BASE_URL is required for openai-compat/") + return _normalize_openai_api_base(api_base) + + @dataclass(frozen=True) class HfRouterAdapter(ProviderAdapter): """HuggingFace router — OpenAI-compat endpoint with HF token chain.""" @@ -121,7 +256,7 @@ class HfRouterAdapter(ProviderAdapter): _EFFORTS: ClassVar[frozenset[str]] = frozenset({"low", "medium", "high"}) def matches(self, model_name: str) -> bool: - return not model_name.startswith(("anthropic/", "openai/")) + return _is_hf_model_name(model_name) def allows_model_name(self, model_name: str) -> bool: return _is_hf_model_name(model_name) @@ -164,6 +299,13 @@ def build_params( ADAPTERS: tuple[ProviderAdapter, ...] = ( AnthropicAdapter(provider_id="anthropic"), OpenAIAdapter(provider_id="openai"), + OllamaAdapter(provider_id="ollama"), + LmStudioAdapter(provider_id="lm_studio"), + VllmAdapter(provider_id="vllm"), + OpenRouterAdapter(provider_id="openrouter"), + OpenCodeZenAdapter(provider_id="opencode_zen"), + OpenCodeGoAdapter(provider_id="opencode_go"), + GenericOpenAICompatAdapter(provider_id="openai_compat"), HfRouterAdapter(provider_id="huggingface"), ) diff --git a/tests/test_provider_adapters.py b/tests/test_provider_adapters.py index 2b6de9dd..6919f0c2 100644 --- a/tests/test_provider_adapters.py +++ b/tests/test_provider_adapters.py @@ -1,5 +1,6 @@ import pytest +import agent.core.provider_adapters as providers from agent.core.llm_params import _resolve_llm_params from agent.core.model_switcher import is_valid_model_id from agent.core.provider_adapters import ( @@ -99,6 +100,120 @@ def test_hf_adapter_strict_rejects_max(): ) +# -- OpenAI-compatible adapters ------------------------------------------------ + + +def test_ollama_adapter_builds_params(monkeypatch): + monkeypatch.delenv("OLLAMA_API_BASE", raising=False) + monkeypatch.delenv("OLLAMA_API_KEY", raising=False) + + params = _resolve_llm_params("ollama/llama3.1") + + assert params == { + "model": "openai/llama3.1", + "api_base": "http://localhost:11434/v1", + "api_key": "ollama", + } + + +def test_ollama_adapter_normalizes_base_url(monkeypatch): + monkeypatch.setenv("OLLAMA_API_BASE", "http://localhost:11434") + + params = _resolve_llm_params("ollama/llama3.1") + + assert params["api_base"] == "http://localhost:11434/v1" + + +def test_ollama_adapter_strict_rejects_effort(): + with pytest.raises(UnsupportedEffortError): + _resolve_llm_params("ollama/llama3.1", reasoning_effort="high", strict=True) + + +def test_lm_studio_adapter_uses_raw_model_name(monkeypatch): + monkeypatch.delenv("LMSTUDIO_BASE_URL", raising=False) + monkeypatch.delenv("LMSTUDIO_API_KEY", raising=False) + + params = _resolve_llm_params("lm_studio/google/gemma-3-12b") + + assert params == { + "model": "lm_studio/google/gemma-3-12b", + "api_base": "http://127.0.0.1:1234/v1", + "api_key": "lm-studio", + } + + +def test_vllm_adapter_uses_env_override(monkeypatch): + monkeypatch.setenv("VLLM_BASE_URL", "http://127.0.0.1:8000") + monkeypatch.setenv("VLLM_API_KEY", "secret") + + params = _resolve_llm_params("vllm/Qwen3-32B") + + assert params == { + "model": "openai/Qwen3-32B", + "api_base": "http://127.0.0.1:8000/v1", + "api_key": "secret", + } + + +def test_openrouter_adapter_uses_api_key(monkeypatch): + monkeypatch.setenv("OPENROUTER_API_KEY", "router-key") + + params = _resolve_llm_params( + "openrouter/anthropic/claude-sonnet-4.5", reasoning_effort="medium" + ) + + assert params == { + "model": "openai/anthropic/claude-sonnet-4.5", + "api_base": "https://openrouter.ai/api/v1", + "api_key": "router-key", + "reasoning_effort": "medium", + } + + +def test_opencode_zen_adapter_uses_api_key(monkeypatch): + monkeypatch.setenv("OPENCODE_ZEN_API_KEY", "zen-key") + + params = _resolve_llm_params("opencode/kimi-k2.6") + + assert params == { + "model": "openai/kimi-k2.6", + "api_base": "https://opencode.ai/zen/v1", + "api_key": "zen-key", + } + + +def test_opencode_go_adapter_uses_api_key(monkeypatch): + monkeypatch.setenv("OPENCODE_GO_API_KEY", "go-key") + + params = _resolve_llm_params("opencode-go/kimi-k2.6") + + assert params == { + "model": "openai/kimi-k2.6", + "api_base": "https://opencode.ai/zen/go/v1", + "api_key": "go-key", + } + + +def test_openai_compat_requires_base_url(monkeypatch): + monkeypatch.delenv("OPENAI_COMPAT_BASE_URL", raising=False) + + with pytest.raises(ValueError, match="OPENAI_COMPAT_BASE_URL"): + _resolve_llm_params("openai-compat/my-model") + + +def test_openai_compat_uses_required_base_url(monkeypatch): + monkeypatch.setenv("OPENAI_COMPAT_BASE_URL", "http://localhost:8080") + monkeypatch.setenv("OPENAI_COMPAT_API_KEY", "compat-key") + + params = _resolve_llm_params("openai-compat/my-model") + + assert params == { + "model": "openai/my-model", + "api_base": "http://localhost:8080/v1", + "api_key": "compat-key", + } + + # -- Validation --------------------------------------------------------------- @@ -110,6 +225,13 @@ def test_model_validation_accepts_free_form_hf_ids(): def test_model_validation_accepts_direct_provider_ids(): assert is_valid_model_name("anthropic/claude-opus-4-7") is True assert is_valid_model_name("openai/gpt-5") is True + assert is_valid_model_name("ollama/llama3.1") is True + assert is_valid_model_name("lm_studio/google/gemma-3-12b") is True + assert is_valid_model_name("vllm/Qwen3-32B") is True + assert is_valid_model_name("openrouter/anthropic/claude-sonnet-4.5") is True + assert is_valid_model_name("opencode/kimi-k2.6") is True + assert is_valid_model_name("opencode-go/kimi-k2.6") is True + assert is_valid_model_name("openai-compat/my-model") is True def test_model_validation_rejects_garbage(): @@ -117,13 +239,30 @@ def test_model_validation_rejects_garbage(): assert is_valid_model_name("no-slash") is False assert is_valid_model_name("anthropic/") is False assert is_valid_model_name("openai/") is False + assert is_valid_model_name("ollama/") is False + assert is_valid_model_name("lm_studio/") is False + assert is_valid_model_name("vllm/") is False + assert is_valid_model_name("openrouter/") is False + assert is_valid_model_name("opencode/") is False + assert is_valid_model_name("opencode-go/") is False + assert is_valid_model_name("openai-compat/") is False assert is_valid_model_name("huggingface/nope") is False assert is_valid_model_name("moonshotai/") is False +def test_hf_validation_excludes_new_provider_prefixes(): + hf = providers.resolve_adapter("openrouter/google/gemini-2.5-pro") + + assert hf is not None + assert hf.provider_id == "openrouter" + + def test_cli_validation_matches_provider_validation(): assert is_valid_model_id("openai/gpt-5") is True assert is_valid_model_id("moonshotai/Kimi-K2.6:fastest") is True + assert is_valid_model_id("ollama/llama3.1") is True + assert is_valid_model_id("openrouter/anthropic/claude-sonnet-4.5") is True + assert is_valid_model_id("openai-compat/my-model") is True assert is_valid_model_id("openai/") is False assert is_valid_model_id("anthropic/") is False From 8b1abd9e5e0ffed67c84fae836e897a852162664 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Vy=C5=A1n=C3=BD?= Date: Thu, 23 Apr 2026 15:20:42 +0200 Subject: [PATCH 09/15] Build adapter-driven model catalog and picker --- TODO.md | 7 +- agent/core/model_switcher.py | 62 +- agent/core/provider_adapters.py | 333 +++++++++- backend/routes/agent.py | 50 +- frontend/src/components/Chat/ChatInput.tsx | 567 ++++++++++++------ .../src/components/Chat/CustomModelDialog.tsx | 140 +++++ frontend/src/utils/model.ts | 5 +- tests/test_provider_adapters.py | 55 ++ 8 files changed, 951 insertions(+), 268 deletions(-) create mode 100644 frontend/src/components/Chat/CustomModelDialog.tsx diff --git a/TODO.md b/TODO.md index 2bdee354..26d949fd 100644 --- a/TODO.md +++ b/TODO.md @@ -7,6 +7,7 @@ - [x] Run compile + provider adapter tests ## Phase 2 TODO -- [ ] Add adapter-driven backend model catalog -- [ ] Add local model discovery + cache -- [ ] Expand web UI model picker +- [x] Add adapter-driven backend model catalog +- [x] Add local model discovery + cache +- [x] Expand web UI model picker +- [x] Add OpenAI-compat custom model modal diff --git a/agent/core/model_switcher.py b/agent/core/model_switcher.py index b816e008..8fc1a693 100644 --- a/agent/core/model_switcher.py +++ b/agent/core/model_switcher.py @@ -3,7 +3,7 @@ Split out of ``agent.main`` so the REPL dispatcher stays focused on input parsing. Exposes: -* ``SUGGESTED_MODELS`` — the short list shown by ``/model`` with no arg. +* Adapter-driven model catalog rendered by ``/model`` with no arg. * ``is_valid_model_id`` — loose format check on user input. * ``probe_and_switch_model`` — async: checks routing, fires a 1-token probe to resolve the effort cascade, then commits the switch (or @@ -17,21 +17,13 @@ from agent.core.effort_probe import ProbeInconclusive, probe_effort from agent.core.llm_errors import render_llm_error_message -from agent.core.provider_adapters import is_valid_model_name, resolve_adapter - - -# Suggested models shown by `/model` (not a gate). Users can paste any HF -# model id (e.g. "MiniMaxAI/MiniMax-M2.7") or an `anthropic/` / `openai/` -# prefix for direct API access. For HF ids, append ":fastest" / -# ":cheapest" / ":preferred" / ":" to override the default -# routing policy (auto = fastest with failover). -SUGGESTED_MODELS = [ - {"id": "anthropic/claude-opus-4-7", "label": "Claude Opus 4.7"}, - {"id": "anthropic/claude-opus-4-6", "label": "Claude Opus 4.6"}, - {"id": "MiniMaxAI/MiniMax-M2.7", "label": "MiniMax M2.7"}, - {"id": "moonshotai/Kimi-K2.6", "label": "Kimi K2.6"}, - {"id": "zai-org/GLM-5.1", "label": "GLM 5.1"}, -] +from agent.core.provider_adapters import ( + ADAPTERS, + find_model_option, + get_available_models, + is_valid_model_name, + resolve_adapter, +) _ROUTING_POLICIES = {"fastest", "cheapest", "preferred"} @@ -122,14 +114,40 @@ def _print_hf_routing_info(model_id: str, console) -> bool: def print_model_listing(config, console) -> None: - """Render the default ``/model`` (no-arg) view: current + suggested.""" + """Render the default ``/model`` (no-arg) view: current + available.""" current = config.model_name if config else "" + current_info = find_model_option(current) + available = get_available_models() + console.print("[bold]Current model:[/bold]") - console.print(f" {current}") - console.print("\n[bold]Suggested:[/bold]") - for m in SUGGESTED_MODELS: - marker = " [dim]<-- current[/dim]" if m["id"] == current else "" - console.print(f" {m['id']} [dim]({m['label']})[/dim]{marker}") + if current_info: + console.print(f" {current} [dim]({current_info['label']})[/dim]") + else: + console.print(f" {current}") + + console.print("\n[bold]Available:[/bold]") + for adapter in ADAPTERS: + section = [m for m in available if m.get("provider") == adapter.provider_id] + if not section: + if adapter.provider_id == "openai_compat" and adapter.should_show(): + console.print(f"\n[bold]{adapter.provider_label}[/bold]") + console.print(" [dim]Use openai-compat/[/dim]") + continue + + console.print(f"\n[bold]{adapter.provider_label}[/bold]") + section_sorted = sorted( + section, + key=lambda m: ( + not bool(m.get("recommended")), + str(m.get("label", "")).lower(), + ), + ) + for m in section_sorted: + marker = " [dim]<-- current[/dim]" if m["id"] == current else "" + source = m.get("source") + source_tag = " [dim](dynamic)[/dim]" if source == "dynamic" else "" + console.print(f" {m['id']} [dim]({m['label']})[/dim]{source_tag}{marker}") + console.print( "\n[dim]Paste any HF model id (e.g. 'MiniMaxAI/MiniMax-M2.7').\n" "Add ':fastest', ':cheapest', ':preferred', or ':' to override routing.\n" diff --git a/agent/core/provider_adapters.py b/agent/core/provider_adapters.py index d14d9f66..673252c1 100644 --- a/agent/core/provider_adapters.py +++ b/agent/core/provider_adapters.py @@ -1,8 +1,16 @@ -"""Provider adapters for runtime params and model-name validation.""" +"""Provider adapters for runtime params and model catalog metadata.""" +import json import os +import time from dataclasses import dataclass from typing import Any, ClassVar +from urllib.error import URLError +from urllib.request import urlopen + +_DISCOVERY_TIMEOUT_SECONDS = 2.0 +_DISCOVERY_CACHE_TTL_SECONDS = 30.0 +_discovery_cache: dict[str, tuple[float, tuple["SuggestedModel", ...]]] = {} class UnsupportedEffortError(ValueError): @@ -13,6 +21,18 @@ class UnsupportedEffortError(ValueError): """ +@dataclass(frozen=True) +class SuggestedModel: + id: str + label: str + description: str + provider: str + provider_label: str + avatar_url: str + recommended: bool = False + source: str = "static" + + def _has_model_suffix(model_name: str, prefix: str) -> bool: if not model_name.startswith(prefix): return False @@ -27,6 +47,22 @@ def _normalize_openai_api_base(api_base: str) -> str: return f"{base}/v1" +def _provider_avatar_url(provider_id: str) -> str: + avatars = { + "anthropic": "https://huggingface.co/api/avatars/Anthropic", + "openai": "https://openai.com/favicon.ico", + "ollama": "https://ollama.com/public/ollama.png", + "lm_studio": "https://avatars.githubusercontent.com/u/16906759?s=200&v=4", + "vllm": "https://avatars.githubusercontent.com/u/132129714?s=200&v=4", + "openrouter": "https://openrouter.ai/favicon.ico", + "opencode_zen": "https://huggingface.co/api/avatars/opencode-ai", + "opencode_go": "https://huggingface.co/api/avatars/opencode-ai", + "openai_compat": "https://openai.com/favicon.ico", + "huggingface": "https://huggingface.co/api/avatars/huggingface", + } + return avatars.get(provider_id, "https://huggingface.co/api/avatars/huggingface") + + def _all_adapter_prefixes() -> tuple[str, ...]: prefixes: list[str] = [] for adapter in ADAPTERS: @@ -42,14 +78,71 @@ def _is_hf_model_name(model_name: str) -> bool: return len(parts) >= 2 and all(parts) +def _discover_models( + *, + provider_id: str, + provider_label: str, + prefix: str, + api_base: str, +) -> tuple[SuggestedModel, ...]: + now = time.monotonic() + cached = _discovery_cache.get(api_base) + if cached and cached[0] > now: + return cached[1] + + models: list[SuggestedModel] = [] + try: + with urlopen( + f"{api_base}/models", timeout=_DISCOVERY_TIMEOUT_SECONDS + ) as response: + payload = json.load(response) + except (OSError, URLError, TimeoutError, ValueError): + payload = {"data": []} + + for item in payload.get("data", []): + if not isinstance(item, dict): + continue + model_id = item.get("id") + if not isinstance(model_id, str) or not model_id: + continue + models.append( + SuggestedModel( + id=f"{prefix}{model_id}", + label=model_id, + description=provider_label, + provider=provider_id, + provider_label=provider_label, + avatar_url=_provider_avatar_url(provider_id), + source="dynamic", + ) + ) + + resolved = tuple(sorted(models, key=lambda m: m.label.lower())) + _discovery_cache[api_base] = (now + _DISCOVERY_CACHE_TTL_SECONDS, resolved) + return resolved + + @dataclass(frozen=True) class ProviderAdapter: provider_id: str + provider_label: str prefixes: tuple[str, ...] = () + supports_custom_model: bool = False + custom_model_hint: str | None = None + custom_model_mode: str | None = None def matches(self, model_name: str) -> bool: return bool(self.prefixes) and model_name.startswith(self.prefixes) + def suggested_models(self) -> tuple[SuggestedModel, ...]: + return () + + def available_models(self) -> tuple[SuggestedModel, ...]: + return self.suggested_models() + + def should_show(self) -> bool: + return True + def build_params( self, model_name: str, @@ -63,6 +156,17 @@ def build_params( def allows_model_name(self, model_name: str) -> bool: return self.matches(model_name) + def to_summary(self) -> dict[str, Any]: + return { + "id": self.provider_id, + "label": self.provider_label, + "avatarUrl": _provider_avatar_url(self.provider_id), + "supportsCustomModel": self.supports_custom_model, + "customModelHint": self.custom_model_hint, + "customModelMode": self.custom_model_mode, + "prefix": self.prefixes[0] if self.prefixes else "", + } + @dataclass(frozen=True) class AnthropicAdapter(ProviderAdapter): @@ -73,6 +177,27 @@ class AnthropicAdapter(ProviderAdapter): {"low", "medium", "high", "xhigh", "max"} ) + def suggested_models(self) -> tuple[SuggestedModel, ...]: + return ( + SuggestedModel( + id="anthropic/claude-opus-4-7", + label="Claude Opus 4.7", + description="Anthropic", + provider="anthropic", + provider_label="Anthropic", + avatar_url=_provider_avatar_url("anthropic"), + recommended=True, + ), + SuggestedModel( + id="anthropic/claude-opus-4-6", + label="Claude Opus 4.6", + description="Anthropic", + provider="anthropic", + provider_label="Anthropic", + avatar_url=_provider_avatar_url("anthropic"), + ), + ) + def allows_model_name(self, model_name: str) -> bool: return _has_model_suffix(model_name, "anthropic/") @@ -105,6 +230,19 @@ class OpenAIAdapter(ProviderAdapter): prefixes: tuple[str, ...] = ("openai/",) _EFFORTS: ClassVar[frozenset[str]] = frozenset({"minimal", "low", "medium", "high"}) + def suggested_models(self) -> tuple[SuggestedModel, ...]: + return ( + SuggestedModel( + id="openai/gpt-5", + label="GPT-5", + description="OpenAI", + provider="openai", + provider_label="OpenAI", + avatar_url=_provider_avatar_url("openai"), + recommended=True, + ), + ) + def allows_model_name(self, model_name: str) -> bool: return _has_model_suffix(model_name, "openai/") @@ -144,6 +282,24 @@ def resolved_api_key(self) -> str | None: return os.environ.get(self.api_key_env, self.default_api_key) return self.default_api_key or None + def suggested_model_defs(self) -> tuple[tuple[str, str, bool], ...]: + return () + + def suggested_models(self) -> tuple[SuggestedModel, ...]: + prefix = self.prefixes[0] + return tuple( + SuggestedModel( + id=f"{prefix}{model_id}", + label=label, + description=self.provider_label, + provider=self.provider_id, + provider_label=self.provider_label, + avatar_url=_provider_avatar_url(self.provider_id), + recommended=recommended, + ) + for model_id, label, recommended in self.suggested_model_defs() + ) + def allows_model_name(self, model_name: str) -> bool: return bool(self.prefixes) and _has_model_suffix(model_name, self.prefixes[0]) @@ -188,6 +344,17 @@ def resolved_api_base(self) -> str: os.environ.get("OLLAMA_API_BASE", "http://localhost:11434/v1") ) + def available_models(self) -> tuple[SuggestedModel, ...]: + return _discover_models( + provider_id=self.provider_id, + provider_label=self.provider_label, + prefix=self.prefixes[0], + api_base=self.resolved_api_base(), + ) + + def should_show(self) -> bool: + return bool(self.available_models()) + @dataclass(frozen=True) class LmStudioAdapter(OpenAICompatAdapter): @@ -202,6 +369,17 @@ def resolved_api_base(self) -> str: os.environ.get("LMSTUDIO_BASE_URL", "http://127.0.0.1:1234/v1") ) + def available_models(self) -> tuple[SuggestedModel, ...]: + return _discover_models( + provider_id=self.provider_id, + provider_label=self.provider_label, + prefix=self.prefixes[0], + api_base=self.resolved_api_base(), + ) + + def should_show(self) -> bool: + return bool(self.available_models()) + @dataclass(frozen=True) class VllmAdapter(OpenAICompatAdapter): @@ -215,6 +393,17 @@ def resolved_api_base(self) -> str: os.environ.get("VLLM_BASE_URL", "http://localhost:8000/v1") ) + def available_models(self) -> tuple[SuggestedModel, ...]: + return _discover_models( + provider_id=self.provider_id, + provider_label=self.provider_label, + prefix=self.prefixes[0], + api_base=self.resolved_api_base(), + ) + + def should_show(self) -> bool: + return bool(self.available_models()) + @dataclass(frozen=True) class OpenRouterAdapter(OpenAICompatAdapter): @@ -222,6 +411,9 @@ class OpenRouterAdapter(OpenAICompatAdapter): api_base_url: str = "https://openrouter.ai/api/v1" api_key_env: str = "OPENROUTER_API_KEY" + def suggested_model_defs(self) -> tuple[tuple[str, str, bool], ...]: + return (("anthropic/claude-sonnet-4.5", "Claude Sonnet 4.5", True),) + @dataclass(frozen=True) class OpenCodeZenAdapter(OpenAICompatAdapter): @@ -229,6 +421,9 @@ class OpenCodeZenAdapter(OpenAICompatAdapter): api_base_url: str = "https://opencode.ai/zen/v1" api_key_env: str = "OPENCODE_ZEN_API_KEY" + def suggested_model_defs(self) -> tuple[tuple[str, str, bool], ...]: + return (("kimi-k2.6", "Kimi K2.6", True),) + @dataclass(frozen=True) class OpenCodeGoAdapter(OpenAICompatAdapter): @@ -236,11 +431,19 @@ class OpenCodeGoAdapter(OpenAICompatAdapter): api_base_url: str = "https://opencode.ai/zen/go/v1" api_key_env: str = "OPENCODE_GO_API_KEY" + def suggested_model_defs(self) -> tuple[tuple[str, str, bool], ...]: + return (("kimi-k2.6", "Kimi K2.6", True),) + @dataclass(frozen=True) class GenericOpenAICompatAdapter(OpenAICompatAdapter): prefixes: tuple[str, ...] = ("openai-compat/",) api_key_env: str = "OPENAI_COMPAT_API_KEY" + supports_custom_model: bool = True + custom_model_hint: str | None = ( + "Use openai-compat/. Configure OPENAI_COMPAT_BASE_URL on server." + ) + custom_model_mode: str | None = "suffix" def resolved_api_base(self) -> str: api_base = os.environ.get("OPENAI_COMPAT_BASE_URL", "") @@ -248,12 +451,49 @@ def resolved_api_base(self) -> str: raise ValueError("OPENAI_COMPAT_BASE_URL is required for openai-compat/") return _normalize_openai_api_base(api_base) + def should_show(self) -> bool: + return bool(os.environ.get("OPENAI_COMPAT_BASE_URL")) + @dataclass(frozen=True) class HfRouterAdapter(ProviderAdapter): """HuggingFace router — OpenAI-compat endpoint with HF token chain.""" _EFFORTS: ClassVar[frozenset[str]] = frozenset({"low", "medium", "high"}) + supports_custom_model: bool = True + custom_model_hint: str | None = ( + "Paste any Hugging Face model id, optionally with :fastest/:cheapest/:preferred" + ) + custom_model_mode: str | None = "raw" + + def suggested_models(self) -> tuple[SuggestedModel, ...]: + return ( + SuggestedModel( + id="moonshotai/Kimi-K2.6", + label="Kimi K2.6", + description="HF Router", + provider="huggingface", + provider_label="Hugging Face Router", + avatar_url="https://huggingface.co/api/avatars/moonshotai", + recommended=True, + ), + SuggestedModel( + id="MiniMaxAI/MiniMax-M2.7", + label="MiniMax M2.7", + description="HF Router", + provider="huggingface", + provider_label="Hugging Face Router", + avatar_url="https://huggingface.co/api/avatars/MiniMaxAI", + ), + SuggestedModel( + id="zai-org/GLM-5.1", + label="GLM 5.1", + description="HF Router", + provider="huggingface", + provider_label="Hugging Face Router", + avatar_url="https://huggingface.co/api/avatars/zai-org", + ), + ) def matches(self, model_name: str) -> bool: return _is_hf_model_name(model_name) @@ -297,16 +537,19 @@ def build_params( ADAPTERS: tuple[ProviderAdapter, ...] = ( - AnthropicAdapter(provider_id="anthropic"), - OpenAIAdapter(provider_id="openai"), - OllamaAdapter(provider_id="ollama"), - LmStudioAdapter(provider_id="lm_studio"), - VllmAdapter(provider_id="vllm"), - OpenRouterAdapter(provider_id="openrouter"), - OpenCodeZenAdapter(provider_id="opencode_zen"), - OpenCodeGoAdapter(provider_id="opencode_go"), - GenericOpenAICompatAdapter(provider_id="openai_compat"), - HfRouterAdapter(provider_id="huggingface"), + AnthropicAdapter(provider_id="anthropic", provider_label="Anthropic"), + OpenAIAdapter(provider_id="openai", provider_label="OpenAI"), + OllamaAdapter(provider_id="ollama", provider_label="Ollama"), + LmStudioAdapter(provider_id="lm_studio", provider_label="LM Studio"), + VllmAdapter(provider_id="vllm", provider_label="vLLM"), + OpenRouterAdapter(provider_id="openrouter", provider_label="OpenRouter"), + OpenCodeZenAdapter(provider_id="opencode_zen", provider_label="OpenCode Zen"), + OpenCodeGoAdapter(provider_id="opencode_go", provider_label="OpenCode Go"), + GenericOpenAICompatAdapter( + provider_id="openai_compat", + provider_label="OpenAI-Compatible", + ), + HfRouterAdapter(provider_id="huggingface", provider_label="Hugging Face Router"), ) @@ -320,3 +563,71 @@ def resolve_adapter(model_name: str) -> ProviderAdapter | None: def is_valid_model_name(model_name: str) -> bool: adapter = resolve_adapter(model_name) return adapter is not None and adapter.allows_model_name(model_name) + + +def _serialized_model(model: SuggestedModel) -> dict[str, Any]: + return { + "id": model.id, + "label": model.label, + "description": model.description, + "provider": model.provider, + "providerLabel": model.provider_label, + "avatarUrl": model.avatar_url, + "recommended": model.recommended, + "source": model.source, + } + + +def get_available_models() -> list[dict[str, Any]]: + available: list[dict[str, Any]] = [] + for adapter in ADAPTERS: + if not adapter.should_show(): + continue + for model in adapter.available_models(): + available.append(_serialized_model(model)) + return available + + +def get_provider_summaries() -> list[dict[str, Any]]: + providers: list[dict[str, Any]] = [] + for adapter in ADAPTERS: + if not adapter.should_show(): + continue + providers.append(adapter.to_summary()) + return providers + + +def find_model_option(model_name: str) -> dict[str, Any] | None: + for model in get_available_models(): + if model["id"] == model_name: + return model + + adapter = resolve_adapter(model_name) + if not adapter or not adapter.allows_model_name(model_name): + return None + + label = model_name + if adapter.provider_id == "huggingface": + label = model_name.removeprefix("huggingface/") + elif adapter.prefixes: + label = model_name.removeprefix(adapter.prefixes[0]) + + return { + "id": model_name, + "label": label, + "description": f"Custom {adapter.provider_label} model", + "provider": adapter.provider_id, + "providerLabel": adapter.provider_label, + "avatarUrl": _provider_avatar_url(adapter.provider_id), + "recommended": False, + "source": "custom", + } + + +def build_model_catalog(current_model: str) -> dict[str, Any]: + return { + "current": current_model, + "available": get_available_models(), + "providers": get_provider_summaries(), + "currentInfo": find_model_option(current_model), + } diff --git a/backend/routes/agent.py b/backend/routes/agent.py index 0b164d49..0261c074 100644 --- a/backend/routes/agent.py +++ b/backend/routes/agent.py @@ -39,47 +39,19 @@ from agent.core.llm_errors import health_error_type, render_llm_error_message from agent.core.llm_params import _resolve_llm_params +from agent.core.provider_adapters import build_model_catalog, is_valid_model_name logger = logging.getLogger(__name__) router = APIRouter(prefix="/api", tags=["agent"]) -AVAILABLE_MODELS = [ - { - "id": "moonshotai/Kimi-K2.6", - "label": "Kimi K2.6", - "provider": "huggingface", - "tier": "free", - "recommended": True, - }, - { - "id": "anthropic/claude-opus-4-6", - "label": "Claude Opus 4.6", - "provider": "anthropic", - "tier": "pro", - "recommended": True, - }, - { - "id": "MiniMaxAI/MiniMax-M2.7", - "label": "MiniMax M2.7", - "provider": "huggingface", - "tier": "free", - }, - { - "id": "zai-org/GLM-5.1", - "label": "GLM 5.1", - "provider": "huggingface", - "tier": "free", - }, -] - async def _require_hf_for_anthropic(request: Request, model_id: str) -> None: """403 if a non-``huggingface``-org user tries to select an Anthropic model. - Anthropic models are billed to the Space's ``ANTHROPIC_API_KEY``; every - other model in ``AVAILABLE_MODELS`` is routed through HF Router and - billed via ``X-HF-Bill-To``. The gate only fires for ``anthropic/*`` so + Anthropic models are billed to the Space's ``ANTHROPIC_API_KEY``; other + providers use their own routing/billing config. The gate only fires for + ``anthropic/*`` so non-HF users can still freely switch between the free models. Pattern: https://github.com/huggingface/ml-intern/pull/63 @@ -191,10 +163,7 @@ async def llm_health_check() -> LLMHealthResponse: @router.get("/config/model") async def get_model() -> dict: """Get current model and available models. No auth required.""" - return { - "current": session_manager.config.model_name, - "available": AVAILABLE_MODELS, - } + return build_model_catalog(session_manager.config.model_name) _TITLE_STRIP_CHARS = str.maketrans("", "", "`*_~#[]()") @@ -289,8 +258,7 @@ async def create_session( if isinstance(body, dict): model = body.get("model") - valid_ids = {m["id"] for m in AVAILABLE_MODELS} - if model and model not in valid_ids: + if model and not is_valid_model_name(model): raise HTTPException(status_code=400, detail=f"Unknown model: {model}") # Opus is gated to HF staff (PR #63). Only fires when the resolved model @@ -334,8 +302,7 @@ async def restore_session_summary( hf_token = os.environ.get("HF_TOKEN") model = body.get("model") - valid_ids = {m["id"] for m in AVAILABLE_MODELS} - if model and model not in valid_ids: + if model and not is_valid_model_name(model): raise HTTPException(status_code=400, detail=f"Unknown model: {model}") resolved_model = model or session_manager.config.model_name @@ -393,8 +360,7 @@ async def set_session_model( model_id = body.get("model") if not model_id: raise HTTPException(status_code=400, detail="Missing 'model' field") - valid_ids = {m["id"] for m in AVAILABLE_MODELS} - if model_id not in valid_ids: + if not is_valid_model_name(model_id): raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}") await _require_hf_for_anthropic(request, model_id) agent_session = session_manager.sessions.get(session_id) diff --git a/frontend/src/components/Chat/ChatInput.tsx b/frontend/src/components/Chat/ChatInput.tsx index d9fe5c4d..db1fbaa2 100644 --- a/frontend/src/components/Chat/ChatInput.tsx +++ b/frontend/src/components/Chat/ChatInput.tsx @@ -1,5 +1,17 @@ import { useState, useCallback, useEffect, useRef, KeyboardEvent } from 'react'; -import { Box, TextField, IconButton, CircularProgress, Typography, Menu, MenuItem, ListItemIcon, ListItemText, Chip } from '@mui/material'; +import { + Box, + TextField, + IconButton, + CircularProgress, + Typography, + Menu, + MenuItem, + ListItemIcon, + ListItemText, + Chip, + ListSubheader, +} from '@mui/material'; import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward'; import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown'; import StopIcon from '@mui/icons-material/Stop'; @@ -8,58 +20,39 @@ import { useUserQuota } from '@/hooks/useUserQuota'; import ClaudeCapDialog from '@/components/ClaudeCapDialog'; import { useAgentStore } from '@/store/agentStore'; import { FIRST_FREE_MODEL_PATH } from '@/utils/model'; +import CustomModelDialog from '@/components/Chat/CustomModelDialog'; -// Model configuration interface ModelOption { id: string; - name: string; + label: string; description: string; - modelPath: string; - avatarUrl: string; + provider: string; + providerLabel: string; + avatarUrl?: string; recommended?: boolean; + source?: string; } -const getHfAvatarUrl = (modelId: string) => { - const org = modelId.split('/')[0]; - return `https://huggingface.co/api/avatars/${org}`; -}; +interface ProviderOption { + id: string; + label: string; + avatarUrl?: string; + supportsCustomModel?: boolean; + customModelHint?: string; + customModelMode?: string; + prefix?: string; +} -const MODEL_OPTIONS: ModelOption[] = [ - { - id: 'kimi-k2.6', - name: 'Kimi K2.6', - description: 'Novita', - modelPath: 'moonshotai/Kimi-K2.6', - avatarUrl: getHfAvatarUrl('moonshotai/Kimi-K2.6'), - recommended: true, - }, - { - id: 'claude-opus', - name: 'Claude Opus 4.6', - description: 'Anthropic', - modelPath: 'anthropic/claude-opus-4-6', - avatarUrl: 'https://huggingface.co/api/avatars/Anthropic', - recommended: true, - }, - { - id: 'minimax-m2.7', - name: 'MiniMax M2.7', - description: 'Novita', - modelPath: 'MiniMaxAI/MiniMax-M2.7', - avatarUrl: getHfAvatarUrl('MiniMaxAI/MiniMax-M2.7'), - }, - { - id: 'glm-5.1', - name: 'GLM 5.1', - description: 'Together', - modelPath: 'zai-org/GLM-5.1', - avatarUrl: getHfAvatarUrl('zai-org/GLM-5.1'), - }, -]; - -const findModelByPath = (path: string): ModelOption | undefined => { - return MODEL_OPTIONS.find(m => m.modelPath === path || path?.includes(m.id)); -}; +interface ModelCatalogResponse { + current?: string; + currentInfo?: ModelOption | null; + available?: ModelOption[]; + providers?: ProviderOption[]; +} + +interface SessionResponse { + model?: string; +} interface ChatInputProps { sessionId?: string; @@ -70,45 +63,145 @@ interface ChatInputProps { placeholder?: string; } -const isClaudeModel = (m: ModelOption) => m.modelPath.startsWith('anthropic/'); -const firstFreeModel = () => MODEL_OPTIONS.find(m => !isClaudeModel(m)) ?? MODEL_OPTIONS[0]; +const OPENAI_COMPAT_PROVIDER = 'openai_compat'; + +const toModelOption = (value: unknown): ModelOption | null => { + if (!value || typeof value !== 'object') return null; + const v = value as Record; + if (typeof v.id !== 'string' || typeof v.label !== 'string') return null; + return { + id: v.id, + label: v.label, + description: typeof v.description === 'string' ? v.description : '', + provider: typeof v.provider === 'string' ? v.provider : '', + providerLabel: typeof v.providerLabel === 'string' ? v.providerLabel : '', + avatarUrl: typeof v.avatarUrl === 'string' ? v.avatarUrl : undefined, + recommended: Boolean(v.recommended), + source: typeof v.source === 'string' ? v.source : undefined, + }; +}; -export default function ChatInput({ sessionId, onSend, onStop, isProcessing = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) { +const toProviderOption = (value: unknown): ProviderOption | null => { + if (!value || typeof value !== 'object') return null; + const v = value as Record; + if (typeof v.id !== 'string' || typeof v.label !== 'string') return null; + return { + id: v.id, + label: v.label, + avatarUrl: typeof v.avatarUrl === 'string' ? v.avatarUrl : undefined, + supportsCustomModel: Boolean(v.supportsCustomModel), + customModelHint: typeof v.customModelHint === 'string' ? v.customModelHint : undefined, + customModelMode: typeof v.customModelMode === 'string' ? v.customModelMode : undefined, + prefix: typeof v.prefix === 'string' ? v.prefix : undefined, + }; +}; + +const isClaudePath = (modelPath: string) => modelPath.startsWith('anthropic/'); + +const firstFreeModel = (models: ModelOption[]) => { + const byPath = models.find((m) => m.id === FIRST_FREE_MODEL_PATH); + if (byPath) return byPath; + return models.find((m) => !isClaudePath(m.id)); +}; + +export default function ChatInput({ + sessionId, + onSend, + onStop, + isProcessing = false, + disabled = false, + placeholder = 'Ask anything...', +}: ChatInputProps) { const [input, setInput] = useState(''); const inputRef = useRef(null); - const [selectedModelId, setSelectedModelId] = useState(MODEL_OPTIONS[0].id); + const [modelOptions, setModelOptions] = useState([]); + const [providerOptions, setProviderOptions] = useState([]); + const [selectedModelPath, setSelectedModelPath] = useState(''); + const [selectedModelInfo, setSelectedModelInfo] = useState(null); const [modelAnchorEl, setModelAnchorEl] = useState(null); + const [customModalOpen, setCustomModalOpen] = useState(false); + const [customPrefix, setCustomPrefix] = useState('openai-compat/'); const { quota, refresh: refreshQuota } = useUserQuota(); - // The daily-cap dialog is triggered from two places: (a) a 429 returned - // from the chat transport when the user tries to send on Opus over cap — - // surfaced via the agent-store flag — and (b) nothing else right now - // (switching models is free). Keeping the open state in the store means - // the hook layer can flip it without threading props through. const claudeQuotaExhausted = useAgentStore((s) => s.claudeQuotaExhausted); const setClaudeQuotaExhausted = useAgentStore((s) => s.setClaudeQuotaExhausted); const lastSentRef = useRef(''); - // Model is per-session: fetch this tab's current model every time the - // session changes. Other tabs keep their own selections independently. + useEffect(() => { + let cancelled = false; + + const loadCatalog = async () => { + try { + const res = await apiFetch('/api/config/model'); + if (!res.ok || cancelled) return; + const data = (await res.json()) as ModelCatalogResponse; + const available = (data.available || []) + .map(toModelOption) + .filter((v): v is ModelOption => v !== null); + const providers = (data.providers || []) + .map(toProviderOption) + .filter((v): v is ProviderOption => v !== null); + const currentInfo = toModelOption(data.currentInfo ?? null); + if (cancelled) return; + + setModelOptions(available); + setProviderOptions(providers); + setSelectedModelPath(data.current || ''); + setSelectedModelInfo(currentInfo); + } catch { + // ignore + } + }; + + void loadCatalog(); + return () => { + cancelled = true; + }; + }, []); + useEffect(() => { if (!sessionId) return; let cancelled = false; apiFetch(`/api/session/${sessionId}`) .then((res) => (res.ok ? res.json() : null)) - .then((data) => { - if (cancelled) return; - if (data?.model) { - const model = findModelByPath(data.model); - if (model) setSelectedModelId(model.id); + .then((data: SessionResponse | null) => { + if (cancelled || !data?.model) return; + setSelectedModelPath(data.model); + const model = modelOptions.find((m) => m.id === data.model); + if (model) { + setSelectedModelInfo(model); + return; } + const inferred = selectedModelInfo && selectedModelInfo.id === data.model + ? selectedModelInfo + : { + id: data.model, + label: data.model, + description: 'Custom model', + provider: '', + providerLabel: '', + }; + setSelectedModelInfo(inferred); }) - .catch(() => { /* ignore */ }); - return () => { cancelled = true; }; - }, [sessionId]); + .catch(() => { + // ignore + }); + return () => { + cancelled = true; + }; + }, [sessionId, modelOptions, selectedModelInfo]); - const selectedModel = MODEL_OPTIONS.find(m => m.id === selectedModelId) || MODEL_OPTIONS[0]; + const selectedModel = selectedModelInfo + || modelOptions.find((m) => m.id === selectedModelPath) + || (selectedModelPath + ? { + id: selectedModelPath, + label: selectedModelPath, + description: 'Custom model', + provider: '', + providerLabel: '', + } + : null); - // Auto-focus the textarea when the session becomes ready useEffect(() => { if (!disabled && !isProcessing && inputRef.current) { inputRef.current.focus(); @@ -123,20 +216,15 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa } }, [input, disabled, onSend]); - // When the chat transport reports a Claude-quota 429, restore the typed - // text so the user doesn't lose their message. useEffect(() => { if (claudeQuotaExhausted && lastSentRef.current) { setInput(lastSentRef.current); } }, [claudeQuotaExhausted]); - // Refresh the quota display whenever the session changes (user might - // have started another tab that spent quota). useEffect(() => { if (sessionId) refreshQuota(); - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [sessionId]); + }, [sessionId, refreshQuota]); const handleKeyDown = useCallback( (e: KeyboardEvent) => { @@ -145,7 +233,7 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa handleSend(); } }, - [handleSend] + [handleSend], ); const handleModelClick = (event: React.MouseEvent) => { @@ -156,49 +244,72 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa setModelAnchorEl(null); }; - const handleSelectModel = async (model: ModelOption) => { - handleModelClose(); - if (!sessionId) return; - try { + const switchModel = useCallback( + async (modelPath: string, info?: ModelOption) => { + if (!sessionId) return; const res = await apiFetch(`/api/session/${sessionId}/model`, { method: 'POST', - body: JSON.stringify({ model: model.modelPath }), + body: JSON.stringify({ model: modelPath }), }); - if (res.ok) setSelectedModelId(model.id); - } catch { /* ignore */ } + if (res.ok) { + setSelectedModelPath(modelPath); + setSelectedModelInfo(info || modelOptions.find((m) => m.id === modelPath) || null); + } + }, + [sessionId, modelOptions], + ); + + const handleSelectModel = async (model: ModelOption) => { + handleModelClose(); + try { + await switchModel(model.id, model); + } catch { + // ignore + } + }; + + const handleOpenCustomModal = (provider: ProviderOption) => { + handleModelClose(); + if (!sessionId || !provider.prefix) return; + setCustomPrefix(provider.prefix); + setCustomModalOpen(true); + }; + + const handleCustomSubmit = async (modelId: string) => { + const full = `${customPrefix}${modelId}`; + const info: ModelOption = { + id: full, + label: modelId, + description: 'Custom OpenAI-compatible model', + provider: OPENAI_COMPAT_PROVIDER, + providerLabel: 'OpenAI-Compatible', + }; + await switchModel(full, info); + setCustomModalOpen(false); }; - // Dialog close: just clear the flag. The typed text is already restored. const handleCapDialogClose = useCallback(() => { setClaudeQuotaExhausted(false); }, [setClaudeQuotaExhausted]); - // "Use a free model" — switch the current session to Kimi (or the first - // non-Anthropic option) and auto-retry the send that tripped the cap. const handleUseFreeModel = useCallback(async () => { setClaudeQuotaExhausted(false); if (!sessionId) return; - const free = MODEL_OPTIONS.find(m => m.modelPath === FIRST_FREE_MODEL_PATH) - ?? firstFreeModel(); + const free = firstFreeModel(modelOptions); + if (!free) return; try { - const res = await apiFetch(`/api/session/${sessionId}/model`, { - method: 'POST', - body: JSON.stringify({ model: free.modelPath }), - }); - if (res.ok) { - setSelectedModelId(free.id); - const retryText = lastSentRef.current; - if (retryText) { - onSend(retryText); - setInput(''); - lastSentRef.current = ''; - } + await switchModel(free.id, free); + const retryText = lastSentRef.current; + if (retryText) { + onSend(retryText); + setInput(''); + lastSentRef.current = ''; } - } catch { /* ignore */ } - }, [sessionId, onSend, setClaudeQuotaExhausted]); + } catch { + // ignore + } + }, [sessionId, modelOptions, onSend, setClaudeQuotaExhausted, switchModel]); - // Hide the chip until the user has actually burned quota — an unused - // Opus session shouldn't populate a counter. const claudeChip = (() => { if (!quota || quota.claudeUsedToday === 0) return null; if (quota.plan === 'free') { @@ -207,6 +318,18 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa return `${quota.claudeUsedToday}/${quota.claudeDailyCap} today`; })(); + const groups = providerOptions.map((provider) => ({ + provider, + models: modelOptions + .filter((m) => m.provider === provider.id) + .sort((a, b) => { + const ar = a.recommended ? 0 : 1; + const br = b.recommended ? 0 : 1; + if (ar !== br) return ar - br; + return a.label.localeCompare(b.label); + }), + })); + return ( {isProcessing ? ( @@ -314,7 +437,6 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa )} - {/* Powered By Badge */} powered by - {selectedModel.name} + {selectedModel?.avatarUrl && ( + {selectedModel.label} + )} - {selectedModel.name} + {selectedModel?.label || 'Model'} - {/* Model Selection Menu */} - {MODEL_OPTIONS.map((model) => ( - handleSelectModel(model)} - selected={selectedModelId === model.id} - sx={{ - py: 1.5, - '&.Mui-selected': { - bgcolor: 'rgba(255,255,255,0.05)', - } - }} - > - - {model.name} - - - {model.name} - {model.recommended && ( - ( + + + {provider.label} + + + {models.map((model) => ( + void handleSelectModel(model)} + disabled={!sessionId} + selected={selectedModelPath === model.id} + sx={{ + py: 1.5, + '&.Mui-selected': { + bgcolor: 'rgba(255,255,255,0.05)', + }, + }} + > + + {model.avatarUrl ? ( + {model.label} - )} - {isClaudeModel(model) && claudeChip && ( - + > + {provider.label.slice(0, 2)} + )} - - } - secondary={model.description} - secondaryTypographyProps={{ - sx: { fontSize: '12px', color: 'var(--muted-text)' } - }} - /> - + + + {model.label} + {model.recommended && ( + + )} + {isClaudePath(model.id) && claudeChip && ( + + )} + + } + secondary={model.description} + secondaryTypographyProps={{ + sx: { fontSize: '12px', color: 'var(--muted-text)' }, + }} + /> + + ))} + + {provider.id === OPENAI_COMPAT_PROVIDER && provider.supportsCustomModel && ( + handleOpenCustomModal(provider)} + disabled={!sessionId} + sx={{ py: 1.5 }} + > + '} + secondaryTypographyProps={{ + sx: { fontSize: '12px', color: 'var(--muted-text)' }, + }} + /> + + )} + ))} + {!sessionId && ( + + + Start a session to switch models. + + + )} + setCustomModalOpen(false)} + onSubmit={handleCustomSubmit} + /> + void; + onSubmit: (modelId: string) => Promise; +} + +export default function CustomModelDialog({ + open, + prefix, + onClose, + onSubmit, +}: CustomModelDialogProps) { + const [value, setValue] = useState(''); + const [error, setError] = useState(''); + const [submitting, setSubmitting] = useState(false); + + useEffect(() => { + if (!open) { + setValue(''); + setError(''); + setSubmitting(false); + } + }, [open]); + + const handleSubmit = async () => { + const trimmed = value.trim(); + if (!trimmed) { + setError('Model id is required'); + return; + } + setError(''); + setSubmitting(true); + try { + await onSubmit(trimmed); + } catch { + setError('Failed to switch model'); + } finally { + setSubmitting(false); + } + }; + + return ( + + + Custom OpenAI-compatible model + + + + Enter model id only. We will use server env config for base URL and key. + + setValue(e.target.value)} + placeholder="e.g. my-model" + disabled={submitting} + sx={{ + '& .MuiOutlinedInput-root': { + bgcolor: 'transparent', + color: 'var(--text)', + }, + }} + /> + + Final id: {prefix}{value.trim() || ''} + + {error && ( + + {error} + + )} + + + + + + + ); +} diff --git a/frontend/src/utils/model.ts b/frontend/src/utils/model.ts index 89f23fe7..8ee4e297 100644 --- a/frontend/src/utils/model.ts +++ b/frontend/src/utils/model.ts @@ -2,9 +2,8 @@ * Shared model-id constants used by session-create call sites and the * ClaudeCapDialog "Use a free model" escape hatch. * - * Keep in sync with MODEL_OPTIONS in components/Chat/ChatInput.tsx and - * AVAILABLE_MODELS in backend/routes/agent.py. Bare HF ids (no - * `huggingface/` prefix) — matches upstream's auto-router. + * ChatInput now loads catalog from /api/config/model, but this fallback + * free-model constant is still used for the Claude-cap escape hatch. */ export const CLAUDE_MODEL_PATH = 'anthropic/claude-opus-4-6'; diff --git a/tests/test_provider_adapters.py b/tests/test_provider_adapters.py index 6919f0c2..97883059 100644 --- a/tests/test_provider_adapters.py +++ b/tests/test_provider_adapters.py @@ -257,6 +257,61 @@ def test_hf_validation_excludes_new_provider_prefixes(): assert hf.provider_id == "openrouter" +def test_model_catalog_hides_local_providers_when_unreachable(monkeypatch): + monkeypatch.setattr(providers.OllamaAdapter, "available_models", lambda self: ()) + monkeypatch.setattr(providers.LmStudioAdapter, "available_models", lambda self: ()) + monkeypatch.setattr(providers.VllmAdapter, "available_models", lambda self: ()) + + catalog = providers.build_model_catalog("anthropic/claude-opus-4-6") + provider_ids = {p["id"] for p in catalog["providers"]} + + assert "ollama" not in provider_ids + assert "lm_studio" not in provider_ids + assert "vllm" not in provider_ids + + +def test_model_catalog_includes_local_providers_when_discovered(monkeypatch): + dynamic = ( + providers.SuggestedModel( + id="ollama/llama3.1", + label="llama3.1", + description="Ollama", + provider="ollama", + provider_label="Ollama", + avatar_url="avatar", + source="dynamic", + ), + ) + monkeypatch.setattr( + providers.OllamaAdapter, "available_models", lambda self: dynamic + ) + + catalog = providers.build_model_catalog("ollama/llama3.1") + + assert any(m["id"] == "ollama/llama3.1" for m in catalog["available"]) + assert any(p["id"] == "ollama" for p in catalog["providers"]) + + +def test_model_catalog_openai_compat_visibility_depends_on_env(monkeypatch): + monkeypatch.delenv("OPENAI_COMPAT_BASE_URL", raising=False) + hidden = providers.build_model_catalog("anthropic/claude-opus-4-6") + assert not any(p["id"] == "openai_compat" for p in hidden["providers"]) + + monkeypatch.setenv("OPENAI_COMPAT_BASE_URL", "http://localhost:8080/v1") + shown = providers.build_model_catalog("anthropic/claude-opus-4-6") + assert any(p["id"] == "openai_compat" for p in shown["providers"]) + + +def test_model_catalog_includes_current_info_for_custom_model(monkeypatch): + monkeypatch.setenv("OPENAI_COMPAT_BASE_URL", "http://localhost:8080/v1") + + catalog = providers.build_model_catalog("openai-compat/my-model") + + assert catalog["currentInfo"] is not None + assert catalog["currentInfo"]["id"] == "openai-compat/my-model" + assert catalog["currentInfo"]["source"] == "custom" + + def test_cli_validation_matches_provider_validation(): assert is_valid_model_id("openai/gpt-5") is True assert is_valid_model_id("moonshotai/Kimi-K2.6:fastest") is True From 88a95a69e4c1de629d2fba7cac8a87b8efdeec19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Vy=C5=A1n=C3=BD?= Date: Fri, 24 Apr 2026 11:20:37 +0200 Subject: [PATCH 10/15] feat: add error handling for missing provider adapter in _resolve_llm_params --- agent/core/llm_params.py | 2 ++ tests/test_provider_adapters.py | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/agent/core/llm_params.py b/agent/core/llm_params.py index a52032e6..cee55d6e 100644 --- a/agent/core/llm_params.py +++ b/agent/core/llm_params.py @@ -70,6 +70,8 @@ def _resolve_llm_params( 3. HF_TOKEN env — belt-and-suspenders fallback for CLI users. """ adapter = resolve_adapter(model_name) + if adapter is None: + raise ValueError(f"No provider adapter for model: {model_name}") return adapter.build_params( model_name, session_hf_token=session_hf_token, diff --git a/tests/test_provider_adapters.py b/tests/test_provider_adapters.py index cf6fb79f..3f809c96 100644 --- a/tests/test_provider_adapters.py +++ b/tests/test_provider_adapters.py @@ -130,6 +130,7 @@ def test_model_validation_accepts_free_form_hf_ids(): def test_model_validation_accepts_direct_provider_ids(): assert is_valid_model_name("anthropic/claude-opus-4-7") is True assert is_valid_model_name("openai/gpt-5") is True + assert is_valid_model_name("bedrock/us.anthropic.claude-opus-4-7") is True def test_model_validation_rejects_garbage(): @@ -148,6 +149,14 @@ def test_cli_validation_matches_provider_validation(): assert is_valid_model_id("anthropic/") is False +def test_resolve_raises_on_no_adapter(monkeypatch): + from agent.core import llm_params + + monkeypatch.setattr(llm_params, "resolve_adapter", lambda _: None) + with pytest.raises(ValueError, match="No provider adapter"): + _resolve_llm_params("anything") + + def test_unsupported_effort_reexport(): """UnsupportedEffortError must be importable from llm_params (backward compat).""" from agent.core.llm_params import UnsupportedEffortError as FromLlm From 590496f3c5a23d66dec00d2e3cfe8b48dd1578ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Vy=C5=A1n=C3=BD?= Date: Fri, 24 Apr 2026 11:30:45 +0200 Subject: [PATCH 11/15] refactor: revert unnecessary formatting changes --- agent/core/agent_loop.py | 328 ++++++++++++--------------------------- 1 file changed, 101 insertions(+), 227 deletions(-) diff --git a/agent/core/agent_loop.py b/agent/core/agent_loop.py index 3129716e..5f09859a 100644 --- a/agent/core/agent_loop.py +++ b/agent/core/agent_loop.py @@ -126,24 +126,14 @@ def _is_transient_error(error: Exception) -> bool: """Return True for errors that are likely transient and worth retrying.""" err_str = str(error).lower() transient_patterns = [ - "timeout", - "timed out", - "429", - "rate limit", - "rate_limit", - "503", - "service unavailable", - "502", - "bad gateway", - "500", - "internal server error", - "overloaded", - "capacity", - "connection reset", - "connection refused", - "connection error", - "eof", - "broken pipe", + "timeout", "timed out", + "429", "rate limit", "rate_limit", + "503", "service unavailable", + "502", "bad gateway", + "500", "internal server error", + "overloaded", "capacity", + "connection reset", "connection refused", "connection error", + "eof", "broken pipe", ] return any(pattern in err_str for pattern in transient_patterns) @@ -157,14 +147,11 @@ def _is_effort_config_error(error: Exception) -> bool: doesn't work for the current model. We heal the cache and retry once. """ from agent.core.effort_probe import _is_invalid_effort, _is_thinking_unsupported - return _is_thinking_unsupported(error) or _is_invalid_effort(error) async def _heal_effort_and_rebuild_params( - session: Session, - error: Exception, - llm_params: dict, + session: Session, error: Exception, llm_params: dict, ) -> dict: """Update the session's effort cache based on ``error`` and return new llm_params. Called only when ``_is_effort_config_error(error)`` is True. @@ -175,11 +162,7 @@ async def _heal_effort_and_rebuild_params( • invalid-effort → re-run the full cascade probe; the result lands in the cache """ - from agent.core.effort_probe import ( - ProbeInconclusive, - _is_thinking_unsupported, - probe_effort, - ) + from agent.core.effort_probe import ProbeInconclusive, _is_thinking_unsupported, probe_effort model = session.config.model_name if _is_thinking_unsupported(error): @@ -188,15 +171,11 @@ async def _heal_effort_and_rebuild_params( else: try: outcome = await probe_effort( - model, - session.config.reasoning_effort, - session.hf_token, + model, session.config.reasoning_effort, session.hf_token, ) session.model_effective_effort[model] = outcome.effective_effort logger.info( - "healed: %s effort cascade → %s", - model, - outcome.effective_effort, + "healed: %s effort cascade → %s", model, outcome.effective_effort, ) except ProbeInconclusive: # Transient during healing — strip thinking for safety, next @@ -222,10 +201,7 @@ async def _compact_and_notify(session: Session) -> None: old_usage = cm.running_context_usage logger.debug( "Compaction check: usage=%d, max=%d, threshold=%d, needs_compact=%s", - old_usage, - cm.model_max_tokens, - cm.compaction_threshold, - cm.needs_compaction, + old_usage, cm.model_max_tokens, cm.compaction_threshold, cm.needs_compaction, ) await cm.compact( model_name=session.config.model_name, @@ -236,10 +212,7 @@ async def _compact_and_notify(session: Session) -> None: if new_usage != old_usage: logger.warning( "Context compacted: %d -> %d tokens (max=%d, %d messages)", - old_usage, - new_usage, - cm.model_max_tokens, - len(cm.items), + old_usage, new_usage, cm.model_max_tokens, len(cm.items), ) await session.send_event( Event( @@ -278,16 +251,13 @@ async def _cleanup_on_cancel(session: Session) -> None: @dataclass class LLMResult: """Result from an LLM call (streaming or non-streaming).""" - content: str | None tool_calls_acc: dict[int, dict] token_count: int finish_reason: str | None -async def _call_llm_streaming( - session: Session, messages, tools, llm_params -) -> LLMResult: +async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> LLMResult: """Call the LLM with streaming, emitting assistant_chunk events.""" response = None _healed_effort = False # one-shot safety net per call @@ -309,37 +279,22 @@ async def _call_llm_streaming( except Exception as e: if not _healed_effort and _is_effort_config_error(e): _healed_effort = True - llm_params = await _heal_effort_and_rebuild_params( - session, e, llm_params - ) - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "system", - "log": "Reasoning effort not supported for this model — adjusting and retrying.", - }, - ) - ) + llm_params = await _heal_effort_and_rebuild_params(session, e, llm_params) + await session.send_event(Event( + event_type="tool_log", + data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."}, + )) continue if _llm_attempt < _MAX_LLM_RETRIES - 1 and _is_transient_error(e): _delay = _LLM_RETRY_DELAYS[_llm_attempt] logger.warning( "Transient LLM error (attempt %d/%d): %s — retrying in %ds", - _llm_attempt + 1, - _MAX_LLM_RETRIES, - e, - _delay, - ) - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "system", - "log": f"LLM connection error, retrying in {_delay}s...", - }, - ) + _llm_attempt + 1, _MAX_LLM_RETRIES, e, _delay, ) + await session.send_event(Event( + event_type="tool_log", + data={"tool": "system", "log": f"LLM connection error, retrying in {_delay}s..."}, + )) await asyncio.sleep(_delay) continue raise @@ -375,21 +330,16 @@ async def _call_llm_streaming( idx = tc_delta.index if idx not in tool_calls_acc: tool_calls_acc[idx] = { - "id": "", - "type": "function", + "id": "", "type": "function", "function": {"name": "", "arguments": ""}, } if tc_delta.id: tool_calls_acc[idx]["id"] = tc_delta.id if tc_delta.function: if tc_delta.function.name: - tool_calls_acc[idx]["function"]["name"] += ( - tc_delta.function.name - ) + tool_calls_acc[idx]["function"]["name"] += tc_delta.function.name if tc_delta.function.arguments: - tool_calls_acc[idx]["function"]["arguments"] += ( - tc_delta.function.arguments - ) + tool_calls_acc[idx]["function"]["arguments"] += tc_delta.function.arguments if hasattr(chunk, "usage") and chunk.usage: token_count = chunk.usage.total_tokens @@ -402,9 +352,7 @@ async def _call_llm_streaming( ) -async def _call_llm_non_streaming( - session: Session, messages, tools, llm_params -) -> LLMResult: +async def _call_llm_non_streaming(session: Session, messages, tools, llm_params) -> LLMResult: """Call the LLM without streaming, emit assistant_message at the end.""" response = None _healed_effort = False @@ -425,37 +373,22 @@ async def _call_llm_non_streaming( except Exception as e: if not _healed_effort and _is_effort_config_error(e): _healed_effort = True - llm_params = await _heal_effort_and_rebuild_params( - session, e, llm_params - ) - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "system", - "log": "Reasoning effort not supported for this model — adjusting and retrying.", - }, - ) - ) + llm_params = await _heal_effort_and_rebuild_params(session, e, llm_params) + await session.send_event(Event( + event_type="tool_log", + data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."}, + )) continue if _llm_attempt < _MAX_LLM_RETRIES - 1 and _is_transient_error(e): _delay = _LLM_RETRY_DELAYS[_llm_attempt] logger.warning( "Transient LLM error (attempt %d/%d): %s — retrying in %ds", - _llm_attempt + 1, - _MAX_LLM_RETRIES, - e, - _delay, - ) - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "system", - "log": f"LLM connection error, retrying in {_delay}s...", - }, - ) + _llm_attempt + 1, _MAX_LLM_RETRIES, e, _delay, ) + await session.send_event(Event( + event_type="tool_log", + data={"tool": "system", "log": f"LLM connection error, retrying in {_delay}s..."}, + )) await asyncio.sleep(_delay) continue raise @@ -536,8 +469,7 @@ async def _abandon_pending_approval(session: Session) -> None: @staticmethod async def run_agent( - session: Session, - text: str, + session: Session, text: str, ) -> str | None: """ Handle user input (like user_input_or_turn in codex.rs:1291) @@ -601,18 +533,12 @@ async def run_agent( llm_params = _resolve_llm_params( session.config.model_name, session.hf_token, - reasoning_effort=session.effective_effort_for( - session.config.model_name - ), + reasoning_effort=session.effective_effort_for(session.config.model_name), ) if session.stream: - llm_result = await _call_llm_streaming( - session, messages, tools, llm_params - ) + llm_result = await _call_llm_streaming(session, messages, tools, llm_params) else: - llm_result = await _call_llm_non_streaming( - session, messages, tools, llm_params - ) + llm_result = await _call_llm_non_streaming(session, messages, tools, llm_params) content = llm_result.content tool_calls_acc = llm_result.tool_calls_acc @@ -656,10 +582,7 @@ async def run_agent( await session.send_event( Event( event_type="tool_log", - data={ - "tool": "system", - "log": f"Output truncated — retrying with smaller content ({dropped_names})", - }, + data={"tool": "system", "log": f"Output truncated — retrying with smaller content ({dropped_names})"}, ) ) iteration += 1 @@ -719,8 +642,7 @@ async def run_agent( except (json.JSONDecodeError, TypeError, ValueError): logger.warning( "Malformed arguments for tool_call %s (%s) — skipping", - tc.id, - tc.function.name, + tc.id, tc.function.name, ) tc.function.arguments = "{}" bad_tools.append(tc) @@ -741,35 +663,20 @@ async def run_agent( f"arguments and was NOT executed. Retry with smaller content — " f"for 'write', split into multiple smaller writes using 'edit'." ) - session.context_manager.add_message( - Message( - role="tool", - content=error_msg, - tool_call_id=tc.id, - name=tc.function.name, - ) - ) - await session.send_event( - Event( - event_type="tool_call", - data={ - "tool": tc.function.name, - "arguments": {}, - "tool_call_id": tc.id, - }, - ) - ) - await session.send_event( - Event( - event_type="tool_output", - data={ - "tool": tc.function.name, - "tool_call_id": tc.id, - "output": error_msg, - "success": False, - }, - ) - ) + session.context_manager.add_message(Message( + role="tool", + content=error_msg, + tool_call_id=tc.id, + name=tc.function.name, + )) + await session.send_event(Event( + event_type="tool_call", + data={"tool": tc.function.name, "arguments": {}, "tool_call_id": tc.id}, + )) + await session.send_event(Event( + event_type="tool_output", + data={"tool": tc.function.name, "tool_call_id": tc.id, "output": error_msg, "success": False}, + )) # ── Cancellation check: before tool execution ── if session.is_cancelled: @@ -787,7 +694,9 @@ async def run_agent( # Execute non-approval tools (in parallel when possible) if non_approval_tools: # 1. Validate args upfront - parsed_tools: list[tuple[ToolCall, str, dict, bool, str]] = [] + parsed_tools: list[ + tuple[ToolCall, str, dict, bool, str] + ] = [] for tc, tool_name, tool_args in non_approval_tools: args_valid, error_msg = _validate_tool_args(tool_args) parsed_tools.append( @@ -823,14 +732,12 @@ async def _exec_tool( ) return (tc, name, args, out, ok) - gather_task = asyncio.ensure_future( - asyncio.gather( - *[ - _exec_tool(tc, name, args, valid, err) - for tc, name, args, valid, err in parsed_tools - ] - ) - ) + gather_task = asyncio.ensure_future(asyncio.gather( + *[ + _exec_tool(tc, name, args, valid, err) + for tc, name, args, valid, err in parsed_tools + ] + )) cancel_task = asyncio.ensure_future(session._cancelled.wait()) done, _ = await asyncio.wait( @@ -847,16 +754,10 @@ async def _exec_tool( # Notify frontend that in-flight tools were cancelled for tc, name, _args, valid, _ in parsed_tools: if valid: - await session.send_event( - Event( - event_type="tool_state_change", - data={ - "tool_call_id": tc.id, - "tool": name, - "state": "cancelled", - }, - ) - ) + await session.send_event(Event( + event_type="tool_state_change", + data={"tool_call_id": tc.id, "tool": name, "state": "cancelled"}, + )) await _cleanup_on_cancel(session) break @@ -892,32 +793,23 @@ async def _exec_tool( for tc, tool_name, tool_args in approval_required_tools: # Resolve sandbox file paths for hf_jobs scripts so the # frontend can display & edit the actual file content. - if tool_name == "hf_jobs" and isinstance( - tool_args.get("script"), str - ): + if tool_name == "hf_jobs" and isinstance(tool_args.get("script"), str): from agent.tools.sandbox_tool import resolve_sandbox_script - sandbox = getattr(session, "sandbox", None) - resolved, _ = await resolve_sandbox_script( - sandbox, tool_args["script"] - ) + resolved, _ = await resolve_sandbox_script(sandbox, tool_args["script"]) if resolved: tool_args = {**tool_args, "script": resolved} - tools_data.append( - { - "tool": tool_name, - "arguments": tool_args, - "tool_call_id": tc.id, - } - ) + tools_data.append({ + "tool": tool_name, + "arguments": tool_args, + "tool_call_id": tc.id, + }) - await session.send_event( - Event( - event_type="approval_required", - data={"tools": tools_data, "count": len(tools_data)}, - ) - ) + await session.send_event(Event( + event_type="approval_required", + data={"tools": tools_data, "count": len(tools_data)}, + )) # Store all approval-requiring tools (ToolCall objects for execution) session.pending_approval = { @@ -935,10 +827,7 @@ async def _exec_tool( logger.warning( "ContextWindowExceededError at iteration %d — forcing compaction " "(usage=%d, model_max_tokens=%d, messages=%d)", - iteration, - cm.running_context_usage, - cm.model_max_tokens, - len(cm.items), + iteration, cm.running_context_usage, cm.model_max_tokens, len(cm.items), ) cm.running_context_usage = cm.model_max_tokens + 1 await _compact_and_notify(session) @@ -1112,15 +1001,13 @@ async def execute_tool(tc, tool_name, tool_args, was_edited): # Execute all approved tools concurrently (cancellable) if approved_tasks: - gather_task = asyncio.ensure_future( - asyncio.gather( - *[ - execute_tool(tc, tool_name, tool_args, was_edited) - for tc, tool_name, tool_args, was_edited in approved_tasks - ], - return_exceptions=True, - ) - ) + gather_task = asyncio.ensure_future(asyncio.gather( + *[ + execute_tool(tc, tool_name, tool_args, was_edited) + for tc, tool_name, tool_args, was_edited in approved_tasks + ], + return_exceptions=True, + )) cancel_task = asyncio.ensure_future(session._cancelled.wait()) done, _ = await asyncio.wait( @@ -1136,16 +1023,10 @@ async def execute_tool(tc, tool_name, tool_args, was_edited): pass # Notify frontend that approved tools were cancelled for tc, tool_name, _args, _was_edited in approved_tasks: - await session.send_event( - Event( - event_type="tool_state_change", - data={ - "tool_call_id": tc.id, - "tool": tool_name, - "state": "cancelled", - }, - ) - ) + await session.send_event(Event( + event_type="tool_state_change", + data={"tool_call_id": tc.id, "tool": tool_name, "state": "cancelled"}, + )) await _cleanup_on_cancel(session) await session.send_event(Event(event_type="interrupted")) session.increment_turn() @@ -1293,12 +1174,8 @@ async def submission_loop( # Create session with tool router session = Session( - event_queue, - config=config, - tool_router=tool_router, - hf_token=hf_token, - local_mode=local_mode, - stream=stream, + event_queue, config=config, tool_router=tool_router, hf_token=hf_token, + local_mode=local_mode, stream=stream, ) if session_holder is not None: session_holder[0] = session @@ -1315,13 +1192,10 @@ async def submission_loop( async with tool_router: # Emit ready event after initialization await session.send_event( - Event( - event_type="ready", - data={ - "message": "Agent initialized", - "tool_count": len(tool_router.tools), - }, - ) + Event(event_type="ready", data={ + "message": "Agent initialized", + "tool_count": len(tool_router.tools), + }) ) while session.is_running: From cc6cff788bd66637b12eafa4d03ad182a0fe06b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Vy=C5=A1n=C3=BD?= Date: Fri, 24 Apr 2026 11:55:41 +0200 Subject: [PATCH 12/15] refactor: cleanup before merge --- agent/core/llm_params.py | 16 ++++++---------- agent/core/model_switcher.py | 19 +++++++------------ backend/routes/agent.py | 26 ++++++-------------------- backend/session_manager.py | 34 ++++++++++++---------------------- 4 files changed, 31 insertions(+), 64 deletions(-) diff --git a/agent/core/llm_params.py b/agent/core/llm_params.py index cee55d6e..8bc8fc9d 100644 --- a/agent/core/llm_params.py +++ b/agent/core/llm_params.py @@ -25,17 +25,13 @@ def _patch_litellm_effort_validation() -> None: def _widened(model: str) -> bool: m = model.lower() + # Original 4.6 match plus any future Opus >= 4.6. We only need this + # to return True for families where "max" / "xhigh" are acceptable + # at the API; the cascade handles the case when they're not. return any( - v in m - for v in ( - "opus-4-6", - "opus_4_6", - "opus-4.6", - "opus_4.6", - "opus-4-7", - "opus_4_7", - "opus-4.7", - "opus_4.7", + v in m for v in ( + "opus-4-6", "opus_4_6", "opus-4.6", "opus_4.6", + "opus-4-7", "opus_4_7", "opus-4.7", "opus_4.7", ) ) diff --git a/agent/core/model_switcher.py b/agent/core/model_switcher.py index f14c2a4f..3e3b0a44 100644 --- a/agent/core/model_switcher.py +++ b/agent/core/model_switcher.py @@ -116,7 +116,9 @@ def _print_hf_routing_info(model_id: str, console) -> bool: ) ctx = f"{p.context_length:,} ctx" if p.context_length else "ctx n/a" tools = "tools" if p.supports_tools else "no tools" - console.print(f" [dim]{p.provider}: {price}, {ctx}, {tools}[/dim]") + console.print( + f" [dim]{p.provider}: {price}, {ctx}, {tools}[/dim]" + ) return True @@ -176,9 +178,7 @@ async def probe_and_switch_model( # Nothing to validate with a ping that we couldn't validate on the # first real call just as cheaply. Skip the probe entirely. _commit_switch(model_id, config, session, effective=None, cache=False) - console.print( - f"[green]Model switched to {model_id}[/green] [dim](effort: off)[/dim]" - ) + console.print(f"[green]Model switched to {model_id}[/green] [dim](effort: off)[/dim]") return console.print(f"[dim]checking {model_id} (effort: {preference})...[/dim]") @@ -194,18 +194,13 @@ async def probe_and_switch_model( return except Exception as e: # Hard persistent error — auth, unknown model, quota. Don't switch. - console.print( - f"[bold red]Switch failed:[/bold red] {render_llm_error_message(e)}" - ) + console.print(f"[bold red]Switch failed:[/bold red] {render_llm_error_message(e)}") console.print(f"[dim]Keeping current model: {config.model_name}[/dim]") return _commit_switch( - model_id, - config, - session, - effective=outcome.effective_effort, - cache=True, + model_id, config, session, + effective=outcome.effective_effort, cache=True, ) effort_label = outcome.effective_effort or "off" suffix = f" — {outcome.note}" if outcome.note else "" diff --git a/backend/routes/agent.py b/backend/routes/agent.py index de12ab09..5d4350dc 100644 --- a/backend/routes/agent.py +++ b/backend/routes/agent.py @@ -28,12 +28,7 @@ SubmitRequest, TruncateRequest, ) -from session_manager import ( - MAX_SESSIONS, - AgentSession, - SessionCapacityError, - session_manager, -) +from session_manager import MAX_SESSIONS, AgentSession, SessionCapacityError, session_manager import user_quotas @@ -532,9 +527,7 @@ async def chat_sse( success = await session_manager.submit_user_input(session_id, text) else: broadcaster.unsubscribe(sub_id) - raise HTTPException( - status_code=400, detail="Must provide 'text' or 'approvals'" - ) + raise HTTPException(status_code=400, detail="Must provide 'text' or 'approvals'") if not success: broadcaster.unsubscribe(sub_id) @@ -551,13 +544,7 @@ async def chat_sse( # --------------------------------------------------------------------------- # Shared SSE helpers # --------------------------------------------------------------------------- -_TERMINAL_EVENTS = { - "turn_complete", - "approval_required", - "error", - "interrupted", - "shutdown", -} +_TERMINAL_EVENTS = {"turn_complete", "approval_required", "error", "interrupted", "shutdown"} _SSE_KEEPALIVE_SECONDS = 15 @@ -657,10 +644,7 @@ async def truncate_session( _check_session_access(session_id, user) success = await session_manager.truncate(session_id, body.user_message_index) if not success: - raise HTTPException( - status_code=404, - detail="Session not found, inactive, or message index out of range", - ) + raise HTTPException(status_code=404, detail="Session not found, inactive, or message index out of range") return {"status": "truncated", "session_id": session_id} @@ -686,3 +670,5 @@ async def shutdown_session( if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") return {"status": "shutdown_requested", "session_id": session_id} + + diff --git a/backend/session_manager.py b/backend/session_manager.py index dbb79ed5..e2747137 100644 --- a/backend/session_manager.py +++ b/backend/session_manager.py @@ -126,7 +126,9 @@ def __init__(self, config_path: str | None = None) -> None: def _count_user_sessions(self, user_id: str) -> int: """Count active sessions owned by a specific user.""" return sum( - 1 for s in self.sessions.values() if s.user_id == user_id and s.is_active + 1 + for s in self.sessions.values() + if s.user_id == user_id and s.is_active ) async def create_session( @@ -190,9 +192,7 @@ def _create_session_sync(): if model: session_config.model_name = model session = Session( - event_queue, - config=session_config, - tool_router=tool_router, + event_queue, config=session_config, tool_router=tool_router, hf_token=hf_token, ) t1 = _time.monotonic() @@ -332,9 +332,7 @@ async def _run_session( ) agent_session.is_processing = True try: - should_continue = await process_submission( - session, submission - ) + should_continue = await process_submission(session, submission) finally: agent_session.is_processing = False if not should_continue: @@ -347,10 +345,7 @@ async def _run_session( except Exception as e: logger.error(f"Error in session {session_id}: {e}") await session.send_event( - Event( - event_type="error", - data={"error": render_llm_error_message(e)}, - ) + Event(event_type="error", data={"error": render_llm_error_message(e)}) ) finally: @@ -414,9 +409,7 @@ async def truncate(self, session_id: str, user_message_index: int) -> bool: agent_session = self.sessions.get(session_id) if not agent_session or not agent_session.is_active: return False - return agent_session.session.context_manager.truncate_to_user_message( - user_message_index - ) + return agent_session.session.context_manager.truncate_to_user_message(user_message_index) async def compact(self, session_id: str) -> bool: """Compact context in a session.""" @@ -495,18 +488,15 @@ def get_session_info(self, session_id: str) -> dict[str, Any] | None: pending_approval = [] for tc in pa["tool_calls"]: import json - try: args = json.loads(tc.function.arguments) except (json.JSONDecodeError, AttributeError): args = {} - pending_approval.append( - { - "tool": tc.function.name, - "tool_call_id": tc.id, - "arguments": args, - } - ) + pending_approval.append({ + "tool": tc.function.name, + "tool_call_id": tc.id, + "arguments": args, + }) return { "session_id": session_id, From 25d933126cdf4f04fef74a0791535dc82fe21c35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Vy=C5=A1n=C3=BD?= Date: Fri, 24 Apr 2026 13:26:42 +0200 Subject: [PATCH 13/15] Add GeminiAdapter for Google Gemini models Native LiteLLM adapter with reasoning_effort support (low/medium/high), minimal->low normalization, suggested models (2.5 Pro, Flash). Supersedes PR#95 approach with proper adapter pattern. Co-Authored-By: Claude Opus 4.6 (1M context) --- agent/core/model_switcher.py | 1 + agent/core/provider_adapters.py | 58 +++++++++++++++++++++++++++++++++ tests/test_provider_adapters.py | 32 ++++++++++++++++++ 3 files changed, 91 insertions(+) diff --git a/agent/core/model_switcher.py b/agent/core/model_switcher.py index 0b8817fd..433d4236 100644 --- a/agent/core/model_switcher.py +++ b/agent/core/model_switcher.py @@ -166,6 +166,7 @@ def print_invalid_id(arg: str, console) -> None: " • anthropic/\n" " • openai/\n" " • bedrock/\n" + " • gemini/\n" " • openrouter/\n" " • opencode/\n" " • opencode-go/\n" diff --git a/agent/core/provider_adapters.py b/agent/core/provider_adapters.py index b4d65cc2..cfc729d2 100644 --- a/agent/core/provider_adapters.py +++ b/agent/core/provider_adapters.py @@ -50,6 +50,7 @@ def _normalize_openai_api_base(api_base: str) -> str: def _provider_avatar_url(provider_id: str) -> str: avatars = { "anthropic": "https://huggingface.co/api/avatars/Anthropic", + "gemini": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06.svg", "openai": "https://openai.com/favicon.ico", "ollama": "https://ollama.com/public/ollama.png", "lm_studio": "https://avatars.githubusercontent.com/u/16906759?s=200&v=4", @@ -290,6 +291,62 @@ def build_params( return {"model": model_name} +@dataclass(frozen=True) +class GeminiAdapter(ProviderAdapter): + """Google Gemini via LiteLLM's native Gemini adapter. + + GEMINI_API_KEY picked up automatically by LiteLLM. + reasoning_effort forwarded directly — LiteLLM maps to thinking_config. + """ + + prefixes: tuple[str, ...] = ("gemini/",) + _EFFORTS: ClassVar[frozenset[str]] = frozenset({"low", "medium", "high"}) + + def suggested_models(self) -> tuple[SuggestedModel, ...]: + return ( + SuggestedModel( + id="gemini/gemini-2.5-pro", + label="Gemini 2.5 Pro", + description="Google Gemini", + provider="gemini", + provider_label="Google Gemini", + avatar_url=_provider_avatar_url("gemini"), + recommended=True, + ), + SuggestedModel( + id="gemini/gemini-2.5-flash", + label="Gemini 2.5 Flash", + description="Google Gemini", + provider="gemini", + provider_label="Google Gemini", + avatar_url=_provider_avatar_url("gemini"), + ), + ) + + def allows_model_name(self, model_name: str) -> bool: + return _has_model_suffix(model_name, "gemini/") + + def build_params( + self, + model_name: str, + *, + session_hf_token: str | None = None, + reasoning_effort: str | None = None, + strict: bool = False, + ) -> dict: + params: dict[str, Any] = {"model": model_name} + if reasoning_effort: + level = "low" if reasoning_effort == "minimal" else reasoning_effort + if level not in self._EFFORTS: + if strict: + raise UnsupportedEffortError( + f"Gemini doesn't accept effort={level!r}" + ) + else: + params["reasoning_effort"] = level + return params + + @dataclass(frozen=True) class OpenAICompatAdapter(ProviderAdapter): api_base_url: str = "" @@ -563,6 +620,7 @@ def build_params( ADAPTERS: tuple[ProviderAdapter, ...] = ( AnthropicAdapter(provider_id="anthropic", provider_label="Anthropic"), BedrockAdapter(provider_id="bedrock", provider_label="AWS Bedrock"), + GeminiAdapter(provider_id="gemini", provider_label="Google Gemini"), OpenAIAdapter(provider_id="openai", provider_label="OpenAI"), OllamaAdapter(provider_id="ollama", provider_label="Ollama"), LmStudioAdapter(provider_id="lm_studio", provider_label="LM Studio"), diff --git a/tests/test_provider_adapters.py b/tests/test_provider_adapters.py index 4aef7bcb..193647ab 100644 --- a/tests/test_provider_adapters.py +++ b/tests/test_provider_adapters.py @@ -85,6 +85,36 @@ def test_bedrock_validation(): assert is_valid_model_name("bedrock/") is False +# -- Gemini adapter ----------------------------------------------------------- + + +def test_gemini_adapter_passes_reasoning_effort(): + params = _resolve_llm_params("gemini/gemini-2.5-pro", reasoning_effort="medium") + assert params == {"model": "gemini/gemini-2.5-pro", "reasoning_effort": "medium"} + + +def test_gemini_adapter_normalizes_minimal(): + params = _resolve_llm_params("gemini/gemini-2.5-flash", reasoning_effort="minimal") + assert params == {"model": "gemini/gemini-2.5-flash", "reasoning_effort": "low"} + + +def test_gemini_adapter_no_effort(): + params = _resolve_llm_params("gemini/gemini-2.5-pro") + assert params == {"model": "gemini/gemini-2.5-pro"} + + +def test_gemini_adapter_strict_rejects_invalid(): + with pytest.raises(UnsupportedEffortError): + _resolve_llm_params("gemini/gemini-2.5-pro", reasoning_effort="max", strict=True) + with pytest.raises(UnsupportedEffortError): + _resolve_llm_params("gemini/gemini-2.5-pro", reasoning_effort="xhigh", strict=True) + + +def test_gemini_validation(): + assert is_valid_model_name("gemini/gemini-2.5-pro") is True + assert is_valid_model_name("gemini/") is False + + # -- HF Router adapter -------------------------------------------------------- @@ -246,6 +276,7 @@ def test_model_validation_accepts_direct_provider_ids(): assert is_valid_model_name("anthropic/claude-opus-4-7") is True assert is_valid_model_name("openai/gpt-5") is True assert is_valid_model_name("bedrock/us.anthropic.claude-opus-4-7") is True + assert is_valid_model_name("gemini/gemini-2.5-pro") is True assert is_valid_model_name("ollama/llama3.1") is True assert is_valid_model_name("lm_studio/google/gemma-3-12b") is True assert is_valid_model_name("vllm/Qwen3-32B") is True @@ -260,6 +291,7 @@ def test_model_validation_rejects_garbage(): assert is_valid_model_name("no-slash") is False assert is_valid_model_name("anthropic/") is False assert is_valid_model_name("openai/") is False + assert is_valid_model_name("gemini/") is False assert is_valid_model_name("ollama/") is False assert is_valid_model_name("lm_studio/") is False assert is_valid_model_name("vllm/") is False From 4ac4f1a526953791c250eca7e0c3b6ff5f2a4612 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Vy=C5=A1n=C3=BD?= Date: Fri, 24 Apr 2026 13:32:07 +0200 Subject: [PATCH 14/15] chore: remove outdated TODO list --- TODO.md | 13 ------------- 1 file changed, 13 deletions(-) delete mode 100644 TODO.md diff --git a/TODO.md b/TODO.md deleted file mode 100644 index 26d949fd..00000000 --- a/TODO.md +++ /dev/null @@ -1,13 +0,0 @@ -# TODO - -## Phase 1 -- [x] Add CLI/runtime-only provider adapters -- [x] Add tests for new provider adapters -- [x] Update CLI model help/validation text -- [x] Run compile + provider adapter tests - -## Phase 2 TODO -- [x] Add adapter-driven backend model catalog -- [x] Add local model discovery + cache -- [x] Expand web UI model picker -- [x] Add OpenAI-compat custom model modal From 42f36ea51fe732ce18d18cd1c831070ba430dc94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrej=20Vy=C5=A1n=C3=BD?= Date: Fri, 24 Apr 2026 17:50:27 +0200 Subject: [PATCH 15/15] Fix review blockers and cleanup for PR merge - Fix blocking urlopen in async GET /api/config/model via asyncio.to_thread - Fix LM Studio use_raw_model_name sending prefixed model id to API - Fix _is_anthropic_model matching openrouter/anthropic/ routes - Add message-count limit (2000) on restore-session-summary - Add provider config check on model switch (openai-compat/ env guard) - Cache _all_adapter_prefixes with lru_cache - Extract useModelCatalog hook from ChatInput (635 -> 528 lines) - Memoize provider groups calculation - Add Enter key support to CustomModelDialog - Fix useEffect dependency loop on selectedModelInfo Co-Authored-By: Claude Opus 4.6 (1M context) --- agent/core/provider_adapters.py | 4 +- backend/routes/agent.py | 19 +- frontend/src/components/Chat/ChatInput.tsx | 487 +++++++----------- .../src/components/Chat/CustomModelDialog.tsx | 98 ++-- frontend/src/hooks/useModelCatalog.ts | 195 +++++++ tests/test_provider_adapters.py | 2 +- 6 files changed, 468 insertions(+), 337 deletions(-) create mode 100644 frontend/src/hooks/useModelCatalog.ts diff --git a/agent/core/provider_adapters.py b/agent/core/provider_adapters.py index cfc729d2..a7f8a278 100644 --- a/agent/core/provider_adapters.py +++ b/agent/core/provider_adapters.py @@ -1,5 +1,6 @@ """Provider adapters for runtime params and model catalog metadata.""" +import functools import json import os import time @@ -64,6 +65,7 @@ def _provider_avatar_url(provider_id: str) -> str: return avatars.get(provider_id, "https://huggingface.co/api/avatars/huggingface") +@functools.lru_cache(maxsize=1) def _all_adapter_prefixes() -> tuple[str, ...]: prefixes: list[str] = [] for adapter in ADAPTERS: @@ -396,7 +398,7 @@ def build_params( model_id = model_name.removeprefix(self.prefixes[0]) params: dict[str, Any] = { - "model": model_name if self.use_raw_model_name else f"openai/{model_id}", + "model": model_id if self.use_raw_model_name else f"openai/{model_id}", "api_base": self.resolved_api_base(), "api_key": self.resolved_api_key(), } diff --git a/backend/routes/agent.py b/backend/routes/agent.py index 23b7d7bd..8c5e4504 100644 --- a/backend/routes/agent.py +++ b/backend/routes/agent.py @@ -34,7 +34,11 @@ from agent.core.llm_errors import health_error_type, render_llm_error_message from agent.core.llm_params import _resolve_llm_params -from agent.core.provider_adapters import build_model_catalog, is_valid_model_name +from agent.core.provider_adapters import ( + build_model_catalog, + is_valid_model_name, + resolve_adapter, +) logger = logging.getLogger(__name__) @@ -42,7 +46,7 @@ def _is_anthropic_model(model_id: str) -> bool: - return "anthropic" in model_id + return model_id.startswith(("anthropic/", "bedrock/")) async def _require_hf_for_anthropic(request: Request, model_id: str) -> None: @@ -162,7 +166,9 @@ async def llm_health_check() -> LLMHealthResponse: @router.get("/config/model") async def get_model() -> dict: """Get current model and available models. No auth required.""" - return build_model_catalog(session_manager.config.model_name) + return await asyncio.to_thread( + build_model_catalog, session_manager.config.model_name + ) _TITLE_STRIP_CHARS = str.maketrans("", "", "`*_~#[]()") @@ -290,6 +296,8 @@ async def restore_session_summary( messages = body.get("messages") if not isinstance(messages, list) or not messages: raise HTTPException(status_code=400, detail="Missing 'messages' array") + if len(messages) > 2000: + raise HTTPException(status_code=400, detail="Too many messages (max 2000)") hf_token = None auth_header = request.headers.get("Authorization", "") @@ -361,6 +369,11 @@ async def set_session_model( raise HTTPException(status_code=400, detail="Missing 'model' field") if not is_valid_model_name(model_id): raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}") + adapter = resolve_adapter(model_id) + if adapter and not adapter.should_show(): + raise HTTPException( + status_code=400, detail=f"Provider not configured for: {model_id}" + ) await _require_hf_for_anthropic(request, model_id) agent_session = session_manager.sessions.get(session_id) if not agent_session: diff --git a/frontend/src/components/Chat/ChatInput.tsx b/frontend/src/components/Chat/ChatInput.tsx index db1fbaa2..5f86a0f5 100644 --- a/frontend/src/components/Chat/ChatInput.tsx +++ b/frontend/src/components/Chat/ChatInput.tsx @@ -1,4 +1,4 @@ -import { useState, useCallback, useEffect, useRef, KeyboardEvent } from 'react'; +import { useState, useCallback, useEffect, useRef, KeyboardEvent } from "react"; import { Box, TextField, @@ -11,48 +11,17 @@ import { ListItemText, Chip, ListSubheader, -} from '@mui/material'; -import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward'; -import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown'; -import StopIcon from '@mui/icons-material/Stop'; -import { apiFetch } from '@/utils/api'; -import { useUserQuota } from '@/hooks/useUserQuota'; -import ClaudeCapDialog from '@/components/ClaudeCapDialog'; -import { useAgentStore } from '@/store/agentStore'; -import { FIRST_FREE_MODEL_PATH } from '@/utils/model'; -import CustomModelDialog from '@/components/Chat/CustomModelDialog'; - -interface ModelOption { - id: string; - label: string; - description: string; - provider: string; - providerLabel: string; - avatarUrl?: string; - recommended?: boolean; - source?: string; -} - -interface ProviderOption { - id: string; - label: string; - avatarUrl?: string; - supportsCustomModel?: boolean; - customModelHint?: string; - customModelMode?: string; - prefix?: string; -} - -interface ModelCatalogResponse { - current?: string; - currentInfo?: ModelOption | null; - available?: ModelOption[]; - providers?: ProviderOption[]; -} - -interface SessionResponse { - model?: string; -} +} from "@mui/material"; +import ArrowUpwardIcon from "@mui/icons-material/ArrowUpward"; +import ArrowDropDownIcon from "@mui/icons-material/ArrowDropDown"; +import StopIcon from "@mui/icons-material/Stop"; +import { useUserQuota } from "@/hooks/useUserQuota"; +import { useModelCatalog } from "@/hooks/useModelCatalog"; +import type { ModelOption, ProviderOption } from "@/hooks/useModelCatalog"; +import ClaudeCapDialog from "@/components/ClaudeCapDialog"; +import { useAgentStore } from "@/store/agentStore"; +import { FIRST_FREE_MODEL_PATH } from "@/utils/model"; +import CustomModelDialog from "@/components/Chat/CustomModelDialog"; interface ChatInputProps { sessionId?: string; @@ -63,40 +32,9 @@ interface ChatInputProps { placeholder?: string; } -const OPENAI_COMPAT_PROVIDER = 'openai_compat'; - -const toModelOption = (value: unknown): ModelOption | null => { - if (!value || typeof value !== 'object') return null; - const v = value as Record; - if (typeof v.id !== 'string' || typeof v.label !== 'string') return null; - return { - id: v.id, - label: v.label, - description: typeof v.description === 'string' ? v.description : '', - provider: typeof v.provider === 'string' ? v.provider : '', - providerLabel: typeof v.providerLabel === 'string' ? v.providerLabel : '', - avatarUrl: typeof v.avatarUrl === 'string' ? v.avatarUrl : undefined, - recommended: Boolean(v.recommended), - source: typeof v.source === 'string' ? v.source : undefined, - }; -}; - -const toProviderOption = (value: unknown): ProviderOption | null => { - if (!value || typeof value !== 'object') return null; - const v = value as Record; - if (typeof v.id !== 'string' || typeof v.label !== 'string') return null; - return { - id: v.id, - label: v.label, - avatarUrl: typeof v.avatarUrl === 'string' ? v.avatarUrl : undefined, - supportsCustomModel: Boolean(v.supportsCustomModel), - customModelHint: typeof v.customModelHint === 'string' ? v.customModelHint : undefined, - customModelMode: typeof v.customModelMode === 'string' ? v.customModelMode : undefined, - prefix: typeof v.prefix === 'string' ? v.prefix : undefined, - }; -}; +const OPENAI_COMPAT_PROVIDER = "openai_compat"; -const isClaudePath = (modelPath: string) => modelPath.startsWith('anthropic/'); +const isClaudePath = (modelPath: string) => modelPath.startsWith("anthropic/"); const firstFreeModel = (models: ModelOption[]) => { const byPath = models.find((m) => m.id === FIRST_FREE_MODEL_PATH); @@ -110,97 +48,26 @@ export default function ChatInput({ onStop, isProcessing = false, disabled = false, - placeholder = 'Ask anything...', + placeholder = "Ask anything...", }: ChatInputProps) { - const [input, setInput] = useState(''); + const [input, setInput] = useState(""); const inputRef = useRef(null); - const [modelOptions, setModelOptions] = useState([]); - const [providerOptions, setProviderOptions] = useState([]); - const [selectedModelPath, setSelectedModelPath] = useState(''); - const [selectedModelInfo, setSelectedModelInfo] = useState(null); + const { + modelOptions, + selectedModelPath, + selectedModel, + switchModel, + groups, + } = useModelCatalog(sessionId); const [modelAnchorEl, setModelAnchorEl] = useState(null); const [customModalOpen, setCustomModalOpen] = useState(false); - const [customPrefix, setCustomPrefix] = useState('openai-compat/'); + const [customPrefix, setCustomPrefix] = useState("openai-compat/"); const { quota, refresh: refreshQuota } = useUserQuota(); const claudeQuotaExhausted = useAgentStore((s) => s.claudeQuotaExhausted); - const setClaudeQuotaExhausted = useAgentStore((s) => s.setClaudeQuotaExhausted); - const lastSentRef = useRef(''); - - useEffect(() => { - let cancelled = false; - - const loadCatalog = async () => { - try { - const res = await apiFetch('/api/config/model'); - if (!res.ok || cancelled) return; - const data = (await res.json()) as ModelCatalogResponse; - const available = (data.available || []) - .map(toModelOption) - .filter((v): v is ModelOption => v !== null); - const providers = (data.providers || []) - .map(toProviderOption) - .filter((v): v is ProviderOption => v !== null); - const currentInfo = toModelOption(data.currentInfo ?? null); - if (cancelled) return; - - setModelOptions(available); - setProviderOptions(providers); - setSelectedModelPath(data.current || ''); - setSelectedModelInfo(currentInfo); - } catch { - // ignore - } - }; - - void loadCatalog(); - return () => { - cancelled = true; - }; - }, []); - - useEffect(() => { - if (!sessionId) return; - let cancelled = false; - apiFetch(`/api/session/${sessionId}`) - .then((res) => (res.ok ? res.json() : null)) - .then((data: SessionResponse | null) => { - if (cancelled || !data?.model) return; - setSelectedModelPath(data.model); - const model = modelOptions.find((m) => m.id === data.model); - if (model) { - setSelectedModelInfo(model); - return; - } - const inferred = selectedModelInfo && selectedModelInfo.id === data.model - ? selectedModelInfo - : { - id: data.model, - label: data.model, - description: 'Custom model', - provider: '', - providerLabel: '', - }; - setSelectedModelInfo(inferred); - }) - .catch(() => { - // ignore - }); - return () => { - cancelled = true; - }; - }, [sessionId, modelOptions, selectedModelInfo]); - - const selectedModel = selectedModelInfo - || modelOptions.find((m) => m.id === selectedModelPath) - || (selectedModelPath - ? { - id: selectedModelPath, - label: selectedModelPath, - description: 'Custom model', - provider: '', - providerLabel: '', - } - : null); + const setClaudeQuotaExhausted = useAgentStore( + (s) => s.setClaudeQuotaExhausted, + ); + const lastSentRef = useRef(""); useEffect(() => { if (!disabled && !isProcessing && inputRef.current) { @@ -212,7 +79,7 @@ export default function ChatInput({ if (input.trim() && !disabled) { lastSentRef.current = input; onSend(input); - setInput(''); + setInput(""); } }, [input, disabled, onSend]); @@ -228,7 +95,7 @@ export default function ChatInput({ const handleKeyDown = useCallback( (e: KeyboardEvent) => { - if (e.key === 'Enter' && !e.shiftKey) { + if (e.key === "Enter" && !e.shiftKey) { e.preventDefault(); handleSend(); } @@ -244,21 +111,6 @@ export default function ChatInput({ setModelAnchorEl(null); }; - const switchModel = useCallback( - async (modelPath: string, info?: ModelOption) => { - if (!sessionId) return; - const res = await apiFetch(`/api/session/${sessionId}/model`, { - method: 'POST', - body: JSON.stringify({ model: modelPath }), - }); - if (res.ok) { - setSelectedModelPath(modelPath); - setSelectedModelInfo(info || modelOptions.find((m) => m.id === modelPath) || null); - } - }, - [sessionId, modelOptions], - ); - const handleSelectModel = async (model: ModelOption) => { handleModelClose(); try { @@ -280,9 +132,9 @@ export default function ChatInput({ const info: ModelOption = { id: full, label: modelId, - description: 'Custom OpenAI-compatible model', + description: "Custom OpenAI-compatible model", provider: OPENAI_COMPAT_PROVIDER, - providerLabel: 'OpenAI-Compatible', + providerLabel: "OpenAI-Compatible", }; await switchModel(full, info); setCustomModalOpen(false); @@ -302,8 +154,8 @@ export default function ChatInput({ const retryText = lastSentRef.current; if (retryText) { onSend(retryText); - setInput(''); - lastSentRef.current = ''; + setInput(""); + lastSentRef.current = ""; } } catch { // ignore @@ -312,48 +164,43 @@ export default function ChatInput({ const claudeChip = (() => { if (!quota || quota.claudeUsedToday === 0) return null; - if (quota.plan === 'free') { - return quota.claudeRemaining > 0 ? 'Free today' : 'Pro only'; + if (quota.plan === "free") { + return quota.claudeRemaining > 0 ? "Free today" : "Pro only"; } return `${quota.claudeUsedToday}/${quota.claudeDailyCap} today`; })(); - const groups = providerOptions.map((provider) => ({ - provider, - models: modelOptions - .filter((m) => m.provider === provider.id) - .sort((a, b) => { - const ar = a.recommended ? 0 : 1; - const br = b.recommended ? 0 : 1; - if (ar !== br) return ar - br; - return a.label.localeCompare(b.label); - }), - })); - return ( - + @@ -371,24 +218,24 @@ export default function ChatInput({ InputProps={{ disableUnderline: true, sx: { - color: 'var(--text)', - fontSize: '15px', - fontFamily: 'inherit', + color: "var(--text)", + fontSize: "15px", + fontFamily: "inherit", padding: 0, lineHeight: 1.5, - minHeight: { xs: '44px', md: '56px' }, - alignItems: 'flex-start', + minHeight: { xs: "44px", md: "56px" }, + alignItems: "flex-start", }, }} sx={{ flex: 1, - '& .MuiInputBase-root': { + "& .MuiInputBase-root": { p: 0, - backgroundColor: 'transparent', + backgroundColor: "transparent", }, - '& textarea': { - resize: 'none', - padding: '0 !important', + "& textarea": { + resize: "none", + padding: "0 !important", }, }} /> @@ -398,18 +245,29 @@ export default function ChatInput({ sx={{ mt: 1, p: 1.5, - borderRadius: '10px', - color: 'var(--muted-text)', - transition: 'all 0.2s', - position: 'relative', - '&:hover': { - bgcolor: 'var(--hover-bg)', - color: 'var(--accent-red)', + borderRadius: "10px", + color: "var(--muted-text)", + transition: "all 0.2s", + position: "relative", + "&:hover": { + bgcolor: "var(--hover-bg)", + color: "var(--accent-red)", }, }} > - - + + @@ -420,14 +278,14 @@ export default function ChatInput({ sx={{ mt: 1, p: 1, - borderRadius: '10px', - color: 'var(--muted-text)', - transition: 'all 0.2s', - '&:hover': { - color: 'var(--accent-yellow)', - bgcolor: 'var(--hover-bg)', + borderRadius: "10px", + color: "var(--muted-text)", + transition: "all 0.2s", + "&:hover": { + color: "var(--accent-yellow)", + bgcolor: "var(--hover-bg)", }, - '&.Mui-disabled': { + "&.Mui-disabled": { opacity: 0.3, }, }} @@ -440,33 +298,57 @@ export default function ChatInput({ - + powered by {selectedModel?.avatarUrl && ( {selectedModel.label} )} - - {selectedModel?.label || 'Model'} + + {selectedModel?.label || "Model"} - + {provider.label} @@ -517,8 +399,8 @@ export default function ChatInput({ selected={selectedModelPath === model.id} sx={{ py: 1.5, - '&.Mui-selected': { - bgcolor: 'rgba(255,255,255,0.05)', + "&.Mui-selected": { + bgcolor: "rgba(255,255,255,0.05)", }, }} > @@ -527,22 +409,27 @@ export default function ChatInput({ {model.label} ) : ( {provider.label.slice(0, 2)} @@ -551,17 +438,19 @@ export default function ChatInput({ + {model.label} {model.recommended && ( @@ -571,10 +460,10 @@ export default function ChatInput({ label={claudeChip} size="small" sx={{ - height: '18px', - fontSize: '10px', - bgcolor: 'rgba(255,255,255,0.08)', - color: 'var(--muted-text)', + height: "18px", + fontSize: "10px", + bgcolor: "rgba(255,255,255,0.08)", + color: "var(--muted-text)", fontWeight: 600, }} /> @@ -583,32 +472,36 @@ export default function ChatInput({ } secondary={model.description} secondaryTypographyProps={{ - sx: { fontSize: '12px', color: 'var(--muted-text)' }, + sx: { fontSize: "12px", color: "var(--muted-text)" }, }} /> ))} - {provider.id === OPENAI_COMPAT_PROVIDER && provider.supportsCustomModel && ( - handleOpenCustomModal(provider)} - disabled={!sessionId} - sx={{ py: 1.5 }} - > - '} - secondaryTypographyProps={{ - sx: { fontSize: '12px', color: 'var(--muted-text)' }, - }} - /> - - )} + {provider.id === OPENAI_COMPAT_PROVIDER && + provider.supportsCustomModel && ( + handleOpenCustomModal(provider)} + disabled={!sessionId} + sx={{ py: 1.5 }} + > + " + } + secondaryTypographyProps={{ + sx: { fontSize: "12px", color: "var(--muted-text)" }, + }} + /> + + )} ))} {!sessionId && ( - + Start a session to switch models. @@ -624,7 +517,7 @@ export default function ChatInput({ { if (!open) { - setValue(''); - setError(''); + setValue(""); + setError(""); setSubmitting(false); } }, [open]); @@ -37,15 +37,15 @@ export default function CustomModelDialog({ const handleSubmit = async () => { const trimmed = value.trim(); if (!trimmed) { - setError('Model id is required'); + setError("Model id is required"); return; } - setError(''); + setError(""); setSubmitting(true); try { await onSubmit(trimmed); } catch { - setError('Failed to switch model'); + setError("Failed to switch model"); } finally { setSubmitting(false); } @@ -56,26 +56,43 @@ export default function CustomModelDialog({ open={open} onClose={submitting ? undefined : onClose} slotProps={{ - backdrop: { sx: { backgroundColor: 'rgba(0,0,0,0.5)', backdropFilter: 'blur(4px)' } }, + backdrop: { + sx: { + backgroundColor: "rgba(0,0,0,0.5)", + backdropFilter: "blur(4px)", + }, + }, }} PaperProps={{ sx: { - bgcolor: 'var(--panel)', - border: '1px solid var(--border)', - borderRadius: 'var(--radius-md)', - boxShadow: 'var(--shadow-1)', + bgcolor: "var(--panel)", + border: "1px solid var(--border)", + borderRadius: "var(--radius-md)", + boxShadow: "var(--shadow-1)", maxWidth: 520, mx: 2, - width: '100%', + width: "100%", }, }} > - + Custom OpenAI-compatible model - - Enter model id only. We will use server env config for base URL and key. + + Enter model id only. We will use server env config for base URL and + key. setValue(e.target.value)} + onKeyDown={(e) => { + if (e.key === "Enter" && !submitting) void handleSubmit(); + }} placeholder="e.g. my-model" disabled={submitting} sx={{ - '& .MuiOutlinedInput-root': { - bgcolor: 'transparent', - color: 'var(--text)', + "& .MuiOutlinedInput-root": { + bgcolor: "transparent", + color: "var(--text)", }, }} /> - - Final id: {prefix}{value.trim() || ''} + + Final id:{" "} + + {prefix} + {value.trim() || ""} + {error && ( - + {error} )} @@ -107,11 +135,11 @@ export default function CustomModelDialog({ disabled={submitting} size="small" sx={{ - color: 'var(--muted-text)', - fontSize: '0.82rem', + color: "var(--muted-text)", + fontSize: "0.82rem", px: 2, - textTransform: 'none', - '&:hover': { bgcolor: 'var(--hover-bg)' }, + textTransform: "none", + "&:hover": { bgcolor: "var(--hover-bg)" }, }} > Cancel @@ -122,17 +150,17 @@ export default function CustomModelDialog({ variant="contained" size="small" sx={{ - fontSize: '0.82rem', + fontSize: "0.82rem", px: 2.5, - bgcolor: 'var(--accent-yellow)', - color: '#000', - textTransform: 'none', + bgcolor: "var(--accent-yellow)", + color: "#000", + textTransform: "none", fontWeight: 700, - boxShadow: 'none', - '&:hover': { bgcolor: '#FFB340', boxShadow: 'none' }, + boxShadow: "none", + "&:hover": { bgcolor: "#FFB340", boxShadow: "none" }, }} > - {submitting ? 'Switching…' : 'Switch model'} + {submitting ? "Switching…" : "Switch model"} diff --git a/frontend/src/hooks/useModelCatalog.ts b/frontend/src/hooks/useModelCatalog.ts new file mode 100644 index 00000000..42d8f7bc --- /dev/null +++ b/frontend/src/hooks/useModelCatalog.ts @@ -0,0 +1,195 @@ +import { useState, useCallback, useEffect, useMemo } from "react"; +import { apiFetch } from "@/utils/api"; + +export interface ModelOption { + id: string; + label: string; + description: string; + provider: string; + providerLabel: string; + avatarUrl?: string; + recommended?: boolean; + source?: string; +} + +export interface ProviderOption { + id: string; + label: string; + avatarUrl?: string; + supportsCustomModel?: boolean; + customModelHint?: string; + customModelMode?: string; + prefix?: string; +} + +interface ModelCatalogResponse { + current?: string; + currentInfo?: ModelOption | null; + available?: ModelOption[]; + providers?: ProviderOption[]; +} + +interface SessionResponse { + model?: string; +} + +const toModelOption = (value: unknown): ModelOption | null => { + if (!value || typeof value !== "object") return null; + const v = value as Record; + if (typeof v.id !== "string" || typeof v.label !== "string") return null; + return { + id: v.id, + label: v.label, + description: typeof v.description === "string" ? v.description : "", + provider: typeof v.provider === "string" ? v.provider : "", + providerLabel: typeof v.providerLabel === "string" ? v.providerLabel : "", + avatarUrl: typeof v.avatarUrl === "string" ? v.avatarUrl : undefined, + recommended: Boolean(v.recommended), + source: typeof v.source === "string" ? v.source : undefined, + }; +}; + +const toProviderOption = (value: unknown): ProviderOption | null => { + if (!value || typeof value !== "object") return null; + const v = value as Record; + if (typeof v.id !== "string" || typeof v.label !== "string") return null; + return { + id: v.id, + label: v.label, + avatarUrl: typeof v.avatarUrl === "string" ? v.avatarUrl : undefined, + supportsCustomModel: Boolean(v.supportsCustomModel), + customModelHint: + typeof v.customModelHint === "string" ? v.customModelHint : undefined, + customModelMode: + typeof v.customModelMode === "string" ? v.customModelMode : undefined, + prefix: typeof v.prefix === "string" ? v.prefix : undefined, + }; +}; + +export function useModelCatalog(sessionId?: string) { + const [modelOptions, setModelOptions] = useState([]); + const [providerOptions, setProviderOptions] = useState([]); + const [selectedModelPath, setSelectedModelPath] = useState(""); + const [selectedModelInfo, setSelectedModelInfo] = + useState(null); + + useEffect(() => { + let cancelled = false; + + const loadCatalog = async () => { + try { + const res = await apiFetch("/api/config/model"); + if (!res.ok || cancelled) return; + const data = (await res.json()) as ModelCatalogResponse; + const available = (data.available || []) + .map(toModelOption) + .filter((v): v is ModelOption => v !== null); + const providers = (data.providers || []) + .map(toProviderOption) + .filter((v): v is ProviderOption => v !== null); + const currentInfo = toModelOption(data.currentInfo ?? null); + if (cancelled) return; + + setModelOptions(available); + setProviderOptions(providers); + setSelectedModelPath(data.current || ""); + setSelectedModelInfo(currentInfo); + } catch { + // ignore — catalog is optional, models may not be configured + } + }; + + void loadCatalog(); + return () => { + cancelled = true; + }; + }, []); + + useEffect(() => { + if (!sessionId) return; + let cancelled = false; + apiFetch(`/api/session/${sessionId}`) + .then((res) => (res.ok ? res.json() : null)) + .then((data: SessionResponse | null) => { + if (cancelled || !data?.model) return; + setSelectedModelPath(data.model); + const model = modelOptions.find((m) => m.id === data.model); + if (model) { + setSelectedModelInfo(model); + return; + } + setSelectedModelInfo((prev) => + prev && prev.id === data.model + ? prev + : { + id: data.model, + label: data.model, + description: "Custom model", + provider: "", + providerLabel: "", + }, + ); + }) + .catch(() => { + // ignore + }); + return () => { + cancelled = true; + }; + }, [sessionId, modelOptions]); + + const selectedModel = + selectedModelInfo || + modelOptions.find((m) => m.id === selectedModelPath) || + (selectedModelPath + ? { + id: selectedModelPath, + label: selectedModelPath, + description: "Custom model", + provider: "", + providerLabel: "", + } + : null); + + const switchModel = useCallback( + async (modelPath: string, info?: ModelOption) => { + if (!sessionId) return; + const res = await apiFetch(`/api/session/${sessionId}/model`, { + method: "POST", + body: JSON.stringify({ model: modelPath }), + }); + if (res.ok) { + setSelectedModelPath(modelPath); + setSelectedModelInfo( + info || modelOptions.find((m) => m.id === modelPath) || null, + ); + } + }, + [sessionId, modelOptions], + ); + + const groups = useMemo( + () => + providerOptions.map((provider) => ({ + provider, + models: modelOptions + .filter((m) => m.provider === provider.id) + .sort((a, b) => { + const ar = a.recommended ? 0 : 1; + const br = b.recommended ? 0 : 1; + if (ar !== br) return ar - br; + return a.label.localeCompare(b.label); + }), + })), + [providerOptions, modelOptions], + ); + + return { + modelOptions, + providerOptions, + selectedModelPath, + selectedModel, + switchModel, + groups, + }; +} diff --git a/tests/test_provider_adapters.py b/tests/test_provider_adapters.py index 193647ab..2e1141e8 100644 --- a/tests/test_provider_adapters.py +++ b/tests/test_provider_adapters.py @@ -186,7 +186,7 @@ def test_lm_studio_adapter_uses_raw_model_name(monkeypatch): params = _resolve_llm_params("lm_studio/google/gemma-3-12b") assert params == { - "model": "lm_studio/google/gemma-3-12b", + "model": "google/gemma-3-12b", "api_base": "http://127.0.0.1:1234/v1", "api_key": "lm-studio", }