From bb56de6ddadf5957a85d4695d6d45fe58eadf395 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 5 May 2026 18:27:57 +0200 Subject: [PATCH 1/3] Add CLI local model support Co-authored-by: OpenAI Codex --- README.md | 33 +++++++++++- agent/core/llm_params.py | 52 ++++++++++++++++++ agent/core/local_models.py | 57 ++++++++++++++++++++ agent/core/model_switcher.py | 60 +++++++++++++++++++-- agent/main.py | 27 +++++----- tests/unit/test_cli_local_models.py | 83 +++++++++++++++++++++++++++++ tests/unit/test_llm_params.py | 74 +++++++++++++++++++++++++ 7 files changed, 369 insertions(+), 17 deletions(-) create mode 100644 agent/core/local_models.py create mode 100644 tests/unit/test_cli_local_models.py diff --git a/README.md b/README.md index 10d80337..51d223fe 100644 --- a/README.md +++ b/README.md @@ -28,10 +28,16 @@ Create a `.env` file in the project root (or export these in your shell): ```bash ANTHROPIC_API_KEY= # if using anthropic models OPENAI_API_KEY= # if using openai models +OLLAMA_BASE_URL=http://localhost:11434 # if using ollama/ local models +VLLM_BASE_URL=http://localhost:8000 # if using vllm/ local models +LMSTUDIO_BASE_URL=http://127.0.0.1:1234 # if using lm_studio/ local models +LLAMACPP_BASE_URL=http://localhost:8080 # if using llamacpp/ local models HF_TOKEN= GITHUB_TOKEN= ``` -If no `HF_TOKEN` is set, the CLI will prompt you to paste one on first launch. To get a GITHUB_TOKEN follow the tutorial [here](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-fine-grained-personal-access-token). +If no `HF_TOKEN` is set, the CLI will prompt you to paste one on first launch +unless you start on a local model. To get a GITHUB_TOKEN follow the tutorial +[here](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-fine-grained-personal-access-token). ### Usage @@ -52,10 +58,35 @@ ml-intern "fine-tune llama on my dataset" ```bash ml-intern --model anthropic/claude-opus-4-6 "your prompt" ml-intern --model openai/gpt-5.5 "your prompt" +ml-intern --model ollama/llama3.1:8b "your prompt" +ml-intern --model vllm/meta-llama/Llama-3.1-8B-Instruct "your prompt" ml-intern --max-iterations 100 "your prompt" ml-intern --no-stream "your prompt" ``` +**Local models:** + +Local model support uses OpenAI-compatible HTTP endpoints through LiteLLM. The +agent does not load model weights directly from disk; start your inference +server first, then select it with a provider-specific model prefix: + +```bash +ml-intern --model ollama/llama3.1:8b "your prompt" +ml-intern --model vllm/meta-llama/Llama-3.1-8B-Instruct "your prompt" +``` + +Inside interactive mode, switch with `/model`: + +```text +/model ollama/llama3.1:8b +/model lm_studio/google/gemma-3-4b +/model llamacpp/llama-3.1-8b-instruct +``` + +Supported local prefixes are `ollama/`, `vllm/`, `lm_studio/`, and +`llamacpp/`. Each prefix has a matching `*_BASE_URL` and optional `*_API_KEY` +environment variable. Base URLs may include or omit `/v1`. + ## Sharing Traces Every session is auto-uploaded to your **own private Hugging Face dataset** diff --git a/agent/core/llm_params.py b/agent/core/llm_params.py index 028dd6df..12abe138 100644 --- a/agent/core/llm_params.py +++ b/agent/core/llm_params.py @@ -5,7 +5,16 @@ creating circular imports. """ +import os + from agent.core.hf_tokens import get_hf_bill_to, resolve_hf_router_token +from agent.core.local_models import ( + LOCAL_MODEL_API_KEY_DEFAULT, + is_local_model_id, + is_reserved_local_model_id, + local_model_name, + local_model_provider, +) def _resolve_hf_router_token(session_hf_token: str | None = None) -> str | None: @@ -96,6 +105,37 @@ class UnsupportedEffortError(ValueError): """ +def _local_api_base(base_url: str) -> str: + base = base_url.strip().rstrip("/") + if base.endswith("/v1"): + return base + return f"{base}/v1" + + +def _resolve_local_model_params( + model_name: str, + reasoning_effort: str | None = None, + strict: bool = False, +) -> dict: + if reasoning_effort and strict: + raise UnsupportedEffortError( + "Local OpenAI-compatible endpoints don't accept reasoning_effort" + ) + + provider = local_model_provider(model_name) + local_name = local_model_name(model_name) + if provider is None or local_name is None or not is_local_model_id(model_name): + raise ValueError(f"Unsupported local model id: {model_name}") + + raw_base = os.environ.get(provider["base_url_env"]) or provider["base_url_default"] + api_key = os.environ.get(provider["api_key_env"]) or LOCAL_MODEL_API_KEY_DEFAULT + return { + "model": f"openai/{local_name}", + "api_base": _local_api_base(raw_base), + "api_key": api_key, + } + + def _resolve_llm_params( model_name: str, session_hf_token: str | None = None, @@ -121,6 +161,12 @@ def _resolve_llm_params( • ``openai/`` — ``reasoning_effort`` forwarded as a top-level kwarg (GPT-5 / o-series). LiteLLM uses the user's ``OPENAI_API_KEY``. + • ``ollama/``, ``vllm/``, ``lm_studio/``, and + ``llamacpp/`` — local OpenAI-compatible endpoints. The id prefix + selects a configurable localhost base URL, and the model suffix is sent + to LiteLLM as ``openai/``. These endpoints don't receive + ``reasoning_effort``. + • Anything else is treated as a HuggingFace router id. We hit the auto-routing OpenAI-compatible endpoint at ``https://router.huggingface.co/v1``. The id can be bare or carry an @@ -187,6 +233,12 @@ def _resolve_llm_params( params["reasoning_effort"] = reasoning_effort return params + if is_reserved_local_model_id(model_name): + raise ValueError(f"Unsupported local model id: {model_name}") + + if local_model_provider(model_name) is not None: + return _resolve_local_model_params(model_name, reasoning_effort, strict) + hf_model = model_name.removeprefix("huggingface/") api_key = _resolve_hf_router_token(session_hf_token) params = { diff --git a/agent/core/local_models.py b/agent/core/local_models.py new file mode 100644 index 00000000..f7d7cb1b --- /dev/null +++ b/agent/core/local_models.py @@ -0,0 +1,57 @@ +"""Helpers for CLI local OpenAI-compatible model ids.""" + +LOCAL_MODEL_PROVIDERS: dict[str, dict[str, str]] = { + "ollama/": { + "base_url_env": "OLLAMA_BASE_URL", + "base_url_default": "http://localhost:11434", + "api_key_env": "OLLAMA_API_KEY", + }, + "vllm/": { + "base_url_env": "VLLM_BASE_URL", + "base_url_default": "http://localhost:8000", + "api_key_env": "VLLM_API_KEY", + }, + "lm_studio/": { + "base_url_env": "LMSTUDIO_BASE_URL", + "base_url_default": "http://127.0.0.1:1234", + "api_key_env": "LMSTUDIO_API_KEY", + }, + "llamacpp/": { + "base_url_env": "LLAMACPP_BASE_URL", + "base_url_default": "http://localhost:8080", + "api_key_env": "LLAMACPP_API_KEY", + }, +} + +LOCAL_MODEL_PREFIXES = tuple(LOCAL_MODEL_PROVIDERS) +RESERVED_LOCAL_MODEL_PREFIXES = ("openai-compat/",) +LOCAL_MODEL_API_KEY_DEFAULT = "sk-local-no-key-required" + + +def local_model_provider(model_id: str) -> dict[str, str] | None: + """Return provider config for a local model id, if it uses a local prefix.""" + for prefix, config in LOCAL_MODEL_PROVIDERS.items(): + if model_id.startswith(prefix): + return config + return None + + +def local_model_name(model_id: str) -> str | None: + """Return the backend model name with the local provider prefix removed.""" + for prefix in LOCAL_MODEL_PREFIXES: + if model_id.startswith(prefix): + name = model_id[len(prefix) :] + return name or None + return None + + +def is_local_model_id(model_id: str) -> bool: + """Return True for non-empty, whitespace-free local model ids.""" + if not model_id or any(char.isspace() for char in model_id): + return False + return local_model_name(model_id) is not None + + +def is_reserved_local_model_id(model_id: str) -> bool: + """Return True for local-style prefixes intentionally not supported.""" + return model_id.startswith(RESERVED_LOCAL_MODEL_PREFIXES) diff --git a/agent/core/model_switcher.py b/agent/core/model_switcher.py index 14b5233d..5a8c1742 100644 --- a/agent/core/model_switcher.py +++ b/agent/core/model_switcher.py @@ -15,7 +15,17 @@ from __future__ import annotations +import asyncio + +from litellm import acompletion + from agent.core.effort_probe import ProbeInconclusive, probe_effort +from agent.core.llm_params import _resolve_llm_params +from agent.core.local_models import ( + LOCAL_MODEL_PREFIXES, + is_local_model_id, + is_reserved_local_model_id, +) # Suggested models shown by `/model` (not a gate). Users can paste any HF @@ -40,6 +50,8 @@ _ROUTING_POLICIES = {"fastest", "cheapest", "preferred"} +_DIRECT_PREFIXES = ("anthropic/", "openai/", *LOCAL_MODEL_PREFIXES) +_LOCAL_PROBE_TIMEOUT = 15.0 def is_valid_model_id(model_id: str) -> bool: @@ -48,13 +60,22 @@ def is_valid_model_id(model_id: str) -> bool: Accepts: • anthropic/ • openai/ + • ollama/, vllm/, lm_studio/, llamacpp//[:] (HF router; tag = provider or policy) • huggingface//[:] (same, accepts legacy prefix) Actual availability is verified against the HF router catalog on switch, and by the provider on the probe's ping call. """ - if not model_id or "/" not in model_id: + if not model_id: + return False + if is_local_model_id(model_id): + return True + if is_reserved_local_model_id(model_id): + return False + if any(model_id.startswith(prefix) for prefix in LOCAL_MODEL_PREFIXES): + return False + if "/" not in model_id: return False head = model_id.split(":", 1)[0] parts = head.split("/") @@ -70,7 +91,7 @@ def _print_hf_routing_info(model_id: str, console) -> bool: Anthropic / OpenAI ids return ``True`` without printing anything — the probe below covers "does this model exist". """ - if model_id.startswith(("anthropic/", "openai/")): + if model_id.startswith(_DIRECT_PREFIXES): return True from agent.core import hf_router_catalog as cat @@ -141,7 +162,9 @@ def print_model_listing(config, console) -> None: console.print( "\n[dim]Paste any HF model id (e.g. 'MiniMaxAI/MiniMax-M2.7').\n" "Add ':fastest', ':cheapest', ':preferred', or ':' to override routing.\n" - "Use 'anthropic/' or 'openai/' for direct API access.[/dim]" + "Use 'anthropic/' or 'openai/' for direct API access.\n" + "Use 'ollama/', 'vllm/', 'lm_studio/', or " + "'llamacpp/' for local OpenAI-compatible endpoints.[/dim]" ) @@ -151,7 +174,21 @@ def print_invalid_id(arg: str, console) -> None: "[dim]Expected:\n" " • /[:tag] (HF router — paste from huggingface.co)\n" " • anthropic/\n" - " • openai/[/dim]" + " • openai/\n" + " • ollama/ | vllm/ | lm_studio/ | llamacpp/[/dim]" + ) + + +async def _probe_local_model(model_id: str) -> None: + params = _resolve_llm_params(model_id) + await asyncio.wait_for( + acompletion( + messages=[{"role": "user", "content": "ping"}], + max_tokens=1, + stream=False, + **params, + ), + timeout=_LOCAL_PROBE_TIMEOUT, ) @@ -176,6 +213,21 @@ async def probe_and_switch_model( Transient errors (5xx, timeout) complete the switch with a yellow warning; the next real call re-surfaces the error if it's persistent. """ + if is_local_model_id(model_id): + console.print(f"[dim]checking local model {model_id}...[/dim]") + try: + await _probe_local_model(model_id) + except Exception as e: + console.print(f"[bold red]Switch failed:[/bold red] {e}") + console.print(f"[dim]Keeping current model: {config.model_name}[/dim]") + return + + _commit_switch(model_id, config, session, effective=None, cache=True) + console.print( + f"[green]Model switched to {model_id}[/green] [dim](effort: off)[/dim]" + ) + return + preference = config.reasoning_effort if not _print_hf_routing_info(model_id, console): return diff --git a/agent/main.py b/agent/main.py index 1fadcbda..a7262707 100644 --- a/agent/main.py +++ b/agent/main.py @@ -25,6 +25,7 @@ from agent.core.agent_loop import submission_loop from agent.core import model_switcher from agent.core.hf_tokens import resolve_hf_token +from agent.core.local_models import is_local_model_id from agent.core.session import OpType from agent.core.tools import ToolRouter from agent.messaging.gateway import NotificationGateway @@ -967,15 +968,15 @@ async def main(model: str | None = None): # Create prompt session for input (needed early for token prompt) prompt_session = PromptSession() - # HF token — required, prompt if missing - hf_token = resolve_hf_token() - if not hf_token: - hf_token = await _prompt_and_save_hf_token(prompt_session) - config = load_config(CLI_CONFIG_PATH, include_user_defaults=True) if model: config.model_name = model + # HF token — required for Hub-backed models/tools, but not for local LLMs. + hf_token = resolve_hf_token() + if not hf_token and not is_local_model_id(config.model_name): + hf_token = await _prompt_and_save_hf_token(prompt_session) + # Resolve username for banner hf_user = _get_hf_user(hf_token) @@ -1198,25 +1199,27 @@ async def headless_main( logging.basicConfig(level=logging.WARNING) _configure_runtime_logging() + config = load_config(CLI_CONFIG_PATH, include_user_defaults=True) + config.yolo_mode = True # Auto-approve everything in headless mode + + if model: + config.model_name = model + hf_token = resolve_hf_token() - if not hf_token: + if not hf_token and not is_local_model_id(config.model_name): print( "ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.", file=sys.stderr, ) sys.exit(1) - print("HF token loaded", file=sys.stderr) + if hf_token: + print("HF token loaded", file=sys.stderr) - config = load_config(CLI_CONFIG_PATH, include_user_defaults=True) - config.yolo_mode = True # Auto-approve everything in headless mode notification_gateway = NotificationGateway(config.messaging) await notification_gateway.start() hf_user = _get_hf_user(hf_token) - if model: - config.model_name = model - if max_iterations is not None: config.max_iterations = max_iterations diff --git a/tests/unit/test_cli_local_models.py b/tests/unit/test_cli_local_models.py new file mode 100644 index 00000000..9988d7e9 --- /dev/null +++ b/tests/unit/test_cli_local_models.py @@ -0,0 +1,83 @@ +import pytest + +from agent.core import model_switcher +from agent.core.local_models import is_local_model_id + + +def test_local_model_helper_accepts_supported_prefixes(): + assert is_local_model_id("ollama/llama3.1:8b") + assert is_local_model_id("vllm/meta-llama/Llama-3.1-8B-Instruct") + assert is_local_model_id("lm_studio/google/gemma-3-4b") + assert is_local_model_id("llamacpp/unsloth/Qwen3.5-2B") + + +def test_model_switcher_accepts_supported_local_prefixes(): + assert model_switcher.is_valid_model_id("ollama/llama3.1:8b") + assert model_switcher.is_valid_model_id("vllm/meta-llama/Llama-3.1-8B") + assert model_switcher.is_valid_model_id("lm_studio/google/gemma-3-4b") + assert model_switcher.is_valid_model_id("llamacpp/llama-3.1-8b") + + +def test_model_switcher_rejects_empty_or_whitespace_local_ids(): + assert not model_switcher.is_valid_model_id("ollama/") + assert not model_switcher.is_valid_model_id("vllm/") + assert not model_switcher.is_valid_model_id("lm_studio/") + assert not model_switcher.is_valid_model_id("llamacpp/") + assert not model_switcher.is_valid_model_id("ollama/llama 3.1") + + +def test_openai_compat_prefix_is_not_supported(): + assert not model_switcher.is_valid_model_id("openai-compat/custom-model") + + +def test_local_models_skip_hf_router_catalog_output(): + class NoPrintConsole: + def print(self, *args, **kwargs): + raise AssertionError("local models should not print HF catalog info") + + assert model_switcher._print_hf_routing_info( + "ollama/llama3.1:8b", + NoPrintConsole(), + ) + + +@pytest.mark.asyncio +async def test_probe_and_switch_local_model_uses_no_effort(monkeypatch): + calls = [] + + async def fake_acompletion(**kwargs): + calls.append(kwargs) + return object() + + monkeypatch.setattr(model_switcher, "acompletion", fake_acompletion) + + class Config: + model_name = "openai/gpt-5.5" + reasoning_effort = "max" + + class Session: + def __init__(self): + self.model_id = None + self.model_effective_effort = {} + + def update_model(self, model_id): + self.model_id = model_id + + class Console: + def print(self, *args, **kwargs): + pass + + session = Session() + await model_switcher.probe_and_switch_model( + "ollama/llama3.1:8b", + Config(), + session, + Console(), + hf_token=None, + ) + + assert session.model_id == "ollama/llama3.1:8b" + assert session.model_effective_effort["ollama/llama3.1:8b"] is None + assert calls[0]["model"] == "openai/llama3.1:8b" + assert "reasoning_effort" not in calls[0] + assert "extra_body" not in calls[0] diff --git a/tests/unit/test_llm_params.py b/tests/unit/test_llm_params.py index 5234461a..9bf4940f 100644 --- a/tests/unit/test_llm_params.py +++ b/tests/unit/test_llm_params.py @@ -1,3 +1,5 @@ +import pytest + from agent.core.hf_tokens import resolve_hf_request_token from agent.core.llm_params import ( UnsupportedEffortError, @@ -30,6 +32,78 @@ def test_openai_max_effort_is_still_rejected(): raise AssertionError("Expected UnsupportedEffortError for max effort") +def test_resolve_ollama_params_adds_v1_and_uses_default_key(monkeypatch): + monkeypatch.delenv("OLLAMA_API_KEY", raising=False) + monkeypatch.setenv("OLLAMA_BASE_URL", "http://localhost:11434") + + params = _resolve_llm_params("ollama/llama3.1:8b") + + assert params == { + "model": "openai/llama3.1:8b", + "api_base": "http://localhost:11434/v1", + "api_key": "sk-local-no-key-required", + } + + +def test_resolve_vllm_params_keeps_existing_v1_and_trims_slash(monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + monkeypatch.setenv("VLLM_BASE_URL", "http://localhost:8000/v1/") + + params = _resolve_llm_params("vllm/meta-llama/Llama-3.1-8B-Instruct") + + assert params["model"] == "openai/meta-llama/Llama-3.1-8B-Instruct" + assert params["api_base"] == "http://localhost:8000/v1" + assert params["api_key"] == "sk-local-no-key-required" + + +def test_resolve_lm_studio_params_uses_api_key_override(monkeypatch): + monkeypatch.setenv("LMSTUDIO_BASE_URL", "http://127.0.0.1:1234") + monkeypatch.setenv("LMSTUDIO_API_KEY", "local-secret") + + params = _resolve_llm_params("lm_studio/google/gemma-3-4b") + + assert params["model"] == "openai/google/gemma-3-4b" + assert params["api_base"] == "http://127.0.0.1:1234/v1" + assert params["api_key"] == "local-secret" + + +def test_resolve_llamacpp_params_strips_provider_prefix(monkeypatch): + monkeypatch.delenv("LLAMACPP_API_KEY", raising=False) + monkeypatch.setenv("LLAMACPP_BASE_URL", "http://localhost:8080") + + params = _resolve_llm_params("llamacpp/unsloth/Qwen3.5-2B") + + assert params["model"] == "openai/unsloth/Qwen3.5-2B" + assert params["api_base"] == "http://localhost:8080/v1" + + +def test_local_params_reject_reasoning_effort_in_strict_mode(): + with pytest.raises(UnsupportedEffortError, match="reasoning_effort"): + _resolve_llm_params("ollama/llama3.1", reasoning_effort="high", strict=True) + + +def test_local_params_drop_reasoning_effort_in_non_strict_mode(): + params = _resolve_llm_params( + "ollama/llama3.1", + reasoning_effort="high", + strict=False, + ) + + assert params["model"] == "openai/llama3.1" + assert "reasoning_effort" not in params + assert "extra_body" not in params + + +def test_openai_compat_prefix_is_not_a_local_escape_hatch(): + with pytest.raises(ValueError, match="Unsupported local model id"): + _resolve_llm_params("openai-compat/custom-model") + + +def test_empty_local_model_id_is_not_treated_as_hf_router(): + with pytest.raises(ValueError, match="Unsupported local model id"): + _resolve_llm_params("ollama/") + + def test_hf_router_token_prefers_inference_token(monkeypatch): monkeypatch.setenv("INFERENCE_TOKEN", " inference-token ") monkeypatch.setenv("HF_TOKEN", "hf-token") From 8d21bf10c6a5dc505b17ad451b2c5825851cc7de Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Wed, 6 May 2026 11:28:53 +0200 Subject: [PATCH 2/3] Add shared local model endpoint fallback Support LOCAL_LLM_BASE_URL and LOCAL_LLM_API_KEY as shared fallbacks while preserving provider-specific local overrides. Co-authored-by: OpenAI Codex --- README.md | 13 +++++++------ agent/core/llm_params.py | 14 ++++++++++++-- agent/core/local_models.py | 2 ++ tests/unit/test_llm_params.py | 15 +++++++++++++++ 4 files changed, 36 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 4c6a9b73..0a692e16 100644 --- a/README.md +++ b/README.md @@ -28,10 +28,8 @@ Create a `.env` file in the project root (or export these in your shell): ```bash ANTHROPIC_API_KEY= # if using anthropic models OPENAI_API_KEY= # if using openai models -OLLAMA_BASE_URL=http://localhost:11434 # if using ollama/ local models -VLLM_BASE_URL=http://localhost:8000 # if using vllm/ local models -LMSTUDIO_BASE_URL=http://127.0.0.1:1234 # if using lm_studio/ local models -LLAMACPP_BASE_URL=http://localhost:8080 # if using llamacpp/ local models +LOCAL_LLM_BASE_URL=http://localhost:8000 # shared fallback for local model prefixes +LOCAL_LLM_API_KEY= # optional shared local API key HF_TOKEN= GITHUB_TOKEN= ``` @@ -88,8 +86,11 @@ Inside interactive mode, switch with `/model`: ``` Supported local prefixes are `ollama/`, `vllm/`, `lm_studio/`, and -`llamacpp/`. Each prefix has a matching `*_BASE_URL` and optional `*_API_KEY` -environment variable. Base URLs may include or omit `/v1`. +`llamacpp/`. Set `LOCAL_LLM_BASE_URL` and optional `LOCAL_LLM_API_KEY` to use +one shared local endpoint, or override a specific provider with its matching +`*_BASE_URL` / `*_API_KEY` variable, such as `OLLAMA_BASE_URL` or +`VLLM_API_KEY`. Provider-specific variables take precedence over the shared +local variables. Base URLs may include or omit `/v1`. ## Sharing Traces diff --git a/agent/core/llm_params.py b/agent/core/llm_params.py index 12abe138..4d0c6d87 100644 --- a/agent/core/llm_params.py +++ b/agent/core/llm_params.py @@ -10,6 +10,8 @@ from agent.core.hf_tokens import get_hf_bill_to, resolve_hf_router_token from agent.core.local_models import ( LOCAL_MODEL_API_KEY_DEFAULT, + LOCAL_MODEL_API_KEY_ENV, + LOCAL_MODEL_BASE_URL_ENV, is_local_model_id, is_reserved_local_model_id, local_model_name, @@ -127,8 +129,16 @@ def _resolve_local_model_params( if provider is None or local_name is None or not is_local_model_id(model_name): raise ValueError(f"Unsupported local model id: {model_name}") - raw_base = os.environ.get(provider["base_url_env"]) or provider["base_url_default"] - api_key = os.environ.get(provider["api_key_env"]) or LOCAL_MODEL_API_KEY_DEFAULT + raw_base = ( + os.environ.get(provider["base_url_env"]) + or os.environ.get(LOCAL_MODEL_BASE_URL_ENV) + or provider["base_url_default"] + ) + api_key = ( + os.environ.get(provider["api_key_env"]) + or os.environ.get(LOCAL_MODEL_API_KEY_ENV) + or LOCAL_MODEL_API_KEY_DEFAULT + ) return { "model": f"openai/{local_name}", "api_base": _local_api_base(raw_base), diff --git a/agent/core/local_models.py b/agent/core/local_models.py index f7d7cb1b..9f8a9491 100644 --- a/agent/core/local_models.py +++ b/agent/core/local_models.py @@ -25,6 +25,8 @@ LOCAL_MODEL_PREFIXES = tuple(LOCAL_MODEL_PROVIDERS) RESERVED_LOCAL_MODEL_PREFIXES = ("openai-compat/",) +LOCAL_MODEL_BASE_URL_ENV = "LOCAL_LLM_BASE_URL" +LOCAL_MODEL_API_KEY_ENV = "LOCAL_LLM_API_KEY" LOCAL_MODEL_API_KEY_DEFAULT = "sk-local-no-key-required" diff --git a/tests/unit/test_llm_params.py b/tests/unit/test_llm_params.py index 9bf4940f..a7c7b4cd 100644 --- a/tests/unit/test_llm_params.py +++ b/tests/unit/test_llm_params.py @@ -59,6 +59,8 @@ def test_resolve_vllm_params_keeps_existing_v1_and_trims_slash(monkeypatch): def test_resolve_lm_studio_params_uses_api_key_override(monkeypatch): monkeypatch.setenv("LMSTUDIO_BASE_URL", "http://127.0.0.1:1234") monkeypatch.setenv("LMSTUDIO_API_KEY", "local-secret") + monkeypatch.setenv("LOCAL_LLM_BASE_URL", "http://localhost:9999") + monkeypatch.setenv("LOCAL_LLM_API_KEY", "shared-secret") params = _resolve_llm_params("lm_studio/google/gemma-3-4b") @@ -67,6 +69,19 @@ def test_resolve_lm_studio_params_uses_api_key_override(monkeypatch): assert params["api_key"] == "local-secret" +def test_resolve_local_params_uses_shared_fallback_env(monkeypatch): + monkeypatch.delenv("VLLM_BASE_URL", raising=False) + monkeypatch.delenv("VLLM_API_KEY", raising=False) + monkeypatch.setenv("LOCAL_LLM_BASE_URL", "http://localhost:9000/v1/") + monkeypatch.setenv("LOCAL_LLM_API_KEY", "shared-local-secret") + + params = _resolve_llm_params("vllm/custom-model") + + assert params["model"] == "openai/custom-model" + assert params["api_base"] == "http://localhost:9000/v1" + assert params["api_key"] == "shared-local-secret" + + def test_resolve_llamacpp_params_strips_provider_prefix(monkeypatch): monkeypatch.delenv("LLAMACPP_API_KEY", raising=False) monkeypatch.setenv("LLAMACPP_BASE_URL", "http://localhost:8080") From 925c29e3b0ca57d6f11f839f72599622df16555c Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Wed, 6 May 2026 11:54:18 +0200 Subject: [PATCH 3/3] Address local model review feedback Clarify local probe failure behavior, add regression coverage for rejected local switches, and simplify local model validation. Co-authored-by: OpenAI Codex --- agent/core/llm_params.py | 6 ++--- agent/core/model_switcher.py | 6 +++-- tests/unit/test_cli_local_models.py | 38 +++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/agent/core/llm_params.py b/agent/core/llm_params.py index 4d0c6d87..f95695fb 100644 --- a/agent/core/llm_params.py +++ b/agent/core/llm_params.py @@ -12,7 +12,6 @@ LOCAL_MODEL_API_KEY_DEFAULT, LOCAL_MODEL_API_KEY_ENV, LOCAL_MODEL_BASE_URL_ENV, - is_local_model_id, is_reserved_local_model_id, local_model_name, local_model_provider, @@ -124,11 +123,12 @@ def _resolve_local_model_params( "Local OpenAI-compatible endpoints don't accept reasoning_effort" ) - provider = local_model_provider(model_name) local_name = local_model_name(model_name) - if provider is None or local_name is None or not is_local_model_id(model_name): + if local_name is None: raise ValueError(f"Unsupported local model id: {model_name}") + provider = local_model_provider(model_name) + assert provider is not None raw_base = ( os.environ.get(provider["base_url_env"]) or os.environ.get(LOCAL_MODEL_BASE_URL_ENV) diff --git a/agent/core/model_switcher.py b/agent/core/model_switcher.py index 5a8c1742..34eaccdd 100644 --- a/agent/core/model_switcher.py +++ b/agent/core/model_switcher.py @@ -210,8 +210,10 @@ async def probe_and_switch_model( * ✗ hard error (auth, model-not-found, quota) — we reject the switch and keep the current model so the user isn't stranded - Transient errors (5xx, timeout) complete the switch with a yellow - warning; the next real call re-surfaces the error if it's persistent. + For non-local models, transient errors (5xx, timeout) complete the switch + with a yellow warning; the next real call re-surfaces the error if it's + persistent. Local models reject every probe error, including timeouts, and + keep the current model. """ if is_local_model_id(model_id): console.print(f"[dim]checking local model {model_id}...[/dim]") diff --git a/tests/unit/test_cli_local_models.py b/tests/unit/test_cli_local_models.py index 9988d7e9..836fb3fd 100644 --- a/tests/unit/test_cli_local_models.py +++ b/tests/unit/test_cli_local_models.py @@ -81,3 +81,41 @@ def print(self, *args, **kwargs): assert calls[0]["model"] == "openai/llama3.1:8b" assert "reasoning_effort" not in calls[0] assert "extra_body" not in calls[0] + + +@pytest.mark.asyncio +async def test_probe_and_switch_local_model_rejects_probe_errors(monkeypatch): + async def failing_acompletion(**kwargs): + raise ConnectionRefusedError("no server") + + monkeypatch.setattr(model_switcher, "acompletion", failing_acompletion) + + class Config: + model_name = "openai/gpt-5.5" + reasoning_effort = None + + class Session: + def __init__(self): + self.model_id = None + self.model_effective_effort = {} + + def update_model(self, model_id): + self.model_id = model_id + + class Console: + def print(self, *args, **kwargs): + pass + + config = Config() + session = Session() + await model_switcher.probe_and_switch_model( + "ollama/llama3.1:8b", + config, + session, + Console(), + hf_token=None, + ) + + assert config.model_name == "openai/gpt-5.5" + assert session.model_id is None + assert "ollama/llama3.1:8b" not in session.model_effective_effort