diff --git a/README.md b/README.md index ab2f7d52..0a692e16 100644 --- a/README.md +++ b/README.md @@ -28,10 +28,14 @@ Create a `.env` file in the project root (or export these in your shell): ```bash ANTHROPIC_API_KEY= # if using anthropic models OPENAI_API_KEY= # if using openai models +LOCAL_LLM_BASE_URL=http://localhost:8000 # shared fallback for local model prefixes +LOCAL_LLM_API_KEY= # optional shared local API key HF_TOKEN= GITHUB_TOKEN= ``` -If no `HF_TOKEN` is set, the CLI will prompt you to paste one on first launch. To get a GITHUB_TOKEN follow the tutorial [here](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-fine-grained-personal-access-token). +If no `HF_TOKEN` is set, the CLI will prompt you to paste one on first launch +unless you start on a local model. To get a GITHUB_TOKEN follow the tutorial +[here](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-fine-grained-personal-access-token). ### Usage @@ -52,12 +56,41 @@ ml-intern "fine-tune llama on my dataset" ```bash ml-intern --model anthropic/claude-opus-4-7 "your prompt" # requires ANTHROPIC_API_KEY ml-intern --model openai/gpt-5.5 "your prompt" # requires OPENAI_API_KEY +ml-intern --model ollama/llama3.1:8b "your prompt" +ml-intern --model vllm/meta-llama/Llama-3.1-8B-Instruct "your prompt" ml-intern --max-iterations 100 "your prompt" ml-intern --no-stream "your prompt" ``` Run `ml-intern` then `/model` to see the full list of suggested model ids -(Claude, GPT, and HF-router models like MiniMax, Kimi, GLM, DeepSeek). +(Claude, GPT, HF-router models like MiniMax, Kimi, GLM, DeepSeek, and local +model prefixes). + +**Local models:** + +Local model support uses OpenAI-compatible HTTP endpoints through LiteLLM. The +agent does not load model weights directly from disk; start your inference +server first, then select it with a provider-specific model prefix: + +```bash +ml-intern --model ollama/llama3.1:8b "your prompt" +ml-intern --model vllm/meta-llama/Llama-3.1-8B-Instruct "your prompt" +``` + +Inside interactive mode, switch with `/model`: + +```text +/model ollama/llama3.1:8b +/model lm_studio/google/gemma-3-4b +/model llamacpp/llama-3.1-8b-instruct +``` + +Supported local prefixes are `ollama/`, `vllm/`, `lm_studio/`, and +`llamacpp/`. Set `LOCAL_LLM_BASE_URL` and optional `LOCAL_LLM_API_KEY` to use +one shared local endpoint, or override a specific provider with its matching +`*_BASE_URL` / `*_API_KEY` variable, such as `OLLAMA_BASE_URL` or +`VLLM_API_KEY`. Provider-specific variables take precedence over the shared +local variables. Base URLs may include or omit `/v1`. ## Sharing Traces diff --git a/agent/core/llm_params.py b/agent/core/llm_params.py index 028dd6df..f95695fb 100644 --- a/agent/core/llm_params.py +++ b/agent/core/llm_params.py @@ -5,7 +5,17 @@ creating circular imports. """ +import os + from agent.core.hf_tokens import get_hf_bill_to, resolve_hf_router_token +from agent.core.local_models import ( + LOCAL_MODEL_API_KEY_DEFAULT, + LOCAL_MODEL_API_KEY_ENV, + LOCAL_MODEL_BASE_URL_ENV, + is_reserved_local_model_id, + local_model_name, + local_model_provider, +) def _resolve_hf_router_token(session_hf_token: str | None = None) -> str | None: @@ -96,6 +106,46 @@ class UnsupportedEffortError(ValueError): """ +def _local_api_base(base_url: str) -> str: + base = base_url.strip().rstrip("/") + if base.endswith("/v1"): + return base + return f"{base}/v1" + + +def _resolve_local_model_params( + model_name: str, + reasoning_effort: str | None = None, + strict: bool = False, +) -> dict: + if reasoning_effort and strict: + raise UnsupportedEffortError( + "Local OpenAI-compatible endpoints don't accept reasoning_effort" + ) + + local_name = local_model_name(model_name) + if local_name is None: + raise ValueError(f"Unsupported local model id: {model_name}") + + provider = local_model_provider(model_name) + assert provider is not None + raw_base = ( + os.environ.get(provider["base_url_env"]) + or os.environ.get(LOCAL_MODEL_BASE_URL_ENV) + or provider["base_url_default"] + ) + api_key = ( + os.environ.get(provider["api_key_env"]) + or os.environ.get(LOCAL_MODEL_API_KEY_ENV) + or LOCAL_MODEL_API_KEY_DEFAULT + ) + return { + "model": f"openai/{local_name}", + "api_base": _local_api_base(raw_base), + "api_key": api_key, + } + + def _resolve_llm_params( model_name: str, session_hf_token: str | None = None, @@ -121,6 +171,12 @@ def _resolve_llm_params( • ``openai/`` — ``reasoning_effort`` forwarded as a top-level kwarg (GPT-5 / o-series). LiteLLM uses the user's ``OPENAI_API_KEY``. + • ``ollama/``, ``vllm/``, ``lm_studio/``, and + ``llamacpp/`` — local OpenAI-compatible endpoints. The id prefix + selects a configurable localhost base URL, and the model suffix is sent + to LiteLLM as ``openai/``. These endpoints don't receive + ``reasoning_effort``. + • Anything else is treated as a HuggingFace router id. We hit the auto-routing OpenAI-compatible endpoint at ``https://router.huggingface.co/v1``. The id can be bare or carry an @@ -187,6 +243,12 @@ def _resolve_llm_params( params["reasoning_effort"] = reasoning_effort return params + if is_reserved_local_model_id(model_name): + raise ValueError(f"Unsupported local model id: {model_name}") + + if local_model_provider(model_name) is not None: + return _resolve_local_model_params(model_name, reasoning_effort, strict) + hf_model = model_name.removeprefix("huggingface/") api_key = _resolve_hf_router_token(session_hf_token) params = { diff --git a/agent/core/local_models.py b/agent/core/local_models.py new file mode 100644 index 00000000..9f8a9491 --- /dev/null +++ b/agent/core/local_models.py @@ -0,0 +1,59 @@ +"""Helpers for CLI local OpenAI-compatible model ids.""" + +LOCAL_MODEL_PROVIDERS: dict[str, dict[str, str]] = { + "ollama/": { + "base_url_env": "OLLAMA_BASE_URL", + "base_url_default": "http://localhost:11434", + "api_key_env": "OLLAMA_API_KEY", + }, + "vllm/": { + "base_url_env": "VLLM_BASE_URL", + "base_url_default": "http://localhost:8000", + "api_key_env": "VLLM_API_KEY", + }, + "lm_studio/": { + "base_url_env": "LMSTUDIO_BASE_URL", + "base_url_default": "http://127.0.0.1:1234", + "api_key_env": "LMSTUDIO_API_KEY", + }, + "llamacpp/": { + "base_url_env": "LLAMACPP_BASE_URL", + "base_url_default": "http://localhost:8080", + "api_key_env": "LLAMACPP_API_KEY", + }, +} + +LOCAL_MODEL_PREFIXES = tuple(LOCAL_MODEL_PROVIDERS) +RESERVED_LOCAL_MODEL_PREFIXES = ("openai-compat/",) +LOCAL_MODEL_BASE_URL_ENV = "LOCAL_LLM_BASE_URL" +LOCAL_MODEL_API_KEY_ENV = "LOCAL_LLM_API_KEY" +LOCAL_MODEL_API_KEY_DEFAULT = "sk-local-no-key-required" + + +def local_model_provider(model_id: str) -> dict[str, str] | None: + """Return provider config for a local model id, if it uses a local prefix.""" + for prefix, config in LOCAL_MODEL_PROVIDERS.items(): + if model_id.startswith(prefix): + return config + return None + + +def local_model_name(model_id: str) -> str | None: + """Return the backend model name with the local provider prefix removed.""" + for prefix in LOCAL_MODEL_PREFIXES: + if model_id.startswith(prefix): + name = model_id[len(prefix) :] + return name or None + return None + + +def is_local_model_id(model_id: str) -> bool: + """Return True for non-empty, whitespace-free local model ids.""" + if not model_id or any(char.isspace() for char in model_id): + return False + return local_model_name(model_id) is not None + + +def is_reserved_local_model_id(model_id: str) -> bool: + """Return True for local-style prefixes intentionally not supported.""" + return model_id.startswith(RESERVED_LOCAL_MODEL_PREFIXES) diff --git a/agent/core/model_switcher.py b/agent/core/model_switcher.py index 14b5233d..34eaccdd 100644 --- a/agent/core/model_switcher.py +++ b/agent/core/model_switcher.py @@ -15,7 +15,17 @@ from __future__ import annotations +import asyncio + +from litellm import acompletion + from agent.core.effort_probe import ProbeInconclusive, probe_effort +from agent.core.llm_params import _resolve_llm_params +from agent.core.local_models import ( + LOCAL_MODEL_PREFIXES, + is_local_model_id, + is_reserved_local_model_id, +) # Suggested models shown by `/model` (not a gate). Users can paste any HF @@ -40,6 +50,8 @@ _ROUTING_POLICIES = {"fastest", "cheapest", "preferred"} +_DIRECT_PREFIXES = ("anthropic/", "openai/", *LOCAL_MODEL_PREFIXES) +_LOCAL_PROBE_TIMEOUT = 15.0 def is_valid_model_id(model_id: str) -> bool: @@ -48,13 +60,22 @@ def is_valid_model_id(model_id: str) -> bool: Accepts: • anthropic/ • openai/ + • ollama/, vllm/, lm_studio/, llamacpp//[:] (HF router; tag = provider or policy) • huggingface//[:] (same, accepts legacy prefix) Actual availability is verified against the HF router catalog on switch, and by the provider on the probe's ping call. """ - if not model_id or "/" not in model_id: + if not model_id: + return False + if is_local_model_id(model_id): + return True + if is_reserved_local_model_id(model_id): + return False + if any(model_id.startswith(prefix) for prefix in LOCAL_MODEL_PREFIXES): + return False + if "/" not in model_id: return False head = model_id.split(":", 1)[0] parts = head.split("/") @@ -70,7 +91,7 @@ def _print_hf_routing_info(model_id: str, console) -> bool: Anthropic / OpenAI ids return ``True`` without printing anything — the probe below covers "does this model exist". """ - if model_id.startswith(("anthropic/", "openai/")): + if model_id.startswith(_DIRECT_PREFIXES): return True from agent.core import hf_router_catalog as cat @@ -141,7 +162,9 @@ def print_model_listing(config, console) -> None: console.print( "\n[dim]Paste any HF model id (e.g. 'MiniMaxAI/MiniMax-M2.7').\n" "Add ':fastest', ':cheapest', ':preferred', or ':' to override routing.\n" - "Use 'anthropic/' or 'openai/' for direct API access.[/dim]" + "Use 'anthropic/' or 'openai/' for direct API access.\n" + "Use 'ollama/', 'vllm/', 'lm_studio/', or " + "'llamacpp/' for local OpenAI-compatible endpoints.[/dim]" ) @@ -151,7 +174,21 @@ def print_invalid_id(arg: str, console) -> None: "[dim]Expected:\n" " • /[:tag] (HF router — paste from huggingface.co)\n" " • anthropic/\n" - " • openai/[/dim]" + " • openai/\n" + " • ollama/ | vllm/ | lm_studio/ | llamacpp/[/dim]" + ) + + +async def _probe_local_model(model_id: str) -> None: + params = _resolve_llm_params(model_id) + await asyncio.wait_for( + acompletion( + messages=[{"role": "user", "content": "ping"}], + max_tokens=1, + stream=False, + **params, + ), + timeout=_LOCAL_PROBE_TIMEOUT, ) @@ -173,9 +210,26 @@ async def probe_and_switch_model( * ✗ hard error (auth, model-not-found, quota) — we reject the switch and keep the current model so the user isn't stranded - Transient errors (5xx, timeout) complete the switch with a yellow - warning; the next real call re-surfaces the error if it's persistent. + For non-local models, transient errors (5xx, timeout) complete the switch + with a yellow warning; the next real call re-surfaces the error if it's + persistent. Local models reject every probe error, including timeouts, and + keep the current model. """ + if is_local_model_id(model_id): + console.print(f"[dim]checking local model {model_id}...[/dim]") + try: + await _probe_local_model(model_id) + except Exception as e: + console.print(f"[bold red]Switch failed:[/bold red] {e}") + console.print(f"[dim]Keeping current model: {config.model_name}[/dim]") + return + + _commit_switch(model_id, config, session, effective=None, cache=True) + console.print( + f"[green]Model switched to {model_id}[/green] [dim](effort: off)[/dim]" + ) + return + preference = config.reasoning_effort if not _print_hf_routing_info(model_id, console): return diff --git a/agent/main.py b/agent/main.py index 1fadcbda..a7262707 100644 --- a/agent/main.py +++ b/agent/main.py @@ -25,6 +25,7 @@ from agent.core.agent_loop import submission_loop from agent.core import model_switcher from agent.core.hf_tokens import resolve_hf_token +from agent.core.local_models import is_local_model_id from agent.core.session import OpType from agent.core.tools import ToolRouter from agent.messaging.gateway import NotificationGateway @@ -967,15 +968,15 @@ async def main(model: str | None = None): # Create prompt session for input (needed early for token prompt) prompt_session = PromptSession() - # HF token — required, prompt if missing - hf_token = resolve_hf_token() - if not hf_token: - hf_token = await _prompt_and_save_hf_token(prompt_session) - config = load_config(CLI_CONFIG_PATH, include_user_defaults=True) if model: config.model_name = model + # HF token — required for Hub-backed models/tools, but not for local LLMs. + hf_token = resolve_hf_token() + if not hf_token and not is_local_model_id(config.model_name): + hf_token = await _prompt_and_save_hf_token(prompt_session) + # Resolve username for banner hf_user = _get_hf_user(hf_token) @@ -1198,25 +1199,27 @@ async def headless_main( logging.basicConfig(level=logging.WARNING) _configure_runtime_logging() + config = load_config(CLI_CONFIG_PATH, include_user_defaults=True) + config.yolo_mode = True # Auto-approve everything in headless mode + + if model: + config.model_name = model + hf_token = resolve_hf_token() - if not hf_token: + if not hf_token and not is_local_model_id(config.model_name): print( "ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.", file=sys.stderr, ) sys.exit(1) - print("HF token loaded", file=sys.stderr) + if hf_token: + print("HF token loaded", file=sys.stderr) - config = load_config(CLI_CONFIG_PATH, include_user_defaults=True) - config.yolo_mode = True # Auto-approve everything in headless mode notification_gateway = NotificationGateway(config.messaging) await notification_gateway.start() hf_user = _get_hf_user(hf_token) - if model: - config.model_name = model - if max_iterations is not None: config.max_iterations = max_iterations diff --git a/tests/unit/test_cli_local_models.py b/tests/unit/test_cli_local_models.py new file mode 100644 index 00000000..836fb3fd --- /dev/null +++ b/tests/unit/test_cli_local_models.py @@ -0,0 +1,121 @@ +import pytest + +from agent.core import model_switcher +from agent.core.local_models import is_local_model_id + + +def test_local_model_helper_accepts_supported_prefixes(): + assert is_local_model_id("ollama/llama3.1:8b") + assert is_local_model_id("vllm/meta-llama/Llama-3.1-8B-Instruct") + assert is_local_model_id("lm_studio/google/gemma-3-4b") + assert is_local_model_id("llamacpp/unsloth/Qwen3.5-2B") + + +def test_model_switcher_accepts_supported_local_prefixes(): + assert model_switcher.is_valid_model_id("ollama/llama3.1:8b") + assert model_switcher.is_valid_model_id("vllm/meta-llama/Llama-3.1-8B") + assert model_switcher.is_valid_model_id("lm_studio/google/gemma-3-4b") + assert model_switcher.is_valid_model_id("llamacpp/llama-3.1-8b") + + +def test_model_switcher_rejects_empty_or_whitespace_local_ids(): + assert not model_switcher.is_valid_model_id("ollama/") + assert not model_switcher.is_valid_model_id("vllm/") + assert not model_switcher.is_valid_model_id("lm_studio/") + assert not model_switcher.is_valid_model_id("llamacpp/") + assert not model_switcher.is_valid_model_id("ollama/llama 3.1") + + +def test_openai_compat_prefix_is_not_supported(): + assert not model_switcher.is_valid_model_id("openai-compat/custom-model") + + +def test_local_models_skip_hf_router_catalog_output(): + class NoPrintConsole: + def print(self, *args, **kwargs): + raise AssertionError("local models should not print HF catalog info") + + assert model_switcher._print_hf_routing_info( + "ollama/llama3.1:8b", + NoPrintConsole(), + ) + + +@pytest.mark.asyncio +async def test_probe_and_switch_local_model_uses_no_effort(monkeypatch): + calls = [] + + async def fake_acompletion(**kwargs): + calls.append(kwargs) + return object() + + monkeypatch.setattr(model_switcher, "acompletion", fake_acompletion) + + class Config: + model_name = "openai/gpt-5.5" + reasoning_effort = "max" + + class Session: + def __init__(self): + self.model_id = None + self.model_effective_effort = {} + + def update_model(self, model_id): + self.model_id = model_id + + class Console: + def print(self, *args, **kwargs): + pass + + session = Session() + await model_switcher.probe_and_switch_model( + "ollama/llama3.1:8b", + Config(), + session, + Console(), + hf_token=None, + ) + + assert session.model_id == "ollama/llama3.1:8b" + assert session.model_effective_effort["ollama/llama3.1:8b"] is None + assert calls[0]["model"] == "openai/llama3.1:8b" + assert "reasoning_effort" not in calls[0] + assert "extra_body" not in calls[0] + + +@pytest.mark.asyncio +async def test_probe_and_switch_local_model_rejects_probe_errors(monkeypatch): + async def failing_acompletion(**kwargs): + raise ConnectionRefusedError("no server") + + monkeypatch.setattr(model_switcher, "acompletion", failing_acompletion) + + class Config: + model_name = "openai/gpt-5.5" + reasoning_effort = None + + class Session: + def __init__(self): + self.model_id = None + self.model_effective_effort = {} + + def update_model(self, model_id): + self.model_id = model_id + + class Console: + def print(self, *args, **kwargs): + pass + + config = Config() + session = Session() + await model_switcher.probe_and_switch_model( + "ollama/llama3.1:8b", + config, + session, + Console(), + hf_token=None, + ) + + assert config.model_name == "openai/gpt-5.5" + assert session.model_id is None + assert "ollama/llama3.1:8b" not in session.model_effective_effort diff --git a/tests/unit/test_llm_params.py b/tests/unit/test_llm_params.py index 5234461a..a7c7b4cd 100644 --- a/tests/unit/test_llm_params.py +++ b/tests/unit/test_llm_params.py @@ -1,3 +1,5 @@ +import pytest + from agent.core.hf_tokens import resolve_hf_request_token from agent.core.llm_params import ( UnsupportedEffortError, @@ -30,6 +32,93 @@ def test_openai_max_effort_is_still_rejected(): raise AssertionError("Expected UnsupportedEffortError for max effort") +def test_resolve_ollama_params_adds_v1_and_uses_default_key(monkeypatch): + monkeypatch.delenv("OLLAMA_API_KEY", raising=False) + monkeypatch.setenv("OLLAMA_BASE_URL", "http://localhost:11434") + + params = _resolve_llm_params("ollama/llama3.1:8b") + + assert params == { + "model": "openai/llama3.1:8b", + "api_base": "http://localhost:11434/v1", + "api_key": "sk-local-no-key-required", + } + + +def test_resolve_vllm_params_keeps_existing_v1_and_trims_slash(monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + monkeypatch.setenv("VLLM_BASE_URL", "http://localhost:8000/v1/") + + params = _resolve_llm_params("vllm/meta-llama/Llama-3.1-8B-Instruct") + + assert params["model"] == "openai/meta-llama/Llama-3.1-8B-Instruct" + assert params["api_base"] == "http://localhost:8000/v1" + assert params["api_key"] == "sk-local-no-key-required" + + +def test_resolve_lm_studio_params_uses_api_key_override(monkeypatch): + monkeypatch.setenv("LMSTUDIO_BASE_URL", "http://127.0.0.1:1234") + monkeypatch.setenv("LMSTUDIO_API_KEY", "local-secret") + monkeypatch.setenv("LOCAL_LLM_BASE_URL", "http://localhost:9999") + monkeypatch.setenv("LOCAL_LLM_API_KEY", "shared-secret") + + params = _resolve_llm_params("lm_studio/google/gemma-3-4b") + + assert params["model"] == "openai/google/gemma-3-4b" + assert params["api_base"] == "http://127.0.0.1:1234/v1" + assert params["api_key"] == "local-secret" + + +def test_resolve_local_params_uses_shared_fallback_env(monkeypatch): + monkeypatch.delenv("VLLM_BASE_URL", raising=False) + monkeypatch.delenv("VLLM_API_KEY", raising=False) + monkeypatch.setenv("LOCAL_LLM_BASE_URL", "http://localhost:9000/v1/") + monkeypatch.setenv("LOCAL_LLM_API_KEY", "shared-local-secret") + + params = _resolve_llm_params("vllm/custom-model") + + assert params["model"] == "openai/custom-model" + assert params["api_base"] == "http://localhost:9000/v1" + assert params["api_key"] == "shared-local-secret" + + +def test_resolve_llamacpp_params_strips_provider_prefix(monkeypatch): + monkeypatch.delenv("LLAMACPP_API_KEY", raising=False) + monkeypatch.setenv("LLAMACPP_BASE_URL", "http://localhost:8080") + + params = _resolve_llm_params("llamacpp/unsloth/Qwen3.5-2B") + + assert params["model"] == "openai/unsloth/Qwen3.5-2B" + assert params["api_base"] == "http://localhost:8080/v1" + + +def test_local_params_reject_reasoning_effort_in_strict_mode(): + with pytest.raises(UnsupportedEffortError, match="reasoning_effort"): + _resolve_llm_params("ollama/llama3.1", reasoning_effort="high", strict=True) + + +def test_local_params_drop_reasoning_effort_in_non_strict_mode(): + params = _resolve_llm_params( + "ollama/llama3.1", + reasoning_effort="high", + strict=False, + ) + + assert params["model"] == "openai/llama3.1" + assert "reasoning_effort" not in params + assert "extra_body" not in params + + +def test_openai_compat_prefix_is_not_a_local_escape_hatch(): + with pytest.raises(ValueError, match="Unsupported local model id"): + _resolve_llm_params("openai-compat/custom-model") + + +def test_empty_local_model_id_is_not_treated_as_hf_router(): + with pytest.raises(ValueError, match="Unsupported local model id"): + _resolve_llm_params("ollama/") + + def test_hf_router_token_prefers_inference_token(monkeypatch): monkeypatch.setenv("INFERENCE_TOKEN", " inference-token ") monkeypatch.setenv("HF_TOKEN", "hf-token")