Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ Create a `.env` file in the project root (or export these in your shell):

```bash
ANTHROPIC_API_KEY=<your-anthropic-api-key> # if using anthropic models
OLLAMA_BASE_URL=http://localhost:11434/v1 # if using ollama/ models
LMSTUDIO_BASE_URL=http://127.0.0.1:1234/v1 # if using lm_studio/ models
VLLM_BASE_URL=http://127.0.0.1:8000/v1 # if using vllm/ models
OPENAI_COMPAT_BASE_URL=http://127.0.0.1:8000/v1 # generic OpenAI-compatible backend
HF_TOKEN=<your-hugging-face-token>
GITHUB_TOKEN=<github-personal-access-token>
```
Expand All @@ -50,6 +54,9 @@ ml-intern "fine-tune llama on my dataset"

```bash
ml-intern --model anthropic/claude-opus-4-6 "your prompt"
ml-intern --model ollama/llama3.1:8b "your prompt"
ml-intern --model lm_studio/google/gemma-4-e4b "your prompt"
ml-intern --model vllm/meta-llama/Llama-3.1-8B-Instruct "your prompt"
ml-intern --max-iterations 100 "your prompt"
ml-intern --no-stream "your prompt"
```
Expand Down
69 changes: 69 additions & 0 deletions agent/core/llm_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,34 @@
import os


_OPENAI_COMPAT_PROVIDERS: dict[str, dict[str, str]] = {
"ollama/": {
"api_base_env": "OLLAMA_BASE_URL",
"api_base_default": "http://localhost:11434/v1",
"api_key_env": "OLLAMA_API_KEY",
"api_key_default": "ollama",
},
"lm_studio/": {
"api_base_env": "LMSTUDIO_BASE_URL",
"api_base_default": "http://127.0.0.1:1234/v1",
"api_key_env": "LMSTUDIO_API_KEY",
"api_key_default": "lm-studio",
},
"vllm/": {
"api_base_env": "VLLM_BASE_URL",
"api_base_default": "http://127.0.0.1:8000/v1",
"api_key_env": "VLLM_API_KEY",
"api_key_default": "EMPTY",
},
"openai-compat/": {
"api_base_env": "OPENAI_COMPAT_BASE_URL",
"api_base_default": "http://127.0.0.1:8000/v1",
"api_key_env": "OPENAI_COMPAT_API_KEY",
"api_key_default": "EMPTY",
},
}


def _patch_litellm_effort_validation() -> None:
"""Neuter LiteLLM 1.83's hardcoded effort-level validation.

Expand Down Expand Up @@ -84,6 +112,39 @@ class UnsupportedEffortError(ValueError):
"""


def _resolve_openai_compat_params(
model_name: str,
reasoning_effort: str | None = None,
strict: bool = False,
) -> dict:
for prefix, config in _OPENAI_COMPAT_PROVIDERS.items():
if not model_name.startswith(prefix):
continue

actual_model = model_name[len(prefix) :]
params = {
"model": f"openai/{actual_model}",
"api_base": os.environ.get(
config["api_base_env"], config["api_base_default"]
).rstrip("/"),
"api_key": os.environ.get(
config["api_key_env"], config["api_key_default"]
),
}
if reasoning_effort:
if reasoning_effort not in _OPENAI_EFFORTS:
if strict:
raise UnsupportedEffortError(
"OpenAI-compatible backends don't accept "
f"effort={reasoning_effort!r}"
)
else:
params["extra_body"] = {"reasoning_effort": reasoning_effort}
return params

raise ValueError(f"Unsupported model id: {model_name}")


