diff --git a/anton/channel/branding.py b/anton/channel/branding.py index cf3419b..a05c882 100644 --- a/anton/channel/branding.py +++ b/anton/channel/branding.py @@ -61,8 +61,8 @@ def pick_tagline(seed: int | None = None) -> str: def _build_robot_text(mouth: str, bubble: str) -> Text: """Build the full robot as a Rich Text object with styling.""" - g = "bold cyan" - m = "dim" + g = "anton.glow" + m = "anton.muted" # Pad bubble to avoid layout jitter (longest phrase is ~16 chars) padded = bubble.ljust(16) lines = [ diff --git a/anton/chat.py b/anton/chat.py index b86e36c..0ec523b 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -1,8 +1,12 @@ from __future__ import annotations import asyncio +import json as _json import os +import re as _re import sys +import uuid +import yaml as _yaml import time from collections.abc import AsyncIterator, Callable from pathlib import Path @@ -40,6 +44,15 @@ format_cell_result, prepare_scratchpad_exec, ) +from anton.data_vault import DataVault, _slug_env_prefix +from anton.datasource_registry import ( + DatasourceEngine, + DatasourceField, + DatasourceRegistry, + _YAML_BLOCK_RE, +) + +from rich.prompt import Confirm, Prompt if TYPE_CHECKING: from rich.console import Console @@ -96,10 +109,15 @@ def __init__( self._console = console self._history: list[dict] = list(initial_history) if initial_history else [] self._pending_memory_confirmations: list = [] - self._turn_count = sum(1 for m in self._history if m.get("role") == "user") if initial_history else 0 + self._turn_count = ( + sum(1 for m in self._history if m.get("role") == "user") + if initial_history + else 0 + ) self._history_store = history_store self._session_id = session_id self._cancel_event = asyncio.Event() + self._active_datasource: str | None = None # slug like "hubspot-2" self._scratchpads = ScratchpadManager( coding_provider=coding_provider, coding_model=getattr(llm_client, "coding_model", ""), @@ -135,17 +153,19 @@ def repair_history(self) -> None: ] if not tool_ids: return - self._history.append({ - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": tid, - "content": "Cancelled by user.", - } - for tid in tool_ids - ], - }) + self._history.append( + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tid, + "content": "Cancelled by user.", + } + for tid in tool_ids + ], + } + ) def _persist_history(self) -> None: """Save current history to disk if a history store is configured.""" @@ -155,7 +175,9 @@ def _persist_history(self) -> None: async def _build_system_prompt(self, user_message: str = "") -> str: prompt = CHAT_SYSTEM_PROMPT.format( runtime_context=self._runtime_context, - visualizations_section=build_visualizations_prompt(self._proactive_dashboards), + visualizations_section=build_visualizations_prompt( + self._proactive_dashboards + ), ) # Inject memory context (replaces old self_awareness) if self._cortex is not None: @@ -172,31 +194,68 @@ async def _build_system_prompt(self, user_message: str = "") -> str: md_context = self._workspace.build_anton_md_context() if md_context: prompt += md_context + # Inject connected datasource context without credentials + ds_ctx = _build_datasource_context(active_only=self._active_datasource) + if ds_ctx: + prompt += ds_ctx return prompt # Packages the LLM is most likely to care about when writing scratchpad code. _NOTABLE_PACKAGES: set[str] = { - "numpy", "pandas", "matplotlib", "seaborn", "scipy", "scikit-learn", - "requests", "httpx", "aiohttp", "beautifulsoup4", "lxml", - "pillow", "sympy", "networkx", "sqlalchemy", "pydantic", - "rich", "tqdm", "click", "fastapi", "flask", "django", - "openai", "anthropic", "tiktoken", "transformers", "torch", - "polars", "pyarrow", "openpyxl", "xlsxwriter", - "plotly", "bokeh", "altair", - "pytest", "hypothesis", - "yaml", "pyyaml", "toml", "tomli", "tomllib", - "jinja2", "markdown", "pygments", - "cryptography", "paramiko", "boto3", + "numpy", + "pandas", + "matplotlib", + "seaborn", + "scipy", + "scikit-learn", + "requests", + "httpx", + "aiohttp", + "beautifulsoup4", + "lxml", + "pillow", + "sympy", + "networkx", + "sqlalchemy", + "pydantic", + "rich", + "tqdm", + "click", + "fastapi", + "flask", + "django", + "openai", + "anthropic", + "tiktoken", + "transformers", + "torch", + "polars", + "pyarrow", + "openpyxl", + "xlsxwriter", + "plotly", + "bokeh", + "altair", + "pytest", + "hypothesis", + "yaml", + "pyyaml", + "toml", + "tomli", + "tomllib", + "jinja2", + "markdown", + "pygments", + "cryptography", + "paramiko", + "boto3", } def _build_tools(self) -> list[dict]: scratchpad_tool = dict(SCRATCHPAD_TOOL) pkg_list = self._scratchpads._available_packages if pkg_list: - notable = sorted( - p for p in pkg_list - if p.lower() in self._NOTABLE_PACKAGES - ) + notable = sorted(p for p in pkg_list if p.lower() in self._NOTABLE_PACKAGES) if notable: pkg_line = ", ".join(notable) extra = f"\n\nInstalled packages ({len(pkg_list)} total, notable: {pkg_line})." @@ -208,7 +267,9 @@ def _build_tools(self) -> list[dict]: if self._cortex is not None: wisdom = self._cortex.get_scratchpad_context() if wisdom: - scratchpad_tool["description"] += f"\n\nLessons from past sessions:\n{wisdom}" + scratchpad_tool[ + "description" + ] += f"\n\nLessons from past sessions:\n{wisdom}" tools = [scratchpad_tool] if self._cortex is not None: @@ -216,6 +277,7 @@ def _build_tools(self) -> list[dict]: elif self._self_awareness is not None: # Legacy fallback from anton.tools import MEMORIZE_TOOL as _MT + tools.append(_MT) if self._episodic is not None and self._episodic.enabled: tools.append(RECALL_TOOL) @@ -253,8 +315,7 @@ async def _summarize_history(self) -> None: if not isinstance(content, list): break has_tool_result = any( - isinstance(b, dict) and b.get("type") == "tool_result" - for b in content + isinstance(b, dict) and b.get("type") == "tool_result" for b in content ) if not has_tool_result: break @@ -285,9 +346,13 @@ async def _summarize_history(self) -> None: if block.get("type") == "text": lines.append(f"[{role}]: {block['text'][:1000]}") elif block.get("type") == "tool_use": - lines.append(f"[{role}/tool_use]: {block.get('name', '')}({str(block.get('input', ''))[:500]})") + lines.append( + f"[{role}/tool_use]: {block.get('name', '')}({str(block.get('input', ''))[:500]})" + ) elif block.get("type") == "tool_result": - lines.append(f"[tool_result]: {str(block.get('content', ''))[:500]}") + lines.append( + f"[tool_result]: {str(block.get('content', ''))[:500]}" + ) old_text = "\n".join(lines) # Cap at ~8000 chars to avoid overloading the summarizer @@ -320,7 +385,11 @@ async def _summarize_history(self) -> None: # If the recent portion starts with a user message, insert a minimal # assistant separator to avoid consecutive user messages (API error). if recent_turns and recent_turns[0].get("role") == "user": - self._history = [summary_msg, {"role": "assistant", "content": "Understood."}, *recent_turns] + self._history = [ + summary_msg, + {"role": "assistant", "content": "Understood."}, + *recent_turns, + ] else: self._history = [summary_msg] + recent_turns @@ -367,15 +436,19 @@ async def turn(self, user_input: str | list[dict]) -> str: while response.tool_calls: tool_round += 1 if tool_round > _MAX_TOOL_ROUNDS: - self._history.append({"role": "assistant", "content": response.content or ""}) - self._history.append({ - "role": "user", - "content": ( - f"SYSTEM: You have used {_MAX_TOOL_ROUNDS} tool-call rounds on this turn. " - "Stop retrying. Summarize what you accomplished and what failed, " - "then tell the user what they can do to unblock the issue." - ), - }) + self._history.append( + {"role": "assistant", "content": response.content or ""} + ) + self._history.append( + { + "role": "user", + "content": ( + f"SYSTEM: You have used {_MAX_TOOL_ROUNDS} tool-call rounds on this turn. " + "Stop retrying. Summarize what you accomplished and what failed, " + "then tell the user what they can do to unblock the issue." + ), + } + ) response = await self._llm.plan( system=system, messages=self._history, @@ -387,12 +460,14 @@ async def turn(self, user_input: str | list[dict]) -> str: if response.content: assistant_content.append({"type": "text", "text": response.content}) for tc in response.tool_calls: - assistant_content.append({ - "type": "tool_use", - "id": tc.id, - "name": tc.name, - "input": tc.input, - }) + assistant_content.append( + { + "type": "tool_use", + "id": tc.id, + "name": tc.name, + "input": tc.input, + } + ) self._history.append({"role": "assistant", "content": assistant_content}) # Process each tool call via registry @@ -403,15 +478,21 @@ async def turn(self, user_input: str | list[dict]) -> str: except Exception as exc: result_text = f"Tool '{tc.name}' failed: {exc}" + result_text = _scrub_credentials(result_text) result_text = _apply_error_tracking( - result_text, tc.name, error_streak, resilience_nudged, + result_text, + tc.name, + error_streak, + resilience_nudged, ) - tool_results.append({ - "type": "tool_result", - "tool_use_id": tc.id, - "content": result_text, - }) + tool_results.append( + { + "type": "tool_result", + "tool_use_id": tc.id, + "content": result_text, + } + ) self._history.append({"role": "user", "content": tool_results}) @@ -446,13 +527,17 @@ async def turn(self, user_input: str | list[dict]) -> str: return reply - async def turn_stream(self, user_input: str | list[dict]) -> AsyncIterator[StreamEvent]: + async def turn_stream( + self, user_input: str | list[dict] + ) -> AsyncIterator[StreamEvent]: """Streaming version of turn(). Yields events as they arrive.""" self._history.append({"role": "user", "content": user_input}) # Log user input to episodic memory if self._episodic is not None: - content = user_input if isinstance(user_input, str) else str(user_input)[:2000] + content = ( + user_input if isinstance(user_input, str) else str(user_input)[:2000] + ) self._episodic.log_turn(self._turn_count + 1, "user", content) user_msg_str = user_input if isinstance(user_input, str) else "" @@ -465,7 +550,9 @@ async def turn_stream(self, user_input: str | list[dict]) -> AsyncIterator[Strea # Log assistant response to episodic memory if self._episodic is not None and assistant_text_parts: self._episodic.log_turn( - self._turn_count + 1, "assistant", "".join(assistant_text_parts)[:2000], + self._turn_count + 1, + "assistant", + "".join(assistant_text_parts)[:2000], ) # Identity extraction (Default Mode Network — every 5 turns) @@ -477,7 +564,9 @@ async def turn_stream(self, user_input: str | list[dict]) -> AsyncIterator[Strea # Periodic memory vacuum (Systems Consolidation) self._cortex.maybe_vacuum() - async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterator[StreamEvent]: + async def _stream_and_handle_tools( + self, user_message: str = "" + ) -> AsyncIterator[StreamEvent]: """Stream one LLM call, handle tool loops, yield all events.""" system = await self._build_system_prompt(user_message) tools = self._build_tools() @@ -520,7 +609,10 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato llm_response = response.response # Proactive compaction - if not _compacted_this_turn and llm_response.usage.context_pressure > _CONTEXT_PRESSURE_THRESHOLD: + if ( + not _compacted_this_turn + and llm_response.usage.context_pressure > _CONTEXT_PRESSURE_THRESHOLD + ): await self._summarize_history() self._compact_scratchpads() _compacted_this_turn = True @@ -543,15 +635,19 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato tool_round += 1 if tool_round > _MAX_TOOL_ROUNDS: _max_rounds_hit = True - self._history.append({"role": "assistant", "content": llm_response.content or ""}) - self._history.append({ - "role": "user", - "content": ( - f"SYSTEM: You have used {_MAX_TOOL_ROUNDS} tool-call rounds on this turn. " - "Stop retrying. Summarize what you accomplished and what failed, " - "then tell the user what they can do to unblock the issue." - ), - }) + self._history.append( + {"role": "assistant", "content": llm_response.content or ""} + ) + self._history.append( + { + "role": "user", + "content": ( + f"SYSTEM: You have used {_MAX_TOOL_ROUNDS} tool-call rounds on this turn. " + "Stop retrying. Summarize what you accomplished and what failed, " + "then tell the user what they can do to unblock the issue." + ), + } + ) async for event in self._llm.plan_stream( system=system, messages=self._history, @@ -562,15 +658,21 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato # Build assistant message with content blocks assistant_content: list[dict] = [] if llm_response.content: - assistant_content.append({"type": "text", "text": llm_response.content}) + assistant_content.append( + {"type": "text", "text": llm_response.content} + ) for tc in llm_response.tool_calls: - assistant_content.append({ - "type": "tool_use", - "id": tc.id, - "name": tc.name, - "input": tc.input, - }) - self._history.append({"role": "assistant", "content": assistant_content}) + assistant_content.append( + { + "type": "tool_use", + "id": tc.id, + "name": tc.name, + "input": tc.input, + } + ) + self._history.append( + {"role": "assistant", "content": assistant_content} + ) # Process each tool call tool_results: list[dict] = [] @@ -579,7 +681,9 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato if self._episodic is not None: tc_desc = str(tc.input)[:2000] self._episodic.log_turn( - self._turn_count + 1, "tool_call", tc_desc, + self._turn_count + 1, + "tool_call", + tc_desc, tool=tc.name, ) @@ -590,7 +694,13 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato if isinstance(prep, str): result_text = prep else: - pad, code, description, estimated_time, estimated_seconds = prep + ( + pad, + code, + description, + estimated_time, + estimated_seconds, + ) = prep # Signal intent + ETA before execution begins yield StreamTaskProgress( phase="scratchpad_start", @@ -598,8 +708,10 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato eta_seconds=estimated_seconds, ) import time as _time + _sp_t0 = _time.monotonic() from anton.scratchpad import Cell + cell = None async for item in pad.execute_streaming( code, @@ -620,18 +732,26 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato message=description or "Done", eta_seconds=_sp_elapsed, ) - result_text = format_cell_result(cell) if cell else "No result produced." + result_text = ( + format_cell_result(cell) + if cell + else "No result produced." + ) # Log scratchpad cell to episodic memory if self._episodic is not None and cell is not None: self._episodic.log_turn( - self._turn_count + 1, "scratchpad", + self._turn_count + 1, + "scratchpad", (cell.stdout or "")[:2000], description=description, ) else: result_text = await dispatch_tool(self, tc.name, tc.input) - if tc.name == "scratchpad" and tc.input.get("action") == "dump": + if ( + tc.name == "scratchpad" + and tc.input.get("action") == "dump" + ): yield StreamToolResult(content=result_text) result_text = ( "The full notebook has been displayed to the user above. " @@ -644,24 +764,34 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato # Log tool result to episodic memory if self._episodic is not None: self._episodic.log_turn( - self._turn_count + 1, "tool_result", result_text[:2000], + self._turn_count + 1, + "tool_result", + result_text[:2000], tool=tc.name, ) + result_text = _scrub_credentials(result_text) result_text = _apply_error_tracking( - result_text, tc.name, error_streak, resilience_nudged, + result_text, + tc.name, + error_streak, + resilience_nudged, ) - tool_results.append({ - "type": "tool_result", - "tool_use_id": tc.id, - "content": result_text, - }) + tool_results.append( + { + "type": "tool_result", + "tool_use_id": tc.id, + "content": result_text, + } + ) self._history.append({"role": "user", "content": tool_results}) # Signal that tools are done and LLM is now analyzing - yield StreamTaskProgress(phase="analyzing", message="Analyzing results...") + yield StreamTaskProgress( + phase="analyzing", message="Analyzing results..." + ) # Stream follow-up response = None @@ -696,7 +826,11 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato llm_response = response.response # Proactive compaction during tool loop - if not _compacted_this_turn and llm_response.usage.context_pressure > _CONTEXT_PRESSURE_THRESHOLD: + if ( + not _compacted_this_turn + and llm_response.usage.context_pressure + > _CONTEXT_PRESSURE_THRESHOLD + ): await self._summarize_history() self._compact_scratchpads() _compacted_this_turn = True @@ -716,18 +850,20 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato if continuation >= _MAX_CONTINUATIONS: # Budget exhausted — ask LLM to diagnose and present to user - self._history.append({ - "role": "user", - "content": ( - "SYSTEM: You have attempted to complete this task multiple times " - "but verification indicates it is still not done. Do NOT try again. " - "Instead:\n" - "1. Summarize exactly what was accomplished so far.\n" - "2. Identify the specific blocker or failure preventing completion.\n" - "3. Suggest concrete next steps the user can take to unblock this.\n" - "Be honest and specific — do not be vague about what went wrong." - ), - }) + self._history.append( + { + "role": "user", + "content": ( + "SYSTEM: You have attempted to complete this task multiple times " + "but verification indicates it is still not done. Do NOT try again. " + "Instead:\n" + "1. Summarize exactly what was accomplished so far.\n" + "2. Identify the specific blocker or failure preventing completion.\n" + "3. Suggest concrete next steps the user can take to unblock this.\n" + "Be honest and specific — do not be vague about what went wrong." + ), + } + ) yield StreamTaskProgress( phase="analyzing", message="Diagnosing incomplete task..." ) @@ -742,13 +878,15 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato # Ask the LLM to self-assess completion. # Use a copy of history with a trailing user message so models # that don't support assistant-prefill won't reject the request. - verify_messages = list(self._history) + [{ - "role": "user", - "content": ( - "SYSTEM: Evaluate whether the task the user originally requested " - "has been fully completed based on the conversation above." - ), - }] + verify_messages = list(self._history) + [ + { + "role": "user", + "content": ( + "SYSTEM: Evaluate whether the task the user originally requested " + "has been fully completed based on the conversation above." + ), + } + ] verification = await self._llm.plan( system=( "You are a task-completion verifier. Given the conversation, determine " @@ -774,15 +912,17 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato if "STATUS: STUCK" in status_text: # Stuck — inject diagnosis request and let the LLM explain reason = (verification.content or "").strip() - self._history.append({ - "role": "user", - "content": ( - f"SYSTEM: Task verification determined this task is stuck.\n" - f"Verifier assessment: {reason}\n\n" - "Explain to the user what went wrong, what you tried, and " - "suggest specific next steps they can take to unblock this." - ), - }) + self._history.append( + { + "role": "user", + "content": ( + f"SYSTEM: Task verification determined this task is stuck.\n" + f"Verifier assessment: {reason}\n\n" + "Explain to the user what went wrong, what you tried, and " + "suggest specific next steps they can take to unblock this." + ), + } + ) yield StreamTaskProgress( phase="analyzing", message="Diagnosing blocked task..." ) @@ -796,16 +936,18 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato # INCOMPLETE — continue working continuation += 1 reason = (verification.content or "").strip() - self._history.append({ - "role": "user", - "content": ( - f"SYSTEM: Task verification determined this task is not yet complete " - f"(attempt {continuation}/{_MAX_CONTINUATIONS}).\n" - f"Verifier assessment: {reason}\n\n" - "Continue working on the original request. Pick up where you left off " - "and finish the remaining work. Do not repeat work already done." - ), - }) + self._history.append( + { + "role": "user", + "content": ( + f"SYSTEM: Task verification determined this task is not yet complete " + f"(attempt {continuation}/{_MAX_CONTINUATIONS}).\n" + f"Verifier assessment: {reason}\n\n" + "Continue working on the original request. Pick up where you left off " + "and finish the remaining work. Do not repeat work already done." + ), + } + ) yield StreamTaskProgress( phase="analyzing", message=f"Task incomplete — continuing ({continuation}/{_MAX_CONTINUATIONS})...", @@ -897,6 +1039,147 @@ def _apply_error_tracking( return result_text +# DS_* var names whose values are known to be secret (passwords, tokens, keys). +# Populated at startup and after each successful connect. +_DS_SECRET_VARS: set[str] = set() + +# DS_* var names for **ALL** fields of registered engines. +_DS_KNOWN_VARS: set[str] = set() + + +def _reset_registered_ds_vars() -> None: + """Clear the DS_* var registries so they can be rebuilt from current vault state.""" + _DS_SECRET_VARS.clear() + _DS_KNOWN_VARS.clear() + + +def parse_connection_slug( + slug: str, + known_engines: list[str], + *, + vault: DataVault | None = None, +) -> tuple[str, str] | None: + """Split a connection slug into (engine, name) using longest-prefix matching. + + First tries each known registry engine longest-first so that 'sql-server-prod-db' is + correctly parsed as engine='sql-server', name='prod-db' rather than + engine='sql', name='server-prod-db'. + + If nothing matches and a vault is supplied, falls back to scanning vault + connections for an exact slug match — handles custom/unregistered engines. + + Returns None if no match found or name part is empty. + """ + for engine in sorted(known_engines, key=len, reverse=True): + prefix = engine + "-" + if slug.startswith(prefix) and len(slug) > len(prefix): + return (engine, slug[len(prefix):]) + + if vault is not None: + for conn in vault.list_connections(): + if f"{conn['engine']}-{conn['name']}" == slug: + return (conn["engine"], conn["name"]) + + return None + + +def _register_secret_vars( + engine_def: "DatasourceEngine", *, engine: str = "", name: str = "" +) -> None: + """Record which DS_* var names correspond to known/secret fields for engine_def. + + If engine and name are given, registers namespaced vars (DS_ENGINE_NAME__FIELD). + Otherwise registers flat vars (DS_FIELD) — for temporary test_snippet execution. + """ + all_fields = list(engine_def.fields) + for am in engine_def.auth_methods or []: + all_fields.extend(am.fields) + for f in all_fields: + if engine and name: + prefix = _slug_env_prefix(engine, name) + key = f"{prefix}__{f.name.upper()}" + else: + key = f"DS_{f.name.upper()}" + _DS_KNOWN_VARS.add(key) + if f.secret: + _DS_SECRET_VARS.add(key) + + +def _scrub_credentials(text: str) -> str: + """Remove secret DS_* values from scratchpad output before it reaches the LLM. + + Only redacts vars registered as secret via _register_secret_vars (driven by + DatasourceField.secret=true in datasources.md). Non-secret fields of known + engines (DS_HOST, DS_PORT, DS_BASE_URL, …) are left readable so the LLM can + reason about connection errors. For truly unknown DS_* vars (custom engines + not yet in the registry) the fallback scrubs any long value — conservative + but safe. + """ + for key in _DS_SECRET_VARS: + value = os.environ.get(key, "") + if not value: + continue + text = text.replace(value, f"[{key}]") + for key, value in os.environ.items(): + if not key.startswith("DS_") or key in _DS_KNOWN_VARS: + continue + # Length guard only for unknown DS_* vars (not registered secrets). + # Unknown vars are matched heuristically — a short value like "on" + # or "true" in a DS_ENABLE_X var should not be scrubbed. + # Registered secret vars bypass this check entirely. + if not value or len(value) <= 8: + continue + text = text.replace(value, f"[{key}]") + return text + + +def _build_datasource_context(active_only: str | None = None) -> str: + """Build a system-prompt section listing available DS_* env vars by name. + + Shows the LLM what data sources are connected and which environment + variable names to use — without exposing any credential values. + + If active_only is set, only the matching slug is included. + """ + try: + vault = DataVault() + conns = vault.list_connections() + except Exception: + return "" + if not conns: + return "" + lines = ["\n\n## Connected Data Sources"] + lines.append( + "Credentials are pre-injected as namespaced DS___ " + "environment variables. Use them directly in scratchpad code " + "(e.g. DS_POSTGRES_PROD_DB__HOST). " + "Never read ~/.anton/data_vault/ files directly.\n" + ) + for c in conns: + slug = f"{c['engine']}-{c['name']}" + if active_only and slug != active_only: + continue + fields = vault.load(c["engine"], c["name"]) or {} + prefix = _slug_env_prefix(c["engine"], c["name"]) + var_names = ", ".join(f"{prefix}__{k.upper()}" for k in fields) + lines.append(f"- `{slug}` ({c['engine']}) → {var_names}") + return "\n".join(lines) + + +def _restore_namespaced_env(vault: DataVault) -> None: + """Clear all DS_* vars, then reinject every saved connection as namespaced.""" + from anton.datasource_registry import DatasourceRegistry + + _reset_registered_ds_vars() + vault.clear_ds_env() + dreg = DatasourceRegistry() + for conn in vault.list_connections(): + vault.inject_env(conn["engine"], conn["name"]) # flat=False by default + edef = dreg.get(conn["engine"]) + if edef is not None: + _register_secret_vars(edef, engine=conn["engine"], name=conn["name"]) + + def _build_runtime_context(settings: AntonSettings) -> str: """Build runtime context string including Minds datasource info if configured.""" ctx = ( @@ -906,15 +1189,16 @@ def _build_runtime_context(settings: AntonSettings) -> str: f"- Workspace: {settings.workspace_path}\n" f"- Memory mode: {settings.memory_mode}" ) - if settings.minds_api_key and (settings.minds_mind_name or settings.minds_datasource): + if settings.minds_api_key and ( + settings.minds_mind_name or settings.minds_datasource + ): engine = settings.minds_datasource_engine or "unknown" ctx += f"\n\n**CONNECTED MIND (Minds):**\n" if settings.minds_mind_name: ctx += f"- Mind: {settings.minds_mind_name}\n" if settings.minds_datasource: ctx += ( - f"- Datasource: {settings.minds_datasource}\n" - f"- Engine: {engine}\n" + f"- Datasource: {settings.minds_datasource}\n" f"- Engine: {engine}\n" ) ctx += ( f"- Minds URL: {settings.minds_url}\n" @@ -956,7 +1240,8 @@ def _rebuild_session( runtime_context = _build_runtime_context(settings) api_key = ( - settings.anthropic_api_key if settings.coding_provider == "anthropic" + settings.anthropic_api_key + if settings.coding_provider == "anthropic" else settings.openai_api_key ) or "" return ChatSession( @@ -1006,17 +1291,34 @@ def _show_scope(label: str, hc) -> int: identity = hc.recall_identity() rules = hc.recall_rules() lessons_raw = hc._read_full_lessons() - rule_count = sum(1 for ln in rules.splitlines() if ln.strip().startswith("- ")) if rules else 0 - lesson_count = sum(1 for ln in lessons_raw.splitlines() if ln.strip().startswith("- ")) if lessons_raw else 0 + rule_count = ( + sum(1 for ln in rules.splitlines() if ln.strip().startswith("- ")) + if rules + else 0 + ) + lesson_count = ( + sum(1 for ln in lessons_raw.splitlines() if ln.strip().startswith("- ")) + if lessons_raw + else 0 + ) topics: list[str] = [] if hc._topics_dir.is_dir(): - topics = [p.stem for p in sorted(hc._topics_dir.iterdir()) if p.suffix == ".md"] + topics = [ + p.stem for p in sorted(hc._topics_dir.iterdir()) if p.suffix == ".md" + ] console.print(f" [anton.cyan]{label}[/] [dim]({hc._dir})[/]") if identity: - entries = [ln.strip()[2:] for ln in identity.splitlines() if ln.strip().startswith("- ")] + entries = [ + ln.strip()[2:] + for ln in identity.splitlines() + if ln.strip().startswith("- ") + ] if entries: - console.print(f" Identity: {', '.join(entries[:3])}" + (" ..." if len(entries) > 3 else "")) + console.print( + f" Identity: {', '.join(entries[:3])}" + + (" ..." if len(entries) > 3 else "") + ) else: console.print(" Identity: [dim](set)[/]") else: @@ -1147,7 +1449,9 @@ async def _handle_resume( new_session._turn_count = sum(1 for m in history if m.get("role") == "user") console.print() - console.print(f"[anton.success]Resumed session from {selected['date']} ({selected['turns']} turns)[/]") + console.print( + f"[anton.success]Resumed session from {selected['date']} ({selected['turns']} turns)[/]" + ) console.print() return new_session, sid @@ -1189,9 +1493,16 @@ async def _handle_setup( return session elif top_choice == "1": return await _handle_setup_models( - console, settings, workspace, state, - self_awareness, cortex, session, episodic=episodic, - history_store=history_store, session_id=session_id, + console, + settings, + workspace, + state, + self_awareness, + cortex, + session, + episodic=episodic, + history_store=history_store, + session_id=session_id, ) else: _handle_setup_memory(console, settings, workspace, cortex, episodic=episodic) @@ -1214,103 +1525,89 @@ async def _handle_setup_models( from rich.prompt import Prompt from anton.workspace import Workspace as _Workspace + from anton.cli import _SetupRetry, _setup_minds, _setup_other_provider # Always persist API keys and model settings to global ~/.anton/.env global_ws = _Workspace(Path.home()) + def _provider_label(provider: str) -> str: + if provider == "openai-compatible": + if settings.minds_url and "mdb.ai" in settings.minds_url: + return "Minds-Cloud" + return "Minds-Enterprise" + return provider.capitalize() + + def _model_label(model: str, role: str) -> str: + if model in ("_reason_", "_code_"): + return f"smart_router({role})" + return model + + provider_display = _provider_label(settings.planning_provider) + planning_display = _model_label(settings.planning_model, "planning") + coding_display = _model_label(settings.coding_model, "coding") + console.print() console.print("[anton.cyan]Current configuration:[/]") - console.print(f" Provider (planning): [bold]{settings.planning_provider}[/]") - console.print(f" Provider (coding): [bold]{settings.coding_provider}[/]") - console.print(f" Planning model: [bold]{settings.planning_model}[/]") - console.print(f" Coding model: [bold]{settings.coding_model}[/]") + console.print(f" Provider: [bold]{provider_display}[/]") + if planning_display == coding_display: + console.print(f" Model: [bold]{planning_display}[/]") + else: + console.print(f" Planning: [bold]{planning_display}[/]") + console.print(f" Coding: [bold]{coding_display}[/]") console.print() # --- Provider --- providers = {"1": "anthropic", "2": "openai", "3": "openai-compatible"} - current_num = {"anthropic": "1", "openai": "2", "openai-compatible": "3"}.get(settings.planning_provider, "1") + current_num = {"anthropic": "1", "openai": "2", "openai-compatible": "3"}.get( + settings.planning_provider, "1" + ) console.print("[anton.cyan]Available providers:[/]") - console.print(r" [bold]1[/] Anthropic (Claude) [dim]\[recommended][/]") - console.print(r" [bold]2[/] OpenAI (GPT / o-series) [dim]\[experimental][/]") - console.print(r" [bold]3[/] OpenAI-compatible (custom endpoint) [dim]\[experimental][/]") - console.print() - - choice = Prompt.ask( - "Select provider", - choices=["1", "2", "3"], - default=current_num, - console=console, + console.print( + r" [bold]1[/] Anthropic (Claude) [dim]\[recommended][/]" ) - provider = providers[choice] - - # --- Base URL (OpenAI-compatible only) --- - if provider == "openai-compatible": - current_base_url = settings.openai_base_url or "" - console.print() - base_url = Prompt.ask( - f"API base URL [dim](e.g. http://localhost:11434/v1)[/]", - default=current_base_url, - console=console, - ) - base_url = base_url.strip() - if base_url: - settings.openai_base_url = base_url - global_ws.set_secret("ANTON_OPENAI_BASE_URL", base_url) - - # --- API key --- - key_attr = "anthropic_api_key" if provider == "anthropic" else "openai_api_key" - current_key = getattr(settings, key_attr) or "" - masked = current_key[:4] + "..." + current_key[-4:] if len(current_key) > 8 else "***" - console.print() - api_key = Prompt.ask( - f"API key for {provider.title()} [dim](Enter to keep {masked})[/]", - default="", - console=console, + console.print( + r" [bold]2[/] OpenAI (GPT / o-series) [dim]\[experimental][/]" + ) + console.print( + r" [bold]3[/] OpenAI-compatible (custom endpoint) [dim]\[experimental][/]" ) - api_key = api_key.strip() + console.print() + # Use the same onboarding flow from cli.py - # --- Models --- - defaults = { - "anthropic": ("claude-sonnet-4-6", "claude-haiku-4-5-20251001"), - "openai": ("gpt-5-mini", "gpt-5-nano"), - } - default_planning, default_coding = defaults.get(provider, ("", "")) + def _print_choices(): + console.print(" [bold]1[/] [link=https://mdb.ai][anton.cyan]Minds-Cloud[/][/link] [anton.success](recommended)[/]") + console.print(" [bold]2[/] [anton.cyan]Minds-Enterprise Server[/]") + console.print(" [bold]3[/] [anton.cyan]Bring your own key[/] [anton.muted]Anthropic / OpenAI[/]") + console.print(" [bold]q[/] [anton.muted]Back[/]") + console.print() - console.print() - planning_model = Prompt.ask( - "Planning model", - default=settings.planning_model if provider == settings.planning_provider else default_planning, - console=console, - ) - coding_model = Prompt.ask( - "Coding model", - default=settings.coding_model if provider == settings.coding_provider else default_coding, - console=console, - ) + _print_choices() - # --- Persist to global ~/.anton/.env --- - settings.planning_provider = provider - settings.coding_provider = provider - settings.planning_model = planning_model - settings.coding_model = coding_model + while True: + choice = Prompt.ask( + "Choose LLM Provider", + choices=["1", "2", "3", "q"], + default="q", + console=console, + ) - global_ws.set_secret("ANTON_PLANNING_PROVIDER", provider) - global_ws.set_secret("ANTON_CODING_PROVIDER", provider) - global_ws.set_secret("ANTON_PLANNING_MODEL", planning_model) - global_ws.set_secret("ANTON_CODING_MODEL", coding_model) + if choice == "q": + return session - if api_key: - setattr(settings, key_attr, api_key) - key_name = f"ANTON_{provider.upper()}_API_KEY" - global_ws.set_secret(key_name, api_key) + try: + if choice == "1": + _setup_minds(settings, global_ws) + elif choice == "2": + _setup_minds(settings, global_ws, default_url=None) + else: + _setup_other_provider(settings, global_ws) + break + except _SetupRetry: + console.print() + _print_choices() + continue - # Validate that we actually have an API key for the chosen provider - final_key = getattr(settings, key_attr) - if not final_key: - console.print() - console.print(f"[anton.error]No API key set for {provider}. Configuration not applied.[/]") - console.print() - return session + global_ws.apply_env_to_process() console.print() console.print("[anton.success]Configuration updated.[/]") @@ -1345,9 +1642,15 @@ def _handle_setup_memory( # --- Memory mode --- console.print(" Memory mode:") - console.print(r" [bold]1[/] Autopilot — Anton decides what to remember [dim]\[recommended][/]") - console.print(r" [bold]2[/] Co-pilot — save obvious, confirm ambiguous [dim]\[selective][/]") - console.print(r" [bold]3[/] Off — never save memory (still reads existing) [dim]\[suppressed][/]") + console.print( + r" [bold]1[/] Autopilot — Anton decides what to remember [dim]\[recommended][/]" + ) + console.print( + r" [bold]2[/] Co-pilot — save obvious, confirm ambiguous [dim]\[selective][/]" + ) + console.print( + r" [bold]3[/] Off — never save memory (still reads existing) [dim]\[suppressed][/]" + ) console.print() mode_map = {"1": "autopilot", "2": "copilot", "3": "off"} @@ -1370,7 +1673,9 @@ def _handle_setup_memory( if episodic is not None: console.print() ep_status = "ON" if episodic.enabled else "OFF" - console.print(f" Episodic memory (conversation archive): Currently [bold]{ep_status}[/]") + console.print( + f" Episodic memory (conversation archive): Currently [bold]{ep_status}[/]" + ) toggle = Prompt.ask( " Toggle episodic memory? (y/n)", choices=["y", "n"], @@ -1381,7 +1686,9 @@ def _handle_setup_memory( new_state = not episodic.enabled episodic.enabled = new_state settings.episodic_memory = new_state - workspace.set_secret("ANTON_EPISODIC_MEMORY", "true" if new_state else "false") + workspace.set_secret( + "ANTON_EPISODIC_MEMORY", "true" if new_state else "false" + ) console.print(f" Episodic memory: [bold]{'ON' if new_state else 'OFF'}[/]") console.print() @@ -1464,7 +1771,10 @@ def _describe_minds_connection_error(err: Exception) -> tuple[str, str]: "Connection failed during TLS certificate verification.", "Common reasons: a self-signed, expired, or otherwise untrusted certificate.", ) - if isinstance(reason, (TimeoutError, socket.timeout)) or "timed out" in str(reason).lower(): + if ( + isinstance(reason, (TimeoutError, socket.timeout)) + or "timed out" in str(reason).lower() + ): return ( "Connection failed because the request timed out.", "Common reasons: the server is slow or unavailable, the URL is wrong, or there is a network path issue.", @@ -1507,7 +1817,10 @@ def _minds_request( req.add_header("Content-Type", "application/json") req.add_header("Accept", "application/json") # Browser-like headers to avoid Cloudflare bot detection - req.add_header("User-Agent", "Mozilla/5.0 (compatible; Anton/1.0; +https://github.com/mindsdb/anton)") + req.add_header( + "User-Agent", + "Mozilla/5.0 (compatible; Anton/1.0; +https://github.com/mindsdb/anton)", + ) req.add_header("Accept-Language", "en-US,en;q=0.9") req.add_header("Accept-Encoding", "identity") req.add_header("Connection", "keep-alive") @@ -1535,7 +1848,9 @@ def _minds_list_minds(base_url: str, api_key: str, verify: bool = True) -> list[ return data.get("minds", data if isinstance(data, list) else []) -def _minds_get_mind(base_url: str, api_key: str, mind_name: str, verify: bool = True) -> dict | None: +def _minds_get_mind( + base_url: str, api_key: str, mind_name: str, verify: bool = True +) -> dict | None: """Fetch a single mind's details from a Minds server.""" import json as _json @@ -1578,7 +1893,9 @@ def _minds_refresh_knowledge(settings: AntonSettings, cortex) -> None: cortex.project_hc._encode_with_lock(topic_path, topic_content, mode="write") -def _minds_list_datasources(base_url: str, api_key: str, verify: bool = True) -> list[dict]: +def _minds_list_datasources( + base_url: str, api_key: str, verify: bool = True +) -> list[dict]: """Fetch datasource list from a Minds server using stdlib urllib.""" import json as _json @@ -1597,11 +1914,13 @@ def _minds_test_llm(base_url: str, api_key: str, verify: bool = True) -> bool: import json as _json url = f"{base_url}/api/v1/chat/completions" - payload = _json.dumps({ - "model": "_code_", - "messages": [{"role": "user", "content": "ping"}], - "max_tokens": 1, - }).encode() + payload = _json.dumps( + { + "model": "_code_", + "messages": [{"role": "user", "content": "ping"}], + "max_tokens": 1, + } + ).encode() try: _minds_request(url, api_key, method="POST", payload=payload, verify=verify) @@ -1611,14 +1930,22 @@ def _minds_test_llm(base_url: str, api_key: str, verify: bool = True) -> bool: _MINDS_KEYS = { - "ANTON_MINDS_API_KEY", "ANTON_MINDS_URL", "ANTON_MINDS_MIND_NAME", - "ANTON_MINDS_DATASOURCE", "ANTON_MINDS_DATASOURCE_ENGINE", "ANTON_MINDS_SSL_VERIFY", + "ANTON_MINDS_API_KEY", + "ANTON_MINDS_URL", + "ANTON_MINDS_MIND_NAME", + "ANTON_MINDS_DATASOURCE", + "ANTON_MINDS_DATASOURCE_ENGINE", + "ANTON_MINDS_SSL_VERIFY", } _LLM_KEYS = { - "ANTON_PLANNING_PROVIDER", "ANTON_CODING_PROVIDER", - "ANTON_PLANNING_MODEL", "ANTON_CODING_MODEL", - "ANTON_ANTHROPIC_API_KEY", "ANTON_OPENAI_API_KEY", "ANTON_OPENAI_BASE_URL", + "ANTON_PLANNING_PROVIDER", + "ANTON_CODING_PROVIDER", + "ANTON_PLANNING_MODEL", + "ANTON_CODING_MODEL", + "ANTON_ANTHROPIC_API_KEY", + "ANTON_OPENAI_API_KEY", + "ANTON_OPENAI_BASE_URL", } _SECRET_PATTERNS = ("KEY", "TOKEN", "SECRET", "PAT", "PASSWORD") @@ -1635,6 +1962,7 @@ def _display_value(key: str, value: str) -> str: return value or "[dim][/]" +#TODO: The /data-connections menu is deprecated and will be removed in a future release. async def _handle_data_connections( console: Console, settings: AntonSettings, @@ -1642,7 +1970,7 @@ async def _handle_data_connections( session: ChatSession, ) -> ChatSession: """View and manage stored keys and connections across global and project vaults.""" - from rich.prompt import Confirm, Prompt + from rich.prompt import Prompt from anton.workspace import Workspace as _Workspace @@ -1653,7 +1981,9 @@ async def _handle_data_connections( # Merge with source tags: project keys override global for display, # but we track where each lives for writes/removals. - all_keys: dict[str, tuple[str, str, str]] = {} # key -> (value, source, scope_label) + all_keys: dict[str, tuple[str, str, str]] = ( + {} + ) # key -> (value, source, scope_label) for k, v in global_env.items(): all_keys[k] = (v, "global", "~/.anton/.env") for k, v in project_env.items(): @@ -1663,7 +1993,9 @@ async def _handle_data_connections( if not all_keys: console.print("[anton.warning]No connections or secrets configured.[/]") - console.print("[anton.muted]Use /connect to set up a Minds connection, or ask Anton to store a key.[/]") + console.print( + "[anton.muted]Use /connect to set up a Minds connection, or ask Anton to store a key.[/]" + ) console.print() return session @@ -1671,7 +2003,11 @@ def _print_table() -> list[tuple[str, str, str, str]]: """Print grouped key table and return flat list for menu selection.""" minds = {k: all_keys[k] for k in sorted(all_keys) if k in _MINDS_KEYS} llm = {k: all_keys[k] for k in sorted(all_keys) if k in _LLM_KEYS} - other = {k: all_keys[k] for k in sorted(all_keys) if k not in _MINDS_KEYS and k not in _LLM_KEYS} + other = { + k: all_keys[k] + for k in sorted(all_keys) + if k not in _MINDS_KEYS and k not in _LLM_KEYS + } flat: list[tuple[str, str, str, str]] = [] # (key, value, source, scope_label) idx = 1 @@ -1679,7 +2015,9 @@ def _print_table() -> list[tuple[str, str, str, str]]: if minds: console.print("[anton.cyan]Minds Connection[/]") for k, (v, src, lbl) in minds.items(): - console.print(f" [bold]{idx}[/] {k} = {_display_value(k, v)} [dim]({lbl})[/]") + console.print( + f" [bold]{idx}[/] {k} = {_display_value(k, v)} [dim]({lbl})[/]" + ) flat.append((k, v, src, lbl)) idx += 1 console.print() @@ -1687,7 +2025,9 @@ def _print_table() -> list[tuple[str, str, str, str]]: if llm: console.print("[anton.cyan]LLM Configuration[/]") for k, (v, src, lbl) in llm.items(): - console.print(f" [bold]{idx}[/] {k} = {_display_value(k, v)} [dim]({lbl})[/]") + console.print( + f" [bold]{idx}[/] {k} = {_display_value(k, v)} [dim]({lbl})[/]" + ) flat.append((k, v, src, lbl)) idx += 1 console.print() @@ -1695,7 +2035,9 @@ def _print_table() -> list[tuple[str, str, str, str]]: if other: console.print("[anton.cyan]Other Integrations[/]") for k, (v, src, lbl) in other.items(): - console.print(f" [bold]{idx}[/] {k} = {_display_value(k, v)} [dim]({lbl})[/]") + console.print( + f" [bold]{idx}[/] {k} = {_display_value(k, v)} [dim]({lbl})[/]" + ) flat.append((k, v, src, lbl)) idx += 1 console.print() @@ -1713,7 +2055,9 @@ def _print_table() -> list[tuple[str, str, str, str]]: console.print(" [bold]q[/] Back") console.print() - action = Prompt.ask("Select", choices=["1", "2", "3", "q"], default="q", console=console) + action = Prompt.ask( + "Select", choices=["1", "2", "3", "q"], default="q", console=console + ) if action == "q": console.print() @@ -1772,7 +2116,9 @@ def _print_table() -> list[tuple[str, str, str, str]]: continue key, _, src, lbl = flat[pick_idx] - if not Confirm.ask(f"Remove {key} from {lbl}?", default=False, console=console): + if not Confirm.ask( + f"Remove {key} from {lbl}?", default=False, console=console + ): console.print("[anton.muted]Cancelled.[/]") console.print() continue @@ -1786,14 +2132,20 @@ def _print_table() -> list[tuple[str, str, str, str]]: elif action == "3": # --- Add --- console.print() - new_key = Prompt.ask("Key name (e.g. HUBSPOT_API_KEY)", console=console).strip() + new_key = Prompt.ask( + "Key name (e.g. HUBSPOT_API_KEY)", console=console + ).strip() if not new_key: console.print("[anton.warning]Key name cannot be empty.[/]") console.print() continue if new_key in all_keys: - if not Confirm.ask(f"{new_key} already exists. Overwrite?", default=False, console=console): + if not Confirm.ask( + f"{new_key} already exists. Overwrite?", + default=False, + console=console, + ): console.print("[anton.muted]Cancelled.[/]") console.print() continue @@ -1816,7 +2168,11 @@ def _print_table() -> list[tuple[str, str, str, str]]: console=console, ) target_ws = global_ws if scope == "global" else workspace - scope_label = "~/.anton/.env" if scope == "global" else f"{workspace.base}/.anton/.env" + scope_label = ( + "~/.anton/.env" + if scope == "global" + else f"{workspace.base}/.anton/.env" + ) target_ws.set_secret(new_key, new_val) target_ws.apply_env_to_process() all_keys[new_key] = (new_val, scope, scope_label) @@ -1924,7 +2280,11 @@ async def _handle_connect( name = mind.get("name", "?") ds_list = mind.get("datasources", []) ds_count = len(ds_list) - ds_label = f"{ds_count} datasource{'s' if ds_count != 1 else ''}" if ds_count else "no datasources" + ds_label = ( + f"{ds_count} datasource{'s' if ds_count != 1 else ''}" + if ds_count + else "no datasources" + ) console.print(f" [bold]{i}[/] {name} [dim]({ds_label})[/]") console.print() @@ -1958,7 +2318,9 @@ async def _handle_connect( # --- Resolve engine type from datasources list --- if ds_name: try: - all_datasources = _minds_list_datasources(minds_url, api_key, verify=ssl_verify) + all_datasources = _minds_list_datasources( + minds_url, api_key, verify=ssl_verify + ) for ds in all_datasources: if ds.get("name") == ds_name: ds_engine = ds.get("engine", "unknown") @@ -1992,7 +2354,9 @@ async def _handle_connect( llm_ok = _minds_test_llm(minds_url, api_key, verify=ssl_verify) if llm_ok: - console.print("[anton.success]LLM endpoints available — using Minds server as LLM provider.[/]") + console.print( + "[anton.success]LLM endpoints available — using Minds server as LLM provider.[/]" + ) base_url = f"{minds_url.rstrip('/')}/api/v1" settings.openai_api_key = api_key settings.openai_base_url = base_url @@ -2008,7 +2372,9 @@ async def _handle_connect( global_ws.set_secret("ANTON_CODING_MODEL", "_code_") else: # Check if Anthropic key is already configured - has_anthropic = settings.anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY") + has_anthropic = settings.anthropic_api_key or os.environ.get( + "ANTHROPIC_API_KEY" + ) if not has_anthropic: anthropic_key = Prompt.ask("Anthropic API key (for LLM)", console=console) if anthropic_key.strip(): @@ -2025,14 +2391,21 @@ async def _handle_connect( global_ws.set_secret("ANTON_CODING_MODEL", "claude-haiku-4-5-20251001") console.print("[anton.success]Anthropic API key saved.[/]") else: - console.print("[anton.warning]No API key provided — LLM calls will not work.[/]") + console.print( + "[anton.warning]No API key provided — LLM calls will not work.[/]" + ) global_ws.apply_env_to_process() console.print() return _rebuild_session( - settings=settings, state=state, self_awareness=self_awareness, - cortex=cortex, workspace=workspace, console=console, episodic=episodic, + settings=settings, + state=state, + self_awareness=self_awareness, + cortex=cortex, + workspace=workspace, + console=console, + episodic=episodic, ) @@ -2068,30 +2441,55 @@ def _format_file_message(text: str, paths: list[Path], console: Console) -> str: # Skip very large files (>500KB) — just reference them if size > 512_000: - parts.append(f"\n\n(File too large to inline — {_human_size(size)}. " - f"Use the scratchpad to read it.)\n") + parts.append( + f'\n\n(File too large to inline — {_human_size(size)}. ' + f"Use the scratchpad to read it.)\n" + ) continue # Skip binary-looking files - if suffix in (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".ico", ".webp", - ".pdf", ".zip", ".tar", ".gz", ".exe", ".dll", ".so", - ".pyc", ".pyo", ".whl", ".egg", ".db", ".sqlite"): - parts.append(f"\n\n(Binary file — {_human_size(size)}. " - f"Use the scratchpad to process it.)\n") + if suffix in ( + ".png", + ".jpg", + ".jpeg", + ".gif", + ".bmp", + ".ico", + ".webp", + ".pdf", + ".zip", + ".tar", + ".gz", + ".exe", + ".dll", + ".so", + ".pyc", + ".pyo", + ".whl", + ".egg", + ".db", + ".sqlite", + ): + parts.append( + f'\n\n(Binary file — {_human_size(size)}. ' + f"Use the scratchpad to process it.)\n" + ) continue try: content = p.read_text(errors="replace") except Exception: - parts.append(f"\n\n(Could not read file.)\n") + parts.append(f'\n\n(Could not read file.)\n') continue - parts.append(f"\n\n{content}\n") + parts.append(f'\n\n{content}\n') return "\n".join(parts) -def _format_clipboard_image_message(uploaded: object, user_text: str = "") -> list[dict]: +def _format_clipboard_image_message( + uploaded: object, user_text: str = "" +) -> list[dict]: """Build a multimodal LLM message for a clipboard image upload. Returns a list of content blocks (image + text) so the LLM can see @@ -2100,7 +2498,11 @@ def _format_clipboard_image_message(uploaded: object, user_text: str = "") -> li """ import base64 - text = user_text.strip() if user_text else "I've pasted an image from my clipboard. Analyze it." + text = ( + user_text.strip() + if user_text + else "I've pasted an image from my clipboard. Analyze it." + ) text += ( f"\n\nThe image is also saved at: {uploaded.path}\n" f"({uploaded.width}x{uploaded.height}, {_human_size(uploaded.size_bytes)}). " @@ -2146,6 +2548,7 @@ async def _ensure_clipboard(console: Console) -> bool: return False console.print("[anton.muted]Installing Pillow...[/]") import subprocess + proc = await asyncio.get_event_loop().run_in_executor( None, lambda: subprocess.run( @@ -2168,7 +2571,9 @@ async def _ensure_clipboard(console: Console) -> bool: ), ) if proc.returncode == 0: - console.print("[anton.success]Pillow installed. Clipboard is now available.[/]") + console.print( + "[anton.success]Pillow installed. Clipboard is now available.[/]" + ) return True console.print("[anton.error]Failed to install Pillow.[/]") return False @@ -2182,37 +2587,1064 @@ def _human_size(nbytes: int) -> str: return f"{nbytes:.1f}TB" -def _print_slash_help(console: Console) -> None: - """Print available slash commands.""" - console.print() - console.print("[anton.cyan]Available commands:[/]") - console.print(" [bold]/connect[/] — Connect to a Minds server and select a mind") - console.print(" [bold]/data-connections[/] — View and manage stored keys and connections") - console.print(" [bold]/setup[/] — Configure models or memory settings") - console.print(" [bold]/memory[/] — Show memory status dashboard") - console.print(" [bold]/paste[/] — Attach clipboard image to your message") - console.print(" [bold]/resume[/] — Resume a previous chat session") - console.print(" [bold]/help[/] — Show this help message") - console.print(" [bold]exit[/] — Quit the chat") - console.print() +def _remove_engine_block(text: str, slug: str) -> str: + """Return *text* with any YAML datasource block for *slug* removed.""" + cleaned = [] + prev = 0 + for m in _YAML_BLOCK_RE.finditer(text): + try: + data = _yaml.safe_load(m.group(3)) + is_dup = isinstance(data, dict) and str(data.get("engine", "")) == slug + except Exception: + is_dup = False + if is_dup: + pre = text[prev : m.start()].rstrip() + pre = _re.sub(r"\n---\s*$", "", pre) + cleaned.append(pre) + else: + cleaned.append(text[prev : m.end()]) + prev = m.end() + cleaned.append(text[prev:]) + return "".join(cleaned) -class _EscapeWatcher: - """Detect Escape keypress during streaming via cbreak terminal mode.""" +async def _handle_add_custom_datasource( + console: Console, + name: str, + registry, + session: "ChatSession", +): + """Ask for the tool name, use the LLM to identify required fields, then collect credentials.""" - def __init__(self, on_cancel: Callable[[], None] | None = None) -> None: - self.cancelled = asyncio.Event() - self._on_cancel = on_cancel - self._task: asyncio.Task | None = None - self._old_settings: list | None = None - self._stop = False + console.print() + preamble = "[anton.cyan](anton)[/] " + if name: + tool_name = name + name_context = f"'{name}' isn't in my built-in list.\n " + else: + tool_name = Prompt.ask( + f"{preamble}What is the name of the tool or service?", + console=console, + ) + if not tool_name.strip(): + return None + tool_name = tool_name.strip() + name_context = "" + + user_answer = Prompt.ask( + f"{preamble}{name_context}How do you authenticate with it? " + "Describe what credentials you have (don't paste actual values)", + console=console, + ) + if not user_answer.strip(): + return None - async def __aenter__(self) -> _EscapeWatcher: - if sys.platform != "win32" and sys.stdin.isatty(): - self._task = asyncio.create_task(self._watch()) - return self + console.print() + console.print("[anton.muted] Got it — working out the connection details…[/]") - async def __aexit__(self, *exc: object) -> None: + try: + response = await session._llm.plan( + system="You are a data source connection expert.", + messages=[ + { + "role": "user", + "content": ( + f"The user wants to connect to {repr(tool_name)} and said: {user_answer}\n\n" + "Return ONLY valid JSON (no markdown fences, no commentary):\n" + '{"display_name":"Human-readable name","pip":"pip-package or empty string",' + '"test_snippet":"python code that tests the connection using os.environ vars DS_FIELDNAME (uppercase field name with DS_ prefix) and prints ok on success, or empty string if untestable",' + '"fields":[{"name":"snake_case_name","value":"value if given inline else empty",' + '"secret":true or false,"required":true or false,"description":"what it is"}]}' + ), + } + ], + max_tokens=1024, + ) + text = response.content.strip() + # Keep + text = _re.sub(r"^```[^\n]*\n|```\s*$", "", text, flags=_re.MULTILINE).strip() + data = _json.loads(text) + except Exception: + console.print( + "[anton.warning] Couldn't identify connection details. Try again.[/]" + ) + console.print() + return None + + test_snippet = str(data.get("test_snippet", "")).strip() + raw_fields = data.get("fields") or [] + fields: list[DatasourceField] = [] + for f in raw_fields: + if not isinstance(f, dict) or not f.get("name"): + continue + fields.append( + DatasourceField( + name=f["name"], + required=bool(f.get("required", True)), + secret=bool(f.get("secret", False)), + description=str(f.get("description", "")), + ) + ) + + if not fields: + console.print("[anton.warning] Couldn't identify any connection fields.[/]") + console.print() + return None + + display_name = str(data.get("display_name", name)) + pip_pkg = str(data.get("pip", "")) + + # Show summary + console.print() + console.print(" [bold]── What I'll save ──────────────────────────[/]") + credentials: dict[str, str] = {} + for f, raw in zip(fields, raw_fields): + inline_value = str(raw.get("value", "")).strip() + if f.secret and inline_value: + console.print( + f" • [bold]{f.name:<14}[/] (secret — provided, stored securely)" + ) + credentials[f.name] = inline_value + elif f.secret: + console.print( + f" • [bold]{f.name:<14}[/] (secret — I'll ask for this)" + ) + else: + val_display = inline_value or "[anton.muted][/]" + console.print(f" • [bold]{f.name:<14}[/] {val_display}") + if inline_value: + credentials[f.name] = inline_value + console.print() + + # Prompt for any secret fields not provided inline + for f, raw in zip(fields, raw_fields): + if not f.secret: + continue + if str(raw.get("value", "")).strip(): + continue + value = Prompt.ask( + f"[anton.cyan](anton)[/] {f.name}", + password=True, + console=console, + default="", + ) + if value: + credentials[f.name] = value + + # Prompt for any required non-secret fields not provided inline + for f, raw in zip(fields, raw_fields): + if f.secret: + continue + if not f.required: + continue + if f.name in credentials: + continue + value = Prompt.ask( + f"[anton.cyan](anton)[/] {f.name}", + console=console, + default="", + ) + if value: + credentials[f.name] = value + + # Offer to collect optional non-secret fields + for f, raw in zip(fields, raw_fields): + if f.secret or f.required or f.name in credentials: + continue + value = Prompt.ask( + f"[anton.cyan](anton)[/] {f.name} (optional — press Enter to skip)", + console=console, + default="", + ) + if value: + credentials[f.name] = value + + if not credentials: + console.print("[anton.warning] No credentials collected. Aborting.[/]") + console.print() + return None + + # Build engine slug and write definition to ~/.anton/datasources.md + slug = _re.sub(r"[^\w]", "_", display_name.lower()).strip("_") + field_lines = "\n".join( + f" - {{ name: {f.name}, required: {str(f.required).lower()}, " + f'secret: {str(f.secret).lower()}, description: "{f.description}" }}' + for f in fields + ) + test_snippet_yaml = "" + if test_snippet: + indented = "\n".join(f" {line}" for line in test_snippet.splitlines()) + test_snippet_yaml = f"test_snippet: |\n{indented}\n" + + yaml_block = ( + f"\n---\n\n## {display_name}\n" + "```yaml\n" + f"engine: {slug}\n" + f"display_name: {display_name}\n" + + (f"pip: {pip_pkg}\n" if pip_pkg else "") + + f"fields:\n{field_lines}\n" + + test_snippet_yaml + + "```\n" + ) + user_ds_path = Path("~/.anton/datasources.md").expanduser() + tmp_path = user_ds_path.with_suffix(".tmp") + + # Write to temp, validate it parses, then rename atomically + existing = ( + user_ds_path.read_text(encoding="utf-8") if user_ds_path.is_file() else "" + ) + + existing = _remove_engine_block(existing, slug) + + tmp_path.write_text(existing + yaml_block, encoding="utf-8") + + parsed = registry.validate_file(tmp_path) + if slug in parsed: + import shutil + + shutil.move(str(tmp_path), str(user_ds_path)) + else: + tmp_path.unlink(missing_ok=True) + console.print( + "[anton.warning]Could not validate engine definition — " + "credentials saved but engine not written to datasources.md.[/]" + ) + + registry.reload() + engine_def = registry.get(slug) + if engine_def is None: + # Fallback: construct inline so the flow can continue even if parse failed + engine_def = DatasourceEngine( + engine=slug, + display_name=display_name, + pip=pip_pkg, + fields=fields, + test_snippet=test_snippet, + ) + + # All required fields must be present before the caller saves credentials + missing_required = [f.name for f in fields if f.required and f.name not in credentials] + if missing_required: + console.print( + "[anton.warning] Cannot save — missing required fields: " + f"{', '.join(missing_required)}. Aborting.[/]" + ) + console.print() + return None + + return engine_def, credentials + + +async def _run_connection_test( + console: "Console", + scratchpads: "ScratchpadManager", + vault: "DataVault", + engine_def: "DatasourceEngine", + credentials: dict[str, str], + retry_fields: "list[DatasourceField]", +) -> bool: + """Inject flat DS_* vars, run engine_def.test_snippet, restore env. + + Returns True on success, False if the user declines retry after failure. + Mutates credentials in-place when the user re-enters secrets on retry. + """ + import os as _os + + while True: + console.print() + console.print("[anton.cyan](anton)[/] Got it. Testing connection…") + + vault.clear_ds_env() + for key, value in credentials.items(): + _os.environ[f"DS_{key.upper()}"] = value + _register_secret_vars(engine_def) # flat mode, for scrubbing during test + + try: + pad = await scratchpads.get_or_create("__datasource_test__") + await pad.reset() + if engine_def.pip: + await pad.install_packages([engine_def.pip]) + cell = await pad.execute(engine_def.test_snippet) + finally: + _restore_namespaced_env(vault) + + if cell.error or (cell.stdout.strip() != "ok" and cell.stderr.strip()): + error_text = cell.error or cell.stderr.strip() or cell.stdout.strip() + first_line = next( + (ln for ln in error_text.splitlines() if ln.strip()), error_text + ) + console.print() + console.print("[anton.warning](anton)[/] ✗ Connection failed.") + console.print() + console.print(f" Error: {first_line}") + console.print() + retry = ( + Prompt.ask( + "[anton.cyan](anton)[/] Would you like to re-enter your credentials? [y/n]", + console=console, + default="n", + ) + .strip() + .lower() + ) + if retry != "y": + return False + console.print() + for f in retry_fields: + if not f.secret: + continue + value = Prompt.ask( + f"[anton.cyan](anton)[/] {f.name}", + password=True, + console=console, + default="", + ) + if value: + credentials[f.name] = value + continue + + console.print("[anton.success] ✓ Connected successfully![/]") + return True + + +async def _handle_connect_datasource( + console: Console, + scratchpads: ScratchpadManager, + session: "ChatSession", + datasource_name: str | None = None, + prefill: str | None = None, +) -> "ChatSession": + """ + Connect a data source by entering credentials, either for a new name or re-entering for an existing one. + """ + + vault = DataVault() + registry = DatasourceRegistry() + + if datasource_name is not None: + _parsed = parse_connection_slug( + datasource_name, [e.engine for e in registry.all_engines()], vault=vault + ) + if _parsed is None: + console.print( + f"[anton.warning]Invalid slug '{datasource_name}'. " + "Expected format: engine-name.[/]" + ) + console.print() + return session + edit_engine, edit_name = _parsed + existing = vault.load(edit_engine, edit_name) + if existing is None: + console.print( + f"[anton.warning]No connection '{datasource_name}' found in Local Vault.[/]" + ) + console.print() + return session + engine_def = registry.get(edit_engine) + if engine_def is None: + console.print( + f"[anton.warning]Unknown engine '{edit_engine}'. " + "Cannot update credentials.[/]" + ) + console.print() + return session + + console.print() + console.print( + f"[anton.cyan](anton)[/] Editing [bold]\"{datasource_name}\"[/bold]" + f" ({engine_def.display_name})." + ) + console.print("[anton.muted] Press Enter to keep the current value.[/]") + console.print() + + # Detect which fields to present (handle auth_method=choice) + active_fields = engine_def.fields + if engine_def.auth_method == "choice" and engine_def.auth_methods: + for am in engine_def.auth_methods: + am_field_names = {af.name for af in am.fields} + if any(k in am_field_names for k in existing): + active_fields = am.fields + break + if not active_fields: + active_fields = engine_def.auth_methods[0].fields + + # Start from existing values; let user update field-by-field + credentials: dict[str, str] = dict(existing) + for f in active_fields: + current = existing.get(f.name, "") + prompt_label = f"[anton.cyan](anton)[/] {f.name}" + if not f.required: + prompt_label += " [anton.muted](optional)[/]" + + if f.secret: + masked = "••••••••" if current else "" + label = ( + f"{prompt_label} [anton.muted][{masked}][/]" + if masked + else prompt_label + ) + value = Prompt.ask(label, password=True, console=console, default="") + if value: + credentials[f.name] = value + # else: keep existing (already in credentials) + elif current: + value = Prompt.ask( + f"{prompt_label} [anton.muted][{current}][/]", + console=console, + default=current, + ) + credentials[f.name] = value if value else current + elif f.default: + value = Prompt.ask( + f"{prompt_label} [anton.muted][{f.default}][/]", + console=console, + default=f.default, + ) + if value: + credentials[f.name] = value + else: + value = Prompt.ask(prompt_label, console=console, default="") + if value: + credentials[f.name] = value + + if engine_def.test_snippet: + while True: + console.print() + console.print("[anton.cyan](anton)[/] Got it. Testing connection…") + + # Temporarily save credentials so inject_env(flat=True) can load them, + # then restore all namespaced env vars in the finally block. + vault.save(edit_engine, edit_name, credentials) + vault.clear_ds_env() + vault.inject_env(edit_engine, edit_name, flat=True) + _register_secret_vars(engine_def) # flat names for scrubbing during test + try: + pad = await scratchpads.get_or_create("__datasource_test__") + await pad.reset() + if engine_def.pip: + await pad.install_packages([engine_def.pip]) + cell = await pad.execute(engine_def.test_snippet) + finally: + _restore_namespaced_env(vault) + + if cell.error or (cell.stdout.strip() != "ok" and cell.stderr.strip()): + error_text = ( + cell.error or cell.stderr.strip() or cell.stdout.strip() + ) + first_line = next( + (ln for ln in error_text.splitlines() if ln.strip()), error_text + ) + console.print() + console.print("[anton.warning](anton)[/] ✗ Connection failed.") + console.print() + console.print(f" Error: {first_line}") + console.print() + retry = ( + Prompt.ask( + "[anton.cyan](anton)[/] Would you like to re-enter your credentials? [y/n]", + console=console, + default="n", + ) + .strip() + .lower() + ) + if retry != "y": + return session + console.print() + for f in active_fields: + if not f.secret: + continue + value = Prompt.ask( + f"[anton.cyan](anton)[/] {f.name}", + password=True, + console=console, + default="", + ) + if value: + credentials[f.name] = value + continue + + console.print("[anton.success] ✓ Connected successfully![/]") + break + + vault.save(edit_engine, edit_name, credentials) + _restore_namespaced_env(vault) + _register_secret_vars(engine_def, engine=edit_engine, name=edit_name) + console.print() + console.print( + f' Credentials updated for [bold]"{datasource_name}"[/bold].' + ) + console.print() + console.print( + "[anton.muted] You can now ask me questions about your data.[/]" + ) + console.print() + session._history.append( + { + "role": "assistant", + "content": ( + f"I've updated the credentials for the {engine_def.display_name} connection " + f'"{datasource_name}" in the Local Vault.' + ), + } + ) + return session + + console.print() + all_engines = registry.all_engines() + if prefill: + answer = prefill + else: + console.print( + "[anton.cyan](anton)[/] Choose a data source:\n" + ) + console.print(" [bold] Primary") + console.print(" [bold] 0.[/bold] Custom datasource (connect anything via API, SQL, or MCP)\n") + console.print(" [bold] Predefined") + for i, e in enumerate(all_engines, 1): + console.print(f" [bold]{i:>2}.[/bold] {e.display_name}") + console.print() + answer = Prompt.ask( + "[anton.cyan](anton)[/] Enter a number or type a name", + console=console, + ) + + stripped_answer = answer.strip() + known_slugs = {f"{c['engine']}-{c['name']}": c for c in vault.list_connections()} + if stripped_answer in known_slugs: + conn = known_slugs[stripped_answer] + _restore_namespaced_env(vault) + session._active_datasource = stripped_answer + recon_engine_def = registry.get(conn["engine"]) + if recon_engine_def: + _register_secret_vars(recon_engine_def, engine=conn["engine"], name=conn["name"]) + engine_label = recon_engine_def.display_name + else: + engine_label = conn["engine"] + console.print() + console.print( + f'[anton.success] ✓ Reconnected to [bold]"{stripped_answer}"[/bold].[/]' + ) + console.print() + session._history.append( + { + "role": "assistant", + "content": ( + f'I\'ve reconnected to the {engine_label} connection "{stripped_answer}" ' + f"in the Local Vault. I can now query this data source when needed." + ), + } + ) + return session + + engine_def: DatasourceEngine | None = None + custom_source = False + + if stripped_answer.isdigit() or (stripped_answer.lstrip("-").isdigit()): + pick_num = int(stripped_answer) + if pick_num == 0: + custom_source = True + elif 1 <= pick_num <= len(all_engines): + engine_def = all_engines[pick_num - 1] + else: + console.print( + f"[anton.warning](anton)[/] '{stripped_answer}' is out of range. " + f"Please enter 0–{len(all_engines)}.[/]" + ) + console.print() + return session + + if engine_def is None and not custom_source: + engine_def = registry.find_by_name(stripped_answer) + # if exact match not found, try substring match against display and engine names + if engine_def is None: + needle = stripped_answer.lower() + candidates = [ + e + for e in all_engines + if needle in e.display_name.lower() or needle in e.engine.lower() + ] + if len(candidates) == 1: + engine_def = candidates[0] + elif len(candidates) > 1: + console.print() + console.print( + f"[anton.warning](anton)[/] '{stripped_answer}' matches multiple engines — " + "which one did you mean?" + ) + console.print() + for i, e in enumerate(candidates, 1): + console.print(f" {i}. {e.display_name}") + console.print() + pick = Prompt.ask( + "[anton.cyan](anton)[/] Enter a number", + console=console, + ).strip() + try: + engine_def = candidates[int(pick) - 1] + except (ValueError, IndexError): + console.print("[anton.warning]Invalid choice. Aborting.[/]") + console.print() + return session + # fuzzy match against display and engine names if exact match not found + if engine_def is None: + fuzzy_matches = registry.fuzzy_find(stripped_answer) + for suggestion in fuzzy_matches: + console.print() + console.print( + f'[anton.cyan](anton)[/] Did you mean [bold]"{suggestion.display_name}"[/bold]?' + ) + confirm = ( + Prompt.ask( + "[anton.cyan](anton)[/] [y/n]", + console=console, + default="n", + ) + .strip() + .lower() + ) + if confirm == "y": + engine_def = suggestion + break + + if engine_def is None: + custom_source = True + + if custom_source: + result = await _handle_add_custom_datasource( + console, stripped_answer if not stripped_answer.isdigit() else "", registry, session + ) + if result is None: + return session + engine_def, credentials = result + if engine_def.test_snippet: + if not await _run_connection_test( + console, scratchpads, vault, engine_def, credentials, engine_def.fields + ): + return session + conn_name = uuid.uuid4().hex[:8] + vault.save(engine_def.engine, conn_name, credentials) + slug = f"{engine_def.engine}-{conn_name}" + _restore_namespaced_env(vault) + session._active_datasource = slug + _register_secret_vars(engine_def, engine=engine_def.engine, name=conn_name) + console.print( + f' Credentials saved to Local Vault as [bold]"{slug}"[/bold].' + ) + console.print() + console.print( + "[anton.muted] You can now ask me questions about your data.[/]" + ) + console.print() + session._history.append( + { + "role": "assistant", + "content": ( + f'I\'ve saved a {engine_def.display_name} connection named "{slug}" ' + f"to the Local Vault. I can now query this data source when needed." + ), + } + ) + return session + + assert engine_def is not None # custom_source path always returns before this line + active_fields = engine_def.fields + if engine_def.auth_method == "choice" and engine_def.auth_methods: + console.print() + console.print( + f"[anton.cyan](anton)[/] How would you like to authenticate with " + f"[bold]{engine_def.display_name}[/]?" + ) + console.print() + for i, am in enumerate(engine_def.auth_methods, 1): + console.print(f" {i}. {am.display}") + console.print() + choice_str = Prompt.ask( + "[anton.cyan](anton)[/] Enter a number", + console=console, + ).strip() + try: + choice_idx = int(choice_str) - 1 + chosen_method = engine_def.auth_methods[choice_idx] + except (ValueError, IndexError): + console.print("[anton.warning]Invalid choice. Aborting.[/]") + console.print() + return session + active_fields = chosen_method.fields + + required_fields = [f for f in active_fields if f.required] + optional_fields = [f for f in active_fields if not f.required] + + console.print() + console.print( + f"[anton.cyan](anton)[/] To connect [bold]{engine_def.display_name}[/], " + "I'll need the following:" + ) + console.print() + + if required_fields: + console.print(" [bold]Required[/] " + "─" * 39) + for f in required_fields: + console.print( + f" • [bold]{f.name:<12}[/] [anton.muted]— {f.description}[/]" + ) + + if optional_fields: + console.print() + console.print(" [bold]Optional[/] " + "─" * 39) + for f in optional_fields: + console.print( + f" • [bold]{f.name:<12}[/] [anton.muted]— {f.description}[/]" + ) + + console.print() + + mode_answer = ( + Prompt.ask( + "[anton.cyan](anton)[/] Do you have these available? [y/n/]", + console=console, + ) + .strip() + .lower() + ) + + if mode_answer == "n": + console.print() + console.print( + "[anton.cyan](anton)[/] No problem. Which parameters do you have? " + "I'll save a partial connection now, and you can fill in the rest later " + "with [bold]/edit[/]." + ) + console.print() + console.print(" Provide what you have (press enter to skip any field):") + console.print() + fields_to_collect = active_fields + partial = True + elif mode_answer == "y": + fields_to_collect = active_fields + partial = False + else: + # User gave a comma-separated list of param names — filter to those fields. + # If nothing matches (e.g. they typed a credential value by mistake), fall + # back to collecting all fields. + requested = {n.strip().lower() for n in mode_answer.split(",")} + matched = [f for f in active_fields if f.name.lower() in requested] + fields_to_collect = matched if matched else active_fields + partial = False + + console.print() + credentials: dict[str, str] = {} + + for f in fields_to_collect: + prompt_label = f"[anton.cyan](anton)[/] {f.name}" + if not f.required or partial: + prompt_label += " [anton.muted](optional, press enter to skip)[/]" + + if f.secret: + value = Prompt.ask(prompt_label, password=True, console=console, default="") + elif f.default: + value = Prompt.ask( + f"{prompt_label} [anton.muted][{f.default}][/]", + console=console, + default=f.default, + ) + else: + value = Prompt.ask(prompt_label, console=console, default="") + + if value: + credentials[f.name] = value + + if partial: + auto_name = uuid.uuid4().hex[:8] + vault.save(engine_def.engine, auto_name, credentials) + slug = f"{engine_def.engine}-{auto_name}" + console.print() + console.print( + f"[anton.muted]Partial connection saved to Local Vault as " + f'[bold]"{slug}"[/bold]. ' + f"Run [bold]/edit {slug}[/bold] to complete it when you're ready.[/]" + ) + console.print() + return session + + if engine_def.test_snippet: + if not await _run_connection_test( + console, scratchpads, vault, engine_def, credentials, active_fields + ): + return session + + conn_name = registry.derive_name(engine_def, credentials) + if not conn_name: + conn_name = uuid.uuid4().hex[:8] + + slug = f"{engine_def.engine}-{conn_name}" + + if vault.load(engine_def.engine, conn_name) is not None: + console.print() + console.print( + f'[anton.warning](anton)[/] A connection [bold]"{slug}"[/bold] already exists.' + ) + console.print() + choice = ( + Prompt.ask( + "[anton.cyan](anton)[/] [reconnect/cancel]", + console=console, + default="cancel", + ) + .strip() + .lower() + ) + if choice != "reconnect": + console.print("[anton.muted]Cancelled.[/]") + console.print() + return session + _restore_namespaced_env(vault) + _register_secret_vars(engine_def, engine=engine_def.engine, name=conn_name) + console.print() + console.print( + f'[anton.success] ✓ Reconnected to [bold]"{slug}"[/bold].[/]' + ) + console.print() + session._history.append( + { + "role": "assistant", + "content": ( + f'I\'ve reconnected to the {engine_def.display_name} connection "{slug}" ' + f"in the Local Vault. I can now query this data source when needed." + ), + } + ) + return session + + vault.save(engine_def.engine, conn_name, credentials) + _restore_namespaced_env(vault) + session._active_datasource = slug + _register_secret_vars(engine_def, engine=engine_def.engine, name=conn_name) + console.print(f' Credentials saved to Local Vault as [bold]"{slug}"[/bold].') + + console.print() + console.print( + "[anton.muted] You can now ask me questions about your data.[/]" + ) + console.print() + + # Inject a brief assistant message so the LLM is aware of the new connection + session._history.append( + { + "role": "assistant", + "content": ( + f'I\'ve saved a {engine_def.display_name} connection named "{slug}" ' + f"to the Local Vault. I can now query this data source when needed." + ), + } + ) + return session + + +def _handle_list_data_sources(console: Console) -> None: + """Print all saved Local Vault connections in a table with status.""" + from rich.table import Table + + vault = DataVault() + registry = DatasourceRegistry() + conns = vault.list_connections() + console.print() + if not conns: + console.print("[anton.muted]No data sources connected yet.[/]") + console.print("[anton.muted]Use /connect to add one.[/]") + console.print() + return + + table = Table(title="Local Vault — Saved Connections", show_lines=False) + table.add_column("Name", style="bold") + table.add_column("Source") + table.add_column("Status") + + for c in conns: + slug = f"{c['engine']}-{c['name']}" + engine_def = registry.get(c["engine"]) + source = engine_def.display_name if engine_def else c["engine"] + fields = vault.load(c["engine"], c["name"]) or {} + + if not fields: + status = "[yellow]incomplete[/]" + elif engine_def and engine_def.auth_method != "choice": + required = [f.name for f in engine_def.fields if f.required] + missing = [name for name in required if name not in fields] + status = "[yellow]incomplete[/]" if missing else "[green]saved[/]" + else: + # choice-auth engine or unknown engine: presence of any field = saved + status = "[green]saved[/]" + + table.add_row(slug, source, status) + + console.print(table) + console.print() + + +def _handle_remove_data_source(console: Console, slug: str) -> None: + """Delete a connection from the Local Vault by slug (engine-name).""" + vault = DataVault() + registry = DatasourceRegistry() + _parsed = parse_connection_slug(slug, [e.engine for e in registry.all_engines()], vault=vault) + if _parsed is None: + console.print( + f"[anton.warning]Invalid name '{slug}'. Use engine-name format.[/]" + ) + console.print() + return + engine, name = _parsed + if vault.load(engine, name) is None: + console.print(f"[anton.warning]No connection '{slug}' found.[/]") + console.print() + return + if Confirm.ask( + f"Remove '{slug}' from Local Vault?", default=False, console=console + ): + vault.delete(engine, name) + _restore_namespaced_env(vault) + console.print(f"[anton.success]Removed {slug}.[/]") + else: + console.print("[anton.muted]Cancelled.[/]") + console.print() + + +async def _handle_test_datasource( + console: Console, + scratchpads: ScratchpadManager, + slug: str, +) -> None: + """Test an existing Local Vault connection by running its test_snippet.""" + if not slug: + console.print( + "[anton.warning]Usage: /test [/]" + ) + console.print() + return + + vault = DataVault() + registry = DatasourceRegistry() + _parsed = parse_connection_slug(slug, [e.engine for e in registry.all_engines()], vault=vault) + if _parsed is None: + console.print( + f"[anton.warning]Invalid name '{slug}'. Use engine-name format.[/]" + ) + console.print() + return + engine, name = _parsed + fields = vault.load(engine, name) + if fields is None: + console.print( + f"[anton.warning]No connection '{slug}' found in Local Vault.[/]" + ) + console.print() + return + + engine_def = registry.get(engine) + if engine_def is None: + console.print( + f"[anton.warning]Unknown engine '{engine}'. Cannot test.[/]" + ) + console.print() + return + + if not engine_def.test_snippet: + console.print( + f"[anton.warning]No test snippet defined for '{engine}'. Cannot test.[/]" + ) + console.print() + return + + console.print() + console.print( + f"[anton.cyan](anton)[/] Testing connection [bold]{slug}[/bold]…" + ) + + vault.clear_ds_env() + vault.inject_env(engine, name, flat=True) + _register_secret_vars(engine_def) # flat names for scrubbing during test + + cell = None + try: + pad = await scratchpads.get_or_create("__datasource_test__") + await pad.reset() + if engine_def.pip: + await pad.install_packages([engine_def.pip]) + cell = await pad.execute(engine_def.test_snippet) + finally: + _restore_namespaced_env(vault) + + if cell is None or cell.error or ( + cell.stdout.strip() != "ok" and cell.stderr.strip() + ): + error_text = "" + if cell is not None: + error_text = cell.error or cell.stderr.strip() or cell.stdout.strip() + first_line = ( + next((ln for ln in error_text.splitlines() if ln.strip()), error_text) + if error_text + else "unknown error" + ) + console.print() + console.print( + f"[anton.warning](anton)[/] ✗ Connection test failed for" + f" [bold]{slug}[/bold]." + ) + console.print() + console.print(f" Error: {first_line}") + else: + console.print( + f"[anton.success] ✓ Connection test passed for" + f" [bold]{slug}[/bold]![/]" + ) + console.print() + + +def _print_slash_help(console: Console) -> None: + """Print available slash commands.""" + console.print() + console.print("[anton.cyan]Available commands:[/]") +# console.print( +# " [bold]/connect[/] — Connect to a Minds server and select a mind" +# ) + console.print( + " [bold]/connect[/] — Connect a database or API to the Local Vault" + ) + console.print( + " [bold]/list[/] — List all saved data source connections" + ) + console.print(" [bold]/edit[/] — Edit a saved connection's credentials") + console.print(" [bold]/remove[/] — Remove a saved connection") + console.print(" [bold]/test[/] — Test a saved connection") + console.print( + " [bold]/setup[/] — Configure models or memory settings" + ) + console.print(" [bold]/memory[/] — Show memory status dashboard") + console.print( + " [bold]/paste[/] — Attach clipboard image to your message" + ) + console.print(" [bold]/resume[/] — Resume a previous chat session") + console.print(" [bold]/help[/] — Show this help message") + console.print(" [bold]exit[/] — Quit the chat") + console.print() + + +class _EscapeWatcher: + """Detect Escape keypress during streaming via cbreak terminal mode.""" + + def __init__(self, on_cancel: Callable[[], None] | None = None) -> None: + self.cancelled = asyncio.Event() + self._on_cancel = on_cancel + self._task: asyncio.Task | None = None + self._old_settings: list | None = None + self._stop = False + + async def __aenter__(self) -> _EscapeWatcher: + if sys.platform != "win32" and sys.stdin.isatty(): + self._task = asyncio.create_task(self._watch()) + return self + + async def __aexit__(self, *exc: object) -> None: self._stop = True if self._task is not None: self._task.cancel() @@ -2302,8 +3734,12 @@ def start(self) -> None: from rich.spinner import Spinner from rich.text import Text - spinner = Spinner("dots", text=Text(" Closing scratchpad processes…", style="anton.muted")) - self._live = Live(spinner, console=self._console, refresh_per_second=6, transient=True) + spinner = Spinner( + "dots", text=Text(" Closing scratchpad processes…", style="anton.muted") + ) + self._live = Live( + spinner, console=self._console, refresh_per_second=6, transient=True + ) self._live.start() def stop(self) -> None: @@ -2312,12 +3748,16 @@ def stop(self) -> None: self._live = None -def run_chat(console: Console, settings: AntonSettings, *, resume: bool = False) -> None: +def run_chat( + console: Console, settings: AntonSettings, *, resume: bool = False +) -> None: """Launch the interactive chat REPL.""" asyncio.run(_chat_loop(console, settings, resume=resume)) -async def _chat_loop(console: Console, settings: AntonSettings, *, resume: bool = False) -> None: +async def _chat_loop( + console: Console, settings: AntonSettings, *, resume: bool = False +) -> None: from anton.context.self_awareness import SelfAwarenessContext from anton.llm.client import LLMClient from anton.memory.cortex import Cortex @@ -2333,6 +3773,17 @@ async def _chat_loop(console: Console, settings: AntonSettings, *, resume: bool workspace = Workspace(settings.workspace_path) workspace.apply_env_to_process() + # Inject all Local Vault connections as namespaced DS_* env vars so every + # scratchpad subprocess inherits them. Must happen before any ChatSession is created. + _dv = DataVault() + _dreg = DatasourceRegistry() + for _conn in _dv.list_connections(): + _dv.inject_env(_conn["engine"], _conn["name"]) # flat=False by default + _edef = _dreg.get(_conn["engine"]) + if _edef is not None: + _register_secret_vars(_edef, engine=_conn["engine"], name=_conn["name"]) + del _dv, _dreg + # --- Memory system (brain-inspired architecture) --- global_memory_dir = Path.home() / ".anton" / "memory" project_memory_dir = settings.workspace_path / ".anton" / "memory" @@ -2379,7 +3830,8 @@ async def _chat_loop(console: Console, settings: AntonSettings, *, resume: bool runtime_context = _build_runtime_context(settings) coding_api_key = ( - settings.anthropic_api_key if settings.coding_provider == "anthropic" + settings.anthropic_api_key + if settings.coding_provider == "anthropic" else settings.openai_api_key ) or "" session = ChatSession( @@ -2400,15 +3852,21 @@ async def _chat_loop(console: Console, settings: AntonSettings, *, resume: bool # Handle --resume flag at startup if resume: session, resumed_id = await _handle_resume( - console, settings, state, self_awareness, cortex, - workspace, session, episodic=episodic, + console, + settings, + state, + self_awareness, + cortex, + workspace, + session, + episodic=episodic, history_store=history_store, ) if resumed_id: current_session_id = resumed_id - console.print("[anton.muted] Chat with Anton. Type '/help' for commands or 'exit' to quit.[/]") + console.print("[anton.muted] Chat with me, type '/help' for commands or 'exit' to quit.[/]") console.print(f"[anton.cyan_dim] {'━' * 40}[/]") console.print() @@ -2436,9 +3894,11 @@ def _bottom_toolbar(): line = status + " " * gap + stats return HTML(f"\n") - pt_style = PTStyle.from_dict({ - "bottom-toolbar": "noreverse nounderline bg:default", - }) + pt_style = PTStyle.from_dict( + { + "bottom-toolbar": "noreverse nounderline bg:default", + } + ) prompt_session: PromptSession[str] = PromptSession( mouse_support=False, @@ -2455,7 +3915,11 @@ def _bottom_toolbar(): for i, engram in enumerate(pending, 1): console.print(f" [bold]{i}.[/] [{engram.kind}] {engram.text}") console.print() - confirm = console.input("[bold]Save to memory? (y/n/pick numbers):[/] ").strip().lower() + confirm = ( + console.input("[bold]Save to memory? (y/n/pick numbers):[/] ") + .strip() + .lower() + ) if confirm in ("y", "yes"): if cortex is not None: await cortex.encode(pending) @@ -2465,11 +3929,19 @@ def _bottom_toolbar(): else: # Parse number selections like "1 3" or "1,3" try: - nums = [int(x.strip()) for x in confirm.replace(",", " ").split() if x.strip().isdigit()] - selected = [pending[n - 1] for n in nums if 1 <= n <= len(pending)] + nums = [ + int(x.strip()) + for x in confirm.replace(",", " ").split() + if x.strip().isdigit() + ] + selected = [ + pending[n - 1] for n in nums if 1 <= n <= len(pending) + ] if selected and cortex is not None: await cortex.encode(selected) - console.print(f"[anton.muted]Saved {len(selected)} entries.[/]") + console.print( + f"[anton.muted]Saved {len(selected)} entries.[/]" + ) else: console.print("[anton.muted]Discarded.[/]") except (ValueError, IndexError): @@ -2522,17 +3994,28 @@ def _bottom_toolbar(): if message_content is None and stripped.startswith("/"): parts = stripped.split(maxsplit=1) cmd = parts[0].lower() - if cmd == "/connect": - session = await _handle_connect( - console, settings, workspace, state, - self_awareness, cortex, session, - episodic=episodic, - ) - continue - elif cmd == "/setup": +# if cmd == "/connect": +# session = await _handle_connect( +# console, +# settings, +# workspace, +# state, +# self_awareness, +# cortex, +# session, +# episodic=episodic, +# ) +# continue +# elif cmd == "/setup": + if cmd == "/setup": session = await _handle_setup( - console, settings, workspace, state, - self_awareness, cortex, session, + console, + settings, + workspace, + state, + self_awareness, + cortex, + session, episodic=episodic, history_store=history_store, session_id=current_session_id, @@ -2541,15 +4024,60 @@ def _bottom_toolbar(): elif cmd == "/memory": _handle_memory(console, settings, cortex, episodic=episodic) continue - elif cmd == "/data-connections": - session = await _handle_data_connections( - console, settings, workspace, session, + elif cmd == "/connect": + arg = parts[1].strip() if len(parts) > 1 else "" + session = await _handle_connect_datasource( + console, + session._scratchpads, + session, + prefill=arg or None, + ) + continue + elif cmd == "/list": + _handle_list_data_sources(console) + continue + elif cmd == "/remove": + arg = parts[1].strip() if len(parts) > 1 else "" + if not arg: + console.print( + "[anton.warning]Usage: /remove" + " [/]" + ) + console.print() + else: + _handle_remove_data_source(console, arg) + continue + elif cmd == "/edit": + arg = parts[1].strip() if len(parts) > 1 else "" + if not arg: + console.print( + "[anton.warning]Usage: /edit [/]" + ) + console.print() + else: + session = await _handle_connect_datasource( + console, + session._scratchpads, + session, + datasource_name=arg, + ) + continue + elif cmd == "/test": + arg = parts[1].strip() if len(parts) > 1 else "" + await _handle_test_datasource( + console, session._scratchpads, arg ) continue elif cmd == "/resume": session, resumed_id = await _handle_resume( - console, settings, state, self_awareness, cortex, - workspace, session, episodic=episodic, + console, + settings, + state, + self_awareness, + cortex, + workspace, + session, + episodic=episodic, history_store=history_store, ) if resumed_id: @@ -2570,7 +4098,9 @@ def _bottom_toolbar(): f"{_human_size(uploaded.size_bytes)})[/]" ) user_text = parts[1] if len(parts) > 1 else "" - message_content = _format_clipboard_image_message(uploaded, user_text) + message_content = _format_clipboard_image_message( + uploaded, user_text + ) # Fall through to turn_stream (don't continue) else: console.print("[anton.warning]No image found on clipboard.[/]") @@ -2639,6 +4169,7 @@ def _bottom_toolbar(): ) settings.anthropic_api_key = None from anton.cli import _ensure_api_key + _ensure_api_key(settings) session = _rebuild_session( settings=settings, diff --git a/anton/cli.py b/anton/cli.py index a701342..365973a 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -103,7 +103,9 @@ def _ensure_dependencies(console: Console) -> None: ): import subprocess - console.print(f"[anton.muted] Running: uv pip install {' '.join(missing)}[/]") + console.print( + f"[anton.muted] Running: uv pip install {' '.join(missing)}[/]" + ) result = subprocess.run( [uv, "pip", "install", "--python", sys.executable, *missing], capture_output=True, @@ -113,12 +115,18 @@ def _ensure_dependencies(console: Console) -> None: _reexec() else: console.print(f"[anton.error] Install failed:[/]") - console.print(result.stderr.decode() if result.stderr else result.stdout.decode()) + console.print( + result.stderr.decode() if result.stderr else result.stdout.decode() + ) if install_script.is_file(): if sys.platform == "win32": - console.print(f"\n[anton.muted] Or run the install script: powershell -File {install_script}[/]") + console.print( + f"\n[anton.muted] Or run the install script: powershell -File {install_script}[/]" + ) else: - console.print(f"\n[anton.muted] Or run the install script: sh {install_script}[/]") + console.print( + f"\n[anton.muted] Or run the install script: sh {install_script}[/]" + ) raise typer.Exit(0) elif install_script.is_file(): console.print(f"To install all dependencies, run:") @@ -133,12 +141,17 @@ def _ensure_dependencies(console: Console) -> None: console.print(f" [bold]pip install {' '.join(missing)}[/]") console.print() if sys.platform == "win32": - console.print("[anton.muted]Or reinstall anton: irm https://raw.githubusercontent.com/mindsdb/anton/main/install.ps1 | iex[/]") + console.print( + "[anton.muted]Or reinstall anton: irm https://raw.githubusercontent.com/mindsdb/anton/main/install.ps1 | iex[/]" + ) else: - console.print('[anton.muted]Or reinstall anton: curl -sSf https://raw.githubusercontent.com/mindsdb/anton/main/install.sh | sh && export PATH="$HOME/.local/bin:$PATH"[/]') + console.print( + '[anton.muted]Or reinstall anton: curl -sSf https://raw.githubusercontent.com/mindsdb/anton/main/install.sh | sh && export PATH="$HOME/.local/bin:$PATH"[/]' + ) console.print() raise typer.Exit(1) + app = typer.Typer( name="anton", help="Anton — a self-evolving autonomous system", @@ -206,6 +219,7 @@ def main( settings.resolve_workspace(folder) from anton.updater import check_and_update + if check_and_update(console, settings): # Re-exec with the freshly installed code so no old modules remain in memory. _reexec() @@ -214,98 +228,315 @@ def main( ctx.obj["settings"] = settings if ctx.invoked_subcommand is None: - from anton.channel.branding import render_banner from anton.chat import run_chat - render_banner(console) _ensure_workspace(settings) - _ensure_api_key(settings) + if not _has_api_key(settings): + _onboard(settings) + else: + from anton.channel.branding import render_banner + render_banner(console) run_chat(console, settings, resume=resume) def _has_api_key(settings) -> bool: - """Check if all configured providers have API keys.""" + """Check if any LLM provider is fully configured.""" providers = {settings.planning_provider, settings.coding_provider} for p in providers: - if p == "anthropic" and not (settings.anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY")): + if p == "anthropic" and not ( + settings.anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY") + ): return False - if p in ("openai", "openai-compatible") and not (settings.openai_api_key or os.environ.get("OPENAI_API_KEY")): + if p in ("openai", "openai-compatible") and not ( + settings.openai_api_key or os.environ.get("OPENAI_API_KEY") + ): return False return True -def _ensure_api_key(settings) -> None: - """Prompt the user to configure a provider and API key if none is set.""" - if _has_api_key(settings): - return +def _onboard(settings) -> None: + """First-time onboarding: animated robot talking the intro + LLM provider selection.""" + import sys + import time from rich.prompt import Prompt + from anton import __version__ from anton.workspace import Workspace ws = Workspace(Path.home()) - - if settings.minds_enabled: - _ensure_minds_api_key(settings, ws) + g = "anton.glow" + + _INTRO_LINES = [ + "Hi Boss! I'm Anton, your AI coworker.", + "", + "For the best experience, I recommend Minds-Cloud as your LLM Provider:", + "", + " \u2713 Smart model routing", + " \u2713 Faster responses", + " \u2713 Cost optimized", + " \u2713 Secure data connectors", + ] + + if sys.stdout.isatty(): + _animate_onboard(console, __version__, _INTRO_LINES, settings=settings, ws=ws) else: - _ensure_anthropic_api_key(settings, ws) + # Static fallback for non-interactive terminals + from anton.channel.branding import render_banner - # Reload env vars into the process so the scratchpad subprocess inherits them - ws.apply_env_to_process() + render_banner(console, animate=False) + console.print() + for line in _INTRO_LINES: + console.print(line) + +def _ensure_api_key(settings) -> None: + if not _has_api_key(settings): + _onboard(settings) + + +def _animate_onboard(console, version: str, intro_lines: list[str], *, settings, ws) -> None: + """Animate the robot talking while typing out the intro text below.""" + import time + + from rich.live import Live + from rich.text import Text + + from anton.channel.branding import ( + _MOUTH_SMILE, + _MOUTH_TALK, + _build_robot_text, + pick_tagline, + ) + + tagline = pick_tagline() + char_delay = 0.02 + line_pause = 0.15 + char_count = 0 # drives mouth animation + + def _build_frame(mouth: str, typed_lines: list[str]) -> Text: + """Build robot + separator + typed text as a single renderable.""" + frame = _build_robot_text(mouth, "\u2661\u2661\u2661\u2661") + frame.append(f" {'━' * 40}\n", style="bold cyan") + frame.append(f" v{version} \u2014 \"{tagline}\"\n", style="dim") + frame.append("\n") + frame.append("anton> ", style="bold cyan") + for line in typed_lines: + frame.append(line) + return frame + + with Live( + _build_frame(_MOUTH_SMILE, []), + console=console, + refresh_per_second=30, + transient=True, + ) as live: + time.sleep(0.4) + + typed_so_far: list[str] = [] + + for line_idx, line in enumerate(intro_lines): + if line == "": + typed_so_far.append("\n") + live.update(_build_frame(_MOUTH_SMILE, typed_so_far)) + time.sleep(line_pause) + continue + + # Type out each character + current = "" + for ch in line: + current += ch + char_count += 1 + mouth = _MOUTH_TALK[char_count % 2] + live.update(_build_frame(mouth, typed_so_far + [current])) + time.sleep(char_delay) + + typed_so_far.append(current + "\n") + live.update(_build_frame(_MOUTH_SMILE, typed_so_far)) + time.sleep(line_pause) + + # Hold final frame briefly + time.sleep(0.3) + + # Print the static final state + from anton.channel.branding import _render_robot_static + + _render_robot_static(console, "\u2661\u2661\u2661\u2661") + console.print(f"[anton.glow] {'━' * 40}[/]") + console.print(f" v{version} \u2014 [anton.muted]\"{tagline}\"[/]") console.print() - console.print(f"[anton.success]Saved to {ws.env_path}[/]") + console.print("[anton.cyan]anton>[/] ", end="") + first_text = True + for line in intro_lines: + if line == "": + if not first_text: + console.print() + elif line.startswith(" \u2713"): + first_text = False + console.print(f" [anton.success]\u2713[/] {line[4:]}") + else: + first_text = False + console.print(line) + + console.print() + console.print(f"[anton.glow] {'━' * 40}[/]") + console.print() + console.print(" [bold]1[/] [link=https://mdb.ai][anton.cyan]Minds-Cloud[/][/link] [anton.success](recommended)[/]") + console.print(" [bold]2[/] [anton.cyan]Minds-Enterprise Server[/]") + console.print(" [bold]3[/] [anton.cyan]Bring your own key[/] [anton.muted]Anthropic / OpenAI[/]") console.print() + while True: + choice = Prompt.ask( + "Choose LLM Provider", + choices=["1", "2", "3"], + default="1", + console=console, + ) -def _ensure_anthropic_api_key(settings, ws) -> None: - """Prompt for Anthropic API key (default flow).""" - from rich.prompt import Prompt + try: + if choice == "1": + _setup_minds(settings, ws) + elif choice == "2": + _setup_minds(settings, ws, default_url=None) + else: + _setup_other_provider(settings, ws) + break # success + except _SetupRetry: + console.print() + console.print(" [bold]1[/] [link=https://mdb.ai][anton.cyan]Minds-Cloud[/][/link] [anton.success](recommended)[/]") + console.print(" [bold]2[/] [anton.cyan]Minds-Enterprise Server[/]") + console.print(" [bold]3[/] [anton.cyan]Bring your own key[/] [anton.muted]Anthropic / OpenAI[/]") + console.print() + continue + + # Reload env vars so the scratchpad subprocess inherits them + ws.apply_env_to_process() + # Summary console.print() - console.print("[anton.cyan]Anthropic configuration[/]") + console.print(f"[anton.glow] {'━' * 40}[/]") + console.print() + provider_label = settings.planning_provider + model_label = settings.planning_model + if provider_label == "openai-compatible": + if settings.minds_url and "mdb.ai" in settings.minds_url: + provider_label = "Minds-Cloud" + else: + provider_label = "Minds-Enterprise Server" + model_label = "smart_router" + console.print(f" [anton.muted]Provider:[/] [anton.cyan]{provider_label}[/]") + console.print(f" [anton.muted]Model:[/] [anton.cyan]{model_label}[/]") console.print() - api_key = Prompt.ask("Anthropic API key", console=console) - if not api_key.strip(): - console.print("[anton.error]No API key provided. Exiting.[/]") - raise typer.Exit(1) - api_key = api_key.strip() - settings.anthropic_api_key = api_key - settings.planning_provider = "anthropic" - settings.coding_provider = "anthropic" - settings.planning_model = "claude-sonnet-4-6" - settings.coding_model = "claude-haiku-4-5-20251001" - ws.set_secret("ANTON_ANTHROPIC_API_KEY", api_key) - ws.set_secret("ANTON_PLANNING_PROVIDER", "anthropic") - ws.set_secret("ANTON_CODING_PROVIDER", "anthropic") - ws.set_secret("ANTON_PLANNING_MODEL", "claude-sonnet-4-6") - ws.set_secret("ANTON_CODING_MODEL", "claude-haiku-4-5-20251001") +class _SetupRetry(Exception): + """Raised by setup functions to go back to provider selection.""" + pass -def _ensure_minds_api_key(settings, ws) -> None: - """Prompt for Minds API key and configure LLM endpoints (opt-in flow).""" - from rich.prompt import Prompt +def _setup_prompt(label: str, default: str | None = None, is_password: bool = False) -> str: + """Prompt for input with ESC-to-go-back and a bottom toolbar hint. + + Returns the user's input string. + Raises _SetupRetry if the user presses ESC. + Works both from sync context (onboarding) and async context (/setup). + """ + import asyncio + + from prompt_toolkit import PromptSession + from prompt_toolkit.formatted_text import HTML + from prompt_toolkit.key_binding import KeyBindings + from prompt_toolkit.styles import Style as PTStyle + + _esc_pressed = False + + bindings = KeyBindings() + + @bindings.add("escape") + def _on_esc(event): + nonlocal _esc_pressed + _esc_pressed = True + event.app.exit(result="") + + pt_style = PTStyle.from_dict({ + "bottom-toolbar": "noreverse nounderline bg:default", + }) + + def _toolbar(): + return HTML("") + + suffix = f" ({default}): " if default else ": " + session: PromptSession[str] = PromptSession( + mouse_support=False, + bottom_toolbar=_toolbar, + style=pt_style, + key_bindings=bindings, + is_password=is_password, + ) + + # Use async prompt if inside a running event loop, sync otherwise + try: + asyncio.get_running_loop() + in_async = True + except RuntimeError: + in_async = False + + if in_async: + # We're inside an async context (e.g. /setup from chat loop) + # Run prompt_toolkit in a thread to avoid nested event loop conflict + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + future = pool.submit(session.prompt, f" {label}{suffix}") + result = future.result() + else: + result = session.prompt(f" {label}{suffix}") + + if _esc_pressed: + console.print(" [anton.muted]Going back...[/]") + raise _SetupRetry() + + if not result and default: + return default + return result + + +def _setup_minds(settings, ws, *, default_url: str | None = "https://mdb.ai") -> None: + """Set up Minds as the LLM provider (cloud or enterprise).""" + from rich.prompt import Confirm, Prompt + + import webbrowser - console.print() - console.print("[anton.cyan]Minds configuration[/]") console.print() - api_key = Prompt.ask("Minds API key", console=console) - if not api_key.strip(): - console.print("[anton.error]No API key provided. Exiting.[/]") - raise typer.Exit(1) - api_key = api_key.strip() + is_cloud = default_url == "https://mdb.ai" - minds_url = Prompt.ask( - "Minds URL", - default="https://mdb.ai", - console=console, - ).strip() - if not minds_url.startswith("http://") and not minds_url.startswith("https://"): - minds_url = "https://" + minds_url - minds_url = minds_url.rstrip("/") + if is_cloud: + minds_url = "https://mdb.ai" + else: + minds_url = _setup_prompt("Server URL", default=default_url).strip() + if not minds_url.startswith("http://") and not minds_url.startswith("https://"): + minds_url = "https://" + minds_url + minds_url = minds_url.rstrip("/") + + if is_cloud: + console.print(" [anton.muted]If you don't have an API key yet, we'll help you create one — it takes a few seconds.[/]") + console.print() + has_key = Confirm.ask( + " Do you have an mdb.ai API key?", + default=True, + console=console, + ) + if not has_key: + webbrowser.open(f"{minds_url}/apiKeys") + console.print() + + while True: + api_key = _setup_prompt("API key", is_password=True) + if api_key.strip(): + break + console.print(" [anton.warning]Please enter your API key.[/]") + api_key = api_key.strip() # Store Minds credentials settings.minds_api_key = api_key @@ -313,15 +544,35 @@ def _ensure_minds_api_key(settings, ws) -> None: ws.set_secret("ANTON_MINDS_API_KEY", api_key) ws.set_secret("ANTON_MINDS_URL", minds_url) - # Test if the Minds server supports LLM endpoints (_code_/_reason_) - # (silenced: was printing "Testing LLM endpoints..." and "not available" messages) + # Test connection with a spinner from anton.chat import _minds_test_llm - llm_ok = _minds_test_llm(minds_url, api_key, verify=True) - if not llm_ok: - llm_ok = _minds_test_llm(minds_url, api_key, verify=False) + + from rich.live import Live + from rich.spinner import Spinner + + ssl_verify = True + llm_ok = False + + with Live(Spinner("dots", text=" Connecting...", style="anton.cyan"), console=console, transient=True): + llm_ok = _minds_test_llm(minds_url, api_key, verify=True) + if not llm_ok: + llm_ok_no_ssl = _minds_test_llm(minds_url, api_key, verify=False) + if llm_ok_no_ssl: + ssl_verify = False + llm_ok = True + + if llm_ok and not ssl_verify: + console.print(" [anton.warning]SSL certificate verification failed.[/]") + skip_ssl = Confirm.ask( + " Continue without SSL verification?", + default=False, + console=console, + ) + if not skip_ssl: + llm_ok = False if llm_ok: - console.print("[anton.success]LLM endpoints available — using Minds server as LLM provider.[/]") + console.print(" [anton.success]Connected[/]") base_url = f"{minds_url}/api/v1" settings.openai_api_key = api_key settings.openai_base_url = base_url @@ -329,15 +580,139 @@ def _ensure_minds_api_key(settings, ws) -> None: settings.coding_provider = "openai-compatible" settings.planning_model = "_reason_" settings.coding_model = "_code_" + settings.minds_ssl_verify = ssl_verify ws.set_secret("ANTON_OPENAI_API_KEY", api_key) ws.set_secret("ANTON_OPENAI_BASE_URL", base_url) ws.set_secret("ANTON_PLANNING_PROVIDER", "openai-compatible") ws.set_secret("ANTON_CODING_PROVIDER", "openai-compatible") ws.set_secret("ANTON_PLANNING_MODEL", "_reason_") ws.set_secret("ANTON_CODING_MODEL", "_code_") + if not ssl_verify: + ws.set_secret("ANTON_MINDS_SSL_VERIFY", "false") + else: + console.print(" [anton.error]Could not connect. Check your API key and URL.[/]") + retry = Confirm.ask(" Try again?", default=True, console=console) + if retry: + _setup_minds(settings, ws, default_url=default_url) + else: + raise _SetupRetry() + + +def _setup_other_provider(settings, ws) -> None: + """Set up Anthropic or OpenAI as the LLM provider.""" + from rich.text import Text + + console.print() + for label, idx in [("Anthropic", "1"), ("OpenAI", "2")]: + line = Text() + line.append(f" {idx} ", style="bold") + line.append(label, style="anton.cyan") + console.print(line) + console.print() + + choice = _setup_prompt("Provider", default="1").strip().lower() + + if choice in ("1", "anthropic"): + _setup_anthropic(settings, ws) + elif choice in ("2", "openai"): + _setup_openai(settings, ws) else: - # LLM endpoints not available — fall back to Anthropic - _ensure_anthropic_api_key(settings, ws) + console.print(f" [anton.warning]Unknown provider '{choice}', using Anthropic.[/]") + _setup_anthropic(settings, ws) + + +def _validate_with_spinner(console, label: str, fn) -> None: + """Run a validation function with a spinner, print result.""" + from rich.live import Live + from rich.spinner import Spinner + + with Live(Spinner("dots", text=f" Validating {label}...", style="anton.cyan"), console=console, transient=True): + fn() + console.print(f" [anton.success]Validated[/] [anton.muted]{label}[/]") + + +def _setup_anthropic(settings, ws) -> None: + """Set up Anthropic with a single model for both reasoning and coding.""" + from rich.prompt import Confirm + + console.print() + while True: + api_key = _setup_prompt("API key", is_password=True) + if api_key.strip(): + break + console.print(" [anton.warning]Please enter your API key.[/]") + api_key = api_key.strip() + + model = _setup_prompt("Model", default="claude-sonnet-4-6").strip() + + try: + def _test(): + import anthropic + client = anthropic.Anthropic(api_key=api_key) + client.messages.create(model=model, max_tokens=1, messages=[{"role": "user", "content": "ping"}]) + + _validate_with_spinner(console, model, _test) + except Exception as exc: + console.print(f" [anton.error]Failed:[/] {exc}") + retry = Confirm.ask(" Try again?", default=True, console=console) + if retry: + _setup_anthropic(settings, ws) + return + else: + raise _SetupRetry() + + settings.anthropic_api_key = api_key + settings.planning_provider = "anthropic" + settings.coding_provider = "anthropic" + settings.planning_model = model + settings.coding_model = model + ws.set_secret("ANTON_ANTHROPIC_API_KEY", api_key) + ws.set_secret("ANTON_PLANNING_PROVIDER", "anthropic") + ws.set_secret("ANTON_CODING_PROVIDER", "anthropic") + ws.set_secret("ANTON_PLANNING_MODEL", model) + ws.set_secret("ANTON_CODING_MODEL", model) + + +def _setup_openai(settings, ws) -> None: + """Set up OpenAI with a single model for both reasoning and coding.""" + from rich.prompt import Confirm + + console.print() + while True: + api_key = _setup_prompt("API key", is_password=True) + if api_key.strip(): + break + console.print(" [anton.warning]Please enter your API key.[/]") + api_key = api_key.strip() + + model = _setup_prompt("Model", default="gpt-4o").strip() + + try: + def _test(): + import openai + client = openai.OpenAI(api_key=api_key) + client.chat.completions.create(model=model, max_tokens=1, messages=[{"role": "user", "content": "ping"}]) + + _validate_with_spinner(console, model, _test) + except Exception as exc: + console.print(f" [anton.error]Failed:[/] {exc}") + retry = Confirm.ask(" Try again?", default=True, console=console) + if retry: + _setup_openai(settings, ws) + return + else: + raise _SetupRetry() + + settings.openai_api_key = api_key + settings.planning_provider = "openai" + settings.coding_provider = "openai" + settings.planning_model = model + settings.coding_model = model + ws.set_secret("ANTON_OPENAI_API_KEY", api_key) + ws.set_secret("ANTON_PLANNING_PROVIDER", "openai") + ws.set_secret("ANTON_CODING_PROVIDER", "openai") + ws.set_secret("ANTON_PLANNING_MODEL", model) + ws.set_secret("ANTON_CODING_MODEL", model) @app.command("setup") @@ -345,7 +720,7 @@ def setup(ctx: typer.Context) -> None: """Configure provider, model, and API key.""" settings = _get_settings(ctx) _ensure_workspace(settings) - _ensure_api_key(settings) + _onboard(settings) console.print("[anton.success]Setup complete.[/]") @@ -440,3 +815,144 @@ def list_learnings(ctx: typer.Context) -> None: def version() -> None: """Show Anton version.""" console.print(f"Anton v{__version__}") + + +@app.command("connect") +def connect_data_source( + ctx: typer.Context, + slug: str = typer.Argument( + default="", help="Existing connection slug to reconnect (e.g. postgres-mydb)." + ), +) -> None: + """Connect a database or API to the Local Vault. + + Pass an existing connection slug (e.g. postgres-mydb) to reconnect using + stored credentials without re-entering them. Use /edit to + update credentials for an existing connection. + """ + import asyncio + + from anton.chat import ChatSession, _handle_connect_datasource + from anton.llm.client import LLMClient + from anton.scratchpad import ScratchpadManager + + settings = _get_settings(ctx) + _ensure_workspace(settings) + _ensure_api_key(settings) + + llm_client = LLMClient.from_settings(settings) + scratchpads = ScratchpadManager( + coding_provider=settings.coding_provider, + coding_model=settings.coding_model, + coding_api_key=( + settings.anthropic_api_key + if settings.coding_provider == "anthropic" + else settings.openai_api_key + ) + or "", + ) + session = ChatSession(llm_client) + + async def _run() -> None: + await _handle_connect_datasource( + console, + scratchpads, + session, + datasource_name=slug or None, + ) + await scratchpads.close_all() + + asyncio.run(_run()) + + +@app.command("list") +def list_data_sources(ctx: typer.Context) -> None: + """List all saved data source connections in the Local Vault.""" + from anton.chat import _handle_list_data_sources + + _handle_list_data_sources(console) + + +@app.command("edit") +def edit_data_source( + ctx: typer.Context, + name: str = typer.Argument(..., help="Connection slug to edit (e.g. postgres-mydb)."), +) -> None: + """Edit credentials for an existing Local Vault connection.""" + import asyncio + + from anton.chat import ChatSession, _handle_connect_datasource + from anton.llm.client import LLMClient + from anton.scratchpad import ScratchpadManager + + settings = _get_settings(ctx) + _ensure_workspace(settings) + _ensure_api_key(settings) + + llm_client = LLMClient.from_settings(settings) + scratchpads = ScratchpadManager( + coding_provider=settings.coding_provider, + coding_model=settings.coding_model, + coding_api_key=( + settings.anthropic_api_key + if settings.coding_provider == "anthropic" + else settings.openai_api_key + ) + or "", + ) + session = ChatSession(llm_client) + + async def _run() -> None: + await _handle_connect_datasource( + console, + scratchpads, + session, + datasource_name=name, + ) + await scratchpads.close_all() + + asyncio.run(_run()) + + +@app.command("remove") +def remove_data_source( + ctx: typer.Context, + name: str = typer.Argument(..., help="Connection slug to remove (e.g. postgres-mydb)."), +) -> None: + """Remove a saved connection from the Local Vault.""" + from anton.chat import _handle_remove_data_source + + _handle_remove_data_source(console, name) + + +@app.command("test") +def test_data_source( + ctx: typer.Context, + name: str = typer.Argument(..., help="Connection slug to test (e.g. postgres-mydb)."), +) -> None: + """Test a saved Local Vault connection using its test snippet.""" + import asyncio + + from anton.chat import _handle_test_datasource + from anton.scratchpad import ScratchpadManager + + settings = _get_settings(ctx) + _ensure_workspace(settings) + _ensure_api_key(settings) + + scratchpads = ScratchpadManager( + coding_provider=settings.coding_provider, + coding_model=settings.coding_model, + coding_api_key=( + settings.anthropic_api_key + if settings.coding_provider == "anthropic" + else settings.openai_api_key + ) + or "", + ) + + async def _run() -> None: + await _handle_test_datasource(console, scratchpads, name) + await scratchpads.close_all() + + asyncio.run(_run()) diff --git a/anton/config/settings.py b/anton/config/settings.py index 1aedfc2..ea4f7d0 100644 --- a/anton/config/settings.py +++ b/anton/config/settings.py @@ -51,7 +51,7 @@ class AntonSettings(BaseSettings): disable_autoupdates: bool = False # Minds datasource integration - minds_enabled: bool = False # opt-in: use Minds server as LLM provider + minds_enabled: bool = True # use Minds server as LLM provider minds_api_key: str | None = None minds_url: str = "https://mdb.ai" minds_mind_name: str | None = None diff --git a/anton/data_vault.py b/anton/data_vault.py new file mode 100644 index 0000000..61f55cc --- /dev/null +++ b/anton/data_vault.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import json +import os +import re +from datetime import datetime, timezone +from pathlib import Path + + +def _sanitize(value: str) -> str: + """Strip characters unsafe for file names, keep alphanumeric, dash, underscore.""" + return re.sub(r"[^\w\-]", "_", value).strip("_") + + +def _slug_env_prefix(engine: str, name: str) -> str: + """Return the DS_ prefix for a namespaced connection env var. + + Examples: + engine="postgres", name="prod_db" → "DS_POSTGRES_PROD_DB" + engine="hubspot", name="main" → "DS_HUBSPOT_MAIN" + engine="postgres", name="prod-db.eu" → "DS_POSTGRES_PROD_DB_EU" + """ + raw = f"{engine}-{name}" + return "DS_" + re.sub(r"[^\w]", "_", raw).upper() + + +class DataVault: + """Manages data source connection credentials in ~/.anton/data_vault/.""" + + def __init__(self, vault_dir: Path | None = None) -> None: + self._dir = vault_dir or Path("~/.anton/data_vault").expanduser() + + def _path_for(self, engine: str, name: str) -> Path: + return self._dir / f"{_sanitize(engine)}-{_sanitize(name)}" + + def _ensure_dir(self) -> None: + self._dir.mkdir(parents=True, exist_ok=True) + self._dir.chmod(0o700) + + def save(self, engine: str, name: str, credentials: dict[str, str]) -> Path: + """Write credentials as JSON atomically. Creates vault dir if needed.""" + self._ensure_dir() + path = self._path_for(engine, name) + data = { + "engine": engine, + "name": name, + "created_at": datetime.now(timezone.utc).isoformat(), + "fields": credentials, + } + tmp = path.with_suffix(".tmp") + tmp.write_text(json.dumps(data, indent=2), encoding="utf-8") + tmp.chmod(0o600) + tmp.rename(path) + return path + + def load(self, engine: str, name: str) -> dict[str, str] | None: + """Return the fields dict for a connection, or None if not found.""" + path = self._path_for(engine, name) + if not path.is_file(): + return None + try: + data = json.loads(path.read_text(encoding="utf-8")) + return data.get("fields", {}) + except (json.JSONDecodeError, OSError): + return None + + def delete(self, engine: str, name: str) -> bool: + """Remove a connection file. Returns True if it existed.""" + path = self._path_for(engine, name) + if path.is_file(): + path.unlink() + return True + return False + + def list_connections(self) -> list[dict[str, str]]: + """Return [{engine, name, created_at}] for all stored connections.""" + if not self._dir.is_dir(): + return [] + results: list[dict[str, str]] = [] + for path in sorted(self._dir.iterdir()): + if not path.is_file(): + continue + try: + data = json.loads(path.read_text(encoding="utf-8")) + results.append( + { + "engine": data.get("engine", ""), + "name": data.get("name", ""), + "created_at": data.get("created_at", ""), + } + ) + except (json.JSONDecodeError, OSError): + continue + return results + + def inject_env(self, engine: str, name: str, *, flat: bool = False) -> list[str] | None: + """Load credentials and set DS_* environment variables. + + Default (flat=False): injects namespaced vars, e.g. DS_POSTGRES_PROD_DB__HOST. + flat=True: injects legacy flat vars, e.g. DS_HOST — use only during + single-connection test_snippet execution. + + Returns the list of env var names set, or None if connection not found. + """ + fields = self.load(engine, name) + if fields is None: + return None + var_names: list[str] = [] + if flat: + for key, value in fields.items(): + var = f"DS_{key.upper()}" + os.environ[var] = value + var_names.append(var) + else: + prefix = _slug_env_prefix(engine, name) + for key, value in fields.items(): + var = f"{prefix}__{key.upper()}" + os.environ[var] = value + var_names.append(var) + return var_names + + def clear_ds_env(self) -> None: + """Remove all DS_* variables from os.environ.""" + ds_keys = [k for k in os.environ if k.startswith("DS_")] + for key in ds_keys: + del os.environ[key] + + def next_connection_number(self, engine: str) -> int: + """Return the next auto-increment number for an engine (1-based). + + Used when naming partial connections: postgresql-1, postgresql-2, etc. + """ + prefix = _sanitize(engine) + "-" + if not self._dir.is_dir(): + return 1 + existing = [ + p.name + for p in self._dir.iterdir() + if p.is_file() and p.name.startswith(prefix) + ] + max_n = 0 + for fname in existing: + suffix = fname[len(prefix) :] + if suffix.isdigit(): + max_n = max(max_n, int(suffix)) + return max_n + 1 diff --git a/anton/datasource_registry.py b/anton/datasource_registry.py new file mode 100644 index 0000000..0596a20 --- /dev/null +++ b/anton/datasource_registry.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +import difflib +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Union + +import yaml + + +@dataclass +class DatasourceField: + name: str + required: bool = True + secret: bool = False + description: str = "" + default: str = "" + + +@dataclass +class AuthMethod: + name: str + display: str + fields: list[DatasourceField] = field(default_factory=list) + + +@dataclass +class DatasourceEngine: + engine: str + display_name: str + pip: str = "" + name_from: Union[str, list[str]] = "" + fields: list[DatasourceField] = field(default_factory=list) + # "choice" means the user must pick from auth_methods before collecting fields + # empty string means no choice, just collect fields from the top-level "fields" list + auth_method: str = "" + auth_methods: list[AuthMethod] = field(default_factory=list) + test_snippet: str = "" + + +# Matches a level-2 heading followed by a ```yaml fenced block. +_YAML_BLOCK_RE = re.compile( + r"^##\s+(.+?)\s*$\n(.*?)^```yaml\n(.*?)^```", + re.MULTILINE | re.DOTALL, +) + + +def _parse_fields(raw: list) -> list[DatasourceField]: + result: list[DatasourceField] = [] + for f in raw or []: + if not isinstance(f, dict): + continue + result.append( + DatasourceField( + name=f.get("name", ""), + required=bool(f.get("required", True)), + secret=bool(f.get("secret", False)), + description=f.get("description", ""), + default=str(f.get("default", "")), + ) + ) + return result + + +def _parse_file(path: Path) -> dict[str, DatasourceEngine]: + """Extract engine definitions from a datasources.md file.""" + if not path.is_file(): + return {} + text = path.read_text(encoding="utf-8") + engines: dict[str, DatasourceEngine] = {} + + for match in _YAML_BLOCK_RE.finditer(text): + yaml_text = match.group(3) + try: + data = yaml.safe_load(yaml_text) + except yaml.YAMLError as exc: + import sys + + print( + f"[anton] Warning: skipping malformed YAML block in {path}: {exc}", + file=sys.stderr, + ) + continue + if not isinstance(data, dict) or "engine" not in data: + continue + + raw_auth_methods = data.get("auth_methods", []) or [] + auth_methods: list[AuthMethod] = [] + for am in raw_auth_methods: + if not isinstance(am, dict): + continue + auth_methods.append( + AuthMethod( + name=am.get("name", ""), + display=am.get("display", am.get("name", "")), + fields=_parse_fields(am.get("fields", [])), + ) + ) + + engine_slug = str(data["engine"]) + engines[engine_slug] = DatasourceEngine( + engine=engine_slug, + display_name=str(data.get("display_name", engine_slug)), + pip=str(data.get("pip", "")), + name_from=data.get("name_from", ""), + fields=_parse_fields(data.get("fields", [])), + auth_method=str(data.get("auth_method", "")), + auth_methods=auth_methods, + test_snippet=str(data.get("test_snippet", "")), + ) + + return engines + + +class DatasourceRegistry: + """Parsed registry of all available data source engines.""" + + _BUILTIN_PATH: Path = Path(__file__).parent.parent / "datasources.md" + _USER_PATH: Path = Path("~/.anton/datasources.md").expanduser() + + def __init__(self) -> None: + self._engines: dict[str, DatasourceEngine] = {} + self._load() + + def _load(self) -> None: + self._engines = _parse_file(self._BUILTIN_PATH) + for slug, engine in _parse_file(self._USER_PATH).items(): + self._engines[slug] = engine + + def reload(self) -> None: + """Reload datasource definitions from disk.""" + self._load() + + def validate_file(self, path: Path) -> dict[str, DatasourceEngine]: + """Parse a datasources.md file and return its engine definitions.""" + return _parse_file(path) + + def get(self, engine_slug: str) -> DatasourceEngine | None: + return self._engines.get(engine_slug) + + def find_by_name(self, display_name: str) -> DatasourceEngine | None: + """Case-insensitive match on display_name or engine slug.""" + needle = display_name.strip().lower() + for engine in self._engines.values(): + if engine.display_name.lower() == needle or engine.engine.lower() == needle: + return engine + matches = [ + e + for e in self._engines.values() + if needle in e.display_name.lower() or needle in e.engine.lower() + ] + return matches[0] if len(matches) == 1 else None + + def fuzzy_find(self, text: str) -> list[DatasourceEngine]: + """Return engines whose name/slug closely matches *text* (fuzzy, for typo tolerance).""" + + def _normalize(s: str) -> str: + return re.sub(r"[\s\-_]", "", s).lower() + + needle = _normalize(text) + # Build a map from normalized key → engine (display_name takes priority) + key_to_engine: dict[str, DatasourceEngine] = {} + for engine in self._engines.values(): + key_to_engine[_normalize(engine.display_name)] = engine + # Don't overwrite display_name key with slug key + slug_key = _normalize(engine.engine) + if slug_key not in key_to_engine: + key_to_engine[slug_key] = engine + + close_keys = difflib.get_close_matches( + needle, key_to_engine.keys(), n=3, cutoff=0.6 + ) + # Deduplicate while preserving order + seen: set[str] = set() + results: list[DatasourceEngine] = [] + for k in close_keys: + eng = key_to_engine[k] + if eng.engine not in seen: + seen.add(eng.engine) + results.append(eng) + return results + + def all_engines(self) -> list[DatasourceEngine]: + return sorted(self._engines.values(), key=lambda e: e.display_name) + + def derive_name( + self, engine_def: DatasourceEngine, credentials: dict[str, str] + ) -> str: + """Derive a default connection name from name_from field(s).""" + name_from = engine_def.name_from + if not name_from: + return "" + if isinstance(name_from, str): + return credentials.get(name_from, "") + parts = [credentials.get(f, "") for f in name_from if credentials.get(f)] + return "_".join(parts) diff --git a/anton/llm/prompts.py b/anton/llm/prompts.py index adbf923..2ba485c 100644 --- a/anton/llm/prompts.py +++ b/anton/llm/prompts.py @@ -57,7 +57,15 @@ tool-call loop inside scratchpad code. The LLM reasons and calls your tools iteratively. \ handle_tool(name, inputs) is a plain sync function returning a string result. Use this for \ multi-step AI workflows like classification, extraction, or analysis with structured outputs. -- All .anton/.env secrets are available as environment variables (os.environ). +- All .anton/.env variables are available as environment variables (os.environ). +- Connected data source credentials are injected as namespaced environment \ +variables in the form DS___ \ +(e.g. DS_POSTGRES_PROD_DB__HOST, DS_POSTGRES_PROD_DB__PASSWORD, \ +DS_HUBSPOT_MAIN__ACCESS_TOKEN). Use those variables directly in scratchpad \ +code and never read ~/.anton/data_vault/ files directly. +- Flat variables like DS_HOST or DS_PASSWORD are used only temporarily \ +during internal connection test snippets. Do not assume they exist during \ +normal chat/runtime execution. - When the user asks how you solved something or wants to see your work, use the scratchpad \ dump action — it shows a clean notebook-style summary without wasting tokens on reformatting. - Always use print() to produce output — scratchpad captures stdout. diff --git a/datasources.md b/datasources.md new file mode 100644 index 0000000..2cafc88 --- /dev/null +++ b/datasources.md @@ -0,0 +1,655 @@ +# Datasource Knowledge + +Anton reads this file when connecting data sources. For each source, the YAML +block defines the fields Python collects. The prose below describes auth flows, +common errors, and how to handle OAuth2 — Anton handles those using the scratchpad. + +Credentials are injected as `DS_` environment variables +before any scratchpad code runs. Never embed raw values in code strings. + +--- + +## PostgreSQL + +```yaml +engine: postgres +display_name: PostgreSQL +pip: psycopg2-binary +name_from: database +fields: + - { name: host, required: true, secret: false, description: "hostname or IP of your database server" } + - { name: port, required: true, secret: false, description: "port number", default: "5432" } + - { name: database, required: true, secret: false, description: "name of the database to connect" } + - { name: user, required: true, secret: false, description: "database username" } + - { name: password, required: true, secret: true, description: "database password" } + - { name: schema, required: false, secret: false, description: "defaults to public if not set" } + - { name: ssl, required: false, secret: false, description: "enable SSL (true/false)" } +test_snippet: | + import psycopg2, os + conn = psycopg2.connect( + host=os.environ['DS_HOST'], port=os.environ.get('DS_PORT','5432'), + dbname=os.environ['DS_DATABASE'], user=os.environ['DS_USER'], + password=os.environ['DS_PASSWORD'], + ) + conn.close() + print("ok") +``` + +Common errors: "password authentication failed" → wrong password or user. +"could not connect to server" → wrong host/port or firewall blocking. + +--- + +## MySQL + +```yaml +engine: mysql +display_name: MySQL +pip: mysql-connector-python +name_from: database +fields: + - { name: host, required: true, secret: false, description: "hostname or IP of your MySQL server" } + - { name: port, required: true, secret: false, description: "port number", default: "3306" } + - { name: database, required: true, secret: false, description: "database name to connect" } + - { name: user, required: true, secret: false, description: "MySQL username" } + - { name: password, required: true, secret: true, description: "MySQL password" } + - { name: ssl, required: false, secret: false, description: "enable SSL (true/false)" } + - { name: ssl_ca, required: false, secret: false, description: "path to CA certificate file" } + - { name: ssl_cert, required: false, secret: false, description: "path to client certificate file" } + - { name: ssl_key, required: false, secret: false, description: "path to client private key file" } +test_snippet: | + import mysql.connector, os + conn = mysql.connector.connect( + host=os.environ['DS_HOST'], + port=int(os.environ.get('DS_PORT', '3306')), + database=os.environ['DS_DATABASE'], + user=os.environ['DS_USER'], + password=os.environ['DS_PASSWORD'], + ) + conn.close() + print("ok") +``` + +--- + +## Snowflake + +```yaml +engine: snowflake +display_name: Snowflake +pip: snowflake-connector-python +name_from: [account, database] +auth_method: choice +auth_methods: + - name: password + display: "Username / Password" + fields: + - { name: account, required: true, secret: false, description: "Snowflake account identifier (e.g. xy12345.us-east-1 or orgname-accountname)" } + - { name: user, required: true, secret: false, description: "Snowflake username" } + - { name: password, required: true, secret: true, description: "Snowflake password" } + - { name: database, required: true, secret: false, description: "database name" } + - { name: schema, required: false, secret: false, description: "schema name (defaults to PUBLIC)" } + - { name: warehouse, required: false, secret: false, description: "warehouse to use for queries" } + - { name: role, required: false, secret: false, description: "role to assume" } + - name: key_pair + display: "Key-Pair Authentication" + fields: + - { name: account, required: true, secret: false, description: "Snowflake account identifier" } + - { name: user, required: true, secret: false, description: "Snowflake username" } + - { name: private_key, required: true, secret: true, description: "PEM-formatted private key content" } + - { name: private_key_passphrase, required: false, secret: true, description: "passphrase for encrypted private key" } + - { name: database, required: true, secret: false, description: "database name" } + - { name: schema, required: false, secret: false, description: "schema name" } + - { name: warehouse, required: false, secret: false, description: "warehouse to use" } + - { name: role, required: false, secret: false, description: "role to assume" } +test_snippet: | + import snowflake.connector, os + conn = snowflake.connector.connect( + account=os.environ['DS_ACCOUNT'], + user=os.environ['DS_USER'], + password=os.environ.get('DS_PASSWORD', ''), + database=os.environ.get('DS_DATABASE', ''), + schema=os.environ.get('DS_SCHEMA', 'PUBLIC'), + warehouse=os.environ.get('DS_WAREHOUSE', ''), + ) + conn.close() + print("ok") +``` + +Account identifier: Admin → Accounts → hover your account name to reveal the identifier. +Format is either `-` or `..`. + +--- + +## Google BigQuery + +```yaml +engine: bigquery +display_name: Google BigQuery +pip: google-cloud-bigquery +name_from: [project_id, dataset] +fields: + - { name: project_id, required: true, secret: false, description: "GCP project ID containing your BigQuery datasets" } + - { name: dataset, required: true, secret: false, description: "BigQuery dataset name" } + - { name: service_account_json, required: false, secret: true, description: "contents of service account JSON key file (paste the full JSON)" } + - { name: service_account_keys, required: false, secret: false, description: "path to service account JSON key file on disk" } +test_snippet: | + import json, os + from google.cloud import bigquery + from google.oauth2 import service_account + + sa_json = os.environ.get('DS_SERVICE_ACCOUNT_JSON', '') + if sa_json: + creds = service_account.Credentials.from_service_account_info( + json.loads(sa_json), + scopes=['https://www.googleapis.com/auth/bigquery.readonly'], + ) + client = bigquery.Client(project=os.environ['DS_PROJECT_ID'], credentials=creds) + else: + client = bigquery.Client(project=os.environ['DS_PROJECT_ID']) + list(client.list_datasets()) + print("ok") +``` + +To create a service account key: GCP Console → IAM → Service Accounts → your account → +Keys → Add Key → JSON. Grant the account `BigQuery Data Viewer` + `BigQuery Job User` roles. + +--- + +## Microsoft SQL Server + +```yaml +engine: mssql +display_name: Microsoft SQL Server +pip: pymssql +name_from: database +fields: + - { name: host, required: true, secret: false, description: "hostname or IP of the SQL Server (for Azure use server field instead)" } + - { name: port, required: true, secret: false, description: "port number", default: "1433" } + - { name: database, required: true, secret: false, description: "database name" } + - { name: user, required: true, secret: false, description: "SQL Server username" } + - { name: password, required: true, secret: true, description: "SQL Server password" } + - { name: server, required: false, secret: false, description: "server name — use for named instances or Azure SQL (e.g. myserver.database.windows.net)" } + - { name: schema, required: false, secret: false, description: "schema name (defaults to dbo)" } +test_snippet: | + import pymssql, os + conn = pymssql.connect( + server=os.environ.get('DS_SERVER') or os.environ['DS_HOST'], + port=os.environ.get('DS_PORT', '1433'), + database=os.environ['DS_DATABASE'], + user=os.environ['DS_USER'], + password=os.environ['DS_PASSWORD'], + ) + conn.close() + print("ok") +``` + +For Azure SQL Database use `server` field with value like `myserver.database.windows.net`. +For Windows Authentication omit user/password and ensure pymssql is built with Kerberos support. + +--- + +## Amazon Redshift + +```yaml +engine: redshift +display_name: Amazon Redshift +pip: psycopg2-binary +name_from: [host, database] +fields: + - { name: host, required: true, secret: false, description: "Redshift cluster endpoint (e.g. mycluster.abc123.us-east-1.redshift.amazonaws.com)" } + - { name: port, required: true, secret: false, description: "port number", default: "5439" } + - { name: database, required: true, secret: false, description: "database name" } + - { name: user, required: true, secret: false, description: "Redshift username" } + - { name: password, required: true, secret: true, description: "Redshift password" } + - { name: schema, required: false, secret: false, description: "schema name (defaults to public)" } + - { name: sslmode, required: false, secret: false, description: "SSL mode: require (default), verify-ca, verify-full, disable", default: "require" } +test_snippet: | + import psycopg2, os + conn = psycopg2.connect( + host=os.environ['DS_HOST'], + port=int(os.environ.get('DS_PORT', '5439')), + dbname=os.environ['DS_DATABASE'], + user=os.environ['DS_USER'], + password=os.environ['DS_PASSWORD'], + sslmode=os.environ.get('DS_SSLMODE', 'require'), + ) + conn.close() + print("ok") +``` + +Redshift is PostgreSQL-compatible. Find the cluster endpoint in AWS Console → +Redshift → Clusters → your cluster → Endpoint (omit the port suffix). + +--- + +## Databricks + +```yaml +engine: databricks +display_name: Databricks +pip: databricks-sql-connector +name_from: [server_hostname, catalog] +fields: + - { name: server_hostname, required: true, secret: false, description: "server hostname for the cluster or SQL warehouse (from JDBC/ODBC connection string)" } + - { name: http_path, required: true, secret: false, description: "HTTP path of the cluster or SQL warehouse" } + - { name: access_token, required: true, secret: true, description: "Databricks personal access token" } + - { name: catalog, required: false, secret: false, description: "Unity Catalog name (defaults to hive_metastore)" } + - { name: schema, required: false, secret: false, description: "schema (database) to use" } + - { name: session_configuration, required: false, secret: false, description: "Spark session configuration as key=value pairs" } +test_snippet: | + from databricks import sql as dbsql + import os + conn = dbsql.connect( + server_hostname=os.environ['DS_SERVER_HOSTNAME'], + http_path=os.environ['DS_HTTP_PATH'], + access_token=os.environ['DS_ACCESS_TOKEN'], + catalog=os.environ.get('DS_CATALOG', ''), + schema=os.environ.get('DS_SCHEMA', ''), + ) + conn.close() + print("ok") +``` + +Personal access token: User Settings → Developer → Access Tokens → Generate New Token. +HTTP path and server hostname: SQL Warehouses → your warehouse → Connection Details tab. + +--- + +## MariaDB + +```yaml +engine: mariadb +display_name: MariaDB +pip: mysql-connector-python +name_from: database +fields: + - { name: host, required: true, secret: false, description: "hostname or IP of your MariaDB server" } + - { name: port, required: true, secret: false, description: "port number", default: "3306" } + - { name: database, required: true, secret: false, description: "database name to connect" } + - { name: user, required: true, secret: false, description: "MariaDB username" } + - { name: password, required: true, secret: true, description: "MariaDB password" } + - { name: ssl, required: false, secret: false, description: "enable SSL (true/false)" } + - { name: ssl_ca, required: false, secret: false, description: "path to CA certificate file" } + - { name: ssl_cert, required: false, secret: false, description: "path to client certificate file" } + - { name: ssl_key, required: false, secret: false, description: "path to client private key file" } +test_snippet: | + import mysql.connector, os + conn = mysql.connector.connect( + host=os.environ['DS_HOST'], + port=int(os.environ.get('DS_PORT', '3306')), + database=os.environ['DS_DATABASE'], + user=os.environ['DS_USER'], + password=os.environ['DS_PASSWORD'], + ) + conn.close() + print("ok") +``` + +MariaDB is wire-compatible with MySQL, so the mysql-connector-python driver works for both. + +--- + +## HubSpot + +```yaml +engine: hubspot +display_name: HubSpot +pip: hubspot-api-client +auth_method: choice +auth_methods: + - name: pat + display: "Private App Token (recommended)" + fields: + - { name: access_token, required: true, secret: true, description: "HubSpot Private App token (starts with pat-na1-)" } + - name: oauth2 + display: "OAuth2 (for multi-account or publishable apps)" + fields: + - { name: client_id, required: true, secret: false, description: "OAuth2 client ID" } + - { name: client_secret, required: true, secret: true, description: "OAuth2 client secret" } + oauth2: + auth_url: https://app.hubspot.com/oauth/authorize + token_url: https://api.hubapi.com/oauth/v1/token + scopes: [crm.objects.contacts.read, crm.objects.deals.read] + store_fields: [access_token, refresh_token] +test_snippet: | + import hubspot, os + client = hubspot.Client.create(access_token=os.environ['DS_ACCESS_TOKEN']) + client.crm.contacts.basic_api.get_page(limit=1) + print("ok") +``` + +For Private App Token: HubSpot → Settings → Integrations → Private Apps → Create. +Recommended scopes: `crm.objects.contacts.read`, `crm.objects.deals.read`, `crm.objects.companies.read`. + +For OAuth2: collect client_id and client_secret, then use the scratchpad to: + +1. Build the authorization URL using `auth_url` + params above +2. Start a local HTTP server on port 8099 to catch the callback +3. Open the URL in the user's browser with `webbrowser.open()` +4. Extract the `code` from the callback, POST to `token_url` for tokens +5. Return `access_token` and `refresh_token` to store in wallet + +--- + +## Oracle Database + +```yaml +engine: oracle_database +display_name: Oracle Database +pip: oracledb +name_from: [host, service_name] +fields: + - { name: user, required: true, secret: false, description: "Oracle database username" } + - { name: password, required: true, secret: true, description: "Oracle database password" } + - { name: host, required: true, secret: false, description: "hostname or IP address of the Oracle server" } + - { name: port, required: true, secret: false, description: "port number (default 1521)", default: "1521" } + - { name: service_name, required: false, secret: false, description: "Oracle service name (preferred over SID)" } + - { name: sid, required: false, secret: false, description: "Oracle SID — use service_name if possible" } + - { name: dsn, required: false, secret: false, description: "full DSN string — overrides host/port/service_name" } + - { name: auth_mode, required: false, secret: false, description: "authorization mode (e.g. SYSDBA)" } +test_snippet: | + import oracledb, os + dsn = os.environ.get('DS_DSN') or oracledb.makedsn( + os.environ.get('DS_HOST', 'localhost'), + os.environ.get('DS_PORT', '1521'), + service_name=os.environ.get('DS_SERVICE_NAME', ''), + ) + conn = oracledb.connect(user=os.environ['DS_USER'], password=os.environ['DS_PASSWORD'], dsn=dsn) + conn.close() + print("ok") +``` + +oracledb runs in thin mode by default (no Oracle Client libraries needed). +Set `auth_mode` to `SYSDBA` or `SYSOPER` for privileged connections. + +--- + +## DuckDB + +```yaml +engine: duckdb +display_name: DuckDB +pip: duckdb +name_from: database +fields: + - { name: database, required: false, secret: false, description: "path to DuckDB database file; omit or use :memory: for in-memory database", default: ":memory:" } + - { name: motherduck_token, required: false, secret: true, description: "MotherDuck access token for connecting to a MotherDuck cloud database" } + - { name: read_only, required: false, secret: false, description: "open in read-only mode (true/false)", default: "false" } +test_snippet: | + import duckdb, os + db_path = os.environ.get('DS_DATABASE', ':memory:') + token = os.environ.get('DS_MOTHERDUCK_TOKEN', '') + if token: + conn = duckdb.connect(f'md:{db_path}?motherduck_token={token}') + else: + conn = duckdb.connect(db_path) + conn.execute('SELECT 1').fetchone() + conn.close() + print("ok") +``` + +For MotherDuck, use `database` as the MotherDuck database name (e.g. `my_db`) and provide +the access token. For local files, provide the path to a `.duckdb` file. + +--- + +## pgvector + +```yaml +engine: pgvector +display_name: pgvector +pip: pgvector psycopg2-binary +name_from: database +fields: + - { name: host, required: true, secret: false, description: "hostname or IP of your PostgreSQL server with pgvector extension" } + - { name: port, required: true, secret: false, description: "port number", default: "5432" } + - { name: database, required: true, secret: false, description: "database name" } + - { name: user, required: true, secret: false, description: "database username" } + - { name: password, required: true, secret: true, description: "database password" } + - { name: schema, required: false, secret: false, description: "schema name (defaults to public)" } + - { name: sslmode, required: false, secret: false, description: "SSL mode: prefer, require, disable" } +test_snippet: | + import psycopg2, os + conn = psycopg2.connect( + host=os.environ['DS_HOST'], + port=int(os.environ.get('DS_PORT', '5432')), + dbname=os.environ['DS_DATABASE'], + user=os.environ['DS_USER'], + password=os.environ['DS_PASSWORD'], + ) + cur = conn.cursor() + cur.execute("SELECT extname FROM pg_extension WHERE extname = 'vector'") + if not cur.fetchone(): + raise RuntimeError("pgvector extension not installed — run: CREATE EXTENSION vector;") + conn.close() + print("ok") +``` + +pgvector must be installed in the PostgreSQL instance: `CREATE EXTENSION IF NOT EXISTS vector;` +Managed options: Supabase, Neon, and AWS RDS for PostgreSQL all support pgvector. + +--- + +## ChromaDB + +```yaml +engine: chromadb +display_name: ChromaDB +pip: chromadb +name_from: host +fields: + - { name: host, required: true, secret: false, description: "ChromaDB server host for HTTP client mode (omit for local in-process mode)" } + - { name: port, required: true, secret: false, description: "ChromaDB server port", default: "8000" } + - { name: persist_directory, required: false, secret: false, description: "local directory for persistent storage (local mode only)" } +test_snippet: | + import chromadb, os + host = os.environ.get('DS_HOST', '') + if host: + client = chromadb.HttpClient( + host=host, + port=int(os.environ.get('DS_PORT', '8000')), + ) + else: + persist_dir = os.environ.get('DS_PERSIST_DIRECTORY', '') + if persist_dir: + client = chromadb.PersistentClient(path=persist_dir) + else: + client = chromadb.EphemeralClient() + client.heartbeat() + print("ok") +``` + +Three modes: HTTP client (connect to a running ChromaDB server), persistent local (file-backed), +or ephemeral in-memory. For production, run `chroma run` to start the HTTP server. + +--- + +## Salesforce + +```yaml +engine: salesforce +display_name: Salesforce +pip: salesforce_api +name_from: username +fields: + - { name: username, required: true, secret: false, description: "Salesforce account username (email)" } + - { name: password, required: true, secret: true, description: "Salesforce account password" } + - { name: client_id, required: true, secret: false, description: "consumer key from the connected app" } + - { name: client_secret, required: true, secret: true, description: "consumer secret from the connected app" } + - { name: is_sandbox, required: false, secret: false, description: "true to connect to sandbox, false for production", default: "false" } +test_snippet: | + import salesforce_api, os + sf = salesforce_api.Salesforce( + username=os.environ['DS_USERNAME'], + password=os.environ['DS_PASSWORD'], + client_id=os.environ['DS_CLIENT_ID'], + client_secret=os.environ['DS_CLIENT_SECRET'], + is_sandbox=os.environ.get('DS_IS_SANDBOX', 'false').lower() == 'true', + ) + sf.query('SELECT Id FROM Account LIMIT 1') + print("ok") +``` + +To get client_id and client_secret: Setup → Apps → App Manager → New Connected App. +Enable OAuth, add callback URL, select scopes (api, refresh_token). + +--- + +## Shopify + +```yaml +engine: shopify +display_name: Shopify +pip: ShopifyAPI +name_from: shop_url +fields: + - { name: shop_url, required: true, secret: false, description: "your Shopify store URL (e.g. mystore.myshopify.com)" } + - { name: client_id, required: true, secret: false, description: "client ID (API key) of the custom app" } + - { name: client_secret, required: true, secret: true, description: "client secret (API secret key) of the custom app" } +test_snippet: | + import shopify, os + shop_url = os.environ['DS_SHOP_URL'].rstrip('/') + if not shop_url.startswith('https://'): + shop_url = f'https://{shop_url}' + session = shopify.Session(shop_url, '2024-01', os.environ['DS_CLIENT_SECRET']) + shopify.ShopifyResource.activate_session(session) + shopify.Shop.current() + shopify.ShopifyResource.clear_session() + print("ok") +``` + +Create a custom app: Shopify Admin → Settings → Apps → Develop apps → Create an app. +Grant required API permissions (read_products, read_orders, etc.) then install the app. + +--- + +## NetSuite + +```yaml +engine: netsuite +display_name: NetSuite +pip: requests-oauthlib>=1.3.1 +name_from: account_id +fields: + - { name: account_id, required: true, secret: false, description: "NetSuite account/realm ID (e.g. 123456_SB1)" } + - { name: consumer_key, required: true, secret: true, description: "OAuth consumer key for the NetSuite integration" } + - { name: consumer_secret,required: true, secret: true, description: "OAuth consumer secret for the NetSuite integration" } + - { name: token_id, required: true, secret: true, description: "Token ID generated for the integration role" } + - { name: token_secret, required: true, secret: true, description: "Token secret generated for the integration role" } + - { name: rest_domain, required: false, secret: false, description: "REST domain override (defaults to https://.suitetalk.api.netsuite.com)" } + - { name: record_types, required: false, secret: false, description: "Comma-separated NetSuite record types to expose (e.g. customer,item,salesOrder)" } +test_snippet: | + import os + from requests_oauthlib import OAuth1Session + account_id = os.environ['DS_ACCOUNT_ID'] + rest_domain = os.environ.get('DS_REST_DOMAIN') or f'https://{account_id.lower().replace("_", "-")}.suitetalk.api.netsuite.com' + url = f'{rest_domain.rstrip("/")}/services/rest/record/v1/metadata-catalog/' + session = OAuth1Session( + client_key=os.environ['DS_CONSUMER_KEY'], + client_secret=os.environ['DS_CONSUMER_SECRET'], + resource_owner_key=os.environ['DS_TOKEN_ID'], + resource_owner_secret=os.environ['DS_TOKEN_SECRET'], + realm=account_id, + signature_method='HMAC-SHA256', + ) + r = session.get(url, headers={'Prefer': 'transient'}) + assert r.status_code < 400, f'HTTP {r.status_code}: {r.text[:200]}' + print("ok") +``` + +NetSuite uses OAuth 1.0a Token-Based Authentication (TBA). Create an integration record in +NetSuite (Setup → Integration → Manage Integrations), then generate token credentials via +Setup → Users/Roles → Access Tokens. The account ID can be found in Setup → Company → Company Information. + +--- + +## Big Commerce + +```yaml +engine: bigcommerce +display_name: Big Commerce +pip: httpx +name_from: store_hash +fields: + - { name: api_base, required: true, secret: false, description: "Base URL of the BigCommerce API (e.g. https://api.bigcommerce.com/stores/0fh0fh0fh0/v3/)" } + - { name: access_token, required: true, secret: true, description: "API token for authenticating with BigCommerce" } +test_snippet: | + import httpx, os + api_base = os.environ['DS_API_BASE'].rstrip('/') + access_token = os.environ['DS_ACCESS_TOKEN'] + headers = {'X-Auth-Token': access_token} + r = httpx.get(f'{api_base}/catalog/products', headers=headers) + assert r.status_code < 400, f'HTTP {r.status_code}: {r.text[:200]}' + print("ok") +``` + +BigCommerce API tokens can be created in the BigCommerce control panel under Advanced Settings → API Accounts. Choose "Create API Account", then select "V2/V3 API Token" and grant the necessary permissions (e.g. "Products: Read-Only" to access product data). +--- + +## TimescaleDB + +```yaml +engine: timescaledb +display_name: TimescaleDB +pip: psycopg2-binary +name_from: [host, database] +fields: + - { name: host, required: true, secret: false, description: "hostname or IP of the TimescaleDB server" } + - { name: port, required: true, secret: false, description: "port number", default: "5432" } + - { name: database, required: true, secret: false, description: "database name" } + - { name: user, required: true, secret: false, description: "database username" } + - { name: password, required: true, secret: true, description: "database password" } + - { name: schema, required: false, secret: false, description: "schema name (defaults to public)" } +test_snippet: | + import psycopg2, os + conn = psycopg2.connect( + host=os.environ['DS_HOST'], + port=int(os.environ.get('DS_PORT', '5432')), + dbname=os.environ['DS_DATABASE'], + user=os.environ['DS_USER'], + password=os.environ['DS_PASSWORD'], + ) + cur = conn.cursor() + cur.execute("SELECT extname FROM pg_extension WHERE extname = 'timescaledb'") + if not cur.fetchone(): + raise RuntimeError("timescaledb extension not found — is TimescaleDB installed?") + conn.close() + print("ok") +``` + +TimescaleDB is a PostgreSQL extension for time-series data. Managed options include Timescale Cloud +and self-hosted PostgreSQL with the TimescaleDB extension installed. + +--- + +## Email + +```yaml +engine: email +display_name: Email +name_from: email +fields: + - { name: email, required: true, secret: false, description: "email address to connect" } + - { name: password, required: true, secret: true, description: "email account password or app-specific password" } + - { name: imap_server, required: false, secret: false, description: "IMAP server hostname", default: "imap.gmail.com" } + - { name: smtp_server, required: false, secret: false, description: "SMTP server hostname", default: "smtp.gmail.com" } + - { name: smtp_port, required: false, secret: false, description: "SMTP port", default: "587" } +test_snippet: | + import imaplib, os + imap = imaplib.IMAP4_SSL(os.environ.get('DS_IMAP_SERVER', 'imap.gmail.com')) + imap.login(os.environ['DS_EMAIL'], os.environ['DS_PASSWORD']) + imap.logout() + print("ok") +``` + +For Gmail, enable IMAP in Settings → See all settings → Forwarding and POP/IMAP, then use an +App Password (Google Account → Security → 2-Step Verification → App passwords) instead of your +account password. For other providers, set imap_server and smtp_server accordingly. + +--- + +## Adding a new data source + +Follow the YAML format above. Add to `~/.anton/datasources.md` (user overrides). +Anton merges user overrides on top of the built-in registry at startup. diff --git a/pyproject.toml b/pyproject.toml index 99ff96f..f648bcb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "rich>=13.0", "prompt-toolkit>=3.0", "packaging>=21.0", + "pyyaml>=6.0", ] [project.optional-dependencies] diff --git a/tests/test_datasource.py b/tests/test_datasource.py new file mode 100644 index 0000000..3989cdf --- /dev/null +++ b/tests/test_datasource.py @@ -0,0 +1,2026 @@ +from __future__ import annotations + +import io +import json +import os +from textwrap import dedent +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from rich.console import Console + +from anton.chat import ( + ChatSession, + _DS_KNOWN_VARS, + _DS_SECRET_VARS, + _build_datasource_context, + _handle_add_custom_datasource, + _handle_connect_datasource, + _handle_list_data_sources, + _handle_remove_data_source, + _handle_test_datasource, + _register_secret_vars, + _restore_namespaced_env, + _scrub_credentials, + parse_connection_slug, +) +from anton.cli import app as _cli_app +from anton.data_vault import DataVault, _slug_env_prefix +from anton.datasource_registry import ( + DatasourceEngine, + DatasourceRegistry, + _parse_file, +) + + +# ───────────────────────────────────────────────────────────────────────────── +# Fixtures +# ───────────────────────────────────────────────────────────────────────────── + + +@pytest.fixture() +def vault_dir(tmp_path): + return tmp_path / "data_vault" + + +@pytest.fixture() +def vault(vault_dir): + return DataVault(vault_dir=vault_dir) + + +@pytest.fixture() +def datasources_md(tmp_path): + """Write a minimal datasources.md and return its path.""" + path = tmp_path / "datasources.md" + path.write_text(dedent("""\ + ## PostgreSQL + + ```yaml + engine: postgresql + display_name: PostgreSQL + pip: psycopg2-binary + name_from: database + fields: + - name: host + required: true + description: hostname or IP + - name: port + required: true + default: "5432" + description: port number + - name: database + required: true + description: database name + - name: user + required: true + description: username + - name: password + required: true + secret: true + description: password + - name: schema + required: false + description: defaults to public + test_snippet: | + import psycopg2 + conn = psycopg2.connect( + host=os.environ["DS_HOST"], + port=os.environ["DS_PORT"], + dbname=os.environ["DS_DATABASE"], + user=os.environ["DS_USER"], + password=os.environ["DS_PASSWORD"], + ) + conn.close() + print("ok") + ``` + + ## HubSpot + + ```yaml + engine: hubspot + display_name: HubSpot + pip: hubspot-api-client + name_from: access_token + auth_method: choice + auth_methods: + - name: private_app + display: Private App token (recommended) + fields: + - name: access_token + required: true + secret: true + description: pat-na1-xxx token + - name: oauth + display: OAuth 2.0 + fields: + - name: client_id + required: true + description: OAuth client ID + - name: client_secret + required: true + secret: true + description: OAuth client secret + test_snippet: | + print("ok") + ``` + """)) + return path + + +@pytest.fixture() +def registry(datasources_md): + """Registry pointing at our temp datasources.md, no user overrides.""" + reg = DatasourceRegistry.__new__(DatasourceRegistry) + reg._engines = _parse_file(datasources_md) + return reg + + +@pytest.fixture() +def make_session(): + """Factory that creates a fresh ChatSession with mocked scratchpads.""" + def _factory(): + mock_llm = AsyncMock() + session = ChatSession(mock_llm) + session._scratchpads = AsyncMock() + return session + return _factory + + +@pytest.fixture() +def make_cell(): + """Factory that creates a MagicMock scratchpad execution cell.""" + def _factory(stdout="ok", stderr="", error=None): + cell = MagicMock() + cell.stdout = stdout + cell.stderr = stderr + cell.error = error + return cell + return _factory + + +@pytest.fixture(autouse=True) +def clean_ds_state(): + """Clear _DS_SECRET_VARS, _DS_KNOWN_VARS, and all DS_* env vars around each test.""" + def _clean(): + _DS_SECRET_VARS.clear() + _DS_KNOWN_VARS.clear() + for k in list(os.environ): + if k.startswith("DS_"): + del os.environ[k] + + _clean() + yield + _clean() + + +# ───────────────────────────────────────────────────────────────────────────── +# DataVault — save / load / delete +# ───────────────────────────────────────────────────────────────────────────── + + +class TestDataVaultSaveLoad: + def test_save_creates_file(self, vault, vault_dir): + vault.save("postgresql", "prod_db", {"host": "db.example.com", "port": "5432"}) + assert (vault_dir / "postgresql-prod_db").is_file() + + def test_save_file_permissions(self, vault, vault_dir): + vault.save("postgresql", "prod_db", {"host": "db.example.com"}) + path = vault_dir / "postgresql-prod_db" + mode = oct(path.stat().st_mode)[-3:] + assert mode == "600" + + def test_vault_dir_permissions(self, vault, vault_dir): + vault.save("postgresql", "prod_db", {"host": "db.example.com"}) + mode = oct(vault_dir.stat().st_mode)[-3:] + assert mode == "700" + + def test_load_returns_fields(self, vault): + creds = {"host": "db.example.com", "port": "5432", "password": "secret"} + vault.save("postgresql", "prod_db", creds) + assert vault.load("postgresql", "prod_db") == creds + + def test_load_missing_returns_none(self, vault): + assert vault.load("postgresql", "nonexistent") is None + + def test_load_corrupt_file_returns_none(self, vault, vault_dir): + vault._ensure_dir() + (vault_dir / "postgresql-bad").write_text("not json") + assert vault.load("postgresql", "bad") is None + + def test_save_overwrites_existing(self, vault): + vault.save("postgresql", "prod_db", {"host": "old.host"}) + vault.save("postgresql", "prod_db", {"host": "new.host"}) + assert vault.load("postgresql", "prod_db") == {"host": "new.host"} + + def test_delete_existing(self, vault, vault_dir): + vault.save("postgresql", "prod_db", {"host": "x"}) + result = vault.delete("postgresql", "prod_db") + assert result is True + assert not (vault_dir / "postgresql-prod_db").is_file() + + def test_delete_missing_returns_false(self, vault): + assert vault.delete("postgresql", "ghost") is False + + def test_special_chars_sanitized_in_filename(self, vault, vault_dir): + vault.save("postgresql", "my db/prod", {"host": "x"}) + files = list(vault_dir.iterdir()) + assert len(files) == 1 + assert "/" not in files[0].name + + def test_json_contains_metadata(self, vault, vault_dir): + vault.save("postgresql", "prod_db", {"host": "x"}) + raw = json.loads((vault_dir / "postgresql-prod_db").read_text()) + assert raw["engine"] == "postgresql" + assert raw["name"] == "prod_db" + assert "created_at" in raw + assert raw["fields"] == {"host": "x"} + + +# ───────────────────────────────────────────────────────────────────────────── +# DataVault — list_connections +# ───────────────────────────────────────────────────────────────────────────── + + +class TestDataVaultListConnections: + def test_empty_vault(self, vault): + assert vault.list_connections() == [] + + def test_lists_all_connections(self, vault): + vault.save("postgresql", "prod_db", {"host": "a"}) + vault.save("hubspot", "main", {"access_token": "pat-xxx"}) + conns = vault.list_connections() + engines = {c["engine"] for c in conns} + assert engines == {"postgresql", "hubspot"} + + def test_skips_corrupt_files(self, vault, vault_dir): + vault._ensure_dir() + vault.save("postgresql", "good", {"host": "x"}) + (vault_dir / "postgresql-bad").write_text("{{not json") + conns = vault.list_connections() + assert len(conns) == 1 + assert conns[0]["name"] == "good" + + def test_vault_dir_missing_returns_empty(self, vault): + assert vault.list_connections() == [] + + +# ───────────────────────────────────────────────────────────────────────────── +# DataVault — inject_env / clear_ds_env +# ───────────────────────────────────────────────────────────────────────────── + + +class TestDataVaultEnvInjection: + def test_inject_sets_ds_vars(self, vault): + vault.save("postgresql", "prod_db", {"host": "db.example.com", "password": "s3cr3t"}) + var_names = vault.inject_env("postgresql", "prod_db") + assert os.environ.get("DS_POSTGRESQL_PROD_DB__HOST") == "db.example.com" + assert os.environ.get("DS_POSTGRESQL_PROD_DB__PASSWORD") == "s3cr3t" + assert set(var_names) == {"DS_POSTGRESQL_PROD_DB__HOST", "DS_POSTGRESQL_PROD_DB__PASSWORD"} + + def test_inject_missing_returns_none(self, vault): + assert vault.inject_env("postgresql", "ghost") is None + + def test_clear_removes_ds_vars(self, vault): + vault.save("postgresql", "prod_db", {"host": "x"}) + vault.inject_env("postgresql", "prod_db") + vault.clear_ds_env() + assert "DS_POSTGRESQL_PROD_DB__HOST" not in os.environ + + def test_clear_leaves_non_ds_vars(self, vault, monkeypatch): + monkeypatch.setenv("MY_VAR", "untouched") + vault.clear_ds_env() + assert os.environ.get("MY_VAR") == "untouched" + + def test_inject_uppercases_field_names(self, vault): + vault.save("postgresql", "prod_db", {"access_token": "tok123"}) + vault.inject_env("postgresql", "prod_db") + assert os.environ.get("DS_POSTGRESQL_PROD_DB__ACCESS_TOKEN") == "tok123" + + def test_inject_flat_mode_sets_flat_vars(self, vault): + """flat=True injects legacy DS_FIELD vars, not namespaced ones.""" + vault.save("postgresql", "prod_db", {"host": "db.example.com"}) + var_names = vault.inject_env("postgresql", "prod_db", flat=True) + assert os.environ.get("DS_HOST") == "db.example.com" + assert "DS_POSTGRESQL_PROD_DB__HOST" not in os.environ + assert set(var_names) == {"DS_HOST"} + + def test_two_same_type_connections_no_collision(self, vault): + """Two connections of the same engine type coexist without overwriting each other.""" + vault.save("postgres", "prod_db", {"host": "prod.example.com"}) + vault.save("postgres", "analytics", {"host": "analytics.example.com"}) + vault.inject_env("postgres", "prod_db") + vault.inject_env("postgres", "analytics") + assert os.environ.get("DS_POSTGRES_PROD_DB__HOST") == "prod.example.com" + assert os.environ.get("DS_POSTGRES_ANALYTICS__HOST") == "analytics.example.com" + assert os.environ.get("DS_POSTGRES_PROD_DB__HOST") != os.environ.get("DS_POSTGRES_ANALYTICS__HOST") + + def test_different_engines_no_collision(self, vault): + """Connections from different engines coexist simultaneously.""" + vault.save("postgres", "prod_db", {"host": "pg.example.com"}) + vault.save("hubspot", "main", {"access_token": "pat-abc"}) + vault.inject_env("postgres", "prod_db") + vault.inject_env("hubspot", "main") + assert os.environ.get("DS_POSTGRES_PROD_DB__HOST") == "pg.example.com" + assert os.environ.get("DS_HUBSPOT_MAIN__ACCESS_TOKEN") == "pat-abc" + + def test_slug_env_prefix_sanitizes_special_chars(self, vault): + """Special characters in names produce correct namespaced vars.""" + assert _slug_env_prefix("postgres", "prod-db.eu") == "DS_POSTGRES_PROD_DB_EU" + vault.save("postgres", "prod-db.eu", {"host": "eu.pg.com"}) + vault.inject_env("postgres", "prod-db.eu") + assert os.environ.get("DS_POSTGRES_PROD_DB_EU__HOST") == "eu.pg.com" + + +# ───────────────────────────────────────────────────────────────────────────── +# DataVault — next_connection_number +# ───────────────────────────────────────────────────────────────────────────── + + +class TestDataVaultNextConnectionNumber: + def test_returns_one_when_empty(self, vault): + assert vault.next_connection_number("postgresql") == 1 + + def test_increments_past_existing(self, vault): + vault.save("postgresql", "1", {"host": "a"}) + vault.save("postgresql", "2", {"host": "b"}) + assert vault.next_connection_number("postgresql") == 3 + + def test_ignores_named_connections(self, vault): + # "prod_db" is not a digit — should not affect numbering + vault.save("postgresql", "prod_db", {"host": "a"}) + assert vault.next_connection_number("postgresql") == 1 + + def test_does_not_confuse_engines(self, vault): + vault.save("hubspot", "1", {"access_token": "x"}) + vault.save("hubspot", "2", {"access_token": "y"}) + assert vault.next_connection_number("postgresql") == 1 + + +# ───────────────────────────────────────────────────────────────────────────── +# DatasourceRegistry — lookup +# ───────────────────────────────────────────────────────────────────────────── + + +class TestDatasourceRegistry: + def test_get_by_slug(self, registry): + engine = registry.get("postgresql") + assert engine is not None + assert engine.display_name == "PostgreSQL" + + def test_get_missing_returns_none(self, registry): + assert registry.get("mysql") is None + + @pytest.mark.parametrize("query", ["PostgreSQL", "postgresql", "POSTGRESQL"]) + def test_find_by_name_variants(self, registry, query): + assert registry.find_by_name(query) is not None + + def test_find_unknown_returns_none(self, registry): + assert registry.find_by_name("MySQL") is None + + def test_all_engines_sorted(self, registry): + engines = registry.all_engines() + names = [e.display_name for e in engines] + assert names == sorted(names) + + def test_fields_parsed_correctly(self, registry): + engine = registry.get("postgresql") + field_names = [f.name for f in engine.fields] + assert "host" in field_names + assert "password" in field_names + + def test_secret_flag_on_password(self, registry): + engine = registry.get("postgresql") + pw = next(f for f in engine.fields if f.name == "password") + assert pw.secret is True + + def test_required_flag(self, registry): + engine = registry.get("postgresql") + schema = next(f for f in engine.fields if f.name == "schema") + assert schema.required is False + + def test_default_value_on_port(self, registry): + engine = registry.get("postgresql") + port = next(f for f in engine.fields if f.name == "port") + assert port.default == "5432" + + def test_pip_field(self, registry): + engine = registry.get("postgresql") + assert engine.pip == "psycopg2-binary" + + def test_test_snippet_present(self, registry): + engine = registry.get("postgresql") + assert 'print("ok")' in engine.test_snippet + + def test_auth_method_choice_parsed(self, registry): + engine = registry.get("hubspot") + assert engine.auth_method == "choice" + assert len(engine.auth_methods) == 2 + method_names = [m.name for m in engine.auth_methods] + assert "private_app" in method_names + assert "oauth" in method_names + + def test_auth_method_fields_parsed(self, registry): + engine = registry.get("hubspot") + private = next(m for m in engine.auth_methods if m.name == "private_app") + assert len(private.fields) == 1 + assert private.fields[0].name == "access_token" + assert private.fields[0].secret is True + + def test_validate_file_returns_engines(self, registry, datasources_md): + result = registry.validate_file(datasources_md) + assert "postgresql" in result + assert "hubspot" in result + assert result["postgresql"].display_name == "PostgreSQL" + + def test_validate_file_missing_returns_empty(self, registry, tmp_path): + result = registry.validate_file(tmp_path / "nonexistent.md") + assert result == {} + + def test_reload_picks_up_new_engine(self, tmp_path): + md = tmp_path / "datasources.md" + md.write_text(dedent("""\ + ## MySQL + + ```yaml + engine: mysql + display_name: MySQL + pip: pymysql + name_from: database + fields: + - name: host + required: true + description: hostname + test_snippet: | + print("ok") + ``` + """)) + reg = DatasourceRegistry.__new__(DatasourceRegistry) + reg._engines = {} + reg._BUILTIN_PATH = md + reg._USER_PATH = tmp_path / "user.md" + reg.reload() + assert reg.get("mysql") is not None + assert reg.get("mysql").display_name == "MySQL" + + +# ───────────────────────────────────────────────────────────────────────────── +# DatasourceRegistry — derive_name +# ───────────────────────────────────────────────────────────────────────────── + + +class TestDeriveConnectionName: + def test_single_field_name_from(self, registry): + engine = registry.get("postgresql") # name_from: database + name = registry.derive_name(engine, {"database": "prod_db", "host": "x"}) + assert name == "prod_db" + + def test_missing_name_from_field_returns_empty(self, registry): + engine = registry.get("postgresql") + assert registry.derive_name(engine, {"host": "x"}) == "" + + def test_no_name_from_returns_empty(self): + engine = DatasourceEngine(engine="test", display_name="Test", name_from="") + reg = DatasourceRegistry.__new__(DatasourceRegistry) + reg._engines = {} + assert reg.derive_name(engine, {"host": "x"}) == "" + + def test_list_name_from(self): + engine = DatasourceEngine( + engine="test", + display_name="Test", + name_from=["host", "database"], + ) + reg = DatasourceRegistry.__new__(DatasourceRegistry) + reg._engines = {} + name = reg.derive_name(engine, {"host": "db.example.com", "database": "prod"}) + assert name == "db.example.com_prod" + + def test_list_name_from_skips_missing(self): + engine = DatasourceEngine( + engine="test", + display_name="Test", + name_from=["host", "database"], + ) + reg = DatasourceRegistry.__new__(DatasourceRegistry) + reg._engines = {} + assert reg.derive_name(engine, {"host": "db.example.com"}) == "db.example.com" + + +# ───────────────────────────────────────────────────────────────────────────── +# DatasourceRegistry — user overrides +# ───────────────────────────────────────────────────────────────────────────── + + +class TestDatasourceRegistryUserOverrides: + def test_user_override_wins(self, tmp_path, datasources_md): + """A user-defined engine with same slug overrides the builtin.""" + user_md = tmp_path / "user_datasources.md" + user_md.write_text(dedent("""\ + ## PostgreSQL + + ```yaml + engine: postgresql + display_name: PostgreSQL (custom) + pip: psycopg2 + fields: + - name: host + required: true + description: custom host field + test_snippet: print("ok") + ``` + """)) + + builtin = _parse_file(datasources_md) + user = _parse_file(user_md) + merged = {**builtin, **user} + + assert merged["postgresql"].display_name == "PostgreSQL (custom)" + assert merged["postgresql"].pip == "psycopg2" + + def test_missing_user_file_falls_back_to_builtin(self, tmp_path): + assert _parse_file(tmp_path / "nonexistent.md") == {} + + +# ───────────────────────────────────────────────────────────────────────────── +# _handle_connect_datasource — integration-style (mocked I/O) +# ───────────────────────────────────────────────────────────────────────────── + + +class TestHandleConnectDatasource: + """Test the slash-command handler with mocked prompts and scratchpad.""" + + @pytest.mark.asyncio + async def test_unknown_engine_returns_early(self, registry, vault_dir, make_session): + """Typing an unknown engine name aborts without saving anything.""" + session = make_session() + console = MagicMock() + + with ( + patch("anton.chat.DataVault", return_value=DataVault(vault_dir=vault_dir)), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", return_value="MySQL"), + ): + result = await _handle_connect_datasource(console, session._scratchpads, session) + + assert result is session + assert DataVault(vault_dir=vault_dir).list_connections() == [] + + @pytest.mark.asyncio + async def test_partial_save_on_n_answer(self, registry, vault_dir, make_session): + """Answering 'n' saves partial credentials and returns without testing.""" + session = make_session() + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + prompt_responses = iter(["PostgreSQL", "n", "db.example.com", "", "", "", "", ""]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + ): + await _handle_connect_datasource(console, session._scratchpads, session) + + conns = vault.list_connections() + assert len(conns) == 1 + assert conns[0]["engine"] == "postgresql" + assert conns[0]["name"].isdigit() + session._scratchpads.get_or_create.assert_not_called() + + @pytest.mark.asyncio + async def test_successful_connection_saves_and_injects_history( + self, registry, vault_dir, make_session, make_cell + ): + """Happy path: test passes, credentials saved, history entry added.""" + session = make_session() + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + pad = AsyncMock() + pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + + prompt_responses = iter([ + "PostgreSQL", "y", + "db.example.com", "5432", "prod_db", "alice", "s3cr3t", "", + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + ): + result = await _handle_connect_datasource(console, session._scratchpads, session) + + conns = vault.list_connections() + assert len(conns) == 1 + saved = vault.load("postgresql", conns[0]["name"]) + assert saved is not None + assert saved["host"] == "db.example.com" + assert saved["password"] == "s3cr3t" + assert result._history + last = result._history[-1] + assert last["role"] == "assistant" + assert "postgresql" in last["content"].lower() + + @pytest.mark.asyncio + async def test_failed_test_offers_retry(self, registry, vault_dir, make_session, make_cell): + """Connection test failure prompts for retry; success on second attempt saves.""" + session = make_session() + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + pad = AsyncMock() + pad.execute = AsyncMock(side_effect=[ + make_cell(stdout="", stderr="password authentication failed"), + make_cell(stdout="ok"), + ]) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + + prompt_responses = iter([ + "PostgreSQL", "y", + "db.example.com", "5432", "prod_db", "alice", "wrongpassword", "", + "y", # retry? + "correctpassword", + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + ): + await _handle_connect_datasource(console, session._scratchpads, session) + + conns = vault.list_connections() + assert len(conns) == 1 + saved = vault.load("postgresql", conns[0]["name"]) + assert saved is not None + assert saved["password"] == "correctpassword" + + @pytest.mark.asyncio + async def test_failed_test_no_retry_returns_without_saving( + self, registry, vault_dir, make_session, make_cell + ): + """Declining retry on failed test leaves vault empty.""" + session = make_session() + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + pad = AsyncMock() + pad.execute = AsyncMock(return_value=make_cell(stdout="", error="connection refused")) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + + prompt_responses = iter([ + "PostgreSQL", "y", + "db.example.com", "5432", "prod_db", "alice", "badpass", "", + "n", + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + ): + result = await _handle_connect_datasource(console, session._scratchpads, session) + + assert vault.list_connections() == [] + assert not result._history + + @pytest.mark.asyncio + async def test_ds_env_injected_after_successful_connect( + self, registry, vault_dir, make_session, make_cell + ): + """After a successful connect, namespaced DS_* vars are injected.""" + session = make_session() + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + pad = AsyncMock() + pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + + prompt_responses = iter([ + "PostgreSQL", "y", + "db.example.com", "5432", "prod_db", "alice", "s3cr3t", "", + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + ): + await _handle_connect_datasource(console, session._scratchpads, session) + + # name_from=database → name="prod_db" → prefix DS_POSTGRESQL_PROD_DB + assert os.environ.get("DS_POSTGRESQL_PROD_DB__HOST") == "db.example.com" + + @pytest.mark.asyncio + async def test_auth_method_choice_selects_fields( + self, registry, vault_dir, make_session, make_cell + ): + """Selecting an auth method filters to that method's fields only.""" + session = make_session() + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + pad = AsyncMock() + pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + + prompt_responses = iter(["HubSpot", "1", "y", "pat-na1-abc123"]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + ): + await _handle_connect_datasource(console, session._scratchpads, session) + + conns = vault.list_connections() + assert len(conns) == 1 + saved = vault.load("hubspot", conns[0]["name"]) + assert saved is not None + # Only private_app fields collected — no client_id or client_secret + assert "access_token" in saved + assert "client_id" not in saved + assert "client_secret" not in saved + + @pytest.mark.asyncio + async def test_selective_field_collection( + self, registry, vault_dir, make_session, make_cell + ): + """Typing 'host,user,password' collects only those three fields.""" + session = make_session() + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + pad = AsyncMock() + pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + + prompt_responses = iter([ + "PostgreSQL", "host,user,password", + "db.example.com", "alice", "s3cr3t", + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + ): + await _handle_connect_datasource(console, session._scratchpads, session) + + conns = vault.list_connections() + assert len(conns) == 1 + saved = vault.load("postgresql", conns[0]["name"]) + assert saved is not None + assert set(saved.keys()) == {"host", "user", "password"} + + +# ───────────────────────────────────────────────────────────────────────────── +# Credential scrubbing +# ───────────────────────────────────────────────────────────────────────────── + + +class TestCredentialScrubbing: + """_scrub_credentials and _register_secret_vars — flat and namespaced modes.""" + + def test_register_secret_vars_adds_secret_fields(self, registry): + """Secret fields are added to _DS_SECRET_VARS; non-secret fields are not.""" + pg = registry.get("postgresql") + assert pg is not None + _register_secret_vars(pg) + assert "DS_PASSWORD" in _DS_SECRET_VARS + assert "DS_HOST" not in _DS_SECRET_VARS + assert "DS_PORT" not in _DS_SECRET_VARS + + def test_scrub_replaces_registered_secret_value(self, monkeypatch): + """A registered secret value is replaced with its placeholder.""" + _DS_SECRET_VARS.add("DS_ACCESS_TOKEN") + monkeypatch.setenv("DS_ACCESS_TOKEN", "supersecrettoken123") + result = _scrub_credentials("token is supersecrettoken123 here") + assert "supersecrettoken123" not in result + assert "[DS_ACCESS_TOKEN]" in result + + def test_scrub_leaves_non_secret_field_readable(self, registry, monkeypatch): + """Non-secret DS_* values (host, port) are left untouched.""" + pg = registry.get("postgresql") + assert pg is not None + _register_secret_vars(pg) + monkeypatch.setenv("DS_HOST", "mydbhostname") + monkeypatch.setenv("DS_PASSWORD", "s3cr3tpassword99") + result = _scrub_credentials("host=mydbhostname pass=s3cr3tpassword99") + assert "mydbhostname" in result + assert "s3cr3tpassword99" not in result + assert "[DS_PASSWORD]" in result + + def test_scrub_skips_short_values(self, monkeypatch): + """Registered secrets are always scrubbed regardless of length.""" + _DS_SECRET_VARS.add("DS_PASSWORD") + monkeypatch.setenv("DS_PASSWORD", "short") + result = _scrub_credentials("password=short") + assert "short" not in result + assert "[DS_PASSWORD]" in result + + def test_scrub_fallback_redacts_unknown_long_ds_vars(self, monkeypatch): + """Long DS_* vars not in _DS_SECRET_VARS are scrubbed as a safety fallback.""" + monkeypatch.setenv("DS_WEBHOOK_SECRET", "wh_sec_abcdefgh1234") + result = _scrub_credentials("secret=wh_sec_abcdefgh1234 here") + assert "wh_sec_abcdefgh1234" not in result + assert "[DS_WEBHOOK_SECRET]" in result + + @pytest.mark.asyncio + async def test_register_and_scrub_on_connect(self, registry, vault_dir, monkeypatch): + """After _handle_connect_datasource, the new secret var is immediately scrubbed.""" + vault = DataVault(vault_dir=vault_dir) + session = MagicMock() + session._history = [] + session._cortex = None + + pad = AsyncMock() + pad.execute = AsyncMock( + return_value=MagicMock(stdout="ok", stderr="", error=None) + ) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + + secret_pw = "supersecretpassword999" + prompt_responses = iter([ + "PostgreSQL", "y", + "db.host.com", "5432", "mydb", "alice", secret_pw, "public", + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + ): + await _handle_connect_datasource(MagicMock(), session._scratchpads, session) + + # name_from=database → name="mydb" → DS_POSTGRESQL_MYDB__PASSWORD + namespaced_pw_var = "DS_POSTGRESQL_MYDB__PASSWORD" + assert namespaced_pw_var in _DS_SECRET_VARS + monkeypatch.setenv(namespaced_pw_var, secret_pw) + result = _scrub_credentials(f"error: auth failed with {secret_pw}") + assert secret_pw not in result + assert f"[{namespaced_pw_var}]" in result + + # ── Namespaced mode ────────────────────────────────────────────────────── + + def test_register_with_slug_uses_namespaced_keys(self, registry): + pg = registry.get("postgresql") + _register_secret_vars(pg, engine="postgresql", name="prod_db") + assert "DS_POSTGRESQL_PROD_DB__PASSWORD" in _DS_SECRET_VARS + assert "DS_POSTGRESQL_PROD_DB__HOST" not in _DS_SECRET_VARS + assert "DS_POSTGRESQL_PROD_DB__HOST" in _DS_KNOWN_VARS + + def test_register_without_slug_uses_flat_keys(self, registry): + pg = registry.get("postgresql") + _register_secret_vars(pg) # no engine/name → flat mode + assert "DS_PASSWORD" in _DS_SECRET_VARS + assert "DS_HOST" not in _DS_SECRET_VARS + + def test_scrub_replaces_namespaced_secret_value(self, registry, monkeypatch): + pg = registry.get("postgresql") + _register_secret_vars(pg, engine="postgresql", name="prod_db") + secret = "namespacedpassword123" + monkeypatch.setenv("DS_POSTGRESQL_PROD_DB__PASSWORD", secret) + result = _scrub_credentials(f"error: {secret}") + assert secret not in result + assert "[DS_POSTGRESQL_PROD_DB__PASSWORD]" in result + + def test_scrub_leaves_namespaced_non_secret_readable(self, registry, monkeypatch): + pg = registry.get("postgresql") + _register_secret_vars(pg, engine="postgresql", name="prod_db") + monkeypatch.setenv("DS_POSTGRESQL_PROD_DB__HOST", "mydbhostname") + result = _scrub_credentials("host=mydbhostname") + assert "mydbhostname" in result + + +# ───────────────────────────────────────────────────────────────────────────── +# Active datasource scoping +# ───────────────────────────────────────────────────────────────────────────── + + +class TestActiveDatasourceScoping: + """Tests for active datasource routing and multi-source context building.""" + + def test_active_datasource_defaults_to_none(self, make_session): + session = make_session() + assert session._active_datasource is None + + @pytest.mark.asyncio + async def test_reconnect_sets_active_datasource(self, vault_dir, make_session): + """Reconnecting to a slug via prefill sets session._active_datasource.""" + vault = DataVault(vault_dir=vault_dir) + vault.save("hubspot", "2", {"access_token": "pat-xxx"}) + + session = make_session() + console = MagicMock() + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry"), + ): + result = await _handle_connect_datasource( + console, session._scratchpads, session, prefill="hubspot-2" + ) + + assert result._active_datasource == "hubspot-2" + + @pytest.mark.asyncio + async def test_reconnect_all_namespaced_vars_available(self, vault_dir, make_session): + """After reconnect, ALL saved connections remain available as namespaced vars.""" + vault = DataVault(vault_dir=vault_dir) + vault.save("oracle", "1", {"host": "oracle.host", "user": "admin", "password": "orapass"}) + vault.save("hubspot", "2", {"access_token": "pat-xxx"}) + + vault.inject_env("oracle", "1") + vault.inject_env("hubspot", "2") + assert os.environ.get("DS_ORACLE_1__HOST") == "oracle.host" + assert os.environ.get("DS_HUBSPOT_2__ACCESS_TOKEN") == "pat-xxx" + + session = make_session() + console = MagicMock() + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry"), + ): + result = await _handle_connect_datasource( + console, session._scratchpads, session, prefill="hubspot-2" + ) + + assert "DS_HOST" not in os.environ + assert "DS_ACCESS_TOKEN" not in os.environ + assert os.environ.get("DS_ORACLE_1__HOST") == "oracle.host" + assert os.environ.get("DS_HUBSPOT_2__ACCESS_TOKEN") == "pat-xxx" + assert result._active_datasource == "hubspot-2" + + def test_build_datasource_context_no_filter(self, vault_dir): + """Without active_only, all vault entries appear in the context.""" + vault = DataVault(vault_dir=vault_dir) + vault.save("oracle", "1", {"host": "oracle.host"}) + vault.save("hubspot", "2", {"access_token": "pat-xxx"}) + + with patch("anton.chat.DataVault", return_value=vault): + ctx = _build_datasource_context() + + assert "oracle-1" in ctx + assert "hubspot-2" in ctx + + def test_build_datasource_context_active_only_filters(self, vault_dir): + """With active_only set, only the matching slug appears.""" + vault = DataVault(vault_dir=vault_dir) + vault.save("oracle", "1", {"host": "oracle.host"}) + vault.save("hubspot", "2", {"access_token": "pat-xxx"}) + + with patch("anton.chat.DataVault", return_value=vault): + ctx = _build_datasource_context(active_only="hubspot-2") + + assert "hubspot-2" in ctx + assert "oracle-1" not in ctx + + def test_build_datasource_context_active_only_empty_when_no_match(self, vault_dir): + """If active_only doesn't match any slug, the section has no entries.""" + vault = DataVault(vault_dir=vault_dir) + vault.save("oracle", "1", {"host": "oracle.host"}) + + with patch("anton.chat.DataVault", return_value=vault): + ctx = _build_datasource_context(active_only="hubspot-99") + + assert "oracle-1" not in ctx + + def test_build_datasource_context_shows_namespaced_vars(self, vault_dir): + vault = DataVault(vault_dir=vault_dir) + vault.save("postgres", "prod_db", {"host": "pg.example.com", "password": "s3cr3t"}) + + with patch("anton.chat.DataVault", return_value=vault): + ctx = _build_datasource_context() + + assert "DS_POSTGRES_PROD_DB__HOST" in ctx + assert "DS_POSTGRES_PROD_DB__PASSWORD" in ctx + assert "DS_HOST" not in ctx + + def test_build_datasource_context_shows_slug_and_engine_label(self, vault_dir): + vault = DataVault(vault_dir=vault_dir) + vault.save("postgres", "prod_db", {"host": "pg.example.com"}) + + with patch("anton.chat.DataVault", return_value=vault): + ctx = _build_datasource_context() + + assert "postgres-prod_db" in ctx + assert "(postgres)" in ctx + + def test_multi_source_context_shows_both_connections(self, vault_dir): + """Both connections are visible in the context with their namespaced vars.""" + vault = DataVault(vault_dir=vault_dir) + vault.save("postgres", "prod_db", {"host": "pg.example.com"}) + vault.save("hubspot", "main", {"access_token": "pat-abc"}) + + with patch("anton.chat.DataVault", return_value=vault): + ctx = _build_datasource_context() + + assert "postgres-prod_db" in ctx + assert "DS_POSTGRES_PROD_DB__HOST" in ctx + assert "hubspot-main" in ctx + assert "DS_HUBSPOT_MAIN__ACCESS_TOKEN" in ctx + + +# ───────────────────────────────────────────────────────────────────────────── +# CLI command registration +# ───────────────────────────────────────────────────────────────────────────── + + +class TestCliCommandRegistration: + @pytest.mark.parametrize("cmd_name", [ + "connect", + "list", + "edit", + "remove", + "test", + ]) + def test_command_registered(self, cmd_name): + names = [cmd.name for cmd in _cli_app.registered_commands] + assert cmd_name in names + + +# ───────────────────────────────────────────────────────────────────────────── +# _handle_list_data_sources +# ───────────────────────────────────────────────────────────────────────────── + + +class TestHandleListDataSources: + def test_empty_vault_shows_message(self, vault_dir): + console = MagicMock() + with patch("anton.chat.DataVault", return_value=DataVault(vault_dir=vault_dir)): + _handle_list_data_sources(console) + printed = " ".join(str(c) for c in console.print.call_args_list) + assert "No data sources" in printed or "connect" in printed + + def test_complete_connection_shows_saved_with_engine_name(self, vault_dir, registry): + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "prod_db", { + "host": "db.example.com", "port": "5432", + "database": "prod", "user": "alice", "password": "s3cr3t", + }) + + buf = io.StringIO() + rich_console = Console(file=buf, highlight=False, markup=False) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + _handle_list_data_sources(rich_console) + + output = buf.getvalue() + assert "postgresql-prod_db" in output + assert "saved" in output.lower() + assert "PostgreSQL" in output # engine display_name shown + + def test_incomplete_connection_shows_incomplete(self, vault_dir, registry): + vault = DataVault(vault_dir=vault_dir) + # Missing required fields: database, user, password + vault.save("postgresql", "partial", {"host": "db.example.com"}) + + buf = io.StringIO() + rich_console = Console(file=buf, highlight=False, markup=False) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + _handle_list_data_sources(rich_console) + + output = buf.getvalue() + assert "incomplete" in output.lower() + + +# ───────────────────────────────────────────────────────────────────────────── +# _handle_test_datasource +# ───────────────────────────────────────────────────────────────────────────── + + +class TestHandleTestDatasource: + @pytest.mark.asyncio + async def test_success_path(self, vault_dir, registry, make_cell): + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "prod_db", { + "host": "db.example.com", "port": "5432", + "database": "prod", "user": "alice", "password": "s3cr3t", + }) + console = MagicMock() + pad = AsyncMock() + pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) + scratchpads = AsyncMock() + scratchpads.get_or_create = AsyncMock(return_value=pad) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + await _handle_test_datasource(console, scratchpads, "postgresql-prod_db") + + printed = " ".join(str(c) for c in console.print.call_args_list) + assert "✓" in printed or "passed" in printed.lower() + + @pytest.mark.asyncio + async def test_failure_path(self, vault_dir, registry, make_cell): + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "prod_db", { + "host": "db.example.com", "port": "5432", + "database": "prod", "user": "alice", "password": "wrongpass", + }) + console = MagicMock() + pad = AsyncMock() + pad.execute = AsyncMock( + return_value=make_cell(stdout="", stderr="password authentication failed") + ) + scratchpads = AsyncMock() + scratchpads.get_or_create = AsyncMock(return_value=pad) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + await _handle_test_datasource(console, scratchpads, "postgresql-prod_db") + + printed = " ".join(str(c) for c in console.print.call_args_list) + assert "✗" in printed or "failed" in printed.lower() + + @pytest.mark.asyncio + async def test_unknown_connection(self, vault_dir, registry): + vault = DataVault(vault_dir=vault_dir) + console = MagicMock() + scratchpads = AsyncMock() + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + await _handle_test_datasource(console, scratchpads, "postgresql-ghost") + + printed = " ".join(str(c) for c in console.print.call_args_list) + assert "not found" in printed.lower() or "No connection" in printed + + @pytest.mark.asyncio + async def test_empty_slug_shows_usage(self, vault_dir, registry): + console = MagicMock() + scratchpads = AsyncMock() + + with ( + patch("anton.chat.DataVault", return_value=DataVault(vault_dir=vault_dir)), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + await _handle_test_datasource(console, scratchpads, "") + + printed = " ".join(str(c) for c in console.print.call_args_list) + assert "Usage" in printed or "test" in printed + + +# ───────────────────────────────────────────────────────────────────────────── +# Edit flow +# ───────────────────────────────────────────────────────────────────────────── + + +class TestEditDatasourceFlow: + @pytest.mark.asyncio + async def test_existing_values_loaded(self, registry, vault_dir, make_session, make_cell): + """Edit shows existing non-secret values as defaults.""" + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "prod_db", { + "host": "old.host", "port": "5432", + "database": "prod", "user": "alice", "password": "oldpass", + }) + + session = make_session() + console = MagicMock() + pad = AsyncMock() + pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + + prompt_values = iter([ + "old.host", "5432", "prod", "alice", "newpass", "", + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_values)), + ): + await _handle_connect_datasource( + console, session._scratchpads, session, datasource_name="postgresql-prod_db" + ) + + saved = vault.load("postgresql", "prod_db") + assert saved is not None + assert saved["host"] == "old.host" + assert saved["password"] == "newpass" + + @pytest.mark.asyncio + async def test_enter_preserves_secret_value(self, registry, vault_dir, make_session, make_cell): + """Pressing Enter on a secret field keeps the existing value.""" + vault = DataVault(vault_dir=vault_dir) + original_pass = "original_secret_pass" + vault.save("postgresql", "prod_db", { + "host": "db.host", "port": "5432", + "database": "prod", "user": "alice", "password": original_pass, + }) + + session = make_session() + console = MagicMock() + pad = AsyncMock() + pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + + prompt_values = iter([ + "db.host", "5432", "prod", "alice", + "", # password — Enter = keep original + "", # schema + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_values)), + ): + await _handle_connect_datasource( + console, session._scratchpads, session, datasource_name="postgresql-prod_db" + ) + + saved = vault.load("postgresql", "prod_db") + assert saved is not None + assert saved["password"] == original_pass + + @pytest.mark.asyncio + async def test_unknown_slug_returns_session(self, registry, vault_dir, make_session): + """Editing a non-existent slug returns the session unchanged.""" + vault = DataVault(vault_dir=vault_dir) + session = make_session() + console = MagicMock() + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + result = await _handle_connect_datasource( + console, session._scratchpads, session, datasource_name="postgresql-ghost" + ) + + assert result is session + printed = " ".join(str(c) for c in console.print.call_args_list) + assert "not found" in printed.lower() or "No connection" in printed + + +# ───────────────────────────────────────────────────────────────────────────── +# Remove flow +# ───────────────────────────────────────────────────────────────────────────── + + +class TestRemoveDatasourceFlow: + def test_confirmation_yes_deletes(self, vault, registry): + vault.save("postgresql", "prod_db", {"host": "x"}) + console = Console(quiet=True) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Confirm.ask", return_value=True), + ): + _handle_remove_data_source(console, "postgresql-prod_db") + + assert vault.load("postgresql", "prod_db") is None + + def test_confirmation_no_preserves(self, vault, registry): + vault.save("postgresql", "prod_db", {"host": "x"}) + console = Console(quiet=True) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Confirm.ask", return_value=False), + ): + _handle_remove_data_source(console, "postgresql-prod_db") + + assert vault.load("postgresql", "prod_db") is not None + + def test_unknown_name_shows_message(self, vault_dir, registry): + vault = DataVault(vault_dir=vault_dir) + console = MagicMock() + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + _handle_remove_data_source(console, "postgresql-ghost") + + printed = " ".join(str(c) for c in console.print.call_args_list) + assert "not found" in printed.lower() or "No connection" in printed + + def test_invalid_format_shows_warning(self, vault_dir): + vault = DataVault(vault_dir=vault_dir) + console = MagicMock() + + with patch("anton.chat.DataVault", return_value=vault): + _handle_remove_data_source(console, "nohyphen") + + printed = " ".join(str(c) for c in console.print.call_args_list) + assert "Invalid" in printed or "engine-name" in printed + + +# ───────────────────────────────────────────────────────────────────────────── +# Environment activation — collision-free behavior +# ───────────────────────────────────────────────────────────────────────────── + + +class TestEnvActivationCollisionFree: + @pytest.mark.asyncio + async def test_connect_clears_previous_ds_vars( + self, registry, vault_dir, make_session, make_cell, monkeypatch + ): + """After a successful new connect, stale DS_* vars are cleared.""" + monkeypatch.setenv("DS_ACCESS_TOKEN", "old-token") + vault = DataVault(vault_dir=vault_dir) + session = make_session() + console = MagicMock() + + pad = AsyncMock() + pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + + prompt_responses = iter([ + "PostgreSQL", "y", + "db.example.com", "5432", "prod_db", "alice", "s3cr3t", "", + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + ): + await _handle_connect_datasource(console, session._scratchpads, session) + + assert "DS_ACCESS_TOKEN" not in os.environ + assert os.environ.get("DS_POSTGRESQL_PROD_DB__HOST") == "db.example.com" + + @pytest.mark.asyncio + async def test_two_same_type_connections_no_collision( + self, registry, vault_dir, make_session + ): + """Both same-type connections remain available as distinct namespaced vars.""" + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "db1", { + "host": "host1.example.com", "port": "5432", + "database": "db1", "user": "u1", "password": "p1", + }) + vault.save("postgresql", "db2", { + "host": "host2.example.com", "port": "5432", + "database": "db2", "user": "u2", "password": "p2", + }) + + session = make_session() + console = MagicMock() + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + await _handle_connect_datasource( + console, session._scratchpads, session, + prefill="postgresql-db2", + ) + + assert os.environ.get("DS_POSTGRESQL_DB1__HOST") == "host1.example.com" + assert os.environ.get("DS_POSTGRESQL_DB2__HOST") == "host2.example.com" + assert "DS_HOST" not in os.environ + assert "DS_DATABASE" not in os.environ + + +# ───────────────────────────────────────────────────────────────────────────── +# Datasource slash-command behavior +# ───────────────────────────────────────────────────────────────────────────── + + +class TestDatasourceSlashCommandBehavior: + @pytest.mark.asyncio + async def test_test_data_source_no_arg_shows_usage(self, vault_dir, registry): + console = MagicMock() + scratchpads = AsyncMock() + + with ( + patch("anton.chat.DataVault", return_value=DataVault(vault_dir=vault_dir)), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + await _handle_test_datasource(console, scratchpads, "") + + printed = " ".join(str(c) for c in console.print.call_args_list) + assert "Usage" in printed or "test" in printed + + @pytest.mark.asyncio + async def test_edit_data_source_no_arg_safe(self, vault_dir, registry, make_session): + """datasource_name=None triggers new-connect flow without crash.""" + session = make_session() + console = MagicMock() + + with ( + patch("anton.chat.DataVault", return_value=DataVault(vault_dir=vault_dir)), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", return_value="UnknownEngine"), + ): + updated = await _handle_connect_datasource( + console, session._scratchpads, session, + datasource_name=None, + ) + + assert updated is not None + + +# ───────────────────────────────────────────────────────────────────────────── +# _slug_env_prefix +# ───────────────────────────────────────────────────────────────────────────── + + +class TestParseConnectionSlug: + ENGINES = ["postgresql", "sql-server", "google-big-query"] + + def test_simple_engine(self): + assert parse_connection_slug("postgresql-prod_db", self.ENGINES) == ("postgresql", "prod_db") + + def test_hyphenated_engine(self): + assert parse_connection_slug("sql-server-prod-db", self.ENGINES) == ("sql-server", "prod-db") + + def test_longest_prefix_wins(self): + engines = ["google", "google-big-query"] + assert parse_connection_slug("google-big-query-main", engines) == ("google-big-query", "main") + + def test_ambiguous_resolves_to_longest(self): + engines = ["sql", "sql-server"] + assert parse_connection_slug("sql-server-1", engines) == ("sql-server", "1") + + def test_invalid_slug_no_match(self): + assert parse_connection_slug("unknown-engine-name", self.ENGINES) is None + + def test_slug_with_empty_name_part(self): + assert parse_connection_slug("postgresql-", self.ENGINES) is None + + def test_fallback_to_vault_for_custom_engine(self, tmp_path): + """Custom engine not in registry is resolved via vault fallback.""" + vault = DataVault(vault_dir=tmp_path / "vault") + vault.save("my_custom_db", "prod", {"host": "localhost"}) + result = parse_connection_slug( + "my_custom_db-prod", + known_engines=["postgresql"], + vault=vault, + ) + assert result == ("my_custom_db", "prod") + + def test_registry_match_takes_priority_over_vault(self, tmp_path): + """Registry prefix match wins even when vault also has the slug.""" + vault = DataVault(vault_dir=tmp_path / "vault") + vault.save("postgresql", "prod", {"host": "localhost"}) + result = parse_connection_slug( + "postgresql-prod", + known_engines=["postgresql"], + vault=vault, + ) + assert result == ("postgresql", "prod") + + def test_no_match_returns_none_with_vault(self, tmp_path): + """Truly unknown slug returns None even with vault supplied.""" + vault = DataVault(vault_dir=tmp_path / "vault") + result = parse_connection_slug( + "ghost-engine-1", + known_engines=["postgresql"], + vault=vault, + ) + assert result is None + + def test_no_vault_still_returns_none_for_unknown(self): + """Backward compat: no vault arg, unknown engine still returns None.""" + assert parse_connection_slug("custom-1", known_engines=["postgresql"]) is None + + +class TestSlugEnvPrefix: + def test_basic_engine_and_name(self): + assert _slug_env_prefix("postgres", "prod_db") == "DS_POSTGRES_PROD_DB" + + def test_hubspot_main(self): + assert _slug_env_prefix("hubspot", "main") == "DS_HUBSPOT_MAIN" + + def test_sanitizes_hyphen_and_dot(self): + assert _slug_env_prefix("postgres", "prod-db.eu") == "DS_POSTGRES_PROD_DB_EU" + + def test_numeric_name(self): + assert _slug_env_prefix("postgresql", "1") == "DS_POSTGRESQL_1" + + +# ───────────────────────────────────────────────────────────────────────────── +# Temporary flat activation and restoration +# ───────────────────────────────────────────────────────────────────────────── + + +class TestTemporaryFlatExecution: + """Tests that flat vars are used only during test_snippet, then restored.""" + + def test_restore_namespaced_env_clears_flat_and_reinjects(self, vault_dir): + """_restore_namespaced_env replaces flat vars with namespaced vars.""" + vault = DataVault(vault_dir=vault_dir) + vault.save("postgres", "analytics", {"host": "analytics.example.com"}) + + vault.inject_env("postgres", "analytics", flat=True) + assert os.environ.get("DS_HOST") == "analytics.example.com" + assert "DS_POSTGRES_ANALYTICS__HOST" not in os.environ + + with patch("anton.chat.DataVault", return_value=vault): + _restore_namespaced_env(vault) + + assert "DS_HOST" not in os.environ + assert os.environ.get("DS_POSTGRES_ANALYTICS__HOST") == "analytics.example.com" + + def test_restore_namespaced_env_reinjects_all_connections(self, vault_dir): + """_restore_namespaced_env restores ALL saved connections, not just one.""" + vault = DataVault(vault_dir=vault_dir) + vault.save("postgres", "prod_db", {"host": "prod.example.com"}) + vault.save("hubspot", "main", {"access_token": "pat-abc"}) + + vault.inject_env("postgres", "prod_db", flat=True) + + with patch("anton.chat.DataVault", return_value=vault): + _restore_namespaced_env(vault) + + assert "DS_HOST" not in os.environ + assert os.environ.get("DS_POSTGRES_PROD_DB__HOST") == "prod.example.com" + assert os.environ.get("DS_HUBSPOT_MAIN__ACCESS_TOKEN") == "pat-abc" + + @pytest.mark.asyncio + async def test_test_datasource_injects_flat_then_restores_namespaced( + self, vault_dir, registry + ): + """_handle_test_datasource uses flat vars during snippet, then restores namespaced.""" + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "prod_db", { + "host": "pg.example.com", "port": "5432", + "database": "prod_db", "user": "alice", "password": "s3cr3t", + }) + vault.save("hubspot", "main", {"access_token": "pat-abc"}) + vault.inject_env("postgresql", "prod_db") + vault.inject_env("hubspot", "main") + + env_during_test: dict = {} + + async def capture_execute(snippet): + env_during_test["DS_HOST"] = os.environ.get("DS_HOST") + env_during_test["DS_POSTGRESQL_PROD_DB__HOST"] = os.environ.get( + "DS_POSTGRESQL_PROD_DB__HOST" + ) + return MagicMock(stdout="ok", stderr="", error=None) + + pad = AsyncMock() + pad.execute = capture_execute + pad.reset = AsyncMock() + pad.install_packages = AsyncMock() + + scratchpads = AsyncMock() + scratchpads.get_or_create = AsyncMock(return_value=pad) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + await _handle_test_datasource(MagicMock(), scratchpads, "postgresql-prod_db") + + # During execution: flat var was set, namespaced was absent + assert env_during_test["DS_HOST"] == "pg.example.com" + assert env_during_test["DS_POSTGRESQL_PROD_DB__HOST"] is None + + # After execution: flat vars gone, namespaced restored + assert "DS_HOST" not in os.environ + assert os.environ.get("DS_POSTGRESQL_PROD_DB__HOST") == "pg.example.com" + assert os.environ.get("DS_HUBSPOT_MAIN__ACCESS_TOKEN") == "pat-abc" + + +# ───────────────────────────────────────────────────────────────────────────── +# Stale registration state regression tests +# ───────────────────────────────────────────────────────────────────────────── + + +class TestStaleDsRegistrationState: + """Regression tests: _DS_SECRET_VARS/_DS_KNOWN_VARS must mirror vault contents.""" + + def test_remove_clears_stale_secret_vars(self, vault_dir, registry): + """After removing a connection, its secret var names leave _DS_SECRET_VARS.""" + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "prod_db", { + "host": "pg.example.com", "port": "5432", + "database": "prod_db", "user": "alice", "password": "s3cr3t", + }) + + with patch("anton.datasource_registry.DatasourceRegistry", return_value=registry): + _restore_namespaced_env(vault) + + assert "DS_POSTGRESQL_PROD_DB__PASSWORD" in _DS_SECRET_VARS + assert "DS_POSTGRESQL_PROD_DB__PASSWORD" in _DS_KNOWN_VARS + + vault.delete("postgresql", "prod_db") + + with patch("anton.datasource_registry.DatasourceRegistry", return_value=registry): + _restore_namespaced_env(vault) + + assert "DS_POSTGRESQL_PROD_DB__PASSWORD" not in _DS_SECRET_VARS + assert "DS_POSTGRESQL_PROD_DB__PASSWORD" not in _DS_KNOWN_VARS + + def test_edit_connection_refreshes_secret_vars(self, vault_dir, registry): + """Overwriting a connection via vault.save rebuilds registration without duplication.""" + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "prod_db", { + "host": "pg.example.com", "port": "5432", + "database": "prod_db", "user": "alice", "password": "old-pass", + }) + + with patch("anton.datasource_registry.DatasourceRegistry", return_value=registry): + _restore_namespaced_env(vault) + + secret_key = "DS_POSTGRESQL_PROD_DB__PASSWORD" + assert secret_key in _DS_SECRET_VARS + count_before = len(_DS_SECRET_VARS) + + # Simulate edit: overwrite with new credentials + vault.save("postgresql", "prod_db", { + "host": "pg.example.com", "port": "5432", + "database": "prod_db", "user": "alice", "password": "new-pass", + }) + + with patch("anton.datasource_registry.DatasourceRegistry", return_value=registry): + _restore_namespaced_env(vault) + + assert secret_key in _DS_SECRET_VARS + assert len(_DS_SECRET_VARS) == count_before + assert os.environ.get(secret_key) == "new-pass" + + def test_reconnect_no_duplicate_secret_vars(self, vault_dir, registry): + """Calling _restore_namespaced_env multiple times does not grow _DS_SECRET_VARS.""" + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "prod_db", { + "host": "pg.example.com", "port": "5432", + "database": "prod_db", "user": "alice", "password": "s3cr3t", + }) + + with patch("anton.datasource_registry.DatasourceRegistry", return_value=registry): + _restore_namespaced_env(vault) + + count_after_first = len(_DS_SECRET_VARS) + known_after_first = len(_DS_KNOWN_VARS) + + with patch("anton.datasource_registry.DatasourceRegistry", return_value=registry): + _restore_namespaced_env(vault) + + assert len(_DS_SECRET_VARS) == count_after_first + assert len(_DS_KNOWN_VARS) == known_after_first + + +# ───────────────────────────────────────────────────────────────────────────── +# TestAddCustomDatasourceFlow +# ───────────────────────────────────────────────────────────────────────────── + + +class TestAddCustomDatasourceFlow: + """Tests for _handle_add_custom_datasource field-collection logic.""" + + def _make_llm_response( + self, fields: list[dict], display_name: str = "MyDB" + ) -> str: + """Return a JSON string mimicking the LLM's plan() response.""" + import json as _json + return _json.dumps( + {"display_name": display_name, "pip": "", "fields": fields} + ) + + def _make_registry(self, tmp_path): + """Return a minimal registry mock that accepts any slug.""" + reg = MagicMock() + reg.validate_file.return_value = {"mydb": MagicMock()} + reg.reload.return_value = None + reg.get.return_value = None # triggers inline fallback + return reg + + def _make_llm(self, json_text: str): + """Return an AsyncMock LLM whose plan() returns json_text.""" + llm = AsyncMock() + response = MagicMock() + response.content = json_text + llm.plan = AsyncMock(return_value=response) + return llm + + def _mock_ds_path(self, mock_path_cls, tmp_path): + """Wire Path mock so datasources.md writes go to tmp_path.""" + mock_path_cls.return_value.expanduser.return_value = tmp_path / "datasources.md" + + @pytest.mark.asyncio + async def test_missing_required_non_secret_field_prompts_user( + self, tmp_path, make_session + ): + """Required non-secret field without inline value triggers Prompt.ask.""" + session = make_session() + session._llm = self._make_llm(self._make_llm_response([ + { + "name": "host", "value": "", "secret": False, + "required": True, "description": "hostname", + }, + ])) + console = MagicMock() + registry = self._make_registry(tmp_path) + + # first response: initial auth question; second: the host field prompt + prompt_responses = iter(["I want to connect to mydb", "localhost"]) + + with ( + patch( + "rich.prompt.Prompt.ask", + side_effect=lambda *a, **kw: next(prompt_responses), + ), + patch("anton.chat.Path") as mock_path_cls, + ): + self._mock_ds_path(mock_path_cls, tmp_path) + result = await _handle_add_custom_datasource( + console, "mydb", registry, session + ) + + assert result is not None + _, credentials = result + assert credentials["host"] == "localhost" + + @pytest.mark.asyncio + async def test_missing_required_secret_field_prompts_user( + self, tmp_path, make_session + ): + """Required secret field without inline value triggers password prompt.""" + session = make_session() + session._llm = self._make_llm(self._make_llm_response([ + { + "name": "api_key", "value": "", "secret": True, + "required": True, "description": "API key", + }, + ])) + console = MagicMock() + registry = self._make_registry(tmp_path) + + password_calls = [] + + def fake_prompt(*args, **kwargs): + if kwargs.get("password"): + password_calls.append(kwargs) + return "mysecret" + return "I want to connect" + + with ( + patch("rich.prompt.Prompt.ask", side_effect=fake_prompt), + patch("anton.chat.Path") as mock_path_cls, + ): + self._mock_ds_path(mock_path_cls, tmp_path) + result = await _handle_add_custom_datasource( + console, "mydb", registry, session + ) + + assert result is not None + _, credentials = result + assert credentials["api_key"] == "mysecret" + assert len(password_calls) == 1, "expected exactly one password prompt" + + @pytest.mark.asyncio + async def test_incomplete_custom_datasource_not_saved( + self, tmp_path, make_session + ): + """Empty responses for all required fields causes a hard stop (None).""" + session = make_session() + session._llm = self._make_llm(self._make_llm_response([ + { + "name": "host", "value": "", "secret": False, + "required": True, "description": "hostname", + }, + { + "name": "api_key", "value": "", "secret": True, + "required": True, "description": "API key", + }, + ])) + console = MagicMock() + registry = self._make_registry(tmp_path) + + # User presses Enter (empty) for every prompt + prompt_responses = iter(["I want to connect", "", ""]) + + with ( + patch( + "rich.prompt.Prompt.ask", + side_effect=lambda *a, **kw: next(prompt_responses), + ), + patch("anton.chat.Path") as mock_path_cls, + ): + self._mock_ds_path(mock_path_cls, tmp_path) + result = await _handle_add_custom_datasource( + console, "mydb", registry, session + ) + + # Must return None — caller must not call vault.save() + assert result is None + + +class TestCustomDatasourceConnectFlow: + """Tests for the custom_source path in _handle_connect_datasource: + test_snippet is run before saving, and failures prevent saving.""" + + # ── helpers (mirrors TestAddCustomDatasourceFlow) ──────────────────── + + def _make_llm_response( + self, + fields: list[dict], + display_name: str = "My API Service", + test_snippet: str = "", + ) -> str: + import json as _json + return _json.dumps({ + "display_name": display_name, + "pip": "", + "test_snippet": test_snippet, + "fields": fields, + }) + + def _make_registry(self, tmp_path): + reg = MagicMock() + reg.all_engines.return_value = [] + reg.find_by_name.return_value = None + reg.fuzzy_find.return_value = [] + reg.validate_file.return_value = {"my_api_service": MagicMock()} + reg.reload.return_value = None + reg.get.return_value = None # triggers inline fallback engine_def + return reg + + def _make_llm(self, json_text: str): + llm = AsyncMock() + response = MagicMock() + response.content = json_text + llm.plan = AsyncMock(return_value=response) + return llm + + def _mock_ds_path(self, mock_path_cls, tmp_path): + mock_path_cls.return_value.expanduser.return_value = tmp_path / "datasources.md" + + # ── tests ───────────────────────────────────────────────────────────── + + @pytest.mark.asyncio + async def test_custom_with_test_snippet_success( + self, vault_dir, make_session, make_cell, tmp_path + ): + """Custom datasource with test_snippet: test passes → connection saved.""" + session = make_session() + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + pad = AsyncMock() + pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + session._llm = self._make_llm(self._make_llm_response( + [{"name": "api_key", "value": "", "secret": True, "required": True, "description": "API key"}], + test_snippet="print('ok')", + )) + + prompt_responses = iter([ + "0", # choose custom + "My API Service", # tool name + "I have an API key", # auth description + "my_secret_key", # api_key (secret prompt) + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=self._make_registry(tmp_path)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + patch("anton.chat.Path") as mock_path_cls, + ): + self._mock_ds_path(mock_path_cls, tmp_path) + result = await _handle_connect_datasource(console, session._scratchpads, session) + + conns = vault.list_connections() + assert len(conns) == 1 + saved = vault.load(conns[0]["engine"], conns[0]["name"]) + assert saved is not None + assert saved.get("api_key") == "my_secret_key" + assert result._history + assert result._history[-1]["role"] == "assistant" + pad.execute.assert_called_once() + + @pytest.mark.asyncio + async def test_custom_with_test_snippet_fail_no_retry( + self, vault_dir, make_session, make_cell, tmp_path + ): + """Custom datasource: test fails and user declines retry → not saved.""" + session = make_session() + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + pad = AsyncMock() + pad.execute = AsyncMock(return_value=make_cell(stdout="", stderr="connection refused")) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + session._llm = self._make_llm(self._make_llm_response( + [{"name": "api_key", "value": "", "secret": True, "required": True, "description": "API key"}], + test_snippet="print('ok')", + )) + + prompt_responses = iter([ + "0", + "My API Service", + "I have an API key", + "bad_key", # api_key + "n", # retry? + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=self._make_registry(tmp_path)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + patch("anton.chat.Path") as mock_path_cls, + ): + self._mock_ds_path(mock_path_cls, tmp_path) + result = await _handle_connect_datasource(console, session._scratchpads, session) + + assert vault.list_connections() == [] + assert not result._history + + @pytest.mark.asyncio + async def test_custom_with_test_snippet_fail_retry_success( + self, vault_dir, make_session, make_cell, tmp_path + ): + """Custom datasource: test fails, user retries with corrected creds → saved.""" + session = make_session() + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + pad = AsyncMock() + pad.execute = AsyncMock(side_effect=[ + make_cell(stdout="", stderr="invalid key"), + make_cell(stdout="ok"), + ]) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + session._llm = self._make_llm(self._make_llm_response( + [{"name": "api_key", "value": "", "secret": True, "required": True, "description": "API key"}], + test_snippet="print('ok')", + )) + + prompt_responses = iter([ + "0", + "My API Service", + "I have an API key", + "bad_key", # api_key first attempt + "y", # retry? + "good_key", # api_key retry + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=self._make_registry(tmp_path)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + patch("anton.chat.Path") as mock_path_cls, + ): + self._mock_ds_path(mock_path_cls, tmp_path) + result = await _handle_connect_datasource(console, session._scratchpads, session) + + conns = vault.list_connections() + assert len(conns) == 1 + saved = vault.load(conns[0]["engine"], conns[0]["name"]) + assert saved is not None + assert saved.get("api_key") == "good_key" + assert result._history + + @pytest.mark.asyncio + async def test_custom_without_test_snippet_saves( + self, vault_dir, make_session, make_cell, tmp_path + ): + """Custom datasource without test_snippet: saves directly, no scratchpad call.""" + session = make_session() + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + pad = AsyncMock() + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + session._llm = self._make_llm(self._make_llm_response( + [{"name": "api_key", "value": "", "secret": True, "required": True, "description": "API key"}], + test_snippet="", + )) + + prompt_responses = iter([ + "0", + "My API Service", + "I have an API key", + "my_key", # api_key + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=self._make_registry(tmp_path)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + patch("anton.chat.Path") as mock_path_cls, + ): + self._mock_ds_path(mock_path_cls, tmp_path) + await _handle_connect_datasource(console, session._scratchpads, session) + + conns = vault.list_connections() + assert len(conns) == 1 + pad.execute.assert_not_called() diff --git a/tests/test_scrubbing.py b/tests/test_scrubbing.py new file mode 100644 index 0000000..cad690c --- /dev/null +++ b/tests/test_scrubbing.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import os +from unittest.mock import patch + +import pytest + +from anton.chat import ( + _DS_KNOWN_VARS, + _DS_SECRET_VARS, + _scrub_credentials, +) + + +@pytest.fixture(autouse=True) +def clean_ds_state(): + """Clear _DS_SECRET_VARS, _DS_KNOWN_VARS, and all DS_* env vars around each test.""" + def _clean(): + _DS_SECRET_VARS.clear() + _DS_KNOWN_VARS.clear() + for k in list(os.environ): + if k.startswith("DS_"): + del os.environ[k] + + _clean() + yield + _clean() + + +class TestScrubCredentials: + """Focused regression tests for _scrub_credentials short-secret handling.""" + + def test_registered_6char_secret_scrubbed(self, monkeypatch): + """A 6-character registered secret is scrubbed regardless of length.""" + _DS_SECRET_VARS.add("DS_PASSWORD") + monkeypatch.setenv("DS_PASSWORD", "abc123") + result = _scrub_credentials("auth failed: abc123") + assert "abc123" not in result + assert "[DS_PASSWORD]" in result + + def test_registered_8char_secret_scrubbed(self, monkeypatch): + """An 8-character registered secret is scrubbed (was at the old threshold).""" + _DS_SECRET_VARS.add("DS_API_KEY") + monkeypatch.setenv("DS_API_KEY", "tok12345") + result = _scrub_credentials("token=tok12345 rejected") + assert "tok12345" not in result + assert "[DS_API_KEY]" in result + + def test_registered_1char_secret_scrubbed(self, monkeypatch): + """A 1-character registered secret is scrubbed.""" + _DS_SECRET_VARS.add("DS_SECRET") + monkeypatch.setenv("DS_SECRET", "x") + result = _scrub_credentials("value=x here") + assert "=x " not in result + assert "[DS_SECRET]" in result + + def test_non_secret_var_not_scrubbed(self, monkeypatch): + """A known but non-secret DS_* var (e.g. DS_HOST) stays readable.""" + _DS_KNOWN_VARS.add("DS_HOST") + monkeypatch.setenv("DS_HOST", "mydbhostname") + result = _scrub_credentials("host=mydbhostname") + assert "mydbhostname" in result + + def test_unknown_short_ds_var_not_scrubbed(self, monkeypatch): + """Unknown DS_* vars with short values are NOT scrubbed (heuristic threshold).""" + monkeypatch.setenv("DS_ENABLE_FEATURE", "on") + result = _scrub_credentials("flag=on active") + assert "on" in result diff --git a/uv.lock b/uv.lock index 7838860..bba6b2b 100644 --- a/uv.lock +++ b/uv.lock @@ -49,6 +49,7 @@ dependencies = [ { name = "prompt-toolkit" }, { name = "pydantic" }, { name = "pydantic-settings" }, + { name = "pyyaml" }, { name = "rich" }, { name = "typer" }, ] @@ -73,6 +74,7 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24" }, + { name = "pyyaml", specifier = ">=6.0" }, { name = "rich", specifier = ">=13.0" }, { name = "typer", specifier = ">=0.9" }, ] @@ -609,6 +611,61 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/14/1b/a298b06749107c305e1fe0f814c6c74aea7b2f1e10989cb30f544a1b3253/python_dotenv-1.2.1-py3-none-any.whl", hash = "sha256:b81ee9561e9ca4004139c6cbba3a238c32b03e4894671e181b671e8cb8425d61", size = 21230, upload-time = "2025-10-26T15:12:09.109Z" }, ] +[[package]] +name = "pyyaml" +version = "6.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/05/8e/961c0007c59b8dd7729d542c61a4d537767a59645b82a0b521206e1e25c2/pyyaml-6.0.3.tar.gz", hash = "sha256:d76623373421df22fb4cf8817020cbb7ef15c725b9d5e45f17e189bfc384190f", size = 130960, upload-time = "2025-09-25T21:33:16.546Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6d/16/a95b6757765b7b031c9374925bb718d55e0a9ba8a1b6a12d25962ea44347/pyyaml-6.0.3-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:44edc647873928551a01e7a563d7452ccdebee747728c1080d881d68af7b997e", size = 185826, upload-time = "2025-09-25T21:31:58.655Z" }, + { url = "https://files.pythonhosted.org/packages/16/19/13de8e4377ed53079ee996e1ab0a9c33ec2faf808a4647b7b4c0d46dd239/pyyaml-6.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:652cb6edd41e718550aad172851962662ff2681490a8a711af6a4d288dd96824", size = 175577, upload-time = "2025-09-25T21:32:00.088Z" }, + { url = "https://files.pythonhosted.org/packages/0c/62/d2eb46264d4b157dae1275b573017abec435397aa59cbcdab6fc978a8af4/pyyaml-6.0.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:10892704fc220243f5305762e276552a0395f7beb4dbf9b14ec8fd43b57f126c", size = 775556, upload-time = "2025-09-25T21:32:01.31Z" }, + { url = "https://files.pythonhosted.org/packages/10/cb/16c3f2cf3266edd25aaa00d6c4350381c8b012ed6f5276675b9eba8d9ff4/pyyaml-6.0.3-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:850774a7879607d3a6f50d36d04f00ee69e7fc816450e5f7e58d7f17f1ae5c00", size = 882114, upload-time = "2025-09-25T21:32:03.376Z" }, + { url = "https://files.pythonhosted.org/packages/71/60/917329f640924b18ff085ab889a11c763e0b573da888e8404ff486657602/pyyaml-6.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b8bb0864c5a28024fac8a632c443c87c5aa6f215c0b126c449ae1a150412f31d", size = 806638, upload-time = "2025-09-25T21:32:04.553Z" }, + { url = "https://files.pythonhosted.org/packages/dd/6f/529b0f316a9fd167281a6c3826b5583e6192dba792dd55e3203d3f8e655a/pyyaml-6.0.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1d37d57ad971609cf3c53ba6a7e365e40660e3be0e5175fa9f2365a379d6095a", size = 767463, upload-time = "2025-09-25T21:32:06.152Z" }, + { url = "https://files.pythonhosted.org/packages/f2/6a/b627b4e0c1dd03718543519ffb2f1deea4a1e6d42fbab8021936a4d22589/pyyaml-6.0.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:37503bfbfc9d2c40b344d06b2199cf0e96e97957ab1c1b546fd4f87e53e5d3e4", size = 794986, upload-time = "2025-09-25T21:32:07.367Z" }, + { url = "https://files.pythonhosted.org/packages/45/91/47a6e1c42d9ee337c4839208f30d9f09caa9f720ec7582917b264defc875/pyyaml-6.0.3-cp311-cp311-win32.whl", hash = "sha256:8098f252adfa6c80ab48096053f512f2321f0b998f98150cea9bd23d83e1467b", size = 142543, upload-time = "2025-09-25T21:32:08.95Z" }, + { url = "https://files.pythonhosted.org/packages/da/e3/ea007450a105ae919a72393cb06f122f288ef60bba2dc64b26e2646fa315/pyyaml-6.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:9f3bfb4965eb874431221a3ff3fdcddc7e74e3b07799e0e84ca4a0f867d449bf", size = 158763, upload-time = "2025-09-25T21:32:09.96Z" }, + { url = "https://files.pythonhosted.org/packages/d1/33/422b98d2195232ca1826284a76852ad5a86fe23e31b009c9886b2d0fb8b2/pyyaml-6.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7f047e29dcae44602496db43be01ad42fc6f1cc0d8cd6c83d342306c32270196", size = 182063, upload-time = "2025-09-25T21:32:11.445Z" }, + { url = "https://files.pythonhosted.org/packages/89/a0/6cf41a19a1f2f3feab0e9c0b74134aa2ce6849093d5517a0c550fe37a648/pyyaml-6.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fc09d0aa354569bc501d4e787133afc08552722d3ab34836a80547331bb5d4a0", size = 173973, upload-time = "2025-09-25T21:32:12.492Z" }, + { url = "https://files.pythonhosted.org/packages/ed/23/7a778b6bd0b9a8039df8b1b1d80e2e2ad78aa04171592c8a5c43a56a6af4/pyyaml-6.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9149cad251584d5fb4981be1ecde53a1ca46c891a79788c0df828d2f166bda28", size = 775116, upload-time = "2025-09-25T21:32:13.652Z" }, + { url = "https://files.pythonhosted.org/packages/65/30/d7353c338e12baef4ecc1b09e877c1970bd3382789c159b4f89d6a70dc09/pyyaml-6.0.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5fdec68f91a0c6739b380c83b951e2c72ac0197ace422360e6d5a959d8d97b2c", size = 844011, upload-time = "2025-09-25T21:32:15.21Z" }, + { url = "https://files.pythonhosted.org/packages/8b/9d/b3589d3877982d4f2329302ef98a8026e7f4443c765c46cfecc8858c6b4b/pyyaml-6.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ba1cc08a7ccde2d2ec775841541641e4548226580ab850948cbfda66a1befcdc", size = 807870, upload-time = "2025-09-25T21:32:16.431Z" }, + { url = "https://files.pythonhosted.org/packages/05/c0/b3be26a015601b822b97d9149ff8cb5ead58c66f981e04fedf4e762f4bd4/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8dc52c23056b9ddd46818a57b78404882310fb473d63f17b07d5c40421e47f8e", size = 761089, upload-time = "2025-09-25T21:32:17.56Z" }, + { url = "https://files.pythonhosted.org/packages/be/8e/98435a21d1d4b46590d5459a22d88128103f8da4c2d4cb8f14f2a96504e1/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:41715c910c881bc081f1e8872880d3c650acf13dfa8214bad49ed4cede7c34ea", size = 790181, upload-time = "2025-09-25T21:32:18.834Z" }, + { url = "https://files.pythonhosted.org/packages/74/93/7baea19427dcfbe1e5a372d81473250b379f04b1bd3c4c5ff825e2327202/pyyaml-6.0.3-cp312-cp312-win32.whl", hash = "sha256:96b533f0e99f6579b3d4d4995707cf36df9100d67e0c8303a0c55b27b5f99bc5", size = 137658, upload-time = "2025-09-25T21:32:20.209Z" }, + { url = "https://files.pythonhosted.org/packages/86/bf/899e81e4cce32febab4fb42bb97dcdf66bc135272882d1987881a4b519e9/pyyaml-6.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:5fcd34e47f6e0b794d17de1b4ff496c00986e1c83f7ab2fb8fcfe9616ff7477b", size = 154003, upload-time = "2025-09-25T21:32:21.167Z" }, + { url = "https://files.pythonhosted.org/packages/1a/08/67bd04656199bbb51dbed1439b7f27601dfb576fb864099c7ef0c3e55531/pyyaml-6.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:64386e5e707d03a7e172c0701abfb7e10f0fb753ee1d773128192742712a98fd", size = 140344, upload-time = "2025-09-25T21:32:22.617Z" }, + { url = "https://files.pythonhosted.org/packages/d1/11/0fd08f8192109f7169db964b5707a2f1e8b745d4e239b784a5a1dd80d1db/pyyaml-6.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8da9669d359f02c0b91ccc01cac4a67f16afec0dac22c2ad09f46bee0697eba8", size = 181669, upload-time = "2025-09-25T21:32:23.673Z" }, + { url = "https://files.pythonhosted.org/packages/b1/16/95309993f1d3748cd644e02e38b75d50cbc0d9561d21f390a76242ce073f/pyyaml-6.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2283a07e2c21a2aa78d9c4442724ec1eb15f5e42a723b99cb3d822d48f5f7ad1", size = 173252, upload-time = "2025-09-25T21:32:25.149Z" }, + { url = "https://files.pythonhosted.org/packages/50/31/b20f376d3f810b9b2371e72ef5adb33879b25edb7a6d072cb7ca0c486398/pyyaml-6.0.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ee2922902c45ae8ccada2c5b501ab86c36525b883eff4255313a253a3160861c", size = 767081, upload-time = "2025-09-25T21:32:26.575Z" }, + { url = "https://files.pythonhosted.org/packages/49/1e/a55ca81e949270d5d4432fbbd19dfea5321eda7c41a849d443dc92fd1ff7/pyyaml-6.0.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a33284e20b78bd4a18c8c2282d549d10bc8408a2a7ff57653c0cf0b9be0afce5", size = 841159, upload-time = "2025-09-25T21:32:27.727Z" }, + { url = "https://files.pythonhosted.org/packages/74/27/e5b8f34d02d9995b80abcef563ea1f8b56d20134d8f4e5e81733b1feceb2/pyyaml-6.0.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0f29edc409a6392443abf94b9cf89ce99889a1dd5376d94316ae5145dfedd5d6", size = 801626, upload-time = "2025-09-25T21:32:28.878Z" }, + { url = "https://files.pythonhosted.org/packages/f9/11/ba845c23988798f40e52ba45f34849aa8a1f2d4af4b798588010792ebad6/pyyaml-6.0.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f7057c9a337546edc7973c0d3ba84ddcdf0daa14533c2065749c9075001090e6", size = 753613, upload-time = "2025-09-25T21:32:30.178Z" }, + { url = "https://files.pythonhosted.org/packages/3d/e0/7966e1a7bfc0a45bf0a7fb6b98ea03fc9b8d84fa7f2229e9659680b69ee3/pyyaml-6.0.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:eda16858a3cab07b80edaf74336ece1f986ba330fdb8ee0d6c0d68fe82bc96be", size = 794115, upload-time = "2025-09-25T21:32:31.353Z" }, + { url = "https://files.pythonhosted.org/packages/de/94/980b50a6531b3019e45ddeada0626d45fa85cbe22300844a7983285bed3b/pyyaml-6.0.3-cp313-cp313-win32.whl", hash = "sha256:d0eae10f8159e8fdad514efdc92d74fd8d682c933a6dd088030f3834bc8e6b26", size = 137427, upload-time = "2025-09-25T21:32:32.58Z" }, + { url = "https://files.pythonhosted.org/packages/97/c9/39d5b874e8b28845e4ec2202b5da735d0199dbe5b8fb85f91398814a9a46/pyyaml-6.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:79005a0d97d5ddabfeeea4cf676af11e647e41d81c9a7722a193022accdb6b7c", size = 154090, upload-time = "2025-09-25T21:32:33.659Z" }, + { url = "https://files.pythonhosted.org/packages/73/e8/2bdf3ca2090f68bb3d75b44da7bbc71843b19c9f2b9cb9b0f4ab7a5a4329/pyyaml-6.0.3-cp313-cp313-win_arm64.whl", hash = "sha256:5498cd1645aa724a7c71c8f378eb29ebe23da2fc0d7a08071d89469bf1d2defb", size = 140246, upload-time = "2025-09-25T21:32:34.663Z" }, + { url = "https://files.pythonhosted.org/packages/9d/8c/f4bd7f6465179953d3ac9bc44ac1a8a3e6122cf8ada906b4f96c60172d43/pyyaml-6.0.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:8d1fab6bb153a416f9aeb4b8763bc0f22a5586065f86f7664fc23339fc1c1fac", size = 181814, upload-time = "2025-09-25T21:32:35.712Z" }, + { url = "https://files.pythonhosted.org/packages/bd/9c/4d95bb87eb2063d20db7b60faa3840c1b18025517ae857371c4dd55a6b3a/pyyaml-6.0.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:34d5fcd24b8445fadc33f9cf348c1047101756fd760b4dacb5c3e99755703310", size = 173809, upload-time = "2025-09-25T21:32:36.789Z" }, + { url = "https://files.pythonhosted.org/packages/92/b5/47e807c2623074914e29dabd16cbbdd4bf5e9b2db9f8090fa64411fc5382/pyyaml-6.0.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:501a031947e3a9025ed4405a168e6ef5ae3126c59f90ce0cd6f2bfc477be31b7", size = 766454, upload-time = "2025-09-25T21:32:37.966Z" }, + { url = "https://files.pythonhosted.org/packages/02/9e/e5e9b168be58564121efb3de6859c452fccde0ab093d8438905899a3a483/pyyaml-6.0.3-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b3bc83488de33889877a0f2543ade9f70c67d66d9ebb4ac959502e12de895788", size = 836355, upload-time = "2025-09-25T21:32:39.178Z" }, + { url = "https://files.pythonhosted.org/packages/88/f9/16491d7ed2a919954993e48aa941b200f38040928474c9e85ea9e64222c3/pyyaml-6.0.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c458b6d084f9b935061bc36216e8a69a7e293a2f1e68bf956dcd9e6cbcd143f5", size = 794175, upload-time = "2025-09-25T21:32:40.865Z" }, + { url = "https://files.pythonhosted.org/packages/dd/3f/5989debef34dc6397317802b527dbbafb2b4760878a53d4166579111411e/pyyaml-6.0.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:7c6610def4f163542a622a73fb39f534f8c101d690126992300bf3207eab9764", size = 755228, upload-time = "2025-09-25T21:32:42.084Z" }, + { url = "https://files.pythonhosted.org/packages/d7/ce/af88a49043cd2e265be63d083fc75b27b6ed062f5f9fd6cdc223ad62f03e/pyyaml-6.0.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5190d403f121660ce8d1d2c1bb2ef1bd05b5f68533fc5c2ea899bd15f4399b35", size = 789194, upload-time = "2025-09-25T21:32:43.362Z" }, + { url = "https://files.pythonhosted.org/packages/23/20/bb6982b26a40bb43951265ba29d4c246ef0ff59c9fdcdf0ed04e0687de4d/pyyaml-6.0.3-cp314-cp314-win_amd64.whl", hash = "sha256:4a2e8cebe2ff6ab7d1050ecd59c25d4c8bd7e6f400f5f82b96557ac0abafd0ac", size = 156429, upload-time = "2025-09-25T21:32:57.844Z" }, + { url = "https://files.pythonhosted.org/packages/f4/f4/a4541072bb9422c8a883ab55255f918fa378ecf083f5b85e87fc2b4eda1b/pyyaml-6.0.3-cp314-cp314-win_arm64.whl", hash = "sha256:93dda82c9c22deb0a405ea4dc5f2d0cda384168e466364dec6255b293923b2f3", size = 143912, upload-time = "2025-09-25T21:32:59.247Z" }, + { url = "https://files.pythonhosted.org/packages/7c/f9/07dd09ae774e4616edf6cda684ee78f97777bdd15847253637a6f052a62f/pyyaml-6.0.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:02893d100e99e03eda1c8fd5c441d8c60103fd175728e23e431db1b589cf5ab3", size = 189108, upload-time = "2025-09-25T21:32:44.377Z" }, + { url = "https://files.pythonhosted.org/packages/4e/78/8d08c9fb7ce09ad8c38ad533c1191cf27f7ae1effe5bb9400a46d9437fcf/pyyaml-6.0.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:c1ff362665ae507275af2853520967820d9124984e0f7466736aea23d8611fba", size = 183641, upload-time = "2025-09-25T21:32:45.407Z" }, + { url = "https://files.pythonhosted.org/packages/7b/5b/3babb19104a46945cf816d047db2788bcaf8c94527a805610b0289a01c6b/pyyaml-6.0.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6adc77889b628398debc7b65c073bcb99c4a0237b248cacaf3fe8a557563ef6c", size = 831901, upload-time = "2025-09-25T21:32:48.83Z" }, + { url = "https://files.pythonhosted.org/packages/8b/cc/dff0684d8dc44da4d22a13f35f073d558c268780ce3c6ba1b87055bb0b87/pyyaml-6.0.3-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a80cb027f6b349846a3bf6d73b5e95e782175e52f22108cfa17876aaeff93702", size = 861132, upload-time = "2025-09-25T21:32:50.149Z" }, + { url = "https://files.pythonhosted.org/packages/b1/5e/f77dc6b9036943e285ba76b49e118d9ea929885becb0a29ba8a7c75e29fe/pyyaml-6.0.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00c4bdeba853cc34e7dd471f16b4114f4162dc03e6b7afcc2128711f0eca823c", size = 839261, upload-time = "2025-09-25T21:32:51.808Z" }, + { url = "https://files.pythonhosted.org/packages/ce/88/a9db1376aa2a228197c58b37302f284b5617f56a5d959fd1763fb1675ce6/pyyaml-6.0.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:66e1674c3ef6f541c35191caae2d429b967b99e02040f5ba928632d9a7f0f065", size = 805272, upload-time = "2025-09-25T21:32:52.941Z" }, + { url = "https://files.pythonhosted.org/packages/da/92/1446574745d74df0c92e6aa4a7b0b3130706a4142b2d1a5869f2eaa423c6/pyyaml-6.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:16249ee61e95f858e83976573de0f5b2893b3677ba71c9dd36b9cf8be9ac6d65", size = 829923, upload-time = "2025-09-25T21:32:54.537Z" }, + { url = "https://files.pythonhosted.org/packages/f0/7a/1c7270340330e575b92f397352af856a8c06f230aa3e76f86b39d01b416a/pyyaml-6.0.3-cp314-cp314t-win_amd64.whl", hash = "sha256:4ad1906908f2f5ae4e5a8ddfce73c320c2a1429ec52eafd27138b7f1cbe341c9", size = 174062, upload-time = "2025-09-25T21:32:55.767Z" }, + { url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" }, +] + [[package]] name = "rich" version = "14.3.3"