From a8f336cfc6bf119f7f0477463e20abcc5b051797 Mon Sep 17 00:00:00 2001 From: ianu82 Date: Thu, 19 Mar 2026 10:16:06 +0000 Subject: [PATCH 1/2] Add native Ollama provider support --- anton/chat.py | 140 ++++---------- anton/chat_ui.py | 11 ++ anton/cli.py | 19 +- anton/config/settings.py | 1 + anton/llm/anthropic.py | 2 + anton/llm/client.py | 22 +++ anton/llm/ollama.py | 354 ++++++++++++++++++++++++++++++++++ anton/llm/openai.py | 150 +------------- anton/llm/openai_compat.py | 139 +++++++++++++ anton/llm/provider.py | 3 + anton/llm/setup.py | 333 ++++++++++++++++++++++++++++++++ anton/scratchpad.py | 2 + anton/scratchpad_boot.py | 4 + pyproject.toml | 1 + tests/test_chat.py | 131 +++++++++++++ tests/test_chat_scratchpad.py | 2 + tests/test_chat_ui.py | 13 +- tests/test_client.py | 18 ++ tests/test_llm_setup.py | 83 ++++++++ tests/test_ollama_provider.py | 162 ++++++++++++++++ tests/test_scratchpad.py | 16 ++ tests/test_settings.py | 36 +++- 22 files changed, 1389 insertions(+), 253 deletions(-) create mode 100644 anton/llm/ollama.py create mode 100644 anton/llm/openai_compat.py create mode 100644 anton/llm/setup.py create mode 100644 tests/test_llm_setup.py create mode 100644 tests/test_ollama_provider.py diff --git a/anton/chat.py b/anton/chat.py index b86e36c..0e6a29c 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -2,6 +2,7 @@ import asyncio import os +import re import sys import time from collections.abc import AsyncIterator, Callable @@ -65,6 +66,21 @@ "Only involve the user if the problem truly requires something only they can provide." ) +_VERIFIER_STATUS_PREFIXES = ( + "STATUS: COMPLETE", + "STATUS: INCOMPLETE", + "STATUS: STUCK", +) + + +def _strip_ollama_think_tags(text: str) -> str: + return re.sub(r".*?", "", text, flags=re.IGNORECASE | re.DOTALL) + + +def _is_parseable_verifier_status(text: str) -> bool: + upper = text.upper() + return any(prefix in upper for prefix in _VERIFIER_STATUS_PREFIXES) + class ChatSession: """Manages a multi-turn conversation with tool-call delegation.""" @@ -749,6 +765,11 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato "has been fully completed based on the conversation above." ), }] + verifier_request_options = ( + {"think": False} + if self._llm.planning_provider_name == "ollama" + else None + ) verification = await self._llm.plan( system=( "You are a task-completion verifier. Given the conversation, determine " @@ -766,14 +787,21 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato ), messages=verify_messages, max_tokens=256, + request_options=verifier_request_options, ) - status_text = (verification.content or "").strip().upper() + verification_text = (verification.content or "").strip() + if self._llm.planning_provider_name == "ollama": + verification_text = _strip_ollama_think_tags(verification_text).strip() + if not verification_text or not _is_parseable_verifier_status(verification_text): + break + + status_text = verification_text.upper() if "STATUS: COMPLETE" in status_text: break if "STATUS: STUCK" in status_text: # Stuck — inject diagnosis request and let the LLM explain - reason = (verification.content or "").strip() + reason = verification_text self._history.append({ "role": "user", "content": ( @@ -795,7 +823,7 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato # INCOMPLETE — continue working continuation += 1 - reason = (verification.content or "").strip() + reason = verification_text self._history.append({ "role": "user", "content": ( @@ -957,7 +985,7 @@ def _rebuild_session( runtime_context = _build_runtime_context(settings) api_key = ( settings.anthropic_api_key if settings.coding_provider == "anthropic" - else settings.openai_api_key + else settings.openai_api_key if settings.coding_provider in {"openai", "openai-compatible"} else "" ) or "" return ChatSession( state["llm_client"], @@ -1211,107 +1239,22 @@ async def _handle_setup_models( session_id: str | None = None, ) -> ChatSession: """Setup sub-menu: provider, API key, and models.""" - from rich.prompt import Prompt - + from anton.llm.setup import configure_llm_settings from anton.workspace import Workspace as _Workspace # Always persist API keys and model settings to global ~/.anton/.env global_ws = _Workspace(Path.home()) - console.print() - console.print("[anton.cyan]Current configuration:[/]") - console.print(f" Provider (planning): [bold]{settings.planning_provider}[/]") - console.print(f" Provider (coding): [bold]{settings.coding_provider}[/]") - console.print(f" Planning model: [bold]{settings.planning_model}[/]") - console.print(f" Coding model: [bold]{settings.coding_model}[/]") - console.print() - - # --- Provider --- - providers = {"1": "anthropic", "2": "openai", "3": "openai-compatible"} - current_num = {"anthropic": "1", "openai": "2", "openai-compatible": "3"}.get(settings.planning_provider, "1") - console.print("[anton.cyan]Available providers:[/]") - console.print(r" [bold]1[/] Anthropic (Claude) [dim]\[recommended][/]") - console.print(r" [bold]2[/] OpenAI (GPT / o-series) [dim]\[experimental][/]") - console.print(r" [bold]3[/] OpenAI-compatible (custom endpoint) [dim]\[experimental][/]") - console.print() - - choice = Prompt.ask( - "Select provider", - choices=["1", "2", "3"], - default=current_num, - console=console, - ) - provider = providers[choice] - - # --- Base URL (OpenAI-compatible only) --- - if provider == "openai-compatible": - current_base_url = settings.openai_base_url or "" - console.print() - base_url = Prompt.ask( - f"API base URL [dim](e.g. http://localhost:11434/v1)[/]", - default=current_base_url, - console=console, - ) - base_url = base_url.strip() - if base_url: - settings.openai_base_url = base_url - global_ws.set_secret("ANTON_OPENAI_BASE_URL", base_url) - - # --- API key --- - key_attr = "anthropic_api_key" if provider == "anthropic" else "openai_api_key" - current_key = getattr(settings, key_attr) or "" - masked = current_key[:4] + "..." + current_key[-4:] if len(current_key) > 8 else "***" - console.print() - api_key = Prompt.ask( - f"API key for {provider.title()} [dim](Enter to keep {masked})[/]", - default="", - console=console, - ) - api_key = api_key.strip() - - # --- Models --- - defaults = { - "anthropic": ("claude-sonnet-4-6", "claude-haiku-4-5-20251001"), - "openai": ("gpt-5-mini", "gpt-5-nano"), - } - default_planning, default_coding = defaults.get(provider, ("", "")) - - console.print() - planning_model = Prompt.ask( - "Planning model", - default=settings.planning_model if provider == settings.planning_provider else default_planning, - console=console, - ) - coding_model = Prompt.ask( - "Coding model", - default=settings.coding_model if provider == settings.coding_provider else default_coding, - console=console, + applied = configure_llm_settings( + console, + settings, + global_ws, + show_current_config=True, ) - - # --- Persist to global ~/.anton/.env --- - settings.planning_provider = provider - settings.coding_provider = provider - settings.planning_model = planning_model - settings.coding_model = coding_model - - global_ws.set_secret("ANTON_PLANNING_PROVIDER", provider) - global_ws.set_secret("ANTON_CODING_PROVIDER", provider) - global_ws.set_secret("ANTON_PLANNING_MODEL", planning_model) - global_ws.set_secret("ANTON_CODING_MODEL", coding_model) - - if api_key: - setattr(settings, key_attr, api_key) - key_name = f"ANTON_{provider.upper()}_API_KEY" - global_ws.set_secret(key_name, api_key) - - # Validate that we actually have an API key for the chosen provider - final_key = getattr(settings, key_attr) - if not final_key: - console.print() - console.print(f"[anton.error]No API key set for {provider}. Configuration not applied.[/]") - console.print() + if not applied: return session + global_ws.apply_env_to_process() console.print() console.print("[anton.success]Configuration updated.[/]") console.print() @@ -1619,6 +1562,7 @@ def _minds_test_llm(base_url: str, api_key: str, verify: bool = True) -> bool: "ANTON_PLANNING_PROVIDER", "ANTON_CODING_PROVIDER", "ANTON_PLANNING_MODEL", "ANTON_CODING_MODEL", "ANTON_ANTHROPIC_API_KEY", "ANTON_OPENAI_API_KEY", "ANTON_OPENAI_BASE_URL", + "ANTON_OLLAMA_BASE_URL", } _SECRET_PATTERNS = ("KEY", "TOKEN", "SECRET", "PAT", "PASSWORD") @@ -2380,7 +2324,7 @@ async def _chat_loop(console: Console, settings: AntonSettings, *, resume: bool coding_api_key = ( settings.anthropic_api_key if settings.coding_provider == "anthropic" - else settings.openai_api_key + else settings.openai_api_key if settings.coding_provider in {"openai", "openai-compatible"} else "" ) or "" session = ChatSession( state["llm_client"], diff --git a/anton/chat_ui.py b/anton/chat_ui.py index 98e9f28..accf6e3 100644 --- a/anton/chat_ui.py +++ b/anton/chat_ui.py @@ -131,6 +131,7 @@ def _tool_display_text(name: str, input_json: str) -> str: PHASE_LABELS = { "memory_recall": "Memory", "planning": "Planning", + "reasoning": "Thinking", "executing": "Executing", "complete": "Complete", "failed": "Failed", @@ -154,6 +155,7 @@ def __init__(self, console: Console, toolbar: dict | None = None) -> None: self._buffer = "" # answer text accumulated during streaming self._in_tool_phase = False self._last_was_tool = False + self._answer_started = False self._initial_text = "" self._initial_printed = False self._active = False @@ -234,6 +236,7 @@ def start(self) -> None: self._initial_printed = False self._in_tool_phase = False self._last_was_tool = False + self._answer_started = False self._cancel_msg = "" self._active = True self._start_spinner() @@ -244,6 +247,7 @@ def append_text(self, delta: str) -> None: if self._in_tool_phase: self._buffer += delta self._last_was_tool = False + self._answer_started = True self._line3_peek = self._extract_peek(self._buffer) self._update_spinner() else: @@ -308,6 +312,13 @@ def update_progress(self, phase: str, message: str, eta: float | None = None) -> self._update_spinner() return + if phase == "reasoning": + self._line2_status = message or "Thinking..." + self._line3_peek = "" + self._set_status(self._line2_status) + self._update_spinner() + return + if phase == "scratchpad_start": # Print the scratchpad activity line NOW (before execution) for act in reversed(self._activities): diff --git a/anton/cli.py b/anton/cli.py index a701342..3879846 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -28,6 +28,7 @@ def _reexec() -> None: # Core dependencies from pyproject.toml that anton needs at runtime _REQUIRED_PACKAGES: dict[str, str] = { "anthropic": "anthropic>=0.42.0", + "ollama": "ollama>=0.6.1", "openai": "openai>=1.0", "pydantic": "pydantic>=2.0", "pydantic_settings": "pydantic-settings>=2.0", @@ -224,9 +225,11 @@ def main( def _has_api_key(settings) -> bool: - """Check if all configured providers have API keys.""" + """Check if the configured providers are ready to use.""" providers = {settings.planning_provider, settings.coding_provider} for p in providers: + if p == "ollama": + continue if p == "anthropic" and not (settings.anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY")): return False if p in ("openai", "openai-compatible") and not (settings.openai_api_key or os.environ.get("OPENAI_API_KEY")): @@ -235,12 +238,11 @@ def _has_api_key(settings) -> bool: def _ensure_api_key(settings) -> None: - """Prompt the user to configure a provider and API key if none is set.""" + """Prompt the user to configure an LLM provider if needed.""" if _has_api_key(settings): return - from rich.prompt import Prompt - + from anton.llm.setup import configure_llm_settings from anton.workspace import Workspace ws = Workspace(Path.home()) @@ -248,7 +250,14 @@ def _ensure_api_key(settings) -> None: if settings.minds_enabled: _ensure_minds_api_key(settings, ws) else: - _ensure_anthropic_api_key(settings, ws) + applied = configure_llm_settings( + console, + settings, + ws, + show_current_config=False, + ) + if not applied: + raise typer.Exit(1) # Reload env vars into the process so the scratchpad subprocess inherits them ws.apply_env_to_process() diff --git a/anton/config/settings.py b/anton/config/settings.py index 1aedfc2..b114cc8 100644 --- a/anton/config/settings.py +++ b/anton/config/settings.py @@ -34,6 +34,7 @@ class AntonSettings(BaseSettings): anthropic_api_key: str | None = None openai_api_key: str | None = None openai_base_url: str | None = None + ollama_base_url: str = "http://localhost:11434" memory_enabled: bool = True memory_dir: str = ".anton" diff --git a/anton/llm/anthropic.py b/anton/llm/anthropic.py index 8aec58b..13d759b 100644 --- a/anton/llm/anthropic.py +++ b/anton/llm/anthropic.py @@ -37,6 +37,7 @@ async def complete( tools: list[dict] | None = None, tool_choice: dict | None = None, max_tokens: int = 4096, + request_options: dict | None = None, ) -> LLMResponse: kwargs: dict = { "model": model, @@ -96,6 +97,7 @@ async def stream( messages: list[dict], tools: list[dict] | None = None, max_tokens: int = 4096, + request_options: dict | None = None, ) -> AsyncIterator[StreamEvent]: kwargs: dict = { "model": model, diff --git a/anton/llm/client.py b/anton/llm/client.py index a58c217..a6a5bc8 100644 --- a/anton/llm/client.py +++ b/anton/llm/client.py @@ -13,14 +13,18 @@ class LLMClient: def __init__( self, *, + planning_provider_name: str = "anthropic", planning_provider: LLMProvider, planning_model: str, + coding_provider_name: str = "anthropic", coding_provider: LLMProvider, coding_model: str, max_tokens: int = 8192, ) -> None: + self._planning_provider_name = planning_provider_name self._planning_provider = planning_provider self._planning_model = planning_model + self._coding_provider_name = coding_provider_name self._coding_provider = coding_provider self._coding_model = coding_model self._max_tokens = max_tokens @@ -32,6 +36,7 @@ async def plan( messages: list[dict], tools: list[dict] | None = None, max_tokens: int | None = None, + request_options: dict | None = None, ) -> LLMResponse: return await self._planning_provider.complete( model=self._planning_model, @@ -39,6 +44,7 @@ async def plan( messages=messages, tools=tools, max_tokens=max_tokens or self._max_tokens, + request_options=request_options, ) async def plan_stream( @@ -48,6 +54,7 @@ async def plan_stream( messages: list[dict], tools: list[dict] | None = None, max_tokens: int | None = None, + request_options: dict | None = None, ) -> AsyncIterator[StreamEvent]: async for event in self._planning_provider.stream( model=self._planning_model, @@ -55,14 +62,23 @@ async def plan_stream( messages=messages, tools=tools, max_tokens=max_tokens or self._max_tokens, + request_options=request_options, ): yield event + @property + def planning_provider_name(self) -> str: + return self._planning_provider_name + @property def coding_provider(self) -> LLMProvider: """The LLM provider used for coding/skill execution.""" return self._coding_provider + @property + def coding_provider_name(self) -> str: + return self._coding_provider_name + @property def coding_model(self) -> str: """The model name used for coding/skill execution.""" @@ -75,6 +91,7 @@ async def code( messages: list[dict], tools: list[dict] | None = None, max_tokens: int | None = None, + request_options: dict | None = None, ) -> LLMResponse: return await self._coding_provider.complete( model=self._coding_model, @@ -82,16 +99,19 @@ async def code( messages=messages, tools=tools, max_tokens=max_tokens or self._max_tokens, + request_options=request_options, ) @classmethod def from_settings(cls, settings: AntonSettings) -> LLMClient: from anton.llm.anthropic import AnthropicProvider + from anton.llm.ollama import OllamaProvider from anton.llm.openai import OpenAIProvider providers = { "anthropic": lambda: AnthropicProvider(api_key=settings.anthropic_api_key), "openai": lambda: OpenAIProvider(api_key=settings.openai_api_key, base_url=settings.openai_base_url, ssl_verify=settings.minds_ssl_verify), + "ollama": lambda: OllamaProvider(base_url=settings.ollama_base_url), "openai-compatible": lambda: OpenAIProvider(api_key=settings.openai_api_key, base_url=settings.openai_base_url, ssl_verify=settings.minds_ssl_verify), } @@ -104,8 +124,10 @@ def from_settings(cls, settings: AntonSettings) -> LLMClient: raise ValueError(f"Unknown coding provider: {settings.coding_provider}") return cls( + planning_provider_name=settings.planning_provider, planning_provider=planning_factory(), planning_model=settings.planning_model, + coding_provider_name=settings.coding_provider, coding_provider=coding_factory(), coding_model=settings.coding_model, max_tokens=getattr(settings, "max_tokens", 8192), diff --git a/anton/llm/ollama.py b/anton/llm/ollama.py new file mode 100644 index 0000000..ba85884 --- /dev/null +++ b/anton/llm/ollama.py @@ -0,0 +1,354 @@ +from __future__ import annotations + +import json +from collections.abc import AsyncIterator +from dataclasses import dataclass + +import ollama + +from anton.llm.openai_compat import translate_tools +from anton.llm.provider import ( + ContextOverflowError, + LLMProvider, + LLMResponse, + StreamComplete, + StreamEvent, + StreamTaskProgress, + StreamTextDelta, + StreamToolUseDelta, + StreamToolUseEnd, + StreamToolUseStart, + ToolCall, + Usage, + compute_context_pressure, +) + +_DEFAULT_OLLAMA_BASE_URL = "http://localhost:11434" + + +@dataclass(frozen=True) +class OllamaModelInfo: + name: str + size: str = "" + quantization: str = "" + parameter_size: str = "" + + @property + def display_name(self) -> str: + details = [detail for detail in (self.parameter_size, self.quantization) if detail] + if details: + return f"{self.name} ({', '.join(details)})" + if self.size: + return f"{self.name} ({self.size})" + return self.name + + +def normalize_ollama_base_url(base_url: str | None) -> str: + url = (base_url or _DEFAULT_OLLAMA_BASE_URL).strip() + if not url: + url = _DEFAULT_OLLAMA_BASE_URL + if not url.startswith(("http://", "https://")): + url = f"http://{url}" + url = url.rstrip("/") + if url.endswith("/v1"): + url = url[:-3].rstrip("/") + return url + + +def list_ollama_models(base_url: str | None = None) -> list[OllamaModelInfo]: + client = ollama.Client(host=normalize_ollama_base_url(base_url)) + response = client.list() + models: list[OllamaModelInfo] = [] + for model in response.models: + size = str(model.size) if model.size is not None else "" + details = model.details + models.append( + OllamaModelInfo( + name=model.model or "", + size=size, + quantization=details.quantization_level if details and details.quantization_level else "", + parameter_size=details.parameter_size if details and details.parameter_size else "", + ) + ) + return models + + +def translate_messages_to_ollama(system: str, messages: list[dict]) -> list[dict]: + """Convert Anthropic-style messages to native Ollama chat format.""" + result: list[dict] = [] + tool_name_by_id: dict[str, str] = {} + if system: + result.append({"role": "system", "content": system}) + + for message in messages: + role = message["role"] + content = message.get("content") + + if isinstance(content, str): + result.append({"role": role, "content": content}) + continue + + if not isinstance(content, list): + result.append({"role": role, "content": str(content) if content else ""}) + continue + + if role == "assistant": + text_parts: list[str] = [] + tool_calls: list[dict] = [] + for block in content: + if block.get("type") == "text": + text_parts.append(block.get("text", "")) + elif block.get("type") == "tool_use": + tool_name_by_id[block["id"]] = block["name"] + tool_calls.append({ + "function": { + "name": block["name"], + "arguments": block.get("input", {}), + } + }) + msg: dict[str, object] = {"role": "assistant"} + if text_parts: + msg["content"] = "\n".join(text_parts) + if tool_calls: + msg["tool_calls"] = tool_calls + result.append(msg) + continue + + if role == "user": + text_parts: list[str] = [] + images: list[str] = [] + for block in content: + block_type = block.get("type") + if block_type == "tool_result": + if text_parts or images: + result.append(_build_ollama_user_message(text_parts, images)) + text_parts = [] + images = [] + tool_name = tool_name_by_id.get(block["tool_use_id"], block["tool_use_id"]) + tool_content = block.get("content", "") + if isinstance(tool_content, list): + tool_content = "\n".join( + item.get("text", "") + for item in tool_content + if item.get("type") == "text" + ) + result.append({ + "role": "tool", + "tool_name": tool_name, + "content": str(tool_content), + }) + elif block_type == "text": + text_parts.append(block.get("text", "")) + elif block_type == "image": + source = block.get("source", {}) + if source.get("type") == "base64" and source.get("data"): + images.append(source["data"]) + if text_parts or images: + result.append(_build_ollama_user_message(text_parts, images)) + continue + + text = " ".join( + block.get("text", "") + for block in content + if block.get("type") == "text" + ) + result.append({"role": role, "content": text or ""}) + + return result + + +def _build_ollama_user_message(text_parts: list[str], images: list[str]) -> dict[str, object]: + message: dict[str, object] = {"role": "user"} + if text_parts: + message["content"] = "\n".join(text_parts) + if images: + message["images"] = images + return message + + +class OllamaProvider(LLMProvider): + def __init__(self, base_url: str | None = None) -> None: + self._base_url = normalize_ollama_base_url(base_url) + self._client = ollama.AsyncClient(host=self._base_url) + + async def complete( + self, + *, + model: str, + system: str, + messages: list[dict], + tools: list[dict] | None = None, + tool_choice: dict | None = None, + max_tokens: int = 4096, + request_options: dict | None = None, + ) -> LLMResponse: + kwargs = self._build_chat_kwargs( + model=model, + system=system, + messages=messages, + tools=tools, + tool_choice=tool_choice, + max_tokens=max_tokens, + request_options=request_options, + ) + try: + response = await self._client.chat(**kwargs) + except ollama.ResponseError as exc: + message = str(exc).lower() + if "context" in message and "length" in message: + raise ContextOverflowError(str(exc)) from exc + raise ConnectionError(str(exc)) from exc + except ConnectionError as exc: + raise ConnectionError(str(exc)) from exc + + content_text = response.message.content or "" + tool_calls = self._tool_calls_from_message(response.message.tool_calls or []) + input_tokens = response.prompt_eval_count or 0 + output_tokens = response.eval_count or 0 + return LLMResponse( + content=content_text, + tool_calls=tool_calls, + usage=Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + context_pressure=compute_context_pressure(model, input_tokens), + ), + stop_reason=response.done_reason, + ) + + async def stream( + self, + *, + model: str, + system: str, + messages: list[dict], + tools: list[dict] | None = None, + max_tokens: int = 4096, + request_options: dict | None = None, + ) -> AsyncIterator[StreamEvent]: + kwargs = self._build_chat_kwargs( + model=model, + system=system, + messages=messages, + tools=tools, + tool_choice=None, + max_tokens=max_tokens, + request_options=request_options, + ) + kwargs["stream"] = True + + content_text = "" + tool_calls: list[ToolCall] = [] + input_tokens = 0 + output_tokens = 0 + stop_reason: str | None = None + showed_reasoning = False + saw_content = False + next_tool_index = 1 + + try: + stream = await self._client.chat(**kwargs) + async for chunk in stream: + message = chunk.message + if not showed_reasoning and not saw_content and message.thinking: + showed_reasoning = True + yield StreamTaskProgress(phase="reasoning", message="Thinking...") + + if message.content: + saw_content = True + content_text += message.content + yield StreamTextDelta(text=message.content) + + if message.tool_calls: + for call in message.tool_calls: + tool_call = self._tool_call_from_ollama(call, next_tool_index) + next_tool_index += 1 + tool_calls.append(tool_call) + yield StreamToolUseStart(id=tool_call.id, name=tool_call.name) + yield StreamToolUseDelta( + id=tool_call.id, + json_delta=json.dumps(tool_call.input), + ) + yield StreamToolUseEnd(id=tool_call.id) + + if chunk.prompt_eval_count is not None: + input_tokens = chunk.prompt_eval_count + if chunk.eval_count is not None: + output_tokens = chunk.eval_count + if chunk.done_reason: + stop_reason = chunk.done_reason + except ollama.ResponseError as exc: + message = str(exc).lower() + if "context" in message and "length" in message: + raise ContextOverflowError(str(exc)) from exc + raise ConnectionError(str(exc)) from exc + except ConnectionError as exc: + raise ConnectionError(str(exc)) from exc + + yield StreamComplete( + response=LLMResponse( + content=content_text, + tool_calls=tool_calls, + usage=Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + context_pressure=compute_context_pressure(model, input_tokens), + ), + stop_reason=stop_reason, + ) + ) + + def _build_chat_kwargs( + self, + *, + model: str, + system: str, + messages: list[dict], + tools: list[dict] | None, + tool_choice: dict | None, + max_tokens: int, + request_options: dict | None, + ) -> dict: + kwargs: dict[str, object] = { + "model": model, + "messages": translate_messages_to_ollama(system, messages), + "options": {"num_predict": max_tokens}, + } + + translated_tools = translate_tools(tools or []) if tools else None + if translated_tools: + kwargs["tools"] = self._apply_tool_choice(translated_tools, tool_choice) + + if request_options and "think" in request_options: + kwargs["think"] = request_options["think"] + + return kwargs + + @staticmethod + def _apply_tool_choice(tools: list[dict], tool_choice: dict | None) -> list[dict]: + if not tool_choice or tool_choice.get("type") != "tool": + return tools + target_name = tool_choice.get("name") + if not target_name: + return tools + filtered = [ + tool + for tool in tools + if tool.get("function", {}).get("name") == target_name + ] + return filtered or tools + + def _tool_calls_from_message(self, calls: list) -> list[ToolCall]: + return [ + self._tool_call_from_ollama(call, index) + for index, call in enumerate(calls, start=1) + ] + + @staticmethod + def _tool_call_from_ollama(call, index: int) -> ToolCall: + function = call.function + return ToolCall( + id=f"ollama_tool_{index}", + name=function.name, + input=dict(function.arguments or {}), + ) diff --git a/anton/llm/openai.py b/anton/llm/openai.py index 403bc55..f7fab1f 100644 --- a/anton/llm/openai.py +++ b/anton/llm/openai.py @@ -5,6 +5,11 @@ import openai +from anton.llm.openai_compat import ( + _translate_messages, + _translate_tool_choice, + _translate_tools, +) from anton.llm.provider import ( ContextOverflowError, LLMProvider, @@ -21,149 +26,6 @@ ) -def _translate_tools(tools: list[dict]) -> list[dict]: - """Anthropic tool format -> OpenAI function-calling format.""" - result = [] - for tool in tools: - result.append({ - "type": "function", - "function": { - "name": tool["name"], - "description": tool.get("description", ""), - "parameters": tool.get("input_schema", {}), - }, - }) - return result - - -def _translate_tool_choice(tool_choice: dict) -> dict | str: - """Anthropic tool_choice -> OpenAI tool_choice.""" - tc_type = tool_choice.get("type") - if tc_type == "tool": - return {"type": "function", "function": {"name": tool_choice["name"]}} - if tc_type == "any": - return "required" - if tc_type == "auto": - return "auto" - return "auto" - - -def _translate_messages(system: str, messages: list[dict]) -> list[dict]: - """Convert Anthropic-style messages to OpenAI chat format. - - Handles: - - system prompt -> {"role": "system", ...} - - plain text messages pass through - - assistant messages with tool_use content blocks -> tool_calls array - - user messages with tool_result content blocks -> role:tool messages - """ - result: list[dict] = [] - if system: - result.append({"role": "system", "content": system}) - - for msg in messages: - role = msg["role"] - content = msg.get("content") - - # Plain string content — pass through - if isinstance(content, str): - result.append({"role": role, "content": content}) - continue - - # Content is a list of blocks (Anthropic format) - if isinstance(content, list): - if role == "assistant": - result.extend(_translate_assistant_blocks(content)) - elif role == "user": - result.extend(_translate_user_blocks(content)) - else: - # Fallback: join text blocks - text = " ".join( - b.get("text", "") for b in content if b.get("type") == "text" - ) - result.append({"role": role, "content": text or ""}) - continue - - # Fallback - result.append({"role": role, "content": str(content) if content else ""}) - - return result - - -def _translate_assistant_blocks(blocks: list[dict]) -> list[dict]: - """Convert assistant content blocks to OpenAI message(s).""" - text_parts: list[str] = [] - tool_calls: list[dict] = [] - - for block in blocks: - if block.get("type") == "text": - text_parts.append(block["text"]) - elif block.get("type") == "tool_use": - tool_calls.append({ - "id": block["id"], - "type": "function", - "function": { - "name": block["name"], - "arguments": json.dumps(block.get("input", {})), - }, - }) - - msg: dict = {"role": "assistant"} - content = "\n".join(text_parts) if text_parts else None - msg["content"] = content - if tool_calls: - msg["tool_calls"] = tool_calls - return [msg] - - -def _translate_user_blocks(blocks: list[dict]) -> list[dict]: - """Convert user content blocks (including tool_result and image) to OpenAI messages.""" - result: list[dict] = [] - content_parts: list[dict] = [] # Accumulates text + image_url blocks - - for block in blocks: - if block.get("type") == "tool_result": - # Flush any accumulated content parts first - if content_parts: - result.append({"role": "user", "content": content_parts}) - content_parts = [] - # tool_result -> role:tool message - content = block.get("content", "") - if isinstance(content, list): - content = "\n".join( - b.get("text", "") for b in content if b.get("type") == "text" - ) - result.append({ - "role": "tool", - "tool_call_id": block["tool_use_id"], - "content": str(content), - }) - elif block.get("type") == "text": - content_parts.append({"type": "text", "text": block.get("text", "")}) - elif block.get("type") == "image": - # Anthropic image block -> OpenAI image_url block - source = block.get("source", {}) - if source.get("type") == "base64": - media_type = source.get("media_type", "image/png") - data = source.get("data", "") - content_parts.append({ - "type": "image_url", - "image_url": {"url": f"data:{media_type};base64,{data}"}, - }) - - if content_parts: - # If only text parts, flatten to a simple string for compatibility - if all(p.get("type") == "text" for p in content_parts): - result.append({ - "role": "user", - "content": "\n".join(p["text"] for p in content_parts), - }) - else: - result.append({"role": "user", "content": content_parts}) - - return result - - class OpenAIProvider(LLMProvider): def __init__( self, @@ -191,6 +53,7 @@ async def complete( tools: list[dict] | None = None, tool_choice: dict | None = None, max_tokens: int = 4096, + request_options: dict | None = None, ) -> LLMResponse: oai_messages = _translate_messages(system, messages) @@ -257,6 +120,7 @@ async def stream( messages: list[dict], tools: list[dict] | None = None, max_tokens: int = 4096, + request_options: dict | None = None, ) -> AsyncIterator[StreamEvent]: oai_messages = _translate_messages(system, messages) diff --git a/anton/llm/openai_compat.py b/anton/llm/openai_compat.py new file mode 100644 index 0000000..e5fc96c --- /dev/null +++ b/anton/llm/openai_compat.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import json + + +def translate_tools(tools: list[dict]) -> list[dict]: + """Anthropic tool format -> OpenAI/Ollama function-calling format.""" + result = [] + for tool in tools: + result.append({ + "type": "function", + "function": { + "name": tool["name"], + "description": tool.get("description", ""), + "parameters": tool.get("input_schema", {}), + }, + }) + return result + + +def translate_tool_choice(tool_choice: dict) -> dict | str: + """Anthropic tool_choice -> OpenAI tool_choice.""" + tc_type = tool_choice.get("type") + if tc_type == "tool": + return {"type": "function", "function": {"name": tool_choice["name"]}} + if tc_type == "any": + return "required" + if tc_type == "auto": + return "auto" + return "auto" + + +def translate_messages(system: str, messages: list[dict]) -> list[dict]: + """Convert Anthropic-style messages to OpenAI chat format.""" + result: list[dict] = [] + if system: + result.append({"role": "system", "content": system}) + + for msg in messages: + role = msg["role"] + content = msg.get("content") + + if isinstance(content, str): + result.append({"role": role, "content": content}) + continue + + if isinstance(content, list): + if role == "assistant": + result.extend(_translate_assistant_blocks(content)) + elif role == "user": + result.extend(_translate_user_blocks(content)) + else: + text = " ".join( + block.get("text", "") + for block in content + if block.get("type") == "text" + ) + result.append({"role": role, "content": text or ""}) + continue + + result.append({"role": role, "content": str(content) if content else ""}) + + return result + + +def _translate_assistant_blocks(blocks: list[dict]) -> list[dict]: + text_parts: list[str] = [] + tool_calls: list[dict] = [] + + for block in blocks: + if block.get("type") == "text": + text_parts.append(block["text"]) + elif block.get("type") == "tool_use": + tool_calls.append({ + "id": block["id"], + "type": "function", + "function": { + "name": block["name"], + "arguments": json.dumps(block.get("input", {})), + }, + }) + + msg: dict = {"role": "assistant"} + msg["content"] = "\n".join(text_parts) if text_parts else None + if tool_calls: + msg["tool_calls"] = tool_calls + return [msg] + + +def _translate_user_blocks(blocks: list[dict]) -> list[dict]: + result: list[dict] = [] + content_parts: list[dict] = [] + + for block in blocks: + if block.get("type") == "tool_result": + if content_parts: + result.append({"role": "user", "content": content_parts}) + content_parts = [] + content = block.get("content", "") + if isinstance(content, list): + content = "\n".join( + item.get("text", "") + for item in content + if item.get("type") == "text" + ) + result.append({ + "role": "tool", + "tool_call_id": block["tool_use_id"], + "content": str(content), + }) + elif block.get("type") == "text": + content_parts.append({"type": "text", "text": block.get("text", "")}) + elif block.get("type") == "image": + source = block.get("source", {}) + if source.get("type") == "base64": + media_type = source.get("media_type", "image/png") + data = source.get("data", "") + content_parts.append({ + "type": "image_url", + "image_url": {"url": f"data:{media_type};base64,{data}"}, + }) + + if content_parts: + if all(part.get("type") == "text" for part in content_parts): + result.append({ + "role": "user", + "content": "\n".join(part["text"] for part in content_parts), + }) + else: + result.append({"role": "user", "content": content_parts}) + + return result + + +_translate_assistant_blocks = _translate_assistant_blocks +_translate_messages = translate_messages +_translate_tool_choice = translate_tool_choice +_translate_tools = translate_tools +_translate_user_blocks = _translate_user_blocks diff --git a/anton/llm/provider.py b/anton/llm/provider.py index 359a77d..d1f6a8c 100644 --- a/anton/llm/provider.py +++ b/anton/llm/provider.py @@ -144,6 +144,7 @@ async def complete( tools: list[dict] | None = None, tool_choice: dict | None = None, max_tokens: int = 4096, + request_options: dict | None = None, ) -> LLMResponse: ... async def stream( @@ -154,6 +155,7 @@ async def stream( messages: list[dict], tools: list[dict] | None = None, max_tokens: int = 4096, + request_options: dict | None = None, ) -> AsyncIterator[StreamEvent]: """Stream LLM responses. Default falls back to complete().""" response = await self.complete( @@ -162,6 +164,7 @@ async def stream( messages=messages, tools=tools, max_tokens=max_tokens, + request_options=request_options, ) if response.content: yield StreamTextDelta(text=response.content) diff --git a/anton/llm/setup.py b/anton/llm/setup.py new file mode 100644 index 0000000..cc575d3 --- /dev/null +++ b/anton/llm/setup.py @@ -0,0 +1,333 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from rich.prompt import Prompt + +from anton.llm.ollama import OllamaModelInfo, list_ollama_models, normalize_ollama_base_url + +if TYPE_CHECKING: + from rich.console import Console + + from anton.config.settings import AntonSettings + from anton.workspace import Workspace + + +@dataclass(frozen=True) +class ProviderOption: + key: str + label: str + badge: str + default_planning_model: str + default_coding_model: str + + +PROVIDER_OPTIONS: tuple[ProviderOption, ...] = ( + ProviderOption( + key="anthropic", + label="Anthropic (Claude)", + badge="recommended", + default_planning_model="claude-sonnet-4-6", + default_coding_model="claude-haiku-4-5-20251001", + ), + ProviderOption( + key="openai", + label="OpenAI (GPT / o-series)", + badge="experimental", + default_planning_model="gpt-5-mini", + default_coding_model="gpt-5-nano", + ), + ProviderOption( + key="ollama", + label="Ollama (local models)", + badge="local", + default_planning_model="", + default_coding_model="", + ), + ProviderOption( + key="openai-compatible", + label="OpenAI-compatible (custom endpoint)", + badge="experimental", + default_planning_model="", + default_coding_model="", + ), +) + + +def configure_llm_settings( + console, + settings, + workspace, + *, + show_current_config: bool = True, +) -> bool: + """Prompt for provider/model configuration and persist it to the workspace.""" + if show_current_config: + console.print() + console.print("[anton.cyan]Current configuration:[/]") + console.print(f" Provider (planning): [bold]{settings.planning_provider}[/]") + console.print(f" Provider (coding): [bold]{settings.coding_provider}[/]") + console.print(f" Planning model: [bold]{settings.planning_model}[/]") + console.print(f" Coding model: [bold]{settings.coding_model}[/]") + console.print() + else: + console.print() + console.print("[anton.cyan]LLM configuration[/]") + console.print() + + option_by_number = { + str(index): option + for index, option in enumerate(PROVIDER_OPTIONS, start=1) + } + current_number = current_provider_number(settings.planning_provider) + + console.print("[anton.cyan]Available providers:[/]") + for number, option in option_by_number.items(): + console.print( + f" [bold]{number}[/] {option.label:<36} [dim]\\[{option.badge}][/]" + ) + console.print() + + choice = Prompt.ask( + "Select provider", + choices=list(option_by_number), + default=current_number, + console=console, + ) + provider = option_by_number[choice] + + if provider.key == "ollama": + config = _prompt_for_ollama_config(console, settings, provider) + if config is None: + return False + planning_model, coding_model = config + settings.ollama_base_url = normalize_ollama_base_url(settings.ollama_base_url) + workspace.set_secret("ANTON_OLLAMA_BASE_URL", settings.ollama_base_url) + else: + if provider.key == "openai-compatible": + current_base_url = settings.openai_base_url or "" + console.print() + base_url = Prompt.ask( + "API base URL [dim](e.g. http://localhost:11434/v1)[/]", + default=current_base_url, + console=console, + ).strip() + if base_url: + settings.openai_base_url = base_url + workspace.set_secret("ANTON_OPENAI_BASE_URL", base_url) + + api_key = _prompt_for_api_key(console, settings, provider.key) + if api_key is None: + return False + + planning_model, coding_model = _prompt_for_cloud_models(console, settings, provider) + if api_key: + key_name = api_key_env_name(provider.key) + if key_name: + workspace.set_secret(key_name, api_key) + _set_provider_api_key(settings, provider.key, api_key) + + settings.planning_provider = provider.key + settings.coding_provider = provider.key + settings.planning_model = planning_model + settings.coding_model = coding_model + workspace.set_secret("ANTON_PLANNING_PROVIDER", provider.key) + workspace.set_secret("ANTON_CODING_PROVIDER", provider.key) + workspace.set_secret("ANTON_PLANNING_MODEL", planning_model) + workspace.set_secret("ANTON_CODING_MODEL", coding_model) + return True + + +def current_provider_number(provider: str) -> str: + for index, option in enumerate(PROVIDER_OPTIONS, start=1): + if option.key == provider: + return str(index) + return "1" + + +def api_key_env_name(provider: str) -> str | None: + if provider == "anthropic": + return "ANTON_ANTHROPIC_API_KEY" + if provider in {"openai", "openai-compatible"}: + return "ANTON_OPENAI_API_KEY" + return None + + +def _prompt_for_api_key(console, settings, provider: str) -> str | None: + key_attr = "anthropic_api_key" if provider == "anthropic" else "openai_api_key" + current_key = getattr(settings, key_attr) or "" + masked = _mask_secret(current_key) if current_key else "***" + console.print() + api_key = Prompt.ask( + f"API key for {provider.title()} [dim](Enter to keep {masked})[/]", + default="", + console=console, + ).strip() + if api_key: + return api_key + if current_key: + return "" + console.print() + console.print(f"[anton.error]No API key set for {provider}. Configuration not applied.[/]") + console.print() + return None + + +def _prompt_for_cloud_models(console, settings, provider: ProviderOption) -> tuple[str, str]: + console.print() + planning_model = Prompt.ask( + "Planning model", + default=( + settings.planning_model + if provider.key == settings.planning_provider + else provider.default_planning_model + ), + console=console, + ) + coding_model = Prompt.ask( + "Coding model", + default=( + settings.coding_model + if provider.key == settings.coding_provider + else provider.default_coding_model + ), + console=console, + ) + return planning_model, coding_model + + +def _prompt_for_ollama_config(console, settings, provider: ProviderOption) -> tuple[str, str] | None: + current_url = settings.ollama_base_url or "http://localhost:11434" + console.print() + base_url = Prompt.ask( + "Ollama URL [dim](e.g. http://localhost:11434)[/]", + default=current_url, + console=console, + ).strip() + settings.ollama_base_url = normalize_ollama_base_url(base_url or current_url) + + try: + models = list_ollama_models(settings.ollama_base_url) + except Exception as exc: + console.print() + console.print( + "[anton.warning]Could not query Ollama for local models.[/]" + f" [dim]({exc})[/]" + ) + console.print("[anton.muted]Enter model names manually instead.[/]") + console.print() + planning_model = Prompt.ask( + "Planning model", + default=settings.planning_model if provider.key == settings.planning_provider else "", + console=console, + ).strip() + if not planning_model: + console.print("[anton.error]No model provided. Configuration not applied.[/]") + console.print() + return None + coding_model = Prompt.ask( + "Coding model", + default=settings.coding_model if provider.key == settings.coding_provider else planning_model, + console=console, + ).strip() + return planning_model, coding_model or planning_model + + if not models: + console.print() + console.print("[anton.warning]No local Ollama models found.[/]") + console.print("[anton.muted]Pull a model with `ollama pull ` or enter a name manually.[/]") + console.print() + planning_model = Prompt.ask( + "Planning model", + default=settings.planning_model if provider.key == settings.planning_provider else "", + console=console, + ).strip() + if not planning_model: + console.print("[anton.error]No model provided. Configuration not applied.[/]") + console.print() + return None + coding_model = Prompt.ask( + "Coding model", + default=settings.coding_model if provider.key == settings.coding_provider else planning_model, + console=console, + ).strip() + return planning_model, coding_model or planning_model + + console.print() + console.print("[anton.cyan]Available local Ollama models:[/]") + for index, model in enumerate(models, start=1): + console.print(f" [bold]{index}[/] {model.display_name}") + manual_choice = str(len(models) + 1) + console.print(f" [bold]{manual_choice}[/] Enter model name manually") + + planning_model = _prompt_for_ollama_model_choice( + console=console, + label="Planning model", + models=models, + default_model=( + settings.planning_model + if provider.key == settings.planning_provider + else models[0].name + ), + ) + coding_model = _prompt_for_ollama_model_choice( + console=console, + label="Coding model", + models=models, + default_model=( + settings.coding_model + if provider.key == settings.coding_provider + else planning_model + ), + manual_default=planning_model, + ) + return planning_model, coding_model + + +def _prompt_for_ollama_model_choice( + *, + console, + label: str, + models: list[OllamaModelInfo], + default_model: str, + manual_default: str = "", +) -> str: + choice_by_number = { + str(index): model.name + for index, model in enumerate(models, start=1) + } + manual_choice = str(len(models) + 1) + choices = [*choice_by_number.keys(), manual_choice] + default_choice = manual_choice + for number, model_name in choice_by_number.items(): + if model_name == default_model: + default_choice = number + break + + choice = Prompt.ask( + label, + choices=choices, + default=default_choice, + console=console, + ) + if choice == manual_choice: + return Prompt.ask( + f"{label} name", + default=manual_default or default_model, + console=console, + ).strip() + return choice_by_number[choice] + + +def _mask_secret(value: str, *, keep: int = 4) -> str: + if len(value) <= keep * 2: + return "*" * max(len(value), 3) + return f"{value[:keep]}...{value[-keep:]}" + + +def _set_provider_api_key(settings, provider: str, api_key: str) -> None: + if provider == "anthropic": + settings.anthropic_api_key = api_key + elif provider in {"openai", "openai-compatible"}: + settings.openai_api_key = api_key diff --git a/anton/scratchpad.py b/anton/scratchpad.py index 84aa7a9..d907182 100644 --- a/anton/scratchpad.py +++ b/anton/scratchpad.py @@ -345,6 +345,8 @@ async def start(self) -> None: env["OPENAI_API_KEY"] = env["ANTON_OPENAI_API_KEY"] if "OPENAI_BASE_URL" not in env and "ANTON_OPENAI_BASE_URL" in env: env["OPENAI_BASE_URL"] = env["ANTON_OPENAI_BASE_URL"] + if "OLLAMA_HOST" not in env and "ANTON_OLLAMA_BASE_URL" in env: + env["OLLAMA_HOST"] = env["ANTON_OLLAMA_BASE_URL"] # If settings provided an explicit API key (e.g. from ~/.anton/.env or # Pydantic settings), inject it so the subprocess SDK can authenticate. if self._coding_api_key: diff --git a/anton/scratchpad_boot.py b/anton/scratchpad_boot.py index fe8e997..6812441 100644 --- a/anton/scratchpad_boot.py +++ b/anton/scratchpad_boot.py @@ -20,12 +20,16 @@ _scratchpad_provider_name = os.environ.get("ANTON_SCRATCHPAD_PROVIDER", "anthropic") if _scratchpad_provider_name in ("openai", "openai-compatible"): from anton.llm.openai import OpenAIProvider as _ProviderClass + elif _scratchpad_provider_name == "ollama": + from anton.llm.ollama import OllamaProvider as _ProviderClass else: from anton.llm.anthropic import AnthropicProvider as _ProviderClass _llm_ssl_verify = os.environ.get("ANTON_MINDS_SSL_VERIFY", "true").lower() != "false" if _scratchpad_provider_name in ("openai", "openai-compatible"): _llm_provider = _ProviderClass(ssl_verify=_llm_ssl_verify) + elif _scratchpad_provider_name == "ollama": + _llm_provider = _ProviderClass(base_url=os.environ.get("ANTON_OLLAMA_BASE_URL")) else: _llm_provider = _ProviderClass() # Anthropic doesn't need ssl_verify _llm_model = _scratchpad_model diff --git a/pyproject.toml b/pyproject.toml index 99ff96f..8bf4c2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ description = "Autonomous coding copilot" requires-python = ">=3.11" dependencies = [ "anthropic>=0.42.0", + "ollama>=0.6.1", "openai>=1.0", "pydantic>=2.0", "pydantic-settings>=2.0", diff --git a/tests/test_chat.py b/tests/test_chat.py index ef33a77..bdf1655 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -173,3 +173,134 @@ async def _plan_stream(**kwargs): session._summarize_history.assert_not_called() compacted = [e for e in events if isinstance(e, StreamContextCompacted)] assert len(compacted) == 0 + + +class _FakeAsyncIter: + def __init__(self, items): + self._items = list(items) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._items: + raise StopAsyncIteration + return self._items.pop(0) + + +class TestOllamaVerifier: + async def test_ollama_verifier_disables_thinking(self, monkeypatch): + mock_llm = AsyncMock() + mock_llm.planning_provider_name = "ollama" + mock_llm.plan = AsyncMock(return_value=_text_response("STATUS: COMPLETE — done")) + + calls = 0 + + def fake_plan_stream(**kwargs): + nonlocal calls + calls += 1 + if calls == 1: + return _FakeAsyncIter([ + StreamComplete( + response=LLMResponse( + content="", + tool_calls=[ToolCall(id="tool_1", name="scratchpad", input={"action": "view", "name": "main"})], + usage=Usage(), + stop_reason="tool_use", + ) + ) + ]) + return _FakeAsyncIter([ + StreamComplete(response=_text_response("All done.")) + ]) + + mock_llm.plan_stream = fake_plan_stream + + async def fake_dispatch_tool(session, name, input): + return "tool ok" + + monkeypatch.setattr("anton.chat.dispatch_tool", fake_dispatch_tool) + + session = ChatSession(mock_llm) + events = [event async for event in session.turn_stream("help")] + + assert any(isinstance(event, StreamComplete) for event in events) + assert mock_llm.plan.await_count == 1 + assert mock_llm.plan.call_args.kwargs["request_options"] == {"think": False} + + async def test_ollama_unparseable_verifier_stops_retry_loop(self, monkeypatch): + mock_llm = AsyncMock() + mock_llm.planning_provider_name = "ollama" + mock_llm.plan = AsyncMock(return_value=_text_response("")) + + calls = 0 + + def fake_plan_stream(**kwargs): + nonlocal calls + calls += 1 + if calls == 1: + return _FakeAsyncIter([ + StreamComplete( + response=LLMResponse( + content="", + tool_calls=[ToolCall(id="tool_1", name="scratchpad", input={"action": "view", "name": "main"})], + usage=Usage(), + stop_reason="tool_use", + ) + ) + ]) + return _FakeAsyncIter([ + StreamComplete(response=_text_response("All done.")) + ]) + + mock_llm.plan_stream = fake_plan_stream + + async def fake_dispatch_tool(session, name, input): + return "tool ok" + + monkeypatch.setattr("anton.chat.dispatch_tool", fake_dispatch_tool) + + session = ChatSession(mock_llm) + events = [event async for event in session.turn_stream("help")] + + assert any(isinstance(event, StreamComplete) for event in events) + assert calls == 2 + + async def test_non_ollama_unparseable_verifier_continues_working(self, monkeypatch): + mock_llm = AsyncMock() + mock_llm.planning_provider_name = "anthropic" + mock_llm.plan = AsyncMock(return_value=_text_response("")) + + calls = 0 + + def fake_plan_stream(**kwargs): + nonlocal calls + calls += 1 + if calls == 1: + return _FakeAsyncIter([ + StreamComplete( + response=LLMResponse( + content="", + tool_calls=[ToolCall(id="tool_1", name="scratchpad", input={"action": "view", "name": "main"})], + usage=Usage(), + stop_reason="tool_use", + ) + ) + ]) + return _FakeAsyncIter([ + StreamComplete(response=_text_response("All done.")) + ]) + + mock_llm.plan_stream = fake_plan_stream + + async def fake_dispatch_tool(session, name, input): + return "tool ok" + + monkeypatch.setattr("anton.chat.dispatch_tool", fake_dispatch_tool) + + session = ChatSession(mock_llm) + events = [event async for event in session.turn_stream("help")] + + assert any(isinstance(event, StreamComplete) for event in events) + assert calls == 3 + assert mock_llm.plan.call_args.kwargs["request_options"] is None diff --git a/tests/test_chat_scratchpad.py b/tests/test_chat_scratchpad.py index 94d0d0a..9ce89b6 100644 --- a/tests/test_chat_scratchpad.py +++ b/tests/test_chat_scratchpad.py @@ -228,6 +228,7 @@ def fake_plan_stream(**kwargs): ]) mock_llm.plan_stream = fake_plan_stream + mock_llm.plan = AsyncMock(return_value=_text_response("STATUS: COMPLETE — done")) session = ChatSession(mock_llm) try: @@ -270,6 +271,7 @@ def fake_plan_stream(**kwargs): return _FakeAsyncIter([StreamComplete(response=final_response)]) mock_llm.plan_stream = fake_plan_stream + mock_llm.plan = AsyncMock(return_value=_text_response("STATUS: COMPLETE — done")) session = ChatSession(mock_llm) try: diff --git a/tests/test_chat_ui.py b/tests/test_chat_ui.py index 471f9bf..3e9ebc4 100644 --- a/tests/test_chat_ui.py +++ b/tests/test_chat_ui.py @@ -86,9 +86,20 @@ def test_update_progress_without_eta(self, MockLive): assert live.update.call_count >= 1 def test_phase_labels_cover_all_phases(self): - expected = {"memory_recall", "planning", "executing", "complete", "failed", "scratchpad"} + expected = {"memory_recall", "planning", "reasoning", "executing", "complete", "failed", "scratchpad"} assert expected == set(PHASE_LABELS.keys()) + @patch("anton.chat_ui.Live") + def test_reasoning_progress_updates_spinner(self, MockLive): + display, _ = self._make_display() + display.start() + live = MockLive.return_value + + display.update_progress("reasoning", "Thinking...") + + assert display._line2_status == "Thinking..." + assert live.update.call_count >= 1 + class TestActivityTracking: def _make_display(self): diff --git a/tests/test_client.py b/tests/test_client.py index 897d41d..904f98c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -6,6 +6,7 @@ from anton.config.settings import AntonSettings from anton.llm.client import LLMClient +from anton.llm.ollama import OllamaProvider from anton.llm.provider import LLMProvider, LLMResponse, Usage @@ -98,3 +99,20 @@ def test_unknown_coding_provider_raises(self): ) with pytest.raises(ValueError, match="Unknown coding provider"): LLMClient.from_settings(settings) + + def test_from_settings_ollama(self): + with patch("anton.llm.ollama.ollama"): + settings = AntonSettings( + planning_provider="ollama", + coding_provider="ollama", + planning_model="qwen3.5:4b", + coding_model="qwen3.5:4b", + ollama_base_url="http://localhost:11434/v1", + _env_file=None, + ) + client = LLMClient.from_settings(settings) + assert isinstance(client, LLMClient) + assert isinstance(client._planning_provider, OllamaProvider) + assert isinstance(client._coding_provider, OllamaProvider) + assert client.planning_provider_name == "ollama" + assert client.coding_provider_name == "ollama" diff --git a/tests/test_llm_setup.py b/tests/test_llm_setup.py new file mode 100644 index 0000000..b7b7b13 --- /dev/null +++ b/tests/test_llm_setup.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from anton.config.settings import AntonSettings +from anton.llm.ollama import OllamaModelInfo +from anton.llm.setup import configure_llm_settings +from anton.workspace import Workspace + + +@pytest.fixture(autouse=True) +def clean_llm_env(monkeypatch): + for key in ( + "ANTON_ANTHROPIC_API_KEY", + "ANTON_OPENAI_API_KEY", + "ANTON_OPENAI_BASE_URL", + "ANTON_OLLAMA_BASE_URL", + "ANTON_PLANNING_PROVIDER", + "ANTON_CODING_PROVIDER", + "ANTON_PLANNING_MODEL", + "ANTON_CODING_MODEL", + ): + monkeypatch.delenv(key, raising=False) + + +class TestConfigureLlmSettings: + @patch("anton.llm.setup.list_ollama_models") + @patch("anton.llm.setup.Prompt.ask") + def test_ollama_discovery_success(self, mock_ask, mock_list, tmp_path): + mock_ask.side_effect = ["3", "http://localhost:11434", "1", "2"] + mock_list.return_value = [ + OllamaModelInfo(name="qwen3.5:4b", parameter_size="4.7B", quantization="Q4_K_M"), + OllamaModelInfo(name="qwen3:0.6b", parameter_size="751.63M", quantization="Q4_K_M"), + ] + + console = MagicMock() + workspace = Workspace(tmp_path) + settings = AntonSettings(_env_file=None) + + applied = configure_llm_settings(console, settings, workspace, show_current_config=False) + + assert applied is True + assert settings.planning_provider == "ollama" + assert settings.coding_provider == "ollama" + assert settings.ollama_base_url == "http://localhost:11434" + assert settings.planning_model == "qwen3.5:4b" + assert settings.coding_model == "qwen3:0.6b" + assert workspace.get_secret("ANTON_OLLAMA_BASE_URL") == "http://localhost:11434" + + @patch("anton.llm.setup.list_ollama_models") + @patch("anton.llm.setup.Prompt.ask") + def test_ollama_manual_fallback_when_discovery_fails(self, mock_ask, mock_list, tmp_path): + mock_ask.side_effect = ["3", "localhost:11434/v1", "qwen3.5:4b", "qwen3.5:4b"] + mock_list.side_effect = ConnectionError("boom") + + console = MagicMock() + workspace = Workspace(tmp_path) + settings = AntonSettings(_env_file=None) + + applied = configure_llm_settings(console, settings, workspace, show_current_config=False) + + assert applied is True + assert settings.planning_provider == "ollama" + assert settings.ollama_base_url == "http://localhost:11434" + assert settings.planning_model == "qwen3.5:4b" + assert settings.coding_model == "qwen3.5:4b" + + @patch("anton.llm.setup.Prompt.ask") + def test_anthropic_setup_persists_api_key(self, mock_ask, tmp_path): + mock_ask.side_effect = ["1", "sk-ant-test", "claude-sonnet-4-6", "claude-haiku-4-5-20251001"] + + console = MagicMock() + workspace = Workspace(tmp_path) + settings = AntonSettings(_env_file=None) + + applied = configure_llm_settings(console, settings, workspace, show_current_config=False) + + assert applied is True + assert settings.planning_provider == "anthropic" + assert settings.anthropic_api_key == "sk-ant-test" + assert workspace.get_secret("ANTON_ANTHROPIC_API_KEY") == "sk-ant-test" diff --git a/tests/test_ollama_provider.py b/tests/test_ollama_provider.py new file mode 100644 index 0000000..a0daa78 --- /dev/null +++ b/tests/test_ollama_provider.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +from anton.llm.ollama import ( + OllamaProvider, + normalize_ollama_base_url, + translate_messages_to_ollama, +) +from anton.llm.provider import ( + StreamComplete, + StreamTaskProgress, + StreamTextDelta, + StreamToolUseDelta, + StreamToolUseEnd, + StreamToolUseStart, +) + + +class _AsyncIter: + def __init__(self, items): + self._items = list(items) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._items: + raise StopAsyncIteration + return self._items.pop(0) + + +def _ollama_response(*, content="", tool_calls=None, prompt_tokens=5, completion_tokens=10, done_reason="stop"): + return SimpleNamespace( + message=SimpleNamespace(content=content, tool_calls=tool_calls or [], thinking=None), + prompt_eval_count=prompt_tokens, + eval_count=completion_tokens, + done_reason=done_reason, + ) + + +class TestOllamaProvider: + async def test_complete_text_response(self): + with patch("anton.llm.ollama.ollama") as mock_ollama: + mock_client = AsyncMock() + mock_ollama.AsyncClient.return_value = mock_client + mock_client.chat = AsyncMock(return_value=_ollama_response(content="Hello world")) + + provider = OllamaProvider(base_url="http://localhost:11434/v1") + result = await provider.complete( + model="qwen3.5:4b", + system="be helpful", + messages=[{"role": "user", "content": "hi"}], + max_tokens=123, + request_options={"think": False}, + ) + + call_kwargs = mock_client.chat.call_args.kwargs + assert call_kwargs["think"] is False + assert call_kwargs["options"] == {"num_predict": 123} + assert result.content == "Hello world" + assert result.tool_calls == [] + assert result.usage.input_tokens == 5 + assert result.usage.output_tokens == 10 + assert result.stop_reason == "stop" + + async def test_complete_tool_use_response(self): + with patch("anton.llm.ollama.ollama") as mock_ollama: + mock_client = AsyncMock() + mock_ollama.AsyncClient.return_value = mock_client + tool_call = SimpleNamespace( + function=SimpleNamespace(name="lookup_weather", arguments={"city": "London"}) + ) + mock_client.chat = AsyncMock( + return_value=_ollama_response(content="", tool_calls=[tool_call], done_reason="tool_calls") + ) + + provider = OllamaProvider(base_url="http://localhost:11434") + result = await provider.complete( + model="qwen3.5:4b", + system="plan", + messages=[{"role": "user", "content": "weather?"}], + tools=[{"name": "lookup_weather", "description": "d", "input_schema": {"type": "object"}}], + ) + + assert result.content == "" + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].id == "ollama_tool_1" + assert result.tool_calls[0].name == "lookup_weather" + assert result.tool_calls[0].input == {"city": "London"} + assert result.stop_reason == "tool_calls" + + async def test_stream_emits_reasoning_progress_and_tool_events(self): + with patch("anton.llm.ollama.ollama") as mock_ollama: + mock_client = AsyncMock() + mock_ollama.AsyncClient.return_value = mock_client + tool_call = SimpleNamespace( + function=SimpleNamespace(name="lookup_weather", arguments={"city": "London"}) + ) + mock_client.chat = AsyncMock( + return_value=_AsyncIter([ + SimpleNamespace( + message=SimpleNamespace(content="", thinking="Working", tool_calls=[]), + prompt_eval_count=5, + eval_count=0, + done_reason=None, + ), + SimpleNamespace( + message=SimpleNamespace(content="Done.", thinking=None, tool_calls=[tool_call]), + prompt_eval_count=None, + eval_count=6, + done_reason="stop", + ), + ]) + ) + + provider = OllamaProvider(base_url="http://localhost:11434") + events = [event async for event in provider.stream( + model="qwen3.5:4b", + system="sys", + messages=[{"role": "user", "content": "hi"}], + )] + + assert any(isinstance(event, StreamTaskProgress) and event.phase == "reasoning" for event in events) + assert any(isinstance(event, StreamTextDelta) and event.text == "Done." for event in events) + assert any(isinstance(event, StreamToolUseStart) and event.name == "lookup_weather" for event in events) + assert any(isinstance(event, StreamToolUseDelta) for event in events) + assert any(isinstance(event, StreamToolUseEnd) for event in events) + complete = next(event for event in events if isinstance(event, StreamComplete)) + assert complete.response.content == "Done." + assert complete.response.tool_calls[0].input == {"city": "London"} + + +class TestOllamaHelpers: + def test_normalize_ollama_base_url_strips_v1(self): + assert normalize_ollama_base_url("localhost:11434/v1") == "http://localhost:11434" + + def test_translate_messages_to_ollama_handles_tool_results(self): + messages = [ + {"role": "user", "content": "Find the file"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "I will check."}, + {"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/tmp/test.txt"}}, + ], + }, + { + "role": "user", + "content": [ + {"type": "tool_result", "tool_use_id": "tool_1", "content": "hello"}, + ], + }, + ] + + translated = translate_messages_to_ollama("system", messages) + + assert translated[0] == {"role": "system", "content": "system"} + assert translated[2]["role"] == "assistant" + assert translated[2]["tool_calls"][0]["function"]["name"] == "read_file" + assert translated[3] == {"role": "tool", "tool_name": "read_file", "content": "hello"} diff --git a/tests/test_scratchpad.py b/tests/test_scratchpad.py index a44a316..782668c 100644 --- a/tests/test_scratchpad.py +++ b/tests/test_scratchpad.py @@ -458,6 +458,22 @@ async def test_api_key_bridged(self, monkeypatch): finally: await pad.close() + async def test_ollama_host_bridged(self, monkeypatch): + """ANTON_OLLAMA_BASE_URL should be bridged to OLLAMA_HOST.""" + monkeypatch.setenv("ANTON_OLLAMA_BASE_URL", "http://localhost:11434") + monkeypatch.delenv("OLLAMA_HOST", raising=False) + pad = Scratchpad(name="ollama-host-test", _coding_provider="ollama", _coding_model="qwen3.5:4b") + await pad.start() + try: + cell = await pad.execute( + "import os; print(os.environ.get('OLLAMA_HOST', 'MISSING'))" + ) + assert cell.stdout.strip() == "http://localhost:11434" + llm_cell = await pad.execute("llm = get_llm(); print(llm.model)") + assert llm_cell.stdout.strip() == "qwen3.5:4b" + finally: + await pad.close() + class TestScratchpadVenv: async def test_venv_created_on_start(self): diff --git a/tests/test_settings.py b/tests/test_settings.py index 5833fe8..e67947b 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -8,35 +8,54 @@ from anton.config.settings import AntonSettings +@pytest.fixture(autouse=True) +def clean_anton_env(monkeypatch): + for key in ( + "ANTON_PLANNING_PROVIDER", + "ANTON_PLANNING_MODEL", + "ANTON_CODING_PROVIDER", + "ANTON_CODING_MODEL", + "ANTON_ANTHROPIC_API_KEY", + "ANTON_OPENAI_API_KEY", + "ANTON_OPENAI_BASE_URL", + "ANTON_OLLAMA_BASE_URL", + ): + monkeypatch.delenv(key, raising=False) + + class TestAntonSettingsDefaults: def test_default_planning_provider(self): - s = AntonSettings(anthropic_api_key="test") + s = AntonSettings(anthropic_api_key="test", _env_file=None) assert s.planning_provider == "anthropic" def test_default_planning_model(self): - s = AntonSettings(anthropic_api_key="test") + s = AntonSettings(anthropic_api_key="test", _env_file=None) assert s.planning_model == "claude-sonnet-4-6" def test_default_coding_provider(self): - s = AntonSettings(anthropic_api_key="test") + s = AntonSettings(anthropic_api_key="test", _env_file=None) assert s.coding_provider == "anthropic" def test_default_coding_model(self): - s = AntonSettings(anthropic_api_key="test") + s = AntonSettings(anthropic_api_key="test", _env_file=None) assert s.coding_model == "claude-haiku-4-5-20251001" def test_default_memory_dir(self): - s = AntonSettings(anthropic_api_key="test") + s = AntonSettings(anthropic_api_key="test", _env_file=None) assert s.memory_dir == ".anton" def test_default_context_dir(self): - s = AntonSettings(anthropic_api_key="test") + s = AntonSettings(anthropic_api_key="test", _env_file=None) assert s.context_dir == ".anton/context" def test_default_api_key_is_none(self): s = AntonSettings(_env_file=None) assert s.anthropic_api_key is None + def test_default_ollama_base_url(self): + s = AntonSettings(_env_file=None) + assert s.ollama_base_url == "http://localhost:11434" + class TestAntonSettingsEnvOverride: def test_env_overrides_planning_model(self, monkeypatch): @@ -49,6 +68,11 @@ def test_env_overrides_api_key(self, monkeypatch): s = AntonSettings(_env_file=None) assert s.anthropic_api_key == "sk-test-key" + def test_env_overrides_ollama_base_url(self, monkeypatch): + monkeypatch.setenv("ANTON_OLLAMA_BASE_URL", "http://example.test:11434") + s = AntonSettings(_env_file=None) + assert s.ollama_base_url == "http://example.test:11434" + class TestWorkspaceResolution: def test_resolve_workspace_defaults_to_cwd(self, tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) From 199419e34ec57fd36ecca2c26b0aaafc1958ffca Mon Sep 17 00:00:00 2001 From: ianu82 Date: Thu, 19 Mar 2026 10:57:23 +0000 Subject: [PATCH 2/2] Preserve 4xx Ollama errors instead of collapsing into ConnectionError ollama.ResponseError carries a status_code, so 4xx errors like "model not found" (404) should surface as actionable errors rather than being misreported as connectivity failures. Only 5xx errors are now mapped to ConnectionError, matching the pattern used by the Anthropic and OpenAI providers. Co-Authored-By: Claude Opus 4.6 (1M context) --- anton/llm/ollama.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/anton/llm/ollama.py b/anton/llm/ollama.py index ba85884..a1882a6 100644 --- a/anton/llm/ollama.py +++ b/anton/llm/ollama.py @@ -197,6 +197,8 @@ async def complete( message = str(exc).lower() if "context" in message and "length" in message: raise ContextOverflowError(str(exc)) from exc + if exc.status_code >= 400 and exc.status_code < 500: + raise raise ConnectionError(str(exc)) from exc except ConnectionError as exc: raise ConnectionError(str(exc)) from exc @@ -281,6 +283,8 @@ async def stream( message = str(exc).lower() if "context" in message and "length" in message: raise ContextOverflowError(str(exc)) from exc + if exc.status_code >= 400 and exc.status_code < 500: + raise raise ConnectionError(str(exc)) from exc except ConnectionError as exc: raise ConnectionError(str(exc)) from exc