def _resolve_llm_params(
model_name: str,
session_hf_token: str | None = None,
Expand All @@ -109,6 +170,11 @@ def _resolve_llm_params(
• ``openai/<model>`` — ``reasoning_effort`` forwarded as a top-level
kwarg (GPT-5 / o-series). LiteLLM uses the user's ``OPENAI_API_KEY``.

• ``ollama/<model>``, ``lm_studio/<model>``, ``vllm/<model>``, and
``openai-compat/<model>`` — OpenAI-compatible backends reachable via a
configurable ``api_base``. ``reasoning_effort`` is forwarded via
``extra_body`` so local servers can ignore it safely if unsupported.

• Anything else is treated as a HuggingFace router id. We hit the
auto-routing OpenAI-compatible endpoint at
``https://router.huggingface.co/v1``. The id can be bare or carry an
Expand Down Expand Up @@ -166,6 +232,9 @@ def _resolve_llm_params(
params["reasoning_effort"] = reasoning_effort
return params

if any(model_name.startswith(prefix) for prefix in _OPENAI_COMPAT_PROVIDERS):
return _resolve_openai_compat_params(model_name, reasoning_effort, strict)

hf_model = model_name.removeprefix("huggingface/")
api_key = (
os.environ.get("INFERENCE_TOKEN")
Expand Down
21 changes: 18 additions & 3 deletions agent/core/model_switcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@


_ROUTING_POLICIES = {"fastest", "cheapest", "preferred"}
_DIRECT_PREFIXES = (
"anthropic/",
"openai/",
"ollama/",
"lm_studio/",
"vllm/",
"openai-compat/",
)


def is_valid_model_id(model_id: str) -> bool:
Expand All @@ -41,6 +49,10 @@ def is_valid_model_id(model_id: str) -> bool:
Accepts:
• anthropic/<model>
• openai/<model>
• ollama/<model>
• lm_studio/<model>
• vllm/<model>
• openai-compat/<model>
• <org>/<model>[:<tag>] (HF router; tag = provider or policy)
• huggingface/<org>/<model>[:<tag>] (same, accepts legacy prefix)

Expand All @@ -63,7 +75,7 @@ def _print_hf_routing_info(model_id: str, console) -> bool:
Anthropic / OpenAI ids return ``True`` without printing anything —
the probe below covers "does this model exist".
"""
if model_id.startswith(("anthropic/", "openai/")):
if model_id.startswith(_DIRECT_PREFIXES):
return True

from agent.core import hf_router_catalog as cat
Expand Down Expand Up @@ -136,7 +148,8 @@ def print_model_listing(config, console) -> None:
console.print(
"\n[dim]Paste any HF model id (e.g. 'MiniMaxAI/MiniMax-M2.7').\n"
"Add ':fastest', ':cheapest', ':preferred', or ':<provider>' to override routing.\n"
"Use 'anthropic/<model>' or 'openai/<model>' for direct API access.[/dim]"
"Use 'anthropic/<model>', 'openai/<model>', 'ollama/<model>',\n"
"'lm_studio/<model>', 'vllm/<model>', or 'openai-compat/<model>' for direct access.[/dim]"
)


Expand All @@ -146,7 +159,9 @@ def print_invalid_id(arg: str, console) -> None:
"[dim]Expected:\n"
" • <org>/<model>[:tag] (HF router — paste from huggingface.co)\n"
" • anthropic/<model>\n"
" • openai/<model>[/dim]"
" • openai/<model>\n"
" • ollama/<model> | lm_studio/<model> | vllm/<model>\n"
" • openai-compat/<model>[/dim]"
)


Expand Down
66 changes: 66 additions & 0 deletions tests/test_llm_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from agent.core.llm_params import _resolve_llm_params
from agent.core.model_switcher import is_valid_model_id


def test_resolve_ollama_params(monkeypatch):
monkeypatch.delenv("OLLAMA_API_KEY", raising=False)
monkeypatch.setenv("OLLAMA_BASE_URL", "http://localhost:11434/v1")

params = _resolve_llm_params("ollama/llama3.1:8b", reasoning_effort="low")

assert params == {
"model": "openai/llama3.1:8b",
"api_base": "http://localhost:11434/v1",
"api_key": "ollama",
"extra_body": {"reasoning_effort": "low"},
}


def test_resolve_lm_studio_params(monkeypatch):
monkeypatch.delenv("LMSTUDIO_API_KEY", raising=False)
monkeypatch.setenv("LMSTUDIO_BASE_URL", "http://127.0.0.1:1234/v1")

params = _resolve_llm_params("lm_studio/google/gemma-4-e4b")

assert params == {
"model": "openai/google/gemma-4-e4b",
"api_base": "http://127.0.0.1:1234/v1",
"api_key": "lm-studio",
}


def test_resolve_vllm_params(monkeypatch):
monkeypatch.delenv("VLLM_API_KEY", raising=False)
monkeypatch.setenv("VLLM_BASE_URL", "http://127.0.0.1:8000/v1")

params = _resolve_llm_params(
"vllm/meta-llama/Llama-3.1-8B-Instruct",
reasoning_effort="medium",
)

assert params == {
"model": "openai/meta-llama/Llama-3.1-8B-Instruct",
"api_base": "http://127.0.0.1:8000/v1",
"api_key": "EMPTY",
"extra_body": {"reasoning_effort": "medium"},
}


def test_resolve_openai_compat_params(monkeypatch):
monkeypatch.setenv("OPENAI_COMPAT_BASE_URL", "http://127.0.0.1:9000/v1")
monkeypatch.setenv("OPENAI_COMPAT_API_KEY", "compat-key")

params = _resolve_llm_params("openai-compat/custom-model")

assert params == {
"model": "openai/custom-model",
"api_base": "http://127.0.0.1:9000/v1",
"api_key": "compat-key",
}


def test_model_switcher_accepts_local_openai_compat_prefixes():
assert is_valid_model_id("ollama/llama3.1:8b") is True
assert is_valid_model_id("lm_studio/google/gemma-4-e4b") is True
assert is_valid_model_id("vllm/meta-llama/Llama-3.1-8B-Instruct") is True
assert is_valid_model_id("openai-compat/custom-model") is True