diff --git a/anton/chat.py b/anton/chat.py index b86e36c..0e6a29c 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -2,6 +2,7 @@ import asyncio import os +import re import sys import time from collections.abc import AsyncIterator, Callable @@ -65,6 +66,21 @@ "Only involve the user if the problem truly requires something only they can provide." ) +_VERIFIER_STATUS_PREFIXES = ( + "STATUS: COMPLETE", + "STATUS: INCOMPLETE", + "STATUS: STUCK", +) + + +def _strip_ollama_think_tags(text: str) -> str: + return re.sub(r".*?", "", text, flags=re.IGNORECASE | re.DOTALL) + + +def _is_parseable_verifier_status(text: str) -> bool: + upper = text.upper() + return any(prefix in upper for prefix in _VERIFIER_STATUS_PREFIXES) + class ChatSession: """Manages a multi-turn conversation with tool-call delegation.""" @@ -749,6 +765,11 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato "has been fully completed based on the conversation above." ), }] + verifier_request_options = ( + {"think": False} + if self._llm.planning_provider_name == "ollama" + else None + ) verification = await self._llm.plan( system=( "You are a task-completion verifier. Given the conversation, determine " @@ -766,14 +787,21 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato ), messages=verify_messages, max_tokens=256, + request_options=verifier_request_options, ) - status_text = (verification.content or "").strip().upper() + verification_text = (verification.content or "").strip() + if self._llm.planning_provider_name == "ollama": + verification_text = _strip_ollama_think_tags(verification_text).strip() + if not verification_text or not _is_parseable_verifier_status(verification_text): + break + + status_text = verification_text.upper() if "STATUS: COMPLETE" in status_text: break if "STATUS: STUCK" in status_text: # Stuck — inject diagnosis request and let the LLM explain - reason = (verification.content or "").strip() + reason = verification_text self._history.append({ "role": "user", "content": ( @@ -795,7 +823,7 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato # INCOMPLETE — continue working continuation += 1 - reason = (verification.content or "").strip() + reason = verification_text self._history.append({ "role": "user", "content": ( @@ -957,7 +985,7 @@ def _rebuild_session( runtime_context = _build_runtime_context(settings) api_key = ( settings.anthropic_api_key if settings.coding_provider == "anthropic" - else settings.openai_api_key + else settings.openai_api_key if settings.coding_provider in {"openai", "openai-compatible"} else "" ) or "" return ChatSession( state["llm_client"], @@ -1211,107 +1239,22 @@ async def _handle_setup_models( session_id: str | None = None, ) -> ChatSession: """Setup sub-menu: provider, API key, and models.""" - from rich.prompt import Prompt - + from anton.llm.setup import configure_llm_settings from anton.workspace import Workspace as _Workspace # Always persist API keys and model settings to global ~/.anton/.env global_ws = _Workspace(Path.home()) - console.print() - console.print("[anton.cyan]Current configuration:[/]") - console.print(f" Provider (planning): [bold]{settings.planning_provider}[/]") - console.print(f" Provider (coding): [bold]{settings.coding_provider}[/]") - console.print(f" Planning model: [bold]{settings.planning_model}[/]") - console.print(f" Coding model: [bold]{settings.coding_model}[/]") - console.print() - - # --- Provider --- - providers = {"1": "anthropic", "2": "openai", "3": "openai-compatible"} - current_num = {"anthropic": "1", "openai": "2", "openai-compatible": "3"}.get(settings.planning_provider, "1") - console.print("[anton.cyan]Available providers:[/]") - console.print(r" [bold]1[/] Anthropic (Claude) [dim]\[recommended][/]") - console.print(r" [bold]2[/] OpenAI (GPT / o-series) [dim]\[experimental][/]") - console.print(r" [bold]3[/] OpenAI-compatible (custom endpoint) [dim]\[experimental][/]") - console.print() - - choice = Prompt.ask( - "Select provider", - choices=["1", "2", "3"], - default=current_num, - console=console, - ) - provider = providers[choice] - - # --- Base URL (OpenAI-compatible only) --- - if provider == "openai-compatible": - current_base_url = settings.openai_base_url or "" - console.print() - base_url = Prompt.ask( - f"API base URL [dim](e.g. http://localhost:11434/v1)[/]", - default=current_base_url, - console=console, - ) - base_url = base_url.strip() - if base_url: - settings.openai_base_url = base_url - global_ws.set_secret("ANTON_OPENAI_BASE_URL", base_url) - - # --- API key --- - key_attr = "anthropic_api_key" if provider == "anthropic" else "openai_api_key" - current_key = getattr(settings, key_attr) or "" - masked = current_key[:4] + "..." + current_key[-4:] if len(current_key) > 8 else "***" - console.print() - api_key = Prompt.ask( - f"API key for {provider.title()} [dim](Enter to keep {masked})[/]", - default="", - console=console, - ) - api_key = api_key.strip() - - # --- Models --- - defaults = { - "anthropic": ("claude-sonnet-4-6", "claude-haiku-4-5-20251001"), - "openai": ("gpt-5-mini", "gpt-5-nano"), - } - default_planning, default_coding = defaults.get(provider, ("", "")) - - console.print() - planning_model = Prompt.ask( - "Planning model", - default=settings.planning_model if provider == settings.planning_provider else default_planning, - console=console, - ) - coding_model = Prompt.ask( - "Coding model", - default=settings.coding_model if provider == settings.coding_provider else default_coding, - console=console, + applied = configure_llm_settings( + console, + settings, + global_ws, + show_current_config=True, ) - - # --- Persist to global ~/.anton/.env --- - settings.planning_provider = provider - settings.coding_provider = provider - settings.planning_model = planning_model - settings.coding_model = coding_model - - global_ws.set_secret("ANTON_PLANNING_PROVIDER", provider) - global_ws.set_secret("ANTON_CODING_PROVIDER", provider) - global_ws.set_secret("ANTON_PLANNING_MODEL", planning_model) - global_ws.set_secret("ANTON_CODING_MODEL", coding_model) - - if api_key: - setattr(settings, key_attr, api_key) - key_name = f"ANTON_{provider.upper()}_API_KEY" - global_ws.set_secret(key_name, api_key) - - # Validate that we actually have an API key for the chosen provider - final_key = getattr(settings, key_attr) - if not final_key: - console.print() - console.print(f"[anton.error]No API key set for {provider}. Configuration not applied.[/]") - console.print() + if not applied: return session + global_ws.apply_env_to_process() console.print() console.print("[anton.success]Configuration updated.[/]") console.print() @@ -1619,6 +1562,7 @@ def _minds_test_llm(base_url: str, api_key: str, verify: bool = True) -> bool: "ANTON_PLANNING_PROVIDER", "ANTON_CODING_PROVIDER", "ANTON_PLANNING_MODEL", "ANTON_CODING_MODEL", "ANTON_ANTHROPIC_API_KEY", "ANTON_OPENAI_API_KEY", "ANTON_OPENAI_BASE_URL", + "ANTON_OLLAMA_BASE_URL", } _SECRET_PATTERNS = ("KEY", "TOKEN", "SECRET", "PAT", "PASSWORD") @@ -2380,7 +2324,7 @@ async def _chat_loop(console: Console, settings: AntonSettings, *, resume: bool coding_api_key = ( settings.anthropic_api_key if settings.coding_provider == "anthropic" - else settings.openai_api_key + else settings.openai_api_key if settings.coding_provider in {"openai", "openai-compatible"} else "" ) or "" session = ChatSession( state["llm_client"], diff --git a/anton/chat_ui.py b/anton/chat_ui.py index 98e9f28..accf6e3 100644 --- a/anton/chat_ui.py +++ b/anton/chat_ui.py @@ -131,6 +131,7 @@ def _tool_display_text(name: str, input_json: str) -> str: PHASE_LABELS = { "memory_recall": "Memory", "planning": "Planning", + "reasoning": "Thinking", "executing": "Executing", "complete": "Complete", "failed": "Failed", @@ -154,6 +155,7 @@ def __init__(self, console: Console, toolbar: dict | None = None) -> None: self._buffer = "" # answer text accumulated during streaming self._in_tool_phase = False self._last_was_tool = False + self._answer_started = False self._initial_text = "" self._initial_printed = False self._active = False @@ -234,6 +236,7 @@ def start(self) -> None: self._initial_printed = False self._in_tool_phase = False self._last_was_tool = False + self._answer_started = False self._cancel_msg = "" self._active = True self._start_spinner() @@ -244,6 +247,7 @@ def append_text(self, delta: str) -> None: if self._in_tool_phase: self._buffer += delta self._last_was_tool = False + self._answer_started = True self._line3_peek = self._extract_peek(self._buffer) self._update_spinner() else: @@ -308,6 +312,13 @@ def update_progress(self, phase: str, message: str, eta: float | None = None) -> self._update_spinner() return + if phase == "reasoning": + self._line2_status = message or "Thinking..." + self._line3_peek = "" + self._set_status(self._line2_status) + self._update_spinner() + return + if phase == "scratchpad_start": # Print the scratchpad activity line NOW (before execution) for act in reversed(self._activities): diff --git a/anton/cli.py b/anton/cli.py index a701342..3879846 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -28,6 +28,7 @@ def _reexec() -> None: # Core dependencies from pyproject.toml that anton needs at runtime _REQUIRED_PACKAGES: dict[str, str] = { "anthropic": "anthropic>=0.42.0", + "ollama": "ollama>=0.6.1", "openai": "openai>=1.0", "pydantic": "pydantic>=2.0", "pydantic_settings": "pydantic-settings>=2.0", @@ -224,9 +225,11 @@ def main( def _has_api_key(settings) -> bool: - """Check if all configured providers have API keys.""" + """Check if the configured providers are ready to use.""" providers = {settings.planning_provider, settings.coding_provider} for p in providers: + if p == "ollama": + continue if p == "anthropic" and not (settings.anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY")): return False if p in ("openai", "openai-compatible") and not (settings.openai_api_key or os.environ.get("OPENAI_API_KEY")): @@ -235,12 +238,11 @@ def _has_api_key(settings) -> bool: def _ensure_api_key(settings) -> None: - """Prompt the user to configure a provider and API key if none is set.""" + """Prompt the user to configure an LLM provider if needed.""" if _has_api_key(settings): return - from rich.prompt import Prompt - + from anton.llm.setup import configure_llm_settings from anton.workspace import Workspace ws = Workspace(Path.home()) @@ -248,7 +250,14 @@ def _ensure_api_key(settings) -> None: if settings.minds_enabled: _ensure_minds_api_key(settings, ws) else: - _ensure_anthropic_api_key(settings, ws) + applied = configure_llm_settings( + console, + settings, + ws, + show_current_config=False, + ) + if not applied: + raise typer.Exit(1) # Reload env vars into the process so the scratchpad subprocess inherits them ws.apply_env_to_process() diff --git a/anton/config/settings.py b/anton/config/settings.py index 1aedfc2..b114cc8 100644 --- a/anton/config/settings.py +++ b/anton/config/settings.py @@ -34,6 +34,7 @@ class AntonSettings(BaseSettings): anthropic_api_key: str | None = None openai_api_key: str | None = None openai_base_url: str | None = None + ollama_base_url: str = "http://localhost:11434" memory_enabled: bool = True memory_dir: str = ".anton" diff --git a/anton/llm/anthropic.py b/anton/llm/anthropic.py index 8aec58b..13d759b 100644 --- a/anton/llm/anthropic.py +++ b/anton/llm/anthropic.py @@ -37,6 +37,7 @@ async def complete( tools: list[dict] | None = None, tool_choice: dict | None = None, max_tokens: int = 4096, + request_options: dict | None = None, ) -> LLMResponse: kwargs: dict = { "model": model, @@ -96,6 +97,7 @@ async def stream( messages: list[dict], tools: list[dict] | None = None, max_tokens: int = 4096, + request_options: dict | None = None, ) -> AsyncIterator[StreamEvent]: kwargs: dict = { "model": model, diff --git a/anton/llm/client.py b/anton/llm/client.py index a58c217..a6a5bc8 100644 --- a/anton/llm/client.py +++ b/anton/llm/client.py @@ -13,14 +13,18 @@ class LLMClient: def __init__( self, *, + planning_provider_name: str = "anthropic", planning_provider: LLMProvider, planning_model: str, + coding_provider_name: str = "anthropic", coding_provider: LLMProvider, coding_model: str, max_tokens: int = 8192, ) -> None: + self._planning_provider_name = planning_provider_name self._planning_provider = planning_provider self._planning_model = planning_model + self._coding_provider_name = coding_provider_name self._coding_provider = coding_provider self._coding_model = coding_model self._max_tokens = max_tokens @@ -32,6 +36,7 @@ async def plan( messages: list[dict], tools: list[dict] | None = None, max_tokens: int | None = None, + request_options: dict | None = None, ) -> LLMResponse: return await self._planning_provider.complete( model=self._planning_model, @@ -39,6 +44,7 @@ async def plan( messages=messages, tools=tools, max_tokens=max_tokens or self._max_tokens, + request_options=request_options, ) async def plan_stream( @@ -48,6 +54,7 @@ async def plan_stream( messages: list[dict], tools: list[dict] | None = None, max_tokens: int | None = None, + request_options: dict | None = None, ) -> AsyncIterator[StreamEvent]: async for event in self._planning_provider.stream( model=self._planning_model, @@ -55,14 +62,23 @@ async def plan_stream( messages=messages, tools=tools, max_tokens=max_tokens or self._max_tokens, + request_options=request_options, ): yield event + @property + def planning_provider_name(self) -> str: + return self._planning_provider_name + @property def coding_provider(self) -> LLMProvider: """The LLM provider used for coding/skill execution.""" return self._coding_provider + @property + def coding_provider_name(self) -> str: + return self._coding_provider_name + @property def coding_model(self) -> str: """The model name used for coding/skill execution.""" @@ -75,6 +91,7 @@ async def code( messages: list[dict], tools: list[dict] | None = None, max_tokens: int | None = None, + request_options: dict | None = None, ) -> LLMResponse: return await self._coding_provider.complete( model=self._coding_model, @@ -82,16 +99,19 @@ async def code( messages=messages, tools=tools, max_tokens=max_tokens or self._max_tokens, + request_options=request_options, ) @classmethod def from_settings(cls, settings: AntonSettings) -> LLMClient: from anton.llm.anthropic import AnthropicProvider + from anton.llm.ollama import OllamaProvider from anton.llm.openai import OpenAIProvider providers = { "anthropic": lambda: AnthropicProvider(api_key=settings.anthropic_api_key), "openai": lambda: OpenAIProvider(api_key=settings.openai_api_key, base_url=settings.openai_base_url, ssl_verify=settings.minds_ssl_verify), + "ollama": lambda: OllamaProvider(base_url=settings.ollama_base_url), "openai-compatible": lambda: OpenAIProvider(api_key=settings.openai_api_key, base_url=settings.openai_base_url, ssl_verify=settings.minds_ssl_verify), } @@ -104,8 +124,10 @@ def from_settings(cls, settings: AntonSettings) -> LLMClient: raise ValueError(f"Unknown coding provider: {settings.coding_provider}") return cls( + planning_provider_name=settings.planning_provider, planning_provider=planning_factory(), planning_model=settings.planning_model, + coding_provider_name=settings.coding_provider, coding_provider=coding_factory(), coding_model=settings.coding_model, max_tokens=getattr(settings, "max_tokens", 8192), diff --git a/anton/llm/ollama.py b/anton/llm/ollama.py new file mode 100644 index 0000000..a1882a6 --- /dev/null +++ b/anton/llm/ollama.py @@ -0,0 +1,358 @@ +from __future__ import annotations + +import json +from collections.abc import AsyncIterator +from dataclasses import dataclass + +import ollama + +from anton.llm.openai_compat import translate_tools +from anton.llm.provider import ( + ContextOverflowError, + LLMProvider, + LLMResponse, + StreamComplete, + StreamEvent, + StreamTaskProgress, + StreamTextDelta, + StreamToolUseDelta, + StreamToolUseEnd, + StreamToolUseStart, + ToolCall, + Usage, + compute_context_pressure, +) + +_DEFAULT_OLLAMA_BASE_URL = "http://localhost:11434" + + +@dataclass(frozen=True) +class OllamaModelInfo: + name: str + size: str = "" + quantization: str = "" + parameter_size: str = "" + + @property + def display_name(self) -> str: + details = [detail for detail in (self.parameter_size, self.quantization) if detail] + if details: + return f"{self.name} ({', '.join(details)})" + if self.size: + return f"{self.name} ({self.size})" + return self.name + + +def normalize_ollama_base_url(base_url: str | None) -> str: + url = (base_url or _DEFAULT_OLLAMA_BASE_URL).strip() + if not url: + url = _DEFAULT_OLLAMA_BASE_URL + if not url.startswith(("http://", "https://")): + url = f"http://{url}" + url = url.rstrip("/") + if url.endswith("/v1"): + url = url[:-3].rstrip("/") + return url + + +def list_ollama_models(base_url: str | None = None) -> list[OllamaModelInfo]: + client = ollama.Client(host=normalize_ollama_base_url(base_url)) + response = client.list() + models: list[OllamaModelInfo] = [] + for model in response.models: + size = str(model.size) if model.size is not None else "" + details = model.details + models.append( + OllamaModelInfo( + name=model.model or "", + size=size, + quantization=details.quantization_level if details and details.quantization_level else "", + parameter_size=details.parameter_size if details and details.parameter_size else "", + ) + ) + return models + + +def translate_messages_to_ollama(system: str, messages: list[dict]) -> list[dict]: + """Convert Anthropic-style messages to native Ollama chat format.""" + result: list[dict] = [] + tool_name_by_id: dict[str, str] = {} + if system: + result.append({"role": "system", "content": system}) + + for message in messages: + role = message["role"] + content = message.get("content") + + if isinstance(content, str): + result.append({"role": role, "content": content}) + continue + + if not isinstance(content, list): + result.append({"role": role, "content": str(content) if content else ""}) + continue + + if role == "assistant": + text_parts: list[str] = [] + tool_calls: list[dict] = [] + for block in content: + if block.get("type") == "text": + text_parts.append(block.get("text", "")) + elif block.get("type") == "tool_use": + tool_name_by_id[block["id"]] = block["name"] + tool_calls.append({ + "function": { + "name": block["name"], + "arguments": block.get("input", {}), + } + }) + msg: dict[str, object] = {"role": "assistant"} + if text_parts: + msg["content"] = "\n".join(text_parts) + if tool_calls: + msg["tool_calls"] = tool_calls + result.append(msg) + continue + + if role == "user": + text_parts: list[str] = [] + images: list[str] = [] + for block in content: + block_type = block.get("type") + if block_type == "tool_result": + if text_parts or images: + result.append(_build_ollama_user_message(text_parts, images)) + text_parts = [] + images = [] + tool_name = tool_name_by_id.get(block["tool_use_id"], block["tool_use_id"]) + tool_content = block.get("content", "") + if isinstance(tool_content, list): + tool_content = "\n".join( + item.get("text", "") + for item in tool_content + if item.get("type") == "text" + ) + result.append({ + "role": "tool", + "tool_name": tool_name, + "content": str(tool_content), + }) + elif block_type == "text": + text_parts.append(block.get("text", "")) + elif block_type == "image": + source = block.get("source", {}) + if source.get("type") == "base64" and source.get("data"): + images.append(source["data"]) + if text_parts or images: + result.append(_build_ollama_user_message(text_parts, images)) + continue + + text = " ".join( + block.get("text", "") + for block in content + if block.get("type") == "text" + ) + result.append({"role": role, "content": text or ""}) + + return result + + +def _build_ollama_user_message(text_parts: list[str], images: list[str]) -> dict[str, object]: + message: dict[str, object] = {"role": "user"} + if text_parts: + message["content"] = "\n".join(text_parts) + if images: + message["images"] = images + return message + + +class OllamaProvider(LLMProvider): + def __init__(self, base_url: str | None = None) -> None: + self._base_url = normalize_ollama_base_url(base_url) + self._client = ollama.AsyncClient(host=self._base_url) + + async def complete( + self, + *, + model: str, + system: str, + messages: list[dict], + tools: list[dict] | None = None, + tool_choice: dict | None = None, + max_tokens: int = 4096, + request_options: dict | None = None, + ) -> LLMResponse: + kwargs = self._build_chat_kwargs( + model=model, + system=system, + messages=messages, + tools=tools, + tool_choice=tool_choice, + max_tokens=max_tokens, + request_options=request_options, + ) + try: + response = await self._client.chat(**kwargs) + except ollama.ResponseError as exc: + message = str(exc).lower() + if "context" in message and "length" in message: + raise ContextOverflowError(str(exc)) from exc + if exc.status_code >= 400 and exc.status_code < 500: + raise + raise ConnectionError(str(exc)) from exc + except ConnectionError as exc: + raise ConnectionError(str(exc)) from exc + + content_text = response.message.content or "" + tool_calls = self._tool_calls_from_message(response.message.tool_calls or []) + input_tokens = response.prompt_eval_count or 0 + output_tokens = response.eval_count or 0 + return LLMResponse( + content=content_text, + tool_calls=tool_calls, + usage=Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + context_pressure=compute_context_pressure(model, input_tokens), + ), + stop_reason=response.done_reason, + ) + + async def stream( + self, + *, + model: str, + system: str, + messages: list[dict], + tools: list[dict] | None = None, + max_tokens: int = 4096, + request_options: dict | None = None, + ) -> AsyncIterator[StreamEvent]: + kwargs = self._build_chat_kwargs( + model=model, + system=system, + messages=messages, + tools=tools, + tool_choice=None, + max_tokens=max_tokens, + request_options=request_options, + ) + kwargs["stream"] = True + + content_text = "" + tool_calls: list[ToolCall] = [] + input_tokens = 0 + output_tokens = 0 + stop_reason: str | None = None + showed_reasoning = False + saw_content = False + next_tool_index = 1 + + try: + stream = await self._client.chat(**kwargs) + async for chunk in stream: + message = chunk.message + if not showed_reasoning and not saw_content and message.thinking: + showed_reasoning = True + yield StreamTaskProgress(phase="reasoning", message="Thinking...") + + if message.content: + saw_content = True + content_text += message.content + yield StreamTextDelta(text=message.content) + + if message.tool_calls: + for call in message.tool_calls: + tool_call = self._tool_call_from_ollama(call, next_tool_index) + next_tool_index += 1 + tool_calls.append(tool_call) + yield StreamToolUseStart(id=tool_call.id, name=tool_call.name) + yield StreamToolUseDelta( + id=tool_call.id, + json_delta=json.dumps(tool_call.input), + ) + yield StreamToolUseEnd(id=tool_call.id) + + if chunk.prompt_eval_count is not None: + input_tokens = chunk.prompt_eval_count + if chunk.eval_count is not None: + output_tokens = chunk.eval_count + if chunk.done_reason: + stop_reason = chunk.done_reason + except ollama.ResponseError as exc: + message = str(exc).lower() + if "context" in message and "length" in message: + raise ContextOverflowError(str(exc)) from exc + if exc.status_code >= 400 and exc.status_code < 500: + raise + raise ConnectionError(str(exc)) from exc + except ConnectionError as exc: + raise ConnectionError(str(exc)) from exc + + yield StreamComplete( + response=LLMResponse( + content=content_text, + tool_calls=tool_calls, + usage=Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + context_pressure=compute_context_pressure(model, input_tokens), + ), + stop_reason=stop_reason, + ) + ) + + def _build_chat_kwargs( + self, + *, + model: str, + system: str, + messages: list[dict], + tools: list[dict] | None, + tool_choice: dict | None, + max_tokens: int, + request_options: dict | None, + ) -> dict: + kwargs: dict[str, object] = { + "model": model, + "messages": translate_messages_to_ollama(system, messages), + "options": {"num_predict": max_tokens}, + } + + translated_tools = translate_tools(tools or []) if tools else None + if translated_tools: + kwargs["tools"] = self._apply_tool_choice(translated_tools, tool_choice) + + if request_options and "think" in request_options: + kwargs["think"] = request_options["think"] + + return kwargs + + @staticmethod + def _apply_tool_choice(tools: list[dict], tool_choice: dict | None) -> list[dict]: + if not tool_choice or tool_choice.get("type") != "tool": + return tools + target_name = tool_choice.get("name") + if not target_name: + return tools + filtered = [ + tool + for tool in tools + if tool.get("function", {}).get("name") == target_name + ] + return filtered or tools + + def _tool_calls_from_message(self, calls: list) -> list[ToolCall]: + return [ + self._tool_call_from_ollama(call, index) + for index, call in enumerate(calls, start=1) + ] + + @staticmethod + def _tool_call_from_ollama(call, index: int) -> ToolCall: + function = call.function + return ToolCall( + id=f"ollama_tool_{index}", + name=function.name, + input=dict(function.arguments or {}), + ) diff --git a/anton/llm/openai.py b/anton/llm/openai.py index 403bc55..f7fab1f 100644 --- a/anton/llm/openai.py +++ b/anton/llm/openai.py @@ -5,6 +5,11 @@ import openai +from anton.llm.openai_compat import ( + _translate_messages, + _translate_tool_choice, + _translate_tools, +) from anton.llm.provider import ( ContextOverflowError, LLMProvider, @@ -21,149 +26,6 @@ ) -def _translate_tools(tools: list[dict]) -> list[dict]: - """Anthropic tool format -> OpenAI function-calling format.""" - result = [] - for tool in tools: - result.append({ - "type": "function", - "function": { - "name": tool["name"], - "description": tool.get("description", ""), - "parameters": tool.get("input_schema", {}), - }, - }) - return result - - -def _translate_tool_choice(tool_choice: dict) -> dict | str: - """Anthropic tool_choice -> OpenAI tool_choice.""" - tc_type = tool_choice.get("type") - if tc_type == "tool": - return {"type": "function", "function": {"name": tool_choice["name"]}} - if tc_type == "any": - return "required" - if tc_type == "auto": - return "auto" - return "auto" - - -def _translate_messages(system: str, messages: list[dict]) -> list[dict]: - """Convert Anthropic-style messages to OpenAI chat format. - - Handles: - - system prompt -> {"role": "system", ...} - - plain text messages pass through - - assistant messages with tool_use content blocks -> tool_calls array - - user messages with tool_result content blocks -> role:tool messages - """ - result: list[dict] = [] - if system: - result.append({"role": "system", "content": system}) - - for msg in messages: - role = msg["role"] - content = msg.get("content") - - # Plain string content — pass through - if isinstance(content, str): - result.append({"role": role, "content": content}) - continue - - # Content is a list of blocks (Anthropic format) - if isinstance(content, list): - if role == "assistant": - result.extend(_translate_assistant_blocks(content)) - elif role == "user": - result.extend(_translate_user_blocks(content)) - else: - # Fallback: join text blocks - text = " ".join( - b.get("text", "") for b in content if b.get("type") == "text" - ) - result.append({"role": role, "content": text or ""}) - continue - - # Fallback - result.append({"role": role, "content": str(content) if content else ""}) - - return result - - -def _translate_assistant_blocks(blocks: list[dict]) -> list[dict]: - """Convert assistant content blocks to OpenAI message(s).""" - text_parts: list[str] = [] - tool_calls: list[dict] = [] - - for block in blocks: - if block.get("type") == "text": - text_parts.append(block["text"]) - elif block.get("type") == "tool_use": - tool_calls.append({ - "id": block["id"], - "type": "function", - "function": { - "name": block["name"], - "arguments": json.dumps(block.get("input", {})), - }, - }) - - msg: dict = {"role": "assistant"} - content = "\n".join(text_parts) if text_parts else None - msg["content"] = content - if tool_calls: - msg["tool_calls"] = tool_calls - return [msg] - - -def _translate_user_blocks(blocks: list[dict]) -> list[dict]: - """Convert user content blocks (including tool_result and image) to OpenAI messages.""" - result: list[dict] = [] - content_parts: list[dict] = [] # Accumulates text + image_url blocks - - for block in blocks: - if block.get("type") == "tool_result": - # Flush any accumulated content parts first - if content_parts: - result.append({"role": "user", "content": content_parts}) - content_parts = [] - # tool_result -> role:tool message - content = block.get("content", "") - if isinstance(content, list): - content = "\n".join( - b.get("text", "") for b in content if b.get("type") == "text" - ) - result.append({ - "role": "tool", - "tool_call_id": block["tool_use_id"], - "content": str(content), - }) - elif block.get("type") == "text": - content_parts.append({"type": "text", "text": block.get("text", "")}) - elif block.get("type") == "image": - # Anthropic image block -> OpenAI image_url block - source = block.get("source", {}) - if source.get("type") == "base64": - media_type = source.get("media_type", "image/png") - data = source.get("data", "") - content_parts.append({ - "type": "image_url", - "image_url": {"url": f"data:{media_type};base64,{data}"}, - }) - - if content_parts: - # If only text parts, flatten to a simple string for compatibility - if all(p.get("type") == "text" for p in content_parts): - result.append({ - "role": "user", - "content": "\n".join(p["text"] for p in content_parts), - }) - else: - result.append({"role": "user", "content": content_parts}) - - return result - - class OpenAIProvider(LLMProvider): def __init__( self, @@ -191,6 +53,7 @@ async def complete( tools: list[dict] | None = None, tool_choice: dict | None = None, max_tokens: int = 4096, + request_options: dict | None = None, ) -> LLMResponse: oai_messages = _translate_messages(system, messages) @@ -257,6 +120,7 @@ async def stream( messages: list[dict], tools: list[dict] | None = None, max_tokens: int = 4096, + request_options: dict | None = None, ) -> AsyncIterator[StreamEvent]: oai_messages = _translate_messages(system, messages) diff --git a/anton/llm/openai_compat.py b/anton/llm/openai_compat.py new file mode 100644 index 0000000..e5fc96c --- /dev/null +++ b/anton/llm/openai_compat.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import json + + +def translate_tools(tools: list[dict]) -> list[dict]: + """Anthropic tool format -> OpenAI/Ollama function-calling format.""" + result = [] + for tool in tools: + result.append({ + "type": "function", + "function": { + "name": tool["name"], + "description": tool.get("description", ""), + "parameters": tool.get("input_schema", {}), + }, + }) + return result + + +def translate_tool_choice(tool_choice: dict) -> dict | str: + """Anthropic tool_choice -> OpenAI tool_choice.""" + tc_type = tool_choice.get("type") + if tc_type == "tool": + return {"type": "function", "function": {"name": tool_choice["name"]}} + if tc_type == "any": + return "required" + if tc_type == "auto": + return "auto" + return "auto" + + +def translate_messages(system: str, messages: list[dict]) -> list[dict]: + """Convert Anthropic-style messages to OpenAI chat format.""" + result: list[dict] = [] + if system: + result.append({"role": "system", "content": system}) + + for msg in messages: + role = msg["role"] + content = msg.get("content") + + if isinstance(content, str): + result.append({"role": role, "content": content}) + continue + + if isinstance(content, list): + if role == "assistant": + result.extend(_translate_assistant_blocks(content)) + elif role == "user": + result.extend(_translate_user_blocks(content)) + else: + text = " ".join( + block.get("text", "") + for block in content + if block.get("type") == "text" + ) + result.append({"role": role, "content": text or ""}) + continue + + result.append({"role": role, "content": str(content) if content else ""}) + + return result + + +def _translate_assistant_blocks(blocks: list[dict]) -> list[dict]: + text_parts: list[str] = [] + tool_calls: list[dict] = [] + + for block in blocks: + if block.get("type") == "text": + text_parts.append(block["text"]) + elif block.get("type") == "tool_use": + tool_calls.append({ + "id": block["id"], + "type": "function", + "function": { + "name": block["name"], + "arguments": json.dumps(block.get("input", {})), + }, + }) + + msg: dict = {"role": "assistant"} + msg["content"] = "\n".join(text_parts) if text_parts else None + if tool_calls: + msg["tool_calls"] = tool_calls + return [msg] + + +def _translate_user_blocks(blocks: list[dict]) -> list[dict]: + result: list[dict] = [] + content_parts: list[dict] = [] + + for block in blocks: + if block.get("type") == "tool_result": + if content_parts: + result.append({"role": "user", "content": content_parts}) + content_parts = [] + content = block.get("content", "") + if isinstance(content, list): + content = "\n".join( + item.get("text", "") + for item in content + if item.get("type") == "text" + ) + result.append({ + "role": "tool", + "tool_call_id": block["tool_use_id"], + "content": str(content), + }) + elif block.get("type") == "text": + content_parts.append({"type": "text", "text": block.get("text", "")}) + elif block.get("type") == "image": + source = block.get("source", {}) + if source.get("type") == "base64": + media_type = source.get("media_type", "image/png") + data = source.get("data", "") + content_parts.append({ + "type": "image_url", + "image_url": {"url": f"data:{media_type};base64,{data}"}, + }) + + if content_parts: + if all(part.get("type") == "text" for part in content_parts): + result.append({ + "role": "user", + "content": "\n".join(part["text"] for part in content_parts), + }) + else: + result.append({"role": "user", "content": content_parts}) + + return result + + +_translate_assistant_blocks = _translate_assistant_blocks +_translate_messages = translate_messages +_translate_tool_choice = translate_tool_choice +_translate_tools = translate_tools +_translate_user_blocks = _translate_user_blocks diff --git a/anton/llm/provider.py b/anton/llm/provider.py index 359a77d..d1f6a8c 100644 --- a/anton/llm/provider.py +++ b/anton/llm/provider.py @@ -144,6 +144,7 @@ async def complete( tools: list[dict] | None = None, tool_choice: dict | None = None, max_tokens: int = 4096, + request_options: dict | None = None, ) -> LLMResponse: ... async def stream( @@ -154,6 +155,7 @@ async def stream( messages: list[dict], tools: list[dict] | None = None, max_tokens: int = 4096, + request_options: dict | None = None, ) -> AsyncIterator[StreamEvent]: """Stream LLM responses. Default falls back to complete().""" response = await self.complete( @@ -162,6 +164,7 @@ async def stream( messages=messages, tools=tools, max_tokens=max_tokens, + request_options=request_options, ) if response.content: yield StreamTextDelta(text=response.content) diff --git a/anton/llm/setup.py b/anton/llm/setup.py new file mode 100644 index 0000000..cc575d3 --- /dev/null +++ b/anton/llm/setup.py @@ -0,0 +1,333 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from rich.prompt import Prompt + +from anton.llm.ollama import OllamaModelInfo, list_ollama_models, normalize_ollama_base_url + +if TYPE_CHECKING: + from rich.console import Console + + from anton.config.settings import AntonSettings + from anton.workspace import Workspace + + +@dataclass(frozen=True) +class ProviderOption: + key: str + label: str + badge: str + default_planning_model: str + default_coding_model: str + + +PROVIDER_OPTIONS: tuple[ProviderOption, ...] = ( + ProviderOption( + key="anthropic", + label="Anthropic (Claude)", + badge="recommended", + default_planning_model="claude-sonnet-4-6", + default_coding_model="claude-haiku-4-5-20251001", + ), + ProviderOption( + key="openai", + label="OpenAI (GPT / o-series)", + badge="experimental", + default_planning_model="gpt-5-mini", + default_coding_model="gpt-5-nano", + ), + ProviderOption( + key="ollama", + label="Ollama (local models)", + badge="local", + default_planning_model="", + default_coding_model="", + ), + ProviderOption( + key="openai-compatible", + label="OpenAI-compatible (custom endpoint)", + badge="experimental", + default_planning_model="", + default_coding_model="", + ), +) + + +def configure_llm_settings( + console, + settings, + workspace, + *, + show_current_config: bool = True, +) -> bool: + """Prompt for provider/model configuration and persist it to the workspace.""" + if show_current_config: + console.print() + console.print("[anton.cyan]Current configuration:[/]") + console.print(f" Provider (planning): [bold]{settings.planning_provider}[/]") + console.print(f" Provider (coding): [bold]{settings.coding_provider}[/]") + console.print(f" Planning model: [bold]{settings.planning_model}[/]") + console.print(f" Coding model: [bold]{settings.coding_model}[/]") + console.print() + else: + console.print() + console.print("[anton.cyan]LLM configuration[/]") + console.print() + + option_by_number = { + str(index): option + for index, option in enumerate(PROVIDER_OPTIONS, start=1) + } + current_number = current_provider_number(settings.planning_provider) + + console.print("[anton.cyan]Available providers:[/]") + for number, option in option_by_number.items(): + console.print( + f" [bold]{number}[/] {option.label:<36} [dim]\\[{option.badge}][/]" + ) + console.print() + + choice = Prompt.ask( + "Select provider", + choices=list(option_by_number), + default=current_number, + console=console, + ) + provider = option_by_number[choice] + + if provider.key == "ollama": + config = _prompt_for_ollama_config(console, settings, provider) + if config is None: + return False + planning_model, coding_model = config + settings.ollama_base_url = normalize_ollama_base_url(settings.ollama_base_url) + workspace.set_secret("ANTON_OLLAMA_BASE_URL", settings.ollama_base_url) + else: + if provider.key == "openai-compatible": + current_base_url = settings.openai_base_url or "" + console.print() + base_url = Prompt.ask( + "API base URL [dim](e.g. http://localhost:11434/v1)[/]", + default=current_base_url, + console=console, + ).strip() + if base_url: + settings.openai_base_url = base_url + workspace.set_secret("ANTON_OPENAI_BASE_URL", base_url) + + api_key = _prompt_for_api_key(console, settings, provider.key) + if api_key is None: + return False + + planning_model, coding_model = _prompt_for_cloud_models(console, settings, provider) + if api_key: + key_name = api_key_env_name(provider.key) + if key_name: + workspace.set_secret(key_name, api_key) + _set_provider_api_key(settings, provider.key, api_key) + + settings.planning_provider = provider.key + settings.coding_provider = provider.key + settings.planning_model = planning_model + settings.coding_model = coding_model + workspace.set_secret("ANTON_PLANNING_PROVIDER", provider.key) + workspace.set_secret("ANTON_CODING_PROVIDER", provider.key) + workspace.set_secret("ANTON_PLANNING_MODEL", planning_model) + workspace.set_secret("ANTON_CODING_MODEL", coding_model) + return True + + +def current_provider_number(provider: str) -> str: + for index, option in enumerate(PROVIDER_OPTIONS, start=1): + if option.key == provider: + return str(index) + return "1" + + +def api_key_env_name(provider: str) -> str | None: + if provider == "anthropic": + return "ANTON_ANTHROPIC_API_KEY" + if provider in {"openai", "openai-compatible"}: + return "ANTON_OPENAI_API_KEY" + return None + + +def _prompt_for_api_key(console, settings, provider: str) -> str | None: + key_attr = "anthropic_api_key" if provider == "anthropic" else "openai_api_key" + current_key = getattr(settings, key_attr) or "" + masked = _mask_secret(current_key) if current_key else "***" + console.print() + api_key = Prompt.ask( + f"API key for {provider.title()} [dim](Enter to keep {masked})[/]", + default="", + console=console, + ).strip() + if api_key: + return api_key + if current_key: + return "" + console.print() + console.print(f"[anton.error]No API key set for {provider}. Configuration not applied.[/]") + console.print() + return None + + +def _prompt_for_cloud_models(console, settings, provider: ProviderOption) -> tuple[str, str]: + console.print() + planning_model = Prompt.ask( + "Planning model", + default=( + settings.planning_model + if provider.key == settings.planning_provider + else provider.default_planning_model + ), + console=console, + ) + coding_model = Prompt.ask( + "Coding model", + default=( + settings.coding_model + if provider.key == settings.coding_provider + else provider.default_coding_model + ), + console=console, + ) + return planning_model, coding_model + + +def _prompt_for_ollama_config(console, settings, provider: ProviderOption) -> tuple[str, str] | None: + current_url = settings.ollama_base_url or "http://localhost:11434" + console.print() + base_url = Prompt.ask( + "Ollama URL [dim](e.g. http://localhost:11434)[/]", + default=current_url, + console=console, + ).strip() + settings.ollama_base_url = normalize_ollama_base_url(base_url or current_url) + + try: + models = list_ollama_models(settings.ollama_base_url) + except Exception as exc: + console.print() + console.print( + "[anton.warning]Could not query Ollama for local models.[/]" + f" [dim]({exc})[/]" + ) + console.print("[anton.muted]Enter model names manually instead.[/]") + console.print() + planning_model = Prompt.ask( + "Planning model", + default=settings.planning_model if provider.key == settings.planning_provider else "", + console=console, + ).strip() + if not planning_model: + console.print("[anton.error]No model provided. Configuration not applied.[/]") + console.print() + return None + coding_model = Prompt.ask( + "Coding model", + default=settings.coding_model if provider.key == settings.coding_provider else planning_model, + console=console, + ).strip() + return planning_model, coding_model or planning_model + + if not models: + console.print() + console.print("[anton.warning]No local Ollama models found.[/]") + console.print("[anton.muted]Pull a model with `ollama pull ` or enter a name manually.[/]") + console.print() + planning_model = Prompt.ask( + "Planning model", + default=settings.planning_model if provider.key == settings.planning_provider else "", + console=console, + ).strip() + if not planning_model: + console.print("[anton.error]No model provided. Configuration not applied.[/]") + console.print() + return None + coding_model = Prompt.ask( + "Coding model", + default=settings.coding_model if provider.key == settings.coding_provider else planning_model, + console=console, + ).strip() + return planning_model, coding_model or planning_model + + console.print() + console.print("[anton.cyan]Available local Ollama models:[/]") + for index, model in enumerate(models, start=1): + console.print(f" [bold]{index}[/] {model.display_name}") + manual_choice = str(len(models) + 1) + console.print(f" [bold]{manual_choice}[/] Enter model name manually") + + planning_model = _prompt_for_ollama_model_choice( + console=console, + label="Planning model", + models=models, + default_model=( + settings.planning_model + if provider.key == settings.planning_provider + else models[0].name + ), + ) + coding_model = _prompt_for_ollama_model_choice( + console=console, + label="Coding model", + models=models, + default_model=( + settings.coding_model + if provider.key == settings.coding_provider + else planning_model + ), + manual_default=planning_model, + ) + return planning_model, coding_model + + +def _prompt_for_ollama_model_choice( + *, + console, + label: str, + models: list[OllamaModelInfo], + default_model: str, + manual_default: str = "", +) -> str: + choice_by_number = { + str(index): model.name + for index, model in enumerate(models, start=1) + } + manual_choice = str(len(models) + 1) + choices = [*choice_by_number.keys(), manual_choice] + default_choice = manual_choice + for number, model_name in choice_by_number.items(): + if model_name == default_model: + default_choice = number + break + + choice = Prompt.ask( + label, + choices=choices, + default=default_choice, + console=console, + ) + if choice == manual_choice: + return Prompt.ask( + f"{label} name", + default=manual_default or default_model, + console=console, + ).strip() + return choice_by_number[choice] + + +def _mask_secret(value: str, *, keep: int = 4) -> str: + if len(value) <= keep * 2: + return "*" * max(len(value), 3) + return f"{value[:keep]}...{value[-keep:]}" + + +def _set_provider_api_key(settings, provider: str, api_key: str) -> None: + if provider == "anthropic": + settings.anthropic_api_key = api_key + elif provider in {"openai", "openai-compatible"}: + settings.openai_api_key = api_key diff --git a/anton/scratchpad.py b/anton/scratchpad.py index 84aa7a9..d907182 100644 --- a/anton/scratchpad.py +++ b/anton/scratchpad.py @@ -345,6 +345,8 @@ async def start(self) -> None: env["OPENAI_API_KEY"] = env["ANTON_OPENAI_API_KEY"] if "OPENAI_BASE_URL" not in env and "ANTON_OPENAI_BASE_URL" in env: env["OPENAI_BASE_URL"] = env["ANTON_OPENAI_BASE_URL"] + if "OLLAMA_HOST" not in env and "ANTON_OLLAMA_BASE_URL" in env: + env["OLLAMA_HOST"] = env["ANTON_OLLAMA_BASE_URL"] # If settings provided an explicit API key (e.g. from ~/.anton/.env or # Pydantic settings), inject it so the subprocess SDK can authenticate. if self._coding_api_key: diff --git a/anton/scratchpad_boot.py b/anton/scratchpad_boot.py index fe8e997..6812441 100644 --- a/anton/scratchpad_boot.py +++ b/anton/scratchpad_boot.py @@ -20,12 +20,16 @@ _scratchpad_provider_name = os.environ.get("ANTON_SCRATCHPAD_PROVIDER", "anthropic") if _scratchpad_provider_name in ("openai", "openai-compatible"): from anton.llm.openai import OpenAIProvider as _ProviderClass + elif _scratchpad_provider_name == "ollama": + from anton.llm.ollama import OllamaProvider as _ProviderClass else: from anton.llm.anthropic import AnthropicProvider as _ProviderClass _llm_ssl_verify = os.environ.get("ANTON_MINDS_SSL_VERIFY", "true").lower() != "false" if _scratchpad_provider_name in ("openai", "openai-compatible"): _llm_provider = _ProviderClass(ssl_verify=_llm_ssl_verify) + elif _scratchpad_provider_name == "ollama": + _llm_provider = _ProviderClass(base_url=os.environ.get("ANTON_OLLAMA_BASE_URL")) else: _llm_provider = _ProviderClass() # Anthropic doesn't need ssl_verify _llm_model = _scratchpad_model diff --git a/pyproject.toml b/pyproject.toml index 99ff96f..8bf4c2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ description = "Autonomous coding copilot" requires-python = ">=3.11" dependencies = [ "anthropic>=0.42.0", + "ollama>=0.6.1", "openai>=1.0", "pydantic>=2.0", "pydantic-settings>=2.0", diff --git a/tests/test_chat.py b/tests/test_chat.py index ef33a77..bdf1655 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -173,3 +173,134 @@ async def _plan_stream(**kwargs): session._summarize_history.assert_not_called() compacted = [e for e in events if isinstance(e, StreamContextCompacted)] assert len(compacted) == 0 + + +class _FakeAsyncIter: + def __init__(self, items): + self._items = list(items) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._items: + raise StopAsyncIteration + return self._items.pop(0) + + +class TestOllamaVerifier: + async def test_ollama_verifier_disables_thinking(self, monkeypatch): + mock_llm = AsyncMock() + mock_llm.planning_provider_name = "ollama" + mock_llm.plan = AsyncMock(return_value=_text_response("STATUS: COMPLETE — done")) + + calls = 0 + + def fake_plan_stream(**kwargs): + nonlocal calls + calls += 1 + if calls == 1: + return _FakeAsyncIter([ + StreamComplete( + response=LLMResponse( + content="", + tool_calls=[ToolCall(id="tool_1", name="scratchpad", input={"action": "view", "name": "main"})], + usage=Usage(), + stop_reason="tool_use", + ) + ) + ]) + return _FakeAsyncIter([ + StreamComplete(response=_text_response("All done.")) + ]) + + mock_llm.plan_stream = fake_plan_stream + + async def fake_dispatch_tool(session, name, input): + return "tool ok" + + monkeypatch.setattr("anton.chat.dispatch_tool", fake_dispatch_tool) + + session = ChatSession(mock_llm) + events = [event async for event in session.turn_stream("help")] + + assert any(isinstance(event, StreamComplete) for event in events) + assert mock_llm.plan.await_count == 1 + assert mock_llm.plan.call_args.kwargs["request_options"] == {"think": False} + + async def test_ollama_unparseable_verifier_stops_retry_loop(self, monkeypatch): + mock_llm = AsyncMock() + mock_llm.planning_provider_name = "ollama" + mock_llm.plan = AsyncMock(return_value=_text_response("")) + + calls = 0 + + def fake_plan_stream(**kwargs): + nonlocal calls + calls += 1 + if calls == 1: + return _FakeAsyncIter([ + StreamComplete( + response=LLMResponse( + content="", + tool_calls=[ToolCall(id="tool_1", name="scratchpad", input={"action": "view", "name": "main"})], + usage=Usage(), + stop_reason="tool_use", + ) + ) + ]) + return _FakeAsyncIter([ + StreamComplete(response=_text_response("All done.")) + ]) + + mock_llm.plan_stream = fake_plan_stream + + async def fake_dispatch_tool(session, name, input): + return "tool ok" + + monkeypatch.setattr("anton.chat.dispatch_tool", fake_dispatch_tool) + + session = ChatSession(mock_llm) + events = [event async for event in session.turn_stream("help")] + + assert any(isinstance(event, StreamComplete) for event in events) + assert calls == 2 + + async def test_non_ollama_unparseable_verifier_continues_working(self, monkeypatch): + mock_llm = AsyncMock() + mock_llm.planning_provider_name = "anthropic" + mock_llm.plan = AsyncMock(return_value=_text_response("")) + + calls = 0 + + def fake_plan_stream(**kwargs): + nonlocal calls + calls += 1 + if calls == 1: + return _FakeAsyncIter([ + StreamComplete( + response=LLMResponse( + content="", + tool_calls=[ToolCall(id="tool_1", name="scratchpad", input={"action": "view", "name": "main"})], + usage=Usage(), + stop_reason="tool_use", + ) + ) + ]) + return _FakeAsyncIter([ + StreamComplete(response=_text_response("All done.")) + ]) + + mock_llm.plan_stream = fake_plan_stream + + async def fake_dispatch_tool(session, name, input): + return "tool ok" + + monkeypatch.setattr("anton.chat.dispatch_tool", fake_dispatch_tool) + + session = ChatSession(mock_llm) + events = [event async for event in session.turn_stream("help")] + + assert any(isinstance(event, StreamComplete) for event in events) + assert calls == 3 + assert mock_llm.plan.call_args.kwargs["request_options"] is None diff --git a/tests/test_chat_scratchpad.py b/tests/test_chat_scratchpad.py index 94d0d0a..9ce89b6 100644 --- a/tests/test_chat_scratchpad.py +++ b/tests/test_chat_scratchpad.py @@ -228,6 +228,7 @@ def fake_plan_stream(**kwargs): ]) mock_llm.plan_stream = fake_plan_stream + mock_llm.plan = AsyncMock(return_value=_text_response("STATUS: COMPLETE — done")) session = ChatSession(mock_llm) try: @@ -270,6 +271,7 @@ def fake_plan_stream(**kwargs): return _FakeAsyncIter([StreamComplete(response=final_response)]) mock_llm.plan_stream = fake_plan_stream + mock_llm.plan = AsyncMock(return_value=_text_response("STATUS: COMPLETE — done")) session = ChatSession(mock_llm) try: diff --git a/tests/test_chat_ui.py b/tests/test_chat_ui.py index 471f9bf..3e9ebc4 100644 --- a/tests/test_chat_ui.py +++ b/tests/test_chat_ui.py @@ -86,9 +86,20 @@ def test_update_progress_without_eta(self, MockLive): assert live.update.call_count >= 1 def test_phase_labels_cover_all_phases(self): - expected = {"memory_recall", "planning", "executing", "complete", "failed", "scratchpad"} + expected = {"memory_recall", "planning", "reasoning", "executing", "complete", "failed", "scratchpad"} assert expected == set(PHASE_LABELS.keys()) + @patch("anton.chat_ui.Live") + def test_reasoning_progress_updates_spinner(self, MockLive): + display, _ = self._make_display() + display.start() + live = MockLive.return_value + + display.update_progress("reasoning", "Thinking...") + + assert display._line2_status == "Thinking..." + assert live.update.call_count >= 1 + class TestActivityTracking: def _make_display(self): diff --git a/tests/test_client.py b/tests/test_client.py index 897d41d..904f98c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -6,6 +6,7 @@ from anton.config.settings import AntonSettings from anton.llm.client import LLMClient +from anton.llm.ollama import OllamaProvider from anton.llm.provider import LLMProvider, LLMResponse, Usage @@ -98,3 +99,20 @@ def test_unknown_coding_provider_raises(self): ) with pytest.raises(ValueError, match="Unknown coding provider"): LLMClient.from_settings(settings) + + def test_from_settings_ollama(self): + with patch("anton.llm.ollama.ollama"): + settings = AntonSettings( + planning_provider="ollama", + coding_provider="ollama", + planning_model="qwen3.5:4b", + coding_model="qwen3.5:4b", + ollama_base_url="http://localhost:11434/v1", + _env_file=None, + ) + client = LLMClient.from_settings(settings) + assert isinstance(client, LLMClient) + assert isinstance(client._planning_provider, OllamaProvider) + assert isinstance(client._coding_provider, OllamaProvider) + assert client.planning_provider_name == "ollama" + assert client.coding_provider_name == "ollama" diff --git a/tests/test_llm_setup.py b/tests/test_llm_setup.py new file mode 100644 index 0000000..b7b7b13 --- /dev/null +++ b/tests/test_llm_setup.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from anton.config.settings import AntonSettings +from anton.llm.ollama import OllamaModelInfo +from anton.llm.setup import configure_llm_settings +from anton.workspace import Workspace + + +@pytest.fixture(autouse=True) +def clean_llm_env(monkeypatch): + for key in ( + "ANTON_ANTHROPIC_API_KEY", + "ANTON_OPENAI_API_KEY", + "ANTON_OPENAI_BASE_URL", + "ANTON_OLLAMA_BASE_URL", + "ANTON_PLANNING_PROVIDER", + "ANTON_CODING_PROVIDER", + "ANTON_PLANNING_MODEL", + "ANTON_CODING_MODEL", + ): + monkeypatch.delenv(key, raising=False) + + +class TestConfigureLlmSettings: + @patch("anton.llm.setup.list_ollama_models") + @patch("anton.llm.setup.Prompt.ask") + def test_ollama_discovery_success(self, mock_ask, mock_list, tmp_path): + mock_ask.side_effect = ["3", "http://localhost:11434", "1", "2"] + mock_list.return_value = [ + OllamaModelInfo(name="qwen3.5:4b", parameter_size="4.7B", quantization="Q4_K_M"), + OllamaModelInfo(name="qwen3:0.6b", parameter_size="751.63M", quantization="Q4_K_M"), + ] + + console = MagicMock() + workspace = Workspace(tmp_path) + settings = AntonSettings(_env_file=None) + + applied = configure_llm_settings(console, settings, workspace, show_current_config=False) + + assert applied is True + assert settings.planning_provider == "ollama" + assert settings.coding_provider == "ollama" + assert settings.ollama_base_url == "http://localhost:11434" + assert settings.planning_model == "qwen3.5:4b" + assert settings.coding_model == "qwen3:0.6b" + assert workspace.get_secret("ANTON_OLLAMA_BASE_URL") == "http://localhost:11434" + + @patch("anton.llm.setup.list_ollama_models") + @patch("anton.llm.setup.Prompt.ask") + def test_ollama_manual_fallback_when_discovery_fails(self, mock_ask, mock_list, tmp_path): + mock_ask.side_effect = ["3", "localhost:11434/v1", "qwen3.5:4b", "qwen3.5:4b"] + mock_list.side_effect = ConnectionError("boom") + + console = MagicMock() + workspace = Workspace(tmp_path) + settings = AntonSettings(_env_file=None) + + applied = configure_llm_settings(console, settings, workspace, show_current_config=False) + + assert applied is True + assert settings.planning_provider == "ollama" + assert settings.ollama_base_url == "http://localhost:11434" + assert settings.planning_model == "qwen3.5:4b" + assert settings.coding_model == "qwen3.5:4b" + + @patch("anton.llm.setup.Prompt.ask") + def test_anthropic_setup_persists_api_key(self, mock_ask, tmp_path): + mock_ask.side_effect = ["1", "sk-ant-test", "claude-sonnet-4-6", "claude-haiku-4-5-20251001"] + + console = MagicMock() + workspace = Workspace(tmp_path) + settings = AntonSettings(_env_file=None) + + applied = configure_llm_settings(console, settings, workspace, show_current_config=False) + + assert applied is True + assert settings.planning_provider == "anthropic" + assert settings.anthropic_api_key == "sk-ant-test" + assert workspace.get_secret("ANTON_ANTHROPIC_API_KEY") == "sk-ant-test" diff --git a/tests/test_ollama_provider.py b/tests/test_ollama_provider.py new file mode 100644 index 0000000..a0daa78 --- /dev/null +++ b/tests/test_ollama_provider.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +from anton.llm.ollama import ( + OllamaProvider, + normalize_ollama_base_url, + translate_messages_to_ollama, +) +from anton.llm.provider import ( + StreamComplete, + StreamTaskProgress, + StreamTextDelta, + StreamToolUseDelta, + StreamToolUseEnd, + StreamToolUseStart, +) + + +class _AsyncIter: + def __init__(self, items): + self._items = list(items) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._items: + raise StopAsyncIteration + return self._items.pop(0) + + +def _ollama_response(*, content="", tool_calls=None, prompt_tokens=5, completion_tokens=10, done_reason="stop"): + return SimpleNamespace( + message=SimpleNamespace(content=content, tool_calls=tool_calls or [], thinking=None), + prompt_eval_count=prompt_tokens, + eval_count=completion_tokens, + done_reason=done_reason, + ) + + +class TestOllamaProvider: + async def test_complete_text_response(self): + with patch("anton.llm.ollama.ollama") as mock_ollama: + mock_client = AsyncMock() + mock_ollama.AsyncClient.return_value = mock_client + mock_client.chat = AsyncMock(return_value=_ollama_response(content="Hello world")) + + provider = OllamaProvider(base_url="http://localhost:11434/v1") + result = await provider.complete( + model="qwen3.5:4b", + system="be helpful", + messages=[{"role": "user", "content": "hi"}], + max_tokens=123, + request_options={"think": False}, + ) + + call_kwargs = mock_client.chat.call_args.kwargs + assert call_kwargs["think"] is False + assert call_kwargs["options"] == {"num_predict": 123} + assert result.content == "Hello world" + assert result.tool_calls == [] + assert result.usage.input_tokens == 5 + assert result.usage.output_tokens == 10 + assert result.stop_reason == "stop" + + async def test_complete_tool_use_response(self): + with patch("anton.llm.ollama.ollama") as mock_ollama: + mock_client = AsyncMock() + mock_ollama.AsyncClient.return_value = mock_client + tool_call = SimpleNamespace( + function=SimpleNamespace(name="lookup_weather", arguments={"city": "London"}) + ) + mock_client.chat = AsyncMock( + return_value=_ollama_response(content="", tool_calls=[tool_call], done_reason="tool_calls") + ) + + provider = OllamaProvider(base_url="http://localhost:11434") + result = await provider.complete( + model="qwen3.5:4b", + system="plan", + messages=[{"role": "user", "content": "weather?"}], + tools=[{"name": "lookup_weather", "description": "d", "input_schema": {"type": "object"}}], + ) + + assert result.content == "" + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].id == "ollama_tool_1" + assert result.tool_calls[0].name == "lookup_weather" + assert result.tool_calls[0].input == {"city": "London"} + assert result.stop_reason == "tool_calls" + + async def test_stream_emits_reasoning_progress_and_tool_events(self): + with patch("anton.llm.ollama.ollama") as mock_ollama: + mock_client = AsyncMock() + mock_ollama.AsyncClient.return_value = mock_client + tool_call = SimpleNamespace( + function=SimpleNamespace(name="lookup_weather", arguments={"city": "London"}) + ) + mock_client.chat = AsyncMock( + return_value=_AsyncIter([ + SimpleNamespace( + message=SimpleNamespace(content="", thinking="Working", tool_calls=[]), + prompt_eval_count=5, + eval_count=0, + done_reason=None, + ), + SimpleNamespace( + message=SimpleNamespace(content="Done.", thinking=None, tool_calls=[tool_call]), + prompt_eval_count=None, + eval_count=6, + done_reason="stop", + ), + ]) + ) + + provider = OllamaProvider(base_url="http://localhost:11434") + events = [event async for event in provider.stream( + model="qwen3.5:4b", + system="sys", + messages=[{"role": "user", "content": "hi"}], + )] + + assert any(isinstance(event, StreamTaskProgress) and event.phase == "reasoning" for event in events) + assert any(isinstance(event, StreamTextDelta) and event.text == "Done." for event in events) + assert any(isinstance(event, StreamToolUseStart) and event.name == "lookup_weather" for event in events) + assert any(isinstance(event, StreamToolUseDelta) for event in events) + assert any(isinstance(event, StreamToolUseEnd) for event in events) + complete = next(event for event in events if isinstance(event, StreamComplete)) + assert complete.response.content == "Done." + assert complete.response.tool_calls[0].input == {"city": "London"} + + +class TestOllamaHelpers: + def test_normalize_ollama_base_url_strips_v1(self): + assert normalize_ollama_base_url("localhost:11434/v1") == "http://localhost:11434" + + def test_translate_messages_to_ollama_handles_tool_results(self): + messages = [ + {"role": "user", "content": "Find the file"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "I will check."}, + {"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/tmp/test.txt"}}, + ], + }, + { + "role": "user", + "content": [ + {"type": "tool_result", "tool_use_id": "tool_1", "content": "hello"}, + ], + }, + ] + + translated = translate_messages_to_ollama("system", messages) + + assert translated[0] == {"role": "system", "content": "system"} + assert translated[2]["role"] == "assistant" + assert translated[2]["tool_calls"][0]["function"]["name"] == "read_file" + assert translated[3] == {"role": "tool", "tool_name": "read_file", "content": "hello"} diff --git a/tests/test_scratchpad.py b/tests/test_scratchpad.py index a44a316..782668c 100644 --- a/tests/test_scratchpad.py +++ b/tests/test_scratchpad.py @@ -458,6 +458,22 @@ async def test_api_key_bridged(self, monkeypatch): finally: await pad.close() + async def test_ollama_host_bridged(self, monkeypatch): + """ANTON_OLLAMA_BASE_URL should be bridged to OLLAMA_HOST.""" + monkeypatch.setenv("ANTON_OLLAMA_BASE_URL", "http://localhost:11434") + monkeypatch.delenv("OLLAMA_HOST", raising=False) + pad = Scratchpad(name="ollama-host-test", _coding_provider="ollama", _coding_model="qwen3.5:4b") + await pad.start() + try: + cell = await pad.execute( + "import os; print(os.environ.get('OLLAMA_HOST', 'MISSING'))" + ) + assert cell.stdout.strip() == "http://localhost:11434" + llm_cell = await pad.execute("llm = get_llm(); print(llm.model)") + assert llm_cell.stdout.strip() == "qwen3.5:4b" + finally: + await pad.close() + class TestScratchpadVenv: async def test_venv_created_on_start(self): diff --git a/tests/test_settings.py b/tests/test_settings.py index 5833fe8..e67947b 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -8,35 +8,54 @@ from anton.config.settings import AntonSettings +@pytest.fixture(autouse=True) +def clean_anton_env(monkeypatch): + for key in ( + "ANTON_PLANNING_PROVIDER", + "ANTON_PLANNING_MODEL", + "ANTON_CODING_PROVIDER", + "ANTON_CODING_MODEL", + "ANTON_ANTHROPIC_API_KEY", + "ANTON_OPENAI_API_KEY", + "ANTON_OPENAI_BASE_URL", + "ANTON_OLLAMA_BASE_URL", + ): + monkeypatch.delenv(key, raising=False) + + class TestAntonSettingsDefaults: def test_default_planning_provider(self): - s = AntonSettings(anthropic_api_key="test") + s = AntonSettings(anthropic_api_key="test", _env_file=None) assert s.planning_provider == "anthropic" def test_default_planning_model(self): - s = AntonSettings(anthropic_api_key="test") + s = AntonSettings(anthropic_api_key="test", _env_file=None) assert s.planning_model == "claude-sonnet-4-6" def test_default_coding_provider(self): - s = AntonSettings(anthropic_api_key="test") + s = AntonSettings(anthropic_api_key="test", _env_file=None) assert s.coding_provider == "anthropic" def test_default_coding_model(self): - s = AntonSettings(anthropic_api_key="test") + s = AntonSettings(anthropic_api_key="test", _env_file=None) assert s.coding_model == "claude-haiku-4-5-20251001" def test_default_memory_dir(self): - s = AntonSettings(anthropic_api_key="test") + s = AntonSettings(anthropic_api_key="test", _env_file=None) assert s.memory_dir == ".anton" def test_default_context_dir(self): - s = AntonSettings(anthropic_api_key="test") + s = AntonSettings(anthropic_api_key="test", _env_file=None) assert s.context_dir == ".anton/context" def test_default_api_key_is_none(self): s = AntonSettings(_env_file=None) assert s.anthropic_api_key is None + def test_default_ollama_base_url(self): + s = AntonSettings(_env_file=None) + assert s.ollama_base_url == "http://localhost:11434" + class TestAntonSettingsEnvOverride: def test_env_overrides_planning_model(self, monkeypatch): @@ -49,6 +68,11 @@ def test_env_overrides_api_key(self, monkeypatch): s = AntonSettings(_env_file=None) assert s.anthropic_api_key == "sk-test-key" + def test_env_overrides_ollama_base_url(self, monkeypatch): + monkeypatch.setenv("ANTON_OLLAMA_BASE_URL", "http://example.test:11434") + s = AntonSettings(_env_file=None) + assert s.ollama_base_url == "http://example.test:11434" + class TestWorkspaceResolution: def test_resolve_workspace_defaults_to_cwd(self, tmp_path, monkeypatch): monkeypatch.chdir(tmp_path)