From 6bfe7423a6c0ccc95d730ba6e1c1efe29d99f93c Mon Sep 17 00:00:00 2001 From: Luca Massaron Date: Sun, 26 Apr 2026 00:02:44 +0200 Subject: [PATCH 1/3] feat: add local model support via Ollama with SSRF and backend security --- agent/core/llm_params.py | 29 ++++++++ agent/core/model_switcher.py | 16 ++++- agent/core/session.py | 10 +++ agent/main.py | 98 ++++++++++++++++++++++--- agent/utils/ollama_utils.py | 122 ++++++++++++++++++++++++++++++++ agent/utils/terminal_display.py | 1 + backend/routes/agent.py | 41 +++++++++++ 7 files changed, 305 insertions(+), 12 deletions(-) create mode 100644 agent/utils/ollama_utils.py diff --git a/agent/core/llm_params.py b/agent/core/llm_params.py index 880886b3..43c308df 100644 --- a/agent/core/llm_params.py +++ b/agent/core/llm_params.py @@ -5,6 +5,8 @@ creating circular imports. """ +import os + from agent.core.hf_tokens import get_hf_bill_to, resolve_hf_router_token @@ -98,6 +100,10 @@ def _resolve_llm_params( """ Build LiteLLM kwargs for a given model id. + • ``ollama/`` — local inference via Ollama. Hits the + OpenAI-compatible /v1 endpoint on localhost for improved tool calling. + Restricted to localhost/loopback to mitigate SSRF risk. + • ``anthropic/`` — native thinking config. We bypass LiteLLM's ``reasoning_effort`` → ``thinking`` mapping (which lags new Claude releases like 4.7 and sends the wrong API shape). Instead we pass @@ -137,6 +143,29 @@ def _resolve_llm_params( 3. huggingface_hub cache — ``HF_TOKEN`` / ``HUGGING_FACE_HUB_TOKEN`` / local ``hf auth login`` cache. """ + if model_name.startswith("ollama/"): + # Use OpenAI compatible endpoint for Ollama to get better tool calling support + actual_model = model_name.replace("ollama/", "", 1) + api_base = os.environ.get("OLLAMA_API_BASE", "http://localhost:11434") + + # SSRF Mitigation: Restrict local models to localhost/loopback + from urllib.parse import urlparse + + parsed = urlparse(api_base) + hostname = parsed.hostname or "" + if hostname not in ("localhost", "127.0.0.1"): + raise ValueError( + f"Security error: local model API base '{api_base}' must point to localhost or 127.0.0.1" + ) + + if not api_base.endswith("/v1"): + api_base = f"{api_base.rstrip('/')}/v1" + return { + "model": f"openai/{actual_model}", + "api_base": api_base, + "api_key": "ollama", + } + if model_name.startswith("anthropic/"): params: dict = {"model": model_name} if reasoning_effort: diff --git a/agent/core/model_switcher.py b/agent/core/model_switcher.py index ea419db1..8f034220 100644 --- a/agent/core/model_switcher.py +++ b/agent/core/model_switcher.py @@ -15,7 +15,12 @@ from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from prompt_toolkit import PromptSession + from agent.core.effort_probe import ProbeInconclusive, probe_effort +from agent.utils.ollama_utils import ensure_ollama_readiness # Suggested models shown by `/model` (not a gate). Users can paste any HF @@ -33,6 +38,7 @@ {"id": "moonshotai/Kimi-K2.6", "label": "Kimi K2.6"}, {"id": "zai-org/GLM-5.1", "label": "GLM 5.1"}, {"id": "deepseek-ai/DeepSeek-V4-Pro:deepinfra", "label": "DeepSeek V4 Pro"}, + {"id": "ollama/qwen3.6", "label": "Qwen 3.6 (Local)"}, ] @@ -160,6 +166,7 @@ async def probe_and_switch_model( session, console, hf_token: str | None, + prompt_session: PromptSession, ) -> None: """Validate model+effort with a 1-token ping, cache the effective effort, then commit the switch. @@ -175,8 +182,15 @@ async def probe_and_switch_model( Transient errors (5xx, timeout) complete the switch with a yellow warning; the next real call re-surfaces the error if it's persistent. """ + normalized = model_id.removeprefix("huggingface/") + + # Local model pre-flight + if normalized.startswith("ollama/"): + if not await ensure_ollama_readiness(normalized, prompt_session): + return + preference = config.reasoning_effort - if not _print_hf_routing_info(model_id, console): + if not _print_hf_routing_info(normalized, console): return if not preference: diff --git a/agent/core/session.py b/agent/core/session.py index 370bb3a6..c40c76cb 100644 --- a/agent/core/session.py +++ b/agent/core/session.py @@ -18,6 +18,12 @@ logger = logging.getLogger(__name__) +# Local max-token overrides — used for models not yet in the litellm catalog +# or when we want to pin a specific window. +# NOTE: Ollama models not listed here will fall back to _DEFAULT_MAX_TOKENS (200k). +_MAX_TOKENS_MAP: dict[str, int] = { + "ollama/qwen3.6": 128_000, +} _DEFAULT_MAX_TOKENS = 200_000 _TURN_COMPLETE_NOTIFICATION_CHARS = 39000 @@ -32,6 +38,10 @@ def _get_max_tokens_safe(model_name: str) -> int: look up the bare model. Falls back to a conservative 200k default for models not in the catalog (typically HF-router-only models). """ + # Check local overrides first + if model_name in _MAX_TOKENS_MAP: + return _MAX_TOKENS_MAP[model_name] + from litellm import get_model_info candidates = [model_name] diff --git a/agent/main.py b/agent/main.py index 606aaf8e..c75eb884 100644 --- a/agent/main.py +++ b/agent/main.py @@ -29,6 +29,7 @@ from agent.core.tools import ToolRouter from agent.messaging.gateway import NotificationGateway from agent.utils.reliability_checks import check_training_script_save_pattern +from agent.utils.ollama_utils import ensure_ollama_readiness from agent.utils.terminal_display import ( get_console, print_approval_header, @@ -77,6 +78,7 @@ def _configure_runtime_logging() -> None: logging.getLogger("LiteLLM").setLevel(logging.ERROR) logging.getLogger("litellm").setLevel(logging.ERROR) + def _safe_get_args(arguments: dict) -> dict: """Safely extract args dict from arguments, handling cases where LLM passes string.""" args = arguments.get("args", {}) @@ -729,6 +731,7 @@ async def _handle_slash_command( session_holder: list, submission_queue: asyncio.Queue, submission_id: list[int], + prompt_session: PromptSession, ) -> Submission | None: """ Handle a slash command. Returns a Submission to enqueue, or None if @@ -770,10 +773,35 @@ async def _handle_slash_command( normalized = arg.removeprefix("huggingface/") session = session_holder[0] if session_holder else None await model_switcher.probe_and_switch_model( - normalized, config, session, console, resolve_hf_token(), + normalized, config, session, console, resolve_hf_token(), prompt_session, ) return None + if command == "/add-model": + if not arg: + print("Usage: /add-model ") + print("Example: /add-model ollama/mistral") + return None + model_id = arg + if not model_id.startswith("ollama/"): + model_id = f"ollama/{model_id}" + + # Check if it's already in suggested models + from agent.core.model_switcher import SUGGESTED_MODELS + if any(m["id"] == model_id for m in SUGGESTED_MODELS): + print(f"Model {model_id} is already available.") + return None + + # Check if Ollama is running and model is available (prompt to pull if not) + if not await ensure_ollama_readiness(model_id, prompt_session): + return None + + label = model_id.replace("ollama/", "", 1).capitalize() + " (Local)" + SUGGESTED_MODELS.append({"id": model_id, "label": label}) + print(f"Added local model: {model_id}") + print(f"Switch to it with: /model {model_id}") + return None + if command == "/yolo": config.yolo_mode = not config.yolo_mode state = "ON" if config.yolo_mode else "OFF" @@ -948,15 +976,23 @@ async def main(model: str | None = None): # Create prompt session for input (needed early for token prompt) prompt_session = PromptSession() - # HF token — required, prompt if missing - hf_token = resolve_hf_token() - if not hf_token: - hf_token = await _prompt_and_save_hf_token(prompt_session) - + # Load config early to check model name config = load_config(CLI_CONFIG_PATH, include_user_defaults=True) if model: config.model_name = model + # Check if Ollama is running and model is available (prompt to pull if not) + if config.model_name.startswith("ollama/"): + if not await ensure_ollama_readiness(config.model_name, prompt_session): + print("\nExiting: Local model not ready.") + return + + # HF token — required, prompt if missing (unless it's a local model) + hf_token = resolve_hf_token() + is_local = config.model_name.startswith("ollama/") + if not hf_token and not is_local: + hf_token = await _prompt_and_save_hf_token(prompt_session) + # Resolve username for banner hf_user = _get_hf_user(hf_token) @@ -1014,7 +1050,24 @@ async def main(model: str | None = None): ) ) - await ready_event.wait() + # Wait for ready or agent task failure + ready_task = asyncio.create_task(ready_event.wait()) + done, pending = await asyncio.wait( + [ready_task, agent_task], + return_when=asyncio.FIRST_COMPLETED + ) + if agent_task in done: + # Agent task died before ready + ready_task.cancel() + try: + await agent_task + except Exception as e: + print_error(f"Agent failed to initialize: {e}") + listener_task.cancel() + return + else: + # ready_task is done + pass submission_id = [0] # Mirrors codex-rs/tui/src/bottom_pane/mod.rs:137 @@ -1110,7 +1163,12 @@ def _install_sigint() -> bool: # Handle slash commands if user_input.strip().startswith("/"): sub = await _handle_slash_command( - user_input.strip(), config, session_holder, submission_queue, submission_id + user_input.strip(), + config, + session_holder, + submission_queue, + submission_id, + prompt_session, ) if sub is None: # Command handled locally, loop back for input @@ -1174,14 +1232,32 @@ async def headless_main( logging.basicConfig(level=logging.WARNING) _configure_runtime_logging() + config = load_config(CLI_CONFIG_PATH, include_user_defaults=True) + + if model: + config.model_name = model + + # Check if Ollama is running and model is available + if config.model_name.startswith("ollama/"): + from agent.utils.ollama_utils import is_ollama_running, is_model_available + + if not await is_ollama_running(): + print(f"ERROR: Ollama server is not reachable for local model {config.model_name}", file=sys.stderr) + sys.exit(1) + if not await is_model_available(config.model_name): + print(f"ERROR: Model {config.model_name} is not available on Ollama. Run 'ollama pull {config.model_name.replace('ollama/', '', 1)}' first.", file=sys.stderr) + sys.exit(1) + + # HF token — required, prompt if missing (unless it's a local model) hf_token = resolve_hf_token() - if not hf_token: + is_local = config.model_name.startswith("ollama/") + if not hf_token and not is_local: print("ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.", file=sys.stderr) sys.exit(1) - print(f"HF token loaded", file=sys.stderr) + if hf_token: + print(f"HF token loaded", file=sys.stderr) - config = load_config(CLI_CONFIG_PATH, include_user_defaults=True) config.yolo_mode = True # Auto-approve everything in headless mode notification_gateway = NotificationGateway(config.messaging) await notification_gateway.start() diff --git a/agent/utils/ollama_utils.py b/agent/utils/ollama_utils.py new file mode 100644 index 00000000..c1e52b56 --- /dev/null +++ b/agent/utils/ollama_utils.py @@ -0,0 +1,122 @@ +import os +import json +import httpx +from typing import Optional +from urllib.parse import urlparse +from agent.utils.terminal_display import get_console + +def get_ollama_base_url() -> str: + """Read OLLAMA_API_BASE from environment, defaulting to localhost:11434. + Validates that the hostname is localhost or 127.0.0.1 to mitigate SSRF. + """ + url = os.environ.get("OLLAMA_API_BASE", "http://localhost:11434").rstrip("/") + + # SSRF Mitigation: Restrict local models to localhost/loopback + parsed = urlparse(url) + hostname = parsed.hostname or "" + if hostname not in ("localhost", "127.0.0.1"): + raise ValueError( + f"Security error: OLLAMA_API_BASE '{url}' must point to localhost or 127.0.0.1" + ) + return url + +async def is_ollama_running() -> bool: + """Check if the Ollama server is reachable using async httpx.""" + try: + url = f"{get_ollama_base_url()}/api/tags" + async with httpx.AsyncClient() as client: + response = await client.get(url, timeout=2.0) + return response.status_code == 200 + except (httpx.RequestError, ValueError): + return False + +async def is_model_available(model_name: str) -> bool: + """Check if a specific model is already pulled in Ollama using async httpx.""" + try: + url = f"{get_ollama_base_url()}/api/tags" + async with httpx.AsyncClient() as client: + response = await client.get(url, timeout=2.0) + if response.status_code != 200: + return False + + tags = response.json().get("models", []) + actual_name = model_name.replace("ollama/", "", 1) + + # Ollama tags can be 'name:latest' or just 'name' + for model in tags: + name = model.get("name", "") + if name == actual_name or name == f"{actual_name}:latest": + return True + return False + except (httpx.RequestError, ValueError): + return False + +async def pull_ollama_model(model_name: str, prompt_session=None) -> bool: + """Pull a model from Ollama with real-time progress tracking using async httpx.""" + actual_name = model_name.replace("ollama/", "", 1) + url = f"{get_ollama_base_url()}/api/pull" + + get_console().print(f"Pulling '{actual_name}' from Ollama...") + + try: + async with httpx.AsyncClient(timeout=None) as client: + async with client.stream("POST", url, json={"name": actual_name}) as response: + if response.status_code != 200: + get_console().print(f"[bold red]Error pulling model:[/bold red] {response.status_code}") + return False + + last_status = "" + async for line in response.aiter_lines(): + if line: + data = json.loads(line) + status = data.get("status", "") + completed = data.get("completed") + total = data.get("total") + + if status != last_status: + if total and completed is not None: + percent = (completed / total) * 100 + print(f"\r{status}: {percent:.1f}%", end="", flush=True) + else: + print(f"\r{status}", end="", flush=True) + last_status = status + + get_console().print("\n[green]Pull complete![/green]") + return True + except (httpx.RequestError, ValueError) as e: + get_console().print(f"\n[bold red]Failed to pull model:[/bold red] {e}") + return False + +async def ensure_ollama_readiness(model_id: str, prompt_session) -> bool: + """ + Check server and model availability. Prompt to pull if missing. + Returns True if ready to proceed, False otherwise. + """ + try: + base_url = get_ollama_base_url() + except ValueError as e: + get_console().print(f"\n[bold red]Configuration Error:[/bold red] {e}") + return False + + if not await is_ollama_running(): + get_console().print(f"\n[bold red]Error:[/bold red] Ollama server is not reachable.") + get_console().print(f"Make sure 'ollama serve' is running at {base_url}") + return False + + if not await is_model_available(model_id): + get_console().print(f"\nModel '{model_id}' not found locally on Ollama.") + + try: + choice = await prompt_session.prompt_async( + f"Would you like to pull {model_id}? (y/n): " + ) + if choice.strip().lower() in ("y", "yes"): + return await pull_ollama_model(model_id) + else: + get_console().print("Model pull cancelled.") + return False + except (EOFError, KeyboardInterrupt): + get_console().print("\nCancelled.") + return False + + return True diff --git a/agent/utils/terminal_display.py b/agent/utils/terminal_display.py index f2b73301..95c29de2 100644 --- a/agent/utils/terminal_display.py +++ b/agent/utils/terminal_display.py @@ -422,6 +422,7 @@ def print_yolo_approve(count: int) -> None: {_I} [cyan]/undo[/cyan] Undo last turn {_I} [cyan]/compact[/cyan] Compact context window {_I} [cyan]/model[/cyan] [id] Show available models or switch +{_I} [cyan]/add-model[/cyan] [id] Add a local Ollama model (prefixes ollama/) {_I} [cyan]/effort[/cyan] [level] Reasoning effort (minimal|low|medium|high|xhigh|max|off) {_I} [cyan]/yolo[/cyan] Toggle auto-approve mode {_I} [cyan]/status[/cyan] Current model & turn count diff --git a/backend/routes/agent.py b/backend/routes/agent.py index ed33650d..68c19de0 100644 --- a/backend/routes/agent.py +++ b/backend/routes/agent.py @@ -301,6 +301,28 @@ async def get_model() -> dict: } +@router.post("/config/model") +async def set_model(body: dict, user: dict = Depends(get_current_user)) -> dict: + """Set the LLM model. Applies to new conversations.""" + model_id = body.get("model") + if not model_id: + raise HTTPException(status_code=400, detail="Missing 'model' field") + + # Security: Ollama models are for local CLI use only to prevent SSRF in hosted environments. + if model_id.startswith("ollama/"): + raise HTTPException( + status_code=400, + detail="Local models (ollama/) can only be used via the CLI for security reasons." + ) + + valid_ids = {m["id"] for m in AVAILABLE_MODELS} + if model_id not in valid_ids: + raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}") + session_manager.config.model_name = model_id + logger.info(f"Model changed to {model_id} by {user.get('username', 'unknown')}") + return {"model": model_id} + + _TITLE_STRIP_CHARS = str.maketrans("", "", "`*_~#[]()") @@ -394,6 +416,12 @@ async def create_session( if isinstance(body, dict): model = body.get("model") + if model and model.startswith("ollama/"): + raise HTTPException( + status_code=400, + detail="Local models (ollama/) can only be used via the CLI for security reasons." + ) + valid_ids = {m["id"] for m in AVAILABLE_MODELS} if model and model not in valid_ids: raise HTTPException(status_code=400, detail=f"Unknown model: {model}") @@ -439,6 +467,12 @@ async def restore_session_summary( hf_token = resolve_hf_request_token(request) model = body.get("model") + if model and model.startswith("ollama/"): + raise HTTPException( + status_code=400, + detail="Local models (ollama/) can only be used via the CLI for security reasons." + ) + valid_ids = {m["id"] for m in AVAILABLE_MODELS} if model and model not in valid_ids: raise HTTPException(status_code=400, detail=f"Unknown model: {model}") @@ -505,6 +539,13 @@ async def set_session_model( model_id = body.get("model") if not model_id: raise HTTPException(status_code=400, detail="Missing 'model' field") + + if model_id.startswith("ollama/"): + raise HTTPException( + status_code=400, + detail="Local models (ollama/) can only be used via the CLI for security reasons." + ) + valid_ids = {m["id"] for m in AVAILABLE_MODELS} if model_id not in valid_ids: raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}") From 536f25f7995d9ef0927f4c2db09b2f21e2cc867a Mon Sep 17 00:00:00 2001 From: Luca Massaron Date: Fri, 1 May 2026 19:57:33 +0200 Subject: [PATCH 2/3] fix: handle socket timeouts during LLM streaming and increase default timeout --- agent/core/agent_loop.py | 121 ++++++++++++++++++++++----------------- 1 file changed, 67 insertions(+), 54 deletions(-) diff --git a/agent/core/agent_loop.py b/agent/core/agent_loop.py index 03a4457a..292cf811 100644 --- a/agent/core/agent_loop.py +++ b/agent/core/agent_loop.py @@ -356,6 +356,7 @@ async def _record_manual_approved_spend_if_needed( _MAX_LLM_RETRIES = 3 _LLM_RETRY_DELAYS = [5, 15, 30] # seconds between retries _LLM_RATE_LIMIT_RETRY_DELAYS = [30, 60] # exceed Bedrock's ~60s TPM bucket window +_LLM_TIMEOUT = 900 # seconds (15 minutes) def _is_rate_limit_error(error: Exception) -> bool: @@ -727,23 +728,86 @@ def _assistant_message_from_result( async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> LLMResult: """Call the LLM with streaming, emitting assistant_chunk events.""" - response = None _healed_effort = False # one-shot safety net per call _healed_thinking_signature = False messages, tools = with_prompt_caching(messages, tools, llm_params.get("model")) t_start = time.monotonic() + + # Initialize accumulators + full_content = "" + tool_calls_acc: dict[int, dict] = {} + token_count = 0 + finish_reason = None + final_usage_chunk = None + chunks = [] + should_replay_thinking = _should_replay_thinking_state(llm_params.get("model")) + for _llm_attempt in range(_MAX_LLM_RETRIES): try: + # Reset accumulators on retry + full_content = "" + tool_calls_acc = {} + token_count = 0 + finish_reason = None + final_usage_chunk = None + chunks = [] + response = await acompletion( messages=messages, tools=tools, tool_choice="auto", stream=True, stream_options={"include_usage": True}, - timeout=600, + timeout=_LLM_TIMEOUT, **llm_params, ) + + async for chunk in response: + chunks.append(chunk) + if session.is_cancelled: + tool_calls_acc.clear() + break + + choice = chunk.choices[0] if chunk.choices else None + if not choice: + if hasattr(chunk, "usage") and chunk.usage: + token_count = chunk.usage.total_tokens + final_usage_chunk = chunk + continue + + delta = choice.delta + if choice.finish_reason: + finish_reason = choice.finish_reason + + if delta.content: + full_content += delta.content + await session.send_event( + Event(event_type="assistant_chunk", data={"content": delta.content}) + ) + + if delta.tool_calls: + for tc_delta in delta.tool_calls: + idx = tc_delta.index + if idx not in tool_calls_acc: + tool_calls_acc[idx] = { + "id": "", "type": "function", + "function": {"name": "", "arguments": ""}, + } + if tc_delta.id: + tool_calls_acc[idx]["id"] = tc_delta.id + if tc_delta.function: + if tc_delta.function.name: + tool_calls_acc[idx]["function"]["name"] += tc_delta.function.name + if tc_delta.function.arguments: + tool_calls_acc[idx]["function"]["arguments"] += tc_delta.function.arguments + + if hasattr(chunk, "usage") and chunk.usage: + token_count = chunk.usage.total_tokens + final_usage_chunk = chunk + + # Success: exit the retry loop break + except ContextWindowExceededError: raise except Exception as e: @@ -779,57 +843,6 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> continue raise - full_content = "" - tool_calls_acc: dict[int, dict] = {} - token_count = 0 - finish_reason = None - final_usage_chunk = None - chunks = [] - should_replay_thinking = _should_replay_thinking_state(llm_params.get("model")) - - async for chunk in response: - chunks.append(chunk) - if session.is_cancelled: - tool_calls_acc.clear() - break - - choice = chunk.choices[0] if chunk.choices else None - if not choice: - if hasattr(chunk, "usage") and chunk.usage: - token_count = chunk.usage.total_tokens - final_usage_chunk = chunk - continue - - delta = choice.delta - if choice.finish_reason: - finish_reason = choice.finish_reason - - if delta.content: - full_content += delta.content - await session.send_event( - Event(event_type="assistant_chunk", data={"content": delta.content}) - ) - - if delta.tool_calls: - for tc_delta in delta.tool_calls: - idx = tc_delta.index - if idx not in tool_calls_acc: - tool_calls_acc[idx] = { - "id": "", "type": "function", - "function": {"name": "", "arguments": ""}, - } - if tc_delta.id: - tool_calls_acc[idx]["id"] = tc_delta.id - if tc_delta.function: - if tc_delta.function.name: - tool_calls_acc[idx]["function"]["name"] += tc_delta.function.name - if tc_delta.function.arguments: - tool_calls_acc[idx]["function"]["arguments"] += tc_delta.function.arguments - - if hasattr(chunk, "usage") and chunk.usage: - token_count = chunk.usage.total_tokens - final_usage_chunk = chunk - usage = await telemetry.record_llm_call( session, model=llm_params.get("model", session.config.model_name), @@ -873,7 +886,7 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params) tools=tools, tool_choice="auto", stream=False, - timeout=600, + timeout=_LLM_TIMEOUT, **llm_params, ) break From 2260dfd164b1224879fcc5e844ab39cf9913ae73 Mon Sep 17 00:00:00 2001 From: Luca Massaron Date: Fri, 1 May 2026 23:15:45 +0200 Subject: [PATCH 3/3] feat: persist and reuse the last used model --- agent/core/model_switcher.py | 8 ++++++ agent/main.py | 21 +++++++++++++--- agent/utils/persistence.py | 49 ++++++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 4 deletions(-) create mode 100644 agent/utils/persistence.py diff --git a/agent/core/model_switcher.py b/agent/core/model_switcher.py index 8f034220..7951324f 100644 --- a/agent/core/model_switcher.py +++ b/agent/core/model_switcher.py @@ -21,6 +21,7 @@ from agent.core.effort_probe import ProbeInconclusive, probe_effort from agent.utils.ollama_utils import ensure_ollama_readiness +from agent.utils.persistence import save_last_model, get_persisted_models # Suggested models shown by `/model` (not a gate). Users can paste any HF @@ -41,6 +42,11 @@ {"id": "ollama/qwen3.6", "label": "Qwen 3.6 (Local)"}, ] +# Load any additional models persisted by the user +for _pm in get_persisted_models(): + if not any(_m["id"] == _pm["id"] for _m in SUGGESTED_MODELS): + SUGGESTED_MODELS.append(_pm) + _ROUTING_POLICIES = {"fastest", "cheapest", "preferred"} @@ -244,3 +250,5 @@ def _commit_switch(model_id, config, session, effective, cache: bool) -> None: session.model_effective_effort.pop(model_id, None) else: config.model_name = model_id + + save_last_model(model_id) diff --git a/agent/main.py b/agent/main.py index c75eb884..a2f7a253 100644 --- a/agent/main.py +++ b/agent/main.py @@ -30,6 +30,7 @@ from agent.messaging.gateway import NotificationGateway from agent.utils.reliability_checks import check_training_script_save_pattern from agent.utils.ollama_utils import ensure_ollama_readiness +from agent.utils.persistence import get_last_model, save_last_model from agent.utils.terminal_display import ( get_console, print_approval_header, @@ -797,6 +798,11 @@ async def _handle_slash_command( return None label = model_id.replace("ollama/", "", 1).capitalize() + " (Local)" + + # Persist it so it shows up in future sessions + from agent.utils.persistence import add_persisted_model + add_persisted_model(model_id, label) + SUGGESTED_MODELS.append({"id": model_id, "label": label}) print(f"Added local model: {model_id}") print(f"Switch to it with: /model {model_id}") @@ -978,8 +984,12 @@ async def main(model: str | None = None): # Load config early to check model name config = load_config(CLI_CONFIG_PATH, include_user_defaults=True) - if model: - config.model_name = model + + # Resolve model: CLI arg > Cache > Config default + config.model_name = model or get_last_model() or config.model_name + + # Persist the choice so we remember it next time + save_last_model(config.model_name) # Check if Ollama is running and model is available (prompt to pull if not) if config.model_name.startswith("ollama/"): @@ -1234,8 +1244,11 @@ async def headless_main( config = load_config(CLI_CONFIG_PATH, include_user_defaults=True) - if model: - config.model_name = model + # Resolve model: CLI arg > Cache > Config default + config.model_name = model or get_last_model() or config.model_name + + # Persist the choice so we remember it next time + save_last_model(config.model_name) # Check if Ollama is running and model is available if config.model_name.startswith("ollama/"): diff --git a/agent/utils/persistence.py b/agent/utils/persistence.py new file mode 100644 index 00000000..72256081 --- /dev/null +++ b/agent/utils/persistence.py @@ -0,0 +1,49 @@ +import json +from pathlib import Path + +STATE_PATH = Path.home() / ".cache" / "ml-intern" / "state.json" + + +def save_last_model(model_name: str) -> None: + """Persist the last successfully used model name.""" + _update_state({"last_model": model_name}) + + +def get_last_model() -> str | None: + """Retrieve the last used model name.""" + return _get_state().get("last_model") + + +def add_persisted_model(model_id: str, label: str) -> None: + """Add a model to the persistent suggested models list.""" + state = _get_state() + models = state.get("added_models", []) + if not any(m["id"] == model_id for m in models): + models.append({"id": model_id, "label": label}) + _update_state({"added_models": models}) + + +def get_persisted_models() -> list[dict[str, str]]: + """Retrieve the list of manually added models.""" + return _get_state().get("added_models", []) + + +def _get_state() -> dict: + try: + if STATE_PATH.exists(): + with open(STATE_PATH, "r", encoding="utf-8") as f: + return json.load(f) + except Exception: + pass + return {} + + +def _update_state(updates: dict) -> None: + try: + STATE_PATH.parent.mkdir(parents=True, exist_ok=True) + state = _get_state() + state.update(updates) + with open(STATE_PATH, "w", encoding="utf-8") as f: + json.dump(state, f) + except Exception: + pass