Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 67 additions & 54 deletions agent/core/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ async def _record_manual_approved_spend_if_needed(
_MAX_LLM_RETRIES = 3
_LLM_RETRY_DELAYS = [5, 15, 30] # seconds between retries
_LLM_RATE_LIMIT_RETRY_DELAYS = [30, 60] # exceed Bedrock's ~60s TPM bucket window
_LLM_TIMEOUT = 900 # seconds (15 minutes)


def _is_rate_limit_error(error: Exception) -> bool:
Expand Down Expand Up @@ -727,23 +728,86 @@ def _assistant_message_from_result(

async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> LLMResult:
"""Call the LLM with streaming, emitting assistant_chunk events."""
response = None
_healed_effort = False # one-shot safety net per call
_healed_thinking_signature = False
messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
t_start = time.monotonic()

# Initialize accumulators
full_content = ""
tool_calls_acc: dict[int, dict] = {}
token_count = 0
finish_reason = None
final_usage_chunk = None
chunks = []
should_replay_thinking = _should_replay_thinking_state(llm_params.get("model"))

for _llm_attempt in range(_MAX_LLM_RETRIES):
try:
# Reset accumulators on retry
full_content = ""
tool_calls_acc = {}
token_count = 0
finish_reason = None
final_usage_chunk = None
chunks = []

response = await acompletion(
messages=messages,
tools=tools,
tool_choice="auto",
stream=True,
stream_options={"include_usage": True},
timeout=600,
timeout=_LLM_TIMEOUT,
**llm_params,
)

async for chunk in response:
chunks.append(chunk)
if session.is_cancelled:
tool_calls_acc.clear()
break

choice = chunk.choices[0] if chunk.choices else None
if not choice:
if hasattr(chunk, "usage") and chunk.usage:
token_count = chunk.usage.total_tokens
final_usage_chunk = chunk
continue

delta = choice.delta
if choice.finish_reason:
finish_reason = choice.finish_reason

if delta.content:
full_content += delta.content
await session.send_event(
Event(event_type="assistant_chunk", data={"content": delta.content})
)

if delta.tool_calls:
for tc_delta in delta.tool_calls:
idx = tc_delta.index
if idx not in tool_calls_acc:
tool_calls_acc[idx] = {
"id": "", "type": "function",
"function": {"name": "", "arguments": ""},
}
if tc_delta.id:
tool_calls_acc[idx]["id"] = tc_delta.id
if tc_delta.function:
if tc_delta.function.name:
tool_calls_acc[idx]["function"]["name"] += tc_delta.function.name
if tc_delta.function.arguments:
tool_calls_acc[idx]["function"]["arguments"] += tc_delta.function.arguments

if hasattr(chunk, "usage") and chunk.usage:
token_count = chunk.usage.total_tokens
final_usage_chunk = chunk

# Success: exit the retry loop
break

except ContextWindowExceededError:
raise
except Exception as e:
Expand Down Expand Up @@ -779,57 +843,6 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
continue
raise

full_content = ""
tool_calls_acc: dict[int, dict] = {}
token_count = 0
finish_reason = None
final_usage_chunk = None
chunks = []
should_replay_thinking = _should_replay_thinking_state(llm_params.get("model"))

async for chunk in response:
chunks.append(chunk)
if session.is_cancelled:
tool_calls_acc.clear()
break

choice = chunk.choices[0] if chunk.choices else None
if not choice:
if hasattr(chunk, "usage") and chunk.usage:
token_count = chunk.usage.total_tokens
final_usage_chunk = chunk
continue

delta = choice.delta
if choice.finish_reason:
finish_reason = choice.finish_reason

if delta.content:
full_content += delta.content
await session.send_event(
Event(event_type="assistant_chunk", data={"content": delta.content})
)

if delta.tool_calls:
for tc_delta in delta.tool_calls:
idx = tc_delta.index
if idx not in tool_calls_acc:
tool_calls_acc[idx] = {
"id": "", "type": "function",
"function": {"name": "", "arguments": ""},
}
if tc_delta.id:
tool_calls_acc[idx]["id"] = tc_delta.id
if tc_delta.function:
if tc_delta.function.name:
tool_calls_acc[idx]["function"]["name"] += tc_delta.function.name
if tc_delta.function.arguments:
tool_calls_acc[idx]["function"]["arguments"] += tc_delta.function.arguments

if hasattr(chunk, "usage") and chunk.usage:
token_count = chunk.usage.total_tokens
final_usage_chunk = chunk

usage = await telemetry.record_llm_call(
session,
model=llm_params.get("model", session.config.model_name),
Expand Down Expand Up @@ -873,7 +886,7 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
tools=tools,
tool_choice="auto",
stream=False,
timeout=600,
timeout=_LLM_TIMEOUT,
**llm_params,
)
break
Expand Down
29 changes: 29 additions & 0 deletions agent/core/llm_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
creating circular imports.
"""

import os

from agent.core.hf_tokens import get_hf_bill_to, resolve_hf_router_token


Expand Down Expand Up @@ -98,6 +100,10 @@ def _resolve_llm_params(
"""
Build LiteLLM kwargs for a given model id.

• ``ollama/<model>`` — local inference via Ollama. Hits the
OpenAI-compatible /v1 endpoint on localhost for improved tool calling.
Restricted to localhost/loopback to mitigate SSRF risk.

• ``anthropic/<model>`` — native thinking config. We bypass LiteLLM's
``reasoning_effort`` → ``thinking`` mapping (which lags new Claude
releases like 4.7 and sends the wrong API shape). Instead we pass
Expand Down Expand Up @@ -137,6 +143,29 @@ def _resolve_llm_params(
3. huggingface_hub cache — ``HF_TOKEN`` / ``HUGGING_FACE_HUB_TOKEN`` /
local ``hf auth login`` cache.
"""
if model_name.startswith("ollama/"):
# Use OpenAI compatible endpoint for Ollama to get better tool calling support
actual_model = model_name.replace("ollama/", "", 1)
api_base = os.environ.get("OLLAMA_API_BASE", "http://localhost:11434")

# SSRF Mitigation: Restrict local models to localhost/loopback
from urllib.parse import urlparse

parsed = urlparse(api_base)
hostname = parsed.hostname or ""
if hostname not in ("localhost", "127.0.0.1"):
raise ValueError(
f"Security error: local model API base '{api_base}' must point to localhost or 127.0.0.1"
)

if not api_base.endswith("/v1"):
api_base = f"{api_base.rstrip('/')}/v1"
return {
"model": f"openai/{actual_model}",
"api_base": api_base,
"api_key": "ollama",
}

if model_name.startswith("anthropic/"):
params: dict = {"model": model_name}
if reasoning_effort:
Expand Down
24 changes: 23 additions & 1 deletion agent/core/model_switcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@

from __future__ import annotations

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from prompt_toolkit import PromptSession

from agent.core.effort_probe import ProbeInconclusive, probe_effort
from agent.utils.ollama_utils import ensure_ollama_readiness
from agent.utils.persistence import save_last_model, get_persisted_models


# Suggested models shown by `/model` (not a gate). Users can paste any HF
Expand All @@ -33,8 +39,14 @@
{"id": "moonshotai/Kimi-K2.6", "label": "Kimi K2.6"},
{"id": "zai-org/GLM-5.1", "label": "GLM 5.1"},
{"id": "deepseek-ai/DeepSeek-V4-Pro:deepinfra", "label": "DeepSeek V4 Pro"},
{"id": "ollama/qwen3.6", "label": "Qwen 3.6 (Local)"},
]

# Load any additional models persisted by the user
for _pm in get_persisted_models():
if not any(_m["id"] == _pm["id"] for _m in SUGGESTED_MODELS):
SUGGESTED_MODELS.append(_pm)


_ROUTING_POLICIES = {"fastest", "cheapest", "preferred"}

Expand Down Expand Up @@ -160,6 +172,7 @@ async def probe_and_switch_model(
session,
console,
hf_token: str | None,
prompt_session: PromptSession,
) -> None:
"""Validate model+effort with a 1-token ping, cache the effective effort,
then commit the switch.
Expand All @@ -175,8 +188,15 @@ async def probe_and_switch_model(
Transient errors (5xx, timeout) complete the switch with a yellow
warning; the next real call re-surfaces the error if it's persistent.
"""
normalized = model_id.removeprefix("huggingface/")

# Local model pre-flight
if normalized.startswith("ollama/"):
if not await ensure_ollama_readiness(normalized, prompt_session):
return

preference = config.reasoning_effort
if not _print_hf_routing_info(model_id, console):
if not _print_hf_routing_info(normalized, console):
return

if not preference:
Expand Down Expand Up @@ -230,3 +250,5 @@ def _commit_switch(model_id, config, session, effective, cache: bool) -> None:
session.model_effective_effort.pop(model_id, None)
else:
config.model_name = model_id

save_last_model(model_id)
10 changes: 10 additions & 0 deletions agent/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@

logger = logging.getLogger(__name__)

# Local max-token overrides — used for models not yet in the litellm catalog
# or when we want to pin a specific window.
# NOTE: Ollama models not listed here will fall back to _DEFAULT_MAX_TOKENS (200k).
_MAX_TOKENS_MAP: dict[str, int] = {
"ollama/qwen3.6": 128_000,
}
_DEFAULT_MAX_TOKENS = 200_000
_TURN_COMPLETE_NOTIFICATION_CHARS = 39000

Expand All @@ -32,6 +38,10 @@ def _get_max_tokens_safe(model_name: str) -> int:
look up the bare model. Falls back to a conservative 200k default for
models not in the catalog (typically HF-router-only models).
"""
# Check local overrides first
if model_name in _MAX_TOKENS_MAP:
return _MAX_TOKENS_MAP[model_name]

from litellm import get_model_info

candidates = [model_name]
Expand Down
Loading
Loading