From 71ff91850a08ccc39cd8e068a0fd9bcd14908baf Mon Sep 17 00:00:00 2001 From: bhargav1000 Date: Fri, 24 Apr 2026 16:32:38 +0800 Subject: [PATCH 1/6] Add local model support --- README.md | 76 +++++++ agent/core/llm_params.py | 37 ++++ agent/core/model_switcher.py | 17 +- backend/model_catalog.py | 96 +++++++++ backend/routes/agent.py | 51 +---- frontend/src/components/Chat/ChatInput.tsx | 235 ++++++++++++++++++--- tests/unit/test_llm_params.py | 62 +++++- tests/unit/test_local_model_validation.py | 34 +++ tests/unit/test_user_quotas.py | 14 +- 9 files changed, 537 insertions(+), 85 deletions(-) create mode 100644 backend/model_catalog.py create mode 100644 tests/unit/test_local_model_validation.py diff --git a/README.md b/README.md index 8a6c1ccd..38418a10 100644 --- a/README.md +++ b/README.md @@ -106,6 +106,82 @@ JSON file: } ``` +### Model Selection + +In the CLI, run `/model` to list suggested models and see the active model: + +```text +/model +``` + +Switch models by passing the model id: + +```text +/model moonshotai/Kimi-K2.6 +/model bedrock/us.anthropic.claude-opus-4-6-v1 +``` + +You can also choose a model at startup: + +```bash +ml-intern --model moonshotai/Kimi-K2.6 "your prompt" +``` + +### Local Models + +Local model support uses OpenAI-compatible HTTP endpoints through LiteLLM. The agent does not load model weights directly from disk; a local inference server must already be running. + +Supported local model id prefixes: + +| Prefix | Default endpoint | Example | +| --- | --- | --- | +| `ollama/` | `http://localhost:11434/v1` | `ollama/llama3.1` | +| `vllm/` | `http://localhost:8000/v1` | `vllm/Qwen3.5-2B` | +| `llamacpp/` | `http://localhost:8001/v1` | `llamacpp/unsloth/Qwen3.5-2B` | +| `local://` | `${LOCAL_LLM_BASE_URL}/v1` | `local://my-model` | + +Override endpoints with environment variables: + +```bash +OLLAMA_BASE_URL=http://localhost:11434 +VLLM_BASE_URL=http://localhost:8000 +LLAMACPP_BASE_URL=http://localhost:8001 +LOCAL_LLM_BASE_URL=http://localhost:8000 +``` + +For example, with Ollama: + +```bash +ollama pull llama3.1 +ollama serve +ml-intern +``` + +Then switch inside the CLI: + +```text +/model ollama/llama3.1 +``` + +For llama.cpp, start an OpenAI-compatible server first, then point the agent at it if you are not using the default port: + +```bash +export LLAMACPP_BASE_URL=http://localhost:8080 +ml-intern +``` + +```text +/model llamacpp/ +``` + +For the web UI/API, enable local model selection: + +```bash +ENABLE_LOCAL_MODELS=true +``` + +When `ENABLE_LOCAL_MODELS=true`, the backend exposes local model presets and accepts custom local paths with the prefixes above. The web model menu also shows a custom local model path field, so you can enter values like `ollama/qwen2.5-coder` or `local://my-model`. + ## Architecture ### Component Overview diff --git a/agent/core/llm_params.py b/agent/core/llm_params.py index 880886b3..0717767e 100644 --- a/agent/core/llm_params.py +++ b/agent/core/llm_params.py @@ -79,6 +79,7 @@ def _widened(model: str) -> bool: _ANTHROPIC_EFFORTS = {"low", "medium", "high", "xhigh", "max"} _OPENAI_EFFORTS = {"minimal", "low", "medium", "high", "xhigh"} _HF_EFFORTS = {"low", "medium", "high"} +_LOCAL_DEFAULT_API_KEY = "sk-no-key-required" class UnsupportedEffortError(ValueError): @@ -180,6 +181,42 @@ def _resolve_llm_params( params["reasoning_effort"] = reasoning_effort return params + if model_name.startswith("ollama/"): + local_model = model_name.split("/", 1)[1] + api_base = os.environ.get("OLLAMA_BASE_URL", "http://localhost:11434") + return { + "model": f"openai/{local_model}", + "api_base": f"{api_base.rstrip('/')}/v1", + "api_key": os.environ.get("OLLAMA_API_KEY", _LOCAL_DEFAULT_API_KEY), + } + + if model_name.startswith("vllm/"): + local_model = model_name.split("/", 1)[1] + api_base = os.environ.get("VLLM_BASE_URL", "http://localhost:8000") + return { + "model": f"openai/{local_model}", + "api_base": f"{api_base.rstrip('/')}/v1", + "api_key": os.environ.get("VLLM_API_KEY", _LOCAL_DEFAULT_API_KEY), + } + + if model_name.startswith("llamacpp/"): + local_model = model_name.split("/", 1)[1] + api_base = os.environ.get("LLAMACPP_BASE_URL", "http://localhost:8001") + return { + "model": f"openai/{local_model}", + "api_base": f"{api_base.rstrip('/')}/v1", + "api_key": os.environ.get("LLAMACPP_API_KEY", _LOCAL_DEFAULT_API_KEY), + } + + if model_name.startswith("local://"): + local_model = model_name.split("://", 1)[1] + api_base = os.environ.get("LOCAL_LLM_BASE_URL", "http://localhost:8000") + return { + "model": f"openai/{local_model}", + "api_base": f"{api_base.rstrip('/')}/v1", + "api_key": os.environ.get("LOCAL_LLM_API_KEY", _LOCAL_DEFAULT_API_KEY), + } + hf_model = model_name.removeprefix("huggingface/") api_key = _resolve_hf_router_token(session_hf_token) params = { diff --git a/agent/core/model_switcher.py b/agent/core/model_switcher.py index 63c0f40c..5ac676d7 100644 --- a/agent/core/model_switcher.py +++ b/agent/core/model_switcher.py @@ -36,6 +36,7 @@ _ROUTING_POLICIES = {"fastest", "cheapest", "preferred"} +_LOCAL_MODEL_PREFIXES = ("ollama/", "vllm/", "llamacpp/", "local://") def is_valid_model_id(model_id: str) -> bool: @@ -44,13 +45,21 @@ def is_valid_model_id(model_id: str) -> bool: Accepts: • anthropic/ • openai/ + • ollama/, vllm/, llamacpp/, local:///[:] (HF router; tag = provider or policy) • huggingface//[:] (same, accepts legacy prefix) Actual availability is verified against the HF router catalog on switch, and by the provider on the probe's ping call. """ - if not model_id or "/" not in model_id: + if not model_id: + return False + if any( + model_id.startswith(prefix) and len(model_id) > len(prefix) + for prefix in _LOCAL_MODEL_PREFIXES + ): + return True + if "/" not in model_id: return False head = model_id.split(":", 1)[0] parts = head.split("/") @@ -66,7 +75,7 @@ def _print_hf_routing_info(model_id: str, console) -> bool: Anthropic / OpenAI ids return ``True`` without printing anything — the probe below covers "does this model exist". """ - if model_id.startswith(("anthropic/", "openai/")): + if model_id.startswith(("anthropic/", "openai/", *_LOCAL_MODEL_PREFIXES)): return True from agent.core import hf_router_catalog as cat @@ -139,7 +148,9 @@ def print_model_listing(config, console) -> None: console.print( "\n[dim]Paste any HF model id (e.g. 'MiniMaxAI/MiniMax-M2.7').\n" "Add ':fastest', ':cheapest', ':preferred', or ':' to override routing.\n" - "Use 'anthropic/' or 'openai/' for direct API access.[/dim]" + "Use 'anthropic/' or 'openai/' for direct API access.\n" + "Use 'ollama/', 'vllm/', 'llamacpp/', or 'local://' " + "for local OpenAI-compatible endpoints.[/dim]" ) diff --git a/backend/model_catalog.py b/backend/model_catalog.py new file mode 100644 index 00000000..8e86c3f5 --- /dev/null +++ b/backend/model_catalog.py @@ -0,0 +1,96 @@ +"""Model catalog and validation helpers for agent API routes.""" + +import os +from typing import Any + +LOCAL_MODEL_PREFIXES = ("ollama/", "vllm/", "llamacpp/", "local://") + + +def local_models_enabled() -> bool: + return os.environ.get("ENABLE_LOCAL_MODELS", "false").lower() in { + "1", + "true", + "yes", + "on", + } + + +def get_available_models() -> list[dict[str, Any]]: + models: list[dict[str, Any]] = [ + { + "id": "moonshotai/Kimi-K2.6", + "label": "Kimi K2.6", + "provider": "huggingface", + "tier": "free", + "recommended": True, + }, + { + "id": "bedrock/us.anthropic.claude-opus-4-6-v1", + "label": "Claude Opus 4.6", + "provider": "anthropic", + "tier": "pro", + "recommended": True, + }, + { + "id": "MiniMaxAI/MiniMax-M2.7", + "label": "MiniMax M2.7", + "provider": "huggingface", + "tier": "free", + }, + { + "id": "zai-org/GLM-5.1", + "label": "GLM 5.1", + "provider": "huggingface", + "tier": "free", + }, + ] + + if local_models_enabled(): + models.extend( + [ + { + "id": "ollama/llama3.1", + "label": "Llama 3.1 (Ollama)", + "provider": "local", + "tier": "free", + }, + { + "id": "vllm/Qwen3.5-2B", + "label": "Qwen3.5-2B (vLLM)", + "provider": "local", + "tier": "free", + }, + { + "id": "llamacpp/unsloth/Qwen3.5-2B", + "label": "Qwen3.5-2B (llama.cpp)", + "provider": "local", + "tier": "free", + "recommended": True, + }, + ] + ) + + return models + + +def available_model_ids() -> set[str]: + return {m["id"] for m in get_available_models()} + + +def is_custom_local_model_id(model_id: str) -> bool: + if not local_models_enabled(): + return False + if not isinstance(model_id, str): + return False + if not model_id or model_id != model_id.strip() or any( + char.isspace() for char in model_id + ): + return False + return any( + model_id.startswith(prefix) and len(model_id) > len(prefix) + for prefix in LOCAL_MODEL_PREFIXES + ) + + +def is_valid_model_id(model_id: str) -> bool: + return model_id in available_model_ids() or is_custom_local_model_id(model_id) diff --git a/backend/routes/agent.py b/backend/routes/agent.py index 6a688f73..622b5999 100644 --- a/backend/routes/agent.py +++ b/backend/routes/agent.py @@ -19,6 +19,7 @@ ) from fastapi.responses import StreamingResponse from litellm import acompletion +from model_catalog import get_available_models, is_anthropic_model, is_valid_model_id from models import ( ApprovalRequest, HealthResponse, @@ -41,51 +42,18 @@ router = APIRouter(prefix="/api", tags=["agent"]) -AVAILABLE_MODELS = [ - { - "id": "moonshotai/Kimi-K2.6", - "label": "Kimi K2.6", - "provider": "huggingface", - "tier": "free", - "recommended": True, - }, - { - "id": "bedrock/us.anthropic.claude-opus-4-6-v1", - "label": "Claude Opus 4.6", - "provider": "anthropic", - "tier": "pro", - "recommended": True, - }, - { - "id": "MiniMaxAI/MiniMax-M2.7", - "label": "MiniMax M2.7", - "provider": "huggingface", - "tier": "free", - }, - { - "id": "zai-org/GLM-5.1", - "label": "GLM 5.1", - "provider": "huggingface", - "tier": "free", - }, -] - - -def _is_anthropic_model(model_id: str) -> bool: - return "anthropic" in model_id - async def _require_hf_for_anthropic(request: Request, model_id: str) -> None: """403 if a non-``huggingface``-org user tries to select an Anthropic model. Anthropic models are billed to the Space's ``ANTHROPIC_API_KEY``; every - other model in ``AVAILABLE_MODELS`` is routed through HF Router and + other configured cloud model is routed through HF Router and billed via ``X-HF-Bill-To``. The gate only fires for Anthropic so non-HF users can still freely switch between the free models. Pattern: https://github.com/huggingface/ml-intern/pull/63 """ - if not _is_anthropic_model(model_id): + if not is_anthropic_model(model_id): return if not await require_huggingface_org_member(request): raise HTTPException( @@ -117,7 +85,7 @@ async def _enforce_claude_quota( if agent_session.claude_counted: return model_name = agent_session.session.config.model_name - if not _is_anthropic_model(model_name): + if not is_anthropic_model(model_name): return user_id = user["user_id"] used = await user_quotas.get_claude_used_today(user_id) @@ -318,7 +286,7 @@ async def get_model() -> dict: """Get current model and available models. No auth required.""" return { "current": session_manager.config.model_name, - "available": AVAILABLE_MODELS, + "available": get_available_models(), } @@ -405,8 +373,7 @@ async def create_session( if isinstance(body, dict): model = body.get("model") - valid_ids = {m["id"] for m in AVAILABLE_MODELS} - if model and model not in valid_ids: + if model and not is_valid_model_id(model): raise HTTPException(status_code=400, detail=f"Unknown model: {model}") # Opus is gated to HF staff (PR #63). Only fires when the resolved model @@ -443,8 +410,7 @@ async def restore_session_summary( hf_token = resolve_hf_request_token(request) model = body.get("model") - valid_ids = {m["id"] for m in AVAILABLE_MODELS} - if model and model not in valid_ids: + if model and not is_valid_model_id(model): raise HTTPException(status_code=400, detail=f"Unknown model: {model}") resolved_model = model or session_manager.config.model_name @@ -502,8 +468,7 @@ async def set_session_model( model_id = body.get("model") if not model_id: raise HTTPException(status_code=400, detail="Missing 'model' field") - valid_ids = {m["id"] for m in AVAILABLE_MODELS} - if model_id not in valid_ids: + if not is_valid_model_id(model_id): raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}") await _require_hf_for_anthropic(request, model_id) agent_session = session_manager.sessions.get(session_id) diff --git a/frontend/src/components/Chat/ChatInput.tsx b/frontend/src/components/Chat/ChatInput.tsx index 28eec904..2f0c8677 100644 --- a/frontend/src/components/Chat/ChatInput.tsx +++ b/frontend/src/components/Chat/ChatInput.tsx @@ -1,8 +1,9 @@ import { useState, useCallback, useEffect, useRef, KeyboardEvent } from 'react'; -import { Box, TextField, IconButton, CircularProgress, Typography, Menu, MenuItem, ListItemIcon, ListItemText, Chip } from '@mui/material'; +import { Box, TextField, IconButton, CircularProgress, Typography, Menu, MenuItem, ListItemIcon, ListItemText, Chip, Divider, Button } from '@mui/material'; import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward'; import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown'; import StopIcon from '@mui/icons-material/Stop'; +import MemoryIcon from '@mui/icons-material/Memory'; import { apiFetch } from '@/utils/api'; import { useUserQuota } from '@/hooks/useUserQuota'; import ClaudeCapDialog from '@/components/ClaudeCapDialog'; @@ -17,6 +18,15 @@ interface ModelOption { description: string; modelPath: string; avatarUrl: string; + provider?: string; + recommended?: boolean; +} + +interface BackendModel { + id: string; + label: string; + provider: string; + tier?: string; recommended?: boolean; } @@ -25,43 +35,82 @@ const getHfAvatarUrl = (modelId: string) => { return `https://huggingface.co/api/avatars/${org}`; }; -const MODEL_OPTIONS: ModelOption[] = [ +const DEFAULT_MODEL_OPTIONS: ModelOption[] = [ { - id: 'kimi-k2.6', + id: 'moonshotai/Kimi-K2.6', name: 'Kimi K2.6', description: 'Novita', modelPath: 'moonshotai/Kimi-K2.6', avatarUrl: getHfAvatarUrl('moonshotai/Kimi-K2.6'), + provider: 'huggingface', recommended: true, }, { - id: 'claude-opus', + id: 'bedrock/us.anthropic.claude-opus-4-6-v1', name: 'Claude Opus 4.6', description: 'Anthropic', modelPath: CLAUDE_MODEL_PATH, avatarUrl: 'https://huggingface.co/api/avatars/Anthropic', + provider: 'anthropic', recommended: true, }, { - id: 'minimax-m2.7', + id: 'MiniMaxAI/MiniMax-M2.7', name: 'MiniMax M2.7', description: 'Novita', modelPath: 'MiniMaxAI/MiniMax-M2.7', avatarUrl: getHfAvatarUrl('MiniMaxAI/MiniMax-M2.7'), + provider: 'huggingface', }, { - id: 'glm-5.1', + id: 'zai-org/GLM-5.1', name: 'GLM 5.1', description: 'Together', modelPath: 'zai-org/GLM-5.1', avatarUrl: getHfAvatarUrl('zai-org/GLM-5.1'), + provider: 'huggingface', }, ]; -const findModelByPath = (path: string): ModelOption | undefined => { - return MODEL_OPTIONS.find(m => m.modelPath === path || path?.includes(m.id)); +const providerDescription = (model: BackendModel): string => { + if (model.provider === 'local') return 'Local'; + if (model.provider === 'anthropic') return 'Anthropic'; + return model.provider === 'huggingface' ? 'Hugging Face' : model.provider; +}; + +const modelOptionFromBackend = (model: BackendModel): ModelOption => ({ + id: model.id, + name: model.label, + description: providerDescription(model), + modelPath: model.id, + avatarUrl: model.provider === 'local' ? '' : getHfAvatarUrl(model.id), + provider: model.provider, + recommended: model.recommended, +}); + +const LOCAL_MODEL_PREFIXES = ['ollama/', 'vllm/', 'llamacpp/', 'local://']; + +const findModelByPath = (options: ModelOption[], path: string): ModelOption | undefined => { + return options.find(m => m.modelPath === path || m.id === path); }; +const isLocalModelPath = (path: string) => LOCAL_MODEL_PREFIXES.some(prefix => path.startsWith(prefix)); + +const labelFromLocalPath = (path: string) => { + if (path.startsWith('local://')) return path.slice('local://'.length); + const prefix = LOCAL_MODEL_PREFIXES.find(p => path.startsWith(p)); + return prefix ? path.slice(prefix.length) : path; +}; + +const customLocalModelOption = (path: string): ModelOption => ({ + id: path, + name: labelFromLocalPath(path), + description: 'Custom local path', + modelPath: path, + avatarUrl: '', + provider: 'local', +}); + interface ChatInputProps { sessionId?: string; onSend: (text: string) => void; @@ -74,13 +123,43 @@ interface ChatInputProps { } const isClaudeModel = (m: ModelOption) => isClaudePath(m.modelPath); -const firstFreeModel = () => MODEL_OPTIONS.find(m => !isClaudeModel(m)) ?? MODEL_OPTIONS[0]; +const isLocalModel = (m: ModelOption) => m.provider === 'local' || isLocalModelPath(m.modelPath); + +function ModelAvatar({ model, size }: { model: ModelOption; size: number }) { + if (isLocalModel(model)) { + return ( + + + + ); + } + return ( + {model.name} + ); +} export default function ChatInput({ sessionId, onSend, onStop, onDeclineBlockedJobs, onContinueBlockedJobsWithNamespace, isProcessing = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) { const [input, setInput] = useState(''); const inputRef = useRef(null); - const [selectedModelId, setSelectedModelId] = useState(MODEL_OPTIONS[0].id); + const [modelOptions, setModelOptions] = useState(DEFAULT_MODEL_OPTIONS); + const [selectedModelId, setSelectedModelId] = useState(DEFAULT_MODEL_OPTIONS[0].id); const [modelAnchorEl, setModelAnchorEl] = useState(null); + const [customModelPath, setCustomModelPath] = useState(''); const { quota, refresh: refreshQuota } = useUserQuota(); // The daily-cap dialog is triggered from two places: (a) a 429 returned // from the chat transport when the user tries to send on Opus over cap — @@ -93,6 +172,26 @@ export default function ChatInput({ sessionId, onSend, onStop, onDeclineBlockedJ const setJobsUpgradeRequired = useAgentStore((s) => s.setJobsUpgradeRequired); const lastSentRef = useRef(''); + useEffect(() => { + let cancelled = false; + apiFetch('/api/config/model') + .then((res) => (res.ok ? res.json() : null)) + .then((data) => { + if (cancelled || !Array.isArray(data?.available)) return; + const nextOptions = data.available.map(modelOptionFromBackend); + if (nextOptions.length === 0) return; + setModelOptions(nextOptions); + const current = typeof data.current === 'string' + ? findModelByPath(nextOptions, data.current) + : undefined; + setSelectedModelId((prev) => ( + current?.id ?? (nextOptions.some((model: ModelOption) => model.id === prev) ? prev : nextOptions[0].id) + )); + }) + .catch(() => { /* keep bundled defaults */ }); + return () => { cancelled = true; }; + }, []); + // Model is per-session: fetch this tab's current model every time the // session changes. Other tabs keep their own selections independently. useEffect(() => { @@ -103,15 +202,23 @@ export default function ChatInput({ sessionId, onSend, onStop, onDeclineBlockedJ .then((data) => { if (cancelled) return; if (data?.model) { - const model = findModelByPath(data.model); + let model = findModelByPath(modelOptions, data.model); + if (!model && isLocalModelPath(data.model)) { + model = customLocalModelOption(data.model); + setModelOptions((prev) => ( + findModelByPath(prev, data.model) ? prev : [...prev, model as ModelOption] + )); + } if (model) setSelectedModelId(model.id); } }) .catch(() => { /* ignore */ }); return () => { cancelled = true; }; - }, [sessionId]); + }, [sessionId, modelOptions]); - const selectedModel = MODEL_OPTIONS.find(m => m.id === selectedModelId) || MODEL_OPTIONS[0]; + const selectedModel = modelOptions.find(m => m.id === selectedModelId) || modelOptions[0]; + const customLocalEnabled = modelOptions.some(isLocalModel); + const firstFreeModel = () => modelOptions.find(m => !isClaudeModel(m)) ?? modelOptions[0]; // Auto-focus the textarea when the session becomes ready useEffect(() => { @@ -173,6 +280,26 @@ export default function ChatInput({ sessionId, onSend, onStop, onDeclineBlockedJ } catch { /* ignore */ } }; + const handleUseCustomModel = async () => { + const path = customModelPath.trim(); + if (!path || !sessionId) return; + const model = customLocalModelOption(path); + try { + const res = await apiFetch(`/api/session/${sessionId}/model`, { + method: 'POST', + body: JSON.stringify({ model: model.modelPath }), + }); + if (res.ok) { + setModelOptions((prev) => ( + findModelByPath(prev, model.modelPath) ? prev : [...prev, model] + )); + setSelectedModelId(model.id); + setCustomModelPath(''); + handleModelClose(); + } + } catch { /* ignore */ } + }; + // Dialog close: just clear the flag. The typed text is already restored. const handleCapDialogClose = useCallback(() => { setClaudeQuotaExhausted(false); @@ -183,7 +310,7 @@ export default function ChatInput({ sessionId, onSend, onStop, onDeclineBlockedJ const handleUseFreeModel = useCallback(async () => { setClaudeQuotaExhausted(false); if (!sessionId) return; - const free = MODEL_OPTIONS.find(m => m.modelPath === FIRST_FREE_MODEL_PATH) + const free = modelOptions.find(m => m.modelPath === FIRST_FREE_MODEL_PATH) ?? firstFreeModel(); try { const res = await apiFetch(`/api/session/${sessionId}/model`, { @@ -200,7 +327,7 @@ export default function ChatInput({ sessionId, onSend, onStop, onDeclineBlockedJ } } } catch { /* ignore */ } - }, [sessionId, onSend, setClaudeQuotaExhausted]); + }, [sessionId, onSend, setClaudeQuotaExhausted, modelOptions]); const handleClaudeUpgradeClick = useCallback(async () => { if (!sessionId) return; @@ -377,11 +504,7 @@ export default function ChatInput({ sessionId, onSend, onStop, onDeclineBlockedJ powered by - {selectedModel.name} + {selectedModel.name} @@ -412,7 +535,7 @@ export default function ChatInput({ sessionId, onSend, onStop, onDeclineBlockedJ } }} > - {MODEL_OPTIONS.map((model) => ( + {modelOptions.map((model) => ( handleSelectModel(model)} @@ -425,11 +548,7 @@ export default function ChatInput({ sessionId, onSend, onStop, onDeclineBlockedJ }} > - {model.name} + ))} + {customLocalEnabled && ( + <> + + e.stopPropagation()} + onKeyDown={(e) => e.stopPropagation()} + > + + Custom local model path + + + setCustomModelPath(e.target.value)} + onKeyDown={(e) => { + if (e.key === 'Enter') { + e.preventDefault(); + handleUseCustomModel(); + } + }} + placeholder="ollama/qwen2.5-coder" + fullWidth + variant="outlined" + sx={{ + '& .MuiInputBase-root': { + color: 'var(--text)', + bgcolor: 'rgba(255,255,255,0.04)', + }, + '& .MuiOutlinedInput-notchedOutline': { + borderColor: 'var(--divider)', + }, + }} + /> + + + + + )} Date: Sat, 25 Apr 2026 01:01:40 +0800 Subject: [PATCH 2/6] Address local model review feedback --- agent/core/llm_params.py | 11 ++++++ agent/core/model_switcher.py | 9 ++--- backend/model_catalog.py | 11 +++--- frontend/src/components/Chat/ChatInput.tsx | 46 +++++++++++++++++----- ml_intern/__init__.py | 1 + ml_intern/local_models.py | 10 +++++ pyproject.toml | 2 +- tests/unit/test_llm_params.py | 10 ++--- tests/unit/test_local_model_validation.py | 7 ++++ 9 files changed, 79 insertions(+), 28 deletions(-) create mode 100644 ml_intern/__init__.py create mode 100644 ml_intern/local_models.py diff --git a/agent/core/llm_params.py b/agent/core/llm_params.py index 0717767e..061911d0 100644 --- a/agent/core/llm_params.py +++ b/agent/core/llm_params.py @@ -90,6 +90,13 @@ class UnsupportedEffortError(ValueError): """ +def _raise_for_local_effort(reasoning_effort: str | None, strict: bool) -> None: + if reasoning_effort and strict: + raise UnsupportedEffortError( + "Local OpenAI-compatible endpoints don't accept reasoning_effort" + ) + + def _resolve_llm_params( model_name: str, session_hf_token: str | None = None, @@ -182,6 +189,7 @@ def _resolve_llm_params( return params if model_name.startswith("ollama/"): + _raise_for_local_effort(reasoning_effort, strict) local_model = model_name.split("/", 1)[1] api_base = os.environ.get("OLLAMA_BASE_URL", "http://localhost:11434") return { @@ -191,6 +199,7 @@ def _resolve_llm_params( } if model_name.startswith("vllm/"): + _raise_for_local_effort(reasoning_effort, strict) local_model = model_name.split("/", 1)[1] api_base = os.environ.get("VLLM_BASE_URL", "http://localhost:8000") return { @@ -200,6 +209,7 @@ def _resolve_llm_params( } if model_name.startswith("llamacpp/"): + _raise_for_local_effort(reasoning_effort, strict) local_model = model_name.split("/", 1)[1] api_base = os.environ.get("LLAMACPP_BASE_URL", "http://localhost:8001") return { @@ -209,6 +219,7 @@ def _resolve_llm_params( } if model_name.startswith("local://"): + _raise_for_local_effort(reasoning_effort, strict) local_model = model_name.split("://", 1)[1] api_base = os.environ.get("LOCAL_LLM_BASE_URL", "http://localhost:8000") return { diff --git a/agent/core/model_switcher.py b/agent/core/model_switcher.py index 5ac676d7..7a4f6bef 100644 --- a/agent/core/model_switcher.py +++ b/agent/core/model_switcher.py @@ -16,6 +16,7 @@ from __future__ import annotations from agent.core.effort_probe import ProbeInconclusive, probe_effort +from ml_intern.local_models import LOCAL_MODEL_PREFIXES, is_local_model_id # Suggested models shown by `/model` (not a gate). Users can paste any HF @@ -36,7 +37,6 @@ _ROUTING_POLICIES = {"fastest", "cheapest", "preferred"} -_LOCAL_MODEL_PREFIXES = ("ollama/", "vllm/", "llamacpp/", "local://") def is_valid_model_id(model_id: str) -> bool: @@ -54,10 +54,7 @@ def is_valid_model_id(model_id: str) -> bool: """ if not model_id: return False - if any( - model_id.startswith(prefix) and len(model_id) > len(prefix) - for prefix in _LOCAL_MODEL_PREFIXES - ): + if is_local_model_id(model_id): return True if "/" not in model_id: return False @@ -75,7 +72,7 @@ def _print_hf_routing_info(model_id: str, console) -> bool: Anthropic / OpenAI ids return ``True`` without printing anything — the probe below covers "does this model exist". """ - if model_id.startswith(("anthropic/", "openai/", *_LOCAL_MODEL_PREFIXES)): + if model_id.startswith(("anthropic/", "openai/", *LOCAL_MODEL_PREFIXES)): return True from agent.core import hf_router_catalog as cat diff --git a/backend/model_catalog.py b/backend/model_catalog.py index 8e86c3f5..86a4cd42 100644 --- a/backend/model_catalog.py +++ b/backend/model_catalog.py @@ -3,7 +3,7 @@ import os from typing import Any -LOCAL_MODEL_PREFIXES = ("ollama/", "vllm/", "llamacpp/", "local://") +from ml_intern.local_models import is_local_model_id def local_models_enabled() -> bool: @@ -86,11 +86,12 @@ def is_custom_local_model_id(model_id: str) -> bool: char.isspace() for char in model_id ): return False - return any( - model_id.startswith(prefix) and len(model_id) > len(prefix) - for prefix in LOCAL_MODEL_PREFIXES - ) + return is_local_model_id(model_id) def is_valid_model_id(model_id: str) -> bool: return model_id in available_model_ids() or is_custom_local_model_id(model_id) + + +def is_anthropic_model(model_id: str) -> bool: + return model_id.startswith(("anthropic/", "bedrock/")) and "anthropic" in model_id diff --git a/frontend/src/components/Chat/ChatInput.tsx b/frontend/src/components/Chat/ChatInput.tsx index 2f0c8677..0d068b57 100644 --- a/frontend/src/components/Chat/ChatInput.tsx +++ b/frontend/src/components/Chat/ChatInput.tsx @@ -160,6 +160,7 @@ export default function ChatInput({ sessionId, onSend, onStop, onDeclineBlockedJ const [selectedModelId, setSelectedModelId] = useState(DEFAULT_MODEL_OPTIONS[0].id); const [modelAnchorEl, setModelAnchorEl] = useState(null); const [customModelPath, setCustomModelPath] = useState(''); + const [customModelError, setCustomModelError] = useState(''); const { quota, refresh: refreshQuota } = useUserQuota(); // The daily-cap dialog is triggered from two places: (a) a 429 returned // from the chat transport when the user tries to send on Opus over cap — @@ -283,21 +284,36 @@ export default function ChatInput({ sessionId, onSend, onStop, onDeclineBlockedJ const handleUseCustomModel = async () => { const path = customModelPath.trim(); if (!path || !sessionId) return; + setCustomModelError(''); const model = customLocalModelOption(path); try { const res = await apiFetch(`/api/session/${sessionId}/model`, { method: 'POST', body: JSON.stringify({ model: model.modelPath }), }); - if (res.ok) { - setModelOptions((prev) => ( - findModelByPath(prev, model.modelPath) ? prev : [...prev, model] - )); - setSelectedModelId(model.id); - setCustomModelPath(''); - handleModelClose(); + if (!res.ok) { + let message = `Unable to use ${path}`; + try { + const data = await res.json(); + if (typeof data?.detail === 'string') { + message = data.detail; + } else if (typeof data?.detail?.message === 'string') { + message = data.detail.message; + } + } catch { /* keep fallback */ } + setCustomModelError(message); + return; } - } catch { /* ignore */ } + setModelOptions((prev) => ( + findModelByPath(prev, model.modelPath) ? prev : [...prev, model] + )); + setSelectedModelId(model.id); + setCustomModelPath(''); + setCustomModelError(''); + handleModelClose(); + } catch { + setCustomModelError('Unable to switch to that local model.'); + } }; // Dialog close: just clear the flag. The typed text is already restored. @@ -614,7 +630,10 @@ export default function ChatInput({ sessionId, onSend, onStop, onDeclineBlockedJ setCustomModelPath(e.target.value)} + onChange={(e) => { + setCustomModelPath(e.target.value); + setCustomModelError(''); + }} onKeyDown={(e) => { if (e.key === 'Enter') { e.preventDefault(); @@ -624,6 +643,15 @@ export default function ChatInput({ sessionId, onSend, onStop, onDeclineBlockedJ placeholder="ollama/qwen2.5-coder" fullWidth variant="outlined" + error={!!customModelError} + helperText={customModelError || ' '} + FormHelperTextProps={{ + sx: { + mx: 0, + color: customModelError ? 'var(--accent-red)' : 'transparent', + fontSize: '11px', + }, + }} sx={{ '& .MuiInputBase-root': { color: 'var(--text)', diff --git a/ml_intern/__init__.py b/ml_intern/__init__.py new file mode 100644 index 00000000..74c0c6e0 --- /dev/null +++ b/ml_intern/__init__.py @@ -0,0 +1 @@ +"""Shared lightweight helpers for ML Intern packages.""" diff --git a/ml_intern/local_models.py b/ml_intern/local_models.py new file mode 100644 index 00000000..63e93b23 --- /dev/null +++ b/ml_intern/local_models.py @@ -0,0 +1,10 @@ +"""Shared helpers for local OpenAI-compatible model ids.""" + +LOCAL_MODEL_PREFIXES = ("ollama/", "vllm/", "llamacpp/", "local://") + + +def is_local_model_id(model_id: str) -> bool: + return any( + model_id.startswith(prefix) and len(model_id) > len(prefix) + for prefix in LOCAL_MODEL_PREFIXES + ) diff --git a/pyproject.toml b/pyproject.toml index bd0f7b53..38167c7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ build-backend = "setuptools.build_meta" # runtime (resolves to /configs/cli_agent_config.json). # Without it, `uv tool install` / `pip install` produce a broken install # that imports fine but crashes at startup with FileNotFoundError. -include = ["agent*", "configs"] +include = ["agent*", "configs", "ml_intern*"] [tool.setuptools.package-data] configs = ["*.json"] diff --git a/tests/unit/test_llm_params.py b/tests/unit/test_llm_params.py index 30eb74c1..2b4abd52 100644 --- a/tests/unit/test_llm_params.py +++ b/tests/unit/test_llm_params.py @@ -150,10 +150,6 @@ def test_resolve_generic_local_params_trims_trailing_slash(monkeypatch): assert params["api_base"] == "http://127.0.0.1:9000/v1" -def test_local_reasoning_effort_rejected_in_strict_mode(): - with pytest.raises(UnsupportedEffortError, match="Local OpenAI-compatible endpoints"): - _resolve_llm_params( - "ollama/llama3.1", - reasoning_effort="high", - strict=True, - ) +def test_local_params_reject_reasoning_effort_in_strict_mode(): + with pytest.raises(UnsupportedEffortError, match="reasoning_effort"): + _resolve_llm_params("ollama/llama3.1", reasoning_effort="high", strict=True) diff --git a/tests/unit/test_local_model_validation.py b/tests/unit/test_local_model_validation.py index aef428ef..3cc993da 100644 --- a/tests/unit/test_local_model_validation.py +++ b/tests/unit/test_local_model_validation.py @@ -32,3 +32,10 @@ def test_custom_local_model_ids_reject_empty_or_whitespace(monkeypatch): assert not model_catalog.is_valid_model_id(" ollama/qwen") assert not model_catalog.is_valid_model_id("ollama/qwen coder") assert not model_catalog.is_valid_model_id("some-org/model") + + +def test_anthropic_detection_is_anchored_to_cloud_prefixes(): + assert model_catalog.is_anthropic_model("anthropic/claude-opus-4-6") + assert model_catalog.is_anthropic_model("bedrock/us.anthropic.claude-opus-4-6-v1") + assert not model_catalog.is_anthropic_model("local://my-anthropic-wrapper") + assert not model_catalog.is_anthropic_model("ollama/anthropic-clone") From 97c6b9c7f84d3a35745621ece62044b9ce256541 Mon Sep 17 00:00:00 2001 From: bhargav1000 Date: Sat, 25 Apr 2026 01:16:46 +0800 Subject: [PATCH 3/6] Make uv test command use project dev deps --- pyproject.toml | 6 ++++++ uv.lock | 14 +++++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 38167c7f..b816af6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,12 @@ all = [ "ml-intern[eval,dev]", ] +[dependency-groups] +dev = [ + "pytest>=9.0.2", + "pytest-asyncio>=1.2.0", +] + [project.scripts] ml-intern = "agent.main:cli" diff --git a/uv.lock b/uv.lock index 07eee3d3..06fddbb7 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.11" resolution-markers = [ "python_full_version >= '3.12'", @@ -1815,6 +1815,12 @@ eval = [ { name = "tenacity" }, ] +[package.dev-dependencies] +dev = [ + { name = "pytest" }, + { name = "pytest-asyncio" }, +] + [package.metadata] requires-dist = [ { name = "apscheduler", specifier = ">=3.10,<4" }, @@ -1846,6 +1852,12 @@ requires-dist = [ ] provides-extras = ["eval", "dev", "all"] +[package.metadata.requires-dev] +dev = [ + { name = "pytest", specifier = ">=9.0.2" }, + { name = "pytest-asyncio", specifier = ">=1.2.0" }, +] + [[package]] name = "mmh3" version = "5.2.0" From abe4fe4a1449e32902ba6a1c0b1295135f54b7fd Mon Sep 17 00:00:00 2001 From: bhargav1000 Date: Tue, 28 Apr 2026 13:56:56 +0800 Subject: [PATCH 4/6] Fix local model env import --- agent/core/llm_params.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/agent/core/llm_params.py b/agent/core/llm_params.py index 061911d0..3f64df00 100644 --- a/agent/core/llm_params.py +++ b/agent/core/llm_params.py @@ -5,6 +5,8 @@ creating circular imports. """ +import os + from agent.core.hf_tokens import get_hf_bill_to, resolve_hf_router_token From 2c6fc0b8e642a8f209c2033592a857d35797d436 Mon Sep 17 00:00:00 2001 From: bhargav1000 Date: Tue, 28 Apr 2026 13:57:58 +0800 Subject: [PATCH 5/6] Align repetition guard test wording --- tests/unit/test_doom_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_doom_loop.py b/tests/unit/test_doom_loop.py index bbdac454..3a31a5a4 100644 --- a/tests/unit/test_doom_loop.py +++ b/tests/unit/test_doom_loop.py @@ -207,7 +207,7 @@ def test_check_for_doom_loop_returns_corrective_prompt_for_identical_run(): msgs = [_assistant_call("read", '{"p": 1}')] * 3 out = check_for_doom_loop(msgs) assert out is not None - assert "DOOM LOOP DETECTED" in out + assert "REPETITION GUARD" in out assert "'read'" in out @@ -218,7 +218,7 @@ def test_check_for_doom_loop_returns_corrective_prompt_for_cycle(): msgs.append(_assistant_call("b", "{}")) out = check_for_doom_loop(msgs) assert out is not None - assert "DOOM LOOP DETECTED" in out + assert "REPETITION GUARD" in out assert "a → b" in out From 37f8622c88e90d0ca26a17cda063719f93c257e9 Mon Sep 17 00:00:00 2001 From: bhargav1000 Date: Tue, 28 Apr 2026 14:25:06 +0800 Subject: [PATCH 6/6] Address local model review feedback --- README.md | 2 ++ agent/core/llm_params.py | 17 +++++++++++++---- backend/model_catalog.py | 4 +--- tests/unit/test_llm_params.py | 13 +++++++++++++ tests/unit/test_local_model_validation.py | 11 ++--------- 5 files changed, 31 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 38418a10..edc89e6e 100644 --- a/README.md +++ b/README.md @@ -149,6 +149,8 @@ LLAMACPP_BASE_URL=http://localhost:8001 LOCAL_LLM_BASE_URL=http://localhost:8000 ``` +Keep these endpoint variables server-controlled. Do not expose them as user-editable web/API inputs; they determine where the backend sends LLM traffic. + For example, with Ollama: ```bash diff --git a/agent/core/llm_params.py b/agent/core/llm_params.py index 3f64df00..576d40ce 100644 --- a/agent/core/llm_params.py +++ b/agent/core/llm_params.py @@ -5,10 +5,13 @@ creating circular imports. """ +import logging import os from agent.core.hf_tokens import get_hf_bill_to, resolve_hf_router_token +logger = logging.getLogger(__name__) + def _resolve_hf_router_token(session_hf_token: str | None = None) -> str | None: """Backward-compatible private wrapper used by tests and older imports.""" @@ -93,10 +96,16 @@ class UnsupportedEffortError(ValueError): def _raise_for_local_effort(reasoning_effort: str | None, strict: bool) -> None: - if reasoning_effort and strict: - raise UnsupportedEffortError( - "Local OpenAI-compatible endpoints don't accept reasoning_effort" - ) + if not reasoning_effort: + return + message = "Local OpenAI-compatible endpoints don't accept reasoning_effort" + if strict: + raise UnsupportedEffortError(message) + logger.warning( + "%s; dropping reasoning_effort=%r for this local model call", + message, + reasoning_effort, + ) def _resolve_llm_params( diff --git a/backend/model_catalog.py b/backend/model_catalog.py index 86a4cd42..44d734fd 100644 --- a/backend/model_catalog.py +++ b/backend/model_catalog.py @@ -82,9 +82,7 @@ def is_custom_local_model_id(model_id: str) -> bool: return False if not isinstance(model_id, str): return False - if not model_id or model_id != model_id.strip() or any( - char.isspace() for char in model_id - ): + if not model_id or any(char.isspace() for char in model_id): return False return is_local_model_id(model_id) diff --git a/tests/unit/test_llm_params.py b/tests/unit/test_llm_params.py index 2b4abd52..3e970580 100644 --- a/tests/unit/test_llm_params.py +++ b/tests/unit/test_llm_params.py @@ -153,3 +153,16 @@ def test_resolve_generic_local_params_trims_trailing_slash(monkeypatch): def test_local_params_reject_reasoning_effort_in_strict_mode(): with pytest.raises(UnsupportedEffortError, match="reasoning_effort"): _resolve_llm_params("ollama/llama3.1", reasoning_effort="high", strict=True) + + +def test_local_params_drop_reasoning_effort_with_warning_in_non_strict_mode(caplog): + params = _resolve_llm_params( + "ollama/llama3.1", + reasoning_effort="high", + strict=False, + ) + + assert params["model"] == "openai/llama3.1" + assert "reasoning_effort" not in params + assert "extra_body" not in params + assert "don't accept reasoning_effort" in caplog.text diff --git a/tests/unit/test_local_model_validation.py b/tests/unit/test_local_model_validation.py index 3cc993da..26c078d5 100644 --- a/tests/unit/test_local_model_validation.py +++ b/tests/unit/test_local_model_validation.py @@ -1,20 +1,13 @@ """Tests for backend custom local model validation.""" -import sys -from pathlib import Path - -_ROOT_DIR = Path(__file__).resolve().parent.parent.parent -_BACKEND_DIR = _ROOT_DIR / "backend" -if str(_BACKEND_DIR) not in sys.path: - sys.path.insert(0, str(_BACKEND_DIR)) - -import model_catalog +from backend import model_catalog def test_custom_local_model_ids_require_feature_flag(monkeypatch): monkeypatch.delenv("ENABLE_LOCAL_MODELS", raising=False) assert not model_catalog.is_valid_model_id("ollama/qwen2.5-coder") + assert not model_catalog.is_custom_local_model_id("ollama/qwen2.5-coder") monkeypatch.setenv("ENABLE_LOCAL_MODELS", "true")