From be8758e13beecac1b1116db6404367c254ef9870 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Thu, 19 Mar 2026 15:37:25 +0100 Subject: [PATCH 01/70] Add datasource markdown and data vault managing --- anton/data_vault.py | 135 ++++++++++++++++++++++++++++++++++++++++++++ datasources.md | 94 ++++++++++++++++++++++++++++++ 2 files changed, 229 insertions(+) create mode 100644 anton/data_vault.py create mode 100644 datasources.md diff --git a/anton/data_vault.py b/anton/data_vault.py new file mode 100644 index 0000000..6b28ea4 --- /dev/null +++ b/anton/data_vault.py @@ -0,0 +1,135 @@ +"""Local Vault — stores data source connection credentials. + +Each connection is saved as a JSON file in ~/.anton/data_vault/. +Files are named {engine}-{connection_name} and contain the credential +fields in plain text. Files are only read at runtime when Anton needs +to establish or test a connection. +""" + +from __future__ import annotations + +import json +import os +import re +from datetime import datetime, timezone +from pathlib import Path + + +def _sanitize(value: str) -> str: + """Strip characters unsafe for file names, keep alphanumeric, dash, underscore.""" + return re.sub(r"[^\w\-]", "_", value).strip("_") + + +class DataVault: + """Manages data source connection credentials in ~/.anton/data_vault/.""" + + def __init__(self, vault_dir: Path | None = None) -> None: + self._dir = vault_dir or Path("~/.anton/data_vault").expanduser() + + # ── Paths ───────────────────────────────────────────────────── + + def _path_for(self, engine: str, name: str) -> Path: + return self._dir / f"{_sanitize(engine)}-{_sanitize(name)}" + + def _ensure_dir(self) -> None: + self._dir.mkdir(parents=True, exist_ok=True) + self._dir.chmod(0o700) + + # ── CRUD ────────────────────────────────────────────────────── + + def save(self, engine: str, name: str, credentials: dict[str, str]) -> Path: + """Write credentials as JSON. Creates vault dir if needed.""" + self._ensure_dir() + path = self._path_for(engine, name) + data = { + "engine": engine, + "name": name, + "created_at": datetime.now(timezone.utc).isoformat(), + "fields": credentials, + } + path.write_text(json.dumps(data, indent=2), encoding="utf-8") + path.chmod(0o600) + return path + + def load(self, engine: str, name: str) -> dict[str, str] | None: + """Return the fields dict for a connection, or None if not found.""" + path = self._path_for(engine, name) + if not path.is_file(): + return None + try: + data = json.loads(path.read_text(encoding="utf-8")) + return data.get("fields", {}) + except (json.JSONDecodeError, OSError): + return None + + def delete(self, engine: str, name: str) -> bool: + """Remove a connection file. Returns True if it existed.""" + path = self._path_for(engine, name) + if path.is_file(): + path.unlink() + return True + return False + + def list_connections(self) -> list[dict[str, str]]: + """Return [{engine, name, created_at}] for all stored connections.""" + if not self._dir.is_dir(): + return [] + results: list[dict[str, str]] = [] + for path in sorted(self._dir.iterdir()): + if not path.is_file(): + continue + try: + data = json.loads(path.read_text(encoding="utf-8")) + results.append({ + "engine": data.get("engine", ""), + "name": data.get("name", ""), + "created_at": data.get("created_at", ""), + }) + except (json.JSONDecodeError, OSError): + continue + return results + + # ── Environment injection ────────────────────────────────────── + + def inject_env(self, engine: str, name: str) -> list[str] | None: + """Load credentials and set DS_ in os.environ. + + Returns the list of env var names set, or None if connection not found. + Call this before a scratchpad exec that needs the data source. + """ + fields = self.load(engine, name) + if fields is None: + return None + var_names: list[str] = [] + for key, value in fields.items(): + var = f"DS_{key.upper()}" + os.environ[var] = value + var_names.append(var) + return var_names + + def clear_ds_env(self) -> None: + """Remove all DS_* variables from os.environ.""" + ds_keys = [k for k in os.environ if k.startswith("DS_")] + for key in ds_keys: + del os.environ[key] + + # ── Helpers ─────────────────────────────────────────────────── + + def next_connection_number(self, engine: str) -> int: + """Return the next auto-increment number for an engine (1-based). + + Used when naming partial connections: postgresql-1, postgresql-2, etc. + """ + prefix = _sanitize(engine) + "-" + if not self._dir.is_dir(): + return 1 + existing = [ + p.name for p in self._dir.iterdir() + if p.is_file() and p.name.startswith(prefix) + ] + max_n = 0 + for fname in existing: + suffix = fname[len(prefix):] + if suffix.isdigit(): + max_n = max(max_n, int(suffix)) + return max_n + 1 diff --git a/datasources.md b/datasources.md new file mode 100644 index 0000000..0450078 --- /dev/null +++ b/datasources.md @@ -0,0 +1,94 @@ +# Datasource Knowledge + +Anton reads this file when connecting data sources. For each source, the YAML +block defines the fields Python collects. The prose below describes auth flows, +common errors, and how to handle OAuth2 — Anton handles those using the scratchpad. + +When a connection is saved, write a brief topic file to +`.anton/memory/topics/datasource-{name}.md` so future sessions know +how to query this source. + +Credentials are injected as `DS_` environment variables +before any scratchpad code runs. Never embed raw values in code strings. + +--- + +## PostgreSQL +```yaml +engine: postgres +display_name: PostgreSQL +pip: psycopg2-binary +name_from: database +fields: + - { name: host, required: true, secret: false, description: "hostname or IP of your database server" } + - { name: port, required: false, secret: false, description: "port number", default: "5432" } + - { name: database, required: true, secret: false, description: "name of the database to connect" } + - { name: user, required: true, secret: false, description: "database username" } + - { name: password, required: true, secret: true, description: "database password" } + - { name: schema, required: false, secret: false, description: "defaults to public if not set" } + - { name: ssl, required: false, secret: false, description: "enable SSL (true/false)" } +test_snippet: | + import psycopg2, os + conn = psycopg2.connect( + host=os.environ['DS_HOST'], port=os.environ.get('DS_PORT','5432'), + dbname=os.environ['DS_DATABASE'], user=os.environ['DS_USER'], + password=os.environ['DS_PASSWORD'], + ) + conn.close() + print("ok") +``` + +Common errors: "password authentication failed" → wrong password or user. +"could not connect to server" → wrong host/port or firewall blocking. + +--- + +## HubSpot +```yaml +engine: hubspot +display_name: HubSpot +pip: hubspot-api-client +name_from: [] +auth_method: choice +fields: [] +auth_methods: + - name: pat + display: "Private App Token (recommended)" + fields: + - { name: access_token, required: true, secret: true, description: "HubSpot Private App token (starts with pat-na1-)" } + - name: oauth2 + display: "OAuth2 (for multi-account or publishable apps)" + fields: + - { name: client_id, required: true, secret: false, description: "OAuth2 client ID" } + - { name: client_secret, required: true, secret: true, description: "OAuth2 client secret" } + oauth2: + auth_url: https://app.hubspot.com/oauth/authorize + token_url: https://api.hubapi.com/oauth/v1/token + scopes: [crm.objects.contacts.read, crm.objects.deals.read] + store_fields: [access_token, refresh_token] +test_snippet: | + import hubspot, os + client = hubspot.Client.create(access_token=os.environ['DS_ACCESS_TOKEN']) + client.crm.contacts.basic_api.get_page(limit=1) + print("ok") +``` + +For Private App Token: HubSpot → Settings → Integrations → Private Apps → Create. +Recommended scopes: `crm.objects.contacts.read`, `crm.objects.deals.read`, `crm.objects.companies.read`. + +For OAuth2: collect client_id and client_secret, then use the scratchpad to: +1. Build the authorization URL using `auth_url` + params above +2. Start a local HTTP server on port 8099 to catch the callback +3. Open the URL in the user's browser with `webbrowser.open()` +4. Extract the `code` from the callback, POST to `token_url` for tokens +5. Return `access_token` and `refresh_token` to store in wallet + +--- + +## Snowflake +... + +## Adding a new data source + +Follow the YAML format above. Add to `~/.anton/datasources.md` (user overrides). +Anton merges user overrides on top of the built-in registry at startup. \ No newline at end of file From 5b87bc38e5c779aade544d11c9cb5ac7628d5846 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Thu, 19 Mar 2026 15:37:58 +0100 Subject: [PATCH 02/70] Add datasource registry for managing the datasource connections --- anton/datasource_registry.py | 132 +++++++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 anton/datasource_registry.py diff --git a/anton/datasource_registry.py b/anton/datasource_registry.py new file mode 100644 index 0000000..507984e --- /dev/null +++ b/anton/datasource_registry.py @@ -0,0 +1,132 @@ +"""Datasource registry — parses datasources.md engine definitions. + +Reads YAML blocks from the built-in datasources.md (project root) and +merges user overrides from ~/.anton/datasources.md on top. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Union + +import yaml + + +@dataclass +class DatasourceField: + name: str + required: bool = True + secret: bool = False + description: str = "" + default: str = "" + + +@dataclass +class DatasourceEngine: + engine: str + display_name: str + pip: str = "" + name_from: Union[str, list[str]] = "" + fields: list[DatasourceField] = field(default_factory=list) + test_snippet: str = "" + + +#Parse the file for engine defintions and extract the YAML blocks. +_YAML_BLOCK_RE = re.compile( + r"^##\s+(.+?)\s*$\n(.*?)^```yaml\n(.*?)^```", + re.MULTILINE | re.DOTALL, +) + + +def _parse_file(path: Path) -> dict[str, DatasourceEngine]: + """Extract engine definitions from a datasources.md file.""" + if not path.is_file(): + return {} + text = path.read_text(encoding="utf-8") + engines: dict[str, DatasourceEngine] = {} + + for match in _YAML_BLOCK_RE.finditer(text): + yaml_text = match.group(3) + try: + data = yaml.safe_load(yaml_text) + except yaml.YAMLError: + continue + if not isinstance(data, dict) or "engine" not in data: + continue + + raw_fields = data.get("fields", []) or [] + parsed_fields: list[DatasourceField] = [] + for f in raw_fields: + if not isinstance(f, dict): + continue + parsed_fields.append(DatasourceField( + name=f.get("name", ""), + required=bool(f.get("required", True)), + secret=bool(f.get("secret", False)), + description=f.get("description", ""), + default=str(f.get("default", "")), + )) + + engine_slug = str(data["engine"]) + engines[engine_slug] = DatasourceEngine( + engine=engine_slug, + display_name=str(data.get("display_name", engine_slug)), + pip=str(data.get("pip", "")), + name_from=data.get("name_from", ""), + fields=parsed_fields, + test_snippet=str(data.get("test_snippet", "")), + ) + + return engines + + +class DatasourceRegistry: + """Parsed registry of all available data source engines.""" + + # Default connection definition + _BUILTIN_PATH: Path = Path(__file__).parent.parent / "datasources.md" + # If user adds new connection + _USER_PATH: Path = Path("~/.anton/datasources.md").expanduser() + + def __init__(self) -> None: + self._engines: dict[str, DatasourceEngine] = {} + self._load() + + def _load(self) -> None: + self._engines = _parse_file(self._BUILTIN_PATH) + # Merge user overrides on top + for slug, engine in _parse_file(self._USER_PATH).items(): + self._engines[slug] = engine + + def get(self, engine_slug: str) -> DatasourceEngine | None: + """Look up an engine by its slug (e.g. 'postgres').""" + return self._engines.get(engine_slug) + + def find_by_name(self, display_name: str) -> DatasourceEngine | None: + """Case-insensitive match on display_name (e.g. 'PostgreSQL').""" + needle = display_name.strip().lower() + for engine in self._engines.values(): + if engine.display_name.lower() == needle: + return engine + # Also match the slug directly + if engine.engine.lower() == needle: + return engine + return None + + def all_engines(self) -> list[DatasourceEngine]: + return sorted(self._engines.values(), key=lambda e: e.display_name) + + def derive_name(self, engine_def: DatasourceEngine, credentials: dict[str, str]) -> str: + """Derive a default connection name from name_from field(s). + + Falls back to the engine slug if the field is missing or empty. + """ + name_from = engine_def.name_from + if not name_from: + return engine_def.engine + if isinstance(name_from, str): + return credentials.get(name_from, engine_def.engine) + parts = [credentials.get(f, "") for f in name_from if credentials.get(f)] + return "_".join(parts) if parts else engine_def.engine From 7180089431092ad8de53c0eb1e6c69e1bcf07eb1 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Thu, 19 Mar 2026 15:38:55 +0100 Subject: [PATCH 03/70] Data vauld updates --- anton/data_vault.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/anton/data_vault.py b/anton/data_vault.py index 6b28ea4..9ab73d4 100644 --- a/anton/data_vault.py +++ b/anton/data_vault.py @@ -26,7 +26,6 @@ class DataVault: def __init__(self, vault_dir: Path | None = None) -> None: self._dir = vault_dir or Path("~/.anton/data_vault").expanduser() - # ── Paths ───────────────────────────────────────────────────── def _path_for(self, engine: str, name: str) -> Path: return self._dir / f"{_sanitize(engine)}-{_sanitize(name)}" @@ -35,7 +34,6 @@ def _ensure_dir(self) -> None: self._dir.mkdir(parents=True, exist_ok=True) self._dir.chmod(0o700) - # ── CRUD ────────────────────────────────────────────────────── def save(self, engine: str, name: str, credentials: dict[str, str]) -> Path: """Write credentials as JSON. Creates vault dir if needed.""" @@ -89,7 +87,6 @@ def list_connections(self) -> list[dict[str, str]]: continue return results - # ── Environment injection ────────────────────────────────────── def inject_env(self, engine: str, name: str) -> list[str] | None: """Load credentials and set DS_ in os.environ. @@ -113,7 +110,6 @@ def clear_ds_env(self) -> None: for key in ds_keys: del os.environ[key] - # ── Helpers ─────────────────────────────────────────────────── def next_connection_number(self, engine: str) -> int: """Return the next auto-increment number for an engine (1-based). From eea24b2cb2d8ff365b2a23c39fedb10400d997c1 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Thu, 19 Mar 2026 15:39:34 +0100 Subject: [PATCH 04/70] Update cli command --- anton/cli.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/anton/cli.py b/anton/cli.py index a701342..decd701 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -440,3 +440,13 @@ def list_learnings(ctx: typer.Context) -> None: def version() -> None: """Show Anton version.""" console.print(f"Anton v{__version__}") + + +@app.command("connect-datasource") +def connect_datasource(ctx: typer.Context) -> None: + """Connect a new datasource.""" + #from anton.datasources import connect_new_datasource + print("Datasource connection flow is not implemented yet.") + settings = _get_settings(ctx) + _ensure_workspace(settings) + #connect_new_datasource(console, settings) \ No newline at end of file From df1ebebeae459d16462a2245099d73d30138e9e2 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Thu, 19 Mar 2026 17:41:06 +0100 Subject: [PATCH 05/70] Fixing the session handling when user makes errors --- anton/chat.py | 262 +++++++++++++++++++++++++++++++++-- anton/datasource_registry.py | 67 +++++---- datasources.md | 6 - pyproject.toml | 1 + uv.lock | 57 ++++++++ 5 files changed, 354 insertions(+), 39 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index b86e36c..1ee1c1d 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -2182,18 +2182,259 @@ def _human_size(nbytes: int) -> str: return f"{nbytes:.1f}TB" +async def _handle_connect_datasource( + console: Console, + scratchpads: ScratchpadManager, + session: "ChatSession", +) -> "ChatSession": + """Interactive flow for connecting a new data source to the Local Vault.""" + from rich.prompt import Prompt + + from anton.data_vault import DataVault + from anton.datasource_registry import DatasourceRegistry + + vault = DataVault() + registry = DatasourceRegistry() + + # ── Step 1: ask which engine ────────────────────────────────── + console.print() + engine_names = ", ".join(e.display_name for e in registry.all_engines()) + answer = Prompt.ask( + f"[anton.cyan](anton)[/] Which data source would you like to connect?\n" + f" [anton.muted](e.g. {engine_names})[/]\n", + console=console, + ) + engine_def = registry.find_by_name(answer.strip()) + if engine_def is None: + console.print(f"[anton.warning]Unknown data source '{answer}'. Available:[/]") + for e in registry.all_engines(): + console.print(f" • {e.display_name}") + console.print() + return session + + # ── Step 2a: auth method choice (if engine requires it) ─────── + active_fields = engine_def.fields + if engine_def.auth_method == "choice" and engine_def.auth_methods: + console.print() + console.print( + f"[anton.cyan](anton)[/] How would you like to authenticate with " + f"[bold]{engine_def.display_name}[/]?" + ) + console.print() + for i, am in enumerate(engine_def.auth_methods, 1): + console.print(f" {i}. {am.display}") + console.print() + choice_str = Prompt.ask( + "[anton.cyan](anton)[/] Enter a number", + console=console, + ).strip() + try: + choice_idx = int(choice_str) - 1 + chosen_method = engine_def.auth_methods[choice_idx] + except (ValueError, IndexError): + console.print("[anton.warning]Invalid choice. Aborting.[/]") + console.print() + return session + active_fields = chosen_method.fields + + # ── Step 2b: show fields ────────────────────────────────────── + required_fields = [f for f in active_fields if f.required] + optional_fields = [f for f in active_fields if not f.required] + + console.print() + console.print( + f"[anton.cyan](anton)[/] To connect [bold]{engine_def.display_name}[/], " + "I'll need the following:" + ) + console.print() + + if required_fields: + console.print(" [bold]Required[/] " + "─" * 39) + for f in required_fields: + console.print(f" • [bold]{f.name:<12}[/] [anton.muted]— {f.description}[/]") + + if optional_fields: + console.print() + console.print(" [bold]Optional[/] " + "─" * 39) + for f in optional_fields: + console.print(f" • [bold]{f.name:<12}[/] [anton.muted]— {f.description}[/]") + + console.print() + + # ── Step 3: determine collection mode ──────────────────────── + mode_answer = Prompt.ask( + "[anton.cyan](anton)[/] Do you have these available? [y/n/]", + console=console, + ).strip().lower() + + if mode_answer == "n": + console.print() + console.print( + "[anton.cyan](anton)[/] No problem. Which parameters do you have? " + "I'll save a partial connection now, and you can fill in the rest later " + "with [bold]/edit-data-source[/]." + ) + console.print() + console.print(" Provide what you have (press enter to skip any field):") + console.print() + fields_to_collect = active_fields + partial = True + elif mode_answer == "y": + fields_to_collect = active_fields + partial = False + else: + # User gave a comma-separated list of param names — filter to those fields. + # If nothing matches (e.g. they typed a credential value by mistake), fall + # back to collecting all fields. + requested = {n.strip().lower() for n in mode_answer.split(",")} + matched = [f for f in active_fields if f.name.lower() in requested] + fields_to_collect = matched if matched else active_fields + partial = False + + # ── Step 4: collect credentials ────────────────────────────── + console.print() + credentials: dict[str, str] = {} + + for f in fields_to_collect: + prompt_label = f"[anton.cyan](anton)[/] {f.name}" + if not f.required or partial: + prompt_label += " [anton.muted](optional, press enter to skip)[/]" + + if f.secret: + value = Prompt.ask(prompt_label, password=True, console=console, default="") + elif f.default: + value = Prompt.ask( + f"{prompt_label} [anton.muted][{f.default}][/]", + console=console, + default=f.default, + ) + else: + value = Prompt.ask(prompt_label, console=console, default="") + + if value: + credentials[f.name] = value + + # ── Partial save ───────────────────────────────────────────── + if partial: + n = vault.next_connection_number(engine_def.engine) + auto_name = str(n) + vault.save(engine_def.engine, auto_name, credentials) + slug = f"{engine_def.engine}-{auto_name}" + console.print() + console.print( + f"[anton.muted]Partial connection saved to Local Vault as " + f"[bold]\"{slug}\"[/bold]. " + f"Run [bold]/edit-data-source {slug}[/bold] to complete it when you're ready.[/]" + ) + console.print() + return session + + # ── Step 5: test connection ─────────────────────────────────── + if engine_def.test_snippet: + while True: + console.print() + console.print("[anton.cyan](anton)[/] Got it. Testing connection…") + + # Temporarily inject DS_* into os.environ (test before committing to vault) + import os as _os + for key, value in credentials.items(): + _os.environ[f"DS_{key.upper()}"] = value + + try: + pad = await scratchpads.get_or_create("__datasource_test__") + await pad.reset() # fresh subprocess inherits current os.environ + + if engine_def.pip: + await pad.install_packages([engine_def.pip]) + + cell = await pad.execute(engine_def.test_snippet) + finally: + # Always clean up DS_* regardless of outcome + ds_keys = [k for k in _os.environ if k.startswith("DS_")] + for k in ds_keys: + del _os.environ[k] + + if cell.error or (cell.stdout.strip() != "ok" and cell.stderr.strip()): + error_text = cell.error or cell.stderr.strip() or cell.stdout.strip() + # Show first meaningful line of the error + first_line = next( + (ln for ln in error_text.splitlines() if ln.strip()), error_text + ) + console.print() + console.print("[anton.warning](anton)[/] ✗ Connection failed.") + console.print() + console.print(f" Error: {first_line}") + console.print() + + retry = Prompt.ask( + "[anton.cyan](anton)[/] Would you like to re-enter your credentials? [y/n]", + console=console, + default="n", + ).strip().lower() + + if retry != "y": + return session + + # Re-collect secret fields only + console.print() + for f in engine_def.fields: + if not f.secret: + continue + value = Prompt.ask( + f"[anton.cyan](anton)[/] {f.name}", + password=True, + console=console, + default="", + ) + if value: + credentials[f.name] = value + continue # retry test + + # Success + console.print("[anton.success] ✓ Connected successfully![/]") + break + + # ── Step 6: save + write topic ─────────────────────────────── + conn_name = registry.derive_name(engine_def, credentials) + if not conn_name or conn_name == engine_def.engine: + # Fall back to auto-number if name_from didn't resolve + n = vault.next_connection_number(engine_def.engine) + conn_name = str(n) + + vault.save(engine_def.engine, conn_name, credentials) + slug = f"{engine_def.engine}-{conn_name}" + console.print( + f" Credentials saved to Local Vault as [bold]\"{slug}\"[/bold]." + ) + + console.print() + console.print("[anton.muted] You can now ask me questions about your data.[/]") + console.print() + + # Inject a brief assistant message so the LLM is aware of the new connection + session._history.append({ + "role": "assistant", + "content": ( + f"I've saved a {engine_def.display_name} connection named \"{slug}\" " + f"to the Local Vault. I can now query this data source when needed." + ), + }) + return session + + def _print_slash_help(console: Console) -> None: """Print available slash commands.""" console.print() console.print("[anton.cyan]Available commands:[/]") - console.print(" [bold]/connect[/] — Connect to a Minds server and select a mind") - console.print(" [bold]/data-connections[/] — View and manage stored keys and connections") - console.print(" [bold]/setup[/] — Configure models or memory settings") - console.print(" [bold]/memory[/] — Show memory status dashboard") - console.print(" [bold]/paste[/] — Attach clipboard image to your message") - console.print(" [bold]/resume[/] — Resume a previous chat session") - console.print(" [bold]/help[/] — Show this help message") - console.print(" [bold]exit[/] — Quit the chat") + console.print(" [bold]/connect[/] — Connect to a Minds server and select a mind") + console.print(" [bold]/connect-data-source[/] — Connect a database or API to the Local Vault") + console.print(" [bold]/data-connections[/] — View and manage stored keys and connections") + console.print(" [bold]/setup[/] — Configure models or memory settings") + console.print(" [bold]/memory[/] — Show memory status dashboard") + console.print(" [bold]/paste[/] — Attach clipboard image to your message") + console.print(" [bold]/resume[/] — Resume a previous chat session") + console.print(" [bold]/help[/] — Show this help message") + console.print(" [bold]exit[/] — Quit the chat") console.print() @@ -2541,6 +2782,11 @@ def _bottom_toolbar(): elif cmd == "/memory": _handle_memory(console, settings, cortex, episodic=episodic) continue + elif cmd == "/connect-data-source": + session = await _handle_connect_datasource( + console, session._scratchpads, session, + ) + continue elif cmd == "/data-connections": session = await _handle_data_connections( console, settings, workspace, session, diff --git a/anton/datasource_registry.py b/anton/datasource_registry.py index 507984e..530b482 100644 --- a/anton/datasource_registry.py +++ b/anton/datasource_registry.py @@ -23,6 +23,13 @@ class DatasourceField: default: str = "" +@dataclass +class AuthMethod: + name: str + display: str + fields: list[DatasourceField] = field(default_factory=list) + + @dataclass class DatasourceEngine: engine: str @@ -30,16 +37,34 @@ class DatasourceEngine: pip: str = "" name_from: Union[str, list[str]] = "" fields: list[DatasourceField] = field(default_factory=list) + # "choice" means the user must pick from auth_methods before collecting fields + auth_method: str = "" + auth_methods: list[AuthMethod] = field(default_factory=list) test_snippet: str = "" -#Parse the file for engine defintions and extract the YAML blocks. +# Matches a level-2 heading followed by a ```yaml fenced block. _YAML_BLOCK_RE = re.compile( r"^##\s+(.+?)\s*$\n(.*?)^```yaml\n(.*?)^```", re.MULTILINE | re.DOTALL, ) +def _parse_fields(raw: list) -> list[DatasourceField]: + result: list[DatasourceField] = [] + for f in raw or []: + if not isinstance(f, dict): + continue + result.append(DatasourceField( + name=f.get("name", ""), + required=bool(f.get("required", True)), + secret=bool(f.get("secret", False)), + description=f.get("description", ""), + default=str(f.get("default", "")), + )) + return result + + def _parse_file(path: Path) -> dict[str, DatasourceEngine]: """Extract engine definitions from a datasources.md file.""" if not path.is_file(): @@ -56,17 +81,15 @@ def _parse_file(path: Path) -> dict[str, DatasourceEngine]: if not isinstance(data, dict) or "engine" not in data: continue - raw_fields = data.get("fields", []) or [] - parsed_fields: list[DatasourceField] = [] - for f in raw_fields: - if not isinstance(f, dict): + raw_auth_methods = data.get("auth_methods", []) or [] + auth_methods: list[AuthMethod] = [] + for am in raw_auth_methods: + if not isinstance(am, dict): continue - parsed_fields.append(DatasourceField( - name=f.get("name", ""), - required=bool(f.get("required", True)), - secret=bool(f.get("secret", False)), - description=f.get("description", ""), - default=str(f.get("default", "")), + auth_methods.append(AuthMethod( + name=am.get("name", ""), + display=am.get("display", am.get("name", "")), + fields=_parse_fields(am.get("fields", [])), )) engine_slug = str(data["engine"]) @@ -75,7 +98,9 @@ def _parse_file(path: Path) -> dict[str, DatasourceEngine]: display_name=str(data.get("display_name", engine_slug)), pip=str(data.get("pip", "")), name_from=data.get("name_from", ""), - fields=parsed_fields, + fields=_parse_fields(data.get("fields", [])), + auth_method=str(data.get("auth_method", "")), + auth_methods=auth_methods, test_snippet=str(data.get("test_snippet", "")), ) @@ -85,9 +110,7 @@ def _parse_file(path: Path) -> dict[str, DatasourceEngine]: class DatasourceRegistry: """Parsed registry of all available data source engines.""" - # Default connection definition _BUILTIN_PATH: Path = Path(__file__).parent.parent / "datasources.md" - # If user adds new connection _USER_PATH: Path = Path("~/.anton/datasources.md").expanduser() def __init__(self) -> None: @@ -96,21 +119,18 @@ def __init__(self) -> None: def _load(self) -> None: self._engines = _parse_file(self._BUILTIN_PATH) - # Merge user overrides on top for slug, engine in _parse_file(self._USER_PATH).items(): self._engines[slug] = engine def get(self, engine_slug: str) -> DatasourceEngine | None: - """Look up an engine by its slug (e.g. 'postgres').""" return self._engines.get(engine_slug) def find_by_name(self, display_name: str) -> DatasourceEngine | None: - """Case-insensitive match on display_name (e.g. 'PostgreSQL').""" + """Case-insensitive match on display_name or engine slug.""" needle = display_name.strip().lower() for engine in self._engines.values(): if engine.display_name.lower() == needle: return engine - # Also match the slug directly if engine.engine.lower() == needle: return engine return None @@ -119,14 +139,11 @@ def all_engines(self) -> list[DatasourceEngine]: return sorted(self._engines.values(), key=lambda e: e.display_name) def derive_name(self, engine_def: DatasourceEngine, credentials: dict[str, str]) -> str: - """Derive a default connection name from name_from field(s). - - Falls back to the engine slug if the field is missing or empty. - """ + """Derive a default connection name from name_from field(s).""" name_from = engine_def.name_from if not name_from: - return engine_def.engine + return "" if isinstance(name_from, str): - return credentials.get(name_from, engine_def.engine) + return credentials.get(name_from, "") parts = [credentials.get(f, "") for f in name_from if credentials.get(f)] - return "_".join(parts) if parts else engine_def.engine + return "_".join(parts) diff --git a/datasources.md b/datasources.md index 0450078..c63074a 100644 --- a/datasources.md +++ b/datasources.md @@ -4,10 +4,6 @@ Anton reads this file when connecting data sources. For each source, the YAML block defines the fields Python collects. The prose below describes auth flows, common errors, and how to handle OAuth2 — Anton handles those using the scratchpad. -When a connection is saved, write a brief topic file to -`.anton/memory/topics/datasource-{name}.md` so future sessions know -how to query this source. - Credentials are injected as `DS_` environment variables before any scratchpad code runs. Never embed raw values in code strings. @@ -48,9 +44,7 @@ Common errors: "password authentication failed" → wrong password or user. engine: hubspot display_name: HubSpot pip: hubspot-api-client -name_from: [] auth_method: choice -fields: [] auth_methods: - name: pat display: "Private App Token (recommended)" diff --git a/pyproject.toml b/pyproject.toml index 99ff96f..f648bcb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "rich>=13.0", "prompt-toolkit>=3.0", "packaging>=21.0", + "pyyaml>=6.0", ] [project.optional-dependencies] diff --git a/uv.lock b/uv.lock index 7838860..bba6b2b 100644 --- a/uv.lock +++ b/uv.lock @@ -49,6 +49,7 @@ dependencies = [ { name = "prompt-toolkit" }, { name = "pydantic" }, { name = "pydantic-settings" }, + { name = "pyyaml" }, { name = "rich" }, { name = "typer" }, ] @@ -73,6 +74,7 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24" }, + { name = "pyyaml", specifier = ">=6.0" }, { name = "rich", specifier = ">=13.0" }, { name = "typer", specifier = ">=0.9" }, ] @@ -609,6 +611,61 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/14/1b/a298b06749107c305e1fe0f814c6c74aea7b2f1e10989cb30f544a1b3253/python_dotenv-1.2.1-py3-none-any.whl", hash = "sha256:b81ee9561e9ca4004139c6cbba3a238c32b03e4894671e181b671e8cb8425d61", size = 21230, upload-time = "2025-10-26T15:12:09.109Z" }, ] +[[package]] +name = "pyyaml" +version = "6.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/05/8e/961c0007c59b8dd7729d542c61a4d537767a59645b82a0b521206e1e25c2/pyyaml-6.0.3.tar.gz", hash = "sha256:d76623373421df22fb4cf8817020cbb7ef15c725b9d5e45f17e189bfc384190f", size = 130960, upload-time = "2025-09-25T21:33:16.546Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6d/16/a95b6757765b7b031c9374925bb718d55e0a9ba8a1b6a12d25962ea44347/pyyaml-6.0.3-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:44edc647873928551a01e7a563d7452ccdebee747728c1080d881d68af7b997e", size = 185826, upload-time = "2025-09-25T21:31:58.655Z" }, + { url = "https://files.pythonhosted.org/packages/16/19/13de8e4377ed53079ee996e1ab0a9c33ec2faf808a4647b7b4c0d46dd239/pyyaml-6.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:652cb6edd41e718550aad172851962662ff2681490a8a711af6a4d288dd96824", size = 175577, upload-time = "2025-09-25T21:32:00.088Z" }, + { url = "https://files.pythonhosted.org/packages/0c/62/d2eb46264d4b157dae1275b573017abec435397aa59cbcdab6fc978a8af4/pyyaml-6.0.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:10892704fc220243f5305762e276552a0395f7beb4dbf9b14ec8fd43b57f126c", size = 775556, upload-time = "2025-09-25T21:32:01.31Z" }, + { url = "https://files.pythonhosted.org/packages/10/cb/16c3f2cf3266edd25aaa00d6c4350381c8b012ed6f5276675b9eba8d9ff4/pyyaml-6.0.3-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:850774a7879607d3a6f50d36d04f00ee69e7fc816450e5f7e58d7f17f1ae5c00", size = 882114, upload-time = "2025-09-25T21:32:03.376Z" }, + { url = "https://files.pythonhosted.org/packages/71/60/917329f640924b18ff085ab889a11c763e0b573da888e8404ff486657602/pyyaml-6.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b8bb0864c5a28024fac8a632c443c87c5aa6f215c0b126c449ae1a150412f31d", size = 806638, upload-time = "2025-09-25T21:32:04.553Z" }, + { url = "https://files.pythonhosted.org/packages/dd/6f/529b0f316a9fd167281a6c3826b5583e6192dba792dd55e3203d3f8e655a/pyyaml-6.0.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1d37d57ad971609cf3c53ba6a7e365e40660e3be0e5175fa9f2365a379d6095a", size = 767463, upload-time = "2025-09-25T21:32:06.152Z" }, + { url = "https://files.pythonhosted.org/packages/f2/6a/b627b4e0c1dd03718543519ffb2f1deea4a1e6d42fbab8021936a4d22589/pyyaml-6.0.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:37503bfbfc9d2c40b344d06b2199cf0e96e97957ab1c1b546fd4f87e53e5d3e4", size = 794986, upload-time = "2025-09-25T21:32:07.367Z" }, + { url = "https://files.pythonhosted.org/packages/45/91/47a6e1c42d9ee337c4839208f30d9f09caa9f720ec7582917b264defc875/pyyaml-6.0.3-cp311-cp311-win32.whl", hash = "sha256:8098f252adfa6c80ab48096053f512f2321f0b998f98150cea9bd23d83e1467b", size = 142543, upload-time = "2025-09-25T21:32:08.95Z" }, + { url = "https://files.pythonhosted.org/packages/da/e3/ea007450a105ae919a72393cb06f122f288ef60bba2dc64b26e2646fa315/pyyaml-6.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:9f3bfb4965eb874431221a3ff3fdcddc7e74e3b07799e0e84ca4a0f867d449bf", size = 158763, upload-time = "2025-09-25T21:32:09.96Z" }, + { url = "https://files.pythonhosted.org/packages/d1/33/422b98d2195232ca1826284a76852ad5a86fe23e31b009c9886b2d0fb8b2/pyyaml-6.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7f047e29dcae44602496db43be01ad42fc6f1cc0d8cd6c83d342306c32270196", size = 182063, upload-time = "2025-09-25T21:32:11.445Z" }, + { url = "https://files.pythonhosted.org/packages/89/a0/6cf41a19a1f2f3feab0e9c0b74134aa2ce6849093d5517a0c550fe37a648/pyyaml-6.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fc09d0aa354569bc501d4e787133afc08552722d3ab34836a80547331bb5d4a0", size = 173973, upload-time = "2025-09-25T21:32:12.492Z" }, + { url = "https://files.pythonhosted.org/packages/ed/23/7a778b6bd0b9a8039df8b1b1d80e2e2ad78aa04171592c8a5c43a56a6af4/pyyaml-6.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9149cad251584d5fb4981be1ecde53a1ca46c891a79788c0df828d2f166bda28", size = 775116, upload-time = "2025-09-25T21:32:13.652Z" }, + { url = "https://files.pythonhosted.org/packages/65/30/d7353c338e12baef4ecc1b09e877c1970bd3382789c159b4f89d6a70dc09/pyyaml-6.0.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5fdec68f91a0c6739b380c83b951e2c72ac0197ace422360e6d5a959d8d97b2c", size = 844011, upload-time = "2025-09-25T21:32:15.21Z" }, + { url = "https://files.pythonhosted.org/packages/8b/9d/b3589d3877982d4f2329302ef98a8026e7f4443c765c46cfecc8858c6b4b/pyyaml-6.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ba1cc08a7ccde2d2ec775841541641e4548226580ab850948cbfda66a1befcdc", size = 807870, upload-time = "2025-09-25T21:32:16.431Z" }, + { url = "https://files.pythonhosted.org/packages/05/c0/b3be26a015601b822b97d9149ff8cb5ead58c66f981e04fedf4e762f4bd4/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8dc52c23056b9ddd46818a57b78404882310fb473d63f17b07d5c40421e47f8e", size = 761089, upload-time = "2025-09-25T21:32:17.56Z" }, + { url = "https://files.pythonhosted.org/packages/be/8e/98435a21d1d4b46590d5459a22d88128103f8da4c2d4cb8f14f2a96504e1/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:41715c910c881bc081f1e8872880d3c650acf13dfa8214bad49ed4cede7c34ea", size = 790181, upload-time = "2025-09-25T21:32:18.834Z" }, + { url = "https://files.pythonhosted.org/packages/74/93/7baea19427dcfbe1e5a372d81473250b379f04b1bd3c4c5ff825e2327202/pyyaml-6.0.3-cp312-cp312-win32.whl", hash = "sha256:96b533f0e99f6579b3d4d4995707cf36df9100d67e0c8303a0c55b27b5f99bc5", size = 137658, upload-time = "2025-09-25T21:32:20.209Z" }, + { url = "https://files.pythonhosted.org/packages/86/bf/899e81e4cce32febab4fb42bb97dcdf66bc135272882d1987881a4b519e9/pyyaml-6.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:5fcd34e47f6e0b794d17de1b4ff496c00986e1c83f7ab2fb8fcfe9616ff7477b", size = 154003, upload-time = "2025-09-25T21:32:21.167Z" }, + { url = "https://files.pythonhosted.org/packages/1a/08/67bd04656199bbb51dbed1439b7f27601dfb576fb864099c7ef0c3e55531/pyyaml-6.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:64386e5e707d03a7e172c0701abfb7e10f0fb753ee1d773128192742712a98fd", size = 140344, upload-time = "2025-09-25T21:32:22.617Z" }, + { url = "https://files.pythonhosted.org/packages/d1/11/0fd08f8192109f7169db964b5707a2f1e8b745d4e239b784a5a1dd80d1db/pyyaml-6.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8da9669d359f02c0b91ccc01cac4a67f16afec0dac22c2ad09f46bee0697eba8", size = 181669, upload-time = "2025-09-25T21:32:23.673Z" }, + { url = "https://files.pythonhosted.org/packages/b1/16/95309993f1d3748cd644e02e38b75d50cbc0d9561d21f390a76242ce073f/pyyaml-6.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2283a07e2c21a2aa78d9c4442724ec1eb15f5e42a723b99cb3d822d48f5f7ad1", size = 173252, upload-time = "2025-09-25T21:32:25.149Z" }, + { url = "https://files.pythonhosted.org/packages/50/31/b20f376d3f810b9b2371e72ef5adb33879b25edb7a6d072cb7ca0c486398/pyyaml-6.0.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ee2922902c45ae8ccada2c5b501ab86c36525b883eff4255313a253a3160861c", size = 767081, upload-time = "2025-09-25T21:32:26.575Z" }, + { url = "https://files.pythonhosted.org/packages/49/1e/a55ca81e949270d5d4432fbbd19dfea5321eda7c41a849d443dc92fd1ff7/pyyaml-6.0.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a33284e20b78bd4a18c8c2282d549d10bc8408a2a7ff57653c0cf0b9be0afce5", size = 841159, upload-time = "2025-09-25T21:32:27.727Z" }, + { url = "https://files.pythonhosted.org/packages/74/27/e5b8f34d02d9995b80abcef563ea1f8b56d20134d8f4e5e81733b1feceb2/pyyaml-6.0.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0f29edc409a6392443abf94b9cf89ce99889a1dd5376d94316ae5145dfedd5d6", size = 801626, upload-time = "2025-09-25T21:32:28.878Z" }, + { url = "https://files.pythonhosted.org/packages/f9/11/ba845c23988798f40e52ba45f34849aa8a1f2d4af4b798588010792ebad6/pyyaml-6.0.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f7057c9a337546edc7973c0d3ba84ddcdf0daa14533c2065749c9075001090e6", size = 753613, upload-time = "2025-09-25T21:32:30.178Z" }, + { url = "https://files.pythonhosted.org/packages/3d/e0/7966e1a7bfc0a45bf0a7fb6b98ea03fc9b8d84fa7f2229e9659680b69ee3/pyyaml-6.0.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:eda16858a3cab07b80edaf74336ece1f986ba330fdb8ee0d6c0d68fe82bc96be", size = 794115, upload-time = "2025-09-25T21:32:31.353Z" }, + { url = "https://files.pythonhosted.org/packages/de/94/980b50a6531b3019e45ddeada0626d45fa85cbe22300844a7983285bed3b/pyyaml-6.0.3-cp313-cp313-win32.whl", hash = "sha256:d0eae10f8159e8fdad514efdc92d74fd8d682c933a6dd088030f3834bc8e6b26", size = 137427, upload-time = "2025-09-25T21:32:32.58Z" }, + { url = "https://files.pythonhosted.org/packages/97/c9/39d5b874e8b28845e4ec2202b5da735d0199dbe5b8fb85f91398814a9a46/pyyaml-6.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:79005a0d97d5ddabfeeea4cf676af11e647e41d81c9a7722a193022accdb6b7c", size = 154090, upload-time = "2025-09-25T21:32:33.659Z" }, + { url = "https://files.pythonhosted.org/packages/73/e8/2bdf3ca2090f68bb3d75b44da7bbc71843b19c9f2b9cb9b0f4ab7a5a4329/pyyaml-6.0.3-cp313-cp313-win_arm64.whl", hash = "sha256:5498cd1645aa724a7c71c8f378eb29ebe23da2fc0d7a08071d89469bf1d2defb", size = 140246, upload-time = "2025-09-25T21:32:34.663Z" }, + { url = "https://files.pythonhosted.org/packages/9d/8c/f4bd7f6465179953d3ac9bc44ac1a8a3e6122cf8ada906b4f96c60172d43/pyyaml-6.0.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:8d1fab6bb153a416f9aeb4b8763bc0f22a5586065f86f7664fc23339fc1c1fac", size = 181814, upload-time = "2025-09-25T21:32:35.712Z" }, + { url = "https://files.pythonhosted.org/packages/bd/9c/4d95bb87eb2063d20db7b60faa3840c1b18025517ae857371c4dd55a6b3a/pyyaml-6.0.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:34d5fcd24b8445fadc33f9cf348c1047101756fd760b4dacb5c3e99755703310", size = 173809, upload-time = "2025-09-25T21:32:36.789Z" }, + { url = "https://files.pythonhosted.org/packages/92/b5/47e807c2623074914e29dabd16cbbdd4bf5e9b2db9f8090fa64411fc5382/pyyaml-6.0.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:501a031947e3a9025ed4405a168e6ef5ae3126c59f90ce0cd6f2bfc477be31b7", size = 766454, upload-time = "2025-09-25T21:32:37.966Z" }, + { url = "https://files.pythonhosted.org/packages/02/9e/e5e9b168be58564121efb3de6859c452fccde0ab093d8438905899a3a483/pyyaml-6.0.3-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b3bc83488de33889877a0f2543ade9f70c67d66d9ebb4ac959502e12de895788", size = 836355, upload-time = "2025-09-25T21:32:39.178Z" }, + { url = "https://files.pythonhosted.org/packages/88/f9/16491d7ed2a919954993e48aa941b200f38040928474c9e85ea9e64222c3/pyyaml-6.0.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c458b6d084f9b935061bc36216e8a69a7e293a2f1e68bf956dcd9e6cbcd143f5", size = 794175, upload-time = "2025-09-25T21:32:40.865Z" }, + { url = "https://files.pythonhosted.org/packages/dd/3f/5989debef34dc6397317802b527dbbafb2b4760878a53d4166579111411e/pyyaml-6.0.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:7c6610def4f163542a622a73fb39f534f8c101d690126992300bf3207eab9764", size = 755228, upload-time = "2025-09-25T21:32:42.084Z" }, + { url = "https://files.pythonhosted.org/packages/d7/ce/af88a49043cd2e265be63d083fc75b27b6ed062f5f9fd6cdc223ad62f03e/pyyaml-6.0.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5190d403f121660ce8d1d2c1bb2ef1bd05b5f68533fc5c2ea899bd15f4399b35", size = 789194, upload-time = "2025-09-25T21:32:43.362Z" }, + { url = "https://files.pythonhosted.org/packages/23/20/bb6982b26a40bb43951265ba29d4c246ef0ff59c9fdcdf0ed04e0687de4d/pyyaml-6.0.3-cp314-cp314-win_amd64.whl", hash = "sha256:4a2e8cebe2ff6ab7d1050ecd59c25d4c8bd7e6f400f5f82b96557ac0abafd0ac", size = 156429, upload-time = "2025-09-25T21:32:57.844Z" }, + { url = "https://files.pythonhosted.org/packages/f4/f4/a4541072bb9422c8a883ab55255f918fa378ecf083f5b85e87fc2b4eda1b/pyyaml-6.0.3-cp314-cp314-win_arm64.whl", hash = "sha256:93dda82c9c22deb0a405ea4dc5f2d0cda384168e466364dec6255b293923b2f3", size = 143912, upload-time = "2025-09-25T21:32:59.247Z" }, + { url = "https://files.pythonhosted.org/packages/7c/f9/07dd09ae774e4616edf6cda684ee78f97777bdd15847253637a6f052a62f/pyyaml-6.0.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:02893d100e99e03eda1c8fd5c441d8c60103fd175728e23e431db1b589cf5ab3", size = 189108, upload-time = "2025-09-25T21:32:44.377Z" }, + { url = "https://files.pythonhosted.org/packages/4e/78/8d08c9fb7ce09ad8c38ad533c1191cf27f7ae1effe5bb9400a46d9437fcf/pyyaml-6.0.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:c1ff362665ae507275af2853520967820d9124984e0f7466736aea23d8611fba", size = 183641, upload-time = "2025-09-25T21:32:45.407Z" }, + { url = "https://files.pythonhosted.org/packages/7b/5b/3babb19104a46945cf816d047db2788bcaf8c94527a805610b0289a01c6b/pyyaml-6.0.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6adc77889b628398debc7b65c073bcb99c4a0237b248cacaf3fe8a557563ef6c", size = 831901, upload-time = "2025-09-25T21:32:48.83Z" }, + { url = "https://files.pythonhosted.org/packages/8b/cc/dff0684d8dc44da4d22a13f35f073d558c268780ce3c6ba1b87055bb0b87/pyyaml-6.0.3-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a80cb027f6b349846a3bf6d73b5e95e782175e52f22108cfa17876aaeff93702", size = 861132, upload-time = "2025-09-25T21:32:50.149Z" }, + { url = "https://files.pythonhosted.org/packages/b1/5e/f77dc6b9036943e285ba76b49e118d9ea929885becb0a29ba8a7c75e29fe/pyyaml-6.0.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00c4bdeba853cc34e7dd471f16b4114f4162dc03e6b7afcc2128711f0eca823c", size = 839261, upload-time = "2025-09-25T21:32:51.808Z" }, + { url = "https://files.pythonhosted.org/packages/ce/88/a9db1376aa2a228197c58b37302f284b5617f56a5d959fd1763fb1675ce6/pyyaml-6.0.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:66e1674c3ef6f541c35191caae2d429b967b99e02040f5ba928632d9a7f0f065", size = 805272, upload-time = "2025-09-25T21:32:52.941Z" }, + { url = "https://files.pythonhosted.org/packages/da/92/1446574745d74df0c92e6aa4a7b0b3130706a4142b2d1a5869f2eaa423c6/pyyaml-6.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:16249ee61e95f858e83976573de0f5b2893b3677ba71c9dd36b9cf8be9ac6d65", size = 829923, upload-time = "2025-09-25T21:32:54.537Z" }, + { url = "https://files.pythonhosted.org/packages/f0/7a/1c7270340330e575b92f397352af856a8c06f230aa3e76f86b39d01b416a/pyyaml-6.0.3-cp314-cp314t-win_amd64.whl", hash = "sha256:4ad1906908f2f5ae4e5a8ddfce73c320c2a1429ec52eafd27138b7f1cbe341c9", size = 174062, upload-time = "2025-09-25T21:32:55.767Z" }, + { url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" }, +] + [[package]] name = "rich" version = "14.3.3" From bd897ed24277b43a7db6d2fcbfde6f6b483d187b Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Thu, 19 Mar 2026 19:14:36 +0100 Subject: [PATCH 06/70] Add datasource tests --- tests/test_datasource.py | 731 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 731 insertions(+) create mode 100644 tests/test_datasource.py diff --git a/tests/test_datasource.py b/tests/test_datasource.py new file mode 100644 index 0000000..467f071 --- /dev/null +++ b/tests/test_datasource.py @@ -0,0 +1,731 @@ +from __future__ import annotations + +import json +import os +from pathlib import Path +from textwrap import dedent +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from anton.data_vault import DataVault +from anton.datasource_registry import ( + AuthMethod, + DatasourceEngine, + DatasourceField, + DatasourceRegistry, +) + + +# ───────────────────────────────────────────────────────────────────────────── +# Fixtures +# ───────────────────────────────────────────────────────────────────────────── + + +@pytest.fixture() +def vault_dir(tmp_path): + return tmp_path / "data_vault" + + +@pytest.fixture() +def vault(vault_dir): + return DataVault(vault_dir=vault_dir) + + +@pytest.fixture() +def datasources_md(tmp_path): + """Write a minimal datasources.md and return its path.""" + path = tmp_path / "datasources.md" + path.write_text(dedent("""\ + ## PostgreSQL + + ```yaml + engine: postgresql + display_name: PostgreSQL + pip: psycopg2-binary + name_from: database + fields: + - name: host + required: true + description: hostname or IP + - name: port + required: true + default: "5432" + description: port number + - name: database + required: true + description: database name + - name: user + required: true + description: username + - name: password + required: true + secret: true + description: password + - name: schema + required: false + description: defaults to public + test_snippet: | + import psycopg2 + conn = psycopg2.connect( + host=os.environ["DS_HOST"], + port=os.environ["DS_PORT"], + dbname=os.environ["DS_DATABASE"], + user=os.environ["DS_USER"], + password=os.environ["DS_PASSWORD"], + ) + conn.close() + print("ok") + ``` + + ## HubSpot + + ```yaml + engine: hubspot + display_name: HubSpot + pip: hubspot-api-client + name_from: access_token + auth_method: choice + auth_methods: + - name: private_app + display: Private App token (recommended) + fields: + - name: access_token + required: true + secret: true + description: pat-na1-xxx token + - name: oauth + display: OAuth 2.0 + fields: + - name: client_id + required: true + description: OAuth client ID + - name: client_secret + required: true + secret: true + description: OAuth client secret + test_snippet: | + print("ok") + ``` + """)) + return path + + +@pytest.fixture() +def registry(datasources_md, tmp_path): + """Registry pointing at our temp datasources.md, no user overrides.""" + reg = DatasourceRegistry.__new__(DatasourceRegistry) + reg._engines = {} + from anton.datasource_registry import _parse_file + reg._engines = _parse_file(datasources_md) + return reg + + +# ───────────────────────────────────────────────────────────────────────────── +# DataVault — save / load / delete +# ───────────────────────────────────────────────────────────────────────────── + + +class TestDataVaultSaveLoad: + def test_save_creates_file(self, vault, vault_dir): + vault.save("postgresql", "prod_db", {"host": "db.example.com", "port": "5432"}) + assert (vault_dir / "postgresql-prod_db").is_file() + + def test_save_file_permissions(self, vault, vault_dir): + vault.save("postgresql", "prod_db", {"host": "db.example.com"}) + path = vault_dir / "postgresql-prod_db" + mode = oct(path.stat().st_mode)[-3:] + assert mode == "600" + + def test_vault_dir_permissions(self, vault, vault_dir): + vault.save("postgresql", "prod_db", {"host": "db.example.com"}) + mode = oct(vault_dir.stat().st_mode)[-3:] + assert mode == "700" + + def test_load_returns_fields(self, vault): + creds = {"host": "db.example.com", "port": "5432", "password": "secret"} + vault.save("postgresql", "prod_db", creds) + loaded = vault.load("postgresql", "prod_db") + assert loaded == creds + + def test_load_missing_returns_none(self, vault): + assert vault.load("postgresql", "nonexistent") is None + + def test_load_corrupt_file_returns_none(self, vault, vault_dir): + vault._ensure_dir() + (vault_dir / "postgresql-bad").write_text("not json") + assert vault.load("postgresql", "bad") is None + + def test_save_overwrites_existing(self, vault): + vault.save("postgresql", "prod_db", {"host": "old.host"}) + vault.save("postgresql", "prod_db", {"host": "new.host"}) + assert vault.load("postgresql", "prod_db") == {"host": "new.host"} + + def test_delete_existing(self, vault, vault_dir): + vault.save("postgresql", "prod_db", {"host": "x"}) + result = vault.delete("postgresql", "prod_db") + assert result is True + assert not (vault_dir / "postgresql-prod_db").is_file() + + def test_delete_missing_returns_false(self, vault): + assert vault.delete("postgresql", "ghost") is False + + def test_special_chars_sanitized_in_filename(self, vault, vault_dir): + vault.save("postgresql", "my db/prod", {"host": "x"}) + files = list(vault_dir.iterdir()) + assert len(files) == 1 + assert "/" not in files[0].name + + def test_json_contains_metadata(self, vault, vault_dir): + vault.save("postgresql", "prod_db", {"host": "x"}) + raw = json.loads((vault_dir / "postgresql-prod_db").read_text()) + assert raw["engine"] == "postgresql" + assert raw["name"] == "prod_db" + assert "created_at" in raw + assert raw["fields"] == {"host": "x"} + + +# ───────────────────────────────────────────────────────────────────────────── +# DataVault — list_connections +# ───────────────────────────────────────────────────────────────────────────── + + +class TestDataVaultListConnections: + def test_empty_vault(self, vault): + assert vault.list_connections() == [] + + def test_lists_all_connections(self, vault): + vault.save("postgresql", "prod_db", {"host": "a"}) + vault.save("hubspot", "main", {"access_token": "pat-xxx"}) + conns = vault.list_connections() + engines = {c["engine"] for c in conns} + assert engines == {"postgresql", "hubspot"} + + def test_skips_corrupt_files(self, vault, vault_dir): + vault._ensure_dir() + vault.save("postgresql", "good", {"host": "x"}) + (vault_dir / "postgresql-bad").write_text("{{not json") + conns = vault.list_connections() + assert len(conns) == 1 + assert conns[0]["name"] == "good" + + def test_vault_dir_missing_returns_empty(self, vault): + # vault_dir was never created + assert vault.list_connections() == [] + + +# ───────────────────────────────────────────────────────────────────────────── +# DataVault — inject_env / clear_ds_env +# ───────────────────────────────────────────────────────────────────────────── + + +class TestDataVaultEnvInjection: + def test_inject_sets_ds_vars(self, vault): + vault.save("postgresql", "prod_db", {"host": "db.example.com", "password": "s3cr3t"}) + var_names = vault.inject_env("postgresql", "prod_db") + assert os.environ.get("DS_HOST") == "db.example.com" + assert os.environ.get("DS_PASSWORD") == "s3cr3t" + assert set(var_names) == {"DS_HOST", "DS_PASSWORD"} + # Cleanup + vault.clear_ds_env() + + def test_inject_missing_returns_none(self, vault): + result = vault.inject_env("postgresql", "ghost") + assert result is None + + def test_clear_removes_ds_vars(self, vault): + vault.save("postgresql", "prod_db", {"host": "x"}) + vault.inject_env("postgresql", "prod_db") + vault.clear_ds_env() + assert "DS_HOST" not in os.environ + + def test_clear_leaves_non_ds_vars(self, vault): + os.environ["MY_VAR"] = "untouched" + vault.clear_ds_env() + assert os.environ.get("MY_VAR") == "untouched" + del os.environ["MY_VAR"] + + def test_inject_uppercases_field_names(self, vault): + vault.save("postgresql", "prod_db", {"access_token": "tok123"}) + vault.inject_env("postgresql", "prod_db") + assert os.environ.get("DS_ACCESS_TOKEN") == "tok123" + vault.clear_ds_env() + + +# ───────────────────────────────────────────────────────────────────────────── +# DataVault — next_connection_number +# ───────────────────────────────────────────────────────────────────────────── + + +class TestDataVaultNextConnectionNumber: + def test_returns_one_when_empty(self, vault): + assert vault.next_connection_number("postgresql") == 1 + + def test_increments_past_existing(self, vault): + vault.save("postgresql", "1", {"host": "a"}) + vault.save("postgresql", "2", {"host": "b"}) + assert vault.next_connection_number("postgresql") == 3 + + def test_ignores_named_connections(self, vault): + # "prod_db" is not a digit — should not affect numbering + vault.save("postgresql", "prod_db", {"host": "a"}) + assert vault.next_connection_number("postgresql") == 1 + + def test_does_not_confuse_engines(self, vault): + vault.save("hubspot", "1", {"access_token": "x"}) + vault.save("hubspot", "2", {"access_token": "y"}) + # postgresql counter is independent + assert vault.next_connection_number("postgresql") == 1 + + +# ───────────────────────────────────────────────────────────────────────────── +# DatasourceRegistry — lookup +# ───────────────────────────────────────────────────────────────────────────── + + +class TestDatasourceRegistry: + def test_get_by_slug(self, registry): + engine = registry.get("postgresql") + assert engine is not None + assert engine.display_name == "PostgreSQL" + + def test_get_missing_returns_none(self, registry): + assert registry.get("mysql") is None + + def test_find_by_name_exact(self, registry): + assert registry.find_by_name("PostgreSQL") is not None + + def test_find_by_name_case_insensitive(self, registry): + assert registry.find_by_name("postgresql") is not None + assert registry.find_by_name("POSTGRESQL") is not None + + def test_find_by_slug(self, registry): + # engine slug is also accepted + assert registry.find_by_name("postgresql") is not None + + def test_find_unknown_returns_none(self, registry): + assert registry.find_by_name("MySQL") is None + + def test_all_engines_sorted(self, registry): + engines = registry.all_engines() + names = [e.display_name for e in engines] + assert names == sorted(names) + + def test_fields_parsed_correctly(self, registry): + engine = registry.get("postgresql") + field_names = [f.name for f in engine.fields] + assert "host" in field_names + assert "password" in field_names + + def test_secret_flag_on_password(self, registry): + engine = registry.get("postgresql") + pw = next(f for f in engine.fields if f.name == "password") + assert pw.secret is True + + def test_required_flag(self, registry): + engine = registry.get("postgresql") + schema = next(f for f in engine.fields if f.name == "schema") + assert schema.required is False + + def test_default_value_on_port(self, registry): + engine = registry.get("postgresql") + port = next(f for f in engine.fields if f.name == "port") + assert port.default == "5432" + + def test_pip_field(self, registry): + engine = registry.get("postgresql") + assert engine.pip == "psycopg2-binary" + + def test_test_snippet_present(self, registry): + engine = registry.get("postgresql") + assert "print(\"ok\")" in engine.test_snippet + + def test_auth_method_choice_parsed(self, registry): + engine = registry.get("hubspot") + assert engine.auth_method == "choice" + assert len(engine.auth_methods) == 2 + method_names = [m.name for m in engine.auth_methods] + assert "private_app" in method_names + assert "oauth" in method_names + + def test_auth_method_fields_parsed(self, registry): + engine = registry.get("hubspot") + private = next(m for m in engine.auth_methods if m.name == "private_app") + assert len(private.fields) == 1 + assert private.fields[0].name == "access_token" + assert private.fields[0].secret is True + + +# ───────────────────────────────────────────────────────────────────────────── +# DatasourceRegistry — derive_name +# ───────────────────────────────────────────────────────────────────────────── + + +class TestDeriveConnectionName: + def test_single_field_name_from(self, registry): + engine = registry.get("postgresql") # name_from: database + name = registry.derive_name(engine, {"database": "prod_db", "host": "x"}) + assert name == "prod_db" + + def test_missing_name_from_field_returns_empty(self, registry): + engine = registry.get("postgresql") + name = registry.derive_name(engine, {"host": "x"}) # no "database" + assert name == "" + + def test_no_name_from_returns_empty(self): + engine = DatasourceEngine(engine="test", display_name="Test", name_from="") + reg = DatasourceRegistry.__new__(DatasourceRegistry) + reg._engines = {} + name = reg.derive_name(engine, {"host": "x"}) + assert name == "" + + def test_list_name_from(self): + engine = DatasourceEngine( + engine="test", + display_name="Test", + name_from=["host", "database"], + ) + reg = DatasourceRegistry.__new__(DatasourceRegistry) + reg._engines = {} + name = reg.derive_name(engine, {"host": "db.example.com", "database": "prod"}) + assert name == "db.example.com_prod" + + def test_list_name_from_skips_missing(self): + engine = DatasourceEngine( + engine="test", + display_name="Test", + name_from=["host", "database"], + ) + reg = DatasourceRegistry.__new__(DatasourceRegistry) + reg._engines = {} + name = reg.derive_name(engine, {"host": "db.example.com"}) + assert name == "db.example.com" + + +# ───────────────────────────────────────────────────────────────────────────── +# DatasourceRegistry — user overrides +# ───────────────────────────────────────────────────────────────────────────── + + +class TestDatasourceRegistryUserOverrides: + def test_user_override_wins(self, tmp_path, datasources_md): + """A user-defined engine with same slug overrides the builtin.""" + user_md = tmp_path / "user_datasources.md" + user_md.write_text(dedent("""\ + ## PostgreSQL + + ```yaml + engine: postgresql + display_name: PostgreSQL (custom) + pip: psycopg2 + fields: + - name: host + required: true + description: custom host field + test_snippet: print("ok") + ``` + """)) + + from anton.datasource_registry import _parse_file + + builtin = _parse_file(datasources_md) + user = _parse_file(user_md) + merged = {**builtin, **user} + + assert merged["postgresql"].display_name == "PostgreSQL (custom)" + assert merged["postgresql"].pip == "psycopg2" + + def test_missing_user_file_falls_back_to_builtin(self, tmp_path, datasources_md): + from anton.datasource_registry import _parse_file + + user_engines = _parse_file(tmp_path / "nonexistent.md") + assert user_engines == {} + + +# ───────────────────────────────────────────────────────────────────────────── +# _handle_connect_datasource — integration-style (mocked I/O) +# ───────────────────────────────────────────────────────────────────────────── + + +class TestHandleConnectDatasource: + """Test the slash-command handler with mocked prompts and scratchpad.""" + + def _make_session(self): + from anton.chat import ChatSession + + mock_llm = AsyncMock() + session = ChatSession(mock_llm) + session._scratchpads = AsyncMock() + return session + + def _make_cell(self, stdout="ok", stderr="", error=None): + cell = MagicMock() + cell.stdout = stdout + cell.stderr = stderr + cell.error = error + return cell + + @pytest.mark.asyncio + async def test_unknown_engine_returns_early(self, registry, vault_dir, capsys): + """Typing an unknown engine name aborts without saving anything.""" + from anton.chat import _handle_connect_datasource + + session = self._make_session() + console = MagicMock() + console.print = MagicMock() + + with ( + patch("anton.chat.DataVault", return_value=DataVault(vault_dir=vault_dir)), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", return_value="MySQL"), + ): + result = await _handle_connect_datasource(console, session._scratchpads, session) + + assert result is session # unchanged session + assert DataVault(vault_dir=vault_dir).list_connections() == [] + + @pytest.mark.asyncio + async def test_partial_save_on_n_answer(self, registry, vault_dir): + """Answering 'n' saves partial credentials and returns without testing.""" + from anton.chat import _handle_connect_datasource + + session = self._make_session() + console = MagicMock() + console.print = MagicMock() + + vault = DataVault(vault_dir=vault_dir) + prompt_responses = iter(["PostgreSQL", "n", "db.example.com", "", "", "", "", ""]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + ): + result = await _handle_connect_datasource(console, session._scratchpads, session) + + conns = vault.list_connections() + assert len(conns) == 1 + assert conns[0]["engine"] == "postgresql" + # Partial connections get auto-numbered names + assert conns[0]["name"].isdigit() + # Scratchpad was NOT used for testing + session._scratchpads.get_or_create.assert_not_called() + + @pytest.mark.asyncio + async def test_successful_connection_saves_and_injects_history(self, registry, vault_dir): + """Happy path: test passes, credentials saved, history entry added.""" + from anton.chat import _handle_connect_datasource + + session = self._make_session() + console = MagicMock() + console.print = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + pad = AsyncMock() + pad.execute = AsyncMock(return_value=self._make_cell(stdout="ok")) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + + prompt_responses = iter([ + "PostgreSQL", # engine choice + "y", # have all credentials + "db.example.com", # host + "5432", # port + "prod_db", # database + "alice", # user + "s3cr3t", # password + "", # schema (optional) + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + ): + result = await _handle_connect_datasource(console, session._scratchpads, session) + + # Credentials saved + conns = vault.list_connections() + assert len(conns) == 1 + saved = vault.load("postgresql", conns[0]["name"]) + assert saved["host"] == "db.example.com" + assert saved["password"] == "s3cr3t" + + # History entry injected + assert result._history + last = result._history[-1] + assert last["role"] == "assistant" + assert "postgresql" in last["content"].lower() + + @pytest.mark.asyncio + async def test_failed_test_offers_retry(self, registry, vault_dir): + """Connection test failure prompts for retry; success on second attempt saves.""" + from anton.chat import _handle_connect_datasource + + session = self._make_session() + console = MagicMock() + console.print = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + fail_cell = self._make_cell(stdout="", stderr="password authentication failed") + ok_cell = self._make_cell(stdout="ok") + pad = AsyncMock() + pad.execute = AsyncMock(side_effect=[fail_cell, ok_cell]) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + + prompt_responses = iter([ + "PostgreSQL", # engine + "y", # have all creds + "db.example.com", # host + "5432", # port + "prod_db", # database + "alice", # user + "wrongpassword", # password (first attempt - fails) + "", # schema + "y", # retry? + "correctpassword", # new password + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + ): + result = await _handle_connect_datasource(console, session._scratchpads, session) + + # Should have saved after second attempt + conns = vault.list_connections() + assert len(conns) == 1 + saved = vault.load("postgresql", conns[0]["name"]) + assert saved["password"] == "correctpassword" + + @pytest.mark.asyncio + async def test_failed_test_no_retry_returns_without_saving(self, registry, vault_dir): + """Declining retry on failed test leaves vault empty.""" + from anton.chat import _handle_connect_datasource + + session = self._make_session() + console = MagicMock() + console.print = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + fail_cell = self._make_cell(stdout="", error="connection refused") + pad = AsyncMock() + pad.execute = AsyncMock(return_value=fail_cell) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + + prompt_responses = iter([ + "PostgreSQL", "y", + "db.example.com", "5432", "prod_db", "alice", "badpass", "", + "n", # don't retry + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + ): + result = await _handle_connect_datasource(console, session._scratchpads, session) + + assert vault.list_connections() == [] + # No history injection since save never happened + assert not result._history + + @pytest.mark.asyncio + async def test_ds_env_cleaned_up_after_test(self, registry, vault_dir): + """DS_* env vars must not leak into os.environ after connection testing.""" + from anton.chat import _handle_connect_datasource + + session = self._make_session() + console = MagicMock() + console.print = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + pad = AsyncMock() + pad.execute = AsyncMock(return_value=self._make_cell(stdout="ok")) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + + prompt_responses = iter([ + "PostgreSQL", "y", + "db.example.com", "5432", "prod_db", "alice", "s3cr3t", "", + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + ): + await _handle_connect_datasource(console, session._scratchpads, session) + + # No DS_* variables should remain + leaked = [k for k in os.environ if k.startswith("DS_")] + assert leaked == [], f"DS_* vars leaked: {leaked}" + + @pytest.mark.asyncio + async def test_auth_method_choice_selects_fields(self, registry, vault_dir): + """Selecting an auth method filters to that method's fields only.""" + from anton.chat import _handle_connect_datasource + + session = self._make_session() + console = MagicMock() + console.print = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + pad = AsyncMock() + pad.execute = AsyncMock(return_value=self._make_cell(stdout="ok")) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + + prompt_responses = iter([ + "HubSpot", # engine + "1", # auth method: private_app + "y", # have all creds + "pat-na1-abc123", # access_token + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + ): + await _handle_connect_datasource(console, session._scratchpads, session) + + conns = vault.list_connections() + assert len(conns) == 1 + saved = vault.load("hubspot", conns[0]["name"]) + # Only private_app fields collected — no client_id or client_secret + assert "access_token" in saved + assert "client_id" not in saved + assert "client_secret" not in saved + + @pytest.mark.asyncio + async def test_selective_field_collection(self, registry, vault_dir): + """Typing 'host,user,password' collects only those three fields.""" + from anton.chat import _handle_connect_datasource + + session = self._make_session() + console = MagicMock() + console.print = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + pad = AsyncMock() + pad.execute = AsyncMock(return_value=self._make_cell(stdout="ok")) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + + prompt_responses = iter([ + "PostgreSQL", # engine + "host,user,password", # selective list + "db.example.com", # host + "alice", # user + "s3cr3t", # password + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + ): + await _handle_connect_datasource(console, session._scratchpads, session) + + conns = vault.list_connections() + assert len(conns) == 1 + saved = vault.load("postgresql", conns[0]["name"]) + assert set(saved.keys()) == {"host", "user", "password"} \ No newline at end of file From dd1aa366f317b0770fd32afcf9ba20637b79d33c Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Thu, 19 Mar 2026 21:01:19 +0100 Subject: [PATCH 07/70] Fixing issues with credentials saving and usage in the scratchpads --- anton/chat.py | 394 +++++++++++++++++++++++++++++-- anton/cli.py | 34 ++- anton/data_vault.py | 8 +- anton/datasource_registry.py | 18 +- anton/llm/prompts.py | 3 + datasources.md | 446 ++++++++++++++++++++++++++++++++++- tests/test_datasource.py | 142 ++++++++++- 7 files changed, 1008 insertions(+), 37 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index 1ee1c1d..add55d6 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -2,6 +2,8 @@ import asyncio import os +import json as _json +import re as _re import sys import time from collections.abc import AsyncIterator, Callable @@ -40,6 +42,10 @@ format_cell_result, prepare_scratchpad_exec, ) +from anton.data_vault import DataVault +from anton.datasource_registry import DatasourceEngine, DatasourceField, DatasourceRegistry, _parse_file as _ds_parse_file +from rich.prompt import Confirm, Prompt + if TYPE_CHECKING: from rich.console import Console @@ -172,6 +178,10 @@ async def _build_system_prompt(self, user_message: str = "") -> str: md_context = self._workspace.build_anton_md_context() if md_context: prompt += md_context + # Inject connected datasource context without credentials + ds_ctx = _build_datasource_context() + if ds_ctx: + prompt += ds_ctx return prompt # Packages the LLM is most likely to care about when writing scratchpad code. @@ -403,6 +413,7 @@ async def turn(self, user_input: str | list[dict]) -> str: except Exception as exc: result_text = f"Tool '{tc.name}' failed: {exc}" + result_text = _scrub_credentials(result_text) result_text = _apply_error_tracking( result_text, tc.name, error_streak, resilience_nudged, ) @@ -648,6 +659,7 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato tool=tc.name, ) + result_text = _scrub_credentials(result_text) result_text = _apply_error_tracking( result_text, tc.name, error_streak, resilience_nudged, ) @@ -897,6 +909,77 @@ def _apply_error_tracking( return result_text +# DS_* var names whose values are known to be secret (passwords, tokens, keys). +# Populated at startup and after each successful connect. +_DS_SECRET_VARS: set[str] = set() + +# DS_* var names for **ALL** fields of registered engines. +_DS_KNOWN_VARS: set[str] = set() + + +def _register_secret_vars(engine_def: "DatasourceEngine") -> None: + """Record which DS_* var names correspond to secret fields for engine_def. + """ + all_fields = list(engine_def.fields) + for am in (engine_def.auth_methods or []): + all_fields.extend(am.fields) + for f in all_fields: + key = f"DS_{f.name.upper()}" + _DS_KNOWN_VARS.add(key) + if f.secret: + _DS_SECRET_VARS.add(key) + + +def _scrub_credentials(text: str) -> str: + """Remove secret DS_* values from scratchpad output before it reaches the LLM. + + Only redacts vars registered as secret via _register_secret_vars (driven by + DatasourceField.secret=true in datasources.md). Non-secret fields of known + engines (DS_HOST, DS_PORT, DS_BASE_URL, …) are left readable so the LLM can + reason about connection errors. For truly unknown DS_* vars (custom engines + not yet in the registry) the fallback scrubs any long value — conservative + but safe. + """ + for key in _DS_SECRET_VARS: + value = os.environ.get(key, "") + if not value or len(value) <= 8: + continue + text = text.replace(value, f"[{key}]") + for key, value in os.environ.items(): + if not key.startswith("DS_") or key in _DS_KNOWN_VARS: + continue + if not value or len(value) <= 8: + continue + text = text.replace(value, f"[{key}]") + return text + + +def _build_datasource_context() -> str: + """Build a system-prompt section listing available DS_* env vars by name. + + Shows the LLM what data sources are connected and which environment + variable names to use — without exposing any credential values. + """ + try: + vault = DataVault() + conns = vault.list_connections() + except Exception: + return "" + if not conns: + return "" + lines = ["\n\n## Connected Data Sources"] + lines.append( + "Credentials are pre-injected as DS_* environment variables. " + "Use them directly in scratchpad code. " + "Never read ~/.anton/data_vault/ files directly.\n" + ) + for c in conns: + fields = vault.load(c["engine"], c["name"]) or {} + var_names = ", ".join(f"DS_{k.upper()}" for k in fields) + lines.append(f"- `{c['engine']}-{c['name']}` → {var_names}") + return "\n".join(lines) + + def _build_runtime_context(settings: AntonSettings) -> str: """Build runtime context string including Minds datasource info if configured.""" ctx = ( @@ -1642,7 +1725,7 @@ async def _handle_data_connections( session: ChatSession, ) -> ChatSession: """View and manage stored keys and connections across global and project vaults.""" - from rich.prompt import Confirm, Prompt + from rich.prompt import Prompt from anton.workspace import Workspace as _Workspace @@ -2182,6 +2265,156 @@ def _human_size(nbytes: int) -> str: return f"{nbytes:.1f}TB" +async def _handle_add_custom_datasource( + console: Console, + name: str, + registry, + session: "ChatSession", +): + """Ask the user how they authenticate, use the LLM to identify fields, save definition.""" + + console.print() + user_answer = Prompt.ask( + f"[anton.cyan](anton)[/] '{name}' isn't in my built-in list.\n" + f" How do you authenticate with it? " + f"Describe what you have or paste credentials directly", + console=console, + ) + if not user_answer.strip(): + return None + + console.print() + console.print("[anton.muted] Got it — working out the connection details…[/]") + + try: + response = await session._llm.plan( + system="You are a data source connection expert.", + messages=[{ + "role": "user", + "content": ( + f"The user wants to connect to '{name}' and said: {user_answer}\n\n" + "Return ONLY valid JSON (no markdown fences, no commentary):\n" + '{"display_name":"Human-readable name","pip":"pip-package or empty string",' + '"fields":[{"name":"snake_case_name","value":"value if given inline else empty",' + '"secret":true or false,"required":true or false,"description":"what it is"}]}' + ), + }], + max_tokens=1024, + ) + text = response.content.strip() + text = _re.sub(r"^```[^\n]*\n|```\s*$", "", text, flags=_re.MULTILINE).strip() + data = _json.loads(text) + except Exception: + console.print("[anton.warning] Couldn't identify connection details. Try again.[/]") + console.print() + return None + + raw_fields = data.get("fields") or [] + fields: list[DatasourceField] = [] + for f in raw_fields: + if not isinstance(f, dict) or not f.get("name"): + continue + fields.append(DatasourceField( + name=f["name"], + required=bool(f.get("required", True)), + secret=bool(f.get("secret", False)), + description=str(f.get("description", "")), + )) + + if not fields: + console.print("[anton.warning] Couldn't identify any connection fields.[/]") + console.print() + return None + + display_name = str(data.get("display_name", name)) + pip_pkg = str(data.get("pip", "")) + + # Show summary + console.print() + console.print(" [bold]── What I'll save ──────────────────────────[/]") + credentials: dict[str, str] = {} + for f, raw in zip(fields, raw_fields): + inline_value = str(raw.get("value", "")).strip() + if f.secret and inline_value: + console.print(f" • [bold]{f.name:<14}[/] (secret — provided, stored securely)") + credentials[f.name] = inline_value + elif f.secret: + console.print(f" • [bold]{f.name:<14}[/] (secret — I'll ask for this)") + else: + val_display = inline_value or "[anton.muted][/]" + console.print(f" • [bold]{f.name:<14}[/] {val_display}") + if inline_value: + credentials[f.name] = inline_value + console.print() + + # Prompt for any secret fields not provided inline + for f, raw in zip(fields, raw_fields): + if not f.secret: + continue + if str(raw.get("value", "")).strip(): + continue + value = Prompt.ask( + f"[anton.cyan](anton)[/] {f.name}", + password=True, + console=console, + default="", + ) + if value: + credentials[f.name] = value + + if not credentials: + console.print("[anton.warning] No credentials collected. Aborting.[/]") + console.print() + return None + + # Build engine slug and write definition to ~/.anton/datasources.md + slug = _re.sub(r"[^\w]", "_", display_name.lower()).strip("_") + field_lines = "\n".join( + f" - {{ name: {f.name}, required: {str(f.required).lower()}, " + f"secret: {str(f.secret).lower()}, description: \"{f.description}\" }}" + for f in fields + ) + yaml_block = ( + f"\n---\n\n## {display_name}\n" + "```yaml\n" + f"engine: {slug}\n" + f"display_name: {display_name}\n" + + (f"pip: {pip_pkg}\n" if pip_pkg else "") + + f"fields:\n{field_lines}\n" + "```\n" + ) + user_ds_path = Path("~/.anton/datasources.md").expanduser() + tmp_path = user_ds_path.with_suffix(".tmp") + + # Write to temp, validate it parses, then rename atomically + existing = user_ds_path.read_text(encoding="utf-8") if user_ds_path.is_file() else "" + tmp_path.write_text(existing + yaml_block, encoding="utf-8") + + parsed = _ds_parse_file(tmp_path) + if slug in parsed: + import shutil + shutil.move(str(tmp_path), str(user_ds_path)) + else: + tmp_path.unlink(missing_ok=True) + console.print( + "[anton.warning]Could not validate engine definition — " + "credentials saved but engine not written to datasources.md.[/]" + ) + + registry._load() + engine_def = registry.get(slug) + if engine_def is None: + # Fallback: construct inline so the flow can continue even if parse failed + engine_def = DatasourceEngine( + engine=slug, + display_name=display_name, + pip=pip_pkg, + fields=fields, + ) + + return engine_def, credentials + + async def _handle_connect_datasource( console: Console, scratchpads: ScratchpadManager, @@ -2190,13 +2423,9 @@ async def _handle_connect_datasource( """Interactive flow for connecting a new data source to the Local Vault.""" from rich.prompt import Prompt - from anton.data_vault import DataVault - from anton.datasource_registry import DatasourceRegistry - vault = DataVault() registry = DatasourceRegistry() - # ── Step 1: ask which engine ────────────────────────────────── console.print() engine_names = ", ".join(e.display_name for e in registry.all_engines()) answer = Prompt.ask( @@ -2206,11 +2435,52 @@ async def _handle_connect_datasource( ) engine_def = registry.find_by_name(answer.strip()) if engine_def is None: - console.print(f"[anton.warning]Unknown data source '{answer}'. Available:[/]") - for e in registry.all_engines(): - console.print(f" • {e.display_name}") - console.print() - return session + # Check whether the input is ambiguous before treating it as unknown + needle = answer.strip().lower() + candidates = [ + e for e in registry.all_engines() + if needle in e.display_name.lower() or needle in e.engine.lower() + ] + if len(candidates) > 1: + console.print() + console.print( + f"[anton.warning](anton)[/] '{answer}' matches multiple engines — " + "which one did you mean?" + ) + console.print() + for i, e in enumerate(candidates, 1): + console.print(f" {i}. {e.display_name}") + console.print() + pick = Prompt.ask( + "[anton.cyan](anton)[/] Enter a number", + console=console, + ).strip() + try: + engine_def = candidates[int(pick) - 1] + except (ValueError, IndexError): + console.print("[anton.warning]Invalid choice. Aborting.[/]") + console.print() + return session + else: + result = await _handle_add_custom_datasource(console, answer.strip(), registry, session) + if result is None: + return session + engine_def, credentials = result + conn_num = vault.next_connection_number(engine_def.engine) + vault.save(engine_def.engine, str(conn_num), credentials) + slug = f"{engine_def.engine}-{conn_num}" + console.print( + f" Credentials saved to Local Vault as [bold]\"{slug}\"[/bold]." + ) + console.print() + session._history.append({ + "role": "assistant", + "content": ( + f"I've saved a {engine_def.display_name} connection named \"{slug}\" " + f"to the Local Vault. I can now query this data source when needed." + ), + }) + return session # ── Step 2a: auth method choice (if engine requires it) ─────── active_fields = engine_def.fields @@ -2237,7 +2507,6 @@ async def _handle_connect_datasource( return session active_fields = chosen_method.fields - # ── Step 2b: show fields ────────────────────────────────────── required_fields = [f for f in active_fields if f.required] optional_fields = [f for f in active_fields if not f.required] @@ -2291,7 +2560,6 @@ async def _handle_connect_datasource( fields_to_collect = matched if matched else active_fields partial = False - # ── Step 4: collect credentials ────────────────────────────── console.print() credentials: dict[str, str] = {} @@ -2314,7 +2582,6 @@ async def _handle_connect_datasource( if value: credentials[f.name] = value - # ── Partial save ───────────────────────────────────────────── if partial: n = vault.next_connection_number(engine_def.engine) auto_name = str(n) @@ -2329,7 +2596,6 @@ async def _handle_connect_datasource( console.print() return session - # ── Step 5: test connection ─────────────────────────────────── if engine_def.test_snippet: while True: console.print() @@ -2377,7 +2643,7 @@ async def _handle_connect_datasource( # Re-collect secret fields only console.print() - for f in engine_def.fields: + for f in active_fields: if not f.secret: continue value = Prompt.ask( @@ -2388,20 +2654,20 @@ async def _handle_connect_datasource( ) if value: credentials[f.name] = value - continue # retry test + # Try again with updated credentials + continue # Success console.print("[anton.success] ✓ Connected successfully![/]") break - # ── Step 6: save + write topic ─────────────────────────────── conn_name = registry.derive_name(engine_def, credentials) - if not conn_name or conn_name == engine_def.engine: - # Fall back to auto-number if name_from didn't resolve + if not conn_name: n = vault.next_connection_number(engine_def.engine) conn_name = str(n) vault.save(engine_def.engine, conn_name, credentials) + _register_secret_vars(engine_def) slug = f"{engine_def.engine}-{conn_name}" console.print( f" Credentials saved to Local Vault as [bold]\"{slug}\"[/bold]." @@ -2422,12 +2688,62 @@ async def _handle_connect_datasource( return session +def _handle_list_data_sources(console: Console) -> None: + """Print all saved Local Vault connections with their DS_* var names.""" + vault = DataVault() + conns = vault.list_connections() + console.print() + if not conns: + console.print("[anton.muted]No data sources connected yet.[/]") + console.print("[anton.muted]Use /connect-data-source to add one.[/]") + console.print() + return + console.print("[anton.cyan]Connected data sources:[/]") + console.print() + for c in conns: + fields = vault.load(c["engine"], c["name"]) or {} + var_names = ", ".join(f"DS_{k.upper()}" for k in fields) + slug = f"{c['engine']}-{c['name']}" + complete = "✓" if var_names else "⚠ incomplete" + console.print(f" [bold]{slug}[/] {complete}") + if var_names: + console.print(f" [anton.muted] {var_names}[/]") + console.print() + + +def _handle_remove_data_source(console: Console, slug: str) -> None: + """Delete a connection from the Local Vault by slug (engine-name).""" + vault = DataVault() + parts = slug.split("-", 1) + if len(parts) != 2: + console.print( + f"[anton.warning]Invalid name '{slug}'. Use engine-name format.[/]" + ) + console.print() + return + engine, name = parts + if vault.load(engine, name) is None: + console.print(f"[anton.warning]No connection '{slug}' found.[/]") + console.print() + return + if Confirm.ask( + f"Remove '{slug}' from Local Vault?", default=False, console=console + ): + vault.delete(engine, name) + console.print(f"[anton.success]Removed {slug}.[/]") + else: + console.print("[anton.muted]Cancelled.[/]") + console.print() + + def _print_slash_help(console: Console) -> None: """Print available slash commands.""" console.print() console.print("[anton.cyan]Available commands:[/]") console.print(" [bold]/connect[/] — Connect to a Minds server and select a mind") console.print(" [bold]/connect-data-source[/] — Connect a database or API to the Local Vault") + console.print(" [bold]/list-data-sources[/] — List all saved data source connections") + console.print(" [bold]/remove-data-source[/] — Remove a saved connection") console.print(" [bold]/data-connections[/] — View and manage stored keys and connections") console.print(" [bold]/setup[/] — Configure models or memory settings") console.print(" [bold]/memory[/] — Show memory status dashboard") @@ -2574,6 +2890,17 @@ async def _chat_loop(console: Console, settings: AntonSettings, *, resume: bool workspace = Workspace(settings.workspace_path) workspace.apply_env_to_process() + # Inject all Local Vault connections as DS_* env vars so every scratchpad + # subprocess inherits them. Must happen before any ChatSession is created. + _dv = DataVault() + _dreg = DatasourceRegistry() + for _conn in _dv.list_connections(): + _dv.inject_env(_conn["engine"], _conn["name"]) + _edef = _dreg.get(_conn["engine"]) + if _edef is not None: + _register_secret_vars(_edef) + del _dv, _dreg + # --- Memory system (brain-inspired architecture) --- global_memory_dir = Path.home() / ".anton" / "memory" project_memory_dir = settings.workspace_path / ".anton" / "memory" @@ -2787,6 +3114,35 @@ def _bottom_toolbar(): console, session._scratchpads, session, ) continue + elif cmd == "/list-data-sources": + _handle_list_data_sources(console) + continue + elif cmd == "/remove-data-source": + arg = parts[1].strip() if len(parts) > 1 else "" + if not arg: + console.print( + "[anton.warning]Usage: /remove-data-source" + " [/]" + ) + else: + _handle_remove_data_source(console, arg) + continue + elif cmd == "/edit-data-source": + arg = parts[1].strip() if len(parts) > 1 else "" + console.print( + f"[anton.muted]/edit-data-source is not yet implemented. " + f"To update '{arg}', use /remove-data-source {arg}" + f" then /connect-data-source.[/]" + ) + console.print() + continue + elif cmd == "/test-data-source": + console.print( + "[anton.muted]/test-data-source is not yet" + " implemented.[/]" + ) + console.print() + continue elif cmd == "/data-connections": session = await _handle_data_connections( console, settings, workspace, session, diff --git a/anton/cli.py b/anton/cli.py index decd701..033de34 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -442,11 +442,33 @@ def version() -> None: console.print(f"Anton v{__version__}") -@app.command("connect-datasource") -def connect_datasource(ctx: typer.Context) -> None: - """Connect a new datasource.""" - #from anton.datasources import connect_new_datasource - print("Datasource connection flow is not implemented yet.") +@app.command("connect-data-source") +def connect_data_source(ctx: typer.Context) -> None: + """Connect a database or API to the Local Vault.""" + import asyncio + + from anton.chat import ChatSession, _handle_connect_datasource + from anton.llm.client import LLMClient + from anton.scratchpad import ScratchpadManager + settings = _get_settings(ctx) _ensure_workspace(settings) - #connect_new_datasource(console, settings) \ No newline at end of file + _ensure_api_key(settings) + + llm_client = LLMClient.from_settings(settings) + scratchpads = ScratchpadManager( + coding_provider=settings.coding_provider, + coding_model=settings.coding_model, + coding_api_key=( + settings.anthropic_api_key + if settings.coding_provider == "anthropic" + else settings.openai_api_key + ) or "", + ) + session = ChatSession(llm_client) + + async def _run() -> None: + updated = await _handle_connect_datasource(console, scratchpads, session) + await updated._scratchpads.close_all() + + asyncio.run(_run()) \ No newline at end of file diff --git a/anton/data_vault.py b/anton/data_vault.py index 9ab73d4..560b06b 100644 --- a/anton/data_vault.py +++ b/anton/data_vault.py @@ -36,7 +36,7 @@ def _ensure_dir(self) -> None: def save(self, engine: str, name: str, credentials: dict[str, str]) -> Path: - """Write credentials as JSON. Creates vault dir if needed.""" + """Write credentials as JSON atomically. Creates vault dir if needed.""" self._ensure_dir() path = self._path_for(engine, name) data = { @@ -45,8 +45,10 @@ def save(self, engine: str, name: str, credentials: dict[str, str]) -> Path: "created_at": datetime.now(timezone.utc).isoformat(), "fields": credentials, } - path.write_text(json.dumps(data, indent=2), encoding="utf-8") - path.chmod(0o600) + tmp = path.with_suffix(".tmp") + tmp.write_text(json.dumps(data, indent=2), encoding="utf-8") + tmp.chmod(0o600) + tmp.rename(path) return path def load(self, engine: str, name: str) -> dict[str, str] | None: diff --git a/anton/datasource_registry.py b/anton/datasource_registry.py index 530b482..f34b334 100644 --- a/anton/datasource_registry.py +++ b/anton/datasource_registry.py @@ -38,6 +38,7 @@ class DatasourceEngine: name_from: Union[str, list[str]] = "" fields: list[DatasourceField] = field(default_factory=list) # "choice" means the user must pick from auth_methods before collecting fields + # empty string means no choice, just collect fields from the top-level "fields" list auth_method: str = "" auth_methods: list[AuthMethod] = field(default_factory=list) test_snippet: str = "" @@ -76,7 +77,9 @@ def _parse_file(path: Path) -> dict[str, DatasourceEngine]: yaml_text = match.group(3) try: data = yaml.safe_load(yaml_text) - except yaml.YAMLError: + except yaml.YAMLError as exc: + import sys + print(f"[anton] Warning: skipping malformed YAML block in {path}: {exc}", file=sys.stderr) continue if not isinstance(data, dict) or "engine" not in data: continue @@ -126,14 +129,17 @@ def get(self, engine_slug: str) -> DatasourceEngine | None: return self._engines.get(engine_slug) def find_by_name(self, display_name: str) -> DatasourceEngine | None: - """Case-insensitive match on display_name or engine slug.""" + """Case-insensitive match on display_name or engine slug. + """ needle = display_name.strip().lower() for engine in self._engines.values(): - if engine.display_name.lower() == needle: + if engine.display_name.lower() == needle or engine.engine.lower() == needle: return engine - if engine.engine.lower() == needle: - return engine - return None + matches = [ + e for e in self._engines.values() + if needle in e.display_name.lower() or needle in e.engine.lower() + ] + return matches[0] if len(matches) == 1 else None def all_engines(self) -> list[DatasourceEngine]: return sorted(self._engines.values(), key=lambda e: e.display_name) diff --git a/anton/llm/prompts.py b/anton/llm/prompts.py index adbf923..c26eef1 100644 --- a/anton/llm/prompts.py +++ b/anton/llm/prompts.py @@ -58,6 +58,9 @@ handle_tool(name, inputs) is a plain sync function returning a string result. Use this for \ multi-step AI workflows like classification, extraction, or analysis with structured outputs. - All .anton/.env secrets are available as environment variables (os.environ). +- Data source credentials are injected as DS_ environment \ +variables (e.g. DS_HOST, DS_PASSWORD, DS_ACCESS_TOKEN). Use them directly \ +in scratchpad code — never read ~/.anton/data_vault/ files directly. - When the user asks how you solved something or wants to see your work, use the scratchpad \ dump action — it shows a clean notebook-style summary without wasting tokens on reformatting. - Always use print() to produce output — scratchpad captures stdout. diff --git a/datasources.md b/datasources.md index c63074a..029fcc3 100644 --- a/datasources.md +++ b/datasources.md @@ -10,6 +10,7 @@ before any scratchpad code runs. Never embed raw values in code strings. --- ## PostgreSQL + ```yaml engine: postgres display_name: PostgreSQL @@ -39,7 +40,107 @@ Common errors: "password authentication failed" → wrong password or user. --- +## MySQL + +```yaml +engine: mysql +display_name: MySQL +pip: mysql-connector-python +name_from: database +fields: + - { name: host, required: true, secret: false, description: "hostname or IP of your MySQL server" } + - { name: port, required: false, secret: false, description: "port number", default: "3306" } + - { name: database, required: true, secret: false, description: "database name to connect" } + - { name: user, required: true, secret: false, description: "MySQL username" } + - { name: password, required: true, secret: true, description: "MySQL password" } + - { name: ssl, required: false, secret: false, description: "enable SSL (true/false)" } + - { name: ssl_ca, required: false, secret: false, description: "path to CA certificate file" } + - { name: ssl_cert, required: false, secret: false, description: "path to client certificate file" } + - { name: ssl_key, required: false, secret: false, description: "path to client private key file" } +test_snippet: | + import mysql.connector, os + conn = mysql.connector.connect( + host=os.environ['DS_HOST'], + port=int(os.environ.get('DS_PORT', '3306')), + database=os.environ['DS_DATABASE'], + user=os.environ['DS_USER'], + password=os.environ['DS_PASSWORD'], + ) + conn.close() + print("ok") +``` + +--- + +## MariaDB + +```yaml +engine: mariadb +display_name: MariaDB +pip: mysql-connector-python +name_from: database +fields: + - { name: host, required: true, secret: false, description: "hostname or IP of your MariaDB server" } + - { name: port, required: false, secret: false, description: "port number", default: "3306" } + - { name: database, required: true, secret: false, description: "database name to connect" } + - { name: user, required: true, secret: false, description: "MariaDB username" } + - { name: password, required: true, secret: true, description: "MariaDB password" } + - { name: ssl, required: false, secret: false, description: "enable SSL (true/false)" } + - { name: ssl_ca, required: false, secret: false, description: "path to CA certificate file" } + - { name: ssl_cert, required: false, secret: false, description: "path to client certificate file" } + - { name: ssl_key, required: false, secret: false, description: "path to client private key file" } +test_snippet: | + import mysql.connector, os + conn = mysql.connector.connect( + host=os.environ['DS_HOST'], + port=int(os.environ.get('DS_PORT', '3306')), + database=os.environ['DS_DATABASE'], + user=os.environ['DS_USER'], + password=os.environ['DS_PASSWORD'], + ) + conn.close() + print("ok") +``` + +MariaDB is wire-compatible with MySQL, so the mysql-connector-python driver works for both. + +--- + +## Microsoft SQL Server + +```yaml +engine: mssql +display_name: Microsoft SQL Server +pip: pymssql +name_from: database +fields: + - { name: host, required: true, secret: false, description: "hostname or IP of the SQL Server (for Azure use server field instead)" } + - { name: port, required: false, secret: false, description: "port number", default: "1433" } + - { name: database, required: true, secret: false, description: "database name" } + - { name: user, required: true, secret: false, description: "SQL Server username" } + - { name: password, required: true, secret: true, description: "SQL Server password" } + - { name: server, required: false, secret: false, description: "server name — use for named instances or Azure SQL (e.g. myserver.database.windows.net)" } + - { name: schema, required: false, secret: false, description: "schema name (defaults to dbo)" } +test_snippet: | + import pymssql, os + conn = pymssql.connect( + server=os.environ.get('DS_SERVER') or os.environ['DS_HOST'], + port=os.environ.get('DS_PORT', '1433'), + database=os.environ['DS_DATABASE'], + user=os.environ['DS_USER'], + password=os.environ['DS_PASSWORD'], + ) + conn.close() + print("ok") +``` + +For Azure SQL Database use `server` field with value like `myserver.database.windows.net`. +For Windows Authentication omit user/password and ensure pymssql is built with Kerberos support. + +--- + ## HubSpot + ```yaml engine: hubspot display_name: HubSpot @@ -71,6 +172,7 @@ For Private App Token: HubSpot → Settings → Integrations → Private Apps Recommended scopes: `crm.objects.contacts.read`, `crm.objects.deals.read`, `crm.objects.companies.read`. For OAuth2: collect client_id and client_secret, then use the scratchpad to: + 1. Build the authorization URL using `auth_url` + params above 2. Start a local HTTP server on port 8099 to catch the callback 3. Open the URL in the user's browser with `webbrowser.open()` @@ -80,9 +182,349 @@ For OAuth2: collect client_id and client_secret, then use the scratchpad to: --- ## Snowflake -... + +```yaml +engine: snowflake +display_name: Snowflake +pip: snowflake-connector-python +name_from: [account, database] +auth_method: choice +auth_methods: + - name: password + display: "Username / Password" + fields: + - { name: account, required: true, secret: false, description: "Snowflake account identifier (e.g. xy12345.us-east-1 or orgname-accountname)" } + - { name: user, required: true, secret: false, description: "Snowflake username" } + - { name: password, required: true, secret: true, description: "Snowflake password" } + - { name: database, required: true, secret: false, description: "database name" } + - { name: schema, required: false, secret: false, description: "schema name (defaults to PUBLIC)" } + - { name: warehouse, required: false, secret: false, description: "warehouse to use for queries" } + - { name: role, required: false, secret: false, description: "role to assume" } + - name: key_pair + display: "Key-Pair Authentication" + fields: + - { name: account, required: true, secret: false, description: "Snowflake account identifier" } + - { name: user, required: true, secret: false, description: "Snowflake username" } + - { name: private_key, required: true, secret: true, description: "PEM-formatted private key content" } + - { name: private_key_passphrase, required: false, secret: true, description: "passphrase for encrypted private key" } + - { name: database, required: true, secret: false, description: "database name" } + - { name: schema, required: false, secret: false, description: "schema name" } + - { name: warehouse, required: false, secret: false, description: "warehouse to use" } + - { name: role, required: false, secret: false, description: "role to assume" } +test_snippet: | + import snowflake.connector, os + conn = snowflake.connector.connect( + account=os.environ['DS_ACCOUNT'], + user=os.environ['DS_USER'], + password=os.environ.get('DS_PASSWORD', ''), + database=os.environ.get('DS_DATABASE', ''), + schema=os.environ.get('DS_SCHEMA', 'PUBLIC'), + warehouse=os.environ.get('DS_WAREHOUSE', ''), + ) + conn.close() + print("ok") +``` + +Account identifier: Admin → Accounts → hover your account name to reveal the identifier. +Format is either `-` or `..`. + +--- + +## Oracle Database + +```yaml +engine: oracle_database +display_name: Oracle Database +pip: oracledb +name_from: [host, service_name] +fields: + - { name: user, required: true, secret: false, description: "Oracle database username" } + - { name: password, required: true, secret: true, description: "Oracle database password" } + - { name: host, required: false, secret: false, description: "hostname or IP address of the Oracle server" } + - { name: port, required: false, secret: false, description: "port number (default 1521)", default: "1521" } + - { name: service_name, required: false, secret: false, description: "Oracle service name (preferred over SID)" } + - { name: sid, required: false, secret: false, description: "Oracle SID — use service_name if possible" } + - { name: dsn, required: false, secret: false, description: "full DSN string — overrides host/port/service_name" } + - { name: auth_mode, required: false, secret: false, description: "authorization mode (e.g. SYSDBA)" } +test_snippet: | + import oracledb, os + dsn = os.environ.get('DS_DSN') or oracledb.makedsn( + os.environ.get('DS_HOST', 'localhost'), + os.environ.get('DS_PORT', '1521'), + service_name=os.environ.get('DS_SERVICE_NAME', ''), + ) + conn = oracledb.connect(user=os.environ['DS_USER'], password=os.environ['DS_PASSWORD'], dsn=dsn) + conn.close() + print("ok") +``` + +oracledb runs in thin mode by default (no Oracle Client libraries needed). +Set `auth_mode` to `SYSDBA` or `SYSOPER` for privileged connections. + +--- + +## Google BigQuery + +```yaml +engine: bigquery +display_name: Google BigQuery +pip: google-cloud-bigquery +name_from: [project_id, dataset] +fields: + - { name: project_id, required: true, secret: false, description: "GCP project ID containing your BigQuery datasets" } + - { name: dataset, required: true, secret: false, description: "BigQuery dataset name" } + - { name: service_account_json, required: false, secret: true, description: "contents of service account JSON key file (paste the full JSON)" } + - { name: service_account_keys, required: false, secret: false, description: "path to service account JSON key file on disk" } +test_snippet: | + import json, os + from google.cloud import bigquery + from google.oauth2 import service_account + + sa_json = os.environ.get('DS_SERVICE_ACCOUNT_JSON', '') + if sa_json: + creds = service_account.Credentials.from_service_account_info( + json.loads(sa_json), + scopes=['https://www.googleapis.com/auth/bigquery.readonly'], + ) + client = bigquery.Client(project=os.environ['DS_PROJECT_ID'], credentials=creds) + else: + client = bigquery.Client(project=os.environ['DS_PROJECT_ID']) + list(client.list_datasets()) + print("ok") +``` + +To create a service account key: GCP Console → IAM → Service Accounts → your account → +Keys → Add Key → JSON. Grant the account `BigQuery Data Viewer` + `BigQuery Job User` roles. + +--- + +## Amazon Redshift + +```yaml +engine: redshift +display_name: Amazon Redshift +pip: psycopg2-binary +name_from: [host, database] +fields: + - { name: host, required: true, secret: false, description: "Redshift cluster endpoint (e.g. mycluster.abc123.us-east-1.redshift.amazonaws.com)" } + - { name: port, required: false, secret: false, description: "port number", default: "5439" } + - { name: database, required: true, secret: false, description: "database name" } + - { name: user, required: true, secret: false, description: "Redshift username" } + - { name: password, required: true, secret: true, description: "Redshift password" } + - { name: schema, required: false, secret: false, description: "schema name (defaults to public)" } + - { name: sslmode, required: false, secret: false, description: "SSL mode: require (default), verify-ca, verify-full, disable", default: "require" } +test_snippet: | + import psycopg2, os + conn = psycopg2.connect( + host=os.environ['DS_HOST'], + port=int(os.environ.get('DS_PORT', '5439')), + dbname=os.environ['DS_DATABASE'], + user=os.environ['DS_USER'], + password=os.environ['DS_PASSWORD'], + sslmode=os.environ.get('DS_SSLMODE', 'require'), + ) + conn.close() + print("ok") +``` + +Redshift is PostgreSQL-compatible. Find the cluster endpoint in AWS Console → +Redshift → Clusters → your cluster → Endpoint (omit the port suffix). + +--- + +## Databricks + +```yaml +engine: databricks +display_name: Databricks +pip: databricks-sql-connector +name_from: [server_hostname, catalog] +fields: + - { name: server_hostname, required: true, secret: false, description: "server hostname for the cluster or SQL warehouse (from JDBC/ODBC connection string)" } + - { name: http_path, required: true, secret: false, description: "HTTP path of the cluster or SQL warehouse" } + - { name: access_token, required: true, secret: true, description: "Databricks personal access token" } + - { name: catalog, required: false, secret: false, description: "Unity Catalog name (defaults to hive_metastore)" } + - { name: schema, required: false, secret: false, description: "schema (database) to use" } + - { name: session_configuration, required: false, secret: false, description: "Spark session configuration as key=value pairs" } +test_snippet: | + from databricks import sql as dbsql + import os + conn = dbsql.connect( + server_hostname=os.environ['DS_SERVER_HOSTNAME'], + http_path=os.environ['DS_HTTP_PATH'], + access_token=os.environ['DS_ACCESS_TOKEN'], + catalog=os.environ.get('DS_CATALOG', ''), + schema=os.environ.get('DS_SCHEMA', ''), + ) + conn.close() + print("ok") +``` + +Personal access token: User Settings → Developer → Access Tokens → Generate New Token. +HTTP path and server hostname: SQL Warehouses → your warehouse → Connection Details tab. + +--- + +## DuckDB + +```yaml +engine: duckdb +display_name: DuckDB +pip: duckdb +name_from: database +fields: + - { name: database, required: false, secret: false, description: "path to DuckDB database file; omit or use :memory: for in-memory database", default: ":memory:" } + - { name: motherduck_token, required: false, secret: true, description: "MotherDuck access token for connecting to a MotherDuck cloud database" } + - { name: read_only, required: false, secret: false, description: "open in read-only mode (true/false)", default: "false" } +test_snippet: | + import duckdb, os + db_path = os.environ.get('DS_DATABASE', ':memory:') + token = os.environ.get('DS_MOTHERDUCK_TOKEN', '') + if token: + conn = duckdb.connect(f'md:{db_path}?motherduck_token={token}') + else: + conn = duckdb.connect(db_path) + conn.execute('SELECT 1').fetchone() + conn.close() + print("ok") +``` + +For MotherDuck, use `database` as the MotherDuck database name (e.g. `my_db`) and provide +the access token. For local files, provide the path to a `.duckdb` file. + +--- + +## pgvector + +```yaml +engine: pgvector +display_name: pgvector +pip: pgvector psycopg2-binary +name_from: database +fields: + - { name: host, required: true, secret: false, description: "hostname or IP of your PostgreSQL server with pgvector extension" } + - { name: port, required: false, secret: false, description: "port number", default: "5432" } + - { name: database, required: true, secret: false, description: "database name" } + - { name: user, required: true, secret: false, description: "database username" } + - { name: password, required: true, secret: true, description: "database password" } + - { name: schema, required: false, secret: false, description: "schema name (defaults to public)" } + - { name: sslmode, required: false, secret: false, description: "SSL mode: prefer, require, disable" } +test_snippet: | + import psycopg2, os + conn = psycopg2.connect( + host=os.environ['DS_HOST'], + port=int(os.environ.get('DS_PORT', '5432')), + dbname=os.environ['DS_DATABASE'], + user=os.environ['DS_USER'], + password=os.environ['DS_PASSWORD'], + ) + cur = conn.cursor() + cur.execute("SELECT extname FROM pg_extension WHERE extname = 'vector'") + if not cur.fetchone(): + raise RuntimeError("pgvector extension not installed — run: CREATE EXTENSION vector;") + conn.close() + print("ok") +``` + +pgvector must be installed in the PostgreSQL instance: `CREATE EXTENSION IF NOT EXISTS vector;` +Managed options: Supabase, Neon, and AWS RDS for PostgreSQL all support pgvector. + +--- + +## ChromaDB + +```yaml +engine: chromadb +display_name: ChromaDB +pip: chromadb +name_from: host +fields: + - { name: host, required: false, secret: false, description: "ChromaDB server host for HTTP client mode (omit for local in-process mode)" } + - { name: port, required: false, secret: false, description: "ChromaDB server port", default: "8000" } + - { name: persist_directory, required: false, secret: false, description: "local directory for persistent storage (local mode only)" } +test_snippet: | + import chromadb, os + host = os.environ.get('DS_HOST', '') + if host: + client = chromadb.HttpClient( + host=host, + port=int(os.environ.get('DS_PORT', '8000')), + ) + else: + persist_dir = os.environ.get('DS_PERSIST_DIRECTORY', '') + if persist_dir: + client = chromadb.PersistentClient(path=persist_dir) + else: + client = chromadb.EphemeralClient() + client.heartbeat() + print("ok") +``` + +Three modes: HTTP client (connect to a running ChromaDB server), persistent local (file-backed), +or ephemeral in-memory. For production, run `chroma run` to start the HTTP server. + +--- + +## Salesforce + +```yaml +engine: salesforce +display_name: Salesforce +pip: salesforce_api +name_from: username +fields: + - { name: username, required: true, secret: false, description: "Salesforce account username (email)" } + - { name: password, required: true, secret: true, description: "Salesforce account password" } + - { name: client_id, required: true, secret: false, description: "consumer key from the connected app" } + - { name: client_secret, required: true, secret: true, description: "consumer secret from the connected app" } + - { name: is_sandbox, required: false, secret: false, description: "true to connect to sandbox, false for production", default: "false" } +test_snippet: | + import salesforce_api, os + sf = salesforce_api.Salesforce( + username=os.environ['DS_USERNAME'], + password=os.environ['DS_PASSWORD'], + client_id=os.environ['DS_CLIENT_ID'], + client_secret=os.environ['DS_CLIENT_SECRET'], + is_sandbox=os.environ.get('DS_IS_SANDBOX', 'false').lower() == 'true', + ) + sf.query('SELECT Id FROM Account LIMIT 1') + print("ok") +``` + +To get client_id and client_secret: Setup → Apps → App Manager → New Connected App. +Enable OAuth, add callback URL, select scopes (api, refresh_token). + +--- + +## Shopify + +```yaml +engine: shopify +display_name: Shopify +pip: ShopifyAPI +name_from: shop_url +fields: + - { name: shop_url, required: true, secret: false, description: "your Shopify store URL (e.g. mystore.myshopify.com)" } + - { name: client_id, required: true, secret: false, description: "client ID (API key) of the custom app" } + - { name: client_secret, required: true, secret: true, description: "client secret (API secret key) of the custom app" } +test_snippet: | + import shopify, os + shop_url = os.environ['DS_SHOP_URL'].rstrip('/') + if not shop_url.startswith('https://'): + shop_url = f'https://{shop_url}' + session = shopify.Session(shop_url, '2024-01', os.environ['DS_CLIENT_SECRET']) + shopify.ShopifyResource.activate_session(session) + shopify.Shop.current() + shopify.ShopifyResource.clear_session() + print("ok") +``` + +Create a custom app: Shopify Admin → Settings → Apps → Develop apps → Create an app. +Grant required API permissions (read_products, read_orders, etc.) then install the app. + +--- ## Adding a new data source Follow the YAML format above. Add to `~/.anton/datasources.md` (user overrides). -Anton merges user overrides on top of the built-in registry at startup. \ No newline at end of file +Anton merges user overrides on top of the built-in registry at startup. diff --git a/tests/test_datasource.py b/tests/test_datasource.py index 467f071..1706e3e 100644 --- a/tests/test_datasource.py +++ b/tests/test_datasource.py @@ -728,4 +728,144 @@ async def test_selective_field_collection(self, registry, vault_dir): conns = vault.list_connections() assert len(conns) == 1 saved = vault.load("postgresql", conns[0]["name"]) - assert set(saved.keys()) == {"host", "user", "password"} \ No newline at end of file + assert set(saved.keys()) == {"host", "user", "password"} + + +# ───────────────────────────────────────────────────────────────────────────── +# Credential scrubbing +# ───────────────────────────────────────────────────────────────────────────── + + +class TestCredentialScrubbing: + """_scrub_credentials and _register_secret_vars.""" + + def setup_method(self): + # Reset the module-level sets before each test + from anton.chat import _DS_KNOWN_VARS, _DS_SECRET_VARS + _DS_SECRET_VARS.clear() + _DS_KNOWN_VARS.clear() + + def test_register_secret_vars_adds_secret_fields(self, registry): + """Secret fields are added to _DS_SECRET_VARS; non-secret fields are not.""" + from anton.chat import _DS_SECRET_VARS, _register_secret_vars + + pg = registry.get("postgresql") + assert pg is not None + _register_secret_vars(pg) + + assert "DS_PASSWORD" in _DS_SECRET_VARS + # host and port are not secret in the fixture definition + assert "DS_HOST" not in _DS_SECRET_VARS + assert "DS_PORT" not in _DS_SECRET_VARS + + def test_scrub_replaces_registered_secret_value(self): + """A registered secret value is replaced with its placeholder.""" + import os + from anton.chat import _DS_SECRET_VARS, _scrub_credentials + + _DS_SECRET_VARS.add("DS_ACCESS_TOKEN") + os.environ["DS_ACCESS_TOKEN"] = "supersecrettoken123" + try: + result = _scrub_credentials("token is supersecrettoken123 here") + assert "supersecrettoken123" not in result + assert "[DS_ACCESS_TOKEN]" in result + finally: + del os.environ["DS_ACCESS_TOKEN"] + _DS_SECRET_VARS.discard("DS_ACCESS_TOKEN") + + def test_scrub_leaves_non_secret_field_readable(self, registry): + """Non-secret DS_* values (host, port) are left untouched.""" + import os + from anton.chat import _register_secret_vars, _scrub_credentials + + pg = registry.get("postgresql") + assert pg is not None + _register_secret_vars(pg) + + os.environ["DS_HOST"] = "db.example.com" + os.environ["DS_PASSWORD"] = "s3cr3tpassword99" + try: + result = _scrub_credentials("host=db.example.com pass=s3cr3tpassword99") + assert "db.example.com" in result # host left readable + assert "s3cr3tpassword99" not in result # password redacted + assert "[DS_PASSWORD]" in result + finally: + del os.environ["DS_HOST"] + del os.environ["DS_PASSWORD"] + + def test_scrub_skips_short_values(self): + """Values of 8 characters or fewer are not scrubbed (e.g. port numbers).""" + import os + from anton.chat import _DS_SECRET_VARS, _scrub_credentials + + _DS_SECRET_VARS.add("DS_PASSWORD") + os.environ["DS_PASSWORD"] = "short" # 5 chars — under threshold + try: + result = _scrub_credentials("password=short") + assert "short" in result + finally: + del os.environ["DS_PASSWORD"] + _DS_SECRET_VARS.discard("DS_PASSWORD") + + def test_scrub_fallback_redacts_unknown_long_ds_vars(self): + """Long DS_* vars not in _DS_SECRET_VARS are scrubbed as a safety fallback.""" + import os + from anton.chat import _scrub_credentials + + # _DS_SECRET_VARS is empty (cleared in setup_method) + os.environ["DS_WEBHOOK_SECRET"] = "wh_sec_abcdefgh1234" + try: + result = _scrub_credentials("secret=wh_sec_abcdefgh1234 here") + assert "wh_sec_abcdefgh1234" not in result + assert "[DS_WEBHOOK_SECRET]" in result + finally: + del os.environ["DS_WEBHOOK_SECRET"] + + @pytest.mark.asyncio + async def test_register_and_scrub_on_connect(self, registry, vault_dir): + """After _handle_connect_datasource, the new secret var is immediately scrubbed.""" + import os + from unittest.mock import AsyncMock, MagicMock, patch + + from anton.chat import _DS_SECRET_VARS, _handle_connect_datasource, _scrub_credentials + + vault = DataVault(vault_dir=vault_dir) + + session = MagicMock() + session._history = [] + session._cortex = None + + pad = AsyncMock() + pad.execute = AsyncMock( + return_value=MagicMock(stdout="ok", stderr="", error=None) + ) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + + secret_pw = "supersecretpassword999" + prompt_responses = iter([ + "PostgreSQL", # engine + "y", # have all credentials + "db.host.com", # host + "5432", # port + "mydb", # database + "alice", # user + secret_pw, # password + "public", # schema (optional, skip) + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + ): + await _handle_connect_datasource(MagicMock(), session._scratchpads, session) + + # After connect, password should be in the secret set and scrubbed + assert "DS_PASSWORD" in _DS_SECRET_VARS + os.environ["DS_PASSWORD"] = secret_pw + try: + result = _scrub_credentials(f"error: auth failed with {secret_pw}") + assert secret_pw not in result + assert "[DS_PASSWORD]" in result + finally: + del os.environ["DS_PASSWORD"] \ No newline at end of file From 06bb5753dd0a06d546ccced1304f7c9b38f99ce2 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Fri, 20 Mar 2026 11:46:03 +0100 Subject: [PATCH 08/70] Implement edit and reconnect --- anton/chat.py | 1196 +++++++++++++++++++++++++--------- anton/cli.py | 65 +- anton/data_vault.py | 29 +- anton/datasource_registry.py | 50 +- 4 files changed, 980 insertions(+), 360 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index add55d6..9f95e24 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -43,10 +43,14 @@ prepare_scratchpad_exec, ) from anton.data_vault import DataVault -from anton.datasource_registry import DatasourceEngine, DatasourceField, DatasourceRegistry, _parse_file as _ds_parse_file +from anton.datasource_registry import ( + DatasourceEngine, + DatasourceField, + DatasourceRegistry, + _parse_file as _ds_parse_file, +) from rich.prompt import Confirm, Prompt - if TYPE_CHECKING: from rich.console import Console @@ -102,7 +106,11 @@ def __init__( self._console = console self._history: list[dict] = list(initial_history) if initial_history else [] self._pending_memory_confirmations: list = [] - self._turn_count = sum(1 for m in self._history if m.get("role") == "user") if initial_history else 0 + self._turn_count = ( + sum(1 for m in self._history if m.get("role") == "user") + if initial_history + else 0 + ) self._history_store = history_store self._session_id = session_id self._cancel_event = asyncio.Event() @@ -141,17 +149,19 @@ def repair_history(self) -> None: ] if not tool_ids: return - self._history.append({ - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": tid, - "content": "Cancelled by user.", - } - for tid in tool_ids - ], - }) + self._history.append( + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tid, + "content": "Cancelled by user.", + } + for tid in tool_ids + ], + } + ) def _persist_history(self) -> None: """Save current history to disk if a history store is configured.""" @@ -161,7 +171,9 @@ def _persist_history(self) -> None: async def _build_system_prompt(self, user_message: str = "") -> str: prompt = CHAT_SYSTEM_PROMPT.format( runtime_context=self._runtime_context, - visualizations_section=build_visualizations_prompt(self._proactive_dashboards), + visualizations_section=build_visualizations_prompt( + self._proactive_dashboards + ), ) # Inject memory context (replaces old self_awareness) if self._cortex is not None: @@ -186,27 +198,60 @@ async def _build_system_prompt(self, user_message: str = "") -> str: # Packages the LLM is most likely to care about when writing scratchpad code. _NOTABLE_PACKAGES: set[str] = { - "numpy", "pandas", "matplotlib", "seaborn", "scipy", "scikit-learn", - "requests", "httpx", "aiohttp", "beautifulsoup4", "lxml", - "pillow", "sympy", "networkx", "sqlalchemy", "pydantic", - "rich", "tqdm", "click", "fastapi", "flask", "django", - "openai", "anthropic", "tiktoken", "transformers", "torch", - "polars", "pyarrow", "openpyxl", "xlsxwriter", - "plotly", "bokeh", "altair", - "pytest", "hypothesis", - "yaml", "pyyaml", "toml", "tomli", "tomllib", - "jinja2", "markdown", "pygments", - "cryptography", "paramiko", "boto3", + "numpy", + "pandas", + "matplotlib", + "seaborn", + "scipy", + "scikit-learn", + "requests", + "httpx", + "aiohttp", + "beautifulsoup4", + "lxml", + "pillow", + "sympy", + "networkx", + "sqlalchemy", + "pydantic", + "rich", + "tqdm", + "click", + "fastapi", + "flask", + "django", + "openai", + "anthropic", + "tiktoken", + "transformers", + "torch", + "polars", + "pyarrow", + "openpyxl", + "xlsxwriter", + "plotly", + "bokeh", + "altair", + "pytest", + "hypothesis", + "yaml", + "pyyaml", + "toml", + "tomli", + "tomllib", + "jinja2", + "markdown", + "pygments", + "cryptography", + "paramiko", + "boto3", } def _build_tools(self) -> list[dict]: scratchpad_tool = dict(SCRATCHPAD_TOOL) pkg_list = self._scratchpads._available_packages if pkg_list: - notable = sorted( - p for p in pkg_list - if p.lower() in self._NOTABLE_PACKAGES - ) + notable = sorted(p for p in pkg_list if p.lower() in self._NOTABLE_PACKAGES) if notable: pkg_line = ", ".join(notable) extra = f"\n\nInstalled packages ({len(pkg_list)} total, notable: {pkg_line})." @@ -218,7 +263,9 @@ def _build_tools(self) -> list[dict]: if self._cortex is not None: wisdom = self._cortex.get_scratchpad_context() if wisdom: - scratchpad_tool["description"] += f"\n\nLessons from past sessions:\n{wisdom}" + scratchpad_tool[ + "description" + ] += f"\n\nLessons from past sessions:\n{wisdom}" tools = [scratchpad_tool] if self._cortex is not None: @@ -226,6 +273,7 @@ def _build_tools(self) -> list[dict]: elif self._self_awareness is not None: # Legacy fallback from anton.tools import MEMORIZE_TOOL as _MT + tools.append(_MT) if self._episodic is not None and self._episodic.enabled: tools.append(RECALL_TOOL) @@ -263,8 +311,7 @@ async def _summarize_history(self) -> None: if not isinstance(content, list): break has_tool_result = any( - isinstance(b, dict) and b.get("type") == "tool_result" - for b in content + isinstance(b, dict) and b.get("type") == "tool_result" for b in content ) if not has_tool_result: break @@ -295,9 +342,13 @@ async def _summarize_history(self) -> None: if block.get("type") == "text": lines.append(f"[{role}]: {block['text'][:1000]}") elif block.get("type") == "tool_use": - lines.append(f"[{role}/tool_use]: {block.get('name', '')}({str(block.get('input', ''))[:500]})") + lines.append( + f"[{role}/tool_use]: {block.get('name', '')}({str(block.get('input', ''))[:500]})" + ) elif block.get("type") == "tool_result": - lines.append(f"[tool_result]: {str(block.get('content', ''))[:500]}") + lines.append( + f"[tool_result]: {str(block.get('content', ''))[:500]}" + ) old_text = "\n".join(lines) # Cap at ~8000 chars to avoid overloading the summarizer @@ -330,7 +381,11 @@ async def _summarize_history(self) -> None: # If the recent portion starts with a user message, insert a minimal # assistant separator to avoid consecutive user messages (API error). if recent_turns and recent_turns[0].get("role") == "user": - self._history = [summary_msg, {"role": "assistant", "content": "Understood."}, *recent_turns] + self._history = [ + summary_msg, + {"role": "assistant", "content": "Understood."}, + *recent_turns, + ] else: self._history = [summary_msg] + recent_turns @@ -377,15 +432,19 @@ async def turn(self, user_input: str | list[dict]) -> str: while response.tool_calls: tool_round += 1 if tool_round > _MAX_TOOL_ROUNDS: - self._history.append({"role": "assistant", "content": response.content or ""}) - self._history.append({ - "role": "user", - "content": ( - f"SYSTEM: You have used {_MAX_TOOL_ROUNDS} tool-call rounds on this turn. " - "Stop retrying. Summarize what you accomplished and what failed, " - "then tell the user what they can do to unblock the issue." - ), - }) + self._history.append( + {"role": "assistant", "content": response.content or ""} + ) + self._history.append( + { + "role": "user", + "content": ( + f"SYSTEM: You have used {_MAX_TOOL_ROUNDS} tool-call rounds on this turn. " + "Stop retrying. Summarize what you accomplished and what failed, " + "then tell the user what they can do to unblock the issue." + ), + } + ) response = await self._llm.plan( system=system, messages=self._history, @@ -397,12 +456,14 @@ async def turn(self, user_input: str | list[dict]) -> str: if response.content: assistant_content.append({"type": "text", "text": response.content}) for tc in response.tool_calls: - assistant_content.append({ - "type": "tool_use", - "id": tc.id, - "name": tc.name, - "input": tc.input, - }) + assistant_content.append( + { + "type": "tool_use", + "id": tc.id, + "name": tc.name, + "input": tc.input, + } + ) self._history.append({"role": "assistant", "content": assistant_content}) # Process each tool call via registry @@ -415,14 +476,19 @@ async def turn(self, user_input: str | list[dict]) -> str: result_text = _scrub_credentials(result_text) result_text = _apply_error_tracking( - result_text, tc.name, error_streak, resilience_nudged, + result_text, + tc.name, + error_streak, + resilience_nudged, ) - tool_results.append({ - "type": "tool_result", - "tool_use_id": tc.id, - "content": result_text, - }) + tool_results.append( + { + "type": "tool_result", + "tool_use_id": tc.id, + "content": result_text, + } + ) self._history.append({"role": "user", "content": tool_results}) @@ -457,13 +523,17 @@ async def turn(self, user_input: str | list[dict]) -> str: return reply - async def turn_stream(self, user_input: str | list[dict]) -> AsyncIterator[StreamEvent]: + async def turn_stream( + self, user_input: str | list[dict] + ) -> AsyncIterator[StreamEvent]: """Streaming version of turn(). Yields events as they arrive.""" self._history.append({"role": "user", "content": user_input}) # Log user input to episodic memory if self._episodic is not None: - content = user_input if isinstance(user_input, str) else str(user_input)[:2000] + content = ( + user_input if isinstance(user_input, str) else str(user_input)[:2000] + ) self._episodic.log_turn(self._turn_count + 1, "user", content) user_msg_str = user_input if isinstance(user_input, str) else "" @@ -476,7 +546,9 @@ async def turn_stream(self, user_input: str | list[dict]) -> AsyncIterator[Strea # Log assistant response to episodic memory if self._episodic is not None and assistant_text_parts: self._episodic.log_turn( - self._turn_count + 1, "assistant", "".join(assistant_text_parts)[:2000], + self._turn_count + 1, + "assistant", + "".join(assistant_text_parts)[:2000], ) # Identity extraction (Default Mode Network — every 5 turns) @@ -488,7 +560,9 @@ async def turn_stream(self, user_input: str | list[dict]) -> AsyncIterator[Strea # Periodic memory vacuum (Systems Consolidation) self._cortex.maybe_vacuum() - async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterator[StreamEvent]: + async def _stream_and_handle_tools( + self, user_message: str = "" + ) -> AsyncIterator[StreamEvent]: """Stream one LLM call, handle tool loops, yield all events.""" system = await self._build_system_prompt(user_message) tools = self._build_tools() @@ -531,7 +605,10 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato llm_response = response.response # Proactive compaction - if not _compacted_this_turn and llm_response.usage.context_pressure > _CONTEXT_PRESSURE_THRESHOLD: + if ( + not _compacted_this_turn + and llm_response.usage.context_pressure > _CONTEXT_PRESSURE_THRESHOLD + ): await self._summarize_history() self._compact_scratchpads() _compacted_this_turn = True @@ -554,15 +631,19 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato tool_round += 1 if tool_round > _MAX_TOOL_ROUNDS: _max_rounds_hit = True - self._history.append({"role": "assistant", "content": llm_response.content or ""}) - self._history.append({ - "role": "user", - "content": ( - f"SYSTEM: You have used {_MAX_TOOL_ROUNDS} tool-call rounds on this turn. " - "Stop retrying. Summarize what you accomplished and what failed, " - "then tell the user what they can do to unblock the issue." - ), - }) + self._history.append( + {"role": "assistant", "content": llm_response.content or ""} + ) + self._history.append( + { + "role": "user", + "content": ( + f"SYSTEM: You have used {_MAX_TOOL_ROUNDS} tool-call rounds on this turn. " + "Stop retrying. Summarize what you accomplished and what failed, " + "then tell the user what they can do to unblock the issue." + ), + } + ) async for event in self._llm.plan_stream( system=system, messages=self._history, @@ -573,15 +654,21 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato # Build assistant message with content blocks assistant_content: list[dict] = [] if llm_response.content: - assistant_content.append({"type": "text", "text": llm_response.content}) + assistant_content.append( + {"type": "text", "text": llm_response.content} + ) for tc in llm_response.tool_calls: - assistant_content.append({ - "type": "tool_use", - "id": tc.id, - "name": tc.name, - "input": tc.input, - }) - self._history.append({"role": "assistant", "content": assistant_content}) + assistant_content.append( + { + "type": "tool_use", + "id": tc.id, + "name": tc.name, + "input": tc.input, + } + ) + self._history.append( + {"role": "assistant", "content": assistant_content} + ) # Process each tool call tool_results: list[dict] = [] @@ -590,7 +677,9 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato if self._episodic is not None: tc_desc = str(tc.input)[:2000] self._episodic.log_turn( - self._turn_count + 1, "tool_call", tc_desc, + self._turn_count + 1, + "tool_call", + tc_desc, tool=tc.name, ) @@ -601,7 +690,13 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato if isinstance(prep, str): result_text = prep else: - pad, code, description, estimated_time, estimated_seconds = prep + ( + pad, + code, + description, + estimated_time, + estimated_seconds, + ) = prep # Signal intent + ETA before execution begins yield StreamTaskProgress( phase="scratchpad_start", @@ -609,8 +704,10 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato eta_seconds=estimated_seconds, ) import time as _time + _sp_t0 = _time.monotonic() from anton.scratchpad import Cell + cell = None async for item in pad.execute_streaming( code, @@ -631,18 +728,26 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato message=description or "Done", eta_seconds=_sp_elapsed, ) - result_text = format_cell_result(cell) if cell else "No result produced." + result_text = ( + format_cell_result(cell) + if cell + else "No result produced." + ) # Log scratchpad cell to episodic memory if self._episodic is not None and cell is not None: self._episodic.log_turn( - self._turn_count + 1, "scratchpad", + self._turn_count + 1, + "scratchpad", (cell.stdout or "")[:2000], description=description, ) else: result_text = await dispatch_tool(self, tc.name, tc.input) - if tc.name == "scratchpad" and tc.input.get("action") == "dump": + if ( + tc.name == "scratchpad" + and tc.input.get("action") == "dump" + ): yield StreamToolResult(content=result_text) result_text = ( "The full notebook has been displayed to the user above. " @@ -655,25 +760,34 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato # Log tool result to episodic memory if self._episodic is not None: self._episodic.log_turn( - self._turn_count + 1, "tool_result", result_text[:2000], + self._turn_count + 1, + "tool_result", + result_text[:2000], tool=tc.name, ) result_text = _scrub_credentials(result_text) result_text = _apply_error_tracking( - result_text, tc.name, error_streak, resilience_nudged, + result_text, + tc.name, + error_streak, + resilience_nudged, ) - tool_results.append({ - "type": "tool_result", - "tool_use_id": tc.id, - "content": result_text, - }) + tool_results.append( + { + "type": "tool_result", + "tool_use_id": tc.id, + "content": result_text, + } + ) self._history.append({"role": "user", "content": tool_results}) # Signal that tools are done and LLM is now analyzing - yield StreamTaskProgress(phase="analyzing", message="Analyzing results...") + yield StreamTaskProgress( + phase="analyzing", message="Analyzing results..." + ) # Stream follow-up response = None @@ -708,7 +822,11 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato llm_response = response.response # Proactive compaction during tool loop - if not _compacted_this_turn and llm_response.usage.context_pressure > _CONTEXT_PRESSURE_THRESHOLD: + if ( + not _compacted_this_turn + and llm_response.usage.context_pressure + > _CONTEXT_PRESSURE_THRESHOLD + ): await self._summarize_history() self._compact_scratchpads() _compacted_this_turn = True @@ -728,18 +846,20 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato if continuation >= _MAX_CONTINUATIONS: # Budget exhausted — ask LLM to diagnose and present to user - self._history.append({ - "role": "user", - "content": ( - "SYSTEM: You have attempted to complete this task multiple times " - "but verification indicates it is still not done. Do NOT try again. " - "Instead:\n" - "1. Summarize exactly what was accomplished so far.\n" - "2. Identify the specific blocker or failure preventing completion.\n" - "3. Suggest concrete next steps the user can take to unblock this.\n" - "Be honest and specific — do not be vague about what went wrong." - ), - }) + self._history.append( + { + "role": "user", + "content": ( + "SYSTEM: You have attempted to complete this task multiple times " + "but verification indicates it is still not done. Do NOT try again. " + "Instead:\n" + "1. Summarize exactly what was accomplished so far.\n" + "2. Identify the specific blocker or failure preventing completion.\n" + "3. Suggest concrete next steps the user can take to unblock this.\n" + "Be honest and specific — do not be vague about what went wrong." + ), + } + ) yield StreamTaskProgress( phase="analyzing", message="Diagnosing incomplete task..." ) @@ -754,13 +874,15 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato # Ask the LLM to self-assess completion. # Use a copy of history with a trailing user message so models # that don't support assistant-prefill won't reject the request. - verify_messages = list(self._history) + [{ - "role": "user", - "content": ( - "SYSTEM: Evaluate whether the task the user originally requested " - "has been fully completed based on the conversation above." - ), - }] + verify_messages = list(self._history) + [ + { + "role": "user", + "content": ( + "SYSTEM: Evaluate whether the task the user originally requested " + "has been fully completed based on the conversation above." + ), + } + ] verification = await self._llm.plan( system=( "You are a task-completion verifier. Given the conversation, determine " @@ -786,15 +908,17 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato if "STATUS: STUCK" in status_text: # Stuck — inject diagnosis request and let the LLM explain reason = (verification.content or "").strip() - self._history.append({ - "role": "user", - "content": ( - f"SYSTEM: Task verification determined this task is stuck.\n" - f"Verifier assessment: {reason}\n\n" - "Explain to the user what went wrong, what you tried, and " - "suggest specific next steps they can take to unblock this." - ), - }) + self._history.append( + { + "role": "user", + "content": ( + f"SYSTEM: Task verification determined this task is stuck.\n" + f"Verifier assessment: {reason}\n\n" + "Explain to the user what went wrong, what you tried, and " + "suggest specific next steps they can take to unblock this." + ), + } + ) yield StreamTaskProgress( phase="analyzing", message="Diagnosing blocked task..." ) @@ -808,16 +932,18 @@ async def _stream_and_handle_tools(self, user_message: str = "") -> AsyncIterato # INCOMPLETE — continue working continuation += 1 reason = (verification.content or "").strip() - self._history.append({ - "role": "user", - "content": ( - f"SYSTEM: Task verification determined this task is not yet complete " - f"(attempt {continuation}/{_MAX_CONTINUATIONS}).\n" - f"Verifier assessment: {reason}\n\n" - "Continue working on the original request. Pick up where you left off " - "and finish the remaining work. Do not repeat work already done." - ), - }) + self._history.append( + { + "role": "user", + "content": ( + f"SYSTEM: Task verification determined this task is not yet complete " + f"(attempt {continuation}/{_MAX_CONTINUATIONS}).\n" + f"Verifier assessment: {reason}\n\n" + "Continue working on the original request. Pick up where you left off " + "and finish the remaining work. Do not repeat work already done." + ), + } + ) yield StreamTaskProgress( phase="analyzing", message=f"Task incomplete — continuing ({continuation}/{_MAX_CONTINUATIONS})...", @@ -918,10 +1044,9 @@ def _apply_error_tracking( def _register_secret_vars(engine_def: "DatasourceEngine") -> None: - """Record which DS_* var names correspond to secret fields for engine_def. - """ + """Record which DS_* var names correspond to secret fields for engine_def.""" all_fields = list(engine_def.fields) - for am in (engine_def.auth_methods or []): + for am in engine_def.auth_methods or []: all_fields.extend(am.fields) for f in all_fields: key = f"DS_{f.name.upper()}" @@ -989,15 +1114,16 @@ def _build_runtime_context(settings: AntonSettings) -> str: f"- Workspace: {settings.workspace_path}\n" f"- Memory mode: {settings.memory_mode}" ) - if settings.minds_api_key and (settings.minds_mind_name or settings.minds_datasource): + if settings.minds_api_key and ( + settings.minds_mind_name or settings.minds_datasource + ): engine = settings.minds_datasource_engine or "unknown" ctx += f"\n\n**CONNECTED MIND (Minds):**\n" if settings.minds_mind_name: ctx += f"- Mind: {settings.minds_mind_name}\n" if settings.minds_datasource: ctx += ( - f"- Datasource: {settings.minds_datasource}\n" - f"- Engine: {engine}\n" + f"- Datasource: {settings.minds_datasource}\n" f"- Engine: {engine}\n" ) ctx += ( f"- Minds URL: {settings.minds_url}\n" @@ -1039,7 +1165,8 @@ def _rebuild_session( runtime_context = _build_runtime_context(settings) api_key = ( - settings.anthropic_api_key if settings.coding_provider == "anthropic" + settings.anthropic_api_key + if settings.coding_provider == "anthropic" else settings.openai_api_key ) or "" return ChatSession( @@ -1089,17 +1216,34 @@ def _show_scope(label: str, hc) -> int: identity = hc.recall_identity() rules = hc.recall_rules() lessons_raw = hc._read_full_lessons() - rule_count = sum(1 for ln in rules.splitlines() if ln.strip().startswith("- ")) if rules else 0 - lesson_count = sum(1 for ln in lessons_raw.splitlines() if ln.strip().startswith("- ")) if lessons_raw else 0 + rule_count = ( + sum(1 for ln in rules.splitlines() if ln.strip().startswith("- ")) + if rules + else 0 + ) + lesson_count = ( + sum(1 for ln in lessons_raw.splitlines() if ln.strip().startswith("- ")) + if lessons_raw + else 0 + ) topics: list[str] = [] if hc._topics_dir.is_dir(): - topics = [p.stem for p in sorted(hc._topics_dir.iterdir()) if p.suffix == ".md"] + topics = [ + p.stem for p in sorted(hc._topics_dir.iterdir()) if p.suffix == ".md" + ] console.print(f" [anton.cyan]{label}[/] [dim]({hc._dir})[/]") if identity: - entries = [ln.strip()[2:] for ln in identity.splitlines() if ln.strip().startswith("- ")] + entries = [ + ln.strip()[2:] + for ln in identity.splitlines() + if ln.strip().startswith("- ") + ] if entries: - console.print(f" Identity: {', '.join(entries[:3])}" + (" ..." if len(entries) > 3 else "")) + console.print( + f" Identity: {', '.join(entries[:3])}" + + (" ..." if len(entries) > 3 else "") + ) else: console.print(" Identity: [dim](set)[/]") else: @@ -1230,7 +1374,9 @@ async def _handle_resume( new_session._turn_count = sum(1 for m in history if m.get("role") == "user") console.print() - console.print(f"[anton.success]Resumed session from {selected['date']} ({selected['turns']} turns)[/]") + console.print( + f"[anton.success]Resumed session from {selected['date']} ({selected['turns']} turns)[/]" + ) console.print() return new_session, sid @@ -1272,9 +1418,16 @@ async def _handle_setup( return session elif top_choice == "1": return await _handle_setup_models( - console, settings, workspace, state, - self_awareness, cortex, session, episodic=episodic, - history_store=history_store, session_id=session_id, + console, + settings, + workspace, + state, + self_awareness, + cortex, + session, + episodic=episodic, + history_store=history_store, + session_id=session_id, ) else: _handle_setup_memory(console, settings, workspace, cortex, episodic=episodic) @@ -1311,11 +1464,19 @@ async def _handle_setup_models( # --- Provider --- providers = {"1": "anthropic", "2": "openai", "3": "openai-compatible"} - current_num = {"anthropic": "1", "openai": "2", "openai-compatible": "3"}.get(settings.planning_provider, "1") + current_num = {"anthropic": "1", "openai": "2", "openai-compatible": "3"}.get( + settings.planning_provider, "1" + ) console.print("[anton.cyan]Available providers:[/]") - console.print(r" [bold]1[/] Anthropic (Claude) [dim]\[recommended][/]") - console.print(r" [bold]2[/] OpenAI (GPT / o-series) [dim]\[experimental][/]") - console.print(r" [bold]3[/] OpenAI-compatible (custom endpoint) [dim]\[experimental][/]") + console.print( + r" [bold]1[/] Anthropic (Claude) [dim]\[recommended][/]" + ) + console.print( + r" [bold]2[/] OpenAI (GPT / o-series) [dim]\[experimental][/]" + ) + console.print( + r" [bold]3[/] OpenAI-compatible (custom endpoint) [dim]\[experimental][/]" + ) console.print() choice = Prompt.ask( @@ -1343,7 +1504,9 @@ async def _handle_setup_models( # --- API key --- key_attr = "anthropic_api_key" if provider == "anthropic" else "openai_api_key" current_key = getattr(settings, key_attr) or "" - masked = current_key[:4] + "..." + current_key[-4:] if len(current_key) > 8 else "***" + masked = ( + current_key[:4] + "..." + current_key[-4:] if len(current_key) > 8 else "***" + ) console.print() api_key = Prompt.ask( f"API key for {provider.title()} [dim](Enter to keep {masked})[/]", @@ -1362,12 +1525,20 @@ async def _handle_setup_models( console.print() planning_model = Prompt.ask( "Planning model", - default=settings.planning_model if provider == settings.planning_provider else default_planning, + default=( + settings.planning_model + if provider == settings.planning_provider + else default_planning + ), console=console, ) coding_model = Prompt.ask( "Coding model", - default=settings.coding_model if provider == settings.coding_provider else default_coding, + default=( + settings.coding_model + if provider == settings.coding_provider + else default_coding + ), console=console, ) @@ -1391,7 +1562,9 @@ async def _handle_setup_models( final_key = getattr(settings, key_attr) if not final_key: console.print() - console.print(f"[anton.error]No API key set for {provider}. Configuration not applied.[/]") + console.print( + f"[anton.error]No API key set for {provider}. Configuration not applied.[/]" + ) console.print() return session @@ -1428,9 +1601,15 @@ def _handle_setup_memory( # --- Memory mode --- console.print(" Memory mode:") - console.print(r" [bold]1[/] Autopilot — Anton decides what to remember [dim]\[recommended][/]") - console.print(r" [bold]2[/] Co-pilot — save obvious, confirm ambiguous [dim]\[selective][/]") - console.print(r" [bold]3[/] Off — never save memory (still reads existing) [dim]\[suppressed][/]") + console.print( + r" [bold]1[/] Autopilot — Anton decides what to remember [dim]\[recommended][/]" + ) + console.print( + r" [bold]2[/] Co-pilot — save obvious, confirm ambiguous [dim]\[selective][/]" + ) + console.print( + r" [bold]3[/] Off — never save memory (still reads existing) [dim]\[suppressed][/]" + ) console.print() mode_map = {"1": "autopilot", "2": "copilot", "3": "off"} @@ -1453,7 +1632,9 @@ def _handle_setup_memory( if episodic is not None: console.print() ep_status = "ON" if episodic.enabled else "OFF" - console.print(f" Episodic memory (conversation archive): Currently [bold]{ep_status}[/]") + console.print( + f" Episodic memory (conversation archive): Currently [bold]{ep_status}[/]" + ) toggle = Prompt.ask( " Toggle episodic memory? (y/n)", choices=["y", "n"], @@ -1464,7 +1645,9 @@ def _handle_setup_memory( new_state = not episodic.enabled episodic.enabled = new_state settings.episodic_memory = new_state - workspace.set_secret("ANTON_EPISODIC_MEMORY", "true" if new_state else "false") + workspace.set_secret( + "ANTON_EPISODIC_MEMORY", "true" if new_state else "false" + ) console.print(f" Episodic memory: [bold]{'ON' if new_state else 'OFF'}[/]") console.print() @@ -1547,7 +1730,10 @@ def _describe_minds_connection_error(err: Exception) -> tuple[str, str]: "Connection failed during TLS certificate verification.", "Common reasons: a self-signed, expired, or otherwise untrusted certificate.", ) - if isinstance(reason, (TimeoutError, socket.timeout)) or "timed out" in str(reason).lower(): + if ( + isinstance(reason, (TimeoutError, socket.timeout)) + or "timed out" in str(reason).lower() + ): return ( "Connection failed because the request timed out.", "Common reasons: the server is slow or unavailable, the URL is wrong, or there is a network path issue.", @@ -1590,7 +1776,10 @@ def _minds_request( req.add_header("Content-Type", "application/json") req.add_header("Accept", "application/json") # Browser-like headers to avoid Cloudflare bot detection - req.add_header("User-Agent", "Mozilla/5.0 (compatible; Anton/1.0; +https://github.com/mindsdb/anton)") + req.add_header( + "User-Agent", + "Mozilla/5.0 (compatible; Anton/1.0; +https://github.com/mindsdb/anton)", + ) req.add_header("Accept-Language", "en-US,en;q=0.9") req.add_header("Accept-Encoding", "identity") req.add_header("Connection", "keep-alive") @@ -1618,7 +1807,9 @@ def _minds_list_minds(base_url: str, api_key: str, verify: bool = True) -> list[ return data.get("minds", data if isinstance(data, list) else []) -def _minds_get_mind(base_url: str, api_key: str, mind_name: str, verify: bool = True) -> dict | None: +def _minds_get_mind( + base_url: str, api_key: str, mind_name: str, verify: bool = True +) -> dict | None: """Fetch a single mind's details from a Minds server.""" import json as _json @@ -1661,7 +1852,9 @@ def _minds_refresh_knowledge(settings: AntonSettings, cortex) -> None: cortex.project_hc._encode_with_lock(topic_path, topic_content, mode="write") -def _minds_list_datasources(base_url: str, api_key: str, verify: bool = True) -> list[dict]: +def _minds_list_datasources( + base_url: str, api_key: str, verify: bool = True +) -> list[dict]: """Fetch datasource list from a Minds server using stdlib urllib.""" import json as _json @@ -1680,11 +1873,13 @@ def _minds_test_llm(base_url: str, api_key: str, verify: bool = True) -> bool: import json as _json url = f"{base_url}/api/v1/chat/completions" - payload = _json.dumps({ - "model": "_code_", - "messages": [{"role": "user", "content": "ping"}], - "max_tokens": 1, - }).encode() + payload = _json.dumps( + { + "model": "_code_", + "messages": [{"role": "user", "content": "ping"}], + "max_tokens": 1, + } + ).encode() try: _minds_request(url, api_key, method="POST", payload=payload, verify=verify) @@ -1694,14 +1889,22 @@ def _minds_test_llm(base_url: str, api_key: str, verify: bool = True) -> bool: _MINDS_KEYS = { - "ANTON_MINDS_API_KEY", "ANTON_MINDS_URL", "ANTON_MINDS_MIND_NAME", - "ANTON_MINDS_DATASOURCE", "ANTON_MINDS_DATASOURCE_ENGINE", "ANTON_MINDS_SSL_VERIFY", + "ANTON_MINDS_API_KEY", + "ANTON_MINDS_URL", + "ANTON_MINDS_MIND_NAME", + "ANTON_MINDS_DATASOURCE", + "ANTON_MINDS_DATASOURCE_ENGINE", + "ANTON_MINDS_SSL_VERIFY", } _LLM_KEYS = { - "ANTON_PLANNING_PROVIDER", "ANTON_CODING_PROVIDER", - "ANTON_PLANNING_MODEL", "ANTON_CODING_MODEL", - "ANTON_ANTHROPIC_API_KEY", "ANTON_OPENAI_API_KEY", "ANTON_OPENAI_BASE_URL", + "ANTON_PLANNING_PROVIDER", + "ANTON_CODING_PROVIDER", + "ANTON_PLANNING_MODEL", + "ANTON_CODING_MODEL", + "ANTON_ANTHROPIC_API_KEY", + "ANTON_OPENAI_API_KEY", + "ANTON_OPENAI_BASE_URL", } _SECRET_PATTERNS = ("KEY", "TOKEN", "SECRET", "PAT", "PASSWORD") @@ -1736,7 +1939,9 @@ async def _handle_data_connections( # Merge with source tags: project keys override global for display, # but we track where each lives for writes/removals. - all_keys: dict[str, tuple[str, str, str]] = {} # key -> (value, source, scope_label) + all_keys: dict[str, tuple[str, str, str]] = ( + {} + ) # key -> (value, source, scope_label) for k, v in global_env.items(): all_keys[k] = (v, "global", "~/.anton/.env") for k, v in project_env.items(): @@ -1746,7 +1951,9 @@ async def _handle_data_connections( if not all_keys: console.print("[anton.warning]No connections or secrets configured.[/]") - console.print("[anton.muted]Use /connect to set up a Minds connection, or ask Anton to store a key.[/]") + console.print( + "[anton.muted]Use /connect to set up a Minds connection, or ask Anton to store a key.[/]" + ) console.print() return session @@ -1754,7 +1961,11 @@ def _print_table() -> list[tuple[str, str, str, str]]: """Print grouped key table and return flat list for menu selection.""" minds = {k: all_keys[k] for k in sorted(all_keys) if k in _MINDS_KEYS} llm = {k: all_keys[k] for k in sorted(all_keys) if k in _LLM_KEYS} - other = {k: all_keys[k] for k in sorted(all_keys) if k not in _MINDS_KEYS and k not in _LLM_KEYS} + other = { + k: all_keys[k] + for k in sorted(all_keys) + if k not in _MINDS_KEYS and k not in _LLM_KEYS + } flat: list[tuple[str, str, str, str]] = [] # (key, value, source, scope_label) idx = 1 @@ -1762,7 +1973,9 @@ def _print_table() -> list[tuple[str, str, str, str]]: if minds: console.print("[anton.cyan]Minds Connection[/]") for k, (v, src, lbl) in minds.items(): - console.print(f" [bold]{idx}[/] {k} = {_display_value(k, v)} [dim]({lbl})[/]") + console.print( + f" [bold]{idx}[/] {k} = {_display_value(k, v)} [dim]({lbl})[/]" + ) flat.append((k, v, src, lbl)) idx += 1 console.print() @@ -1770,7 +1983,9 @@ def _print_table() -> list[tuple[str, str, str, str]]: if llm: console.print("[anton.cyan]LLM Configuration[/]") for k, (v, src, lbl) in llm.items(): - console.print(f" [bold]{idx}[/] {k} = {_display_value(k, v)} [dim]({lbl})[/]") + console.print( + f" [bold]{idx}[/] {k} = {_display_value(k, v)} [dim]({lbl})[/]" + ) flat.append((k, v, src, lbl)) idx += 1 console.print() @@ -1778,7 +1993,9 @@ def _print_table() -> list[tuple[str, str, str, str]]: if other: console.print("[anton.cyan]Other Integrations[/]") for k, (v, src, lbl) in other.items(): - console.print(f" [bold]{idx}[/] {k} = {_display_value(k, v)} [dim]({lbl})[/]") + console.print( + f" [bold]{idx}[/] {k} = {_display_value(k, v)} [dim]({lbl})[/]" + ) flat.append((k, v, src, lbl)) idx += 1 console.print() @@ -1796,7 +2013,9 @@ def _print_table() -> list[tuple[str, str, str, str]]: console.print(" [bold]q[/] Back") console.print() - action = Prompt.ask("Select", choices=["1", "2", "3", "q"], default="q", console=console) + action = Prompt.ask( + "Select", choices=["1", "2", "3", "q"], default="q", console=console + ) if action == "q": console.print() @@ -1855,7 +2074,9 @@ def _print_table() -> list[tuple[str, str, str, str]]: continue key, _, src, lbl = flat[pick_idx] - if not Confirm.ask(f"Remove {key} from {lbl}?", default=False, console=console): + if not Confirm.ask( + f"Remove {key} from {lbl}?", default=False, console=console + ): console.print("[anton.muted]Cancelled.[/]") console.print() continue @@ -1869,14 +2090,20 @@ def _print_table() -> list[tuple[str, str, str, str]]: elif action == "3": # --- Add --- console.print() - new_key = Prompt.ask("Key name (e.g. HUBSPOT_API_KEY)", console=console).strip() + new_key = Prompt.ask( + "Key name (e.g. HUBSPOT_API_KEY)", console=console + ).strip() if not new_key: console.print("[anton.warning]Key name cannot be empty.[/]") console.print() continue if new_key in all_keys: - if not Confirm.ask(f"{new_key} already exists. Overwrite?", default=False, console=console): + if not Confirm.ask( + f"{new_key} already exists. Overwrite?", + default=False, + console=console, + ): console.print("[anton.muted]Cancelled.[/]") console.print() continue @@ -1899,7 +2126,11 @@ def _print_table() -> list[tuple[str, str, str, str]]: console=console, ) target_ws = global_ws if scope == "global" else workspace - scope_label = "~/.anton/.env" if scope == "global" else f"{workspace.base}/.anton/.env" + scope_label = ( + "~/.anton/.env" + if scope == "global" + else f"{workspace.base}/.anton/.env" + ) target_ws.set_secret(new_key, new_val) target_ws.apply_env_to_process() all_keys[new_key] = (new_val, scope, scope_label) @@ -2007,7 +2238,11 @@ async def _handle_connect( name = mind.get("name", "?") ds_list = mind.get("datasources", []) ds_count = len(ds_list) - ds_label = f"{ds_count} datasource{'s' if ds_count != 1 else ''}" if ds_count else "no datasources" + ds_label = ( + f"{ds_count} datasource{'s' if ds_count != 1 else ''}" + if ds_count + else "no datasources" + ) console.print(f" [bold]{i}[/] {name} [dim]({ds_label})[/]") console.print() @@ -2041,7 +2276,9 @@ async def _handle_connect( # --- Resolve engine type from datasources list --- if ds_name: try: - all_datasources = _minds_list_datasources(minds_url, api_key, verify=ssl_verify) + all_datasources = _minds_list_datasources( + minds_url, api_key, verify=ssl_verify + ) for ds in all_datasources: if ds.get("name") == ds_name: ds_engine = ds.get("engine", "unknown") @@ -2075,7 +2312,9 @@ async def _handle_connect( llm_ok = _minds_test_llm(minds_url, api_key, verify=ssl_verify) if llm_ok: - console.print("[anton.success]LLM endpoints available — using Minds server as LLM provider.[/]") + console.print( + "[anton.success]LLM endpoints available — using Minds server as LLM provider.[/]" + ) base_url = f"{minds_url.rstrip('/')}/api/v1" settings.openai_api_key = api_key settings.openai_base_url = base_url @@ -2091,7 +2330,9 @@ async def _handle_connect( global_ws.set_secret("ANTON_CODING_MODEL", "_code_") else: # Check if Anthropic key is already configured - has_anthropic = settings.anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY") + has_anthropic = settings.anthropic_api_key or os.environ.get( + "ANTHROPIC_API_KEY" + ) if not has_anthropic: anthropic_key = Prompt.ask("Anthropic API key (for LLM)", console=console) if anthropic_key.strip(): @@ -2108,14 +2349,21 @@ async def _handle_connect( global_ws.set_secret("ANTON_CODING_MODEL", "claude-haiku-4-5-20251001") console.print("[anton.success]Anthropic API key saved.[/]") else: - console.print("[anton.warning]No API key provided — LLM calls will not work.[/]") + console.print( + "[anton.warning]No API key provided — LLM calls will not work.[/]" + ) global_ws.apply_env_to_process() console.print() return _rebuild_session( - settings=settings, state=state, self_awareness=self_awareness, - cortex=cortex, workspace=workspace, console=console, episodic=episodic, + settings=settings, + state=state, + self_awareness=self_awareness, + cortex=cortex, + workspace=workspace, + console=console, + episodic=episodic, ) @@ -2151,30 +2399,55 @@ def _format_file_message(text: str, paths: list[Path], console: Console) -> str: # Skip very large files (>500KB) — just reference them if size > 512_000: - parts.append(f"\n\n(File too large to inline — {_human_size(size)}. " - f"Use the scratchpad to read it.)\n") + parts.append( + f'\n\n(File too large to inline — {_human_size(size)}. ' + f"Use the scratchpad to read it.)\n" + ) continue # Skip binary-looking files - if suffix in (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".ico", ".webp", - ".pdf", ".zip", ".tar", ".gz", ".exe", ".dll", ".so", - ".pyc", ".pyo", ".whl", ".egg", ".db", ".sqlite"): - parts.append(f"\n\n(Binary file — {_human_size(size)}. " - f"Use the scratchpad to process it.)\n") + if suffix in ( + ".png", + ".jpg", + ".jpeg", + ".gif", + ".bmp", + ".ico", + ".webp", + ".pdf", + ".zip", + ".tar", + ".gz", + ".exe", + ".dll", + ".so", + ".pyc", + ".pyo", + ".whl", + ".egg", + ".db", + ".sqlite", + ): + parts.append( + f'\n\n(Binary file — {_human_size(size)}. ' + f"Use the scratchpad to process it.)\n" + ) continue try: content = p.read_text(errors="replace") except Exception: - parts.append(f"\n\n(Could not read file.)\n") + parts.append(f'\n\n(Could not read file.)\n') continue - parts.append(f"\n\n{content}\n") + parts.append(f'\n\n{content}\n') return "\n".join(parts) -def _format_clipboard_image_message(uploaded: object, user_text: str = "") -> list[dict]: +def _format_clipboard_image_message( + uploaded: object, user_text: str = "" +) -> list[dict]: """Build a multimodal LLM message for a clipboard image upload. Returns a list of content blocks (image + text) so the LLM can see @@ -2183,7 +2456,11 @@ def _format_clipboard_image_message(uploaded: object, user_text: str = "") -> li """ import base64 - text = user_text.strip() if user_text else "I've pasted an image from my clipboard. Analyze it." + text = ( + user_text.strip() + if user_text + else "I've pasted an image from my clipboard. Analyze it." + ) text += ( f"\n\nThe image is also saved at: {uploaded.path}\n" f"({uploaded.width}x{uploaded.height}, {_human_size(uploaded.size_bytes)}). " @@ -2229,6 +2506,7 @@ async def _ensure_clipboard(console: Console) -> bool: return False console.print("[anton.muted]Installing Pillow...[/]") import subprocess + proc = await asyncio.get_event_loop().run_in_executor( None, lambda: subprocess.run( @@ -2251,7 +2529,9 @@ async def _ensure_clipboard(console: Console) -> bool: ), ) if proc.returncode == 0: - console.print("[anton.success]Pillow installed. Clipboard is now available.[/]") + console.print( + "[anton.success]Pillow installed. Clipboard is now available.[/]" + ) return True console.print("[anton.error]Failed to install Pillow.[/]") return False @@ -2289,23 +2569,27 @@ async def _handle_add_custom_datasource( try: response = await session._llm.plan( system="You are a data source connection expert.", - messages=[{ - "role": "user", - "content": ( - f"The user wants to connect to '{name}' and said: {user_answer}\n\n" - "Return ONLY valid JSON (no markdown fences, no commentary):\n" - '{"display_name":"Human-readable name","pip":"pip-package or empty string",' - '"fields":[{"name":"snake_case_name","value":"value if given inline else empty",' - '"secret":true or false,"required":true or false,"description":"what it is"}]}' - ), - }], + messages=[ + { + "role": "user", + "content": ( + f"The user wants to connect to '{name}' and said: {user_answer}\n\n" + "Return ONLY valid JSON (no markdown fences, no commentary):\n" + '{"display_name":"Human-readable name","pip":"pip-package or empty string",' + '"fields":[{"name":"snake_case_name","value":"value if given inline else empty",' + '"secret":true or false,"required":true or false,"description":"what it is"}]}' + ), + } + ], max_tokens=1024, ) text = response.content.strip() text = _re.sub(r"^```[^\n]*\n|```\s*$", "", text, flags=_re.MULTILINE).strip() data = _json.loads(text) except Exception: - console.print("[anton.warning] Couldn't identify connection details. Try again.[/]") + console.print( + "[anton.warning] Couldn't identify connection details. Try again.[/]" + ) console.print() return None @@ -2314,12 +2598,14 @@ async def _handle_add_custom_datasource( for f in raw_fields: if not isinstance(f, dict) or not f.get("name"): continue - fields.append(DatasourceField( - name=f["name"], - required=bool(f.get("required", True)), - secret=bool(f.get("secret", False)), - description=str(f.get("description", "")), - )) + fields.append( + DatasourceField( + name=f["name"], + required=bool(f.get("required", True)), + secret=bool(f.get("secret", False)), + description=str(f.get("description", "")), + ) + ) if not fields: console.print("[anton.warning] Couldn't identify any connection fields.[/]") @@ -2336,10 +2622,14 @@ async def _handle_add_custom_datasource( for f, raw in zip(fields, raw_fields): inline_value = str(raw.get("value", "")).strip() if f.secret and inline_value: - console.print(f" • [bold]{f.name:<14}[/] (secret — provided, stored securely)") + console.print( + f" • [bold]{f.name:<14}[/] (secret — provided, stored securely)" + ) credentials[f.name] = inline_value elif f.secret: - console.print(f" • [bold]{f.name:<14}[/] (secret — I'll ask for this)") + console.print( + f" • [bold]{f.name:<14}[/] (secret — I'll ask for this)" + ) else: val_display = inline_value or "[anton.muted][/]" console.print(f" • [bold]{f.name:<14}[/] {val_display}") @@ -2371,7 +2661,7 @@ async def _handle_add_custom_datasource( slug = _re.sub(r"[^\w]", "_", display_name.lower()).strip("_") field_lines = "\n".join( f" - {{ name: {f.name}, required: {str(f.required).lower()}, " - f"secret: {str(f.secret).lower()}, description: \"{f.description}\" }}" + f'secret: {str(f.secret).lower()}, description: "{f.description}" }}' for f in fields ) yaml_block = ( @@ -2387,12 +2677,15 @@ async def _handle_add_custom_datasource( tmp_path = user_ds_path.with_suffix(".tmp") # Write to temp, validate it parses, then rename atomically - existing = user_ds_path.read_text(encoding="utf-8") if user_ds_path.is_file() else "" + existing = ( + user_ds_path.read_text(encoding="utf-8") if user_ds_path.is_file() else "" + ) tmp_path.write_text(existing + yaml_block, encoding="utf-8") parsed = _ds_parse_file(tmp_path) if slug in parsed: import shutil + shutil.move(str(tmp_path), str(user_ds_path)) else: tmp_path.unlink(missing_ok=True) @@ -2419,13 +2712,151 @@ async def _handle_connect_datasource( console: Console, scratchpads: ScratchpadManager, session: "ChatSession", + datasource_name: str | None = None, ) -> "ChatSession": - """Interactive flow for connecting a new data source to the Local Vault.""" - from rich.prompt import Prompt + """ + Connect a data source by entering credentials, either for a new name or re-entering for an existing one. + """ vault = DataVault() registry = DatasourceRegistry() + # ── /edit-data-source path: re-enter credentials for an existing slug ───── + if datasource_name is not None: + slug_parts = datasource_name.split("-", 1) + if len(slug_parts) != 2: + console.print( + f"[anton.warning]Invalid slug '{datasource_name}'. " + "Expected format: engine-name.[/]" + ) + console.print() + return session + edit_engine, edit_name = slug_parts + if vault.load(edit_engine, edit_name) is None: + console.print( + f"[anton.warning]No connection '{datasource_name}' found in Local Vault.[/]" + ) + console.print() + return session + engine_def = registry.get(edit_engine) + if engine_def is None: + console.print( + f"[anton.warning]Unknown engine '{edit_engine}'. " + "Cannot update credentials.[/]" + ) + console.print() + return session + + console.print() + console.print( + f"[anton.cyan](anton)[/] Updating credentials for " + f'[bold]"{datasource_name}"[/bold] ({engine_def.display_name}).' + ) + console.print() + + credentials: dict[str, str] = {} + for f in engine_def.fields: + prompt_label = f"[anton.cyan](anton)[/] {f.name}" + if not f.required: + prompt_label += " [anton.muted](optional, press enter to skip)[/]" + if f.secret: + value = Prompt.ask( + prompt_label, password=True, console=console, default="" + ) + elif f.default: + value = Prompt.ask( + f"{prompt_label} [anton.muted][{f.default}][/]", + console=console, + default=f.default, + ) + else: + value = Prompt.ask(prompt_label, console=console, default="") + if value: + credentials[f.name] = value + + if engine_def.test_snippet: + while True: + console.print() + console.print("[anton.cyan](anton)[/] Got it. Testing connection…") + import os as _os + + for key, value in credentials.items(): + _os.environ[f"DS_{key.upper()}"] = value + try: + pad = await scratchpads.get_or_create("__datasource_test__") + await pad.reset() + if engine_def.pip: + await pad.install_packages([engine_def.pip]) + cell = await pad.execute(engine_def.test_snippet) + finally: + ds_keys = [k for k in _os.environ if k.startswith("DS_")] + for k in ds_keys: + del _os.environ[k] + + if cell.error or (cell.stdout.strip() != "ok" and cell.stderr.strip()): + error_text = ( + cell.error or cell.stderr.strip() or cell.stdout.strip() + ) + first_line = next( + (ln for ln in error_text.splitlines() if ln.strip()), error_text + ) + console.print() + console.print("[anton.warning](anton)[/] ✗ Connection failed.") + console.print() + console.print(f" Error: {first_line}") + console.print() + retry = ( + Prompt.ask( + "[anton.cyan](anton)[/] Would you like to re-enter your credentials? [y/n]", + console=console, + default="n", + ) + .strip() + .lower() + ) + if retry != "y": + return session + console.print() + for f in engine_def.fields: + if not f.secret: + continue + value = Prompt.ask( + f"[anton.cyan](anton)[/] {f.name}", + password=True, + console=console, + default="", + ) + if value: + credentials[f.name] = value + continue + + console.print("[anton.success] ✓ Connected successfully![/]") + break + + vault.save(edit_engine, edit_name, credentials) + vault.inject_env(edit_engine, edit_name) + _register_secret_vars(engine_def) + console.print() + console.print( + f' Credentials updated for [bold]"{datasource_name}"[/bold].' + ) + console.print() + console.print( + "[anton.muted] You can now ask me questions about your data.[/]" + ) + console.print() + session._history.append( + { + "role": "assistant", + "content": ( + f"I've updated the credentials for the {engine_def.display_name} connection " + f'"{datasource_name}" in the Local Vault.' + ), + } + ) + return session + + # ── Normal flow: connect a new (or reconnect an existing) data source ───── console.print() engine_names = ", ".join(e.display_name for e in registry.all_engines()) answer = Prompt.ask( @@ -2433,18 +2864,48 @@ async def _handle_connect_datasource( f" [anton.muted](e.g. {engine_names})[/]\n", console=console, ) - engine_def = registry.find_by_name(answer.strip()) + + # ── Reconnect path: user typed an existing vault slug ───────────────────── + stripped_answer = answer.strip() + known_slugs = {f"{c['engine']}-{c['name']}": c for c in vault.list_connections()} + if stripped_answer in known_slugs: + conn = known_slugs[stripped_answer] + vault.inject_env(conn["engine"], conn["name"]) + recon_engine_def = registry.get(conn["engine"]) + if recon_engine_def: + _register_secret_vars(recon_engine_def) + engine_label = recon_engine_def.display_name + else: + engine_label = conn["engine"] + console.print() + console.print( + f'[anton.success] ✓ Reconnected to [bold]"{stripped_answer}"[/bold].[/]' + ) + console.print() + session._history.append( + { + "role": "assistant", + "content": ( + f'I\'ve reconnected to the {engine_label} connection "{stripped_answer}" ' + f"in the Local Vault. I can now query this data source when needed." + ), + } + ) + return session + + engine_def = registry.find_by_name(stripped_answer) if engine_def is None: # Check whether the input is ambiguous before treating it as unknown - needle = answer.strip().lower() + needle = stripped_answer.lower() candidates = [ - e for e in registry.all_engines() + e + for e in registry.all_engines() if needle in e.display_name.lower() or needle in e.engine.lower() ] if len(candidates) > 1: console.print() console.print( - f"[anton.warning](anton)[/] '{answer}' matches multiple engines — " + f"[anton.warning](anton)[/] '{stripped_answer}' matches multiple engines — " "which one did you mean?" ) console.print() @@ -2462,7 +2923,9 @@ async def _handle_connect_datasource( console.print() return session else: - result = await _handle_add_custom_datasource(console, answer.strip(), registry, session) + result = await _handle_add_custom_datasource( + console, stripped_answer, registry, session + ) if result is None: return session engine_def, credentials = result @@ -2470,16 +2933,18 @@ async def _handle_connect_datasource( vault.save(engine_def.engine, str(conn_num), credentials) slug = f"{engine_def.engine}-{conn_num}" console.print( - f" Credentials saved to Local Vault as [bold]\"{slug}\"[/bold]." + f' Credentials saved to Local Vault as [bold]"{slug}"[/bold].' ) console.print() - session._history.append({ - "role": "assistant", - "content": ( - f"I've saved a {engine_def.display_name} connection named \"{slug}\" " - f"to the Local Vault. I can now query this data source when needed." - ), - }) + session._history.append( + { + "role": "assistant", + "content": ( + f'I\'ve saved a {engine_def.display_name} connection named "{slug}" ' + f"to the Local Vault. I can now query this data source when needed." + ), + } + ) return session # ── Step 2a: auth method choice (if engine requires it) ─────── @@ -2520,21 +2985,29 @@ async def _handle_connect_datasource( if required_fields: console.print(" [bold]Required[/] " + "─" * 39) for f in required_fields: - console.print(f" • [bold]{f.name:<12}[/] [anton.muted]— {f.description}[/]") + console.print( + f" • [bold]{f.name:<12}[/] [anton.muted]— {f.description}[/]" + ) if optional_fields: console.print() console.print(" [bold]Optional[/] " + "─" * 39) for f in optional_fields: - console.print(f" • [bold]{f.name:<12}[/] [anton.muted]— {f.description}[/]") + console.print( + f" • [bold]{f.name:<12}[/] [anton.muted]— {f.description}[/]" + ) console.print() # ── Step 3: determine collection mode ──────────────────────── - mode_answer = Prompt.ask( - "[anton.cyan](anton)[/] Do you have these available? [y/n/]", - console=console, - ).strip().lower() + mode_answer = ( + Prompt.ask( + "[anton.cyan](anton)[/] Do you have these available? [y/n/]", + console=console, + ) + .strip() + .lower() + ) if mode_answer == "n": console.print() @@ -2590,7 +3063,7 @@ async def _handle_connect_datasource( console.print() console.print( f"[anton.muted]Partial connection saved to Local Vault as " - f"[bold]\"{slug}\"[/bold]. " + f'[bold]"{slug}"[/bold]. ' f"Run [bold]/edit-data-source {slug}[/bold] to complete it when you're ready.[/]" ) console.print() @@ -2603,6 +3076,7 @@ async def _handle_connect_datasource( # Temporarily inject DS_* into os.environ (test before committing to vault) import os as _os + for key, value in credentials.items(): _os.environ[f"DS_{key.upper()}"] = value @@ -2632,11 +3106,15 @@ async def _handle_connect_datasource( console.print(f" Error: {first_line}") console.print() - retry = Prompt.ask( - "[anton.cyan](anton)[/] Would you like to re-enter your credentials? [y/n]", - console=console, - default="n", - ).strip().lower() + retry = ( + Prompt.ask( + "[anton.cyan](anton)[/] Would you like to re-enter your credentials? [y/n]", + console=console, + default="n", + ) + .strip() + .lower() + ) if retry != "y": return session @@ -2666,25 +3144,65 @@ async def _handle_connect_datasource( n = vault.next_connection_number(engine_def.engine) conn_name = str(n) + slug = f"{engine_def.engine}-{conn_name}" + + if vault.load(engine_def.engine, conn_name) is not None: + console.print() + console.print( + f'[anton.warning](anton)[/] A connection [bold]"{slug}"[/bold] already exists.' + ) + console.print() + choice = ( + Prompt.ask( + "[anton.cyan](anton)[/] [reconnect/cancel]", + console=console, + default="cancel", + ) + .strip() + .lower() + ) + if choice != "reconnect": + console.print("[anton.muted]Cancelled.[/]") + console.print() + return session + vault.inject_env(engine_def.engine, conn_name) + _register_secret_vars(engine_def) + console.print() + console.print( + f'[anton.success] ✓ Reconnected to [bold]"{slug}"[/bold].[/]' + ) + console.print() + session._history.append( + { + "role": "assistant", + "content": ( + f'I\'ve reconnected to the {engine_def.display_name} connection "{slug}" ' + f"in the Local Vault. I can now query this data source when needed." + ), + } + ) + return session + vault.save(engine_def.engine, conn_name, credentials) _register_secret_vars(engine_def) - slug = f"{engine_def.engine}-{conn_name}" - console.print( - f" Credentials saved to Local Vault as [bold]\"{slug}\"[/bold]." - ) + console.print(f' Credentials saved to Local Vault as [bold]"{slug}"[/bold].') console.print() - console.print("[anton.muted] You can now ask me questions about your data.[/]") + console.print( + "[anton.muted] You can now ask me questions about your data.[/]" + ) console.print() # Inject a brief assistant message so the LLM is aware of the new connection - session._history.append({ - "role": "assistant", - "content": ( - f"I've saved a {engine_def.display_name} connection named \"{slug}\" " - f"to the Local Vault. I can now query this data source when needed." - ), - }) + session._history.append( + { + "role": "assistant", + "content": ( + f'I\'ve saved a {engine_def.display_name} connection named "{slug}" ' + f"to the Local Vault. I can now query this data source when needed." + ), + } + ) return session @@ -2740,14 +3258,26 @@ def _print_slash_help(console: Console) -> None: """Print available slash commands.""" console.print() console.print("[anton.cyan]Available commands:[/]") - console.print(" [bold]/connect[/] — Connect to a Minds server and select a mind") - console.print(" [bold]/connect-data-source[/] — Connect a database or API to the Local Vault") - console.print(" [bold]/list-data-sources[/] — List all saved data source connections") + console.print( + " [bold]/connect[/] — Connect to a Minds server and select a mind" + ) + console.print( + " [bold]/connect-data-source[/] — Connect a database or API to the Local Vault" + ) + console.print( + " [bold]/list-data-sources[/] — List all saved data source connections" + ) console.print(" [bold]/remove-data-source[/] — Remove a saved connection") - console.print(" [bold]/data-connections[/] — View and manage stored keys and connections") - console.print(" [bold]/setup[/] — Configure models or memory settings") + console.print( + " [bold]/data-connections[/] — View and manage stored keys and connections" + ) + console.print( + " [bold]/setup[/] — Configure models or memory settings" + ) console.print(" [bold]/memory[/] — Show memory status dashboard") - console.print(" [bold]/paste[/] — Attach clipboard image to your message") + console.print( + " [bold]/paste[/] — Attach clipboard image to your message" + ) console.print(" [bold]/resume[/] — Resume a previous chat session") console.print(" [bold]/help[/] — Show this help message") console.print(" [bold]exit[/] — Quit the chat") @@ -2859,8 +3389,12 @@ def start(self) -> None: from rich.spinner import Spinner from rich.text import Text - spinner = Spinner("dots", text=Text(" Closing scratchpad processes…", style="anton.muted")) - self._live = Live(spinner, console=self._console, refresh_per_second=6, transient=True) + spinner = Spinner( + "dots", text=Text(" Closing scratchpad processes…", style="anton.muted") + ) + self._live = Live( + spinner, console=self._console, refresh_per_second=6, transient=True + ) self._live.start() def stop(self) -> None: @@ -2869,12 +3403,16 @@ def stop(self) -> None: self._live = None -def run_chat(console: Console, settings: AntonSettings, *, resume: bool = False) -> None: +def run_chat( + console: Console, settings: AntonSettings, *, resume: bool = False +) -> None: """Launch the interactive chat REPL.""" asyncio.run(_chat_loop(console, settings, resume=resume)) -async def _chat_loop(console: Console, settings: AntonSettings, *, resume: bool = False) -> None: +async def _chat_loop( + console: Console, settings: AntonSettings, *, resume: bool = False +) -> None: from anton.context.self_awareness import SelfAwarenessContext from anton.llm.client import LLMClient from anton.memory.cortex import Cortex @@ -2947,7 +3485,8 @@ async def _chat_loop(console: Console, settings: AntonSettings, *, resume: bool runtime_context = _build_runtime_context(settings) coding_api_key = ( - settings.anthropic_api_key if settings.coding_provider == "anthropic" + settings.anthropic_api_key + if settings.coding_provider == "anthropic" else settings.openai_api_key ) or "" session = ChatSession( @@ -2968,15 +3507,22 @@ async def _chat_loop(console: Console, settings: AntonSettings, *, resume: bool # Handle --resume flag at startup if resume: session, resumed_id = await _handle_resume( - console, settings, state, self_awareness, cortex, - workspace, session, episodic=episodic, + console, + settings, + state, + self_awareness, + cortex, + workspace, + session, + episodic=episodic, history_store=history_store, ) if resumed_id: current_session_id = resumed_id - - console.print("[anton.muted] Chat with Anton. Type '/help' for commands or 'exit' to quit.[/]") + console.print( + "[anton.muted] Chat with Anton. Type '/help' for commands or 'exit' to quit.[/]" + ) console.print(f"[anton.cyan_dim] {'━' * 40}[/]") console.print() @@ -3004,9 +3550,11 @@ def _bottom_toolbar(): line = status + " " * gap + stats return HTML(f"\n") - pt_style = PTStyle.from_dict({ - "bottom-toolbar": "noreverse nounderline bg:default", - }) + pt_style = PTStyle.from_dict( + { + "bottom-toolbar": "noreverse nounderline bg:default", + } + ) prompt_session: PromptSession[str] = PromptSession( mouse_support=False, @@ -3023,7 +3571,11 @@ def _bottom_toolbar(): for i, engram in enumerate(pending, 1): console.print(f" [bold]{i}.[/] [{engram.kind}] {engram.text}") console.print() - confirm = console.input("[bold]Save to memory? (y/n/pick numbers):[/] ").strip().lower() + confirm = ( + console.input("[bold]Save to memory? (y/n/pick numbers):[/] ") + .strip() + .lower() + ) if confirm in ("y", "yes"): if cortex is not None: await cortex.encode(pending) @@ -3033,11 +3585,19 @@ def _bottom_toolbar(): else: # Parse number selections like "1 3" or "1,3" try: - nums = [int(x.strip()) for x in confirm.replace(",", " ").split() if x.strip().isdigit()] - selected = [pending[n - 1] for n in nums if 1 <= n <= len(pending)] + nums = [ + int(x.strip()) + for x in confirm.replace(",", " ").split() + if x.strip().isdigit() + ] + selected = [ + pending[n - 1] for n in nums if 1 <= n <= len(pending) + ] if selected and cortex is not None: await cortex.encode(selected) - console.print(f"[anton.muted]Saved {len(selected)} entries.[/]") + console.print( + f"[anton.muted]Saved {len(selected)} entries.[/]" + ) else: console.print("[anton.muted]Discarded.[/]") except (ValueError, IndexError): @@ -3092,15 +3652,25 @@ def _bottom_toolbar(): cmd = parts[0].lower() if cmd == "/connect": session = await _handle_connect( - console, settings, workspace, state, - self_awareness, cortex, session, + console, + settings, + workspace, + state, + self_awareness, + cortex, + session, episodic=episodic, ) continue elif cmd == "/setup": session = await _handle_setup( - console, settings, workspace, state, - self_awareness, cortex, session, + console, + settings, + workspace, + state, + self_awareness, + cortex, + session, episodic=episodic, history_store=history_store, session_id=current_session_id, @@ -3111,7 +3681,9 @@ def _bottom_toolbar(): continue elif cmd == "/connect-data-source": session = await _handle_connect_datasource( - console, session._scratchpads, session, + console, + session._scratchpads, + session, ) continue elif cmd == "/list-data-sources": @@ -3124,34 +3696,47 @@ def _bottom_toolbar(): "[anton.warning]Usage: /remove-data-source" " [/]" ) - else: _handle_remove_data_source(console, arg) continue elif cmd == "/edit-data-source": arg = parts[1].strip() if len(parts) > 1 else "" - console.print( - f"[anton.muted]/edit-data-source is not yet implemented. " - f"To update '{arg}', use /remove-data-source {arg}" - f" then /connect-data-source.[/]" - ) - console.print() + if not arg: + console.print( + "[anton.warning]Usage: /edit-data-source [/]" + ) + console.print() + else: + session = await _handle_connect_datasource( + console, + session._scratchpads, + session, + datasource_name=arg, + ) continue elif cmd == "/test-data-source": console.print( - "[anton.muted]/test-data-source is not yet" - " implemented.[/]" + "[anton.muted]/test-data-source is not yet" " implemented.[/]" ) console.print() continue elif cmd == "/data-connections": session = await _handle_data_connections( - console, settings, workspace, session, + console, + settings, + workspace, + session, ) continue elif cmd == "/resume": session, resumed_id = await _handle_resume( - console, settings, state, self_awareness, cortex, - workspace, session, episodic=episodic, + console, + settings, + state, + self_awareness, + cortex, + workspace, + session, + episodic=episodic, history_store=history_store, ) if resumed_id: @@ -3172,7 +3757,9 @@ def _bottom_toolbar(): f"{_human_size(uploaded.size_bytes)})[/]" ) user_text = parts[1] if len(parts) > 1 else "" - message_content = _format_clipboard_image_message(uploaded, user_text) + message_content = _format_clipboard_image_message( + uploaded, user_text + ) # Fall through to turn_stream (don't continue) else: console.print("[anton.warning]No image found on clipboard.[/]") @@ -3241,6 +3828,7 @@ def _bottom_toolbar(): ) settings.anthropic_api_key = None from anton.cli import _ensure_api_key + _ensure_api_key(settings) session = _rebuild_session( settings=settings, diff --git a/anton/cli.py b/anton/cli.py index 033de34..fa922e5 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -103,7 +103,9 @@ def _ensure_dependencies(console: Console) -> None: ): import subprocess - console.print(f"[anton.muted] Running: uv pip install {' '.join(missing)}[/]") + console.print( + f"[anton.muted] Running: uv pip install {' '.join(missing)}[/]" + ) result = subprocess.run( [uv, "pip", "install", "--python", sys.executable, *missing], capture_output=True, @@ -113,12 +115,18 @@ def _ensure_dependencies(console: Console) -> None: _reexec() else: console.print(f"[anton.error] Install failed:[/]") - console.print(result.stderr.decode() if result.stderr else result.stdout.decode()) + console.print( + result.stderr.decode() if result.stderr else result.stdout.decode() + ) if install_script.is_file(): if sys.platform == "win32": - console.print(f"\n[anton.muted] Or run the install script: powershell -File {install_script}[/]") + console.print( + f"\n[anton.muted] Or run the install script: powershell -File {install_script}[/]" + ) else: - console.print(f"\n[anton.muted] Or run the install script: sh {install_script}[/]") + console.print( + f"\n[anton.muted] Or run the install script: sh {install_script}[/]" + ) raise typer.Exit(0) elif install_script.is_file(): console.print(f"To install all dependencies, run:") @@ -133,12 +141,17 @@ def _ensure_dependencies(console: Console) -> None: console.print(f" [bold]pip install {' '.join(missing)}[/]") console.print() if sys.platform == "win32": - console.print("[anton.muted]Or reinstall anton: irm https://raw.githubusercontent.com/mindsdb/anton/main/install.ps1 | iex[/]") + console.print( + "[anton.muted]Or reinstall anton: irm https://raw.githubusercontent.com/mindsdb/anton/main/install.ps1 | iex[/]" + ) else: - console.print('[anton.muted]Or reinstall anton: curl -sSf https://raw.githubusercontent.com/mindsdb/anton/main/install.sh | sh && export PATH="$HOME/.local/bin:$PATH"[/]') + console.print( + '[anton.muted]Or reinstall anton: curl -sSf https://raw.githubusercontent.com/mindsdb/anton/main/install.sh | sh && export PATH="$HOME/.local/bin:$PATH"[/]' + ) console.print() raise typer.Exit(1) + app = typer.Typer( name="anton", help="Anton — a self-evolving autonomous system", @@ -206,6 +219,7 @@ def main( settings.resolve_workspace(folder) from anton.updater import check_and_update + if check_and_update(console, settings): # Re-exec with the freshly installed code so no old modules remain in memory. _reexec() @@ -227,9 +241,13 @@ def _has_api_key(settings) -> bool: """Check if all configured providers have API keys.""" providers = {settings.planning_provider, settings.coding_provider} for p in providers: - if p == "anthropic" and not (settings.anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY")): + if p == "anthropic" and not ( + settings.anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY") + ): return False - if p in ("openai", "openai-compatible") and not (settings.openai_api_key or os.environ.get("OPENAI_API_KEY")): + if p in ("openai", "openai-compatible") and not ( + settings.openai_api_key or os.environ.get("OPENAI_API_KEY") + ): return False return True @@ -316,12 +334,15 @@ def _ensure_minds_api_key(settings, ws) -> None: # Test if the Minds server supports LLM endpoints (_code_/_reason_) # (silenced: was printing "Testing LLM endpoints..." and "not available" messages) from anton.chat import _minds_test_llm + llm_ok = _minds_test_llm(minds_url, api_key, verify=True) if not llm_ok: llm_ok = _minds_test_llm(minds_url, api_key, verify=False) if llm_ok: - console.print("[anton.success]LLM endpoints available — using Minds server as LLM provider.[/]") + console.print( + "[anton.success]LLM endpoints available — using Minds server as LLM provider.[/]" + ) base_url = f"{minds_url}/api/v1" settings.openai_api_key = api_key settings.openai_base_url = base_url @@ -443,8 +464,18 @@ def version() -> None: @app.command("connect-data-source") -def connect_data_source(ctx: typer.Context) -> None: - """Connect a database or API to the Local Vault.""" +def connect_data_source( + ctx: typer.Context, + slug: str = typer.Argument( + default="", help="Existing connection slug to reconnect (e.g. postgres-mydb)." + ), +) -> None: + """Connect a database or API to the Local Vault. + + Pass an existing connection slug (e.g. postgres-mydb) to reconnect using + stored credentials without re-entering them. Use /edit-data-source to + update credentials for an existing connection. + """ import asyncio from anton.chat import ChatSession, _handle_connect_datasource @@ -463,12 +494,18 @@ def connect_data_source(ctx: typer.Context) -> None: settings.anthropic_api_key if settings.coding_provider == "anthropic" else settings.openai_api_key - ) or "", + ) + or "", ) session = ChatSession(llm_client) async def _run() -> None: - updated = await _handle_connect_datasource(console, scratchpads, session) + updated = await _handle_connect_datasource( + console, + scratchpads, + session, + datasource_name=slug or None, + ) await updated._scratchpads.close_all() - asyncio.run(_run()) \ No newline at end of file + asyncio.run(_run()) diff --git a/anton/data_vault.py b/anton/data_vault.py index 560b06b..8ce1d25 100644 --- a/anton/data_vault.py +++ b/anton/data_vault.py @@ -1,11 +1,3 @@ -"""Local Vault — stores data source connection credentials. - -Each connection is saved as a JSON file in ~/.anton/data_vault/. -Files are named {engine}-{connection_name} and contain the credential -fields in plain text. Files are only read at runtime when Anton needs -to establish or test a connection. -""" - from __future__ import annotations import json @@ -26,7 +18,6 @@ class DataVault: def __init__(self, vault_dir: Path | None = None) -> None: self._dir = vault_dir or Path("~/.anton/data_vault").expanduser() - def _path_for(self, engine: str, name: str) -> Path: return self._dir / f"{_sanitize(engine)}-{_sanitize(name)}" @@ -34,7 +25,6 @@ def _ensure_dir(self) -> None: self._dir.mkdir(parents=True, exist_ok=True) self._dir.chmod(0o700) - def save(self, engine: str, name: str, credentials: dict[str, str]) -> Path: """Write credentials as JSON atomically. Creates vault dir if needed.""" self._ensure_dir() @@ -80,16 +70,17 @@ def list_connections(self) -> list[dict[str, str]]: continue try: data = json.loads(path.read_text(encoding="utf-8")) - results.append({ - "engine": data.get("engine", ""), - "name": data.get("name", ""), - "created_at": data.get("created_at", ""), - }) + results.append( + { + "engine": data.get("engine", ""), + "name": data.get("name", ""), + "created_at": data.get("created_at", ""), + } + ) except (json.JSONDecodeError, OSError): continue return results - def inject_env(self, engine: str, name: str) -> list[str] | None: """Load credentials and set DS_ in os.environ. @@ -112,7 +103,6 @@ def clear_ds_env(self) -> None: for key in ds_keys: del os.environ[key] - def next_connection_number(self, engine: str) -> int: """Return the next auto-increment number for an engine (1-based). @@ -122,12 +112,13 @@ def next_connection_number(self, engine: str) -> int: if not self._dir.is_dir(): return 1 existing = [ - p.name for p in self._dir.iterdir() + p.name + for p in self._dir.iterdir() if p.is_file() and p.name.startswith(prefix) ] max_n = 0 for fname in existing: - suffix = fname[len(prefix):] + suffix = fname[len(prefix) :] if suffix.isdigit(): max_n = max(max_n, int(suffix)) return max_n + 1 diff --git a/anton/datasource_registry.py b/anton/datasource_registry.py index f34b334..b106833 100644 --- a/anton/datasource_registry.py +++ b/anton/datasource_registry.py @@ -1,9 +1,3 @@ -"""Datasource registry — parses datasources.md engine definitions. - -Reads YAML blocks from the built-in datasources.md (project root) and -merges user overrides from ~/.anton/datasources.md on top. -""" - from __future__ import annotations import re @@ -56,13 +50,15 @@ def _parse_fields(raw: list) -> list[DatasourceField]: for f in raw or []: if not isinstance(f, dict): continue - result.append(DatasourceField( - name=f.get("name", ""), - required=bool(f.get("required", True)), - secret=bool(f.get("secret", False)), - description=f.get("description", ""), - default=str(f.get("default", "")), - )) + result.append( + DatasourceField( + name=f.get("name", ""), + required=bool(f.get("required", True)), + secret=bool(f.get("secret", False)), + description=f.get("description", ""), + default=str(f.get("default", "")), + ) + ) return result @@ -79,7 +75,11 @@ def _parse_file(path: Path) -> dict[str, DatasourceEngine]: data = yaml.safe_load(yaml_text) except yaml.YAMLError as exc: import sys - print(f"[anton] Warning: skipping malformed YAML block in {path}: {exc}", file=sys.stderr) + + print( + f"[anton] Warning: skipping malformed YAML block in {path}: {exc}", + file=sys.stderr, + ) continue if not isinstance(data, dict) or "engine" not in data: continue @@ -89,11 +89,13 @@ def _parse_file(path: Path) -> dict[str, DatasourceEngine]: for am in raw_auth_methods: if not isinstance(am, dict): continue - auth_methods.append(AuthMethod( - name=am.get("name", ""), - display=am.get("display", am.get("name", "")), - fields=_parse_fields(am.get("fields", [])), - )) + auth_methods.append( + AuthMethod( + name=am.get("name", ""), + display=am.get("display", am.get("name", "")), + fields=_parse_fields(am.get("fields", [])), + ) + ) engine_slug = str(data["engine"]) engines[engine_slug] = DatasourceEngine( @@ -129,14 +131,14 @@ def get(self, engine_slug: str) -> DatasourceEngine | None: return self._engines.get(engine_slug) def find_by_name(self, display_name: str) -> DatasourceEngine | None: - """Case-insensitive match on display_name or engine slug. - """ + """Case-insensitive match on display_name or engine slug.""" needle = display_name.strip().lower() for engine in self._engines.values(): if engine.display_name.lower() == needle or engine.engine.lower() == needle: return engine matches = [ - e for e in self._engines.values() + e + for e in self._engines.values() if needle in e.display_name.lower() or needle in e.engine.lower() ] return matches[0] if len(matches) == 1 else None @@ -144,7 +146,9 @@ def find_by_name(self, display_name: str) -> DatasourceEngine | None: def all_engines(self) -> list[DatasourceEngine]: return sorted(self._engines.values(), key=lambda e: e.display_name) - def derive_name(self, engine_def: DatasourceEngine, credentials: dict[str, str]) -> str: + def derive_name( + self, engine_def: DatasourceEngine, credentials: dict[str, str] + ) -> str: """Derive a default connection name from name_from field(s).""" name_from = engine_def.name_from if not name_from: From d1eed7f0590c6ff36199fd87b9103db085fcc716 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Fri, 20 Mar 2026 12:03:33 +0100 Subject: [PATCH 09/70] Cleanup connection variables if not active/connected --- anton/chat.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index 9f95e24..a9f9817 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -114,6 +114,7 @@ def __init__( self._history_store = history_store self._session_id = session_id self._cancel_event = asyncio.Event() + self._active_datasource: str | None = None # slug like "hubspot-2" self._scratchpads = ScratchpadManager( coding_provider=coding_provider, coding_model=getattr(llm_client, "coding_model", ""), @@ -191,7 +192,7 @@ async def _build_system_prompt(self, user_message: str = "") -> str: if md_context: prompt += md_context # Inject connected datasource context without credentials - ds_ctx = _build_datasource_context() + ds_ctx = _build_datasource_context(active_only=self._active_datasource) if ds_ctx: prompt += ds_ctx return prompt @@ -1079,11 +1080,13 @@ def _scrub_credentials(text: str) -> str: return text -def _build_datasource_context() -> str: +def _build_datasource_context(active_only: str | None = None) -> str: """Build a system-prompt section listing available DS_* env vars by name. Shows the LLM what data sources are connected and which environment variable names to use — without exposing any credential values. + + If active_only is set, only the matching slug is included. """ try: vault = DataVault() @@ -1099,9 +1102,12 @@ def _build_datasource_context() -> str: "Never read ~/.anton/data_vault/ files directly.\n" ) for c in conns: + slug = f"{c['engine']}-{c['name']}" + if active_only and slug != active_only: + continue fields = vault.load(c["engine"], c["name"]) or {} var_names = ", ".join(f"DS_{k.upper()}" for k in fields) - lines.append(f"- `{c['engine']}-{c['name']}` → {var_names}") + lines.append(f"- `{slug}` → {var_names}") return "\n".join(lines) @@ -2713,6 +2719,7 @@ async def _handle_connect_datasource( scratchpads: ScratchpadManager, session: "ChatSession", datasource_name: str | None = None, + prefill: str | None = None, ) -> "ChatSession": """ Connect a data source by entering credentials, either for a new name or re-entering for an existing one. @@ -2859,18 +2866,23 @@ async def _handle_connect_datasource( # ── Normal flow: connect a new (or reconnect an existing) data source ───── console.print() engine_names = ", ".join(e.display_name for e in registry.all_engines()) - answer = Prompt.ask( - f"[anton.cyan](anton)[/] Which data source would you like to connect?\n" - f" [anton.muted](e.g. {engine_names})[/]\n", - console=console, - ) + if prefill: + answer = prefill + else: + answer = Prompt.ask( + f"[anton.cyan](anton)[/] Which data source would you like to connect?\n" + f" [anton.muted](e.g. {engine_names})[/]\n", + console=console, + ) # ── Reconnect path: user typed an existing vault slug ───────────────────── stripped_answer = answer.strip() known_slugs = {f"{c['engine']}-{c['name']}": c for c in vault.list_connections()} if stripped_answer in known_slugs: conn = known_slugs[stripped_answer] + vault.clear_ds_env() vault.inject_env(conn["engine"], conn["name"]) + session._active_datasource = stripped_answer recon_engine_def = registry.get(conn["engine"]) if recon_engine_def: _register_secret_vars(recon_engine_def) @@ -3680,10 +3692,12 @@ def _bottom_toolbar(): _handle_memory(console, settings, cortex, episodic=episodic) continue elif cmd == "/connect-data-source": + arg = parts[1].strip() if len(parts) > 1 else "" session = await _handle_connect_datasource( console, session._scratchpads, session, + prefill=arg or None, ) continue elif cmd == "/list-data-sources": From c91534cb126384baf67f10b29e7aac5d5988cd3a Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Fri, 20 Mar 2026 12:06:05 +0100 Subject: [PATCH 10/70] Add tests to test active connection --- tests/test_datasource.py | 122 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 121 insertions(+), 1 deletion(-) diff --git a/tests/test_datasource.py b/tests/test_datasource.py index 1706e3e..63a3d4f 100644 --- a/tests/test_datasource.py +++ b/tests/test_datasource.py @@ -868,4 +868,124 @@ async def test_register_and_scrub_on_connect(self, registry, vault_dir): assert secret_pw not in result assert "[DS_PASSWORD]" in result finally: - del os.environ["DS_PASSWORD"] \ No newline at end of file + del os.environ["DS_PASSWORD"] + + +# ───────────────────────────────────────────────────────────────────────────── +# Active datasource scoping +# ───────────────────────────────────────────────────────────────────────────── + + +class TestActiveDatasourceScoping: + """Tests for /connect-data-source isolating a single datasource.""" + + def _make_session(self): + from anton.chat import ChatSession + + mock_llm = AsyncMock() + session = ChatSession(mock_llm) + session._scratchpads = AsyncMock() + return session + + def test_active_datasource_defaults_to_none(self): + session = self._make_session() + assert session._active_datasource is None + + @pytest.mark.asyncio + async def test_reconnect_sets_active_datasource(self, vault_dir): + """Reconnecting to a slug via prefill sets session._active_datasource.""" + from anton.chat import _handle_connect_datasource + + vault = DataVault(vault_dir=vault_dir) + vault.save("hubspot", "2", {"access_token": "pat-xxx"}) + + session = self._make_session() + console = MagicMock() + console.print = MagicMock() + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry"), + ): + result = await _handle_connect_datasource( + console, session._scratchpads, session, prefill="hubspot-2" + ) + + assert result._active_datasource == "hubspot-2" + + @pytest.mark.asyncio + async def test_reconnect_clears_other_ds_vars(self, vault_dir): + """Reconnecting to one slug removes DS_* vars from all other connections.""" + from anton.chat import _handle_connect_datasource + + vault = DataVault(vault_dir=vault_dir) + vault.save("oracle", "1", {"host": "oracle.host", "user": "admin", "password": "orapass"}) + vault.save("hubspot", "2", {"access_token": "pat-xxx"}) + + # Simulate startup: inject all connections + vault.inject_env("oracle", "1") + vault.inject_env("hubspot", "2") + assert os.environ.get("DS_HOST") == "oracle.host" + assert os.environ.get("DS_ACCESS_TOKEN") == "pat-xxx" + + session = self._make_session() + console = MagicMock() + console.print = MagicMock() + + try: + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry"), + ): + await _handle_connect_datasource( + console, session._scratchpads, session, prefill="hubspot-2" + ) + + # Oracle vars must be gone; HubSpot var must be present + assert "DS_HOST" not in os.environ + assert "DS_USER" not in os.environ + assert "DS_PASSWORD" not in os.environ + assert os.environ.get("DS_ACCESS_TOKEN") == "pat-xxx" + finally: + vault.clear_ds_env() + + def test_build_datasource_context_no_filter(self, vault_dir): + """Without active_only, all vault entries appear in the context.""" + from anton.chat import _build_datasource_context + + vault = DataVault(vault_dir=vault_dir) + vault.save("oracle", "1", {"host": "oracle.host"}) + vault.save("hubspot", "2", {"access_token": "pat-xxx"}) + + with patch("anton.chat.DataVault", return_value=vault): + ctx = _build_datasource_context() + + assert "oracle-1" in ctx + assert "hubspot-2" in ctx + + def test_build_datasource_context_active_only_filters(self, vault_dir): + """With active_only set, only the matching slug appears.""" + from anton.chat import _build_datasource_context + + vault = DataVault(vault_dir=vault_dir) + vault.save("oracle", "1", {"host": "oracle.host"}) + vault.save("hubspot", "2", {"access_token": "pat-xxx"}) + + with patch("anton.chat.DataVault", return_value=vault): + ctx = _build_datasource_context(active_only="hubspot-2") + + assert "hubspot-2" in ctx + assert "oracle-1" not in ctx + + def test_build_datasource_context_active_only_empty_when_no_match(self, vault_dir): + """If active_only doesn't match any slug, the section has no entries.""" + from anton.chat import _build_datasource_context + + vault = DataVault(vault_dir=vault_dir) + vault.save("oracle", "1", {"host": "oracle.host"}) + + with patch("anton.chat.DataVault", return_value=vault): + ctx = _build_datasource_context(active_only="hubspot-99") + + # Header is present but no datasource lines + assert "oracle-1" not in ctx \ No newline at end of file From 64cbe6ba78bde5d9b0fbf11956edf9ceb60470f5 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Fri, 20 Mar 2026 16:55:28 +0100 Subject: [PATCH 11/70] Better handling of same engin connections and maintain active in the Antons session --- anton/chat.py | 281 ++++++++-- anton/cli.py | 93 ++++ anton/data_vault.py | 36 +- tests/test_datasource.py | 1121 ++++++++++++++++++++++++++++++++++++-- 4 files changed, 1436 insertions(+), 95 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index a9f9817..cb802a4 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -42,7 +42,7 @@ format_cell_result, prepare_scratchpad_exec, ) -from anton.data_vault import DataVault +from anton.data_vault import DataVault, _slug_env_prefix from anton.datasource_registry import ( DatasourceEngine, DatasourceField, @@ -1044,13 +1044,23 @@ def _apply_error_tracking( _DS_KNOWN_VARS: set[str] = set() -def _register_secret_vars(engine_def: "DatasourceEngine") -> None: - """Record which DS_* var names correspond to secret fields for engine_def.""" +def _register_secret_vars( + engine_def: "DatasourceEngine", *, engine: str = "", name: str = "" +) -> None: + """Record which DS_* var names correspond to known/secret fields for engine_def. + + If engine and name are given, registers namespaced vars (DS_ENGINE_NAME__FIELD). + Otherwise registers flat vars (DS_FIELD) — for temporary test_snippet execution. + """ all_fields = list(engine_def.fields) for am in engine_def.auth_methods or []: all_fields.extend(am.fields) for f in all_fields: - key = f"DS_{f.name.upper()}" + if engine and name: + prefix = _slug_env_prefix(engine, name) + key = f"{prefix}__{f.name.upper()}" + else: + key = f"DS_{f.name.upper()}" _DS_KNOWN_VARS.add(key) if f.secret: _DS_SECRET_VARS.add(key) @@ -1097,8 +1107,9 @@ def _build_datasource_context(active_only: str | None = None) -> str: return "" lines = ["\n\n## Connected Data Sources"] lines.append( - "Credentials are pre-injected as DS_* environment variables. " - "Use them directly in scratchpad code. " + "Credentials are pre-injected as namespaced DS___ " + "environment variables. Use them directly in scratchpad code " + "(e.g. DS_POSTGRES_PROD_DB__HOST). " "Never read ~/.anton/data_vault/ files directly.\n" ) for c in conns: @@ -1106,11 +1117,25 @@ def _build_datasource_context(active_only: str | None = None) -> str: if active_only and slug != active_only: continue fields = vault.load(c["engine"], c["name"]) or {} - var_names = ", ".join(f"DS_{k.upper()}" for k in fields) - lines.append(f"- `{slug}` → {var_names}") + prefix = _slug_env_prefix(c["engine"], c["name"]) + var_names = ", ".join(f"{prefix}__{k.upper()}" for k in fields) + lines.append(f"- `{slug}` ({c['engine']}) → {var_names}") return "\n".join(lines) +def _restore_namespaced_env(vault: DataVault) -> None: + """Clear all DS_* vars, then reinject every saved connection as namespaced.""" + from anton.datasource_registry import DatasourceRegistry + + vault.clear_ds_env() + dreg = DatasourceRegistry() + for conn in vault.list_connections(): + vault.inject_env(conn["engine"], conn["name"]) # flat=False by default + edef = dreg.get(conn["engine"]) + if edef is not None: + _register_secret_vars(edef, engine=conn["engine"], name=conn["name"]) + + def _build_runtime_context(settings: AntonSettings) -> str: """Build runtime context string including Minds datasource info if configured.""" ctx = ( @@ -2728,7 +2753,7 @@ async def _handle_connect_datasource( vault = DataVault() registry = DatasourceRegistry() - # ── /edit-data-source path: re-enter credentials for an existing slug ───── + # ── /edit-data-source path: update credentials for an existing slug ──────── if datasource_name is not None: slug_parts = datasource_name.split("-", 1) if len(slug_parts) != 2: @@ -2739,7 +2764,8 @@ async def _handle_connect_datasource( console.print() return session edit_engine, edit_name = slug_parts - if vault.load(edit_engine, edit_name) is None: + existing = vault.load(edit_engine, edit_name) + if existing is None: console.print( f"[anton.warning]No connection '{datasource_name}' found in Local Vault.[/]" ) @@ -2756,39 +2782,73 @@ async def _handle_connect_datasource( console.print() console.print( - f"[anton.cyan](anton)[/] Updating credentials for " - f'[bold]"{datasource_name}"[/bold] ({engine_def.display_name}).' + f"[anton.cyan](anton)[/] Editing [bold]\"{datasource_name}\"[/bold]" + f" ({engine_def.display_name})." ) + console.print("[anton.muted] Press Enter to keep the current value.[/]") console.print() - credentials: dict[str, str] = {} - for f in engine_def.fields: + # Detect which fields to present (handle auth_method=choice) + active_fields = engine_def.fields + if engine_def.auth_method == "choice" and engine_def.auth_methods: + for am in engine_def.auth_methods: + am_field_names = {af.name for af in am.fields} + if any(k in am_field_names for k in existing): + active_fields = am.fields + break + if not active_fields: + active_fields = engine_def.auth_methods[0].fields + + # Start from existing values; let user update field-by-field + credentials: dict[str, str] = dict(existing) + for f in active_fields: + current = existing.get(f.name, "") prompt_label = f"[anton.cyan](anton)[/] {f.name}" if not f.required: - prompt_label += " [anton.muted](optional, press enter to skip)[/]" + prompt_label += " [anton.muted](optional)[/]" + if f.secret: + masked = "••••••••" if current else "" + label = ( + f"{prompt_label} [anton.muted][{masked}][/]" + if masked + else prompt_label + ) + value = Prompt.ask(label, password=True, console=console, default="") + if value: + credentials[f.name] = value + # else: keep existing (already in credentials) + elif current: value = Prompt.ask( - prompt_label, password=True, console=console, default="" + f"{prompt_label} [anton.muted][{current}][/]", + console=console, + default=current, ) + credentials[f.name] = value if value else current elif f.default: value = Prompt.ask( f"{prompt_label} [anton.muted][{f.default}][/]", console=console, default=f.default, ) + if value: + credentials[f.name] = value else: value = Prompt.ask(prompt_label, console=console, default="") - if value: - credentials[f.name] = value + if value: + credentials[f.name] = value if engine_def.test_snippet: while True: console.print() console.print("[anton.cyan](anton)[/] Got it. Testing connection…") - import os as _os - for key, value in credentials.items(): - _os.environ[f"DS_{key.upper()}"] = value + # Temporarily save credentials so inject_env(flat=True) can load them, + # then restore all namespaced env vars in the finally block. + vault.save(edit_engine, edit_name, credentials) + vault.clear_ds_env() + vault.inject_env(edit_engine, edit_name, flat=True) + _register_secret_vars(engine_def) # flat names for scrubbing during test try: pad = await scratchpads.get_or_create("__datasource_test__") await pad.reset() @@ -2796,9 +2856,7 @@ async def _handle_connect_datasource( await pad.install_packages([engine_def.pip]) cell = await pad.execute(engine_def.test_snippet) finally: - ds_keys = [k for k in _os.environ if k.startswith("DS_")] - for k in ds_keys: - del _os.environ[k] + _restore_namespaced_env(vault) if cell.error or (cell.stdout.strip() != "ok" and cell.stderr.strip()): error_text = ( @@ -2824,7 +2882,7 @@ async def _handle_connect_datasource( if retry != "y": return session console.print() - for f in engine_def.fields: + for f in active_fields: if not f.secret: continue value = Prompt.ask( @@ -2841,8 +2899,8 @@ async def _handle_connect_datasource( break vault.save(edit_engine, edit_name, credentials) - vault.inject_env(edit_engine, edit_name) - _register_secret_vars(engine_def) + _restore_namespaced_env(vault) + _register_secret_vars(engine_def, engine=edit_engine, name=edit_name) console.print() console.print( f' Credentials updated for [bold]"{datasource_name}"[/bold].' @@ -2880,12 +2938,11 @@ async def _handle_connect_datasource( known_slugs = {f"{c['engine']}-{c['name']}": c for c in vault.list_connections()} if stripped_answer in known_slugs: conn = known_slugs[stripped_answer] - vault.clear_ds_env() - vault.inject_env(conn["engine"], conn["name"]) + _restore_namespaced_env(vault) session._active_datasource = stripped_answer recon_engine_def = registry.get(conn["engine"]) if recon_engine_def: - _register_secret_vars(recon_engine_def) + _register_secret_vars(recon_engine_def, engine=conn["engine"], name=conn["name"]) engine_label = recon_engine_def.display_name else: engine_label = conn["engine"] @@ -3086,11 +3143,14 @@ async def _handle_connect_datasource( console.print() console.print("[anton.cyan](anton)[/] Got it. Testing connection…") - # Temporarily inject DS_* into os.environ (test before committing to vault) + # Temporarily inject flat DS_* vars for test_snippet execution. + # conn_name is not yet known, so inject directly from credentials. import os as _os + vault.clear_ds_env() for key, value in credentials.items(): _os.environ[f"DS_{key.upper()}"] = value + _register_secret_vars(engine_def) # flat mode, for scrubbing during test try: pad = await scratchpads.get_or_create("__datasource_test__") @@ -3101,10 +3161,7 @@ async def _handle_connect_datasource( cell = await pad.execute(engine_def.test_snippet) finally: - # Always clean up DS_* regardless of outcome - ds_keys = [k for k in _os.environ if k.startswith("DS_")] - for k in ds_keys: - del _os.environ[k] + _restore_namespaced_env(vault) if cell.error or (cell.stdout.strip() != "ok" and cell.stderr.strip()): error_text = cell.error or cell.stderr.strip() or cell.stdout.strip() @@ -3177,8 +3234,8 @@ async def _handle_connect_datasource( console.print("[anton.muted]Cancelled.[/]") console.print() return session - vault.inject_env(engine_def.engine, conn_name) - _register_secret_vars(engine_def) + _restore_namespaced_env(vault) + _register_secret_vars(engine_def, engine=engine_def.engine, name=conn_name) console.print() console.print( f'[anton.success] ✓ Reconnected to [bold]"{slug}"[/bold].[/]' @@ -3196,7 +3253,9 @@ async def _handle_connect_datasource( return session vault.save(engine_def.engine, conn_name, credentials) - _register_secret_vars(engine_def) + _restore_namespaced_env(vault) + session._active_datasource = slug + _register_secret_vars(engine_def, engine=engine_def.engine, name=conn_name) console.print(f' Credentials saved to Local Vault as [bold]"{slug}"[/bold].') console.print() @@ -3219,8 +3278,11 @@ async def _handle_connect_datasource( def _handle_list_data_sources(console: Console) -> None: - """Print all saved Local Vault connections with their DS_* var names.""" + """Print all saved Local Vault connections in a table with status.""" + from rich.table import Table + vault = DataVault() + registry = DatasourceRegistry() conns = vault.list_connections() console.print() if not conns: @@ -3228,16 +3290,31 @@ def _handle_list_data_sources(console: Console) -> None: console.print("[anton.muted]Use /connect-data-source to add one.[/]") console.print() return - console.print("[anton.cyan]Connected data sources:[/]") - console.print() + + table = Table(title="Local Vault — Saved Connections", show_lines=False) + table.add_column("Name", style="bold") + table.add_column("Source") + table.add_column("Status") + for c in conns: - fields = vault.load(c["engine"], c["name"]) or {} - var_names = ", ".join(f"DS_{k.upper()}" for k in fields) slug = f"{c['engine']}-{c['name']}" - complete = "✓" if var_names else "⚠ incomplete" - console.print(f" [bold]{slug}[/] {complete}") - if var_names: - console.print(f" [anton.muted] {var_names}[/]") + engine_def = registry.get(c["engine"]) + source = engine_def.display_name if engine_def else c["engine"] + fields = vault.load(c["engine"], c["name"]) or {} + + if not fields: + status = "[yellow]incomplete[/]" + elif engine_def and engine_def.auth_method != "choice": + required = [f.name for f in engine_def.fields if f.required] + missing = [name for name in required if name not in fields] + status = "[yellow]incomplete[/]" if missing else "[green]saved[/]" + else: + # choice-auth engine or unknown engine: presence of any field = saved + status = "[green]saved[/]" + + table.add_row(slug, source, status) + + console.print(table) console.print() @@ -3266,6 +3343,98 @@ def _handle_remove_data_source(console: Console, slug: str) -> None: console.print() +async def _handle_test_datasource( + console: Console, + scratchpads: ScratchpadManager, + slug: str, +) -> None: + """Test an existing Local Vault connection by running its test_snippet.""" + if not slug: + console.print( + "[anton.warning]Usage: /test-data-source [/]" + ) + console.print() + return + + parts = slug.split("-", 1) + if len(parts) != 2: + console.print( + f"[anton.warning]Invalid name '{slug}'. Use engine-name format.[/]" + ) + console.print() + return + + vault = DataVault() + registry = DatasourceRegistry() + engine, name = parts + fields = vault.load(engine, name) + if fields is None: + console.print( + f"[anton.warning]No connection '{slug}' found in Local Vault.[/]" + ) + console.print() + return + + engine_def = registry.get(engine) + if engine_def is None: + console.print( + f"[anton.warning]Unknown engine '{engine}'. Cannot test.[/]" + ) + console.print() + return + + if not engine_def.test_snippet: + console.print( + f"[anton.warning]No test snippet defined for '{engine}'. Cannot test.[/]" + ) + console.print() + return + + console.print() + console.print( + f"[anton.cyan](anton)[/] Testing connection [bold]{slug}[/bold]…" + ) + + vault.clear_ds_env() + vault.inject_env(engine, name, flat=True) + _register_secret_vars(engine_def) # flat names for scrubbing during test + + cell = None + try: + pad = await scratchpads.get_or_create("__datasource_test__") + await pad.reset() + if engine_def.pip: + await pad.install_packages([engine_def.pip]) + cell = await pad.execute(engine_def.test_snippet) + finally: + _restore_namespaced_env(vault) + + if cell is None or cell.error or ( + cell.stdout.strip() != "ok" and cell.stderr.strip() + ): + error_text = "" + if cell is not None: + error_text = cell.error or cell.stderr.strip() or cell.stdout.strip() + first_line = ( + next((ln for ln in error_text.splitlines() if ln.strip()), error_text) + if error_text + else "unknown error" + ) + console.print() + console.print( + f"[anton.warning](anton)[/] ✗ Connection test failed for" + f" [bold]{slug}[/bold]." + ) + console.print() + console.print(f" Error: {first_line}") + else: + console.print( + f"[anton.success] ✓ Connection test passed for" + f" [bold]{slug}[/bold]![/]" + ) + console.print() + + def _print_slash_help(console: Console) -> None: """Print available slash commands.""" console.print() @@ -3279,7 +3448,9 @@ def _print_slash_help(console: Console) -> None: console.print( " [bold]/list-data-sources[/] — List all saved data source connections" ) + console.print(" [bold]/edit-data-source[/] — Edit a saved connection's credentials") console.print(" [bold]/remove-data-source[/] — Remove a saved connection") + console.print(" [bold]/test-data-source[/] — Test a saved connection") console.print( " [bold]/data-connections[/] — View and manage stored keys and connections" ) @@ -3440,15 +3611,15 @@ async def _chat_loop( workspace = Workspace(settings.workspace_path) workspace.apply_env_to_process() - # Inject all Local Vault connections as DS_* env vars so every scratchpad - # subprocess inherits them. Must happen before any ChatSession is created. + # Inject all Local Vault connections as namespaced DS_* env vars so every + # scratchpad subprocess inherits them. Must happen before any ChatSession is created. _dv = DataVault() _dreg = DatasourceRegistry() for _conn in _dv.list_connections(): - _dv.inject_env(_conn["engine"], _conn["name"]) + _dv.inject_env(_conn["engine"], _conn["name"]) # flat=False by default _edef = _dreg.get(_conn["engine"]) if _edef is not None: - _register_secret_vars(_edef) + _register_secret_vars(_edef, engine=_conn["engine"], name=_conn["name"]) del _dv, _dreg # --- Memory system (brain-inspired architecture) --- @@ -3710,6 +3881,8 @@ def _bottom_toolbar(): "[anton.warning]Usage: /remove-data-source" " [/]" ) + console.print() + else: _handle_remove_data_source(console, arg) continue elif cmd == "/edit-data-source": @@ -3728,10 +3901,10 @@ def _bottom_toolbar(): ) continue elif cmd == "/test-data-source": - console.print( - "[anton.muted]/test-data-source is not yet" " implemented.[/]" + arg = parts[1].strip() if len(parts) > 1 else "" + await _handle_test_datasource( + console, session._scratchpads, arg ) - console.print() continue elif cmd == "/data-connections": session = await _handle_data_connections( diff --git a/anton/cli.py b/anton/cli.py index fa922e5..7c98a36 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -509,3 +509,96 @@ async def _run() -> None: await updated._scratchpads.close_all() asyncio.run(_run()) + + +@app.command("list-data-sources") +def list_data_sources(ctx: typer.Context) -> None: + """List all saved data source connections in the Local Vault.""" + from anton.chat import _handle_list_data_sources + + _handle_list_data_sources(console) + + +@app.command("edit-data-source") +def edit_data_source( + ctx: typer.Context, + name: str = typer.Argument(..., help="Connection slug to edit (e.g. postgres-mydb)."), +) -> None: + """Edit credentials for an existing Local Vault connection.""" + import asyncio + + from anton.chat import ChatSession, _handle_connect_datasource + from anton.llm.client import LLMClient + from anton.scratchpad import ScratchpadManager + + settings = _get_settings(ctx) + _ensure_workspace(settings) + _ensure_api_key(settings) + + llm_client = LLMClient.from_settings(settings) + scratchpads = ScratchpadManager( + coding_provider=settings.coding_provider, + coding_model=settings.coding_model, + coding_api_key=( + settings.anthropic_api_key + if settings.coding_provider == "anthropic" + else settings.openai_api_key + ) + or "", + ) + session = ChatSession(llm_client) + + async def _run() -> None: + updated = await _handle_connect_datasource( + console, + scratchpads, + session, + datasource_name=name, + ) + await updated._scratchpads.close_all() + + asyncio.run(_run()) + + +@app.command("remove-data-source") +def remove_data_source( + ctx: typer.Context, + name: str = typer.Argument(..., help="Connection slug to remove (e.g. postgres-mydb)."), +) -> None: + """Remove a saved connection from the Local Vault.""" + from anton.chat import _handle_remove_data_source + + _handle_remove_data_source(console, name) + + +@app.command("test-data-source") +def test_data_source( + ctx: typer.Context, + name: str = typer.Argument(..., help="Connection slug to test (e.g. postgres-mydb)."), +) -> None: + """Test a saved Local Vault connection using its test snippet.""" + import asyncio + + from anton.chat import _handle_test_datasource + from anton.scratchpad import ScratchpadManager + + settings = _get_settings(ctx) + _ensure_workspace(settings) + _ensure_api_key(settings) + + scratchpads = ScratchpadManager( + coding_provider=settings.coding_provider, + coding_model=settings.coding_model, + coding_api_key=( + settings.anthropic_api_key + if settings.coding_provider == "anthropic" + else settings.openai_api_key + ) + or "", + ) + + async def _run() -> None: + await _handle_test_datasource(console, scratchpads, name) + await scratchpads.close_all() + + asyncio.run(_run()) diff --git a/anton/data_vault.py b/anton/data_vault.py index 8ce1d25..61f55cc 100644 --- a/anton/data_vault.py +++ b/anton/data_vault.py @@ -12,6 +12,18 @@ def _sanitize(value: str) -> str: return re.sub(r"[^\w\-]", "_", value).strip("_") +def _slug_env_prefix(engine: str, name: str) -> str: + """Return the DS_ prefix for a namespaced connection env var. + + Examples: + engine="postgres", name="prod_db" → "DS_POSTGRES_PROD_DB" + engine="hubspot", name="main" → "DS_HUBSPOT_MAIN" + engine="postgres", name="prod-db.eu" → "DS_POSTGRES_PROD_DB_EU" + """ + raw = f"{engine}-{name}" + return "DS_" + re.sub(r"[^\w]", "_", raw).upper() + + class DataVault: """Manages data source connection credentials in ~/.anton/data_vault/.""" @@ -81,20 +93,30 @@ def list_connections(self) -> list[dict[str, str]]: continue return results - def inject_env(self, engine: str, name: str) -> list[str] | None: - """Load credentials and set DS_ in os.environ. + def inject_env(self, engine: str, name: str, *, flat: bool = False) -> list[str] | None: + """Load credentials and set DS_* environment variables. + + Default (flat=False): injects namespaced vars, e.g. DS_POSTGRES_PROD_DB__HOST. + flat=True: injects legacy flat vars, e.g. DS_HOST — use only during + single-connection test_snippet execution. Returns the list of env var names set, or None if connection not found. - Call this before a scratchpad exec that needs the data source. """ fields = self.load(engine, name) if fields is None: return None var_names: list[str] = [] - for key, value in fields.items(): - var = f"DS_{key.upper()}" - os.environ[var] = value - var_names.append(var) + if flat: + for key, value in fields.items(): + var = f"DS_{key.upper()}" + os.environ[var] = value + var_names.append(var) + else: + prefix = _slug_env_prefix(engine, name) + for key, value in fields.items(): + var = f"{prefix}__{key.upper()}" + os.environ[var] = value + var_names.append(var) return var_names def clear_ds_env(self) -> None: diff --git a/tests/test_datasource.py b/tests/test_datasource.py index 63a3d4f..1d1f194 100644 --- a/tests/test_datasource.py +++ b/tests/test_datasource.py @@ -223,9 +223,9 @@ class TestDataVaultEnvInjection: def test_inject_sets_ds_vars(self, vault): vault.save("postgresql", "prod_db", {"host": "db.example.com", "password": "s3cr3t"}) var_names = vault.inject_env("postgresql", "prod_db") - assert os.environ.get("DS_HOST") == "db.example.com" - assert os.environ.get("DS_PASSWORD") == "s3cr3t" - assert set(var_names) == {"DS_HOST", "DS_PASSWORD"} + assert os.environ.get("DS_POSTGRESQL_PROD_DB__HOST") == "db.example.com" + assert os.environ.get("DS_POSTGRESQL_PROD_DB__PASSWORD") == "s3cr3t" + assert set(var_names) == {"DS_POSTGRESQL_PROD_DB__HOST", "DS_POSTGRESQL_PROD_DB__PASSWORD"} # Cleanup vault.clear_ds_env() @@ -237,7 +237,7 @@ def test_clear_removes_ds_vars(self, vault): vault.save("postgresql", "prod_db", {"host": "x"}) vault.inject_env("postgresql", "prod_db") vault.clear_ds_env() - assert "DS_HOST" not in os.environ + assert "DS_POSTGRESQL_PROD_DB__HOST" not in os.environ def test_clear_leaves_non_ds_vars(self, vault): os.environ["MY_VAR"] = "untouched" @@ -248,9 +248,57 @@ def test_clear_leaves_non_ds_vars(self, vault): def test_inject_uppercases_field_names(self, vault): vault.save("postgresql", "prod_db", {"access_token": "tok123"}) vault.inject_env("postgresql", "prod_db") - assert os.environ.get("DS_ACCESS_TOKEN") == "tok123" + assert os.environ.get("DS_POSTGRESQL_PROD_DB__ACCESS_TOKEN") == "tok123" + vault.clear_ds_env() + + def test_inject_flat_mode_sets_flat_vars(self, vault): + """flat=True injects legacy DS_FIELD vars, not namespaced ones.""" + vault.save("postgresql", "prod_db", {"host": "db.example.com"}) + var_names = vault.inject_env("postgresql", "prod_db", flat=True) + assert os.environ.get("DS_HOST") == "db.example.com" + assert "DS_POSTGRESQL_PROD_DB__HOST" not in os.environ + assert set(var_names) == {"DS_HOST"} vault.clear_ds_env() + def test_two_same_type_connections_no_collision(self, vault): + """Two connections of the same engine type coexist without overwriting each other.""" + vault.save("postgres", "prod_db", {"host": "prod.example.com"}) + vault.save("postgres", "analytics", {"host": "analytics.example.com"}) + vault.inject_env("postgres", "prod_db") + vault.inject_env("postgres", "analytics") + try: + assert os.environ.get("DS_POSTGRES_PROD_DB__HOST") == "prod.example.com" + assert os.environ.get("DS_POSTGRES_ANALYTICS__HOST") == "analytics.example.com" + # The two vars are distinct — no collision + assert os.environ.get("DS_POSTGRES_PROD_DB__HOST") != os.environ.get("DS_POSTGRES_ANALYTICS__HOST") + finally: + vault.clear_ds_env() + + def test_different_engines_no_collision(self, vault): + """Connections from different engines coexist simultaneously.""" + vault.save("postgres", "prod_db", {"host": "pg.example.com"}) + vault.save("hubspot", "main", {"access_token": "pat-abc"}) + vault.inject_env("postgres", "prod_db") + vault.inject_env("hubspot", "main") + try: + assert os.environ.get("DS_POSTGRES_PROD_DB__HOST") == "pg.example.com" + assert os.environ.get("DS_HUBSPOT_MAIN__ACCESS_TOKEN") == "pat-abc" + finally: + vault.clear_ds_env() + + def test_slug_env_prefix_sanitizes_special_chars(self, vault): + """Special characters in names are sanitized to underscores.""" + from anton.data_vault import _slug_env_prefix + + assert _slug_env_prefix("postgres", "prod-db.eu") == "DS_POSTGRES_PROD_DB_EU" + # Full var + vault.save("postgres", "prod-db.eu", {"host": "eu.pg.com"}) + vault.inject_env("postgres", "prod-db.eu") + try: + assert os.environ.get("DS_POSTGRES_PROD_DB_EU__HOST") == "eu.pg.com" + finally: + vault.clear_ds_env() + # ───────────────────────────────────────────────────────────────────────────── # DataVault — next_connection_number @@ -631,8 +679,10 @@ async def test_failed_test_no_retry_returns_without_saving(self, registry, vault assert not result._history @pytest.mark.asyncio - async def test_ds_env_cleaned_up_after_test(self, registry, vault_dir): - """DS_* env vars must not leak into os.environ after connection testing.""" + async def test_ds_env_injected_after_successful_connect( + self, registry, vault_dir + ): + """After a successful connect, DS_* vars are injected into the env.""" from anton.chat import _handle_connect_datasource session = self._make_session() @@ -641,7 +691,9 @@ async def test_ds_env_cleaned_up_after_test(self, registry, vault_dir): vault = DataVault(vault_dir=vault_dir) pad = AsyncMock() - pad.execute = AsyncMock(return_value=self._make_cell(stdout="ok")) + pad.execute = AsyncMock( + return_value=self._make_cell(stdout="ok") + ) session._scratchpads.get_or_create = AsyncMock(return_value=pad) prompt_responses = iter([ @@ -649,16 +701,29 @@ async def test_ds_env_cleaned_up_after_test(self, registry, vault_dir): "db.example.com", "5432", "prod_db", "alice", "s3cr3t", "", ]) - with ( - patch("anton.chat.DataVault", return_value=vault), - patch("anton.chat.DatasourceRegistry", return_value=registry), - patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), - ): - await _handle_connect_datasource(console, session._scratchpads, session) + try: + with ( + patch("anton.chat.DataVault", return_value=vault), + patch( + "anton.chat.DatasourceRegistry", + return_value=registry, + ), + patch( + "rich.prompt.Prompt.ask", + side_effect=lambda *a, **kw: next( + prompt_responses + ), + ), + ): + await _handle_connect_datasource( + console, session._scratchpads, session + ) - # No DS_* variables should remain - leaked = [k for k in os.environ if k.startswith("DS_")] - assert leaked == [], f"DS_* vars leaked: {leaked}" + # After successful connect, namespaced DS_* vars are injected. + # name_from=database → name="prod_db" → prefix DS_POSTGRESQL_PROD_DB + assert os.environ.get("DS_POSTGRESQL_PROD_DB__HOST") == "db.example.com" + finally: + vault.clear_ds_env() @pytest.mark.asyncio async def test_auth_method_choice_selects_fields(self, registry, vault_dir): @@ -740,10 +805,13 @@ class TestCredentialScrubbing: """_scrub_credentials and _register_secret_vars.""" def setup_method(self): - # Reset the module-level sets before each test + # Reset the module-level sets and clear any DS_* env vars from other tests from anton.chat import _DS_KNOWN_VARS, _DS_SECRET_VARS _DS_SECRET_VARS.clear() _DS_KNOWN_VARS.clear() + ds_keys = [k for k in os.environ if k.startswith("DS_")] + for k in ds_keys: + del os.environ[k] def test_register_secret_vars_adds_secret_fields(self, registry): """Secret fields are added to _DS_SECRET_VARS; non-secret fields are not.""" @@ -860,15 +928,17 @@ async def test_register_and_scrub_on_connect(self, registry, vault_dir): ): await _handle_connect_datasource(MagicMock(), session._scratchpads, session) - # After connect, password should be in the secret set and scrubbed - assert "DS_PASSWORD" in _DS_SECRET_VARS - os.environ["DS_PASSWORD"] = secret_pw + # After connect, namespaced secret var is registered and scrubbed. + # name_from=database → name="mydb" → DS_POSTGRESQL_MYDB__PASSWORD + namespaced_pw_var = "DS_POSTGRESQL_MYDB__PASSWORD" + assert namespaced_pw_var in _DS_SECRET_VARS + os.environ[namespaced_pw_var] = secret_pw try: result = _scrub_credentials(f"error: auth failed with {secret_pw}") assert secret_pw not in result - assert "[DS_PASSWORD]" in result + assert f"[{namespaced_pw_var}]" in result finally: - del os.environ["DS_PASSWORD"] + del os.environ[namespaced_pw_var] # ───────────────────────────────────────────────────────────────────────────── @@ -914,19 +984,19 @@ async def test_reconnect_sets_active_datasource(self, vault_dir): assert result._active_datasource == "hubspot-2" @pytest.mark.asyncio - async def test_reconnect_clears_other_ds_vars(self, vault_dir): - """Reconnecting to one slug removes DS_* vars from all other connections.""" + async def test_reconnect_all_namespaced_vars_available(self, vault_dir): + """After reconnect, ALL saved connections remain available as namespaced vars.""" from anton.chat import _handle_connect_datasource vault = DataVault(vault_dir=vault_dir) vault.save("oracle", "1", {"host": "oracle.host", "user": "admin", "password": "orapass"}) vault.save("hubspot", "2", {"access_token": "pat-xxx"}) - # Simulate startup: inject all connections + # Simulate startup: inject all connections as namespaced vault.inject_env("oracle", "1") vault.inject_env("hubspot", "2") - assert os.environ.get("DS_HOST") == "oracle.host" - assert os.environ.get("DS_ACCESS_TOKEN") == "pat-xxx" + assert os.environ.get("DS_ORACLE_1__HOST") == "oracle.host" + assert os.environ.get("DS_HUBSPOT_2__ACCESS_TOKEN") == "pat-xxx" session = self._make_session() console = MagicMock() @@ -937,15 +1007,19 @@ async def test_reconnect_clears_other_ds_vars(self, vault_dir): patch("anton.chat.DataVault", return_value=vault), patch("anton.chat.DatasourceRegistry"), ): - await _handle_connect_datasource( + result = await _handle_connect_datasource( console, session._scratchpads, session, prefill="hubspot-2" ) - # Oracle vars must be gone; HubSpot var must be present + # After reconnect, all connections are restored as namespaced vars. + # No flat DS_* vars are present. assert "DS_HOST" not in os.environ - assert "DS_USER" not in os.environ - assert "DS_PASSWORD" not in os.environ - assert os.environ.get("DS_ACCESS_TOKEN") == "pat-xxx" + assert "DS_ACCESS_TOKEN" not in os.environ + # Both connections remain available as namespaced vars. + assert os.environ.get("DS_ORACLE_1__HOST") == "oracle.host" + assert os.environ.get("DS_HUBSPOT_2__ACCESS_TOKEN") == "pat-xxx" + # Active datasource is updated to the reconnected slug. + assert result._active_datasource == "hubspot-2" finally: vault.clear_ds_env() @@ -988,4 +1062,983 @@ def test_build_datasource_context_active_only_empty_when_no_match(self, vault_di ctx = _build_datasource_context(active_only="hubspot-99") # Header is present but no datasource lines - assert "oracle-1" not in ctx \ No newline at end of file + assert "oracle-1" not in ctx + + +# ───────────────────────────────────────────────────────────────────────────── +# CLI command registration +# ───────────────────────────────────────────────────────────────────────────── + + +class TestCliCommandRegistration: + """Verify all datasource CLI commands are registered.""" + + def _get_command_names(self): + from anton.cli import app + return [cmd.name for cmd in app.registered_commands] + + def test_connect_data_source_registered(self): + names = self._get_command_names() + assert "connect-data-source" in names + + def test_list_data_sources_registered(self): + names = self._get_command_names() + assert "list-data-sources" in names + + def test_edit_data_source_registered(self): + names = self._get_command_names() + assert "edit-data-source" in names + + def test_remove_data_source_registered(self): + names = self._get_command_names() + assert "remove-data-source" in names + + def test_test_data_source_registered(self): + names = self._get_command_names() + assert "test-data-source" in names + + +# ───────────────────────────────────────────────────────────────────────────── +# _handle_list_data_sources — improved output +# ───────────────────────────────────────────────────────────────────────────── + + +class TestHandleListDataSources: + def test_empty_vault_shows_message(self, vault_dir): + from unittest.mock import MagicMock, patch + from anton.chat import _handle_list_data_sources + + console = MagicMock() + console.print = MagicMock() + + with patch("anton.chat.DataVault", return_value=DataVault(vault_dir=vault_dir)): + _handle_list_data_sources(console) + + printed = " ".join(str(c) for c in console.print.call_args_list) + assert "No data sources" in printed or "connect-data-source" in printed + + def test_complete_connection_shows_saved(self, vault_dir, registry): + from unittest.mock import MagicMock, patch + from rich.console import Console + from anton.chat import _handle_list_data_sources + import io + + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "prod_db", { + "host": "db.example.com", "port": "5432", + "database": "prod", "user": "alice", "password": "s3cr3t", + }) + + buf = io.StringIO() + rich_console = Console(file=buf, highlight=False, markup=False) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + _handle_list_data_sources(rich_console) + + output = buf.getvalue() + assert "postgresql-prod_db" in output + assert "saved" in output.lower() + + def test_incomplete_connection_shows_incomplete(self, vault_dir, registry): + from unittest.mock import MagicMock, patch + from rich.console import Console + from anton.chat import _handle_list_data_sources + import io + + vault = DataVault(vault_dir=vault_dir) + # Missing required fields: database, user, password + vault.save("postgresql", "partial", {"host": "db.example.com"}) + + buf = io.StringIO() + rich_console = Console(file=buf, highlight=False, markup=False) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + _handle_list_data_sources(rich_console) + + output = buf.getvalue() + assert "incomplete" in output.lower() + + def test_shows_source_name(self, vault_dir, registry): + from unittest.mock import patch + from rich.console import Console + from anton.chat import _handle_list_data_sources + import io + + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "prod_db", { + "host": "x", "port": "5432", "database": "d", "user": "u", "password": "p", + }) + + buf = io.StringIO() + rich_console = Console(file=buf, highlight=False, markup=False) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + _handle_list_data_sources(rich_console) + + output = buf.getvalue() + assert "PostgreSQL" in output + + +# ───────────────────────────────────────────────────────────────────────────── +# _handle_test_datasource +# ───────────────────────────────────────────────────────────────────────────── + + +class TestHandleTestDatasource: + def _make_cell(self, stdout="ok", stderr="", error=None): + from unittest.mock import MagicMock + cell = MagicMock() + cell.stdout = stdout + cell.stderr = stderr + cell.error = error + return cell + + @pytest.mark.asyncio + async def test_success_path(self, vault_dir, registry): + from unittest.mock import AsyncMock, MagicMock, patch + from anton.chat import _handle_test_datasource + + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "prod_db", { + "host": "db.example.com", "port": "5432", + "database": "prod", "user": "alice", "password": "s3cr3t", + }) + + console = MagicMock() + console.print = MagicMock() + + pad = AsyncMock() + pad.execute = AsyncMock(return_value=self._make_cell(stdout="ok")) + scratchpads = AsyncMock() + scratchpads.get_or_create = AsyncMock(return_value=pad) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + await _handle_test_datasource(console, scratchpads, "postgresql-prod_db") + + printed = " ".join(str(c) for c in console.print.call_args_list) + assert "✓" in printed or "passed" in printed.lower() + + @pytest.mark.asyncio + async def test_failure_path(self, vault_dir, registry): + from unittest.mock import AsyncMock, MagicMock, patch + from anton.chat import _handle_test_datasource + + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "prod_db", { + "host": "db.example.com", "port": "5432", + "database": "prod", "user": "alice", "password": "wrongpass", + }) + + console = MagicMock() + console.print = MagicMock() + + pad = AsyncMock() + pad.execute = AsyncMock( + return_value=self._make_cell(stdout="", stderr="password authentication failed") + ) + scratchpads = AsyncMock() + scratchpads.get_or_create = AsyncMock(return_value=pad) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + await _handle_test_datasource(console, scratchpads, "postgresql-prod_db") + + printed = " ".join(str(c) for c in console.print.call_args_list) + assert "✗" in printed or "failed" in printed.lower() + + @pytest.mark.asyncio + async def test_unknown_connection(self, vault_dir, registry): + from unittest.mock import AsyncMock, MagicMock, patch + from anton.chat import _handle_test_datasource + + vault = DataVault(vault_dir=vault_dir) + console = MagicMock() + console.print = MagicMock() + scratchpads = AsyncMock() + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + await _handle_test_datasource(console, scratchpads, "postgresql-ghost") + + printed = " ".join(str(c) for c in console.print.call_args_list) + assert "not found" in printed.lower() or "No connection" in printed + + @pytest.mark.asyncio + async def test_empty_slug_shows_usage(self, vault_dir, registry): + from unittest.mock import AsyncMock, MagicMock, patch + from anton.chat import _handle_test_datasource + + console = MagicMock() + console.print = MagicMock() + scratchpads = AsyncMock() + + with ( + patch("anton.chat.DataVault", return_value=DataVault(vault_dir=vault_dir)), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + await _handle_test_datasource(console, scratchpads, "") + + printed = " ".join(str(c) for c in console.print.call_args_list) + assert "Usage" in printed or "test-data-source" in printed + + @pytest.mark.asyncio + async def test_ds_env_after_test(self, vault_dir, registry): + """After test-data-source: flat vars are gone, namespaced vars are restored.""" + from unittest.mock import AsyncMock, MagicMock, patch + from anton.chat import _handle_test_datasource + + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "prod_db", { + "host": "db.example.com", "port": "5432", + "database": "prod", "user": "alice", "password": "s3cr3t", + }) + + console = MagicMock() + pad = AsyncMock() + pad.execute = AsyncMock(return_value=self._make_cell(stdout="ok")) + scratchpads = AsyncMock() + scratchpads.get_or_create = AsyncMock(return_value=pad) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + await _handle_test_datasource(console, scratchpads, "postgresql-prod_db") + + try: + # Flat vars must be gone + assert "DS_HOST" not in os.environ + assert "DS_PASSWORD" not in os.environ + # Namespaced vars are restored (all saved connections) + assert os.environ.get("DS_POSTGRESQL_PROD_DB__HOST") == "db.example.com" + finally: + vault.clear_ds_env() + + +# ───────────────────────────────────────────────────────────────────────────── +# Edit flow +# ───────────────────────────────────────────────────────────────────────────── + + +class TestEditDatasourceFlow: + def _make_session(self): + from unittest.mock import AsyncMock + from anton.chat import ChatSession + + mock_llm = AsyncMock() + session = ChatSession(mock_llm) + session._scratchpads = AsyncMock() + return session + + def _make_cell(self, stdout="ok", stderr="", error=None): + from unittest.mock import MagicMock + cell = MagicMock() + cell.stdout = stdout + cell.stderr = stderr + cell.error = error + return cell + + @pytest.mark.asyncio + async def test_existing_values_loaded(self, registry, vault_dir): + """Edit shows existing non-secret values as defaults.""" + from unittest.mock import AsyncMock, MagicMock, patch + from anton.chat import _handle_connect_datasource + + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "prod_db", { + "host": "old.host", "port": "5432", + "database": "prod", "user": "alice", "password": "oldpass", + }) + + session = self._make_session() + console = MagicMock() + console.print = MagicMock() + + pad = AsyncMock() + pad.execute = AsyncMock(return_value=self._make_cell(stdout="ok")) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + + # The prompt for 'host' should default to "old.host"; user presses Enter (keeps it) + # The prompt for 'password' is secret; user provides new value + prompt_values = iter([ + "old.host", # host — Enter = keep (returns default) + "5432", # port + "prod", # database + "alice", # user + "newpass", # password + "", # schema (optional) + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_values)), + ): + await _handle_connect_datasource( + console, session._scratchpads, session, datasource_name="postgresql-prod_db" + ) + + saved = vault.load("postgresql", "prod_db") + assert saved["host"] == "old.host" + assert saved["password"] == "newpass" + + @pytest.mark.asyncio + async def test_enter_preserves_secret_value(self, registry, vault_dir): + """Pressing Enter on a secret field keeps the existing value.""" + from unittest.mock import AsyncMock, MagicMock, patch + from anton.chat import _handle_connect_datasource + + vault = DataVault(vault_dir=vault_dir) + original_pass = "original_secret_pass" + vault.save("postgresql", "prod_db", { + "host": "db.host", "port": "5432", + "database": "prod", "user": "alice", "password": original_pass, + }) + + session = self._make_session() + console = MagicMock() + + pad = AsyncMock() + pad.execute = AsyncMock(return_value=self._make_cell(stdout="ok")) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + + # Empty string for secret field = keep existing + prompt_values = iter([ + "db.host", "5432", "prod", "alice", + "", # password — Enter = keep original + "", # schema + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_values)), + ): + await _handle_connect_datasource( + console, session._scratchpads, session, datasource_name="postgresql-prod_db" + ) + + saved = vault.load("postgresql", "prod_db") + assert saved["password"] == original_pass + + @pytest.mark.asyncio + async def test_unknown_slug_returns_session(self, registry, vault_dir): + """Editing a non-existent slug returns the session unchanged.""" + from unittest.mock import AsyncMock, MagicMock, patch + from anton.chat import _handle_connect_datasource + + vault = DataVault(vault_dir=vault_dir) + session = self._make_session() + console = MagicMock() + console.print = MagicMock() + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + result = await _handle_connect_datasource( + console, session._scratchpads, session, datasource_name="postgresql-ghost" + ) + + assert result is session + printed = " ".join(str(c) for c in console.print.call_args_list) + assert "not found" in printed.lower() or "No connection" in printed + + +# ───────────────────────────────────────────────────────────────────────────── +# Remove flow — full coverage +# ───────────────────────────────────────────────────────────────────────────── + + +class TestRemoveDatasourceFlow: + def test_confirmation_yes_deletes(self, vault, registry): + from unittest.mock import patch + from anton.chat import _handle_remove_data_source + from rich.console import Console + + vault.save("postgresql", "prod_db", {"host": "x"}) + console = Console(quiet=True) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("rich.prompt.Confirm.ask", return_value=True), + ): + _handle_remove_data_source(console, "postgresql-prod_db") + + assert vault.load("postgresql", "prod_db") is None + + def test_confirmation_no_preserves(self, vault, registry): + from unittest.mock import patch + from anton.chat import _handle_remove_data_source + from rich.console import Console + + vault.save("postgresql", "prod_db", {"host": "x"}) + console = Console(quiet=True) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("rich.prompt.Confirm.ask", return_value=False), + ): + _handle_remove_data_source(console, "postgresql-prod_db") + + assert vault.load("postgresql", "prod_db") is not None + + def test_unknown_name_shows_message(self, vault_dir): + from unittest.mock import MagicMock, patch + from anton.chat import _handle_remove_data_source + + vault = DataVault(vault_dir=vault_dir) + console = MagicMock() + console.print = MagicMock() + + with patch("anton.chat.DataVault", return_value=vault): + _handle_remove_data_source(console, "postgresql-ghost") + + printed = " ".join(str(c) for c in console.print.call_args_list) + assert "not found" in printed.lower() or "No connection" in printed + + def test_invalid_format_shows_warning(self, vault_dir): + from unittest.mock import MagicMock, patch + from anton.chat import _handle_remove_data_source + + vault = DataVault(vault_dir=vault_dir) + console = MagicMock() + console.print = MagicMock() + + with patch("anton.chat.DataVault", return_value=vault): + _handle_remove_data_source(console, "nohyphen") + + printed = " ".join(str(c) for c in console.print.call_args_list) + assert "Invalid" in printed or "engine-name" in printed + + +# ───────────────────────────────────────────────────────────────────────────── +# Remove slash command — empty name bug fix +# ───────────────────────────────────────────────────────────────────────────── + + +class TestRemoveSlashCommandEmptyName: + """Ensure /remove-data-source without arg does NOT call the handler.""" + + def test_empty_arg_does_not_call_handler(self): + """The chat loop must not call _handle_remove_data_source with an empty slug.""" + from unittest.mock import MagicMock, patch + # Import the handler to verify it's not called with empty arg + with patch("anton.chat._handle_remove_data_source") as mock_remove: + # Simulate what the chat loop does when arg is empty + cmd = "/remove-data-source" + parts = [cmd] # no argument + arg = parts[1].strip() if len(parts) > 1 else "" + assert arg == "" + # The fixed logic: + if not arg: + pass # show usage message, do NOT call handler + else: + from anton.chat import _handle_remove_data_source + _handle_remove_data_source(MagicMock(), arg) + + mock_remove.assert_not_called() + + +# ───────────────────────────────────────────────────────────────────────────── +# Environment activation — collision-free behavior +# ───────────────────────────────────────────────────────────────────────────── + + +class TestEnvActivationCollisionFree: + def _make_session(self): + from unittest.mock import AsyncMock + from anton.chat import ChatSession + + mock_llm = AsyncMock() + session = ChatSession(mock_llm) + session._scratchpads = AsyncMock() + return session + + def _make_cell(self, stdout="ok", stderr="", error=None): + from unittest.mock import MagicMock + cell = MagicMock() + cell.stdout = stdout + cell.stderr = stderr + cell.error = error + return cell + + @pytest.mark.asyncio + async def test_connect_clears_previous_ds_vars(self, registry, vault_dir): + """After a successful new connect, only the new connection's DS_* vars are set.""" + from unittest.mock import AsyncMock, patch + from anton.chat import _handle_connect_datasource + + # Pre-inject an "old" connection + os.environ["DS_ACCESS_TOKEN"] = "old-token" + + vault = DataVault(vault_dir=vault_dir) + session = self._make_session() + console = MagicMock() + + pad = AsyncMock() + pad.execute = AsyncMock(return_value=self._make_cell(stdout="ok")) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + + prompt_responses = iter([ + "PostgreSQL", "y", + "db.example.com", "5432", "prod_db", "alice", "s3cr3t", "", + ]) + + try: + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + ): + await _handle_connect_datasource(console, session._scratchpads, session) + + # Old flat token must be gone; flat vars are never kept in runtime + assert "DS_ACCESS_TOKEN" not in os.environ + # Namespaced vars for the new connection must be present + # name_from=database → name="prod_db" → DS_POSTGRESQL_PROD_DB__HOST + assert os.environ.get("DS_POSTGRESQL_PROD_DB__HOST") == "db.example.com" + finally: + vault.clear_ds_env() + os.environ.pop("DS_ACCESS_TOKEN", None) + + @pytest.mark.asyncio + async def test_two_same_type_connections_no_collision(self, registry, vault_dir): + """Activating one of two same-type connections sets only that one's vars.""" + from unittest.mock import MagicMock, patch + from anton.chat import _handle_connect_datasource + + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "db1", { + "host": "host1.example.com", "port": "5432", + "database": "db1", "user": "u1", "password": "p1", + }) + vault.save("postgresql", "db2", { + "host": "host2.example.com", "port": "5432", + "database": "db2", "user": "u2", "password": "p2", + }) + + session = self._make_session() + console = MagicMock() + + try: + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + # Reconnect to db2 (prefill path) + await _handle_connect_datasource( + console, session._scratchpads, session, + prefill="postgresql-db2", + ) + + # Both connections remain available as namespaced vars (no collision) + assert os.environ.get("DS_POSTGRESQL_DB1__HOST") == "host1.example.com" + assert os.environ.get("DS_POSTGRESQL_DB2__HOST") == "host2.example.com" + assert os.environ.get("DS_POSTGRESQL_DB2__DATABASE") == "db2" + # No flat vars + assert "DS_HOST" not in os.environ + assert "DS_DATABASE" not in os.environ + finally: + vault.clear_ds_env() + + @pytest.mark.asyncio + async def test_test_datasource_does_not_leave_vars(self, registry, vault_dir): + """_handle_test_datasource cleans up all DS_* vars after testing.""" + from unittest.mock import AsyncMock, MagicMock, patch + from anton.chat import _handle_test_datasource + + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "prod_db", { + "host": "db.example.com", "port": "5432", + "database": "prod", "user": "alice", "password": "s3cr3t", + }) + + pad = AsyncMock() + pad.execute = AsyncMock( + return_value=MagicMock(stdout="ok", stderr="", error=None) + ) + scratchpads = AsyncMock() + scratchpads.get_or_create = AsyncMock(return_value=pad) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + await _handle_test_datasource(MagicMock(), scratchpads, "postgresql-prod_db") + + try: + # Flat vars must not be present after the test + flat_leaked = [k for k in os.environ if k.startswith("DS_") and "__" not in k] + assert flat_leaked == [], f"Flat DS_* vars leaked after test: {flat_leaked}" + # Namespaced vars are restored — that is the expected runtime state + assert os.environ.get("DS_POSTGRESQL_PROD_DB__HOST") == "db.example.com" + finally: + vault.clear_ds_env() + + +# ───────────────────────────────────────────────────────────────────────────── +# Chat slash-command behavior +# ───────────────────────────────────────────────────────────────────────────── + + +class TestChatSlashCommands: + """Verify slash-command routing logic for datasource commands.""" + + def test_list_data_sources_slash_command_routes(self): + """'/list-data-sources' cmd string maps to _handle_list_data_sources.""" + # Verify the function is importable and callable (routing covered by chat loop) + from anton.chat import _handle_list_data_sources + assert callable(_handle_list_data_sources) + + def test_test_data_source_slash_command_routes(self): + """'/test-data-source' is importable and async.""" + import inspect + from anton.chat import _handle_test_datasource + assert inspect.iscoroutinefunction(_handle_test_datasource) + + def test_remove_data_source_slash_command_routes(self): + from anton.chat import _handle_remove_data_source + assert callable(_handle_remove_data_source) + + def test_edit_data_source_routes_to_connect_handler(self): + """'/edit-data-source' uses _handle_connect_datasource with datasource_name arg.""" + import inspect + from anton.chat import _handle_connect_datasource + sig = inspect.signature(_handle_connect_datasource) + assert "datasource_name" in sig.parameters + + @pytest.mark.asyncio + async def test_test_data_source_no_arg_shows_usage(self, vault_dir, registry): + from unittest.mock import AsyncMock, MagicMock, patch + from anton.chat import _handle_test_datasource + + console = MagicMock() + console.print = MagicMock() + scratchpads = AsyncMock() + + with ( + patch("anton.chat.DataVault", return_value=DataVault(vault_dir=vault_dir)), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + await _handle_test_datasource(console, scratchpads, "") + + printed = " ".join(str(c) for c in console.print.call_args_list) + assert "Usage" in printed or "test-data-source" in printed + + @pytest.mark.asyncio + async def test_edit_data_source_no_arg_safe(self, vault_dir, registry): + """'/edit-data-source' without arg (datasource_name=None) triggers new-connect flow.""" + from anton.chat import _handle_connect_datasource + # datasource_name=None means new connect, not edit — no crash expected + from anton.chat import ChatSession + from unittest.mock import AsyncMock, MagicMock, patch + + mock_llm = AsyncMock() + session = ChatSession(mock_llm) + session._scratchpads = AsyncMock() + console = MagicMock() + + with ( + patch( + "anton.chat.DataVault", + return_value=DataVault(vault_dir=vault_dir), + ), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", return_value="UnknownEngine"), + ): + updated = await _handle_connect_datasource( + console, session._scratchpads, session, + datasource_name=None, + ) + # Should return session (even if unknown engine) + assert updated is not None + +# ───────────────────────────────────────────────────────────────────────────── +# _slug_env_prefix +# ───────────────────────────────────────────────────────────────────────────── + + +class TestSlugEnvPrefix: + """Unit tests for the _slug_env_prefix helper.""" + + def test_basic_engine_and_name(self): + from anton.data_vault import _slug_env_prefix + assert _slug_env_prefix("postgres", "prod_db") == "DS_POSTGRES_PROD_DB" + + def test_hubspot_main(self): + from anton.data_vault import _slug_env_prefix + assert _slug_env_prefix("hubspot", "main") == "DS_HUBSPOT_MAIN" + + def test_sanitizes_hyphen_and_dot(self): + from anton.data_vault import _slug_env_prefix + assert _slug_env_prefix("postgres", "prod-db.eu") == "DS_POSTGRES_PROD_DB_EU" + + def test_numeric_name(self): + from anton.data_vault import _slug_env_prefix + assert _slug_env_prefix("postgresql", "1") == "DS_POSTGRESQL_1" + + def test_uppercase_result(self): + from anton.data_vault import _slug_env_prefix + result = _slug_env_prefix("myengine", "myname") + assert result == result.upper() + + def test_double_underscore_separator_in_full_var(self): + """The separator between prefix and field must be double underscore.""" + from anton.data_vault import _slug_env_prefix + prefix = _slug_env_prefix("postgres", "prod_db") + full_var = f"{prefix}__HOST" + assert full_var == "DS_POSTGRES_PROD_DB__HOST" + + +# ───────────────────────────────────────────────────────────────────────────── +# Namespaced runtime env — _build_datasource_context +# ───────────────────────────────────────────────────────────────────────────── + + +class TestNamespacedRuntimeEnv: + """Tests for namespaced vars in datasource context and multi-source access.""" + + def test_build_datasource_context_shows_namespaced_vars(self, vault_dir): + from unittest.mock import patch + from anton.chat import _build_datasource_context + + vault = DataVault(vault_dir=vault_dir) + vault.save("postgres", "prod_db", {"host": "pg.example.com", "password": "s3cr3t"}) + + with patch("anton.chat.DataVault", return_value=vault): + ctx = _build_datasource_context() + + assert "DS_POSTGRES_PROD_DB__HOST" in ctx + assert "DS_POSTGRES_PROD_DB__PASSWORD" in ctx + # No flat vars + assert "DS_HOST" not in ctx + + def test_build_datasource_context_shows_slug_and_engine_label(self, vault_dir): + from unittest.mock import patch + from anton.chat import _build_datasource_context + + vault = DataVault(vault_dir=vault_dir) + vault.save("postgres", "prod_db", {"host": "pg.example.com"}) + + with patch("anton.chat.DataVault", return_value=vault): + ctx = _build_datasource_context() + + assert "postgres-prod_db" in ctx + assert "(postgres)" in ctx + + def test_multi_source_context_shows_both_connections(self, vault_dir): + """Both connections are visible in the context with their namespaced vars.""" + from unittest.mock import patch + from anton.chat import _build_datasource_context + + vault = DataVault(vault_dir=vault_dir) + vault.save("postgres", "prod_db", {"host": "pg.example.com"}) + vault.save("hubspot", "main", {"access_token": "pat-abc"}) + + with patch("anton.chat.DataVault", return_value=vault): + ctx = _build_datasource_context() + + assert "postgres-prod_db" in ctx + assert "DS_POSTGRES_PROD_DB__HOST" in ctx + assert "hubspot-main" in ctx + assert "DS_HUBSPOT_MAIN__ACCESS_TOKEN" in ctx + + def test_build_datasource_context_header_mentions_namespaced(self, vault_dir): + """The header text explains the namespaced pattern.""" + from unittest.mock import patch + from anton.chat import _build_datasource_context + + vault = DataVault(vault_dir=vault_dir) + vault.save("postgres", "prod_db", {"host": "pg.example.com"}) + + with patch("anton.chat.DataVault", return_value=vault): + ctx = _build_datasource_context() + + assert "namespaced" in ctx.lower() or "DS_" in ctx + + +# ───────────────────────────────────────────────────────────────────────────── +# _register_secret_vars — namespaced mode +# ───────────────────────────────────────────────────────────────────────────── + + +class TestRegisterSecretVarsNamespaced: + """Tests for namespaced secret var registration.""" + + def setup_method(self): + from anton.chat import _DS_KNOWN_VARS, _DS_SECRET_VARS + _DS_SECRET_VARS.clear() + _DS_KNOWN_VARS.clear() + ds_keys = [k for k in os.environ if k.startswith("DS_")] + for k in ds_keys: + del os.environ[k] + + def test_register_with_slug_uses_namespaced_keys(self, registry): + from anton.chat import _DS_KNOWN_VARS, _DS_SECRET_VARS, _register_secret_vars + + pg = registry.get("postgresql") + _register_secret_vars(pg, engine="postgresql", name="prod_db") + + assert "DS_POSTGRESQL_PROD_DB__PASSWORD" in _DS_SECRET_VARS + assert "DS_POSTGRESQL_PROD_DB__HOST" not in _DS_SECRET_VARS + assert "DS_POSTGRESQL_PROD_DB__HOST" in _DS_KNOWN_VARS + + def test_register_without_slug_uses_flat_keys(self, registry): + from anton.chat import _DS_SECRET_VARS, _register_secret_vars + + pg = registry.get("postgresql") + _register_secret_vars(pg) # no engine/name → flat mode + + assert "DS_PASSWORD" in _DS_SECRET_VARS + assert "DS_HOST" not in _DS_SECRET_VARS + + def test_scrub_replaces_namespaced_secret_value(self, registry): + from anton.chat import _DS_SECRET_VARS, _register_secret_vars, _scrub_credentials + + pg = registry.get("postgresql") + _register_secret_vars(pg, engine="postgresql", name="prod_db") + + secret = "namespacedpassword123" + os.environ["DS_POSTGRESQL_PROD_DB__PASSWORD"] = secret + try: + result = _scrub_credentials(f"error: {secret}") + assert secret not in result + assert "[DS_POSTGRESQL_PROD_DB__PASSWORD]" in result + finally: + del os.environ["DS_POSTGRESQL_PROD_DB__PASSWORD"] + + def test_scrub_leaves_namespaced_non_secret_readable(self, registry): + from anton.chat import _register_secret_vars, _scrub_credentials + + pg = registry.get("postgresql") + _register_secret_vars(pg, engine="postgresql", name="prod_db") + + os.environ["DS_POSTGRESQL_PROD_DB__HOST"] = "db.example.com" + try: + result = _scrub_credentials("host=db.example.com") + assert "db.example.com" in result + finally: + del os.environ["DS_POSTGRESQL_PROD_DB__HOST"] + + +# ───────────────────────────────────────────────────────────────────────────── +# Temporary flat activation and restoration +# ───────────────────────────────────────────────────────────────────────────── + + +class TestTemporaryFlatExecution: + """Tests that flat vars are used only during test_snippet, then restored.""" + + def test_restore_namespaced_env_clears_flat_and_reinjects(self, vault_dir): + """_restore_namespaced_env replaces flat vars with namespaced vars.""" + from unittest.mock import patch + from anton.chat import _restore_namespaced_env + + vault = DataVault(vault_dir=vault_dir) + vault.save("postgres", "analytics", {"host": "analytics.example.com"}) + + # Simulate a flat injection (as done during test_snippet execution) + vault.inject_env("postgres", "analytics", flat=True) + assert os.environ.get("DS_HOST") == "analytics.example.com" + assert "DS_POSTGRES_ANALYTICS__HOST" not in os.environ + + with patch("anton.chat.DataVault", return_value=vault): + _restore_namespaced_env(vault) + + # Flat var is gone; namespaced var is back + assert "DS_HOST" not in os.environ + assert os.environ.get("DS_POSTGRES_ANALYTICS__HOST") == "analytics.example.com" + vault.clear_ds_env() + + def test_restore_namespaced_env_reinjects_all_connections(self, vault_dir): + """_restore_namespaced_env restores ALL saved connections, not just one.""" + from unittest.mock import patch + from anton.chat import _restore_namespaced_env + + vault = DataVault(vault_dir=vault_dir) + vault.save("postgres", "prod_db", {"host": "prod.example.com"}) + vault.save("hubspot", "main", {"access_token": "pat-abc"}) + + # Simulate flat injection for one connection only + vault.clear_ds_env() + vault.inject_env("postgres", "prod_db", flat=True) + + with patch("anton.chat.DataVault", return_value=vault): + _restore_namespaced_env(vault) + + # Both connections are available as namespaced vars + assert "DS_HOST" not in os.environ + assert os.environ.get("DS_POSTGRES_PROD_DB__HOST") == "prod.example.com" + assert os.environ.get("DS_HUBSPOT_MAIN__ACCESS_TOKEN") == "pat-abc" + vault.clear_ds_env() + + @pytest.mark.asyncio + async def test_test_datasource_injects_flat_then_restores_namespaced( + self, vault_dir, registry + ): + """_handle_test_datasource uses flat vars during snippet, then restores namespaced.""" + from unittest.mock import AsyncMock, MagicMock, patch + from anton.chat import _handle_test_datasource + + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "prod_db", { + "host": "pg.example.com", "port": "5432", + "database": "prod_db", "user": "alice", "password": "s3cr3t", + }) + # Pre-save a second connection that should be restored after the test + vault.save("hubspot", "main", {"access_token": "pat-abc"}) + vault.inject_env("postgresql", "prod_db") + vault.inject_env("hubspot", "main") + + env_during_test: dict = {} + + async def capture_execute(snippet): + # Capture env state mid-execution to verify flat vars are set + env_during_test["DS_HOST"] = os.environ.get("DS_HOST") + env_during_test["DS_POSTGRESQL_PROD_DB__HOST"] = os.environ.get( + "DS_POSTGRESQL_PROD_DB__HOST" + ) + return MagicMock(stdout="ok", stderr="", error=None) + + pad = AsyncMock() + pad.execute = capture_execute + pad.reset = AsyncMock() + pad.install_packages = AsyncMock() + + scratchpads = AsyncMock() + scratchpads.get_or_create = AsyncMock(return_value=pad) + console = MagicMock() + console.print = MagicMock() + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): + await _handle_test_datasource(console, scratchpads, "postgresql-prod_db") + + # During execution: flat var was set, namespaced was absent + assert env_during_test["DS_HOST"] == "pg.example.com" + assert env_during_test["DS_POSTGRESQL_PROD_DB__HOST"] is None + + # After execution: flat vars gone, namespaced restored + assert "DS_HOST" not in os.environ + assert os.environ.get("DS_POSTGRESQL_PROD_DB__HOST") == "pg.example.com" + assert os.environ.get("DS_HUBSPOT_MAIN__ACCESS_TOKEN") == "pat-abc" + vault.clear_ds_env() From 445778e789196bf3f8888c3d407cfc98439dfc30 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Fri, 20 Mar 2026 20:10:27 +0100 Subject: [PATCH 12/70] Fix test --- tests/test_datasource.py | 1194 ++++++++++---------------------------- 1 file changed, 322 insertions(+), 872 deletions(-) diff --git a/tests/test_datasource.py b/tests/test_datasource.py index 1d1f194..2ceca44 100644 --- a/tests/test_datasource.py +++ b/tests/test_datasource.py @@ -1,19 +1,33 @@ from __future__ import annotations +import io import json import os -from pathlib import Path from textwrap import dedent from unittest.mock import AsyncMock, MagicMock, patch import pytest - -from anton.data_vault import DataVault +from rich.console import Console + +from anton.chat import ( + ChatSession, + _DS_KNOWN_VARS, + _DS_SECRET_VARS, + _build_datasource_context, + _handle_connect_datasource, + _handle_list_data_sources, + _handle_remove_data_source, + _handle_test_datasource, + _register_secret_vars, + _restore_namespaced_env, + _scrub_credentials, +) +from anton.cli import app as _cli_app +from anton.data_vault import DataVault, _slug_env_prefix from anton.datasource_registry import ( - AuthMethod, DatasourceEngine, - DatasourceField, DatasourceRegistry, + _parse_file, ) @@ -112,15 +126,51 @@ def datasources_md(tmp_path): @pytest.fixture() -def registry(datasources_md, tmp_path): +def registry(datasources_md): """Registry pointing at our temp datasources.md, no user overrides.""" reg = DatasourceRegistry.__new__(DatasourceRegistry) - reg._engines = {} - from anton.datasource_registry import _parse_file reg._engines = _parse_file(datasources_md) return reg +@pytest.fixture() +def make_session(): + """Factory that creates a fresh ChatSession with mocked scratchpads.""" + def _factory(): + mock_llm = AsyncMock() + session = ChatSession(mock_llm) + session._scratchpads = AsyncMock() + return session + return _factory + + +@pytest.fixture() +def make_cell(): + """Factory that creates a MagicMock scratchpad execution cell.""" + def _factory(stdout="ok", stderr="", error=None): + cell = MagicMock() + cell.stdout = stdout + cell.stderr = stderr + cell.error = error + return cell + return _factory + + +@pytest.fixture(autouse=True) +def clean_ds_state(): + """Clear _DS_SECRET_VARS, _DS_KNOWN_VARS, and all DS_* env vars around each test.""" + def _clean(): + _DS_SECRET_VARS.clear() + _DS_KNOWN_VARS.clear() + for k in list(os.environ): + if k.startswith("DS_"): + del os.environ[k] + + _clean() + yield + _clean() + + # ───────────────────────────────────────────────────────────────────────────── # DataVault — save / load / delete # ───────────────────────────────────────────────────────────────────────────── @@ -145,8 +195,7 @@ def test_vault_dir_permissions(self, vault, vault_dir): def test_load_returns_fields(self, vault): creds = {"host": "db.example.com", "port": "5432", "password": "secret"} vault.save("postgresql", "prod_db", creds) - loaded = vault.load("postgresql", "prod_db") - assert loaded == creds + assert vault.load("postgresql", "prod_db") == creds def test_load_missing_returns_none(self, vault): assert vault.load("postgresql", "nonexistent") is None @@ -210,7 +259,6 @@ def test_skips_corrupt_files(self, vault, vault_dir): assert conns[0]["name"] == "good" def test_vault_dir_missing_returns_empty(self, vault): - # vault_dir was never created assert vault.list_connections() == [] @@ -226,12 +274,9 @@ def test_inject_sets_ds_vars(self, vault): assert os.environ.get("DS_POSTGRESQL_PROD_DB__HOST") == "db.example.com" assert os.environ.get("DS_POSTGRESQL_PROD_DB__PASSWORD") == "s3cr3t" assert set(var_names) == {"DS_POSTGRESQL_PROD_DB__HOST", "DS_POSTGRESQL_PROD_DB__PASSWORD"} - # Cleanup - vault.clear_ds_env() def test_inject_missing_returns_none(self, vault): - result = vault.inject_env("postgresql", "ghost") - assert result is None + assert vault.inject_env("postgresql", "ghost") is None def test_clear_removes_ds_vars(self, vault): vault.save("postgresql", "prod_db", {"host": "x"}) @@ -239,17 +284,15 @@ def test_clear_removes_ds_vars(self, vault): vault.clear_ds_env() assert "DS_POSTGRESQL_PROD_DB__HOST" not in os.environ - def test_clear_leaves_non_ds_vars(self, vault): - os.environ["MY_VAR"] = "untouched" + def test_clear_leaves_non_ds_vars(self, vault, monkeypatch): + monkeypatch.setenv("MY_VAR", "untouched") vault.clear_ds_env() assert os.environ.get("MY_VAR") == "untouched" - del os.environ["MY_VAR"] def test_inject_uppercases_field_names(self, vault): vault.save("postgresql", "prod_db", {"access_token": "tok123"}) vault.inject_env("postgresql", "prod_db") assert os.environ.get("DS_POSTGRESQL_PROD_DB__ACCESS_TOKEN") == "tok123" - vault.clear_ds_env() def test_inject_flat_mode_sets_flat_vars(self, vault): """flat=True injects legacy DS_FIELD vars, not namespaced ones.""" @@ -258,7 +301,6 @@ def test_inject_flat_mode_sets_flat_vars(self, vault): assert os.environ.get("DS_HOST") == "db.example.com" assert "DS_POSTGRESQL_PROD_DB__HOST" not in os.environ assert set(var_names) == {"DS_HOST"} - vault.clear_ds_env() def test_two_same_type_connections_no_collision(self, vault): """Two connections of the same engine type coexist without overwriting each other.""" @@ -266,13 +308,9 @@ def test_two_same_type_connections_no_collision(self, vault): vault.save("postgres", "analytics", {"host": "analytics.example.com"}) vault.inject_env("postgres", "prod_db") vault.inject_env("postgres", "analytics") - try: - assert os.environ.get("DS_POSTGRES_PROD_DB__HOST") == "prod.example.com" - assert os.environ.get("DS_POSTGRES_ANALYTICS__HOST") == "analytics.example.com" - # The two vars are distinct — no collision - assert os.environ.get("DS_POSTGRES_PROD_DB__HOST") != os.environ.get("DS_POSTGRES_ANALYTICS__HOST") - finally: - vault.clear_ds_env() + assert os.environ.get("DS_POSTGRES_PROD_DB__HOST") == "prod.example.com" + assert os.environ.get("DS_POSTGRES_ANALYTICS__HOST") == "analytics.example.com" + assert os.environ.get("DS_POSTGRES_PROD_DB__HOST") != os.environ.get("DS_POSTGRES_ANALYTICS__HOST") def test_different_engines_no_collision(self, vault): """Connections from different engines coexist simultaneously.""" @@ -280,24 +318,15 @@ def test_different_engines_no_collision(self, vault): vault.save("hubspot", "main", {"access_token": "pat-abc"}) vault.inject_env("postgres", "prod_db") vault.inject_env("hubspot", "main") - try: - assert os.environ.get("DS_POSTGRES_PROD_DB__HOST") == "pg.example.com" - assert os.environ.get("DS_HUBSPOT_MAIN__ACCESS_TOKEN") == "pat-abc" - finally: - vault.clear_ds_env() + assert os.environ.get("DS_POSTGRES_PROD_DB__HOST") == "pg.example.com" + assert os.environ.get("DS_HUBSPOT_MAIN__ACCESS_TOKEN") == "pat-abc" def test_slug_env_prefix_sanitizes_special_chars(self, vault): - """Special characters in names are sanitized to underscores.""" - from anton.data_vault import _slug_env_prefix - + """Special characters in names produce correct namespaced vars.""" assert _slug_env_prefix("postgres", "prod-db.eu") == "DS_POSTGRES_PROD_DB_EU" - # Full var vault.save("postgres", "prod-db.eu", {"host": "eu.pg.com"}) vault.inject_env("postgres", "prod-db.eu") - try: - assert os.environ.get("DS_POSTGRES_PROD_DB_EU__HOST") == "eu.pg.com" - finally: - vault.clear_ds_env() + assert os.environ.get("DS_POSTGRES_PROD_DB_EU__HOST") == "eu.pg.com" # ───────────────────────────────────────────────────────────────────────────── @@ -322,7 +351,6 @@ def test_ignores_named_connections(self, vault): def test_does_not_confuse_engines(self, vault): vault.save("hubspot", "1", {"access_token": "x"}) vault.save("hubspot", "2", {"access_token": "y"}) - # postgresql counter is independent assert vault.next_connection_number("postgresql") == 1 @@ -340,16 +368,9 @@ def test_get_by_slug(self, registry): def test_get_missing_returns_none(self, registry): assert registry.get("mysql") is None - def test_find_by_name_exact(self, registry): - assert registry.find_by_name("PostgreSQL") is not None - - def test_find_by_name_case_insensitive(self, registry): - assert registry.find_by_name("postgresql") is not None - assert registry.find_by_name("POSTGRESQL") is not None - - def test_find_by_slug(self, registry): - # engine slug is also accepted - assert registry.find_by_name("postgresql") is not None + @pytest.mark.parametrize("query", ["PostgreSQL", "postgresql", "POSTGRESQL"]) + def test_find_by_name_variants(self, registry, query): + assert registry.find_by_name(query) is not None def test_find_unknown_returns_none(self, registry): assert registry.find_by_name("MySQL") is None @@ -386,7 +407,7 @@ def test_pip_field(self, registry): def test_test_snippet_present(self, registry): engine = registry.get("postgresql") - assert "print(\"ok\")" in engine.test_snippet + assert 'print("ok")' in engine.test_snippet def test_auth_method_choice_parsed(self, registry): engine = registry.get("hubspot") @@ -417,15 +438,13 @@ def test_single_field_name_from(self, registry): def test_missing_name_from_field_returns_empty(self, registry): engine = registry.get("postgresql") - name = registry.derive_name(engine, {"host": "x"}) # no "database" - assert name == "" + assert registry.derive_name(engine, {"host": "x"}) == "" def test_no_name_from_returns_empty(self): engine = DatasourceEngine(engine="test", display_name="Test", name_from="") reg = DatasourceRegistry.__new__(DatasourceRegistry) reg._engines = {} - name = reg.derive_name(engine, {"host": "x"}) - assert name == "" + assert reg.derive_name(engine, {"host": "x"}) == "" def test_list_name_from(self): engine = DatasourceEngine( @@ -446,8 +465,7 @@ def test_list_name_from_skips_missing(self): ) reg = DatasourceRegistry.__new__(DatasourceRegistry) reg._engines = {} - name = reg.derive_name(engine, {"host": "db.example.com"}) - assert name == "db.example.com" + assert reg.derive_name(engine, {"host": "db.example.com"}) == "db.example.com" # ───────────────────────────────────────────────────────────────────────────── @@ -474,8 +492,6 @@ def test_user_override_wins(self, tmp_path, datasources_md): ``` """)) - from anton.datasource_registry import _parse_file - builtin = _parse_file(datasources_md) user = _parse_file(user_md) merged = {**builtin, **user} @@ -483,11 +499,8 @@ def test_user_override_wins(self, tmp_path, datasources_md): assert merged["postgresql"].display_name == "PostgreSQL (custom)" assert merged["postgresql"].pip == "psycopg2" - def test_missing_user_file_falls_back_to_builtin(self, tmp_path, datasources_md): - from anton.datasource_registry import _parse_file - - user_engines = _parse_file(tmp_path / "nonexistent.md") - assert user_engines == {} + def test_missing_user_file_falls_back_to_builtin(self, tmp_path): + assert _parse_file(tmp_path / "nonexistent.md") == {} # ───────────────────────────────────────────────────────────────────────────── @@ -498,29 +511,11 @@ def test_missing_user_file_falls_back_to_builtin(self, tmp_path, datasources_md) class TestHandleConnectDatasource: """Test the slash-command handler with mocked prompts and scratchpad.""" - def _make_session(self): - from anton.chat import ChatSession - - mock_llm = AsyncMock() - session = ChatSession(mock_llm) - session._scratchpads = AsyncMock() - return session - - def _make_cell(self, stdout="ok", stderr="", error=None): - cell = MagicMock() - cell.stdout = stdout - cell.stderr = stderr - cell.error = error - return cell - @pytest.mark.asyncio - async def test_unknown_engine_returns_early(self, registry, vault_dir, capsys): + async def test_unknown_engine_returns_early(self, registry, vault_dir, make_session): """Typing an unknown engine name aborts without saving anything.""" - from anton.chat import _handle_connect_datasource - - session = self._make_session() + session = make_session() console = MagicMock() - console.print = MagicMock() with ( patch("anton.chat.DataVault", return_value=DataVault(vault_dir=vault_dir)), @@ -529,18 +524,14 @@ async def test_unknown_engine_returns_early(self, registry, vault_dir, capsys): ): result = await _handle_connect_datasource(console, session._scratchpads, session) - assert result is session # unchanged session + assert result is session assert DataVault(vault_dir=vault_dir).list_connections() == [] @pytest.mark.asyncio - async def test_partial_save_on_n_answer(self, registry, vault_dir): + async def test_partial_save_on_n_answer(self, registry, vault_dir, make_session): """Answering 'n' saves partial credentials and returns without testing.""" - from anton.chat import _handle_connect_datasource - - session = self._make_session() + session = make_session() console = MagicMock() - console.print = MagicMock() - vault = DataVault(vault_dir=vault_dir) prompt_responses = iter(["PostgreSQL", "n", "db.example.com", "", "", "", "", ""]) @@ -549,39 +540,30 @@ async def test_partial_save_on_n_answer(self, registry, vault_dir): patch("anton.chat.DatasourceRegistry", return_value=registry), patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), ): - result = await _handle_connect_datasource(console, session._scratchpads, session) + await _handle_connect_datasource(console, session._scratchpads, session) conns = vault.list_connections() assert len(conns) == 1 assert conns[0]["engine"] == "postgresql" - # Partial connections get auto-numbered names assert conns[0]["name"].isdigit() - # Scratchpad was NOT used for testing session._scratchpads.get_or_create.assert_not_called() @pytest.mark.asyncio - async def test_successful_connection_saves_and_injects_history(self, registry, vault_dir): + async def test_successful_connection_saves_and_injects_history( + self, registry, vault_dir, make_session, make_cell + ): """Happy path: test passes, credentials saved, history entry added.""" - from anton.chat import _handle_connect_datasource - - session = self._make_session() + session = make_session() console = MagicMock() - console.print = MagicMock() vault = DataVault(vault_dir=vault_dir) pad = AsyncMock() - pad.execute = AsyncMock(return_value=self._make_cell(stdout="ok")) + pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) session._scratchpads.get_or_create = AsyncMock(return_value=pad) prompt_responses = iter([ - "PostgreSQL", # engine choice - "y", # have all credentials - "db.example.com", # host - "5432", # port - "prod_db", # database - "alice", # user - "s3cr3t", # password - "", # schema (optional) + "PostgreSQL", "y", + "db.example.com", "5432", "prod_db", "alice", "s3cr3t", "", ]) with ( @@ -591,46 +573,36 @@ async def test_successful_connection_saves_and_injects_history(self, registry, v ): result = await _handle_connect_datasource(console, session._scratchpads, session) - # Credentials saved conns = vault.list_connections() assert len(conns) == 1 saved = vault.load("postgresql", conns[0]["name"]) + assert saved is not None assert saved["host"] == "db.example.com" assert saved["password"] == "s3cr3t" - - # History entry injected assert result._history last = result._history[-1] assert last["role"] == "assistant" assert "postgresql" in last["content"].lower() @pytest.mark.asyncio - async def test_failed_test_offers_retry(self, registry, vault_dir): + async def test_failed_test_offers_retry(self, registry, vault_dir, make_session, make_cell): """Connection test failure prompts for retry; success on second attempt saves.""" - from anton.chat import _handle_connect_datasource - - session = self._make_session() + session = make_session() console = MagicMock() - console.print = MagicMock() vault = DataVault(vault_dir=vault_dir) - fail_cell = self._make_cell(stdout="", stderr="password authentication failed") - ok_cell = self._make_cell(stdout="ok") pad = AsyncMock() - pad.execute = AsyncMock(side_effect=[fail_cell, ok_cell]) + pad.execute = AsyncMock(side_effect=[ + make_cell(stdout="", stderr="password authentication failed"), + make_cell(stdout="ok"), + ]) session._scratchpads.get_or_create = AsyncMock(return_value=pad) prompt_responses = iter([ - "PostgreSQL", # engine - "y", # have all creds - "db.example.com", # host - "5432", # port - "prod_db", # database - "alice", # user - "wrongpassword", # password (first attempt - fails) - "", # schema - "y", # retry? - "correctpassword", # new password + "PostgreSQL", "y", + "db.example.com", "5432", "prod_db", "alice", "wrongpassword", "", + "y", # retry? + "correctpassword", ]) with ( @@ -638,33 +610,31 @@ async def test_failed_test_offers_retry(self, registry, vault_dir): patch("anton.chat.DatasourceRegistry", return_value=registry), patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), ): - result = await _handle_connect_datasource(console, session._scratchpads, session) + await _handle_connect_datasource(console, session._scratchpads, session) - # Should have saved after second attempt conns = vault.list_connections() assert len(conns) == 1 saved = vault.load("postgresql", conns[0]["name"]) + assert saved is not None assert saved["password"] == "correctpassword" @pytest.mark.asyncio - async def test_failed_test_no_retry_returns_without_saving(self, registry, vault_dir): + async def test_failed_test_no_retry_returns_without_saving( + self, registry, vault_dir, make_session, make_cell + ): """Declining retry on failed test leaves vault empty.""" - from anton.chat import _handle_connect_datasource - - session = self._make_session() + session = make_session() console = MagicMock() - console.print = MagicMock() vault = DataVault(vault_dir=vault_dir) - fail_cell = self._make_cell(stdout="", error="connection refused") pad = AsyncMock() - pad.execute = AsyncMock(return_value=fail_cell) + pad.execute = AsyncMock(return_value=make_cell(stdout="", error="connection refused")) session._scratchpads.get_or_create = AsyncMock(return_value=pad) prompt_responses = iter([ "PostgreSQL", "y", "db.example.com", "5432", "prod_db", "alice", "badpass", "", - "n", # don't retry + "n", ]) with ( @@ -675,25 +645,19 @@ async def test_failed_test_no_retry_returns_without_saving(self, registry, vault result = await _handle_connect_datasource(console, session._scratchpads, session) assert vault.list_connections() == [] - # No history injection since save never happened assert not result._history @pytest.mark.asyncio async def test_ds_env_injected_after_successful_connect( - self, registry, vault_dir + self, registry, vault_dir, make_session, make_cell ): - """After a successful connect, DS_* vars are injected into the env.""" - from anton.chat import _handle_connect_datasource - - session = self._make_session() + """After a successful connect, namespaced DS_* vars are injected.""" + session = make_session() console = MagicMock() - console.print = MagicMock() vault = DataVault(vault_dir=vault_dir) pad = AsyncMock() - pad.execute = AsyncMock( - return_value=self._make_cell(stdout="ok") - ) + pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) session._scratchpads.get_or_create = AsyncMock(return_value=pad) prompt_responses = iter([ @@ -701,50 +665,30 @@ async def test_ds_env_injected_after_successful_connect( "db.example.com", "5432", "prod_db", "alice", "s3cr3t", "", ]) - try: - with ( - patch("anton.chat.DataVault", return_value=vault), - patch( - "anton.chat.DatasourceRegistry", - return_value=registry, - ), - patch( - "rich.prompt.Prompt.ask", - side_effect=lambda *a, **kw: next( - prompt_responses - ), - ), - ): - await _handle_connect_datasource( - console, session._scratchpads, session - ) - - # After successful connect, namespaced DS_* vars are injected. - # name_from=database → name="prod_db" → prefix DS_POSTGRESQL_PROD_DB - assert os.environ.get("DS_POSTGRESQL_PROD_DB__HOST") == "db.example.com" - finally: - vault.clear_ds_env() + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + ): + await _handle_connect_datasource(console, session._scratchpads, session) + + # name_from=database → name="prod_db" → prefix DS_POSTGRESQL_PROD_DB + assert os.environ.get("DS_POSTGRESQL_PROD_DB__HOST") == "db.example.com" @pytest.mark.asyncio - async def test_auth_method_choice_selects_fields(self, registry, vault_dir): + async def test_auth_method_choice_selects_fields( + self, registry, vault_dir, make_session, make_cell + ): """Selecting an auth method filters to that method's fields only.""" - from anton.chat import _handle_connect_datasource - - session = self._make_session() + session = make_session() console = MagicMock() - console.print = MagicMock() vault = DataVault(vault_dir=vault_dir) pad = AsyncMock() - pad.execute = AsyncMock(return_value=self._make_cell(stdout="ok")) + pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) session._scratchpads.get_or_create = AsyncMock(return_value=pad) - prompt_responses = iter([ - "HubSpot", # engine - "1", # auth method: private_app - "y", # have all creds - "pat-na1-abc123", # access_token - ]) + prompt_responses = iter(["HubSpot", "1", "y", "pat-na1-abc123"]) with ( patch("anton.chat.DataVault", return_value=vault), @@ -756,31 +700,28 @@ async def test_auth_method_choice_selects_fields(self, registry, vault_dir): conns = vault.list_connections() assert len(conns) == 1 saved = vault.load("hubspot", conns[0]["name"]) + assert saved is not None # Only private_app fields collected — no client_id or client_secret assert "access_token" in saved assert "client_id" not in saved assert "client_secret" not in saved @pytest.mark.asyncio - async def test_selective_field_collection(self, registry, vault_dir): + async def test_selective_field_collection( + self, registry, vault_dir, make_session, make_cell + ): """Typing 'host,user,password' collects only those three fields.""" - from anton.chat import _handle_connect_datasource - - session = self._make_session() + session = make_session() console = MagicMock() - console.print = MagicMock() vault = DataVault(vault_dir=vault_dir) pad = AsyncMock() - pad.execute = AsyncMock(return_value=self._make_cell(stdout="ok")) + pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) session._scratchpads.get_or_create = AsyncMock(return_value=pad) prompt_responses = iter([ - "PostgreSQL", # engine - "host,user,password", # selective list - "db.example.com", # host - "alice", # user - "s3cr3t", # password + "PostgreSQL", "host,user,password", + "db.example.com", "alice", "s3cr3t", ]) with ( @@ -793,6 +734,7 @@ async def test_selective_field_collection(self, registry, vault_dir): conns = vault.list_connections() assert len(conns) == 1 saved = vault.load("postgresql", conns[0]["name"]) + assert saved is not None assert set(saved.keys()) == {"host", "user", "password"} @@ -802,103 +744,55 @@ async def test_selective_field_collection(self, registry, vault_dir): class TestCredentialScrubbing: - """_scrub_credentials and _register_secret_vars.""" - - def setup_method(self): - # Reset the module-level sets and clear any DS_* env vars from other tests - from anton.chat import _DS_KNOWN_VARS, _DS_SECRET_VARS - _DS_SECRET_VARS.clear() - _DS_KNOWN_VARS.clear() - ds_keys = [k for k in os.environ if k.startswith("DS_")] - for k in ds_keys: - del os.environ[k] + """_scrub_credentials and _register_secret_vars — flat and namespaced modes.""" def test_register_secret_vars_adds_secret_fields(self, registry): """Secret fields are added to _DS_SECRET_VARS; non-secret fields are not.""" - from anton.chat import _DS_SECRET_VARS, _register_secret_vars - pg = registry.get("postgresql") assert pg is not None _register_secret_vars(pg) - assert "DS_PASSWORD" in _DS_SECRET_VARS - # host and port are not secret in the fixture definition assert "DS_HOST" not in _DS_SECRET_VARS assert "DS_PORT" not in _DS_SECRET_VARS - def test_scrub_replaces_registered_secret_value(self): + def test_scrub_replaces_registered_secret_value(self, monkeypatch): """A registered secret value is replaced with its placeholder.""" - import os - from anton.chat import _DS_SECRET_VARS, _scrub_credentials - _DS_SECRET_VARS.add("DS_ACCESS_TOKEN") - os.environ["DS_ACCESS_TOKEN"] = "supersecrettoken123" - try: - result = _scrub_credentials("token is supersecrettoken123 here") - assert "supersecrettoken123" not in result - assert "[DS_ACCESS_TOKEN]" in result - finally: - del os.environ["DS_ACCESS_TOKEN"] - _DS_SECRET_VARS.discard("DS_ACCESS_TOKEN") - - def test_scrub_leaves_non_secret_field_readable(self, registry): - """Non-secret DS_* values (host, port) are left untouched.""" - import os - from anton.chat import _register_secret_vars, _scrub_credentials + monkeypatch.setenv("DS_ACCESS_TOKEN", "supersecrettoken123") + result = _scrub_credentials("token is supersecrettoken123 here") + assert "supersecrettoken123" not in result + assert "[DS_ACCESS_TOKEN]" in result + def test_scrub_leaves_non_secret_field_readable(self, registry, monkeypatch): + """Non-secret DS_* values (host, port) are left untouched.""" pg = registry.get("postgresql") assert pg is not None _register_secret_vars(pg) - - os.environ["DS_HOST"] = "db.example.com" - os.environ["DS_PASSWORD"] = "s3cr3tpassword99" - try: - result = _scrub_credentials("host=db.example.com pass=s3cr3tpassword99") - assert "db.example.com" in result # host left readable - assert "s3cr3tpassword99" not in result # password redacted - assert "[DS_PASSWORD]" in result - finally: - del os.environ["DS_HOST"] - del os.environ["DS_PASSWORD"] - - def test_scrub_skips_short_values(self): + monkeypatch.setenv("DS_HOST", "db.example.com") + monkeypatch.setenv("DS_PASSWORD", "s3cr3tpassword99") + result = _scrub_credentials("host=db.example.com pass=s3cr3tpassword99") + assert "db.example.com" in result + assert "s3cr3tpassword99" not in result + assert "[DS_PASSWORD]" in result + + def test_scrub_skips_short_values(self, monkeypatch): """Values of 8 characters or fewer are not scrubbed (e.g. port numbers).""" - import os - from anton.chat import _DS_SECRET_VARS, _scrub_credentials - _DS_SECRET_VARS.add("DS_PASSWORD") - os.environ["DS_PASSWORD"] = "short" # 5 chars — under threshold - try: - result = _scrub_credentials("password=short") - assert "short" in result - finally: - del os.environ["DS_PASSWORD"] - _DS_SECRET_VARS.discard("DS_PASSWORD") - - def test_scrub_fallback_redacts_unknown_long_ds_vars(self): + monkeypatch.setenv("DS_PASSWORD", "short") + result = _scrub_credentials("password=short") + assert "short" in result + + def test_scrub_fallback_redacts_unknown_long_ds_vars(self, monkeypatch): """Long DS_* vars not in _DS_SECRET_VARS are scrubbed as a safety fallback.""" - import os - from anton.chat import _scrub_credentials - - # _DS_SECRET_VARS is empty (cleared in setup_method) - os.environ["DS_WEBHOOK_SECRET"] = "wh_sec_abcdefgh1234" - try: - result = _scrub_credentials("secret=wh_sec_abcdefgh1234 here") - assert "wh_sec_abcdefgh1234" not in result - assert "[DS_WEBHOOK_SECRET]" in result - finally: - del os.environ["DS_WEBHOOK_SECRET"] + monkeypatch.setenv("DS_WEBHOOK_SECRET", "wh_sec_abcdefgh1234") + result = _scrub_credentials("secret=wh_sec_abcdefgh1234 here") + assert "wh_sec_abcdefgh1234" not in result + assert "[DS_WEBHOOK_SECRET]" in result @pytest.mark.asyncio - async def test_register_and_scrub_on_connect(self, registry, vault_dir): + async def test_register_and_scrub_on_connect(self, registry, vault_dir, monkeypatch): """After _handle_connect_datasource, the new secret var is immediately scrubbed.""" - import os - from unittest.mock import AsyncMock, MagicMock, patch - - from anton.chat import _DS_SECRET_VARS, _handle_connect_datasource, _scrub_credentials - vault = DataVault(vault_dir=vault_dir) - session = MagicMock() session._history = [] session._cortex = None @@ -911,14 +805,8 @@ async def test_register_and_scrub_on_connect(self, registry, vault_dir): secret_pw = "supersecretpassword999" prompt_responses = iter([ - "PostgreSQL", # engine - "y", # have all credentials - "db.host.com", # host - "5432", # port - "mydb", # database - "alice", # user - secret_pw, # password - "public", # schema (optional, skip) + "PostgreSQL", "y", + "db.host.com", "5432", "mydb", "alice", secret_pw, "public", ]) with ( @@ -928,17 +816,44 @@ async def test_register_and_scrub_on_connect(self, registry, vault_dir): ): await _handle_connect_datasource(MagicMock(), session._scratchpads, session) - # After connect, namespaced secret var is registered and scrubbed. # name_from=database → name="mydb" → DS_POSTGRESQL_MYDB__PASSWORD namespaced_pw_var = "DS_POSTGRESQL_MYDB__PASSWORD" assert namespaced_pw_var in _DS_SECRET_VARS - os.environ[namespaced_pw_var] = secret_pw - try: - result = _scrub_credentials(f"error: auth failed with {secret_pw}") - assert secret_pw not in result - assert f"[{namespaced_pw_var}]" in result - finally: - del os.environ[namespaced_pw_var] + monkeypatch.setenv(namespaced_pw_var, secret_pw) + result = _scrub_credentials(f"error: auth failed with {secret_pw}") + assert secret_pw not in result + assert f"[{namespaced_pw_var}]" in result + + # ── Namespaced mode ────────────────────────────────────────────────────── + + def test_register_with_slug_uses_namespaced_keys(self, registry): + pg = registry.get("postgresql") + _register_secret_vars(pg, engine="postgresql", name="prod_db") + assert "DS_POSTGRESQL_PROD_DB__PASSWORD" in _DS_SECRET_VARS + assert "DS_POSTGRESQL_PROD_DB__HOST" not in _DS_SECRET_VARS + assert "DS_POSTGRESQL_PROD_DB__HOST" in _DS_KNOWN_VARS + + def test_register_without_slug_uses_flat_keys(self, registry): + pg = registry.get("postgresql") + _register_secret_vars(pg) # no engine/name → flat mode + assert "DS_PASSWORD" in _DS_SECRET_VARS + assert "DS_HOST" not in _DS_SECRET_VARS + + def test_scrub_replaces_namespaced_secret_value(self, registry, monkeypatch): + pg = registry.get("postgresql") + _register_secret_vars(pg, engine="postgresql", name="prod_db") + secret = "namespacedpassword123" + monkeypatch.setenv("DS_POSTGRESQL_PROD_DB__PASSWORD", secret) + result = _scrub_credentials(f"error: {secret}") + assert secret not in result + assert "[DS_POSTGRESQL_PROD_DB__PASSWORD]" in result + + def test_scrub_leaves_namespaced_non_secret_readable(self, registry, monkeypatch): + pg = registry.get("postgresql") + _register_secret_vars(pg, engine="postgresql", name="prod_db") + monkeypatch.setenv("DS_POSTGRESQL_PROD_DB__HOST", "db.example.com") + result = _scrub_credentials("host=db.example.com") + assert "db.example.com" in result # ───────────────────────────────────────────────────────────────────────────── @@ -947,31 +862,20 @@ async def test_register_and_scrub_on_connect(self, registry, vault_dir): class TestActiveDatasourceScoping: - """Tests for /connect-data-source isolating a single datasource.""" - - def _make_session(self): - from anton.chat import ChatSession + """Tests for active datasource routing and multi-source context building.""" - mock_llm = AsyncMock() - session = ChatSession(mock_llm) - session._scratchpads = AsyncMock() - return session - - def test_active_datasource_defaults_to_none(self): - session = self._make_session() + def test_active_datasource_defaults_to_none(self, make_session): + session = make_session() assert session._active_datasource is None @pytest.mark.asyncio - async def test_reconnect_sets_active_datasource(self, vault_dir): + async def test_reconnect_sets_active_datasource(self, vault_dir, make_session): """Reconnecting to a slug via prefill sets session._active_datasource.""" - from anton.chat import _handle_connect_datasource - vault = DataVault(vault_dir=vault_dir) vault.save("hubspot", "2", {"access_token": "pat-xxx"}) - session = self._make_session() + session = make_session() console = MagicMock() - console.print = MagicMock() with ( patch("anton.chat.DataVault", return_value=vault), @@ -984,49 +888,36 @@ async def test_reconnect_sets_active_datasource(self, vault_dir): assert result._active_datasource == "hubspot-2" @pytest.mark.asyncio - async def test_reconnect_all_namespaced_vars_available(self, vault_dir): + async def test_reconnect_all_namespaced_vars_available(self, vault_dir, make_session): """After reconnect, ALL saved connections remain available as namespaced vars.""" - from anton.chat import _handle_connect_datasource - vault = DataVault(vault_dir=vault_dir) vault.save("oracle", "1", {"host": "oracle.host", "user": "admin", "password": "orapass"}) vault.save("hubspot", "2", {"access_token": "pat-xxx"}) - # Simulate startup: inject all connections as namespaced vault.inject_env("oracle", "1") vault.inject_env("hubspot", "2") assert os.environ.get("DS_ORACLE_1__HOST") == "oracle.host" assert os.environ.get("DS_HUBSPOT_2__ACCESS_TOKEN") == "pat-xxx" - session = self._make_session() + session = make_session() console = MagicMock() - console.print = MagicMock() - - try: - with ( - patch("anton.chat.DataVault", return_value=vault), - patch("anton.chat.DatasourceRegistry"), - ): - result = await _handle_connect_datasource( - console, session._scratchpads, session, prefill="hubspot-2" - ) - - # After reconnect, all connections are restored as namespaced vars. - # No flat DS_* vars are present. - assert "DS_HOST" not in os.environ - assert "DS_ACCESS_TOKEN" not in os.environ - # Both connections remain available as namespaced vars. - assert os.environ.get("DS_ORACLE_1__HOST") == "oracle.host" - assert os.environ.get("DS_HUBSPOT_2__ACCESS_TOKEN") == "pat-xxx" - # Active datasource is updated to the reconnected slug. - assert result._active_datasource == "hubspot-2" - finally: - vault.clear_ds_env() + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry"), + ): + result = await _handle_connect_datasource( + console, session._scratchpads, session, prefill="hubspot-2" + ) + + assert "DS_HOST" not in os.environ + assert "DS_ACCESS_TOKEN" not in os.environ + assert os.environ.get("DS_ORACLE_1__HOST") == "oracle.host" + assert os.environ.get("DS_HUBSPOT_2__ACCESS_TOKEN") == "pat-xxx" + assert result._active_datasource == "hubspot-2" def test_build_datasource_context_no_filter(self, vault_dir): """Without active_only, all vault entries appear in the context.""" - from anton.chat import _build_datasource_context - vault = DataVault(vault_dir=vault_dir) vault.save("oracle", "1", {"host": "oracle.host"}) vault.save("hubspot", "2", {"access_token": "pat-xxx"}) @@ -1039,8 +930,6 @@ def test_build_datasource_context_no_filter(self, vault_dir): def test_build_datasource_context_active_only_filters(self, vault_dir): """With active_only set, only the matching slug appears.""" - from anton.chat import _build_datasource_context - vault = DataVault(vault_dir=vault_dir) vault.save("oracle", "1", {"host": "oracle.host"}) vault.save("hubspot", "2", {"access_token": "pat-xxx"}) @@ -1053,76 +942,82 @@ def test_build_datasource_context_active_only_filters(self, vault_dir): def test_build_datasource_context_active_only_empty_when_no_match(self, vault_dir): """If active_only doesn't match any slug, the section has no entries.""" - from anton.chat import _build_datasource_context - vault = DataVault(vault_dir=vault_dir) vault.save("oracle", "1", {"host": "oracle.host"}) with patch("anton.chat.DataVault", return_value=vault): ctx = _build_datasource_context(active_only="hubspot-99") - # Header is present but no datasource lines assert "oracle-1" not in ctx + def test_build_datasource_context_shows_namespaced_vars(self, vault_dir): + vault = DataVault(vault_dir=vault_dir) + vault.save("postgres", "prod_db", {"host": "pg.example.com", "password": "s3cr3t"}) -# ───────────────────────────────────────────────────────────────────────────── -# CLI command registration -# ───────────────────────────────────────────────────────────────────────────── + with patch("anton.chat.DataVault", return_value=vault): + ctx = _build_datasource_context() + assert "DS_POSTGRES_PROD_DB__HOST" in ctx + assert "DS_POSTGRES_PROD_DB__PASSWORD" in ctx + assert "DS_HOST" not in ctx -class TestCliCommandRegistration: - """Verify all datasource CLI commands are registered.""" + def test_build_datasource_context_shows_slug_and_engine_label(self, vault_dir): + vault = DataVault(vault_dir=vault_dir) + vault.save("postgres", "prod_db", {"host": "pg.example.com"}) - def _get_command_names(self): - from anton.cli import app - return [cmd.name for cmd in app.registered_commands] + with patch("anton.chat.DataVault", return_value=vault): + ctx = _build_datasource_context() - def test_connect_data_source_registered(self): - names = self._get_command_names() - assert "connect-data-source" in names + assert "postgres-prod_db" in ctx + assert "(postgres)" in ctx - def test_list_data_sources_registered(self): - names = self._get_command_names() - assert "list-data-sources" in names + def test_multi_source_context_shows_both_connections(self, vault_dir): + """Both connections are visible in the context with their namespaced vars.""" + vault = DataVault(vault_dir=vault_dir) + vault.save("postgres", "prod_db", {"host": "pg.example.com"}) + vault.save("hubspot", "main", {"access_token": "pat-abc"}) + + with patch("anton.chat.DataVault", return_value=vault): + ctx = _build_datasource_context() + + assert "postgres-prod_db" in ctx + assert "DS_POSTGRES_PROD_DB__HOST" in ctx + assert "hubspot-main" in ctx + assert "DS_HUBSPOT_MAIN__ACCESS_TOKEN" in ctx - def test_edit_data_source_registered(self): - names = self._get_command_names() - assert "edit-data-source" in names - def test_remove_data_source_registered(self): - names = self._get_command_names() - assert "remove-data-source" in names +# ───────────────────────────────────────────────────────────────────────────── +# CLI command registration +# ───────────────────────────────────────────────────────────────────────────── - def test_test_data_source_registered(self): - names = self._get_command_names() - assert "test-data-source" in names + +class TestCliCommandRegistration: + @pytest.mark.parametrize("cmd_name", [ + "connect-data-source", + "list-data-sources", + "edit-data-source", + "remove-data-source", + "test-data-source", + ]) + def test_command_registered(self, cmd_name): + names = [cmd.name for cmd in _cli_app.registered_commands] + assert cmd_name in names # ───────────────────────────────────────────────────────────────────────────── -# _handle_list_data_sources — improved output +# _handle_list_data_sources # ───────────────────────────────────────────────────────────────────────────── class TestHandleListDataSources: def test_empty_vault_shows_message(self, vault_dir): - from unittest.mock import MagicMock, patch - from anton.chat import _handle_list_data_sources - console = MagicMock() - console.print = MagicMock() - with patch("anton.chat.DataVault", return_value=DataVault(vault_dir=vault_dir)): _handle_list_data_sources(console) - printed = " ".join(str(c) for c in console.print.call_args_list) assert "No data sources" in printed or "connect-data-source" in printed - def test_complete_connection_shows_saved(self, vault_dir, registry): - from unittest.mock import MagicMock, patch - from rich.console import Console - from anton.chat import _handle_list_data_sources - import io - + def test_complete_connection_shows_saved_with_engine_name(self, vault_dir, registry): vault = DataVault(vault_dir=vault_dir) vault.save("postgresql", "prod_db", { "host": "db.example.com", "port": "5432", @@ -1141,13 +1036,9 @@ def test_complete_connection_shows_saved(self, vault_dir, registry): output = buf.getvalue() assert "postgresql-prod_db" in output assert "saved" in output.lower() + assert "PostgreSQL" in output # engine display_name shown def test_incomplete_connection_shows_incomplete(self, vault_dir, registry): - from unittest.mock import MagicMock, patch - from rich.console import Console - from anton.chat import _handle_list_data_sources - import io - vault = DataVault(vault_dir=vault_dir) # Missing required fields: database, user, password vault.save("postgresql", "partial", {"host": "db.example.com"}) @@ -1164,29 +1055,6 @@ def test_incomplete_connection_shows_incomplete(self, vault_dir, registry): output = buf.getvalue() assert "incomplete" in output.lower() - def test_shows_source_name(self, vault_dir, registry): - from unittest.mock import patch - from rich.console import Console - from anton.chat import _handle_list_data_sources - import io - - vault = DataVault(vault_dir=vault_dir) - vault.save("postgresql", "prod_db", { - "host": "x", "port": "5432", "database": "d", "user": "u", "password": "p", - }) - - buf = io.StringIO() - rich_console = Console(file=buf, highlight=False, markup=False) - - with ( - patch("anton.chat.DataVault", return_value=vault), - patch("anton.chat.DatasourceRegistry", return_value=registry), - ): - _handle_list_data_sources(rich_console) - - output = buf.getvalue() - assert "PostgreSQL" in output - # ───────────────────────────────────────────────────────────────────────────── # _handle_test_datasource @@ -1194,30 +1062,16 @@ def test_shows_source_name(self, vault_dir, registry): class TestHandleTestDatasource: - def _make_cell(self, stdout="ok", stderr="", error=None): - from unittest.mock import MagicMock - cell = MagicMock() - cell.stdout = stdout - cell.stderr = stderr - cell.error = error - return cell - @pytest.mark.asyncio - async def test_success_path(self, vault_dir, registry): - from unittest.mock import AsyncMock, MagicMock, patch - from anton.chat import _handle_test_datasource - + async def test_success_path(self, vault_dir, registry, make_cell): vault = DataVault(vault_dir=vault_dir) vault.save("postgresql", "prod_db", { "host": "db.example.com", "port": "5432", "database": "prod", "user": "alice", "password": "s3cr3t", }) - console = MagicMock() - console.print = MagicMock() - pad = AsyncMock() - pad.execute = AsyncMock(return_value=self._make_cell(stdout="ok")) + pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) scratchpads = AsyncMock() scratchpads.get_or_create = AsyncMock(return_value=pad) @@ -1231,22 +1085,16 @@ async def test_success_path(self, vault_dir, registry): assert "✓" in printed or "passed" in printed.lower() @pytest.mark.asyncio - async def test_failure_path(self, vault_dir, registry): - from unittest.mock import AsyncMock, MagicMock, patch - from anton.chat import _handle_test_datasource - + async def test_failure_path(self, vault_dir, registry, make_cell): vault = DataVault(vault_dir=vault_dir) vault.save("postgresql", "prod_db", { "host": "db.example.com", "port": "5432", "database": "prod", "user": "alice", "password": "wrongpass", }) - console = MagicMock() - console.print = MagicMock() - pad = AsyncMock() pad.execute = AsyncMock( - return_value=self._make_cell(stdout="", stderr="password authentication failed") + return_value=make_cell(stdout="", stderr="password authentication failed") ) scratchpads = AsyncMock() scratchpads.get_or_create = AsyncMock(return_value=pad) @@ -1262,12 +1110,8 @@ async def test_failure_path(self, vault_dir, registry): @pytest.mark.asyncio async def test_unknown_connection(self, vault_dir, registry): - from unittest.mock import AsyncMock, MagicMock, patch - from anton.chat import _handle_test_datasource - vault = DataVault(vault_dir=vault_dir) console = MagicMock() - console.print = MagicMock() scratchpads = AsyncMock() with ( @@ -1281,11 +1125,7 @@ async def test_unknown_connection(self, vault_dir, registry): @pytest.mark.asyncio async def test_empty_slug_shows_usage(self, vault_dir, registry): - from unittest.mock import AsyncMock, MagicMock, patch - from anton.chat import _handle_test_datasource - console = MagicMock() - console.print = MagicMock() scratchpads = AsyncMock() with ( @@ -1297,39 +1137,6 @@ async def test_empty_slug_shows_usage(self, vault_dir, registry): printed = " ".join(str(c) for c in console.print.call_args_list) assert "Usage" in printed or "test-data-source" in printed - @pytest.mark.asyncio - async def test_ds_env_after_test(self, vault_dir, registry): - """After test-data-source: flat vars are gone, namespaced vars are restored.""" - from unittest.mock import AsyncMock, MagicMock, patch - from anton.chat import _handle_test_datasource - - vault = DataVault(vault_dir=vault_dir) - vault.save("postgresql", "prod_db", { - "host": "db.example.com", "port": "5432", - "database": "prod", "user": "alice", "password": "s3cr3t", - }) - - console = MagicMock() - pad = AsyncMock() - pad.execute = AsyncMock(return_value=self._make_cell(stdout="ok")) - scratchpads = AsyncMock() - scratchpads.get_or_create = AsyncMock(return_value=pad) - - with ( - patch("anton.chat.DataVault", return_value=vault), - patch("anton.chat.DatasourceRegistry", return_value=registry), - ): - await _handle_test_datasource(console, scratchpads, "postgresql-prod_db") - - try: - # Flat vars must be gone - assert "DS_HOST" not in os.environ - assert "DS_PASSWORD" not in os.environ - # Namespaced vars are restored (all saved connections) - assert os.environ.get("DS_POSTGRESQL_PROD_DB__HOST") == "db.example.com" - finally: - vault.clear_ds_env() - # ───────────────────────────────────────────────────────────────────────────── # Edit flow @@ -1337,52 +1144,23 @@ async def test_ds_env_after_test(self, vault_dir, registry): class TestEditDatasourceFlow: - def _make_session(self): - from unittest.mock import AsyncMock - from anton.chat import ChatSession - - mock_llm = AsyncMock() - session = ChatSession(mock_llm) - session._scratchpads = AsyncMock() - return session - - def _make_cell(self, stdout="ok", stderr="", error=None): - from unittest.mock import MagicMock - cell = MagicMock() - cell.stdout = stdout - cell.stderr = stderr - cell.error = error - return cell - @pytest.mark.asyncio - async def test_existing_values_loaded(self, registry, vault_dir): + async def test_existing_values_loaded(self, registry, vault_dir, make_session, make_cell): """Edit shows existing non-secret values as defaults.""" - from unittest.mock import AsyncMock, MagicMock, patch - from anton.chat import _handle_connect_datasource - vault = DataVault(vault_dir=vault_dir) vault.save("postgresql", "prod_db", { "host": "old.host", "port": "5432", "database": "prod", "user": "alice", "password": "oldpass", }) - session = self._make_session() + session = make_session() console = MagicMock() - console.print = MagicMock() - pad = AsyncMock() - pad.execute = AsyncMock(return_value=self._make_cell(stdout="ok")) + pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) session._scratchpads.get_or_create = AsyncMock(return_value=pad) - # The prompt for 'host' should default to "old.host"; user presses Enter (keeps it) - # The prompt for 'password' is secret; user provides new value prompt_values = iter([ - "old.host", # host — Enter = keep (returns default) - "5432", # port - "prod", # database - "alice", # user - "newpass", # password - "", # schema (optional) + "old.host", "5432", "prod", "alice", "newpass", "", ]) with ( @@ -1395,15 +1173,13 @@ async def test_existing_values_loaded(self, registry, vault_dir): ) saved = vault.load("postgresql", "prod_db") + assert saved is not None assert saved["host"] == "old.host" assert saved["password"] == "newpass" @pytest.mark.asyncio - async def test_enter_preserves_secret_value(self, registry, vault_dir): + async def test_enter_preserves_secret_value(self, registry, vault_dir, make_session, make_cell): """Pressing Enter on a secret field keeps the existing value.""" - from unittest.mock import AsyncMock, MagicMock, patch - from anton.chat import _handle_connect_datasource - vault = DataVault(vault_dir=vault_dir) original_pass = "original_secret_pass" vault.save("postgresql", "prod_db", { @@ -1411,14 +1187,12 @@ async def test_enter_preserves_secret_value(self, registry, vault_dir): "database": "prod", "user": "alice", "password": original_pass, }) - session = self._make_session() + session = make_session() console = MagicMock() - pad = AsyncMock() - pad.execute = AsyncMock(return_value=self._make_cell(stdout="ok")) + pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) session._scratchpads.get_or_create = AsyncMock(return_value=pad) - # Empty string for secret field = keep existing prompt_values = iter([ "db.host", "5432", "prod", "alice", "", # password — Enter = keep original @@ -1435,18 +1209,15 @@ async def test_enter_preserves_secret_value(self, registry, vault_dir): ) saved = vault.load("postgresql", "prod_db") + assert saved is not None assert saved["password"] == original_pass @pytest.mark.asyncio - async def test_unknown_slug_returns_session(self, registry, vault_dir): + async def test_unknown_slug_returns_session(self, registry, vault_dir, make_session): """Editing a non-existent slug returns the session unchanged.""" - from unittest.mock import AsyncMock, MagicMock, patch - from anton.chat import _handle_connect_datasource - vault = DataVault(vault_dir=vault_dir) - session = self._make_session() + session = make_session() console = MagicMock() - console.print = MagicMock() with ( patch("anton.chat.DataVault", return_value=vault), @@ -1462,16 +1233,12 @@ async def test_unknown_slug_returns_session(self, registry, vault_dir): # ───────────────────────────────────────────────────────────────────────────── -# Remove flow — full coverage +# Remove flow # ───────────────────────────────────────────────────────────────────────────── class TestRemoveDatasourceFlow: def test_confirmation_yes_deletes(self, vault, registry): - from unittest.mock import patch - from anton.chat import _handle_remove_data_source - from rich.console import Console - vault.save("postgresql", "prod_db", {"host": "x"}) console = Console(quiet=True) @@ -1484,10 +1251,6 @@ def test_confirmation_yes_deletes(self, vault, registry): assert vault.load("postgresql", "prod_db") is None def test_confirmation_no_preserves(self, vault, registry): - from unittest.mock import patch - from anton.chat import _handle_remove_data_source - from rich.console import Console - vault.save("postgresql", "prod_db", {"host": "x"}) console = Console(quiet=True) @@ -1500,12 +1263,8 @@ def test_confirmation_no_preserves(self, vault, registry): assert vault.load("postgresql", "prod_db") is not None def test_unknown_name_shows_message(self, vault_dir): - from unittest.mock import MagicMock, patch - from anton.chat import _handle_remove_data_source - vault = DataVault(vault_dir=vault_dir) console = MagicMock() - console.print = MagicMock() with patch("anton.chat.DataVault", return_value=vault): _handle_remove_data_source(console, "postgresql-ghost") @@ -1514,12 +1273,8 @@ def test_unknown_name_shows_message(self, vault_dir): assert "not found" in printed.lower() or "No connection" in printed def test_invalid_format_shows_warning(self, vault_dir): - from unittest.mock import MagicMock, patch - from anton.chat import _handle_remove_data_source - vault = DataVault(vault_dir=vault_dir) console = MagicMock() - console.print = MagicMock() with patch("anton.chat.DataVault", return_value=vault): _handle_remove_data_source(console, "nohyphen") @@ -1528,72 +1283,24 @@ def test_invalid_format_shows_warning(self, vault_dir): assert "Invalid" in printed or "engine-name" in printed -# ───────────────────────────────────────────────────────────────────────────── -# Remove slash command — empty name bug fix -# ───────────────────────────────────────────────────────────────────────────── - - -class TestRemoveSlashCommandEmptyName: - """Ensure /remove-data-source without arg does NOT call the handler.""" - - def test_empty_arg_does_not_call_handler(self): - """The chat loop must not call _handle_remove_data_source with an empty slug.""" - from unittest.mock import MagicMock, patch - # Import the handler to verify it's not called with empty arg - with patch("anton.chat._handle_remove_data_source") as mock_remove: - # Simulate what the chat loop does when arg is empty - cmd = "/remove-data-source" - parts = [cmd] # no argument - arg = parts[1].strip() if len(parts) > 1 else "" - assert arg == "" - # The fixed logic: - if not arg: - pass # show usage message, do NOT call handler - else: - from anton.chat import _handle_remove_data_source - _handle_remove_data_source(MagicMock(), arg) - - mock_remove.assert_not_called() - - # ───────────────────────────────────────────────────────────────────────────── # Environment activation — collision-free behavior # ───────────────────────────────────────────────────────────────────────────── class TestEnvActivationCollisionFree: - def _make_session(self): - from unittest.mock import AsyncMock - from anton.chat import ChatSession - - mock_llm = AsyncMock() - session = ChatSession(mock_llm) - session._scratchpads = AsyncMock() - return session - - def _make_cell(self, stdout="ok", stderr="", error=None): - from unittest.mock import MagicMock - cell = MagicMock() - cell.stdout = stdout - cell.stderr = stderr - cell.error = error - return cell - @pytest.mark.asyncio - async def test_connect_clears_previous_ds_vars(self, registry, vault_dir): - """After a successful new connect, only the new connection's DS_* vars are set.""" - from unittest.mock import AsyncMock, patch - from anton.chat import _handle_connect_datasource - - # Pre-inject an "old" connection - os.environ["DS_ACCESS_TOKEN"] = "old-token" - + async def test_connect_clears_previous_ds_vars( + self, registry, vault_dir, make_session, make_cell, monkeypatch + ): + """After a successful new connect, stale DS_* vars are cleared.""" + monkeypatch.setenv("DS_ACCESS_TOKEN", "old-token") vault = DataVault(vault_dir=vault_dir) - session = self._make_session() + session = make_session() console = MagicMock() pad = AsyncMock() - pad.execute = AsyncMock(return_value=self._make_cell(stdout="ok")) + pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) session._scratchpads.get_or_create = AsyncMock(return_value=pad) prompt_responses = iter([ @@ -1601,29 +1308,21 @@ async def test_connect_clears_previous_ds_vars(self, registry, vault_dir): "db.example.com", "5432", "prod_db", "alice", "s3cr3t", "", ]) - try: - with ( - patch("anton.chat.DataVault", return_value=vault), - patch("anton.chat.DatasourceRegistry", return_value=registry), - patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), - ): - await _handle_connect_datasource(console, session._scratchpads, session) - - # Old flat token must be gone; flat vars are never kept in runtime - assert "DS_ACCESS_TOKEN" not in os.environ - # Namespaced vars for the new connection must be present - # name_from=database → name="prod_db" → DS_POSTGRESQL_PROD_DB__HOST - assert os.environ.get("DS_POSTGRESQL_PROD_DB__HOST") == "db.example.com" - finally: - vault.clear_ds_env() - os.environ.pop("DS_ACCESS_TOKEN", None) + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + ): + await _handle_connect_datasource(console, session._scratchpads, session) - @pytest.mark.asyncio - async def test_two_same_type_connections_no_collision(self, registry, vault_dir): - """Activating one of two same-type connections sets only that one's vars.""" - from unittest.mock import MagicMock, patch - from anton.chat import _handle_connect_datasource + assert "DS_ACCESS_TOKEN" not in os.environ + assert os.environ.get("DS_POSTGRESQL_PROD_DB__HOST") == "db.example.com" + @pytest.mark.asyncio + async def test_two_same_type_connections_no_collision( + self, registry, vault_dir, make_session + ): + """Both same-type connections remain available as distinct namespaced vars.""" vault = DataVault(vault_dir=vault_dir) vault.save("postgresql", "db1", { "host": "host1.example.com", "port": "5432", @@ -1634,103 +1333,33 @@ async def test_two_same_type_connections_no_collision(self, registry, vault_dir) "database": "db2", "user": "u2", "password": "p2", }) - session = self._make_session() + session = make_session() console = MagicMock() - try: - with ( - patch("anton.chat.DataVault", return_value=vault), - patch("anton.chat.DatasourceRegistry", return_value=registry), - ): - # Reconnect to db2 (prefill path) - await _handle_connect_datasource( - console, session._scratchpads, session, - prefill="postgresql-db2", - ) - - # Both connections remain available as namespaced vars (no collision) - assert os.environ.get("DS_POSTGRESQL_DB1__HOST") == "host1.example.com" - assert os.environ.get("DS_POSTGRESQL_DB2__HOST") == "host2.example.com" - assert os.environ.get("DS_POSTGRESQL_DB2__DATABASE") == "db2" - # No flat vars - assert "DS_HOST" not in os.environ - assert "DS_DATABASE" not in os.environ - finally: - vault.clear_ds_env() - - @pytest.mark.asyncio - async def test_test_datasource_does_not_leave_vars(self, registry, vault_dir): - """_handle_test_datasource cleans up all DS_* vars after testing.""" - from unittest.mock import AsyncMock, MagicMock, patch - from anton.chat import _handle_test_datasource - - vault = DataVault(vault_dir=vault_dir) - vault.save("postgresql", "prod_db", { - "host": "db.example.com", "port": "5432", - "database": "prod", "user": "alice", "password": "s3cr3t", - }) - - pad = AsyncMock() - pad.execute = AsyncMock( - return_value=MagicMock(stdout="ok", stderr="", error=None) - ) - scratchpads = AsyncMock() - scratchpads.get_or_create = AsyncMock(return_value=pad) - with ( patch("anton.chat.DataVault", return_value=vault), patch("anton.chat.DatasourceRegistry", return_value=registry), ): - await _handle_test_datasource(MagicMock(), scratchpads, "postgresql-prod_db") + await _handle_connect_datasource( + console, session._scratchpads, session, + prefill="postgresql-db2", + ) - try: - # Flat vars must not be present after the test - flat_leaked = [k for k in os.environ if k.startswith("DS_") and "__" not in k] - assert flat_leaked == [], f"Flat DS_* vars leaked after test: {flat_leaked}" - # Namespaced vars are restored — that is the expected runtime state - assert os.environ.get("DS_POSTGRESQL_PROD_DB__HOST") == "db.example.com" - finally: - vault.clear_ds_env() + assert os.environ.get("DS_POSTGRESQL_DB1__HOST") == "host1.example.com" + assert os.environ.get("DS_POSTGRESQL_DB2__HOST") == "host2.example.com" + assert "DS_HOST" not in os.environ + assert "DS_DATABASE" not in os.environ # ───────────────────────────────────────────────────────────────────────────── -# Chat slash-command behavior +# Datasource slash-command behavior # ───────────────────────────────────────────────────────────────────────────── -class TestChatSlashCommands: - """Verify slash-command routing logic for datasource commands.""" - - def test_list_data_sources_slash_command_routes(self): - """'/list-data-sources' cmd string maps to _handle_list_data_sources.""" - # Verify the function is importable and callable (routing covered by chat loop) - from anton.chat import _handle_list_data_sources - assert callable(_handle_list_data_sources) - - def test_test_data_source_slash_command_routes(self): - """'/test-data-source' is importable and async.""" - import inspect - from anton.chat import _handle_test_datasource - assert inspect.iscoroutinefunction(_handle_test_datasource) - - def test_remove_data_source_slash_command_routes(self): - from anton.chat import _handle_remove_data_source - assert callable(_handle_remove_data_source) - - def test_edit_data_source_routes_to_connect_handler(self): - """'/edit-data-source' uses _handle_connect_datasource with datasource_name arg.""" - import inspect - from anton.chat import _handle_connect_datasource - sig = inspect.signature(_handle_connect_datasource) - assert "datasource_name" in sig.parameters - +class TestDatasourceSlashCommandBehavior: @pytest.mark.asyncio async def test_test_data_source_no_arg_shows_usage(self, vault_dir, registry): - from unittest.mock import AsyncMock, MagicMock, patch - from anton.chat import _handle_test_datasource - console = MagicMock() - console.print = MagicMock() scratchpads = AsyncMock() with ( @@ -1743,23 +1372,13 @@ async def test_test_data_source_no_arg_shows_usage(self, vault_dir, registry): assert "Usage" in printed or "test-data-source" in printed @pytest.mark.asyncio - async def test_edit_data_source_no_arg_safe(self, vault_dir, registry): - """'/edit-data-source' without arg (datasource_name=None) triggers new-connect flow.""" - from anton.chat import _handle_connect_datasource - # datasource_name=None means new connect, not edit — no crash expected - from anton.chat import ChatSession - from unittest.mock import AsyncMock, MagicMock, patch - - mock_llm = AsyncMock() - session = ChatSession(mock_llm) - session._scratchpads = AsyncMock() + async def test_edit_data_source_no_arg_safe(self, vault_dir, registry, make_session): + """datasource_name=None triggers new-connect flow without crash.""" + session = make_session() console = MagicMock() with ( - patch( - "anton.chat.DataVault", - return_value=DataVault(vault_dir=vault_dir), - ), + patch("anton.chat.DataVault", return_value=DataVault(vault_dir=vault_dir)), patch("anton.chat.DatasourceRegistry", return_value=registry), patch("rich.prompt.Prompt.ask", return_value="UnknownEngine"), ): @@ -1767,176 +1386,28 @@ async def test_edit_data_source_no_arg_safe(self, vault_dir, registry): console, session._scratchpads, session, datasource_name=None, ) - # Should return session (even if unknown engine) + assert updated is not None + # ───────────────────────────────────────────────────────────────────────────── # _slug_env_prefix # ───────────────────────────────────────────────────────────────────────────── class TestSlugEnvPrefix: - """Unit tests for the _slug_env_prefix helper.""" - def test_basic_engine_and_name(self): - from anton.data_vault import _slug_env_prefix assert _slug_env_prefix("postgres", "prod_db") == "DS_POSTGRES_PROD_DB" def test_hubspot_main(self): - from anton.data_vault import _slug_env_prefix assert _slug_env_prefix("hubspot", "main") == "DS_HUBSPOT_MAIN" def test_sanitizes_hyphen_and_dot(self): - from anton.data_vault import _slug_env_prefix assert _slug_env_prefix("postgres", "prod-db.eu") == "DS_POSTGRES_PROD_DB_EU" def test_numeric_name(self): - from anton.data_vault import _slug_env_prefix assert _slug_env_prefix("postgresql", "1") == "DS_POSTGRESQL_1" - def test_uppercase_result(self): - from anton.data_vault import _slug_env_prefix - result = _slug_env_prefix("myengine", "myname") - assert result == result.upper() - - def test_double_underscore_separator_in_full_var(self): - """The separator between prefix and field must be double underscore.""" - from anton.data_vault import _slug_env_prefix - prefix = _slug_env_prefix("postgres", "prod_db") - full_var = f"{prefix}__HOST" - assert full_var == "DS_POSTGRES_PROD_DB__HOST" - - -# ───────────────────────────────────────────────────────────────────────────── -# Namespaced runtime env — _build_datasource_context -# ───────────────────────────────────────────────────────────────────────────── - - -class TestNamespacedRuntimeEnv: - """Tests for namespaced vars in datasource context and multi-source access.""" - - def test_build_datasource_context_shows_namespaced_vars(self, vault_dir): - from unittest.mock import patch - from anton.chat import _build_datasource_context - - vault = DataVault(vault_dir=vault_dir) - vault.save("postgres", "prod_db", {"host": "pg.example.com", "password": "s3cr3t"}) - - with patch("anton.chat.DataVault", return_value=vault): - ctx = _build_datasource_context() - - assert "DS_POSTGRES_PROD_DB__HOST" in ctx - assert "DS_POSTGRES_PROD_DB__PASSWORD" in ctx - # No flat vars - assert "DS_HOST" not in ctx - - def test_build_datasource_context_shows_slug_and_engine_label(self, vault_dir): - from unittest.mock import patch - from anton.chat import _build_datasource_context - - vault = DataVault(vault_dir=vault_dir) - vault.save("postgres", "prod_db", {"host": "pg.example.com"}) - - with patch("anton.chat.DataVault", return_value=vault): - ctx = _build_datasource_context() - - assert "postgres-prod_db" in ctx - assert "(postgres)" in ctx - - def test_multi_source_context_shows_both_connections(self, vault_dir): - """Both connections are visible in the context with their namespaced vars.""" - from unittest.mock import patch - from anton.chat import _build_datasource_context - - vault = DataVault(vault_dir=vault_dir) - vault.save("postgres", "prod_db", {"host": "pg.example.com"}) - vault.save("hubspot", "main", {"access_token": "pat-abc"}) - - with patch("anton.chat.DataVault", return_value=vault): - ctx = _build_datasource_context() - - assert "postgres-prod_db" in ctx - assert "DS_POSTGRES_PROD_DB__HOST" in ctx - assert "hubspot-main" in ctx - assert "DS_HUBSPOT_MAIN__ACCESS_TOKEN" in ctx - - def test_build_datasource_context_header_mentions_namespaced(self, vault_dir): - """The header text explains the namespaced pattern.""" - from unittest.mock import patch - from anton.chat import _build_datasource_context - - vault = DataVault(vault_dir=vault_dir) - vault.save("postgres", "prod_db", {"host": "pg.example.com"}) - - with patch("anton.chat.DataVault", return_value=vault): - ctx = _build_datasource_context() - - assert "namespaced" in ctx.lower() or "DS_" in ctx - - -# ───────────────────────────────────────────────────────────────────────────── -# _register_secret_vars — namespaced mode -# ───────────────────────────────────────────────────────────────────────────── - - -class TestRegisterSecretVarsNamespaced: - """Tests for namespaced secret var registration.""" - - def setup_method(self): - from anton.chat import _DS_KNOWN_VARS, _DS_SECRET_VARS - _DS_SECRET_VARS.clear() - _DS_KNOWN_VARS.clear() - ds_keys = [k for k in os.environ if k.startswith("DS_")] - for k in ds_keys: - del os.environ[k] - - def test_register_with_slug_uses_namespaced_keys(self, registry): - from anton.chat import _DS_KNOWN_VARS, _DS_SECRET_VARS, _register_secret_vars - - pg = registry.get("postgresql") - _register_secret_vars(pg, engine="postgresql", name="prod_db") - - assert "DS_POSTGRESQL_PROD_DB__PASSWORD" in _DS_SECRET_VARS - assert "DS_POSTGRESQL_PROD_DB__HOST" not in _DS_SECRET_VARS - assert "DS_POSTGRESQL_PROD_DB__HOST" in _DS_KNOWN_VARS - - def test_register_without_slug_uses_flat_keys(self, registry): - from anton.chat import _DS_SECRET_VARS, _register_secret_vars - - pg = registry.get("postgresql") - _register_secret_vars(pg) # no engine/name → flat mode - - assert "DS_PASSWORD" in _DS_SECRET_VARS - assert "DS_HOST" not in _DS_SECRET_VARS - - def test_scrub_replaces_namespaced_secret_value(self, registry): - from anton.chat import _DS_SECRET_VARS, _register_secret_vars, _scrub_credentials - - pg = registry.get("postgresql") - _register_secret_vars(pg, engine="postgresql", name="prod_db") - - secret = "namespacedpassword123" - os.environ["DS_POSTGRESQL_PROD_DB__PASSWORD"] = secret - try: - result = _scrub_credentials(f"error: {secret}") - assert secret not in result - assert "[DS_POSTGRESQL_PROD_DB__PASSWORD]" in result - finally: - del os.environ["DS_POSTGRESQL_PROD_DB__PASSWORD"] - - def test_scrub_leaves_namespaced_non_secret_readable(self, registry): - from anton.chat import _register_secret_vars, _scrub_credentials - - pg = registry.get("postgresql") - _register_secret_vars(pg, engine="postgresql", name="prod_db") - - os.environ["DS_POSTGRESQL_PROD_DB__HOST"] = "db.example.com" - try: - result = _scrub_credentials("host=db.example.com") - assert "db.example.com" in result - finally: - del os.environ["DS_POSTGRESQL_PROD_DB__HOST"] - # ───────────────────────────────────────────────────────────────────────────── # Temporary flat activation and restoration @@ -1948,13 +1419,9 @@ class TestTemporaryFlatExecution: def test_restore_namespaced_env_clears_flat_and_reinjects(self, vault_dir): """_restore_namespaced_env replaces flat vars with namespaced vars.""" - from unittest.mock import patch - from anton.chat import _restore_namespaced_env - vault = DataVault(vault_dir=vault_dir) vault.save("postgres", "analytics", {"host": "analytics.example.com"}) - # Simulate a flat injection (as done during test_snippet execution) vault.inject_env("postgres", "analytics", flat=True) assert os.environ.get("DS_HOST") == "analytics.example.com" assert "DS_POSTGRES_ANALYTICS__HOST" not in os.environ @@ -1962,47 +1429,34 @@ def test_restore_namespaced_env_clears_flat_and_reinjects(self, vault_dir): with patch("anton.chat.DataVault", return_value=vault): _restore_namespaced_env(vault) - # Flat var is gone; namespaced var is back assert "DS_HOST" not in os.environ assert os.environ.get("DS_POSTGRES_ANALYTICS__HOST") == "analytics.example.com" - vault.clear_ds_env() def test_restore_namespaced_env_reinjects_all_connections(self, vault_dir): """_restore_namespaced_env restores ALL saved connections, not just one.""" - from unittest.mock import patch - from anton.chat import _restore_namespaced_env - vault = DataVault(vault_dir=vault_dir) vault.save("postgres", "prod_db", {"host": "prod.example.com"}) vault.save("hubspot", "main", {"access_token": "pat-abc"}) - # Simulate flat injection for one connection only - vault.clear_ds_env() vault.inject_env("postgres", "prod_db", flat=True) with patch("anton.chat.DataVault", return_value=vault): _restore_namespaced_env(vault) - # Both connections are available as namespaced vars assert "DS_HOST" not in os.environ assert os.environ.get("DS_POSTGRES_PROD_DB__HOST") == "prod.example.com" assert os.environ.get("DS_HUBSPOT_MAIN__ACCESS_TOKEN") == "pat-abc" - vault.clear_ds_env() @pytest.mark.asyncio async def test_test_datasource_injects_flat_then_restores_namespaced( self, vault_dir, registry ): """_handle_test_datasource uses flat vars during snippet, then restores namespaced.""" - from unittest.mock import AsyncMock, MagicMock, patch - from anton.chat import _handle_test_datasource - vault = DataVault(vault_dir=vault_dir) vault.save("postgresql", "prod_db", { "host": "pg.example.com", "port": "5432", "database": "prod_db", "user": "alice", "password": "s3cr3t", }) - # Pre-save a second connection that should be restored after the test vault.save("hubspot", "main", {"access_token": "pat-abc"}) vault.inject_env("postgresql", "prod_db") vault.inject_env("hubspot", "main") @@ -2010,7 +1464,6 @@ async def test_test_datasource_injects_flat_then_restores_namespaced( env_during_test: dict = {} async def capture_execute(snippet): - # Capture env state mid-execution to verify flat vars are set env_during_test["DS_HOST"] = os.environ.get("DS_HOST") env_during_test["DS_POSTGRESQL_PROD_DB__HOST"] = os.environ.get( "DS_POSTGRESQL_PROD_DB__HOST" @@ -2024,14 +1477,12 @@ async def capture_execute(snippet): scratchpads = AsyncMock() scratchpads.get_or_create = AsyncMock(return_value=pad) - console = MagicMock() - console.print = MagicMock() with ( patch("anton.chat.DataVault", return_value=vault), patch("anton.chat.DatasourceRegistry", return_value=registry), ): - await _handle_test_datasource(console, scratchpads, "postgresql-prod_db") + await _handle_test_datasource(MagicMock(), scratchpads, "postgresql-prod_db") # During execution: flat var was set, namespaced was absent assert env_during_test["DS_HOST"] == "pg.example.com" @@ -2041,4 +1492,3 @@ async def capture_execute(snippet): assert "DS_HOST" not in os.environ assert os.environ.get("DS_POSTGRESQL_PROD_DB__HOST") == "pg.example.com" assert os.environ.get("DS_HUBSPOT_MAIN__ACCESS_TOKEN") == "pat-abc" - vault.clear_ds_env() From f4d8f6b2417a31bf12ecb182f47b5b2bbf6b5f18 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Fri, 20 Mar 2026 21:11:37 +0100 Subject: [PATCH 13/70] Add moee dataspources --- anton/chat.py | 168 ++++++++++++++++++++++++----------- anton/datasource_registry.py | 30 +++++++ datasources.md | 133 +++++++++++++++++++++++++++ 3 files changed, 277 insertions(+), 54 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index cb802a4..b55a191 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -2585,9 +2585,12 @@ async def _handle_add_custom_datasource( """Ask the user how they authenticate, use the LLM to identify fields, save definition.""" console.print() + if name: + preamble = f"[anton.cyan](anton)[/] '{name}' isn't in my built-in list.\n " + else: + preamble = "[anton.cyan](anton)[/] " user_answer = Prompt.ask( - f"[anton.cyan](anton)[/] '{name}' isn't in my built-in list.\n" - f" How do you authenticate with it? " + f"{preamble}How do you authenticate with it? " f"Describe what you have or paste credentials directly", console=console, ) @@ -2604,7 +2607,7 @@ async def _handle_add_custom_datasource( { "role": "user", "content": ( - f"The user wants to connect to '{name}' and said: {user_answer}\n\n" + f"The user wants to connect to{(' ' + repr(name)) if name else ' a custom data source'} and said: {user_answer}\n\n" "Return ONLY valid JSON (no markdown fences, no commentary):\n" '{"display_name":"Human-readable name","pip":"pip-package or empty string",' '"fields":[{"name":"snake_case_name","value":"value if given inline else empty",' @@ -2923,13 +2926,20 @@ async def _handle_connect_datasource( # ── Normal flow: connect a new (or reconnect an existing) data source ───── console.print() - engine_names = ", ".join(e.display_name for e in registry.all_engines()) + all_engines = registry.all_engines() + if prefill: answer = prefill else: + console.print( + "[anton.cyan](anton)[/] Which data source would you like to connect?\n" + ) + console.print(" [bold] 0.[/bold] Create a custom datasource") + for i, e in enumerate(all_engines, 1): + console.print(f" [bold]{i:>2}.[/bold] {e.display_name}") + console.print() answer = Prompt.ask( - f"[anton.cyan](anton)[/] Which data source would you like to connect?\n" - f" [anton.muted](e.g. {engine_names})[/]\n", + "[anton.cyan](anton)[/] Enter a number, or type the name", console=console, ) @@ -2962,60 +2972,110 @@ async def _handle_connect_datasource( ) return session - engine_def = registry.find_by_name(stripped_answer) - if engine_def is None: - # Check whether the input is ambiguous before treating it as unknown - needle = stripped_answer.lower() - candidates = [ - e - for e in registry.all_engines() - if needle in e.display_name.lower() or needle in e.engine.lower() - ] - if len(candidates) > 1: - console.print() - console.print( - f"[anton.warning](anton)[/] '{stripped_answer}' matches multiple engines — " - "which one did you mean?" - ) - console.print() - for i, e in enumerate(candidates, 1): - console.print(f" {i}. {e.display_name}") - console.print() - pick = Prompt.ask( - "[anton.cyan](anton)[/] Enter a number", - console=console, - ).strip() - try: - engine_def = candidates[int(pick) - 1] - except (ValueError, IndexError): - console.print("[anton.warning]Invalid choice. Aborting.[/]") - console.print() - return session + # ── Number selection ─────────────────────────────────────────────────────── + engine_def: DatasourceEngine | None = None + _go_custom = False + + if stripped_answer.isdigit() or (stripped_answer.lstrip("-").isdigit()): + pick_num = int(stripped_answer) + if pick_num == 0: + _go_custom = True + elif 1 <= pick_num <= len(all_engines): + engine_def = all_engines[pick_num - 1] else: - result = await _handle_add_custom_datasource( - console, stripped_answer, registry, session - ) - if result is None: - return session - engine_def, credentials = result - conn_num = vault.next_connection_number(engine_def.engine) - vault.save(engine_def.engine, str(conn_num), credentials) - slug = f"{engine_def.engine}-{conn_num}" console.print( - f' Credentials saved to Local Vault as [bold]"{slug}"[/bold].' + f"[anton.warning](anton)[/] '{stripped_answer}' is out of range. " + f"Please enter 0–{len(all_engines)}.[/]" ) console.print() - session._history.append( - { - "role": "assistant", - "content": ( - f'I\'ve saved a {engine_def.display_name} connection named "{slug}" ' - f"to the Local Vault. I can now query this data source when needed." - ), - } - ) return session + # ── Name-based resolution (when not a number) ───────────────────────────── + if engine_def is None and not _go_custom: + # 1. Exact / case-insensitive / whitespace-normalized match + engine_def = registry.find_by_name(stripped_answer) + + if engine_def is None: + # 2. Substring match + needle = stripped_answer.lower() + candidates = [ + e + for e in all_engines + if needle in e.display_name.lower() or needle in e.engine.lower() + ] + if len(candidates) == 1: + engine_def = candidates[0] + elif len(candidates) > 1: + console.print() + console.print( + f"[anton.warning](anton)[/] '{stripped_answer}' matches multiple engines — " + "which one did you mean?" + ) + console.print() + for i, e in enumerate(candidates, 1): + console.print(f" {i}. {e.display_name}") + console.print() + pick = Prompt.ask( + "[anton.cyan](anton)[/] Enter a number", + console=console, + ).strip() + try: + engine_def = candidates[int(pick) - 1] + except (ValueError, IndexError): + console.print("[anton.warning]Invalid choice. Aborting.[/]") + console.print() + return session + + if engine_def is None: + # 3. Fuzzy / close match + fuzzy_matches = registry.fuzzy_find(stripped_answer) + for suggestion in fuzzy_matches: + console.print() + console.print( + f'[anton.cyan](anton)[/] Did you mean [bold]"{suggestion.display_name}"[/bold]?' + ) + confirm = ( + Prompt.ask( + "[anton.cyan](anton)[/] [y/n]", + console=console, + default="n", + ) + .strip() + .lower() + ) + if confirm == "y": + engine_def = suggestion + break + + if engine_def is None: + _go_custom = True + + # ── Custom datasource flow ──────────────────────────────────────────────── + if _go_custom: + result = await _handle_add_custom_datasource( + console, stripped_answer if not stripped_answer.isdigit() else "", registry, session + ) + if result is None: + return session + engine_def, credentials = result + conn_num = vault.next_connection_number(engine_def.engine) + vault.save(engine_def.engine, str(conn_num), credentials) + slug = f"{engine_def.engine}-{conn_num}" + console.print( + f' Credentials saved to Local Vault as [bold]"{slug}"[/bold].' + ) + console.print() + session._history.append( + { + "role": "assistant", + "content": ( + f'I\'ve saved a {engine_def.display_name} connection named "{slug}" ' + f"to the Local Vault. I can now query this data source when needed." + ), + } + ) + return session + # ── Step 2a: auth method choice (if engine requires it) ─────── active_fields = engine_def.fields if engine_def.auth_method == "choice" and engine_def.auth_methods: diff --git a/anton/datasource_registry.py b/anton/datasource_registry.py index b106833..009aa55 100644 --- a/anton/datasource_registry.py +++ b/anton/datasource_registry.py @@ -1,5 +1,6 @@ from __future__ import annotations +import difflib import re from dataclasses import dataclass, field from pathlib import Path @@ -143,6 +144,35 @@ def find_by_name(self, display_name: str) -> DatasourceEngine | None: ] return matches[0] if len(matches) == 1 else None + def fuzzy_find(self, text: str) -> list[DatasourceEngine]: + """Return engines whose name/slug closely matches *text* (fuzzy, for typo tolerance).""" + + def _normalize(s: str) -> str: + return re.sub(r"[\s\-_]", "", s).lower() + + needle = _normalize(text) + # Build a map from normalized key → engine (display_name takes priority) + key_to_engine: dict[str, DatasourceEngine] = {} + for engine in self._engines.values(): + key_to_engine[_normalize(engine.display_name)] = engine + # Don't overwrite display_name key with slug key + slug_key = _normalize(engine.engine) + if slug_key not in key_to_engine: + key_to_engine[slug_key] = engine + + close_keys = difflib.get_close_matches( + needle, key_to_engine.keys(), n=3, cutoff=0.6 + ) + # Deduplicate while preserving order + seen: set[str] = set() + results: list[DatasourceEngine] = [] + for k in close_keys: + eng = key_to_engine[k] + if eng.engine not in seen: + seen.add(eng.engine) + results.append(eng) + return results + def all_engines(self) -> list[DatasourceEngine]: return sorted(self._engines.values(), key=lambda e: e.display_name) diff --git a/datasources.md b/datasources.md index 029fcc3..53bd5a9 100644 --- a/datasources.md +++ b/datasources.md @@ -524,6 +524,139 @@ Grant required API permissions (read_products, read_orders, etc.) then install t --- +## NetSuite + +```yaml +engine: netsuite +display_name: NetSuite +pip: requests-oauthlib>=1.3.1 +name_from: account_id +fields: + - { name: account_id, required: true, secret: false, description: "NetSuite account/realm ID (e.g. 123456_SB1)" } + - { name: consumer_key, required: true, secret: true, description: "OAuth consumer key for the NetSuite integration" } + - { name: consumer_secret,required: true, secret: true, description: "OAuth consumer secret for the NetSuite integration" } + - { name: token_id, required: true, secret: true, description: "Token ID generated for the integration role" } + - { name: token_secret, required: true, secret: true, description: "Token secret generated for the integration role" } + - { name: rest_domain, required: false, secret: false, description: "REST domain override (defaults to https://.suitetalk.api.netsuite.com)" } + - { name: record_types, required: false, secret: false, description: "Comma-separated NetSuite record types to expose (e.g. customer,item,salesOrder)" } +test_snippet: | + import os + from requests_oauthlib import OAuth1Session + account_id = os.environ['DS_ACCOUNT_ID'] + rest_domain = os.environ.get('DS_REST_DOMAIN') or f'https://{account_id.lower().replace("_", "-")}.suitetalk.api.netsuite.com' + url = f'{rest_domain.rstrip("/")}/services/rest/record/v1/metadata-catalog/' + session = OAuth1Session( + client_key=os.environ['DS_CONSUMER_KEY'], + client_secret=os.environ['DS_CONSUMER_SECRET'], + resource_owner_key=os.environ['DS_TOKEN_ID'], + resource_owner_secret=os.environ['DS_TOKEN_SECRET'], + realm=account_id, + signature_method='HMAC-SHA256', + ) + r = session.get(url, headers={'Prefer': 'transient'}) + assert r.status_code < 400, f'HTTP {r.status_code}: {r.text[:200]}' + print("ok") +``` + +NetSuite uses OAuth 1.0a Token-Based Authentication (TBA). Create an integration record in +NetSuite (Setup → Integration → Manage Integrations), then generate token credentials via +Setup → Users/Roles → Access Tokens. The account ID can be found in Setup → Company → Company Information. + +--- + +## Denodo + +```yaml +engine: denodo +display_name: Denodo +pip: psycopg2-binary +name_from: [host, database] +fields: + - { name: host, required: true, secret: false, description: "hostname or IP of the Denodo server" } + - { name: port, required: false, secret: false, description: "port number (default 9996)", default: "9996" } + - { name: database, required: true, secret: false, description: "Denodo virtual database name" } + - { name: user, required: true, secret: false, description: "Denodo username" } + - { name: password, required: true, secret: true, description: "Denodo password" } +test_snippet: | + import psycopg2, os + conn = psycopg2.connect( + host=os.environ['DS_HOST'], + port=int(os.environ.get('DS_PORT', '9996')), + dbname=os.environ['DS_DATABASE'], + user=os.environ['DS_USER'], + password=os.environ['DS_PASSWORD'], + ) + conn.close() + print("ok") +``` + +Denodo exposes a PostgreSQL-compatible wire protocol on port 9996 by default. +Connect via the Denodo Platform Control Center → Server Configuration → ODBC/JDBC to find the port. + +--- + +## TimescaleDB + +```yaml +engine: timescaledb +display_name: TimescaleDB +pip: psycopg2-binary +name_from: [host, database] +fields: + - { name: host, required: true, secret: false, description: "hostname or IP of the TimescaleDB server" } + - { name: port, required: false, secret: false, description: "port number", default: "5432" } + - { name: database, required: true, secret: false, description: "database name" } + - { name: user, required: true, secret: false, description: "database username" } + - { name: password, required: true, secret: true, description: "database password" } + - { name: schema, required: false, secret: false, description: "schema name (defaults to public)" } +test_snippet: | + import psycopg2, os + conn = psycopg2.connect( + host=os.environ['DS_HOST'], + port=int(os.environ.get('DS_PORT', '5432')), + dbname=os.environ['DS_DATABASE'], + user=os.environ['DS_USER'], + password=os.environ['DS_PASSWORD'], + ) + cur = conn.cursor() + cur.execute("SELECT extname FROM pg_extension WHERE extname = 'timescaledb'") + if not cur.fetchone(): + raise RuntimeError("timescaledb extension not found — is TimescaleDB installed?") + conn.close() + print("ok") +``` + +TimescaleDB is a PostgreSQL extension for time-series data. Managed options include Timescale Cloud +and self-hosted PostgreSQL with the TimescaleDB extension installed. + +--- + +## Email + +```yaml +engine: email +display_name: Email +name_from: email +fields: + - { name: email, required: true, secret: false, description: "email address to connect" } + - { name: password, required: true, secret: true, description: "email account password or app-specific password" } + - { name: imap_server, required: false, secret: false, description: "IMAP server hostname", default: "imap.gmail.com" } + - { name: smtp_server, required: false, secret: false, description: "SMTP server hostname", default: "smtp.gmail.com" } + - { name: smtp_port, required: false, secret: false, description: "SMTP port", default: "587" } +test_snippet: | + import imaplib, os + imap = imaplib.IMAP4_SSL(os.environ.get('DS_IMAP_SERVER', 'imap.gmail.com')) + imap.login(os.environ['DS_EMAIL'], os.environ['DS_PASSWORD']) + imap.logout() + print("ok") +``` + +For Gmail, enable IMAP in Settings → See all settings → Forwarding and POP/IMAP, then use an +App Password (Google Account → Security → 2-Step Verification → App passwords) instead of your +account password. For other providers, set imap_server and smtp_server accordingly. + +--- + ## Adding a new data source Follow the YAML format above. Add to `~/.anton/datasources.md` (user overrides). From 75faa5693ff2d066cdaf37641f3b95b6325d466c Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Sat, 21 Mar 2026 18:16:51 +0100 Subject: [PATCH 14/70] Fix issues with variables on reconnect --- anton/chat.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/anton/chat.py b/anton/chat.py index b55a191..2b6aecd 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -1044,6 +1044,12 @@ def _apply_error_tracking( _DS_KNOWN_VARS: set[str] = set() +def reset_registered_ds_vars() -> None: + """Clear the DS_* var registries so they can be rebuilt from current vault state.""" + _DS_SECRET_VARS.clear() + _DS_KNOWN_VARS.clear() + + def _register_secret_vars( engine_def: "DatasourceEngine", *, engine: str = "", name: str = "" ) -> None: @@ -1127,6 +1133,7 @@ def _restore_namespaced_env(vault: DataVault) -> None: """Clear all DS_* vars, then reinject every saved connection as namespaced.""" from anton.datasource_registry import DatasourceRegistry + reset_registered_ds_vars() vault.clear_ds_env() dreg = DatasourceRegistry() for conn in vault.list_connections(): @@ -3397,6 +3404,7 @@ def _handle_remove_data_source(console: Console, slug: str) -> None: f"Remove '{slug}' from Local Vault?", default=False, console=console ): vault.delete(engine, name) + _restore_namespaced_env(vault) console.print(f"[anton.success]Removed {slug}.[/]") else: console.print("[anton.muted]Cancelled.[/]") From 756025e4c9dab28da210f1383c3af259309d23b2 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Sat, 21 Mar 2026 18:29:38 +0100 Subject: [PATCH 15/70] Update tests and fix issues reported in tests --- anton/chat.py | 4 +- tests/test_datasource.py | 79 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 2 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index 2b6aecd..8dd211f 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -1044,7 +1044,7 @@ def _apply_error_tracking( _DS_KNOWN_VARS: set[str] = set() -def reset_registered_ds_vars() -> None: +def _reset_registered_ds_vars() -> None: """Clear the DS_* var registries so they can be rebuilt from current vault state.""" _DS_SECRET_VARS.clear() _DS_KNOWN_VARS.clear() @@ -1133,7 +1133,7 @@ def _restore_namespaced_env(vault: DataVault) -> None: """Clear all DS_* vars, then reinject every saved connection as namespaced.""" from anton.datasource_registry import DatasourceRegistry - reset_registered_ds_vars() + _reset_registered_ds_vars() vault.clear_ds_env() dreg = DatasourceRegistry() for conn in vault.list_connections(): diff --git a/tests/test_datasource.py b/tests/test_datasource.py index 2ceca44..a71942b 100644 --- a/tests/test_datasource.py +++ b/tests/test_datasource.py @@ -1492,3 +1492,82 @@ async def capture_execute(snippet): assert "DS_HOST" not in os.environ assert os.environ.get("DS_POSTGRESQL_PROD_DB__HOST") == "pg.example.com" assert os.environ.get("DS_HUBSPOT_MAIN__ACCESS_TOKEN") == "pat-abc" + + +# ───────────────────────────────────────────────────────────────────────────── +# Stale registration state regression tests +# ───────────────────────────────────────────────────────────────────────────── + + +class TestStaleDsRegistrationState: + """Regression tests: _DS_SECRET_VARS/_DS_KNOWN_VARS must mirror vault contents.""" + + def test_remove_clears_stale_secret_vars(self, vault_dir, registry): + """After removing a connection, its secret var names leave _DS_SECRET_VARS.""" + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "prod_db", { + "host": "pg.example.com", "port": "5432", + "database": "prod_db", "user": "alice", "password": "s3cr3t", + }) + + with patch("anton.datasource_registry.DatasourceRegistry", return_value=registry): + _restore_namespaced_env(vault) + + assert "DS_POSTGRESQL_PROD_DB__PASSWORD" in _DS_SECRET_VARS + assert "DS_POSTGRESQL_PROD_DB__PASSWORD" in _DS_KNOWN_VARS + + vault.delete("postgresql", "prod_db") + + with patch("anton.datasource_registry.DatasourceRegistry", return_value=registry): + _restore_namespaced_env(vault) + + assert "DS_POSTGRESQL_PROD_DB__PASSWORD" not in _DS_SECRET_VARS + assert "DS_POSTGRESQL_PROD_DB__PASSWORD" not in _DS_KNOWN_VARS + + def test_edit_connection_refreshes_secret_vars(self, vault_dir, registry): + """Overwriting a connection via vault.save rebuilds registration without duplication.""" + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "prod_db", { + "host": "pg.example.com", "port": "5432", + "database": "prod_db", "user": "alice", "password": "old-pass", + }) + + with patch("anton.datasource_registry.DatasourceRegistry", return_value=registry): + _restore_namespaced_env(vault) + + secret_key = "DS_POSTGRESQL_PROD_DB__PASSWORD" + assert secret_key in _DS_SECRET_VARS + count_before = len(_DS_SECRET_VARS) + + # Simulate edit: overwrite with new credentials + vault.save("postgresql", "prod_db", { + "host": "pg.example.com", "port": "5432", + "database": "prod_db", "user": "alice", "password": "new-pass", + }) + + with patch("anton.datasource_registry.DatasourceRegistry", return_value=registry): + _restore_namespaced_env(vault) + + assert secret_key in _DS_SECRET_VARS + assert len(_DS_SECRET_VARS) == count_before + assert os.environ.get(secret_key) == "new-pass" + + def test_reconnect_no_duplicate_secret_vars(self, vault_dir, registry): + """Calling _restore_namespaced_env multiple times does not grow _DS_SECRET_VARS.""" + vault = DataVault(vault_dir=vault_dir) + vault.save("postgresql", "prod_db", { + "host": "pg.example.com", "port": "5432", + "database": "prod_db", "user": "alice", "password": "s3cr3t", + }) + + with patch("anton.datasource_registry.DatasourceRegistry", return_value=registry): + _restore_namespaced_env(vault) + + count_after_first = len(_DS_SECRET_VARS) + known_after_first = len(_DS_KNOWN_VARS) + + with patch("anton.datasource_registry.DatasourceRegistry", return_value=registry): + _restore_namespaced_env(vault) + + assert len(_DS_SECRET_VARS) == count_after_first + assert len(_DS_KNOWN_VARS) == known_after_first From 02e92f32086528c323375bf0f15dd2059d35a75f Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Sat, 21 Mar 2026 18:55:14 +0100 Subject: [PATCH 16/70] Improve parsing of the datasource names and slugs --- anton/chat.py | 53 +++++++++++++++++-------- anton/datasource_registry.py | 8 ++++ tests/test_datasource.py | 75 ++++++++++++++++++++++++++++++++++-- 3 files changed, 116 insertions(+), 20 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index 8dd211f..5340761 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -47,7 +47,6 @@ DatasourceEngine, DatasourceField, DatasourceRegistry, - _parse_file as _ds_parse_file, ) from rich.prompt import Confirm, Prompt @@ -1050,6 +1049,22 @@ def _reset_registered_ds_vars() -> None: _DS_KNOWN_VARS.clear() +def parse_connection_slug(slug: str, known_engines: list[str]) -> tuple[str, str] | None: + """Split a connection slug into (engine, name) using longest-prefix matching. + + Tries each known engine slug longest-first so that 'sql-server-prod-db' is + correctly parsed as engine='sql-server', name='prod-db' rather than + engine='sql', name='server-prod-db'. + + Returns None if no known engine prefix matches or if the name part is empty. + """ + for engine in sorted(known_engines, key=len, reverse=True): + prefix = engine + "-" + if slug.startswith(prefix) and len(slug) > len(prefix): + return (engine, slug[len(prefix):]) + return None + + def _register_secret_vars( engine_def: "DatasourceEngine", *, engine: str = "", name: str = "" ) -> None: @@ -1084,12 +1099,16 @@ def _scrub_credentials(text: str) -> str: """ for key in _DS_SECRET_VARS: value = os.environ.get(key, "") - if not value or len(value) <= 8: + if not value: continue text = text.replace(value, f"[{key}]") for key, value in os.environ.items(): if not key.startswith("DS_") or key in _DS_KNOWN_VARS: continue + # Length guard only for unknown DS_* vars (not registered secrets). + # Unknown vars are matched heuristically — a short value like "on" + # or "true" in a DS_ENABLE_X var should not be scrubbed. + # Registered secret vars bypass this check entirely. if not value or len(value) <= 8: continue text = text.replace(value, f"[{key}]") @@ -2723,7 +2742,7 @@ async def _handle_add_custom_datasource( ) tmp_path.write_text(existing + yaml_block, encoding="utf-8") - parsed = _ds_parse_file(tmp_path) + parsed = registry.validate_file(tmp_path) if slug in parsed: import shutil @@ -2735,7 +2754,7 @@ async def _handle_add_custom_datasource( "credentials saved but engine not written to datasources.md.[/]" ) - registry._load() + registry.reload() engine_def = registry.get(slug) if engine_def is None: # Fallback: construct inline so the flow can continue even if parse failed @@ -2765,15 +2784,17 @@ async def _handle_connect_datasource( # ── /edit-data-source path: update credentials for an existing slug ──────── if datasource_name is not None: - slug_parts = datasource_name.split("-", 1) - if len(slug_parts) != 2: + _parsed = parse_connection_slug( + datasource_name, [e.engine for e in registry.all_engines()] + ) + if _parsed is None: console.print( f"[anton.warning]Invalid slug '{datasource_name}'. " "Expected format: engine-name.[/]" ) console.print() return session - edit_engine, edit_name = slug_parts + edit_engine, edit_name = _parsed existing = vault.load(edit_engine, edit_name) if existing is None: console.print( @@ -3388,14 +3409,15 @@ def _handle_list_data_sources(console: Console) -> None: def _handle_remove_data_source(console: Console, slug: str) -> None: """Delete a connection from the Local Vault by slug (engine-name).""" vault = DataVault() - parts = slug.split("-", 1) - if len(parts) != 2: + registry = DatasourceRegistry() + _parsed = parse_connection_slug(slug, [e.engine for e in registry.all_engines()]) + if _parsed is None: console.print( f"[anton.warning]Invalid name '{slug}'. Use engine-name format.[/]" ) console.print() return - engine, name = parts + engine, name = _parsed if vault.load(engine, name) is None: console.print(f"[anton.warning]No connection '{slug}' found.[/]") console.print() @@ -3424,17 +3446,16 @@ async def _handle_test_datasource( console.print() return - parts = slug.split("-", 1) - if len(parts) != 2: + vault = DataVault() + registry = DatasourceRegistry() + _parsed = parse_connection_slug(slug, [e.engine for e in registry.all_engines()]) + if _parsed is None: console.print( f"[anton.warning]Invalid name '{slug}'. Use engine-name format.[/]" ) console.print() return - - vault = DataVault() - registry = DatasourceRegistry() - engine, name = parts + engine, name = _parsed fields = vault.load(engine, name) if fields is None: console.print( diff --git a/anton/datasource_registry.py b/anton/datasource_registry.py index 009aa55..0596a20 100644 --- a/anton/datasource_registry.py +++ b/anton/datasource_registry.py @@ -128,6 +128,14 @@ def _load(self) -> None: for slug, engine in _parse_file(self._USER_PATH).items(): self._engines[slug] = engine + def reload(self) -> None: + """Reload datasource definitions from disk.""" + self._load() + + def validate_file(self, path: Path) -> dict[str, DatasourceEngine]: + """Parse a datasources.md file and return its engine definitions.""" + return _parse_file(path) + def get(self, engine_slug: str) -> DatasourceEngine | None: return self._engines.get(engine_slug) diff --git a/tests/test_datasource.py b/tests/test_datasource.py index a71942b..d5e3ec0 100644 --- a/tests/test_datasource.py +++ b/tests/test_datasource.py @@ -21,6 +21,7 @@ _register_secret_vars, _restore_namespaced_env, _scrub_credentials, + parse_connection_slug, ) from anton.cli import app as _cli_app from anton.data_vault import DataVault, _slug_env_prefix @@ -424,6 +425,42 @@ def test_auth_method_fields_parsed(self, registry): assert private.fields[0].name == "access_token" assert private.fields[0].secret is True + def test_validate_file_returns_engines(self, registry, datasources_md): + result = registry.validate_file(datasources_md) + assert "postgresql" in result + assert "hubspot" in result + assert result["postgresql"].display_name == "PostgreSQL" + + def test_validate_file_missing_returns_empty(self, registry, tmp_path): + result = registry.validate_file(tmp_path / "nonexistent.md") + assert result == {} + + def test_reload_picks_up_new_engine(self, tmp_path): + md = tmp_path / "datasources.md" + md.write_text(dedent("""\ + ## MySQL + + ```yaml + engine: mysql + display_name: MySQL + pip: pymysql + name_from: database + fields: + - name: host + required: true + description: hostname + test_snippet: | + print("ok") + ``` + """)) + reg = DatasourceRegistry.__new__(DatasourceRegistry) + reg._engines = {} + reg._BUILTIN_PATH = md + reg._USER_PATH = tmp_path / "user.md" + reg.reload() + assert reg.get("mysql") is not None + assert reg.get("mysql").display_name == "MySQL" + # ───────────────────────────────────────────────────────────────────────────── # DatasourceRegistry — derive_name @@ -776,11 +813,12 @@ def test_scrub_leaves_non_secret_field_readable(self, registry, monkeypatch): assert "[DS_PASSWORD]" in result def test_scrub_skips_short_values(self, monkeypatch): - """Values of 8 characters or fewer are not scrubbed (e.g. port numbers).""" + """Registered secrets are always scrubbed regardless of length.""" _DS_SECRET_VARS.add("DS_PASSWORD") monkeypatch.setenv("DS_PASSWORD", "short") result = _scrub_credentials("password=short") - assert "short" in result + assert "short" not in result + assert "[DS_PASSWORD]" in result def test_scrub_fallback_redacts_unknown_long_ds_vars(self, monkeypatch): """Long DS_* vars not in _DS_SECRET_VARS are scrubbed as a safety fallback.""" @@ -1244,6 +1282,7 @@ def test_confirmation_yes_deletes(self, vault, registry): with ( patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), patch("rich.prompt.Confirm.ask", return_value=True), ): _handle_remove_data_source(console, "postgresql-prod_db") @@ -1256,17 +1295,21 @@ def test_confirmation_no_preserves(self, vault, registry): with ( patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), patch("rich.prompt.Confirm.ask", return_value=False), ): _handle_remove_data_source(console, "postgresql-prod_db") assert vault.load("postgresql", "prod_db") is not None - def test_unknown_name_shows_message(self, vault_dir): + def test_unknown_name_shows_message(self, vault_dir, registry): vault = DataVault(vault_dir=vault_dir) console = MagicMock() - with patch("anton.chat.DataVault", return_value=vault): + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=registry), + ): _handle_remove_data_source(console, "postgresql-ghost") printed = " ".join(str(c) for c in console.print.call_args_list) @@ -1395,6 +1438,30 @@ async def test_edit_data_source_no_arg_safe(self, vault_dir, registry, make_sess # ───────────────────────────────────────────────────────────────────────────── +class TestParseConnectionSlug: + ENGINES = ["postgresql", "sql-server", "google-big-query"] + + def test_simple_engine(self): + assert parse_connection_slug("postgresql-prod_db", self.ENGINES) == ("postgresql", "prod_db") + + def test_hyphenated_engine(self): + assert parse_connection_slug("sql-server-prod-db", self.ENGINES) == ("sql-server", "prod-db") + + def test_longest_prefix_wins(self): + engines = ["google", "google-big-query"] + assert parse_connection_slug("google-big-query-main", engines) == ("google-big-query", "main") + + def test_ambiguous_resolves_to_longest(self): + engines = ["sql", "sql-server"] + assert parse_connection_slug("sql-server-1", engines) == ("sql-server", "1") + + def test_invalid_slug_no_match(self): + assert parse_connection_slug("unknown-engine-name", self.ENGINES) is None + + def test_slug_with_empty_name_part(self): + assert parse_connection_slug("postgresql-", self.ENGINES) is None + + class TestSlugEnvPrefix: def test_basic_engine_and_name(self): assert _slug_env_prefix("postgres", "prod_db") == "DS_POSTGRES_PROD_DB" From d071f50ef8c0e34f3c6eb9f2cf07f19132809a7c Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Sat, 21 Mar 2026 18:55:37 +0100 Subject: [PATCH 17/70] Split tests a bit --- tests/test_scrubbing.py | 68 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 tests/test_scrubbing.py diff --git a/tests/test_scrubbing.py b/tests/test_scrubbing.py new file mode 100644 index 0000000..cdce8cd --- /dev/null +++ b/tests/test_scrubbing.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import os +from unittest.mock import patch + +import pytest + +from anton.chat import ( + _DS_KNOWN_VARS, + _DS_SECRET_VARS, + _scrub_credentials, +) + + +@pytest.fixture(autouse=True) +def clean_ds_state(): + """Clear _DS_SECRET_VARS, _DS_KNOWN_VARS, and all DS_* env vars around each test.""" + def _clean(): + _DS_SECRET_VARS.clear() + _DS_KNOWN_VARS.clear() + for k in list(os.environ): + if k.startswith("DS_"): + del os.environ[k] + + _clean() + yield + _clean() + + +class TestScrubCredentials: + """Focused regression tests for _scrub_credentials short-secret handling.""" + + def test_registered_6char_secret_scrubbed(self, monkeypatch): + """A 6-character registered secret is scrubbed regardless of length.""" + _DS_SECRET_VARS.add("DS_PASSWORD") + monkeypatch.setenv("DS_PASSWORD", "abc123") + result = _scrub_credentials("auth failed: abc123") + assert "abc123" not in result + assert "[DS_PASSWORD]" in result + + def test_registered_8char_secret_scrubbed(self, monkeypatch): + """An 8-character registered secret is scrubbed (was at the old threshold).""" + _DS_SECRET_VARS.add("DS_API_KEY") + monkeypatch.setenv("DS_API_KEY", "tok12345") + result = _scrub_credentials("token=tok12345 rejected") + assert "tok12345" not in result + assert "[DS_API_KEY]" in result + + def test_registered_1char_secret_scrubbed(self, monkeypatch): + """A 1-character registered secret is scrubbed.""" + _DS_SECRET_VARS.add("DS_SECRET") + monkeypatch.setenv("DS_SECRET", "x") + result = _scrub_credentials("value=x here") + assert "=x " not in result + assert "[DS_SECRET]" in result + + def test_non_secret_var_not_scrubbed(self, monkeypatch): + """A known but non-secret DS_* var (e.g. DS_HOST) stays readable.""" + _DS_KNOWN_VARS.add("DS_HOST") + monkeypatch.setenv("DS_HOST", "db.example.com") + result = _scrub_credentials("host=db.example.com") + assert "db.example.com" in result + + def test_unknown_short_ds_var_not_scrubbed(self, monkeypatch): + """Unknown DS_* vars with short values are NOT scrubbed (heuristic threshold).""" + monkeypatch.setenv("DS_ENABLE_FEATURE", "on") + result = _scrub_credentials("flag=on active") + assert "on" in result From 2a35a80d1a39f70de12b7093d9a266e36dcc5d5e Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Sat, 21 Mar 2026 19:12:00 +0100 Subject: [PATCH 18/70] Fix the custom datasource secrets handling --- anton/chat.py | 38 ++++++++++ tests/test_datasource.py | 150 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 188 insertions(+) diff --git a/anton/chat.py b/anton/chat.py index 5340761..3e85e65 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -2712,6 +2712,34 @@ async def _handle_add_custom_datasource( if value: credentials[f.name] = value + # Prompt for any required non-secret fields not provided inline + for f, raw in zip(fields, raw_fields): + if f.secret: + continue # already handled above + if not f.required: + continue # optional fields handled below + if f.name in credentials: + continue # already collected inline + value = Prompt.ask( + f"[anton.cyan](anton)[/] {f.name}", + console=console, + default="", + ) + if value: + credentials[f.name] = value + + # Offer to collect optional non-secret fields + for f, raw in zip(fields, raw_fields): + if f.secret or f.required or f.name in credentials: + continue + value = Prompt.ask( + f"[anton.cyan](anton)[/] {f.name} (optional — press Enter to skip)", + console=console, + default="", + ) + if value: + credentials[f.name] = value + if not credentials: console.print("[anton.warning] No credentials collected. Aborting.[/]") console.print() @@ -2765,6 +2793,16 @@ async def _handle_add_custom_datasource( fields=fields, ) + # All required fields must be present before the caller saves credentials + missing_required = [f.name for f in fields if f.required and f.name not in credentials] + if missing_required: + console.print( + "[anton.warning] Cannot save — missing required fields: " + f"{', '.join(missing_required)}. Aborting.[/]" + ) + console.print() + return None + return engine_def, credentials diff --git a/tests/test_datasource.py b/tests/test_datasource.py index d5e3ec0..d051bf7 100644 --- a/tests/test_datasource.py +++ b/tests/test_datasource.py @@ -14,6 +14,7 @@ _DS_KNOWN_VARS, _DS_SECRET_VARS, _build_datasource_context, + _handle_add_custom_datasource, _handle_connect_datasource, _handle_list_data_sources, _handle_remove_data_source, @@ -1638,3 +1639,152 @@ def test_reconnect_no_duplicate_secret_vars(self, vault_dir, registry): assert len(_DS_SECRET_VARS) == count_after_first assert len(_DS_KNOWN_VARS) == known_after_first + + +# ───────────────────────────────────────────────────────────────────────────── +# TestAddCustomDatasourceFlow +# ───────────────────────────────────────────────────────────────────────────── + + +class TestAddCustomDatasourceFlow: + """Tests for _handle_add_custom_datasource field-collection logic.""" + + def _make_llm_response( + self, fields: list[dict], display_name: str = "MyDB" + ) -> str: + """Return a JSON string mimicking the LLM's plan() response.""" + import json as _json + return _json.dumps( + {"display_name": display_name, "pip": "", "fields": fields} + ) + + def _make_registry(self, tmp_path): + """Return a minimal registry mock that accepts any slug.""" + reg = MagicMock() + reg.validate_file.return_value = {"mydb": MagicMock()} + reg.reload.return_value = None + reg.get.return_value = None # triggers inline fallback + return reg + + def _make_llm(self, json_text: str): + """Return an AsyncMock LLM whose plan() returns json_text.""" + llm = AsyncMock() + response = MagicMock() + response.content = json_text + llm.plan = AsyncMock(return_value=response) + return llm + + def _mock_ds_path(self, mock_path_cls, tmp_path): + """Wire Path mock so datasources.md writes go to tmp_path.""" + mock_ds = MagicMock() + mock_ds.is_file.return_value = False + mock_ds.with_suffix.return_value = tmp_path / "datasources.tmp" + mock_path_cls.return_value.expanduser.return_value = mock_ds + + @pytest.mark.asyncio + async def test_missing_required_non_secret_field_prompts_user( + self, tmp_path, make_session + ): + """Required non-secret field without inline value triggers Prompt.ask.""" + session = make_session() + session._llm = self._make_llm(self._make_llm_response([ + { + "name": "host", "value": "", "secret": False, + "required": True, "description": "hostname", + }, + ])) + console = MagicMock() + registry = self._make_registry(tmp_path) + + # first response: initial auth question; second: the host field prompt + prompt_responses = iter(["I want to connect to mydb", "localhost"]) + + with ( + patch( + "rich.prompt.Prompt.ask", + side_effect=lambda *a, **kw: next(prompt_responses), + ), + patch("anton.chat.Path") as mock_path_cls, + ): + self._mock_ds_path(mock_path_cls, tmp_path) + result = await _handle_add_custom_datasource( + console, "mydb", registry, session + ) + + assert result is not None + _, credentials = result + assert credentials["host"] == "localhost" + + @pytest.mark.asyncio + async def test_missing_required_secret_field_prompts_user( + self, tmp_path, make_session + ): + """Required secret field without inline value triggers password prompt.""" + session = make_session() + session._llm = self._make_llm(self._make_llm_response([ + { + "name": "api_key", "value": "", "secret": True, + "required": True, "description": "API key", + }, + ])) + console = MagicMock() + registry = self._make_registry(tmp_path) + + password_calls = [] + + def fake_prompt(*args, **kwargs): + if kwargs.get("password"): + password_calls.append(kwargs) + return "mysecret" + return "I want to connect" + + with ( + patch("rich.prompt.Prompt.ask", side_effect=fake_prompt), + patch("anton.chat.Path") as mock_path_cls, + ): + self._mock_ds_path(mock_path_cls, tmp_path) + result = await _handle_add_custom_datasource( + console, "mydb", registry, session + ) + + assert result is not None + _, credentials = result + assert credentials["api_key"] == "mysecret" + assert len(password_calls) == 1, "expected exactly one password prompt" + + @pytest.mark.asyncio + async def test_incomplete_custom_datasource_not_saved( + self, tmp_path, make_session + ): + """Empty responses for all required fields causes a hard stop (None).""" + session = make_session() + session._llm = self._make_llm(self._make_llm_response([ + { + "name": "host", "value": "", "secret": False, + "required": True, "description": "hostname", + }, + { + "name": "api_key", "value": "", "secret": True, + "required": True, "description": "API key", + }, + ])) + console = MagicMock() + registry = self._make_registry(tmp_path) + + # User presses Enter (empty) for every prompt + prompt_responses = iter(["I want to connect", "", ""]) + + with ( + patch( + "rich.prompt.Prompt.ask", + side_effect=lambda *a, **kw: next(prompt_responses), + ), + patch("anton.chat.Path") as mock_path_cls, + ): + self._mock_ds_path(mock_path_cls, tmp_path) + result = await _handle_add_custom_datasource( + console, "mydb", registry, session + ) + + # Must return None — caller must not call vault.save() + assert result is None From 17d67625e1111581f5b08e03bc637394826efcdb Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Sat, 21 Mar 2026 19:35:21 +0100 Subject: [PATCH 19/70] Fix tests --- anton/chat.py | 26 +++++++++++++++++++------ anton/llm/prompts.py | 13 +++++++++---- tests/test_datasource.py | 41 ++++++++++++++++++++++++++++++++++++---- 3 files changed, 66 insertions(+), 14 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index 3e85e65..b8720c7 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -1049,19 +1049,33 @@ def _reset_registered_ds_vars() -> None: _DS_KNOWN_VARS.clear() -def parse_connection_slug(slug: str, known_engines: list[str]) -> tuple[str, str] | None: +def parse_connection_slug( + slug: str, + known_engines: list[str], + *, + vault: DataVault | None = None, +) -> tuple[str, str] | None: """Split a connection slug into (engine, name) using longest-prefix matching. - Tries each known engine slug longest-first so that 'sql-server-prod-db' is + First tries each known registry engine longest-first so that 'sql-server-prod-db' is correctly parsed as engine='sql-server', name='prod-db' rather than engine='sql', name='server-prod-db'. - Returns None if no known engine prefix matches or if the name part is empty. + If nothing matches and a vault is supplied, falls back to scanning vault + connections for an exact slug match — handles custom/unregistered engines. + + Returns None if no match found or name part is empty. """ for engine in sorted(known_engines, key=len, reverse=True): prefix = engine + "-" if slug.startswith(prefix) and len(slug) > len(prefix): return (engine, slug[len(prefix):]) + + if vault is not None: + for conn in vault.list_connections(): + if f"{conn['engine']}-{conn['name']}" == slug: + return (conn["engine"], conn["name"]) + return None @@ -2823,7 +2837,7 @@ async def _handle_connect_datasource( # ── /edit-data-source path: update credentials for an existing slug ──────── if datasource_name is not None: _parsed = parse_connection_slug( - datasource_name, [e.engine for e in registry.all_engines()] + datasource_name, [e.engine for e in registry.all_engines()], vault=vault ) if _parsed is None: console.print( @@ -3448,7 +3462,7 @@ def _handle_remove_data_source(console: Console, slug: str) -> None: """Delete a connection from the Local Vault by slug (engine-name).""" vault = DataVault() registry = DatasourceRegistry() - _parsed = parse_connection_slug(slug, [e.engine for e in registry.all_engines()]) + _parsed = parse_connection_slug(slug, [e.engine for e in registry.all_engines()], vault=vault) if _parsed is None: console.print( f"[anton.warning]Invalid name '{slug}'. Use engine-name format.[/]" @@ -3486,7 +3500,7 @@ async def _handle_test_datasource( vault = DataVault() registry = DatasourceRegistry() - _parsed = parse_connection_slug(slug, [e.engine for e in registry.all_engines()]) + _parsed = parse_connection_slug(slug, [e.engine for e in registry.all_engines()], vault=vault) if _parsed is None: console.print( f"[anton.warning]Invalid name '{slug}'. Use engine-name format.[/]" diff --git a/anton/llm/prompts.py b/anton/llm/prompts.py index c26eef1..2ba485c 100644 --- a/anton/llm/prompts.py +++ b/anton/llm/prompts.py @@ -57,10 +57,15 @@ tool-call loop inside scratchpad code. The LLM reasons and calls your tools iteratively. \ handle_tool(name, inputs) is a plain sync function returning a string result. Use this for \ multi-step AI workflows like classification, extraction, or analysis with structured outputs. -- All .anton/.env secrets are available as environment variables (os.environ). -- Data source credentials are injected as DS_ environment \ -variables (e.g. DS_HOST, DS_PASSWORD, DS_ACCESS_TOKEN). Use them directly \ -in scratchpad code — never read ~/.anton/data_vault/ files directly. +- All .anton/.env variables are available as environment variables (os.environ). +- Connected data source credentials are injected as namespaced environment \ +variables in the form DS___ \ +(e.g. DS_POSTGRES_PROD_DB__HOST, DS_POSTGRES_PROD_DB__PASSWORD, \ +DS_HUBSPOT_MAIN__ACCESS_TOKEN). Use those variables directly in scratchpad \ +code and never read ~/.anton/data_vault/ files directly. +- Flat variables like DS_HOST or DS_PASSWORD are used only temporarily \ +during internal connection test snippets. Do not assume they exist during \ +normal chat/runtime execution. - When the user asks how you solved something or wants to see your work, use the scratchpad \ dump action — it shows a clean notebook-style summary without wasting tokens on reformatting. - Always use print() to produce output — scratchpad captures stdout. diff --git a/tests/test_datasource.py b/tests/test_datasource.py index d051bf7..49c89ae 100644 --- a/tests/test_datasource.py +++ b/tests/test_datasource.py @@ -1462,6 +1462,42 @@ def test_invalid_slug_no_match(self): def test_slug_with_empty_name_part(self): assert parse_connection_slug("postgresql-", self.ENGINES) is None + def test_fallback_to_vault_for_custom_engine(self, tmp_path): + """Custom engine not in registry is resolved via vault fallback.""" + vault = DataVault(vault_dir=tmp_path / "vault") + vault.save("my_custom_db", "prod", {"host": "localhost"}) + result = parse_connection_slug( + "my_custom_db-prod", + known_engines=["postgresql"], + vault=vault, + ) + assert result == ("my_custom_db", "prod") + + def test_registry_match_takes_priority_over_vault(self, tmp_path): + """Registry prefix match wins even when vault also has the slug.""" + vault = DataVault(vault_dir=tmp_path / "vault") + vault.save("postgresql", "prod", {"host": "localhost"}) + result = parse_connection_slug( + "postgresql-prod", + known_engines=["postgresql"], + vault=vault, + ) + assert result == ("postgresql", "prod") + + def test_no_match_returns_none_with_vault(self, tmp_path): + """Truly unknown slug returns None even with vault supplied.""" + vault = DataVault(vault_dir=tmp_path / "vault") + result = parse_connection_slug( + "ghost-engine-1", + known_engines=["postgresql"], + vault=vault, + ) + assert result is None + + def test_no_vault_still_returns_none_for_unknown(self): + """Backward compat: no vault arg, unknown engine still returns None.""" + assert parse_connection_slug("custom-1", known_engines=["postgresql"]) is None + class TestSlugEnvPrefix: def test_basic_engine_and_name(self): @@ -1676,10 +1712,7 @@ def _make_llm(self, json_text: str): def _mock_ds_path(self, mock_path_cls, tmp_path): """Wire Path mock so datasources.md writes go to tmp_path.""" - mock_ds = MagicMock() - mock_ds.is_file.return_value = False - mock_ds.with_suffix.return_value = tmp_path / "datasources.tmp" - mock_path_cls.return_value.expanduser.return_value = mock_ds + mock_path_cls.return_value.expanduser.return_value = tmp_path / "datasources.md" @pytest.mark.asyncio async def test_missing_required_non_secret_field_prompts_user( From f69a6c7f6fc60f19cf788886f6b96ff030a23df6 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Sat, 21 Mar 2026 19:36:30 +0100 Subject: [PATCH 20/70] Fix tests --- tests/test_datasource.py | 12 ++++++------ tests/test_scrubbing.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test_datasource.py b/tests/test_datasource.py index 49c89ae..a548f03 100644 --- a/tests/test_datasource.py +++ b/tests/test_datasource.py @@ -806,10 +806,10 @@ def test_scrub_leaves_non_secret_field_readable(self, registry, monkeypatch): pg = registry.get("postgresql") assert pg is not None _register_secret_vars(pg) - monkeypatch.setenv("DS_HOST", "db.example.com") + monkeypatch.setenv("DS_HOST", "mydbhostname") monkeypatch.setenv("DS_PASSWORD", "s3cr3tpassword99") - result = _scrub_credentials("host=db.example.com pass=s3cr3tpassword99") - assert "db.example.com" in result + result = _scrub_credentials("host=mydbhostname pass=s3cr3tpassword99") + assert "mydbhostname" in result assert "s3cr3tpassword99" not in result assert "[DS_PASSWORD]" in result @@ -890,9 +890,9 @@ def test_scrub_replaces_namespaced_secret_value(self, registry, monkeypatch): def test_scrub_leaves_namespaced_non_secret_readable(self, registry, monkeypatch): pg = registry.get("postgresql") _register_secret_vars(pg, engine="postgresql", name="prod_db") - monkeypatch.setenv("DS_POSTGRESQL_PROD_DB__HOST", "db.example.com") - result = _scrub_credentials("host=db.example.com") - assert "db.example.com" in result + monkeypatch.setenv("DS_POSTGRESQL_PROD_DB__HOST", "mydbhostname") + result = _scrub_credentials("host=mydbhostname") + assert "mydbhostname" in result # ───────────────────────────────────────────────────────────────────────────── diff --git a/tests/test_scrubbing.py b/tests/test_scrubbing.py index cdce8cd..cad690c 100644 --- a/tests/test_scrubbing.py +++ b/tests/test_scrubbing.py @@ -57,9 +57,9 @@ def test_registered_1char_secret_scrubbed(self, monkeypatch): def test_non_secret_var_not_scrubbed(self, monkeypatch): """A known but non-secret DS_* var (e.g. DS_HOST) stays readable.""" _DS_KNOWN_VARS.add("DS_HOST") - monkeypatch.setenv("DS_HOST", "db.example.com") - result = _scrub_credentials("host=db.example.com") - assert "db.example.com" in result + monkeypatch.setenv("DS_HOST", "mydbhostname") + result = _scrub_credentials("host=mydbhostname") + assert "mydbhostname" in result def test_unknown_short_ds_var_not_scrubbed(self, monkeypatch): """Unknown DS_* vars with short values are NOT scrubbed (heuristic threshold).""" From af587f7d37f181ae8e767e587cbe0d4d44a30dd7 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Mon, 23 Mar 2026 14:59:01 +0100 Subject: [PATCH 21/70] Remove dendoo and replace with Big Commerce --- datasources.md | 36 ++++++++++++++---------------------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/datasources.md b/datasources.md index 53bd5a9..d435404 100644 --- a/datasources.md +++ b/datasources.md @@ -564,35 +564,27 @@ Setup → Users/Roles → Access Tokens. The account ID can be found in Setup --- -## Denodo +## Big Commerce ```yaml -engine: denodo -display_name: Denodo -pip: psycopg2-binary -name_from: [host, database] +engine: bigcommerce +display_name: Big Commerce +pip: httpx +name_from: store_hash fields: - - { name: host, required: true, secret: false, description: "hostname or IP of the Denodo server" } - - { name: port, required: false, secret: false, description: "port number (default 9996)", default: "9996" } - - { name: database, required: true, secret: false, description: "Denodo virtual database name" } - - { name: user, required: true, secret: false, description: "Denodo username" } - - { name: password, required: true, secret: true, description: "Denodo password" } + - { name: api_base, required: true, secret: false, description: "Base URL of the BigCommerce API (e.g. https://api.bigcommerce.com/stores/0fh0fh0fh0/v3/)" } + - { name: access_token, required: true, secret: true, description: "API token for authenticating with BigCommerce" } test_snippet: | - import psycopg2, os - conn = psycopg2.connect( - host=os.environ['DS_HOST'], - port=int(os.environ.get('DS_PORT', '9996')), - dbname=os.environ['DS_DATABASE'], - user=os.environ['DS_USER'], - password=os.environ['DS_PASSWORD'], - ) - conn.close() + import httpx, os + api_base = os.environ['DS_API_BASE'].rstrip('/') + access_token = os.environ['DS_ACCESS_TOKEN'] + headers = {'X-Auth-Token': access_token} + r = httpx.get(f'{api_base}/catalog/products', headers=headers) + assert r.status_code < 400, f'HTTP {r.status_code}: {r.text[:200]}' print("ok") ``` -Denodo exposes a PostgreSQL-compatible wire protocol on port 9996 by default. -Connect via the Denodo Platform Control Center → Server Configuration → ODBC/JDBC to find the port. - +BigCommerce API tokens can be created in the BigCommerce control panel under Advanced Settings → API Accounts. Choose "Create API Account", then select "V2/V3 API Token" and grant the necessary permissions (e.g. "Products: Read-Only" to access product data). --- ## TimescaleDB From de8ebc0776f0e3f2143055607ca4138b7547f6d3 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Mon, 23 Mar 2026 15:01:26 +0100 Subject: [PATCH 22/70] Make port required field for all engines --- datasources.md | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/datasources.md b/datasources.md index d435404..7a9b85e 100644 --- a/datasources.md +++ b/datasources.md @@ -18,7 +18,7 @@ pip: psycopg2-binary name_from: database fields: - { name: host, required: true, secret: false, description: "hostname or IP of your database server" } - - { name: port, required: false, secret: false, description: "port number", default: "5432" } + - { name: port, required: true, secret: false, description: "port number", default: "5432" } - { name: database, required: true, secret: false, description: "name of the database to connect" } - { name: user, required: true, secret: false, description: "database username" } - { name: password, required: true, secret: true, description: "database password" } @@ -49,7 +49,7 @@ pip: mysql-connector-python name_from: database fields: - { name: host, required: true, secret: false, description: "hostname or IP of your MySQL server" } - - { name: port, required: false, secret: false, description: "port number", default: "3306" } + - { name: port, required: true, secret: false, description: "port number", default: "3306" } - { name: database, required: true, secret: false, description: "database name to connect" } - { name: user, required: true, secret: false, description: "MySQL username" } - { name: password, required: true, secret: true, description: "MySQL password" } @@ -81,7 +81,7 @@ pip: mysql-connector-python name_from: database fields: - { name: host, required: true, secret: false, description: "hostname or IP of your MariaDB server" } - - { name: port, required: false, secret: false, description: "port number", default: "3306" } + - { name: port, required: true, secret: false, description: "port number", default: "3306" } - { name: database, required: true, secret: false, description: "database name to connect" } - { name: user, required: true, secret: false, description: "MariaDB username" } - { name: password, required: true, secret: true, description: "MariaDB password" } @@ -115,7 +115,7 @@ pip: pymssql name_from: database fields: - { name: host, required: true, secret: false, description: "hostname or IP of the SQL Server (for Azure use server field instead)" } - - { name: port, required: false, secret: false, description: "port number", default: "1433" } + - { name: port, required: true, secret: false, description: "port number", default: "1433" } - { name: database, required: true, secret: false, description: "database name" } - { name: user, required: true, secret: false, description: "SQL Server username" } - { name: password, required: true, secret: true, description: "SQL Server password" } @@ -240,8 +240,8 @@ name_from: [host, service_name] fields: - { name: user, required: true, secret: false, description: "Oracle database username" } - { name: password, required: true, secret: true, description: "Oracle database password" } - - { name: host, required: false, secret: false, description: "hostname or IP address of the Oracle server" } - - { name: port, required: false, secret: false, description: "port number (default 1521)", default: "1521" } + - { name: host, required: true, secret: false, description: "hostname or IP address of the Oracle server" } + - { name: port, required: true, secret: false, description: "port number (default 1521)", default: "1521" } - { name: service_name, required: false, secret: false, description: "Oracle service name (preferred over SID)" } - { name: sid, required: false, secret: false, description: "Oracle SID — use service_name if possible" } - { name: dsn, required: false, secret: false, description: "full DSN string — overrides host/port/service_name" } @@ -307,7 +307,7 @@ pip: psycopg2-binary name_from: [host, database] fields: - { name: host, required: true, secret: false, description: "Redshift cluster endpoint (e.g. mycluster.abc123.us-east-1.redshift.amazonaws.com)" } - - { name: port, required: false, secret: false, description: "port number", default: "5439" } + - { name: port, required: true, secret: false, description: "port number", default: "5439" } - { name: database, required: true, secret: false, description: "database name" } - { name: user, required: true, secret: false, description: "Redshift username" } - { name: password, required: true, secret: true, description: "Redshift password" } @@ -403,7 +403,7 @@ pip: pgvector psycopg2-binary name_from: database fields: - { name: host, required: true, secret: false, description: "hostname or IP of your PostgreSQL server with pgvector extension" } - - { name: port, required: false, secret: false, description: "port number", default: "5432" } + - { name: port, required: true, secret: false, description: "port number", default: "5432" } - { name: database, required: true, secret: false, description: "database name" } - { name: user, required: true, secret: false, description: "database username" } - { name: password, required: true, secret: true, description: "database password" } @@ -439,8 +439,8 @@ display_name: ChromaDB pip: chromadb name_from: host fields: - - { name: host, required: false, secret: false, description: "ChromaDB server host for HTTP client mode (omit for local in-process mode)" } - - { name: port, required: false, secret: false, description: "ChromaDB server port", default: "8000" } + - { name: host, required: true, secret: false, description: "ChromaDB server host for HTTP client mode (omit for local in-process mode)" } + - { name: port, required: true, secret: false, description: "ChromaDB server port", default: "8000" } - { name: persist_directory, required: false, secret: false, description: "local directory for persistent storage (local mode only)" } test_snippet: | import chromadb, os @@ -596,7 +596,7 @@ pip: psycopg2-binary name_from: [host, database] fields: - { name: host, required: true, secret: false, description: "hostname or IP of the TimescaleDB server" } - - { name: port, required: false, secret: false, description: "port number", default: "5432" } + - { name: port, required: true, secret: false, description: "port number", default: "5432" } - { name: database, required: true, secret: false, description: "database name" } - { name: user, required: true, secret: false, description: "database username" } - { name: password, required: true, secret: true, description: "database password" } From 74ac2b7b95dedcc7f478ee85acb232d2bfa454f0 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Mon, 23 Mar 2026 17:05:58 +0100 Subject: [PATCH 23/70] Fix issues with custom datasources --- anton/chat.py | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index b8720c7..d18d47e 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -2622,20 +2622,20 @@ async def _handle_add_custom_datasource( registry, session: "ChatSession", ): - """Ask the user how they authenticate, use the LLM to identify fields, save definition.""" + """Ask for the tool name, use the LLM to identify required fields, then collect credentials.""" console.print() + preamble = "[anton.cyan](anton)[/] " if name: - preamble = f"[anton.cyan](anton)[/] '{name}' isn't in my built-in list.\n " + tool_name = name else: - preamble = "[anton.cyan](anton)[/] " - user_answer = Prompt.ask( - f"{preamble}How do you authenticate with it? " - f"Describe what you have or paste credentials directly", - console=console, - ) - if not user_answer.strip(): + tool_name = Prompt.ask( + f"{preamble}What is the name of the tool or service?", + console=console, + ) + if not tool_name.strip(): return None + tool_name = tool_name.strip() console.print() console.print("[anton.muted] Got it — working out the connection details…[/]") @@ -2647,7 +2647,7 @@ async def _handle_add_custom_datasource( { "role": "user", "content": ( - f"The user wants to connect to{(' ' + repr(name)) if name else ' a custom data source'} and said: {user_answer}\n\n" + f"The user wants to connect to {repr(tool_name)}.\n\n" "Return ONLY valid JSON (no markdown fences, no commentary):\n" '{"display_name":"Human-readable name","pip":"pip-package or empty string",' '"fields":[{"name":"snake_case_name","value":"value if given inline else empty",' @@ -2658,6 +2658,7 @@ async def _handle_add_custom_datasource( max_tokens=1024, ) text = response.content.strip() + # Keep text = _re.sub(r"^```[^\n]*\n|```\s*$", "", text, flags=_re.MULTILINE).strip() data = _json.loads(text) except Exception: @@ -3139,12 +3140,20 @@ async def _handle_connect_datasource( return session engine_def, credentials = result conn_num = vault.next_connection_number(engine_def.engine) - vault.save(engine_def.engine, str(conn_num), credentials) + conn_name = str(conn_num) + vault.save(engine_def.engine, conn_name, credentials) slug = f"{engine_def.engine}-{conn_num}" + _restore_namespaced_env(vault) + session._active_datasource = slug + _register_secret_vars(engine_def, engine=engine_def.engine, name=conn_name) console.print( f' Credentials saved to Local Vault as [bold]"{slug}"[/bold].' ) console.print() + console.print( + "[anton.muted] You can now ask me questions about your data.[/]" + ) + console.print() session._history.append( { "role": "assistant", From 5081cd8ee942aa9b595da519109071dfd4db8a6b Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Mon, 23 Mar 2026 17:07:22 +0100 Subject: [PATCH 24/70] Format code --- anton/chat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/anton/chat.py b/anton/chat.py index d18d47e..e319866 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -1,8 +1,8 @@ from __future__ import annotations import asyncio -import os import json as _json +import os import re as _re import sys import time @@ -48,6 +48,7 @@ DatasourceField, DatasourceRegistry, ) + from rich.prompt import Confirm, Prompt if TYPE_CHECKING: From a88a1127eb4086edebbe593dab8fcc594efeba7b Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Mon, 23 Mar 2026 18:20:42 +0100 Subject: [PATCH 25/70] Prevent duplicate entries for custom datasources --- anton/chat.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/anton/chat.py b/anton/chat.py index e319866..d82e787 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -5,6 +5,7 @@ import os import re as _re import sys +import yaml as _yaml import time from collections.abc import AsyncIterator, Callable from pathlib import Path @@ -47,6 +48,7 @@ DatasourceEngine, DatasourceField, DatasourceRegistry, + _YAML_BLOCK_RE, ) from rich.prompt import Confirm, Prompt @@ -2617,6 +2619,27 @@ def _human_size(nbytes: int) -> str: return f"{nbytes:.1f}TB" +def _remove_engine_block(text: str, slug: str) -> str: + """Return *text* with any YAML datasource block for *slug* removed.""" + cleaned = [] + prev = 0 + for m in _YAML_BLOCK_RE.finditer(text): + try: + data = _yaml.safe_load(m.group(3)) + is_dup = isinstance(data, dict) and str(data.get("engine", "")) == slug + except Exception: + is_dup = False + if is_dup: + pre = text[prev : m.start()].rstrip() + pre = _re.sub(r"\n---\s*$", "", pre) + cleaned.append(pre) + else: + cleaned.append(text[prev : m.end()]) + prev = m.end() + cleaned.append(text[prev:]) + return "".join(cleaned) + + async def _handle_add_custom_datasource( console: Console, name: str, @@ -2784,6 +2807,9 @@ async def _handle_add_custom_datasource( existing = ( user_ds_path.read_text(encoding="utf-8") if user_ds_path.is_file() else "" ) + + existing = _remove_engine_block(existing, slug) + tmp_path.write_text(existing + yaml_block, encoding="utf-8") parsed = registry.validate_file(tmp_path) From 8c42a86ee9780db325d9384554b51de8a2acbe44 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Mon, 23 Mar 2026 18:25:26 +0100 Subject: [PATCH 26/70] Revert the questions flow for custom datasources --- anton/chat.py | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index d82e787..4202a51 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -2652,14 +2652,24 @@ async def _handle_add_custom_datasource( preamble = "[anton.cyan](anton)[/] " if name: tool_name = name + name_context = f"'{name}' isn't in my built-in list.\n " else: tool_name = Prompt.ask( f"{preamble}What is the name of the tool or service?", console=console, ) - if not tool_name.strip(): + if not tool_name.strip(): + return None + tool_name = tool_name.strip() + name_context = "" + + user_answer = Prompt.ask( + f"{preamble}{name_context}How do you authenticate with it? " + "Describe what credentials you have (don't paste actual values)", + console=console, + ) + if not user_answer.strip(): return None - tool_name = tool_name.strip() console.print() console.print("[anton.muted] Got it — working out the connection details…[/]") @@ -2671,7 +2681,7 @@ async def _handle_add_custom_datasource( { "role": "user", "content": ( - f"The user wants to connect to {repr(tool_name)}.\n\n" + f"The user wants to connect to {repr(tool_name)} and said: {user_answer}\n\n" "Return ONLY valid JSON (no markdown fences, no commentary):\n" '{"display_name":"Human-readable name","pip":"pip-package or empty string",' '"fields":[{"name":"snake_case_name","value":"value if given inline else empty",' @@ -2754,11 +2764,11 @@ async def _handle_add_custom_datasource( # Prompt for any required non-secret fields not provided inline for f, raw in zip(fields, raw_fields): if f.secret: - continue # already handled above + continue if not f.required: - continue # optional fields handled below + continue if f.name in credentials: - continue # already collected inline + continue value = Prompt.ask( f"[anton.cyan](anton)[/] {f.name}", console=console, @@ -2862,7 +2872,6 @@ async def _handle_connect_datasource( vault = DataVault() registry = DatasourceRegistry() - # ── /edit-data-source path: update credentials for an existing slug ──────── if datasource_name is not None: _parsed = parse_connection_slug( datasource_name, [e.engine for e in registry.all_engines()], vault=vault @@ -3032,7 +3041,6 @@ async def _handle_connect_datasource( ) return session - # ── Normal flow: connect a new (or reconnect an existing) data source ───── console.print() all_engines = registry.all_engines() @@ -3051,7 +3059,6 @@ async def _handle_connect_datasource( console=console, ) - # ── Reconnect path: user typed an existing vault slug ───────────────────── stripped_answer = answer.strip() known_slugs = {f"{c['engine']}-{c['name']}": c for c in vault.list_connections()} if stripped_answer in known_slugs: @@ -3080,7 +3087,6 @@ async def _handle_connect_datasource( ) return session - # ── Number selection ─────────────────────────────────────────────────────── engine_def: DatasourceEngine | None = None _go_custom = False @@ -3098,13 +3104,10 @@ async def _handle_connect_datasource( console.print() return session - # ── Name-based resolution (when not a number) ───────────────────────────── if engine_def is None and not _go_custom: - # 1. Exact / case-insensitive / whitespace-normalized match engine_def = registry.find_by_name(stripped_answer) - + # if exact match not found, try substring match against display and engine names if engine_def is None: - # 2. Substring match needle = stripped_answer.lower() candidates = [ e @@ -3133,9 +3136,8 @@ async def _handle_connect_datasource( console.print("[anton.warning]Invalid choice. Aborting.[/]") console.print() return session - + # fuzzy match against display and engine names if exact match not found if engine_def is None: - # 3. Fuzzy / close match fuzzy_matches = registry.fuzzy_find(stripped_answer) for suggestion in fuzzy_matches: console.print() @@ -3158,7 +3160,6 @@ async def _handle_connect_datasource( if engine_def is None: _go_custom = True - # ── Custom datasource flow ──────────────────────────────────────────────── if _go_custom: result = await _handle_add_custom_datasource( console, stripped_answer if not stripped_answer.isdigit() else "", registry, session @@ -3192,7 +3193,6 @@ async def _handle_connect_datasource( ) return session - # ── Step 2a: auth method choice (if engine requires it) ─────── active_fields = engine_def.fields if engine_def.auth_method == "choice" and engine_def.auth_methods: console.print() @@ -3244,7 +3244,6 @@ async def _handle_connect_datasource( console.print() - # ── Step 3: determine collection mode ──────────────────────── mode_answer = ( Prompt.ask( "[anton.cyan](anton)[/] Do you have these available? [y/n/]", From c02c14a4f1dab771ce8272084b305edc7b0c4002 Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Mon, 23 Mar 2026 12:59:23 -0700 Subject: [PATCH 27/70] minds passthrough --- anton/cli.py | 15 ++++++++++----- anton/config/settings.py | 2 +- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index a701342..130e606 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -224,7 +224,15 @@ def main( def _has_api_key(settings) -> bool: - """Check if all configured providers have API keys.""" + """Check if all configured providers have API keys. + + Also returns False when the Minds API key is missing, so that + upgrading users are prompted to set it up on first re-launch. + """ + # Minds key is always required now + if not (settings.minds_api_key or os.environ.get("ANTON_MINDS_API_KEY")): + return False + providers = {settings.planning_provider, settings.coding_provider} for p in providers: if p == "anthropic" and not (settings.anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY")): @@ -245,10 +253,7 @@ def _ensure_api_key(settings) -> None: ws = Workspace(Path.home()) - if settings.minds_enabled: - _ensure_minds_api_key(settings, ws) - else: - _ensure_anthropic_api_key(settings, ws) + _ensure_minds_api_key(settings, ws) # Reload env vars into the process so the scratchpad subprocess inherits them ws.apply_env_to_process() diff --git a/anton/config/settings.py b/anton/config/settings.py index 1aedfc2..ea4f7d0 100644 --- a/anton/config/settings.py +++ b/anton/config/settings.py @@ -51,7 +51,7 @@ class AntonSettings(BaseSettings): disable_autoupdates: bool = False # Minds datasource integration - minds_enabled: bool = False # opt-in: use Minds server as LLM provider + minds_enabled: bool = True # use Minds server as LLM provider minds_api_key: str | None = None minds_url: str = "https://mdb.ai" minds_mind_name: str | None = None From ad1457b8fc0bf779d88cb15d7c907050de515d13 Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Mon, 23 Mar 2026 13:23:37 -0700 Subject: [PATCH 28/70] ssl passthough --- anton/cli.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/anton/cli.py b/anton/cli.py index 130e606..da03327 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -319,11 +319,13 @@ def _ensure_minds_api_key(settings, ws) -> None: ws.set_secret("ANTON_MINDS_URL", minds_url) # Test if the Minds server supports LLM endpoints (_code_/_reason_) - # (silenced: was printing "Testing LLM endpoints..." and "not available" messages) from anton.chat import _minds_test_llm + ssl_verify = True llm_ok = _minds_test_llm(minds_url, api_key, verify=True) if not llm_ok: llm_ok = _minds_test_llm(minds_url, api_key, verify=False) + if llm_ok: + ssl_verify = False if llm_ok: console.print("[anton.success]LLM endpoints available — using Minds server as LLM provider.[/]") @@ -334,12 +336,15 @@ def _ensure_minds_api_key(settings, ws) -> None: settings.coding_provider = "openai-compatible" settings.planning_model = "_reason_" settings.coding_model = "_code_" + settings.minds_ssl_verify = ssl_verify ws.set_secret("ANTON_OPENAI_API_KEY", api_key) ws.set_secret("ANTON_OPENAI_BASE_URL", base_url) ws.set_secret("ANTON_PLANNING_PROVIDER", "openai-compatible") ws.set_secret("ANTON_CODING_PROVIDER", "openai-compatible") ws.set_secret("ANTON_PLANNING_MODEL", "_reason_") ws.set_secret("ANTON_CODING_MODEL", "_code_") + if not ssl_verify: + ws.set_secret("ANTON_MINDS_SSL_VERIFY", "false") else: # LLM endpoints not available — fall back to Anthropic _ensure_anthropic_api_key(settings, ws) From 4144a98c3c343e3e3e8118bccaa373b19c1bc93b Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Mon, 23 Mar 2026 13:26:52 -0700 Subject: [PATCH 29/70] ssl verify option --- anton/cli.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index da03327..5e7ce57 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -319,13 +319,27 @@ def _ensure_minds_api_key(settings, ws) -> None: ws.set_secret("ANTON_MINDS_URL", minds_url) # Test if the Minds server supports LLM endpoints (_code_/_reason_) + from rich.prompt import Confirm + from anton.chat import _minds_test_llm ssl_verify = True llm_ok = _minds_test_llm(minds_url, api_key, verify=True) if not llm_ok: - llm_ok = _minds_test_llm(minds_url, api_key, verify=False) - if llm_ok: - ssl_verify = False + # SSL verification failed — check if the server is reachable without it + llm_ok_no_ssl = _minds_test_llm(minds_url, api_key, verify=False) + if llm_ok_no_ssl: + console.print("[anton.warning]SSL certificate verification failed for this server.[/]") + skip_ssl = Confirm.ask( + "Continue without verifying SSL certificates?", + default=False, + console=console, + ) + if skip_ssl: + ssl_verify = False + llm_ok = True + else: + console.print("[anton.error]Cannot connect with SSL verification. Check your server certificate.[/]") + llm_ok = False if llm_ok: console.print("[anton.success]LLM endpoints available — using Minds server as LLM provider.[/]") From 0cc1a65597cd958335d656acae7036edafa2f27d Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 09:29:44 -0700 Subject: [PATCH 30/70] launch super --- anton/cli.py | 264 ++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 198 insertions(+), 66 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index 5e7ce57..587c9ae 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -224,15 +224,7 @@ def main( def _has_api_key(settings) -> bool: - """Check if all configured providers have API keys. - - Also returns False when the Minds API key is missing, so that - upgrading users are prompted to set it up on first re-launch. - """ - # Minds key is always required now - if not (settings.minds_api_key or os.environ.get("ANTON_MINDS_API_KEY")): - return False - + """Check if any LLM provider is fully configured.""" providers = {settings.planning_provider, settings.coding_provider} for p in providers: if p == "anthropic" and not (settings.anthropic_api_key or os.environ.get("ANTHROPIC_API_KEY")): @@ -247,64 +239,91 @@ def _ensure_api_key(settings) -> None: if _has_api_key(settings): return + from rich.panel import Panel from rich.prompt import Prompt + from rich.text import Text from anton.workspace import Workspace ws = Workspace(Path.home()) - _ensure_minds_api_key(settings, ws) - - # Reload env vars into the process so the scratchpad subprocess inherits them - ws.apply_env_to_process() - + # Header console.print() - console.print(f"[anton.success]Saved to {ws.env_path}[/]") + header = Text() + header.append(" First-time setup", style="bold anton.cyan") + header.append(" — pick your LLM provider\n", style="anton.muted") + console.print(header) + + # Provider choices + minds_line = Text() + minds_line.append(" 1 ", style="bold") + minds_line.append("Minds ", style="anton.cyan") + minds_line.append("mdb.ai", style="anton.muted") + minds_line.append(" recommended", style="bold anton.success") + console.print(minds_line) + + benefits = Text() + benefits.append(" Optimized model routing ", style="anton.muted") + benefits.append("|", style="dim") + benefits.append(" Faster responses ", style="anton.muted") + benefits.append("|", style="dim") + benefits.append(" Built-in billing", style="anton.muted") + console.print(benefits) console.print() - -def _ensure_anthropic_api_key(settings, ws) -> None: - """Prompt for Anthropic API key (default flow).""" - from rich.prompt import Prompt - - console.print() - console.print("[anton.cyan]Anthropic configuration[/]") + other_line = Text() + other_line.append(" 2 ", style="bold") + other_line.append("Bring your own key ", style="anton.cyan") + other_line.append("Anthropic or OpenAI", style="anton.muted") + console.print(other_line) console.print() - api_key = Prompt.ask("Anthropic API key", console=console) - if not api_key.strip(): - console.print("[anton.error]No API key provided. Exiting.[/]") - raise typer.Exit(1) - api_key = api_key.strip() - - settings.anthropic_api_key = api_key - settings.planning_provider = "anthropic" - settings.coding_provider = "anthropic" - settings.planning_model = "claude-sonnet-4-6" - settings.coding_model = "claude-haiku-4-5-20251001" - ws.set_secret("ANTON_ANTHROPIC_API_KEY", api_key) - ws.set_secret("ANTON_PLANNING_PROVIDER", "anthropic") - ws.set_secret("ANTON_CODING_PROVIDER", "anthropic") - ws.set_secret("ANTON_PLANNING_MODEL", "claude-sonnet-4-6") - ws.set_secret("ANTON_CODING_MODEL", "claude-haiku-4-5-20251001") + choice = Prompt.ask( + "[anton.cyan]>[/]", + choices=["1", "2"], + default="1", + console=console, + show_choices=False, + show_default=False, + ) + if choice == "1": + _setup_minds(settings, ws) + else: + _setup_other_provider(settings, ws) -def _ensure_minds_api_key(settings, ws) -> None: - """Prompt for Minds API key and configure LLM endpoints (opt-in flow).""" - from rich.prompt import Prompt + # Reload env vars into the process so the scratchpad subprocess inherits them + ws.apply_env_to_process() + # Summary console.print() - console.print("[anton.cyan]Minds configuration[/]") + provider_label = settings.planning_provider + model_label = settings.planning_model + if provider_label == "openai-compatible": + provider_label = "Minds" + summary = Text() + summary.append(" Provider ", style="anton.muted") + summary.append(provider_label, style="anton.cyan") + summary.append(" Model ", style="anton.muted") + summary.append(model_label, style="anton.cyan") + console.print(summary) + console.print(f" [anton.success]Ready.[/] [anton.muted]Saved to {ws.env_path}[/]") console.print() - api_key = Prompt.ask("Minds API key", console=console) + +def _setup_minds(settings, ws) -> None: + """Set up Minds (mdb.ai) as the LLM provider.""" + from rich.prompt import Confirm, Prompt + + console.print() + api_key = Prompt.ask(" [anton.cyan]Minds API key[/]", console=console) if not api_key.strip(): - console.print("[anton.error]No API key provided. Exiting.[/]") + console.print(" [anton.error]No API key provided.[/]") raise typer.Exit(1) api_key = api_key.strip() minds_url = Prompt.ask( - "Minds URL", + " [anton.cyan]Minds URL[/]", default="https://mdb.ai", console=console, ).strip() @@ -318,31 +337,36 @@ def _ensure_minds_api_key(settings, ws) -> None: ws.set_secret("ANTON_MINDS_API_KEY", api_key) ws.set_secret("ANTON_MINDS_URL", minds_url) - # Test if the Minds server supports LLM endpoints (_code_/_reason_) - from rich.prompt import Confirm - + # Test connection with a spinner from anton.chat import _minds_test_llm + + from rich.live import Live + from rich.spinner import Spinner + ssl_verify = True - llm_ok = _minds_test_llm(minds_url, api_key, verify=True) - if not llm_ok: - # SSL verification failed — check if the server is reachable without it - llm_ok_no_ssl = _minds_test_llm(minds_url, api_key, verify=False) - if llm_ok_no_ssl: - console.print("[anton.warning]SSL certificate verification failed for this server.[/]") - skip_ssl = Confirm.ask( - "Continue without verifying SSL certificates?", - default=False, - console=console, - ) - if skip_ssl: + llm_ok = False + + with Live(Spinner("dots", text=" Connecting...", style="anton.cyan"), console=console, transient=True): + llm_ok = _minds_test_llm(minds_url, api_key, verify=True) + if not llm_ok: + llm_ok_no_ssl = _minds_test_llm(minds_url, api_key, verify=False) + if llm_ok_no_ssl: ssl_verify = False llm_ok = True - else: - console.print("[anton.error]Cannot connect with SSL verification. Check your server certificate.[/]") - llm_ok = False + + if llm_ok and not ssl_verify: + console.print(" [anton.warning]SSL certificate verification failed.[/]") + skip_ssl = Confirm.ask( + " Continue without SSL verification?", + default=False, + console=console, + ) + if not skip_ssl: + console.print(" [anton.error]Setup cancelled.[/]") + raise typer.Exit(1) if llm_ok: - console.print("[anton.success]LLM endpoints available — using Minds server as LLM provider.[/]") + console.print(" [anton.success]Connected[/]") base_url = f"{minds_url}/api/v1" settings.openai_api_key = api_key settings.openai_base_url = base_url @@ -360,8 +384,116 @@ def _ensure_minds_api_key(settings, ws) -> None: if not ssl_verify: ws.set_secret("ANTON_MINDS_SSL_VERIFY", "false") else: - # LLM endpoints not available — fall back to Anthropic - _ensure_anthropic_api_key(settings, ws) + console.print(" [anton.error]Could not connect. Check your API key and URL.[/]") + raise typer.Exit(1) + + +def _setup_other_provider(settings, ws) -> None: + """Set up Anthropic or OpenAI as the LLM provider.""" + from rich.prompt import Prompt + from rich.text import Text + + console.print() + for label, idx in [("Anthropic (Claude)", "1"), ("OpenAI (GPT)", "2")]: + line = Text() + line.append(f" {idx} ", style="bold") + line.append(label, style="anton.cyan") + console.print(line) + console.print() + + provider_choice = Prompt.ask( + "[anton.cyan]>[/]", + choices=["1", "2"], + console=console, + show_choices=False, + ) + + if provider_choice == "1": + _setup_anthropic(settings, ws) + else: + _setup_openai(settings, ws) + + +def _validate_with_spinner(console, label: str, fn) -> None: + """Run a validation function with a spinner, print result.""" + from rich.live import Live + from rich.spinner import Spinner + + with Live(Spinner("dots", text=f" Validating {label}...", style="anton.cyan"), console=console, transient=True): + fn() + console.print(f" [anton.success]Validated[/] [anton.muted]{label}[/]") + + +def _setup_anthropic(settings, ws) -> None: + """Set up Anthropic with a single model for both reasoning and coding.""" + from rich.prompt import Prompt + + console.print() + api_key = Prompt.ask(" [anton.cyan]API key[/]", console=console) + if not api_key.strip(): + console.print(" [anton.error]No API key provided.[/]") + raise typer.Exit(1) + api_key = api_key.strip() + + model = Prompt.ask(" [anton.cyan]Model[/]", default="claude-sonnet-4-6", console=console).strip() + + try: + def _test(): + import anthropic + client = anthropic.Anthropic(api_key=api_key) + client.messages.create(model=model, max_tokens=1, messages=[{"role": "user", "content": "ping"}]) + + _validate_with_spinner(console, model, _test) + except Exception as exc: + console.print(f" [anton.error]Failed:[/] {exc}") + raise typer.Exit(1) + + settings.anthropic_api_key = api_key + settings.planning_provider = "anthropic" + settings.coding_provider = "anthropic" + settings.planning_model = model + settings.coding_model = model + ws.set_secret("ANTON_ANTHROPIC_API_KEY", api_key) + ws.set_secret("ANTON_PLANNING_PROVIDER", "anthropic") + ws.set_secret("ANTON_CODING_PROVIDER", "anthropic") + ws.set_secret("ANTON_PLANNING_MODEL", model) + ws.set_secret("ANTON_CODING_MODEL", model) + + +def _setup_openai(settings, ws) -> None: + """Set up OpenAI with a single model for both reasoning and coding.""" + from rich.prompt import Prompt + + console.print() + api_key = Prompt.ask(" [anton.cyan]API key[/]", console=console) + if not api_key.strip(): + console.print(" [anton.error]No API key provided.[/]") + raise typer.Exit(1) + api_key = api_key.strip() + + model = Prompt.ask(" [anton.cyan]Model[/]", default="gpt-4o", console=console).strip() + + try: + def _test(): + import openai + client = openai.OpenAI(api_key=api_key) + client.chat.completions.create(model=model, max_tokens=1, messages=[{"role": "user", "content": "ping"}]) + + _validate_with_spinner(console, model, _test) + except Exception as exc: + console.print(f" [anton.error]Failed:[/] {exc}") + raise typer.Exit(1) + + settings.openai_api_key = api_key + settings.planning_provider = "openai" + settings.coding_provider = "openai" + settings.planning_model = model + settings.coding_model = model + ws.set_secret("ANTON_OPENAI_API_KEY", api_key) + ws.set_secret("ANTON_PLANNING_PROVIDER", "openai") + ws.set_secret("ANTON_CODING_PROVIDER", "openai") + ws.set_secret("ANTON_PLANNING_MODEL", model) + ws.set_secret("ANTON_CODING_MODEL", model) @app.command("setup") From 09ea441cd607f923ccdb892912211a7b75b8c1aa Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 09:42:58 -0700 Subject: [PATCH 31/70] cli onboarding --- anton/cli.py | 69 +++++++++++++++++++++++++++++++--------------------- 1 file changed, 41 insertions(+), 28 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index 587c9ae..b3655df 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -247,36 +247,49 @@ def _ensure_api_key(settings) -> None: ws = Workspace(Path.home()) - # Header - console.print() - header = Text() - header.append(" First-time setup", style="bold anton.cyan") - header.append(" — pick your LLM provider\n", style="anton.muted") - console.print(header) - - # Provider choices - minds_line = Text() - minds_line.append(" 1 ", style="bold") - minds_line.append("Minds ", style="anton.cyan") - minds_line.append("mdb.ai", style="anton.muted") - minds_line.append(" recommended", style="bold anton.success") - console.print(minds_line) - - benefits = Text() - benefits.append(" Optimized model routing ", style="anton.muted") - benefits.append("|", style="dim") - benefits.append(" Faster responses ", style="anton.muted") - benefits.append("|", style="dim") - benefits.append(" Built-in billing", style="anton.muted") - console.print(benefits) - console.print() + from rich.table import Table + + table = Table( + show_header=False, + show_edge=False, + show_lines=False, + padding=(0, 2), + expand=True, + ) + table.add_column(ratio=1) + table.add_column("", width=1, style="anton.cyan_dim") + table.add_column(ratio=1) + + # Left: provider choices + choices = Text() + choices.append("1 ", style="bold") + choices.append("MindsDB Cloud ", style="bold anton.cyan") + choices.append("(recommended)\n\n", style="anton.success") + choices.append("2 ", style="bold") + choices.append("Bring your own key\n", style="anton.cyan") + choices.append(" Anthropic or OpenAI", style="anton.muted") + + # Right: why Minds + info = Text() + info.append("MindsDB Cloud ", style="bold anton.cyan") + info.append("mdb.ai\n\n", style="anton.muted") + info.append("MindsDB is the maker of Anton\n", style="anton.muted") + info.append("and provides an LLM service\n", style="anton.muted") + info.append("optimized for Anton:\n\n", style="anton.muted") + info.append(" Smart model routing\n", style="") + info.append(" Faster responses\n", style="") + info.append(" Cost optimized", style="") + + divider = Text("│\n│\n│\n│\n│\n│\n│\n│\n│\n│", style="anton.cyan_dim") + table.add_row(choices, divider, info) - other_line = Text() - other_line.append(" 2 ", style="bold") - other_line.append("Bring your own key ", style="anton.cyan") - other_line.append("Anthropic or OpenAI", style="anton.muted") - console.print(other_line) console.print() + console.print(Panel( + table, + title="[bold anton.cyan]LLM Setup[/]", + border_style="anton.cyan_dim", + padding=(1, 2), + )) choice = Prompt.ask( "[anton.cyan]>[/]", From b5569b1f00d45673f6e0715f63abd544c728a6ee Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 09:53:36 -0700 Subject: [PATCH 32/70] cli update onboarding --- anton/cli.py | 117 +++++++++++++++++++++++++-------------------------- 1 file changed, 58 insertions(+), 59 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index b3655df..aacb5a6 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -214,12 +214,14 @@ def main( ctx.obj["settings"] = settings if ctx.invoked_subcommand is None: - from anton.channel.branding import render_banner from anton.chat import run_chat - render_banner(console) _ensure_workspace(settings) - _ensure_api_key(settings) + if not _has_api_key(settings): + _onboard(settings) + else: + from anton.channel.branding import render_banner + render_banner(console) run_chat(console, settings, resume=resume) @@ -234,70 +236,69 @@ def _has_api_key(settings) -> bool: return True -def _ensure_api_key(settings) -> None: - """Prompt the user to configure a provider and API key if none is set.""" - if _has_api_key(settings): - return - - from rich.panel import Panel +def _onboard(settings) -> None: + """First-time onboarding: robot intro + LLM provider selection.""" from rich.prompt import Prompt from rich.text import Text + from anton import __version__ + from anton.channel.branding import _render_robot_static from anton.workspace import Workspace ws = Workspace(Path.home()) - from rich.table import Table + # Robot on the left, intro text on the right + g = "anton.glow" + m = "anton.muted" + + # Build robot lines (static, no animation for clean layout) + robot_lines = [ + f"[{g}] \u2590[/]", + f"[{g}] \u2584\u2588\u2580\u2588\u2588\u2580\u2588\u2584[/] [{g}]\u2661\u2661\u2661\u2661[/]", + f"[{g}] \u2588\u2588[/] [{m}](\u00b0\u1d17\u00b0)[/] [{g}]\u2588\u2588[/]", + f"[{g}] \u2580\u2588\u2584\u2588\u2588\u2584\u2588\u2580[/]" + f" [{g}]\u2584\u2580\u2588 \u2588\u2584 \u2588 \u2580\u2588\u2580 \u2588\u2580\u2588 \u2588\u2584 \u2588[/]", + f"[{g}] \u2590 \u2590[/]" + f" [{g}]\u2588\u2580\u2588 \u2588 \u2580\u2588 \u2588 \u2588\u2584\u2588 \u2588 \u2580\u2588[/]", + f"[{g}] \u2590 \u2590[/]", + f"[{g}] {'━' * 40}[/]", + f" v{__version__}", + ] + + for line in robot_lines: + console.print(line) - table = Table( - show_header=False, - show_edge=False, - show_lines=False, - padding=(0, 2), - expand=True, + console.print() + console.print( + "[anton.cyan]Anton[/] is an autonomous AI coworker built by " + "[bold anton.cyan]MindsDB[/]." + ) + console.print() + console.print( + "For the best experience, we recommend [bold anton.cyan]MindsDB Cloud[/] " + "[anton.muted](mdb.ai)[/]" + ) + console.print( + "as your LLM provider. It is optimized for Anton with:" ) - table.add_column(ratio=1) - table.add_column("", width=1, style="anton.cyan_dim") - table.add_column(ratio=1) - - # Left: provider choices - choices = Text() - choices.append("1 ", style="bold") - choices.append("MindsDB Cloud ", style="bold anton.cyan") - choices.append("(recommended)\n\n", style="anton.success") - choices.append("2 ", style="bold") - choices.append("Bring your own key\n", style="anton.cyan") - choices.append(" Anthropic or OpenAI", style="anton.muted") - - # Right: why Minds - info = Text() - info.append("MindsDB Cloud ", style="bold anton.cyan") - info.append("mdb.ai\n\n", style="anton.muted") - info.append("MindsDB is the maker of Anton\n", style="anton.muted") - info.append("and provides an LLM service\n", style="anton.muted") - info.append("optimized for Anton:\n\n", style="anton.muted") - info.append(" Smart model routing\n", style="") - info.append(" Faster responses\n", style="") - info.append(" Cost optimized", style="") - - divider = Text("│\n│\n│\n│\n│\n│\n│\n│\n│\n│", style="anton.cyan_dim") - table.add_row(choices, divider, info) + console.print() + console.print(" [anton.success]\u2713[/] Smart model routing") + console.print(" [anton.success]\u2713[/] Faster responses") + console.print(" [anton.success]\u2713[/] Cost optimized") + console.print() + + console.print(f"[{g}] {'━' * 40}[/]") + console.print() + console.print(" [bold]1[/] [link=https://mdb.ai][anton.cyan]MindsDB Cloud[/][/link] [anton.success](recommended)[/]") + console.print(" [bold]2[/] [anton.cyan]Bring your own key[/] [anton.muted]Anthropic / OpenAI[/]") console.print() - console.print(Panel( - table, - title="[bold anton.cyan]LLM Setup[/]", - border_style="anton.cyan_dim", - padding=(1, 2), - )) choice = Prompt.ask( - "[anton.cyan]>[/]", + "Choose LLM Provider", choices=["1", "2"], default="1", console=console, - show_choices=False, - show_default=False, ) if choice == "1": @@ -305,7 +306,7 @@ def _ensure_api_key(settings) -> None: else: _setup_other_provider(settings, ws) - # Reload env vars into the process so the scratchpad subprocess inherits them + # Reload env vars so the scratchpad subprocess inherits them ws.apply_env_to_process() # Summary @@ -313,13 +314,11 @@ def _ensure_api_key(settings) -> None: provider_label = settings.planning_provider model_label = settings.planning_model if provider_label == "openai-compatible": - provider_label = "Minds" - summary = Text() - summary.append(" Provider ", style="anton.muted") - summary.append(provider_label, style="anton.cyan") - summary.append(" Model ", style="anton.muted") - summary.append(model_label, style="anton.cyan") - console.print(summary) + provider_label = "MindsDB Cloud" + console.print( + f" [anton.muted]Provider[/] [anton.cyan]{provider_label}[/]" + f" [anton.muted]Model[/] [anton.cyan]{model_label}[/]" + ) console.print(f" [anton.success]Ready.[/] [anton.muted]Saved to {ws.env_path}[/]") console.print() From e263b239dc369dd694a896c7630a86aab9d4c064 Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 09:57:13 -0700 Subject: [PATCH 33/70] on boarding --- anton/cli.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/anton/cli.py b/anton/cli.py index aacb5a6..fadb6c5 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -328,7 +328,12 @@ def _setup_minds(settings, ws) -> None: from rich.prompt import Confirm, Prompt console.print() - api_key = Prompt.ask(" [anton.cyan]Minds API key[/]", console=console) + console.print( + " [anton.muted]Don't have a key yet? Create one in seconds at[/]" + " [link=https://mdb.ai][bold anton.cyan]mdb.ai[/][/link]" + ) + console.print() + api_key = Prompt.ask(" [anton.cyan]API key[/]", console=console) if not api_key.strip(): console.print(" [anton.error]No API key provided.[/]") raise typer.Exit(1) From bb6c953fbd6bf4d1ec3823b177ad5d725ee566c0 Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 10:03:01 -0700 Subject: [PATCH 34/70] browser --- anton/cli.py | 97 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 57 insertions(+), 40 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index fadb6c5..ce06499 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -236,57 +236,71 @@ def _has_api_key(settings) -> bool: return True +def _typewrite(text: str, delay: float = 0.02) -> None: + """Print text character by character with a typing effect.""" + import sys + import time + + for ch in text: + sys.stdout.write(ch) + sys.stdout.flush() + time.sleep(delay) + sys.stdout.write("\n") + sys.stdout.flush() + + def _onboard(settings) -> None: - """First-time onboarding: robot intro + LLM provider selection.""" + """First-time onboarding: animated robot + typed intro + LLM provider selection.""" + import sys + import time + from rich.prompt import Prompt - from rich.text import Text from anton import __version__ - from anton.channel.branding import _render_robot_static + from anton.channel.branding import render_banner from anton.workspace import Workspace ws = Workspace(Path.home()) - - # Robot on the left, intro text on the right g = "anton.glow" - m = "anton.muted" - - # Build robot lines (static, no animation for clean layout) - robot_lines = [ - f"[{g}] \u2590[/]", - f"[{g}] \u2584\u2588\u2580\u2588\u2588\u2580\u2588\u2584[/] [{g}]\u2661\u2661\u2661\u2661[/]", - f"[{g}] \u2588\u2588[/] [{m}](\u00b0\u1d17\u00b0)[/] [{g}]\u2588\u2588[/]", - f"[{g}] \u2580\u2588\u2584\u2588\u2588\u2584\u2588\u2580[/]" - f" [{g}]\u2584\u2580\u2588 \u2588\u2584 \u2588 \u2580\u2588\u2580 \u2588\u2580\u2588 \u2588\u2584 \u2588[/]", - f"[{g}] \u2590 \u2590[/]" - f" [{g}]\u2588\u2580\u2588 \u2588 \u2580\u2588 \u2588 \u2588\u2584\u2588 \u2588 \u2580\u2588[/]", - f"[{g}] \u2590 \u2590[/]", - f"[{g}] {'━' * 40}[/]", - f" v{__version__}", - ] - - for line in robot_lines: - console.print(line) - console.print() - console.print( - "[anton.cyan]Anton[/] is an autonomous AI coworker built by " - "[bold anton.cyan]MindsDB[/]." - ) - console.print() - console.print( - "For the best experience, we recommend [bold anton.cyan]MindsDB Cloud[/] " - "[anton.muted](mdb.ai)[/]" - ) - console.print( - "as your LLM provider. It is optimized for Anton with:" - ) - console.print() - console.print(" [anton.success]\u2713[/] Smart model routing") - console.print(" [anton.success]\u2713[/] Faster responses") - console.print(" [anton.success]\u2713[/] Cost optimized") + # Animated robot banner (same as normal launch) + render_banner(console) + console.print() + # Type out the intro + if sys.stdout.isatty(): + _typewrite("Anton is an autonomous AI coworker built by MindsDB.") + time.sleep(0.3) + console.print() + _typewrite("For the best experience, we recommend MindsDB Cloud (mdb.ai)") + _typewrite("as your LLM provider. It is optimized for Anton with:") + time.sleep(0.2) + console.print() + for line in [ + " \u2713 Smart model routing", + " \u2713 Faster responses", + " \u2713 Cost optimized", + ]: + _typewrite(line, delay=0.015) + time.sleep(0.1) + else: + console.print( + "[anton.cyan]Anton[/] is an autonomous AI coworker built by " + "[bold anton.cyan]MindsDB[/]." + ) + console.print() + console.print( + "For the best experience, we recommend [bold anton.cyan]MindsDB Cloud[/] " + "[anton.muted](mdb.ai)[/]" + ) + console.print("as your LLM provider. It is optimized for Anton with:") + console.print() + console.print(" [anton.success]\u2713[/] Smart model routing") + console.print(" [anton.success]\u2713[/] Faster responses") + console.print(" [anton.success]\u2713[/] Cost optimized") + + console.print() console.print(f"[{g}] {'━' * 40}[/]") console.print() @@ -327,11 +341,14 @@ def _setup_minds(settings, ws) -> None: """Set up Minds (mdb.ai) as the LLM provider.""" from rich.prompt import Confirm, Prompt + import webbrowser + console.print() console.print( " [anton.muted]Don't have a key yet? Create one in seconds at[/]" " [link=https://mdb.ai][bold anton.cyan]mdb.ai[/][/link]" ) + webbrowser.open("https://mdb.ai") console.print() api_key = Prompt.ask(" [anton.cyan]API key[/]", console=console) if not api_key.strip(): From a79fd53728103452312f3cbb71bf871b7ff325a0 Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 10:07:06 -0700 Subject: [PATCH 35/70] anton welcome message chat lalala --- anton/cli.py | 146 ++++++++++++++++++++++++++++++++++----------------- 1 file changed, 99 insertions(+), 47 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index ce06499..b479c05 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -236,74 +236,126 @@ def _has_api_key(settings) -> bool: return True -def _typewrite(text: str, delay: float = 0.02) -> None: - """Print text character by character with a typing effect.""" - import sys - import time - - for ch in text: - sys.stdout.write(ch) - sys.stdout.flush() - time.sleep(delay) - sys.stdout.write("\n") - sys.stdout.flush() - - def _onboard(settings) -> None: - """First-time onboarding: animated robot + typed intro + LLM provider selection.""" + """First-time onboarding: animated robot talking the intro + LLM provider selection.""" import sys import time from rich.prompt import Prompt from anton import __version__ - from anton.channel.branding import render_banner from anton.workspace import Workspace ws = Workspace(Path.home()) g = "anton.glow" - # Animated robot banner (same as normal launch) - render_banner(console) - - console.print() + _INTRO_LINES = [ + "", + "Hi! I'm Anton, an autonomous AI coworker built by MindsDB.", + "", + "For the best experience, I recommend MindsDB Cloud (mdb.ai)", + "as your LLM provider. It is optimized for me with:", + "", + " \u2713 Smart model routing", + " \u2713 Faster responses", + " \u2713 Cost optimized", + ] - # Type out the intro if sys.stdout.isatty(): - _typewrite("Anton is an autonomous AI coworker built by MindsDB.") - time.sleep(0.3) - console.print() - _typewrite("For the best experience, we recommend MindsDB Cloud (mdb.ai)") - _typewrite("as your LLM provider. It is optimized for Anton with:") - time.sleep(0.2) - console.print() - for line in [ - " \u2713 Smart model routing", - " \u2713 Faster responses", - " \u2713 Cost optimized", - ]: - _typewrite(line, delay=0.015) - time.sleep(0.1) + _animate_onboard(console, __version__, _INTRO_LINES) else: - console.print( - "[anton.cyan]Anton[/] is an autonomous AI coworker built by " - "[bold anton.cyan]MindsDB[/]." - ) - console.print() - console.print( - "For the best experience, we recommend [bold anton.cyan]MindsDB Cloud[/] " - "[anton.muted](mdb.ai)[/]" - ) - console.print("as your LLM provider. It is optimized for Anton with:") + # Static fallback for non-interactive terminals + from anton.channel.branding import render_banner + + render_banner(console, animate=False) console.print() - console.print(" [anton.success]\u2713[/] Smart model routing") - console.print(" [anton.success]\u2713[/] Faster responses") - console.print(" [anton.success]\u2713[/] Cost optimized") + for line in _INTRO_LINES: + console.print(line) console.print() console.print(f"[{g}] {'━' * 40}[/]") console.print() + +def _animate_onboard(console, version: str, intro_lines: list[str]) -> None: + """Animate the robot talking while typing out the intro text below.""" + import time + + from rich.live import Live + from rich.text import Text + + from anton.channel.branding import ( + _MOUTH_SMILE, + _MOUTH_TALK, + _build_robot_text, + pick_tagline, + ) + + tagline = pick_tagline() + char_delay = 0.02 + line_pause = 0.15 + char_count = 0 # drives mouth animation + + def _build_frame(mouth: str, typed_lines: list[str]) -> Text: + """Build robot + separator + typed text as a single renderable.""" + frame = _build_robot_text(mouth, "\u2661\u2661\u2661\u2661") + frame.append(f" {'━' * 40}\n", style="bold cyan") + frame.append(f" v{version} \u2014 \"{tagline}\"\n", style="dim") + frame.append("\n") + frame.append(" anton> ", style="bold cyan") + for line in typed_lines: + frame.append(line) + return frame + + with Live( + _build_frame(_MOUTH_SMILE, []), + console=console, + refresh_per_second=30, + transient=True, + ) as live: + time.sleep(0.4) + + typed_so_far: list[str] = [] + + for line_idx, line in enumerate(intro_lines): + if line == "": + typed_so_far.append("\n") + live.update(_build_frame(_MOUTH_SMILE, typed_so_far)) + time.sleep(line_pause) + continue + + # Type out each character + current = "" + for ch in line: + current += ch + char_count += 1 + mouth = _MOUTH_TALK[char_count % 2] + live.update(_build_frame(mouth, typed_so_far + [current])) + time.sleep(char_delay) + + typed_so_far.append(current + "\n") + live.update(_build_frame(_MOUTH_SMILE, typed_so_far)) + time.sleep(line_pause) + + # Hold final frame briefly + time.sleep(0.3) + + # Print the static final state + from anton.channel.branding import _render_robot_static + + _render_robot_static(console, "\u2661\u2661\u2661\u2661") + console.print(f"[{g}] {'━' * 40}[/]") + console.print(f" v{version} \u2014 [anton.muted]\"{tagline}\"[/]") + console.print() + console.print("[anton.cyan] anton>[/] ", end="") + for line in intro_lines: + if line == "": + console.print() + elif line.startswith(" \u2713"): + console.print(f"[anton.success]{line}[/]") + else: + console.print(line) + console.print(" [bold]1[/] [link=https://mdb.ai][anton.cyan]MindsDB Cloud[/][/link] [anton.success](recommended)[/]") console.print(" [bold]2[/] [anton.cyan]Bring your own key[/] [anton.muted]Anthropic / OpenAI[/]") console.print() From 4072818787eeaa9b7bd0d147559d203b1c9a309e Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 10:12:23 -0700 Subject: [PATCH 36/70] onboarding --- anton/cli.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index b479c05..b79aed9 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -344,16 +344,20 @@ def _build_frame(mouth: str, typed_lines: list[str]) -> Text: from anton.channel.branding import _render_robot_static _render_robot_static(console, "\u2661\u2661\u2661\u2661") - console.print(f"[{g}] {'━' * 40}[/]") + console.print(f"[anton.glow] {'━' * 40}[/]") console.print(f" v{version} \u2014 [anton.muted]\"{tagline}\"[/]") console.print() - console.print("[anton.cyan] anton>[/] ", end="") + console.print("[anton.cyan]anton>[/] ", end="") + first_text = True for line in intro_lines: if line == "": - console.print() + if not first_text: + console.print() elif line.startswith(" \u2713"): + first_text = False console.print(f"[anton.success]{line}[/]") else: + first_text = False console.print(line) console.print(" [bold]1[/] [link=https://mdb.ai][anton.cyan]MindsDB Cloud[/][/link] [anton.success](recommended)[/]") From be8317e153a0cbbbf6751531668fd6f2d56f55fa Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 10:15:51 -0700 Subject: [PATCH 37/70] cli onboarding --- anton/cli.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index b79aed9..8986b0f 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -272,10 +272,6 @@ def _onboard(settings) -> None: for line in _INTRO_LINES: console.print(line) - console.print() - console.print(f"[{g}] {'━' * 40}[/]") - console.print() - def _animate_onboard(console, version: str, intro_lines: list[str]) -> None: """Animate the robot talking while typing out the intro text below.""" @@ -360,6 +356,9 @@ def _build_frame(mouth: str, typed_lines: list[str]) -> Text: first_text = False console.print(line) + console.print() + console.print(f"[anton.glow] {'━' * 40}[/]") + console.print() console.print(" [bold]1[/] [link=https://mdb.ai][anton.cyan]MindsDB Cloud[/][/link] [anton.success](recommended)[/]") console.print(" [bold]2[/] [anton.cyan]Bring your own key[/] [anton.muted]Anthropic / OpenAI[/]") console.print() From 326004581c0e33c3676f36e96c43f38216421f06 Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 10:19:09 -0700 Subject: [PATCH 38/70] making on boarding feel great --- anton/cli.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index 8986b0f..353b283 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -250,7 +250,6 @@ def _onboard(settings) -> None: g = "anton.glow" _INTRO_LINES = [ - "", "Hi! I'm Anton, an autonomous AI coworker built by MindsDB.", "", "For the best experience, I recommend MindsDB Cloud (mdb.ai)", @@ -298,7 +297,7 @@ def _build_frame(mouth: str, typed_lines: list[str]) -> Text: frame.append(f" {'━' * 40}\n", style="bold cyan") frame.append(f" v{version} \u2014 \"{tagline}\"\n", style="dim") frame.append("\n") - frame.append(" anton> ", style="bold cyan") + frame.append("anton> ", style="bold cyan") for line in typed_lines: frame.append(line) return frame From 421f598942461ad26bc3bfb2a3889a3a5bdfe989 Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 10:22:39 -0700 Subject: [PATCH 39/70] anton cli fixes onboarding --- anton/cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index 353b283..0f89465 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -261,7 +261,7 @@ def _onboard(settings) -> None: ] if sys.stdout.isatty(): - _animate_onboard(console, __version__, _INTRO_LINES) + _animate_onboard(console, __version__, _INTRO_LINES, settings=settings, ws=ws) else: # Static fallback for non-interactive terminals from anton.channel.branding import render_banner @@ -272,7 +272,7 @@ def _onboard(settings) -> None: console.print(line) -def _animate_onboard(console, version: str, intro_lines: list[str]) -> None: +def _animate_onboard(console, version: str, intro_lines: list[str], *, settings, ws) -> None: """Animate the robot talking while typing out the intro text below.""" import time From 842b6080005b525876779ffbdc674e8325cb8afb Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 10:28:33 -0700 Subject: [PATCH 40/70] option 2 --- anton/cli.py | 51 +++++++++++++++++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index 0f89465..1c08e60 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -252,7 +252,7 @@ def _onboard(settings) -> None: _INTRO_LINES = [ "Hi! I'm Anton, an autonomous AI coworker built by MindsDB.", "", - "For the best experience, I recommend MindsDB Cloud (mdb.ai)", + "For the best experience, I recommend MindsDB Cloud (https://mdb.ai)", "as your LLM provider. It is optimized for me with:", "", " \u2713 Smart model routing", @@ -359,18 +359,21 @@ def _build_frame(mouth: str, typed_lines: list[str]) -> Text: console.print(f"[anton.glow] {'━' * 40}[/]") console.print() console.print(" [bold]1[/] [link=https://mdb.ai][anton.cyan]MindsDB Cloud[/][/link] [anton.success](recommended)[/]") - console.print(" [bold]2[/] [anton.cyan]Bring your own key[/] [anton.muted]Anthropic / OpenAI[/]") + console.print(" [bold]2[/] [anton.cyan]MindsDB Enterprise Server[/]") + console.print(" [bold]3[/] [anton.cyan]Bring your own key[/] [anton.muted]Anthropic / OpenAI[/]") console.print() choice = Prompt.ask( "Choose LLM Provider", - choices=["1", "2"], + choices=["1", "2", "3"], default="1", console=console, ) if choice == "1": _setup_minds(settings, ws) + elif choice == "2": + _setup_minds(settings, ws, enterprise=True) else: _setup_other_provider(settings, ws) @@ -391,34 +394,42 @@ def _build_frame(mouth: str, typed_lines: list[str]) -> Text: console.print() -def _setup_minds(settings, ws) -> None: - """Set up Minds (mdb.ai) as the LLM provider.""" +def _setup_minds(settings, ws, *, enterprise: bool = False) -> None: + """Set up Minds as the LLM provider (cloud or enterprise).""" from rich.prompt import Confirm, Prompt import webbrowser console.print() - console.print( - " [anton.muted]Don't have a key yet? Create one in seconds at[/]" - " [link=https://mdb.ai][bold anton.cyan]mdb.ai[/][/link]" - ) - webbrowser.open("https://mdb.ai") - console.print() + + if enterprise: + # Enterprise: ask for server URL first + minds_url = Prompt.ask( + " [anton.cyan]Server URL[/]", + console=console, + ).strip() + if not minds_url: + console.print(" [anton.error]No URL provided.[/]") + raise typer.Exit(1) + if not minds_url.startswith("http://") and not minds_url.startswith("https://"): + minds_url = "https://" + minds_url + minds_url = minds_url.rstrip("/") + console.print() + else: + minds_url = "https://mdb.ai" + console.print( + " [anton.muted]Don't have a key yet? Create one in seconds at[/]" + " [link=https://mdb.ai][bold anton.cyan]https://mdb.ai[/][/link]" + ) + webbrowser.open("https://mdb.ai") + console.print() + api_key = Prompt.ask(" [anton.cyan]API key[/]", console=console) if not api_key.strip(): console.print(" [anton.error]No API key provided.[/]") raise typer.Exit(1) api_key = api_key.strip() - minds_url = Prompt.ask( - " [anton.cyan]Minds URL[/]", - default="https://mdb.ai", - console=console, - ).strip() - if not minds_url.startswith("http://") and not minds_url.startswith("https://"): - minds_url = "https://" + minds_url - minds_url = minds_url.rstrip("/") - # Store Minds credentials settings.minds_api_key = api_key settings.minds_url = minds_url From 981ddd33328bb10d91b4a329e83ad1ff87ac4978 Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 10:40:38 -0700 Subject: [PATCH 41/70] anton cli workflow --- anton/cli.py | 95 ++++++++++++++++++++++++++++++++-------------------- 1 file changed, 59 insertions(+), 36 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index 1c08e60..13dca6b 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -252,7 +252,7 @@ def _onboard(settings) -> None: _INTRO_LINES = [ "Hi! I'm Anton, an autonomous AI coworker built by MindsDB.", "", - "For the best experience, I recommend MindsDB Cloud (https://mdb.ai)", + "For the best experience, I recommend MindsDB Cloud https://mdb.ai", "as your LLM provider. It is optimized for me with:", "", " \u2713 Smart model routing", @@ -350,7 +350,7 @@ def _build_frame(mouth: str, typed_lines: list[str]) -> Text: console.print() elif line.startswith(" \u2713"): first_text = False - console.print(f"[anton.success]{line}[/]") + console.print(f" [anton.success]\u2713[/] {line[4:]}") else: first_text = False console.print(line) @@ -363,19 +363,29 @@ def _build_frame(mouth: str, typed_lines: list[str]) -> Text: console.print(" [bold]3[/] [anton.cyan]Bring your own key[/] [anton.muted]Anthropic / OpenAI[/]") console.print() - choice = Prompt.ask( - "Choose LLM Provider", - choices=["1", "2", "3"], - default="1", - console=console, - ) + while True: + choice = Prompt.ask( + "Choose LLM Provider", + choices=["1", "2", "3"], + default="1", + console=console, + ) - if choice == "1": - _setup_minds(settings, ws) - elif choice == "2": - _setup_minds(settings, ws, enterprise=True) - else: - _setup_other_provider(settings, ws) + try: + if choice == "1": + _setup_minds(settings, ws) + elif choice == "2": + _setup_minds(settings, ws, enterprise=True) + else: + _setup_other_provider(settings, ws) + break # success + except _SetupRetry: + console.print() + console.print(" [bold]1[/] [link=https://mdb.ai][anton.cyan]MindsDB Cloud[/][/link] [anton.success](recommended)[/]") + console.print(" [bold]2[/] [anton.cyan]MindsDB Enterprise Server[/]") + console.print(" [bold]3[/] [anton.cyan]Bring your own key[/] [anton.muted]Anthropic / OpenAI[/]") + console.print() + continue # Reload env vars so the scratchpad subprocess inherits them ws.apply_env_to_process() @@ -394,6 +404,11 @@ def _build_frame(mouth: str, typed_lines: list[str]) -> Text: console.print() +class _SetupRetry(Exception): + """Raised by setup functions to go back to provider selection.""" + pass + + def _setup_minds(settings, ws, *, enterprise: bool = False) -> None: """Set up Minds as the LLM provider (cloud or enterprise).""" from rich.prompt import Confirm, Prompt @@ -402,21 +417,7 @@ def _setup_minds(settings, ws, *, enterprise: bool = False) -> None: console.print() - if enterprise: - # Enterprise: ask for server URL first - minds_url = Prompt.ask( - " [anton.cyan]Server URL[/]", - console=console, - ).strip() - if not minds_url: - console.print(" [anton.error]No URL provided.[/]") - raise typer.Exit(1) - if not minds_url.startswith("http://") and not minds_url.startswith("https://"): - minds_url = "https://" + minds_url - minds_url = minds_url.rstrip("/") - console.print() - else: - minds_url = "https://mdb.ai" + if not enterprise: console.print( " [anton.muted]Don't have a key yet? Create one in seconds at[/]" " [link=https://mdb.ai][bold anton.cyan]https://mdb.ai[/][/link]" @@ -424,6 +425,15 @@ def _setup_minds(settings, ws, *, enterprise: bool = False) -> None: webbrowser.open("https://mdb.ai") console.print() + minds_url = Prompt.ask( + " [anton.cyan]Server URL[/]", + default="https://mdb.ai", + console=console, + ).strip() + if not minds_url.startswith("http://") and not minds_url.startswith("https://"): + minds_url = "https://" + minds_url + minds_url = minds_url.rstrip("/") + api_key = Prompt.ask(" [anton.cyan]API key[/]", console=console) if not api_key.strip(): console.print(" [anton.error]No API key provided.[/]") @@ -461,8 +471,7 @@ def _setup_minds(settings, ws, *, enterprise: bool = False) -> None: console=console, ) if not skip_ssl: - console.print(" [anton.error]Setup cancelled.[/]") - raise typer.Exit(1) + llm_ok = False if llm_ok: console.print(" [anton.success]Connected[/]") @@ -484,7 +493,11 @@ def _setup_minds(settings, ws, *, enterprise: bool = False) -> None: ws.set_secret("ANTON_MINDS_SSL_VERIFY", "false") else: console.print(" [anton.error]Could not connect. Check your API key and URL.[/]") - raise typer.Exit(1) + retry = Confirm.ask(" Try again?", default=True, console=console) + if retry: + _setup_minds(settings, ws, enterprise=enterprise) + else: + raise _SetupRetry() def _setup_other_provider(settings, ws) -> None: @@ -525,7 +538,7 @@ def _validate_with_spinner(console, label: str, fn) -> None: def _setup_anthropic(settings, ws) -> None: """Set up Anthropic with a single model for both reasoning and coding.""" - from rich.prompt import Prompt + from rich.prompt import Confirm, Prompt console.print() api_key = Prompt.ask(" [anton.cyan]API key[/]", console=console) @@ -545,7 +558,12 @@ def _test(): _validate_with_spinner(console, model, _test) except Exception as exc: console.print(f" [anton.error]Failed:[/] {exc}") - raise typer.Exit(1) + retry = Confirm.ask(" Try again?", default=True, console=console) + if retry: + _setup_anthropic(settings, ws) + return + else: + raise _SetupRetry() settings.anthropic_api_key = api_key settings.planning_provider = "anthropic" @@ -561,7 +579,7 @@ def _test(): def _setup_openai(settings, ws) -> None: """Set up OpenAI with a single model for both reasoning and coding.""" - from rich.prompt import Prompt + from rich.prompt import Confirm, Prompt console.print() api_key = Prompt.ask(" [anton.cyan]API key[/]", console=console) @@ -581,7 +599,12 @@ def _test(): _validate_with_spinner(console, model, _test) except Exception as exc: console.print(f" [anton.error]Failed:[/] {exc}") - raise typer.Exit(1) + retry = Confirm.ask(" Try again?", default=True, console=console) + if retry: + _setup_openai(settings, ws) + return + else: + raise _SetupRetry() settings.openai_api_key = api_key settings.planning_provider = "openai" From 95ec38cdad44bbacd77faa0c53565d25b2034eb3 Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 10:46:30 -0700 Subject: [PATCH 42/70] best cli baby --- anton/cli.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index 13dca6b..e9f9558 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -252,12 +252,13 @@ def _onboard(settings) -> None: _INTRO_LINES = [ "Hi! I'm Anton, an autonomous AI coworker built by MindsDB.", "", - "For the best experience, I recommend MindsDB Cloud https://mdb.ai", - "as your LLM provider. It is optimized for me with:", + "For the best experience, I recommend MindsDB Cloud as your", + "LLM Provider, it is optimized for:", "", " \u2713 Smart model routing", " \u2713 Faster responses", " \u2713 Cost optimized", + " \u2713 Secure data connectors", ] if sys.stdout.isatty(): @@ -417,14 +418,6 @@ def _setup_minds(settings, ws, *, enterprise: bool = False) -> None: console.print() - if not enterprise: - console.print( - " [anton.muted]Don't have a key yet? Create one in seconds at[/]" - " [link=https://mdb.ai][bold anton.cyan]https://mdb.ai[/][/link]" - ) - webbrowser.open("https://mdb.ai") - console.print() - minds_url = Prompt.ask( " [anton.cyan]Server URL[/]", default="https://mdb.ai", @@ -434,6 +427,18 @@ def _setup_minds(settings, ws, *, enterprise: bool = False) -> None: minds_url = "https://" + minds_url minds_url = minds_url.rstrip("/") + has_key = Confirm.ask( + " Do you have an API key?", + default=True, + console=console, + ) + if not has_key: + console.print( + " [anton.muted]No problem — it only takes a few seconds to create one.[/]" + ) + webbrowser.open(f"{minds_url}/apiKeys") + console.print() + api_key = Prompt.ask(" [anton.cyan]API key[/]", console=console) if not api_key.strip(): console.print(" [anton.error]No API key provided.[/]") From c50a243bfaa62e554f661a05918e6d1772739123 Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 10:49:41 -0700 Subject: [PATCH 43/70] optimized for --- anton/cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index e9f9558..ba000aa 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -252,8 +252,8 @@ def _onboard(settings) -> None: _INTRO_LINES = [ "Hi! I'm Anton, an autonomous AI coworker built by MindsDB.", "", - "For the best experience, I recommend MindsDB Cloud as your", - "LLM Provider, it is optimized for:", + "For the best experience, I recommend MindsDB Cloud as your LLM Provider,", + "optimized for:", "", " \u2713 Smart model routing", " \u2713 Faster responses", From f6d3819311ed27a79d6c9090e7856b5166de6559 Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 10:51:20 -0700 Subject: [PATCH 44/70] glow anton glow --- anton/channel/branding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/anton/channel/branding.py b/anton/channel/branding.py index cf3419b..a05c882 100644 --- a/anton/channel/branding.py +++ b/anton/channel/branding.py @@ -61,8 +61,8 @@ def pick_tagline(seed: int | None = None) -> str: def _build_robot_text(mouth: str, bubble: str) -> Text: """Build the full robot as a Rich Text object with styling.""" - g = "bold cyan" - m = "dim" + g = "anton.glow" + m = "anton.muted" # Pad bubble to avoid layout jitter (longest phrase is ~16 chars) padded = bubble.ljust(16) lines = [ From 82402164d3e7126a6f7e0eed2fc8b548d1d4bc6c Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 10:53:19 -0700 Subject: [PATCH 45/70] llm cli --- anton/cli.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index ba000aa..f95297a 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -439,10 +439,11 @@ def _setup_minds(settings, ws, *, enterprise: bool = False) -> None: webbrowser.open(f"{minds_url}/apiKeys") console.print() - api_key = Prompt.ask(" [anton.cyan]API key[/]", console=console) - if not api_key.strip(): - console.print(" [anton.error]No API key provided.[/]") - raise typer.Exit(1) + while True: + api_key = Prompt.ask(" [anton.cyan]API key[/]", console=console) + if api_key.strip(): + break + console.print(" [anton.warning]Please enter your API key.[/]") api_key = api_key.strip() # Store Minds credentials @@ -546,10 +547,11 @@ def _setup_anthropic(settings, ws) -> None: from rich.prompt import Confirm, Prompt console.print() - api_key = Prompt.ask(" [anton.cyan]API key[/]", console=console) - if not api_key.strip(): - console.print(" [anton.error]No API key provided.[/]") - raise typer.Exit(1) + while True: + api_key = Prompt.ask(" [anton.cyan]API key[/]", console=console) + if api_key.strip(): + break + console.print(" [anton.warning]Please enter your API key.[/]") api_key = api_key.strip() model = Prompt.ask(" [anton.cyan]Model[/]", default="claude-sonnet-4-6", console=console).strip() @@ -587,10 +589,11 @@ def _setup_openai(settings, ws) -> None: from rich.prompt import Confirm, Prompt console.print() - api_key = Prompt.ask(" [anton.cyan]API key[/]", console=console) - if not api_key.strip(): - console.print(" [anton.error]No API key provided.[/]") - raise typer.Exit(1) + while True: + api_key = Prompt.ask(" [anton.cyan]API key[/]", console=console) + if api_key.strip(): + break + console.print(" [anton.warning]Please enter your API key.[/]") api_key = api_key.strip() model = Prompt.ask(" [anton.cyan]Model[/]", default="gpt-4o", console=console).strip() From 4eff12fbef7ef36678056e37ac92ba8b95c5a5a6 Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 10:55:55 -0700 Subject: [PATCH 46/70] cli option 2 --- anton/cli.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index f95297a..62c19dd 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -376,7 +376,7 @@ def _build_frame(mouth: str, typed_lines: list[str]) -> Text: if choice == "1": _setup_minds(settings, ws) elif choice == "2": - _setup_minds(settings, ws, enterprise=True) + _setup_minds(settings, ws, default_url=None) else: _setup_other_provider(settings, ws) break # success @@ -410,7 +410,7 @@ class _SetupRetry(Exception): pass -def _setup_minds(settings, ws, *, enterprise: bool = False) -> None: +def _setup_minds(settings, ws, *, default_url: str | None = "https://mdb.ai") -> None: """Set up Minds as the LLM provider (cloud or enterprise).""" from rich.prompt import Confirm, Prompt @@ -418,11 +418,10 @@ def _setup_minds(settings, ws, *, enterprise: bool = False) -> None: console.print() - minds_url = Prompt.ask( - " [anton.cyan]Server URL[/]", - default="https://mdb.ai", - console=console, - ).strip() + prompt_kwargs = {"console": console} + if default_url: + prompt_kwargs["default"] = default_url + minds_url = Prompt.ask(" [anton.cyan]Server URL[/]", **prompt_kwargs).strip() if not minds_url.startswith("http://") and not minds_url.startswith("https://"): minds_url = "https://" + minds_url minds_url = minds_url.rstrip("/") @@ -501,7 +500,7 @@ def _setup_minds(settings, ws, *, enterprise: bool = False) -> None: console.print(" [anton.error]Could not connect. Check your API key and URL.[/]") retry = Confirm.ask(" Try again?", default=True, console=console) if retry: - _setup_minds(settings, ws, enterprise=enterprise) + _setup_minds(settings, ws, default_url=default_url) else: raise _SetupRetry() From 0d9de04425bd5a2c34a1fce15cd69dea7749e205 Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 11:15:23 -0700 Subject: [PATCH 47/70] cli with escape, --- anton/cli.py | 67 ++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 54 insertions(+), 13 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index 62c19dd..acf13f1 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -393,10 +393,16 @@ def _build_frame(mouth: str, typed_lines: list[str]) -> Text: # Summary console.print() + console.print(f"[anton.glow] {'━' * 40}[/]") + console.print() provider_label = settings.planning_provider model_label = settings.planning_model if provider_label == "openai-compatible": - provider_label = "MindsDB Cloud" + if settings.minds_url and "mdb.ai" in settings.minds_url: + provider_label = "MindsDB Cloud" + else: + provider_label = "MindsDB Enterprise Server" + model_label = "smart_router" console.print( f" [anton.muted]Provider[/] [anton.cyan]{provider_label}[/]" f" [anton.muted]Model[/] [anton.cyan]{model_label}[/]" @@ -410,6 +416,44 @@ class _SetupRetry(Exception): pass +def _setup_prompt(label: str, default: str | None = None) -> str: + """Prompt for input with ESC-to-go-back and a bottom toolbar hint. + + Returns the user's input string. + Raises _SetupRetry if the user presses ESC. + """ + from prompt_toolkit import PromptSession + from prompt_toolkit.formatted_text import HTML + from prompt_toolkit.key_binding import KeyBindings + from prompt_toolkit.styles import Style as PTStyle + + bindings = KeyBindings() + + @bindings.add("escape") + def _on_esc(event): + event.app.exit(exception=_SetupRetry()) + + pt_style = PTStyle.from_dict({ + "bottom-toolbar": "noreverse nounderline bg:default", + }) + + def _toolbar(): + return HTML("") + + suffix = f" ({default}): " if default else ": " + session: PromptSession[str] = PromptSession( + mouse_support=False, + bottom_toolbar=_toolbar, + style=pt_style, + key_bindings=bindings, + ) + + result = session.prompt(f" {label}{suffix}") + if not result and default: + return default + return result + + def _setup_minds(settings, ws, *, default_url: str | None = "https://mdb.ai") -> None: """Set up Minds as the LLM provider (cloud or enterprise).""" from rich.prompt import Confirm, Prompt @@ -418,10 +462,7 @@ def _setup_minds(settings, ws, *, default_url: str | None = "https://mdb.ai") -> console.print() - prompt_kwargs = {"console": console} - if default_url: - prompt_kwargs["default"] = default_url - minds_url = Prompt.ask(" [anton.cyan]Server URL[/]", **prompt_kwargs).strip() + minds_url = _setup_prompt("Server URL", default=default_url).strip() if not minds_url.startswith("http://") and not minds_url.startswith("https://"): minds_url = "https://" + minds_url minds_url = minds_url.rstrip("/") @@ -439,7 +480,7 @@ def _setup_minds(settings, ws, *, default_url: str | None = "https://mdb.ai") -> console.print() while True: - api_key = Prompt.ask(" [anton.cyan]API key[/]", console=console) + api_key = _setup_prompt("API key") if api_key.strip(): break console.print(" [anton.warning]Please enter your API key.[/]") @@ -543,17 +584,17 @@ def _validate_with_spinner(console, label: str, fn) -> None: def _setup_anthropic(settings, ws) -> None: """Set up Anthropic with a single model for both reasoning and coding.""" - from rich.prompt import Confirm, Prompt + from rich.prompt import Confirm console.print() while True: - api_key = Prompt.ask(" [anton.cyan]API key[/]", console=console) + api_key = _setup_prompt("API key") if api_key.strip(): break console.print(" [anton.warning]Please enter your API key.[/]") api_key = api_key.strip() - model = Prompt.ask(" [anton.cyan]Model[/]", default="claude-sonnet-4-6", console=console).strip() + model = _setup_prompt("Model", default="claude-sonnet-4-6").strip() try: def _test(): @@ -585,17 +626,17 @@ def _test(): def _setup_openai(settings, ws) -> None: """Set up OpenAI with a single model for both reasoning and coding.""" - from rich.prompt import Confirm, Prompt + from rich.prompt import Confirm console.print() while True: - api_key = Prompt.ask(" [anton.cyan]API key[/]", console=console) + api_key = _setup_prompt("API key") if api_key.strip(): break console.print(" [anton.warning]Please enter your API key.[/]") api_key = api_key.strip() - model = Prompt.ask(" [anton.cyan]Model[/]", default="gpt-4o", console=console).strip() + model = _setup_prompt("Model", default="gpt-4o").strip() try: def _test(): @@ -630,7 +671,7 @@ def setup(ctx: typer.Context) -> None: """Configure provider, model, and API key.""" settings = _get_settings(ctx) _ensure_workspace(settings) - _ensure_api_key(settings) + _onboard(settings) console.print("[anton.success]Setup complete.[/]") From 6c113a0061a00cb25e0d62f944e52f4f5125e60b Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 11:17:55 -0700 Subject: [PATCH 48/70] ESC button --- anton/cli.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index acf13f1..944c310 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -427,18 +427,22 @@ def _setup_prompt(label: str, default: str | None = None) -> str: from prompt_toolkit.key_binding import KeyBindings from prompt_toolkit.styles import Style as PTStyle + _esc_pressed = False + bindings = KeyBindings() @bindings.add("escape") def _on_esc(event): - event.app.exit(exception=_SetupRetry()) + nonlocal _esc_pressed + _esc_pressed = True + event.app.exit(result="") pt_style = PTStyle.from_dict({ "bottom-toolbar": "noreverse nounderline bg:default", }) def _toolbar(): - return HTML("") + return HTML("") suffix = f" ({default}): " if default else ": " session: PromptSession[str] = PromptSession( @@ -449,6 +453,11 @@ def _toolbar(): ) result = session.prompt(f" {label}{suffix}") + + if _esc_pressed: + console.print(" [anton.muted]Going back...[/]") + raise _SetupRetry() + if not result and default: return default return result From a73af38f4c955a48f331329cfb687969c6ceec26 Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 11:22:31 -0700 Subject: [PATCH 49/70] cli updates --- anton/cli.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index 944c310..7cfb46c 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -250,10 +250,9 @@ def _onboard(settings) -> None: g = "anton.glow" _INTRO_LINES = [ - "Hi! I'm Anton, an autonomous AI coworker built by MindsDB.", + "Hi! I'm Anton, an autonomous AI coworker.", "", - "For the best experience, I recommend MindsDB Cloud as your LLM Provider,", - "optimized for:", + "For the best experience, I recommend MindsDB-Cloud as your LLM Provider:", "", " \u2713 Smart model routing", " \u2713 Faster responses", @@ -557,28 +556,25 @@ def _setup_minds(settings, ws, *, default_url: str | None = "https://mdb.ai") -> def _setup_other_provider(settings, ws) -> None: """Set up Anthropic or OpenAI as the LLM provider.""" - from rich.prompt import Prompt from rich.text import Text console.print() - for label, idx in [("Anthropic (Claude)", "1"), ("OpenAI (GPT)", "2")]: + for label, idx in [("Anthropic", "1"), ("OpenAI", "2")]: line = Text() line.append(f" {idx} ", style="bold") line.append(label, style="anton.cyan") console.print(line) console.print() - provider_choice = Prompt.ask( - "[anton.cyan]>[/]", - choices=["1", "2"], - console=console, - show_choices=False, - ) + choice = _setup_prompt("Provider", default="Anthropic").strip().lower() - if provider_choice == "1": + if choice in ("1", "anthropic"): _setup_anthropic(settings, ws) - else: + elif choice in ("2", "openai"): _setup_openai(settings, ws) + else: + console.print(f" [anton.warning]Unknown provider '{choice}', using Anthropic.[/]") + _setup_anthropic(settings, ws) def _validate_with_spinner(console, label: str, fn) -> None: From 6aa1c36a006155c293d4ed5f3a4da49d1fd64743 Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 11:26:53 -0700 Subject: [PATCH 50/70] ok cli --- anton/chat.py | 2 +- anton/cli.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index b86e36c..dc6d509 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -2408,7 +2408,7 @@ async def _chat_loop(console: Console, settings: AntonSettings, *, resume: bool current_session_id = resumed_id - console.print("[anton.muted] Chat with Anton. Type '/help' for commands or 'exit' to quit.[/]") + console.print("[anton.muted] Chat with me, type '/help' for commands or 'exit' to quit.[/]") console.print(f"[anton.cyan_dim] {'━' * 40}[/]") console.print() diff --git a/anton/cli.py b/anton/cli.py index 7cfb46c..780d42a 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -406,7 +406,7 @@ def _build_frame(mouth: str, typed_lines: list[str]) -> Text: f" [anton.muted]Provider[/] [anton.cyan]{provider_label}[/]" f" [anton.muted]Model[/] [anton.cyan]{model_label}[/]" ) - console.print(f" [anton.success]Ready.[/] [anton.muted]Saved to {ws.env_path}[/]") + console.print(f" [anton.success]Ready.[/]") console.print() From 3945720266a456bfe0f5cae5abb40ef62e5c26e7 Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 11:33:21 -0700 Subject: [PATCH 51/70] just final details on cli --- anton/chat.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index dc6d509..a751bdd 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -1218,12 +1218,30 @@ async def _handle_setup_models( # Always persist API keys and model settings to global ~/.anton/.env global_ws = _Workspace(Path.home()) + def _provider_label(provider: str) -> str: + if provider == "openai-compatible": + if settings.minds_url and "mdb.ai" in settings.minds_url: + return "MindsDB-Cloud" + return "MindsDB-Enterprise" + return provider.capitalize() + + def _model_label(model: str, role: str) -> str: + if model in ("_reason_", "_code_"): + return f"smart_router({role})" + return model + + provider_display = _provider_label(settings.planning_provider) + planning_display = _model_label(settings.planning_model, "planning") + coding_display = _model_label(settings.coding_model, "coding") + console.print() console.print("[anton.cyan]Current configuration:[/]") - console.print(f" Provider (planning): [bold]{settings.planning_provider}[/]") - console.print(f" Provider (coding): [bold]{settings.coding_provider}[/]") - console.print(f" Planning model: [bold]{settings.planning_model}[/]") - console.print(f" Coding model: [bold]{settings.coding_model}[/]") + console.print(f" Provider: [bold]{provider_display}[/]") + if planning_display == coding_display: + console.print(f" Model: [bold]{planning_display}[/]") + else: + console.print(f" Planning: [bold]{planning_display}[/]") + console.print(f" Coding: [bold]{coding_display}[/]") console.print() # --- Provider --- From 1ca7b3d052ad731676f237d3dbe1b41479d3013e Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Tue, 24 Mar 2026 19:38:04 +0100 Subject: [PATCH 52/70] Add test snippet for custom datasources --- anton/chat.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/anton/chat.py b/anton/chat.py index 4202a51..97cb98d 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -2684,6 +2684,7 @@ async def _handle_add_custom_datasource( f"The user wants to connect to {repr(tool_name)} and said: {user_answer}\n\n" "Return ONLY valid JSON (no markdown fences, no commentary):\n" '{"display_name":"Human-readable name","pip":"pip-package or empty string",' + '"test_snippet":"python code that tests the connection using os.environ vars DS_FIELDNAME (uppercase field name with DS_ prefix) and prints ok on success, or empty string if untestable",' '"fields":[{"name":"snake_case_name","value":"value if given inline else empty",' '"secret":true or false,"required":true or false,"description":"what it is"}]}' ), @@ -2702,6 +2703,7 @@ async def _handle_add_custom_datasource( console.print() return None + test_snippet = str(data.get("test_snippet", "")).strip() raw_fields = data.get("fields") or [] fields: list[DatasourceField] = [] for f in raw_fields: @@ -2801,6 +2803,11 @@ async def _handle_add_custom_datasource( f'secret: {str(f.secret).lower()}, description: "{f.description}" }}' for f in fields ) + test_snippet_yaml = "" + if test_snippet: + indented = "\n".join(f" {line}" for line in test_snippet.splitlines()) + test_snippet_yaml = f"test_snippet: |\n{indented}\n" + yaml_block = ( f"\n---\n\n## {display_name}\n" "```yaml\n" @@ -2808,7 +2815,8 @@ async def _handle_add_custom_datasource( f"display_name: {display_name}\n" + (f"pip: {pip_pkg}\n" if pip_pkg else "") + f"fields:\n{field_lines}\n" - "```\n" + + test_snippet_yaml + + "```\n" ) user_ds_path = Path("~/.anton/datasources.md").expanduser() tmp_path = user_ds_path.with_suffix(".tmp") From 54bf9599884a219757e2a94529f3ead16ac9a493 Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 11:42:22 -0700 Subject: [PATCH 53/70] Minds-Cloud --- anton/chat.py | 110 +++++++++++++------------------------------------- anton/cli.py | 34 +++++++++------- 2 files changed, 48 insertions(+), 96 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index a751bdd..71619a7 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -1221,8 +1221,8 @@ async def _handle_setup_models( def _provider_label(provider: str) -> str: if provider == "openai-compatible": if settings.minds_url and "mdb.ai" in settings.minds_url: - return "MindsDB-Cloud" - return "MindsDB-Enterprise" + return "Minds-Cloud" + return "Minds-Enterprise" return provider.capitalize() def _model_label(model: str, role: str) -> str: @@ -1244,91 +1244,39 @@ def _model_label(model: str, role: str) -> str: console.print(f" Coding: [bold]{coding_display}[/]") console.print() - # --- Provider --- - providers = {"1": "anthropic", "2": "openai", "3": "openai-compatible"} - current_num = {"anthropic": "1", "openai": "2", "openai-compatible": "3"}.get(settings.planning_provider, "1") - console.print("[anton.cyan]Available providers:[/]") - console.print(r" [bold]1[/] Anthropic (Claude) [dim]\[recommended][/]") - console.print(r" [bold]2[/] OpenAI (GPT / o-series) [dim]\[experimental][/]") - console.print(r" [bold]3[/] OpenAI-compatible (custom endpoint) [dim]\[experimental][/]") - console.print() + # Use the same onboarding flow from cli.py + from anton.cli import _setup_minds, _setup_other_provider, _SetupRetry, _setup_prompt - choice = Prompt.ask( - "Select provider", - choices=["1", "2", "3"], - default=current_num, - console=console, - ) - provider = providers[choice] + console.print(" [bold]1[/] [link=https://mdb.ai][anton.cyan]Minds-Cloud[/][/link] [anton.success](recommended)[/]") + console.print(" [bold]2[/] [anton.cyan]Minds-Enterprise Server[/]") + console.print(" [bold]3[/] [anton.cyan]Bring your own key[/] [anton.muted]Anthropic / OpenAI[/]") + console.print() - # --- Base URL (OpenAI-compatible only) --- - if provider == "openai-compatible": - current_base_url = settings.openai_base_url or "" - console.print() - base_url = Prompt.ask( - f"API base URL [dim](e.g. http://localhost:11434/v1)[/]", - default=current_base_url, + while True: + choice = Prompt.ask( + "Choose LLM Provider", + choices=["1", "2", "3"], + default="1", console=console, ) - base_url = base_url.strip() - if base_url: - settings.openai_base_url = base_url - global_ws.set_secret("ANTON_OPENAI_BASE_URL", base_url) - - # --- API key --- - key_attr = "anthropic_api_key" if provider == "anthropic" else "openai_api_key" - current_key = getattr(settings, key_attr) or "" - masked = current_key[:4] + "..." + current_key[-4:] if len(current_key) > 8 else "***" - console.print() - api_key = Prompt.ask( - f"API key for {provider.title()} [dim](Enter to keep {masked})[/]", - default="", - console=console, - ) - api_key = api_key.strip() - - # --- Models --- - defaults = { - "anthropic": ("claude-sonnet-4-6", "claude-haiku-4-5-20251001"), - "openai": ("gpt-5-mini", "gpt-5-nano"), - } - default_planning, default_coding = defaults.get(provider, ("", "")) - console.print() - planning_model = Prompt.ask( - "Planning model", - default=settings.planning_model if provider == settings.planning_provider else default_planning, - console=console, - ) - coding_model = Prompt.ask( - "Coding model", - default=settings.coding_model if provider == settings.coding_provider else default_coding, - console=console, - ) - - # --- Persist to global ~/.anton/.env --- - settings.planning_provider = provider - settings.coding_provider = provider - settings.planning_model = planning_model - settings.coding_model = coding_model - - global_ws.set_secret("ANTON_PLANNING_PROVIDER", provider) - global_ws.set_secret("ANTON_CODING_PROVIDER", provider) - global_ws.set_secret("ANTON_PLANNING_MODEL", planning_model) - global_ws.set_secret("ANTON_CODING_MODEL", coding_model) - - if api_key: - setattr(settings, key_attr, api_key) - key_name = f"ANTON_{provider.upper()}_API_KEY" - global_ws.set_secret(key_name, api_key) + try: + if choice == "1": + _setup_minds(settings, global_ws) + elif choice == "2": + _setup_minds(settings, global_ws, default_url=None) + else: + _setup_other_provider(settings, global_ws) + break + except _SetupRetry: + console.print() + console.print(" [bold]1[/] [link=https://mdb.ai][anton.cyan]Minds-Cloud[/][/link] [anton.success](recommended)[/]") + console.print(" [bold]2[/] [anton.cyan]Minds-Enterprise Server[/]") + console.print(" [bold]3[/] [anton.cyan]Bring your own key[/] [anton.muted]Anthropic / OpenAI[/]") + console.print() + continue - # Validate that we actually have an API key for the chosen provider - final_key = getattr(settings, key_attr) - if not final_key: - console.print() - console.print(f"[anton.error]No API key set for {provider}. Configuration not applied.[/]") - console.print() - return session + global_ws.apply_env_to_process() console.print() console.print("[anton.success]Configuration updated.[/]") diff --git a/anton/cli.py b/anton/cli.py index 780d42a..8858d65 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -252,7 +252,7 @@ def _onboard(settings) -> None: _INTRO_LINES = [ "Hi! I'm Anton, an autonomous AI coworker.", "", - "For the best experience, I recommend MindsDB-Cloud as your LLM Provider:", + "For the best experience, I recommend Minds-Cloud as your LLM Provider:", "", " \u2713 Smart model routing", " \u2713 Faster responses", @@ -358,8 +358,8 @@ def _build_frame(mouth: str, typed_lines: list[str]) -> Text: console.print() console.print(f"[anton.glow] {'━' * 40}[/]") console.print() - console.print(" [bold]1[/] [link=https://mdb.ai][anton.cyan]MindsDB Cloud[/][/link] [anton.success](recommended)[/]") - console.print(" [bold]2[/] [anton.cyan]MindsDB Enterprise Server[/]") + console.print(" [bold]1[/] [link=https://mdb.ai][anton.cyan]Minds-Cloud[/][/link] [anton.success](recommended)[/]") + console.print(" [bold]2[/] [anton.cyan]Minds-Enterprise Server[/]") console.print(" [bold]3[/] [anton.cyan]Bring your own key[/] [anton.muted]Anthropic / OpenAI[/]") console.print() @@ -381,8 +381,8 @@ def _build_frame(mouth: str, typed_lines: list[str]) -> Text: break # success except _SetupRetry: console.print() - console.print(" [bold]1[/] [link=https://mdb.ai][anton.cyan]MindsDB Cloud[/][/link] [anton.success](recommended)[/]") - console.print(" [bold]2[/] [anton.cyan]MindsDB Enterprise Server[/]") + console.print(" [bold]1[/] [link=https://mdb.ai][anton.cyan]Minds-Cloud[/][/link] [anton.success](recommended)[/]") + console.print(" [bold]2[/] [anton.cyan]Minds-Enterprise Server[/]") console.print(" [bold]3[/] [anton.cyan]Bring your own key[/] [anton.muted]Anthropic / OpenAI[/]") console.print() continue @@ -398,9 +398,9 @@ def _build_frame(mouth: str, typed_lines: list[str]) -> Text: model_label = settings.planning_model if provider_label == "openai-compatible": if settings.minds_url and "mdb.ai" in settings.minds_url: - provider_label = "MindsDB Cloud" + provider_label = "Minds-Cloud" else: - provider_label = "MindsDB Enterprise Server" + provider_label = "Minds-Enterprise Server" model_label = "smart_router" console.print( f" [anton.muted]Provider[/] [anton.cyan]{provider_label}[/]" @@ -470,20 +470,24 @@ def _setup_minds(settings, ws, *, default_url: str | None = "https://mdb.ai") -> console.print() - minds_url = _setup_prompt("Server URL", default=default_url).strip() - if not minds_url.startswith("http://") and not minds_url.startswith("https://"): - minds_url = "https://" + minds_url - minds_url = minds_url.rstrip("/") + is_cloud = default_url == "https://mdb.ai" + if is_cloud: + minds_url = "https://mdb.ai" + else: + minds_url = _setup_prompt("Server URL", default=default_url).strip() + if not minds_url.startswith("http://") and not minds_url.startswith("https://"): + minds_url = "https://" + minds_url + minds_url = minds_url.rstrip("/") + + console.print(" [anton.muted]If you don't have an API key yet, we'll help you create one — it takes a few seconds.[/]") + console.print() has_key = Confirm.ask( - " Do you have an API key?", + " Do you have an mdb.ai API key?" if is_cloud else " Do you have an API key?", default=True, console=console, ) if not has_key: - console.print( - " [anton.muted]No problem — it only takes a few seconds to create one.[/]" - ) webbrowser.open(f"{minds_url}/apiKeys") console.print() From 597ded37fe532759c59a963ef20c002ba21f7228 Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 11:47:56 -0700 Subject: [PATCH 54/70] ok, lets mambo, it is done now --- anton/chat.py | 24 ++++++++++++++---------- anton/cli.py | 21 +++++++++++---------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index 71619a7..253a4bf 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -1247,19 +1247,26 @@ def _model_label(model: str, role: str) -> str: # Use the same onboarding flow from cli.py from anton.cli import _setup_minds, _setup_other_provider, _SetupRetry, _setup_prompt - console.print(" [bold]1[/] [link=https://mdb.ai][anton.cyan]Minds-Cloud[/][/link] [anton.success](recommended)[/]") - console.print(" [bold]2[/] [anton.cyan]Minds-Enterprise Server[/]") - console.print(" [bold]3[/] [anton.cyan]Bring your own key[/] [anton.muted]Anthropic / OpenAI[/]") - console.print() + def _print_choices(): + console.print(" [bold]1[/] [link=https://mdb.ai][anton.cyan]Minds-Cloud[/][/link] [anton.success](recommended)[/]") + console.print(" [bold]2[/] [anton.cyan]Minds-Enterprise Server[/]") + console.print(" [bold]3[/] [anton.cyan]Bring your own key[/] [anton.muted]Anthropic / OpenAI[/]") + console.print(" [bold]q[/] [anton.muted]Back[/]") + console.print() + + _print_choices() while True: choice = Prompt.ask( "Choose LLM Provider", - choices=["1", "2", "3"], - default="1", + choices=["1", "2", "3", "q"], + default="q", console=console, ) + if choice == "q": + return session + try: if choice == "1": _setup_minds(settings, global_ws) @@ -1270,10 +1277,7 @@ def _model_label(model: str, role: str) -> str: break except _SetupRetry: console.print() - console.print(" [bold]1[/] [link=https://mdb.ai][anton.cyan]Minds-Cloud[/][/link] [anton.success](recommended)[/]") - console.print(" [bold]2[/] [anton.cyan]Minds-Enterprise Server[/]") - console.print(" [bold]3[/] [anton.cyan]Bring your own key[/] [anton.muted]Anthropic / OpenAI[/]") - console.print() + _print_choices() continue global_ws.apply_env_to_process() diff --git a/anton/cli.py b/anton/cli.py index 8858d65..def1b51 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -250,7 +250,7 @@ def _onboard(settings) -> None: g = "anton.glow" _INTRO_LINES = [ - "Hi! I'm Anton, an autonomous AI coworker.", + "Hi Boss! I'm Anton, your AI coworker.", "", "For the best experience, I recommend Minds-Cloud as your LLM Provider:", "", @@ -480,16 +480,17 @@ def _setup_minds(settings, ws, *, default_url: str | None = "https://mdb.ai") -> minds_url = "https://" + minds_url minds_url = minds_url.rstrip("/") - console.print(" [anton.muted]If you don't have an API key yet, we'll help you create one — it takes a few seconds.[/]") - console.print() - has_key = Confirm.ask( - " Do you have an mdb.ai API key?" if is_cloud else " Do you have an API key?", - default=True, - console=console, - ) - if not has_key: - webbrowser.open(f"{minds_url}/apiKeys") + if is_cloud: + console.print(" [anton.muted]If you don't have an API key yet, we'll help you create one — it takes a few seconds.[/]") console.print() + has_key = Confirm.ask( + " Do you have an mdb.ai API key?", + default=True, + console=console, + ) + if not has_key: + webbrowser.open(f"{minds_url}/apiKeys") + console.print() while True: api_key = _setup_prompt("API key") From 913615ca0cfd67e7587d2d72ad742180b4be7671 Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Tue, 24 Mar 2026 11:54:20 -0700 Subject: [PATCH 55/70] and BYOK --- anton/cli.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index def1b51..45b9e16 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -402,11 +402,8 @@ def _build_frame(mouth: str, typed_lines: list[str]) -> Text: else: provider_label = "Minds-Enterprise Server" model_label = "smart_router" - console.print( - f" [anton.muted]Provider[/] [anton.cyan]{provider_label}[/]" - f" [anton.muted]Model[/] [anton.cyan]{model_label}[/]" - ) - console.print(f" [anton.success]Ready.[/]") + console.print(f" [anton.muted]Provider:[/] [anton.cyan]{provider_label}[/]") + console.print(f" [anton.muted]Model:[/] [anton.cyan]{model_label}[/]") console.print() @@ -420,7 +417,10 @@ def _setup_prompt(label: str, default: str | None = None) -> str: Returns the user's input string. Raises _SetupRetry if the user presses ESC. + Works both from sync context (onboarding) and async context (/setup). """ + import asyncio + from prompt_toolkit import PromptSession from prompt_toolkit.formatted_text import HTML from prompt_toolkit.key_binding import KeyBindings @@ -441,7 +441,7 @@ def _on_esc(event): }) def _toolbar(): - return HTML("") + return HTML("") suffix = f" ({default}): " if default else ": " session: PromptSession[str] = PromptSession( @@ -451,7 +451,22 @@ def _toolbar(): key_bindings=bindings, ) - result = session.prompt(f" {label}{suffix}") + # Use async prompt if inside a running event loop, sync otherwise + try: + asyncio.get_running_loop() + in_async = True + except RuntimeError: + in_async = False + + if in_async: + # We're inside an async context (e.g. /setup from chat loop) + # Run prompt_toolkit in a thread to avoid nested event loop conflict + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + future = pool.submit(session.prompt, f" {label}{suffix}") + result = future.result() + else: + result = session.prompt(f" {label}{suffix}") if _esc_pressed: console.print(" [anton.muted]Going back...[/]") From 4d47a93f569e76099a92bf6aa01829012b1e63d6 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Tue, 24 Mar 2026 21:07:33 +0100 Subject: [PATCH 56/70] Cleanup merge confilct --- anton/chat.py | 70 --------------------------------------------------- 1 file changed, 70 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index 64d91f3..524df22 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -1589,24 +1589,8 @@ def _print_choices(): console=console, ) -<<<<<<< HEAD - # --- API key --- - key_attr = "anthropic_api_key" if provider == "anthropic" else "openai_api_key" - current_key = getattr(settings, key_attr) or "" - masked = ( - current_key[:4] + "..." + current_key[-4:] if len(current_key) > 8 else "***" - ) - console.print() - api_key = Prompt.ask( - f"API key for {provider.title()} [dim](Enter to keep {masked})[/]", - default="", - console=console, - ) - api_key = api_key.strip() -======= if choice == "q": return session ->>>>>>> origin/onboarding-launch try: if choice == "1": @@ -1621,55 +1605,7 @@ def _print_choices(): _print_choices() continue -<<<<<<< HEAD - console.print() - planning_model = Prompt.ask( - "Planning model", - default=( - settings.planning_model - if provider == settings.planning_provider - else default_planning - ), - console=console, - ) - coding_model = Prompt.ask( - "Coding model", - default=( - settings.coding_model - if provider == settings.coding_provider - else default_coding - ), - console=console, - ) - - # --- Persist to global ~/.anton/.env --- - settings.planning_provider = provider - settings.coding_provider = provider - settings.planning_model = planning_model - settings.coding_model = coding_model - - global_ws.set_secret("ANTON_PLANNING_PROVIDER", provider) - global_ws.set_secret("ANTON_CODING_PROVIDER", provider) - global_ws.set_secret("ANTON_PLANNING_MODEL", planning_model) - global_ws.set_secret("ANTON_CODING_MODEL", coding_model) - - if api_key: - setattr(settings, key_attr, api_key) - key_name = f"ANTON_{provider.upper()}_API_KEY" - global_ws.set_secret(key_name, api_key) - - # Validate that we actually have an API key for the chosen provider - final_key = getattr(settings, key_attr) - if not final_key: - console.print() - console.print( - f"[anton.error]No API key set for {provider}. Configuration not applied.[/]" - ) - console.print() - return session -======= global_ws.apply_env_to_process() ->>>>>>> origin/onboarding-launch console.print() console.print("[anton.success]Configuration updated.[/]") @@ -3916,14 +3852,8 @@ async def _chat_loop( if resumed_id: current_session_id = resumed_id -<<<<<<< HEAD - console.print( - "[anton.muted] Chat with Anton. Type '/help' for commands or 'exit' to quit.[/]" - ) -======= console.print("[anton.muted] Chat with me, type '/help' for commands or 'exit' to quit.[/]") ->>>>>>> origin/onboarding-launch console.print(f"[anton.cyan_dim] {'━' * 40}[/]") console.print() From c8b3eede7e36e5f1858be621586cc2541d804473 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Tue, 24 Mar 2026 21:26:35 +0100 Subject: [PATCH 57/70] Edit name of the vault entries --- anton/chat.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index 524df22..fbb59f4 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -2881,6 +2881,14 @@ async def _handle_connect_datasource( console.print("[anton.muted] Press Enter to keep the current value.[/]") console.print() + new_name = Prompt.ask( + f"[anton.cyan](anton)[/] name [anton.muted][{edit_name}][/]", + console=console, + default=edit_name, + ) + new_name = new_name.strip() or edit_name + new_slug = f"{edit_engine}-{new_name}" + # Detect which fields to present (handle auth_method=choice) active_fields = engine_def.fields if engine_def.auth_method == "choice" and engine_def.auth_methods: @@ -2991,12 +2999,16 @@ async def _handle_connect_datasource( console.print("[anton.success] ✓ Connected successfully![/]") break - vault.save(edit_engine, edit_name, credentials) + vault.save(edit_engine, new_name, credentials) + if new_name != edit_name: + vault.delete(edit_engine, edit_name) _restore_namespaced_env(vault) - _register_secret_vars(engine_def, engine=edit_engine, name=edit_name) + _register_secret_vars(engine_def, engine=edit_engine, name=new_name) + if new_name != edit_name and session._active_datasource == datasource_name: + session._active_datasource = new_slug console.print() console.print( - f' Credentials updated for [bold]"{datasource_name}"[/bold].' + f' Credentials updated for [bold]"{new_slug}"[/bold].' ) console.print() console.print( @@ -3008,7 +3020,7 @@ async def _handle_connect_datasource( "role": "assistant", "content": ( f"I've updated the credentials for the {engine_def.display_name} connection " - f'"{datasource_name}" in the Local Vault.' + f'"{new_slug}" in the Local Vault.' ), } ) From 803c2a570512d587796753dbe1cc0ff1efd7e7ea Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Tue, 24 Mar 2026 21:55:04 +0100 Subject: [PATCH 58/70] Return wrapper for ensure_api_key --- anton/cli.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/anton/cli.py b/anton/cli.py index a255d0c..be0ea1c 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -290,6 +290,11 @@ def _onboard(settings) -> None: console.print(line) +def _ensure_api_key(settings) -> None: + if not _has_api_key(settings): + _onboard(settings) + + def _animate_onboard(console, version: str, intro_lines: list[str], *, settings, ws) -> None: """Animate the robot talking while typing out the intro text below.""" import time From 449689582b171bcf97e38a8099d82ccf8ccd2ed0 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Tue, 24 Mar 2026 22:09:59 +0100 Subject: [PATCH 59/70] Revert back edit name because breaks env variables --- anton/chat.py | 20 ++++---------------- anton/cli.py | 8 ++++---- 2 files changed, 8 insertions(+), 20 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index fbb59f4..524df22 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -2881,14 +2881,6 @@ async def _handle_connect_datasource( console.print("[anton.muted] Press Enter to keep the current value.[/]") console.print() - new_name = Prompt.ask( - f"[anton.cyan](anton)[/] name [anton.muted][{edit_name}][/]", - console=console, - default=edit_name, - ) - new_name = new_name.strip() or edit_name - new_slug = f"{edit_engine}-{new_name}" - # Detect which fields to present (handle auth_method=choice) active_fields = engine_def.fields if engine_def.auth_method == "choice" and engine_def.auth_methods: @@ -2999,16 +2991,12 @@ async def _handle_connect_datasource( console.print("[anton.success] ✓ Connected successfully![/]") break - vault.save(edit_engine, new_name, credentials) - if new_name != edit_name: - vault.delete(edit_engine, edit_name) + vault.save(edit_engine, edit_name, credentials) _restore_namespaced_env(vault) - _register_secret_vars(engine_def, engine=edit_engine, name=new_name) - if new_name != edit_name and session._active_datasource == datasource_name: - session._active_datasource = new_slug + _register_secret_vars(engine_def, engine=edit_engine, name=edit_name) console.print() console.print( - f' Credentials updated for [bold]"{new_slug}"[/bold].' + f' Credentials updated for [bold]"{datasource_name}"[/bold].' ) console.print() console.print( @@ -3020,7 +3008,7 @@ async def _handle_connect_datasource( "role": "assistant", "content": ( f"I've updated the credentials for the {engine_def.display_name} connection " - f'"{new_slug}" in the Local Vault.' + f'"{datasource_name}" in the Local Vault.' ), } ) diff --git a/anton/cli.py b/anton/cli.py index be0ea1c..1e7a17f 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -853,13 +853,13 @@ def connect_data_source( session = ChatSession(llm_client) async def _run() -> None: - updated = await _handle_connect_datasource( + await _handle_connect_datasource( console, scratchpads, session, datasource_name=slug or None, ) - await updated._scratchpads.close_all() + await scratchpads.close_all() asyncio.run(_run()) @@ -902,13 +902,13 @@ def edit_data_source( session = ChatSession(llm_client) async def _run() -> None: - updated = await _handle_connect_datasource( + await _handle_connect_datasource( console, scratchpads, session, datasource_name=name, ) - await updated._scratchpads.close_all() + await scratchpads.close_all() asyncio.run(_run()) From 41247291db7aea8aed5edde58156724c31631403 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Tue, 24 Mar 2026 23:43:10 +0100 Subject: [PATCH 60/70] Return back imports --- anton/chat.py | 1 + 1 file changed, 1 insertion(+) diff --git a/anton/chat.py b/anton/chat.py index 524df22..0ac9ca6 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -1524,6 +1524,7 @@ async def _handle_setup_models( from rich.prompt import Prompt from anton.workspace import Workspace as _Workspace + from anton.cli import _SetupRetry, _setup_minds, _setup_other_provider # Always persist API keys and model settings to global ~/.anton/.env global_ws = _Workspace(Path.home()) From 30df7a0e8de9f20e8c19ab8315952bb886a4d0bf Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Wed, 25 Mar 2026 12:57:02 +0100 Subject: [PATCH 61/70] Small change in prompt to follow same pattern as other ones --- anton/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/anton/cli.py b/anton/cli.py index 1e7a17f..5878af1 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -609,7 +609,7 @@ def _setup_other_provider(settings, ws) -> None: console.print(line) console.print() - choice = _setup_prompt("Provider", default="Anthropic").strip().lower() + choice = _setup_prompt("Provider", default="1").strip().lower() if choice in ("1", "anthropic"): _setup_anthropic(settings, ws) From 4588b89aa14a1defd319a746af22b869e904fbea Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Wed, 25 Mar 2026 13:03:07 +0100 Subject: [PATCH 62/70] Hide api keys from displaying --- anton/cli.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/anton/cli.py b/anton/cli.py index 5878af1..5f82aa9 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -435,7 +435,7 @@ class _SetupRetry(Exception): pass -def _setup_prompt(label: str, default: str | None = None) -> str: +def _setup_prompt(label: str, default: str | None = None, is_password: bool = False) -> str: """Prompt for input with ESC-to-go-back and a bottom toolbar hint. Returns the user's input string. @@ -472,6 +472,7 @@ def _toolbar(): bottom_toolbar=_toolbar, style=pt_style, key_bindings=bindings, + is_password=is_password, ) # Use async prompt if inside a running event loop, sync otherwise @@ -531,7 +532,7 @@ def _setup_minds(settings, ws, *, default_url: str | None = "https://mdb.ai") -> console.print() while True: - api_key = _setup_prompt("API key") + api_key = _setup_prompt("API key", is_password=True) if api_key.strip(): break console.print(" [anton.warning]Please enter your API key.[/]") @@ -636,7 +637,7 @@ def _setup_anthropic(settings, ws) -> None: console.print() while True: - api_key = _setup_prompt("API key") + api_key = _setup_prompt("API key", is_password=True) if api_key.strip(): break console.print(" [anton.warning]Please enter your API key.[/]") @@ -678,7 +679,7 @@ def _setup_openai(settings, ws) -> None: console.print() while True: - api_key = _setup_prompt("API key") + api_key = _setup_prompt("API key", is_password=True) if api_key.strip(): break console.print(" [anton.warning]Please enter your API key.[/]") From 57a9945e3fc667c33d37c1eab2ba8814e32f8b16 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Wed, 25 Mar 2026 15:02:26 +0100 Subject: [PATCH 63/70] Add hash value for the datasource name --- anton/chat.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index 0ac9ca6..041f7d1 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -5,6 +5,7 @@ import os import re as _re import sys +import uuid import yaml as _yaml import time from collections.abc import AsyncIterator, Callable @@ -3141,10 +3142,9 @@ async def _handle_connect_datasource( if result is None: return session engine_def, credentials = result - conn_num = vault.next_connection_number(engine_def.engine) - conn_name = str(conn_num) + conn_name = uuid.uuid4().hex[:8] vault.save(engine_def.engine, conn_name, credentials) - slug = f"{engine_def.engine}-{conn_num}" + slug = f"{engine_def.engine}-{conn_name}" _restore_namespaced_env(vault) session._active_datasource = slug _register_secret_vars(engine_def, engine=engine_def.engine, name=conn_name) @@ -3274,8 +3274,7 @@ async def _handle_connect_datasource( credentials[f.name] = value if partial: - n = vault.next_connection_number(engine_def.engine) - auto_name = str(n) + auto_name = uuid.uuid4().hex[:8] vault.save(engine_def.engine, auto_name, credentials) slug = f"{engine_def.engine}-{auto_name}" console.print() @@ -3359,8 +3358,7 @@ async def _handle_connect_datasource( conn_name = registry.derive_name(engine_def, credentials) if not conn_name: - n = vault.next_connection_number(engine_def.engine) - conn_name = str(n) + conn_name = uuid.uuid4().hex[:8] slug = f"{engine_def.engine}-{conn_name}" From 1c0a67d6886975d5cefef82d817579213e8f8276 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Wed, 25 Mar 2026 16:50:43 +0100 Subject: [PATCH 64/70] Remove data-connections command --- anton/chat.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index 041f7d1..fd79348 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -6,6 +6,7 @@ import re as _re import sys import uuid +from warnings import deprecated import yaml as _yaml import time from collections.abc import AsyncIterator, Callable @@ -1962,6 +1963,7 @@ def _display_value(key: str, value: str) -> str: return value or "[dim][/]" +@deprecated("The /data-connections menu is deprecated and will be removed in a future release.") async def _handle_data_connections( console: Console, settings: AntonSettings, @@ -3599,9 +3601,6 @@ def _print_slash_help(console: Console) -> None: console.print(" [bold]/edit-data-source[/] — Edit a saved connection's credentials") console.print(" [bold]/remove-data-source[/] — Remove a saved connection") console.print(" [bold]/test-data-source[/] — Test a saved connection") - console.print( - " [bold]/data-connections[/] — View and manage stored keys and connections" - ) console.print( " [bold]/setup[/] — Configure models or memory settings" ) @@ -4053,14 +4052,6 @@ def _bottom_toolbar(): console, session._scratchpads, arg ) continue - elif cmd == "/data-connections": - session = await _handle_data_connections( - console, - settings, - workspace, - session, - ) - continue elif cmd == "/resume": session, resumed_id = await _handle_resume( console, From 23004752b8ddc8ee33f70d7a45e1c4a8dab5bede Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Wed, 25 Mar 2026 18:02:09 +0100 Subject: [PATCH 65/70] Remove imporinng of deprecated --- anton/chat.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index fd79348..bb3933e 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -6,7 +6,6 @@ import re as _re import sys import uuid -from warnings import deprecated import yaml as _yaml import time from collections.abc import AsyncIterator, Callable @@ -1963,7 +1962,7 @@ def _display_value(key: str, value: str) -> str: return value or "[dim][/]" -@deprecated("The /data-connections menu is deprecated and will be removed in a future release.") +#TODO: The /data-connections menu is deprecated and will be removed in a future release. async def _handle_data_connections( console: Console, settings: AntonSettings, From 05146562771597bd4e91eda882d3a3b0c999f5b9 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Wed, 25 Mar 2026 18:47:12 +0100 Subject: [PATCH 66/70] Fix 0 alligment --- anton/chat.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index bb3933e..de53b1f 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -2847,7 +2847,7 @@ async def _handle_connect_datasource( vault = DataVault() registry = DatasourceRegistry() - + if datasource_name is not None: _parsed = parse_connection_slug( datasource_name, [e.engine for e in registry.all_engines()], vault=vault @@ -3019,14 +3019,13 @@ async def _handle_connect_datasource( console.print() all_engines = registry.all_engines() - if prefill: answer = prefill else: console.print( "[anton.cyan](anton)[/] Which data source would you like to connect?\n" ) - console.print(" [bold] 0.[/bold] Create a custom datasource") + console.print(" [bold] 0.[/bold] Create a custom datasource") for i, e in enumerate(all_engines, 1): console.print(f" [bold]{i:>2}.[/bold] {e.display_name}") console.print() From 16a903d5b3dbde29c6b2f2c4be8b819b3799e418 Mon Sep 17 00:00:00 2001 From: martyna-mindsdb Date: Wed, 25 Mar 2026 19:58:38 +0100 Subject: [PATCH 67/70] changed command names to /connect, /list, /remove, /test; and disabled /connect for minds --- anton/chat.py | 65 ++++++++++++++++++++-------------------- anton/cli.py | 12 ++++---- tests/test_datasource.py | 16 +++++----- 3 files changed, 47 insertions(+), 46 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index bb3933e..f44773f 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -3233,7 +3233,7 @@ async def _handle_connect_datasource( console.print( "[anton.cyan](anton)[/] No problem. Which parameters do you have? " "I'll save a partial connection now, and you can fill in the rest later " - "with [bold]/edit-data-source[/]." + "with [bold]/edit[/]." ) console.print() console.print(" Provide what you have (press enter to skip any field):") @@ -3282,7 +3282,7 @@ async def _handle_connect_datasource( console.print( f"[anton.muted]Partial connection saved to Local Vault as " f'[bold]"{slug}"[/bold]. ' - f"Run [bold]/edit-data-source {slug}[/bold] to complete it when you're ready.[/]" + f"Run [bold]/edit {slug}[/bold] to complete it when you're ready.[/]" ) console.print() return session @@ -3435,7 +3435,7 @@ def _handle_list_data_sources(console: Console) -> None: console.print() if not conns: console.print("[anton.muted]No data sources connected yet.[/]") - console.print("[anton.muted]Use /connect-data-source to add one.[/]") + console.print("[anton.muted]Use /connect to add one.[/]") console.print() return @@ -3501,7 +3501,7 @@ async def _handle_test_datasource( """Test an existing Local Vault connection by running its test_snippet.""" if not slug: console.print( - "[anton.warning]Usage: /test-data-source [/]" + "[anton.warning]Usage: /test [/]" ) console.print() return @@ -3588,18 +3588,18 @@ def _print_slash_help(console: Console) -> None: """Print available slash commands.""" console.print() console.print("[anton.cyan]Available commands:[/]") +# console.print( +# " [bold]/connect[/] — Connect to a Minds server and select a mind" +# ) console.print( - " [bold]/connect[/] — Connect to a Minds server and select a mind" + " [bold]/connect[/] — Connect a database or API to the Local Vault" ) console.print( - " [bold]/connect-data-source[/] — Connect a database or API to the Local Vault" + " [bold]/list[/] — List all saved data source connections" ) - console.print( - " [bold]/list-data-sources[/] — List all saved data source connections" - ) - console.print(" [bold]/edit-data-source[/] — Edit a saved connection's credentials") - console.print(" [bold]/remove-data-source[/] — Remove a saved connection") - console.print(" [bold]/test-data-source[/] — Test a saved connection") + console.print(" [bold]/edit[/] — Edit a saved connection's credentials") + console.print(" [bold]/remove[/] — Remove a saved connection") + console.print(" [bold]/test[/] — Test a saved connection") console.print( " [bold]/setup[/] — Configure models or memory settings" ) @@ -3978,19 +3978,20 @@ def _bottom_toolbar(): if message_content is None and stripped.startswith("/"): parts = stripped.split(maxsplit=1) cmd = parts[0].lower() - if cmd == "/connect": - session = await _handle_connect( - console, - settings, - workspace, - state, - self_awareness, - cortex, - session, - episodic=episodic, - ) - continue - elif cmd == "/setup": +# if cmd == "/connect": +# session = await _handle_connect( +# console, +# settings, +# workspace, +# state, +# self_awareness, +# cortex, +# session, +# episodic=episodic, +# ) +# continue +# elif cmd == "/setup": + if cmd == "/setup": session = await _handle_setup( console, settings, @@ -4007,7 +4008,7 @@ def _bottom_toolbar(): elif cmd == "/memory": _handle_memory(console, settings, cortex, episodic=episodic) continue - elif cmd == "/connect-data-source": + elif cmd == "/connect": arg = parts[1].strip() if len(parts) > 1 else "" session = await _handle_connect_datasource( console, @@ -4016,25 +4017,25 @@ def _bottom_toolbar(): prefill=arg or None, ) continue - elif cmd == "/list-data-sources": + elif cmd == "/list": _handle_list_data_sources(console) continue - elif cmd == "/remove-data-source": + elif cmd == "/remove": arg = parts[1].strip() if len(parts) > 1 else "" if not arg: console.print( - "[anton.warning]Usage: /remove-data-source" + "[anton.warning]Usage: /remove" " [/]" ) console.print() else: _handle_remove_data_source(console, arg) continue - elif cmd == "/edit-data-source": + elif cmd == "/edit": arg = parts[1].strip() if len(parts) > 1 else "" if not arg: console.print( - "[anton.warning]Usage: /edit-data-source [/]" + "[anton.warning]Usage: /edit [/]" ) console.print() else: @@ -4045,7 +4046,7 @@ def _bottom_toolbar(): datasource_name=arg, ) continue - elif cmd == "/test-data-source": + elif cmd == "/test": arg = parts[1].strip() if len(parts) > 1 else "" await _handle_test_datasource( console, session._scratchpads, arg diff --git a/anton/cli.py b/anton/cli.py index 5f82aa9..365973a 100644 --- a/anton/cli.py +++ b/anton/cli.py @@ -817,7 +817,7 @@ def version() -> None: console.print(f"Anton v{__version__}") -@app.command("connect-data-source") +@app.command("connect") def connect_data_source( ctx: typer.Context, slug: str = typer.Argument( @@ -827,7 +827,7 @@ def connect_data_source( """Connect a database or API to the Local Vault. Pass an existing connection slug (e.g. postgres-mydb) to reconnect using - stored credentials without re-entering them. Use /edit-data-source to + stored credentials without re-entering them. Use /edit to update credentials for an existing connection. """ import asyncio @@ -865,7 +865,7 @@ async def _run() -> None: asyncio.run(_run()) -@app.command("list-data-sources") +@app.command("list") def list_data_sources(ctx: typer.Context) -> None: """List all saved data source connections in the Local Vault.""" from anton.chat import _handle_list_data_sources @@ -873,7 +873,7 @@ def list_data_sources(ctx: typer.Context) -> None: _handle_list_data_sources(console) -@app.command("edit-data-source") +@app.command("edit") def edit_data_source( ctx: typer.Context, name: str = typer.Argument(..., help="Connection slug to edit (e.g. postgres-mydb)."), @@ -914,7 +914,7 @@ async def _run() -> None: asyncio.run(_run()) -@app.command("remove-data-source") +@app.command("remove") def remove_data_source( ctx: typer.Context, name: str = typer.Argument(..., help="Connection slug to remove (e.g. postgres-mydb)."), @@ -925,7 +925,7 @@ def remove_data_source( _handle_remove_data_source(console, name) -@app.command("test-data-source") +@app.command("test") def test_data_source( ctx: typer.Context, name: str = typer.Argument(..., help="Connection slug to test (e.g. postgres-mydb)."), diff --git a/tests/test_datasource.py b/tests/test_datasource.py index a548f03..6caa3ea 100644 --- a/tests/test_datasource.py +++ b/tests/test_datasource.py @@ -1032,11 +1032,11 @@ def test_multi_source_context_shows_both_connections(self, vault_dir): class TestCliCommandRegistration: @pytest.mark.parametrize("cmd_name", [ - "connect-data-source", - "list-data-sources", - "edit-data-source", - "remove-data-source", - "test-data-source", + "connect", + "list", + "edit", + "remove", + "test", ]) def test_command_registered(self, cmd_name): names = [cmd.name for cmd in _cli_app.registered_commands] @@ -1054,7 +1054,7 @@ def test_empty_vault_shows_message(self, vault_dir): with patch("anton.chat.DataVault", return_value=DataVault(vault_dir=vault_dir)): _handle_list_data_sources(console) printed = " ".join(str(c) for c in console.print.call_args_list) - assert "No data sources" in printed or "connect-data-source" in printed + assert "No data sources" in printed or "connect" in printed def test_complete_connection_shows_saved_with_engine_name(self, vault_dir, registry): vault = DataVault(vault_dir=vault_dir) @@ -1174,7 +1174,7 @@ async def test_empty_slug_shows_usage(self, vault_dir, registry): await _handle_test_datasource(console, scratchpads, "") printed = " ".join(str(c) for c in console.print.call_args_list) - assert "Usage" in printed or "test-data-source" in printed + assert "Usage" in printed or "test" in printed # ───────────────────────────────────────────────────────────────────────────── @@ -1413,7 +1413,7 @@ async def test_test_data_source_no_arg_shows_usage(self, vault_dir, registry): await _handle_test_datasource(console, scratchpads, "") printed = " ".join(str(c) for c in console.print.call_args_list) - assert "Usage" in printed or "test-data-source" in printed + assert "Usage" in printed or "test" in printed @pytest.mark.asyncio async def test_edit_data_source_no_arg_safe(self, vault_dir, registry, make_session): From 43e3d12edb29281bc51c66c500257c890416d2ef Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Wed, 25 Mar 2026 20:26:03 +0100 Subject: [PATCH 68/70] Fix testing for custom datasources --- anton/chat.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index de53b1f..2a21048 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -2819,6 +2819,7 @@ async def _handle_add_custom_datasource( display_name=display_name, pip=pip_pkg, fields=fields, + test_snippet=test_snippet, ) # All required fields must be present before the caller saves credentials @@ -3025,7 +3026,8 @@ async def _handle_connect_datasource( console.print( "[anton.cyan](anton)[/] Which data source would you like to connect?\n" ) - console.print(" [bold] 0.[/bold] Create a custom datasource") + console.print(" [bold] 0.[/bold] Connect to a custom datasource") + console.print(" [bold]OR Select from the list below[/bold]") for i, e in enumerate(all_engines, 1): console.print(f" [bold]{i:>2}.[/bold] {e.display_name}") console.print() @@ -3285,7 +3287,7 @@ async def _handle_connect_datasource( ) console.print() return session - + if engine_def.test_snippet: while True: console.print() From 48c67f88a80d8c3d0c22caf71f16e3d82b9b0ae4 Mon Sep 17 00:00:00 2001 From: martyna-mindsdb Date: Thu, 26 Mar 2026 13:04:07 +0100 Subject: [PATCH 69/70] updated initial output of /connect --- anton/chat.py | 11 +- datasources.md | 284 ++++++++++++++++++++++++------------------------- 2 files changed, 148 insertions(+), 147 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index a6f40fa..3c687cf 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -3024,15 +3024,16 @@ async def _handle_connect_datasource( answer = prefill else: console.print( - "[anton.cyan](anton)[/] Which data source would you like to connect?\n" + "[anton.cyan](anton)[/] Choose a data source:\n" ) - console.print(" [bold] 0.[/bold] Connect to a custom datasource") - console.print(" [bold]OR Select from the list below[/bold]") + console.print(" [bold] Primary") + console.print(" [bold] 0.[/bold] Custom datasource (connect anything via API, SQL, or MCP)\n") + console.print(" [bold] Predefined") for i, e in enumerate(all_engines, 1): - console.print(f" [bold]{i:>2}.[/bold] {e.display_name}") + console.print(f" [bold]{i:>2}.[/bold] {e.display_name}") console.print() answer = Prompt.ask( - "[anton.cyan](anton)[/] Enter a number, or type the name", + "[anton.cyan](anton)[/] Enter a number or type a name", console=console, ) diff --git a/datasources.md b/datasources.md index 7a9b85e..2cafc88 100644 --- a/datasources.md +++ b/datasources.md @@ -72,115 +72,6 @@ test_snippet: | --- -## MariaDB - -```yaml -engine: mariadb -display_name: MariaDB -pip: mysql-connector-python -name_from: database -fields: - - { name: host, required: true, secret: false, description: "hostname or IP of your MariaDB server" } - - { name: port, required: true, secret: false, description: "port number", default: "3306" } - - { name: database, required: true, secret: false, description: "database name to connect" } - - { name: user, required: true, secret: false, description: "MariaDB username" } - - { name: password, required: true, secret: true, description: "MariaDB password" } - - { name: ssl, required: false, secret: false, description: "enable SSL (true/false)" } - - { name: ssl_ca, required: false, secret: false, description: "path to CA certificate file" } - - { name: ssl_cert, required: false, secret: false, description: "path to client certificate file" } - - { name: ssl_key, required: false, secret: false, description: "path to client private key file" } -test_snippet: | - import mysql.connector, os - conn = mysql.connector.connect( - host=os.environ['DS_HOST'], - port=int(os.environ.get('DS_PORT', '3306')), - database=os.environ['DS_DATABASE'], - user=os.environ['DS_USER'], - password=os.environ['DS_PASSWORD'], - ) - conn.close() - print("ok") -``` - -MariaDB is wire-compatible with MySQL, so the mysql-connector-python driver works for both. - ---- - -## Microsoft SQL Server - -```yaml -engine: mssql -display_name: Microsoft SQL Server -pip: pymssql -name_from: database -fields: - - { name: host, required: true, secret: false, description: "hostname or IP of the SQL Server (for Azure use server field instead)" } - - { name: port, required: true, secret: false, description: "port number", default: "1433" } - - { name: database, required: true, secret: false, description: "database name" } - - { name: user, required: true, secret: false, description: "SQL Server username" } - - { name: password, required: true, secret: true, description: "SQL Server password" } - - { name: server, required: false, secret: false, description: "server name — use for named instances or Azure SQL (e.g. myserver.database.windows.net)" } - - { name: schema, required: false, secret: false, description: "schema name (defaults to dbo)" } -test_snippet: | - import pymssql, os - conn = pymssql.connect( - server=os.environ.get('DS_SERVER') or os.environ['DS_HOST'], - port=os.environ.get('DS_PORT', '1433'), - database=os.environ['DS_DATABASE'], - user=os.environ['DS_USER'], - password=os.environ['DS_PASSWORD'], - ) - conn.close() - print("ok") -``` - -For Azure SQL Database use `server` field with value like `myserver.database.windows.net`. -For Windows Authentication omit user/password and ensure pymssql is built with Kerberos support. - ---- - -## HubSpot - -```yaml -engine: hubspot -display_name: HubSpot -pip: hubspot-api-client -auth_method: choice -auth_methods: - - name: pat - display: "Private App Token (recommended)" - fields: - - { name: access_token, required: true, secret: true, description: "HubSpot Private App token (starts with pat-na1-)" } - - name: oauth2 - display: "OAuth2 (for multi-account or publishable apps)" - fields: - - { name: client_id, required: true, secret: false, description: "OAuth2 client ID" } - - { name: client_secret, required: true, secret: true, description: "OAuth2 client secret" } - oauth2: - auth_url: https://app.hubspot.com/oauth/authorize - token_url: https://api.hubapi.com/oauth/v1/token - scopes: [crm.objects.contacts.read, crm.objects.deals.read] - store_fields: [access_token, refresh_token] -test_snippet: | - import hubspot, os - client = hubspot.Client.create(access_token=os.environ['DS_ACCESS_TOKEN']) - client.crm.contacts.basic_api.get_page(limit=1) - print("ok") -``` - -For Private App Token: HubSpot → Settings → Integrations → Private Apps → Create. -Recommended scopes: `crm.objects.contacts.read`, `crm.objects.deals.read`, `crm.objects.companies.read`. - -For OAuth2: collect client_id and client_secret, then use the scratchpad to: - -1. Build the authorization URL using `auth_url` + params above -2. Start a local HTTP server on port 8099 to catch the callback -3. Open the URL in the user's browser with `webbrowser.open()` -4. Extract the `code` from the callback, POST to `token_url` for tokens -5. Return `access_token` and `refresh_token` to store in wallet - ---- - ## Snowflake ```yaml @@ -230,39 +121,6 @@ Format is either `-` or `.. --- -## Oracle Database - -```yaml -engine: oracle_database -display_name: Oracle Database -pip: oracledb -name_from: [host, service_name] -fields: - - { name: user, required: true, secret: false, description: "Oracle database username" } - - { name: password, required: true, secret: true, description: "Oracle database password" } - - { name: host, required: true, secret: false, description: "hostname or IP address of the Oracle server" } - - { name: port, required: true, secret: false, description: "port number (default 1521)", default: "1521" } - - { name: service_name, required: false, secret: false, description: "Oracle service name (preferred over SID)" } - - { name: sid, required: false, secret: false, description: "Oracle SID — use service_name if possible" } - - { name: dsn, required: false, secret: false, description: "full DSN string — overrides host/port/service_name" } - - { name: auth_mode, required: false, secret: false, description: "authorization mode (e.g. SYSDBA)" } -test_snippet: | - import oracledb, os - dsn = os.environ.get('DS_DSN') or oracledb.makedsn( - os.environ.get('DS_HOST', 'localhost'), - os.environ.get('DS_PORT', '1521'), - service_name=os.environ.get('DS_SERVICE_NAME', ''), - ) - conn = oracledb.connect(user=os.environ['DS_USER'], password=os.environ['DS_PASSWORD'], dsn=dsn) - conn.close() - print("ok") -``` - -oracledb runs in thin mode by default (no Oracle Client libraries needed). -Set `auth_mode` to `SYSDBA` or `SYSOPER` for privileged connections. - ---- - ## Google BigQuery ```yaml @@ -298,6 +156,39 @@ Keys → Add Key → JSON. Grant the account `BigQuery Data Viewer` + `BigQuery --- +## Microsoft SQL Server + +```yaml +engine: mssql +display_name: Microsoft SQL Server +pip: pymssql +name_from: database +fields: + - { name: host, required: true, secret: false, description: "hostname or IP of the SQL Server (for Azure use server field instead)" } + - { name: port, required: true, secret: false, description: "port number", default: "1433" } + - { name: database, required: true, secret: false, description: "database name" } + - { name: user, required: true, secret: false, description: "SQL Server username" } + - { name: password, required: true, secret: true, description: "SQL Server password" } + - { name: server, required: false, secret: false, description: "server name — use for named instances or Azure SQL (e.g. myserver.database.windows.net)" } + - { name: schema, required: false, secret: false, description: "schema name (defaults to dbo)" } +test_snippet: | + import pymssql, os + conn = pymssql.connect( + server=os.environ.get('DS_SERVER') or os.environ['DS_HOST'], + port=os.environ.get('DS_PORT', '1433'), + database=os.environ['DS_DATABASE'], + user=os.environ['DS_USER'], + password=os.environ['DS_PASSWORD'], + ) + conn.close() + print("ok") +``` + +For Azure SQL Database use `server` field with value like `myserver.database.windows.net`. +For Windows Authentication omit user/password and ensure pymssql is built with Kerberos support. + +--- + ## Amazon Redshift ```yaml @@ -365,6 +256,115 @@ HTTP path and server hostname: SQL Warehouses → your warehouse → Connection --- +## MariaDB + +```yaml +engine: mariadb +display_name: MariaDB +pip: mysql-connector-python +name_from: database +fields: + - { name: host, required: true, secret: false, description: "hostname or IP of your MariaDB server" } + - { name: port, required: true, secret: false, description: "port number", default: "3306" } + - { name: database, required: true, secret: false, description: "database name to connect" } + - { name: user, required: true, secret: false, description: "MariaDB username" } + - { name: password, required: true, secret: true, description: "MariaDB password" } + - { name: ssl, required: false, secret: false, description: "enable SSL (true/false)" } + - { name: ssl_ca, required: false, secret: false, description: "path to CA certificate file" } + - { name: ssl_cert, required: false, secret: false, description: "path to client certificate file" } + - { name: ssl_key, required: false, secret: false, description: "path to client private key file" } +test_snippet: | + import mysql.connector, os + conn = mysql.connector.connect( + host=os.environ['DS_HOST'], + port=int(os.environ.get('DS_PORT', '3306')), + database=os.environ['DS_DATABASE'], + user=os.environ['DS_USER'], + password=os.environ['DS_PASSWORD'], + ) + conn.close() + print("ok") +``` + +MariaDB is wire-compatible with MySQL, so the mysql-connector-python driver works for both. + +--- + +## HubSpot + +```yaml +engine: hubspot +display_name: HubSpot +pip: hubspot-api-client +auth_method: choice +auth_methods: + - name: pat + display: "Private App Token (recommended)" + fields: + - { name: access_token, required: true, secret: true, description: "HubSpot Private App token (starts with pat-na1-)" } + - name: oauth2 + display: "OAuth2 (for multi-account or publishable apps)" + fields: + - { name: client_id, required: true, secret: false, description: "OAuth2 client ID" } + - { name: client_secret, required: true, secret: true, description: "OAuth2 client secret" } + oauth2: + auth_url: https://app.hubspot.com/oauth/authorize + token_url: https://api.hubapi.com/oauth/v1/token + scopes: [crm.objects.contacts.read, crm.objects.deals.read] + store_fields: [access_token, refresh_token] +test_snippet: | + import hubspot, os + client = hubspot.Client.create(access_token=os.environ['DS_ACCESS_TOKEN']) + client.crm.contacts.basic_api.get_page(limit=1) + print("ok") +``` + +For Private App Token: HubSpot → Settings → Integrations → Private Apps → Create. +Recommended scopes: `crm.objects.contacts.read`, `crm.objects.deals.read`, `crm.objects.companies.read`. + +For OAuth2: collect client_id and client_secret, then use the scratchpad to: + +1. Build the authorization URL using `auth_url` + params above +2. Start a local HTTP server on port 8099 to catch the callback +3. Open the URL in the user's browser with `webbrowser.open()` +4. Extract the `code` from the callback, POST to `token_url` for tokens +5. Return `access_token` and `refresh_token` to store in wallet + +--- + +## Oracle Database + +```yaml +engine: oracle_database +display_name: Oracle Database +pip: oracledb +name_from: [host, service_name] +fields: + - { name: user, required: true, secret: false, description: "Oracle database username" } + - { name: password, required: true, secret: true, description: "Oracle database password" } + - { name: host, required: true, secret: false, description: "hostname or IP address of the Oracle server" } + - { name: port, required: true, secret: false, description: "port number (default 1521)", default: "1521" } + - { name: service_name, required: false, secret: false, description: "Oracle service name (preferred over SID)" } + - { name: sid, required: false, secret: false, description: "Oracle SID — use service_name if possible" } + - { name: dsn, required: false, secret: false, description: "full DSN string — overrides host/port/service_name" } + - { name: auth_mode, required: false, secret: false, description: "authorization mode (e.g. SYSDBA)" } +test_snippet: | + import oracledb, os + dsn = os.environ.get('DS_DSN') or oracledb.makedsn( + os.environ.get('DS_HOST', 'localhost'), + os.environ.get('DS_PORT', '1521'), + service_name=os.environ.get('DS_SERVICE_NAME', ''), + ) + conn = oracledb.connect(user=os.environ['DS_USER'], password=os.environ['DS_PASSWORD'], dsn=dsn) + conn.close() + print("ok") +``` + +oracledb runs in thin mode by default (no Oracle Client libraries needed). +Set `auth_mode` to `SYSDBA` or `SYSOPER` for privileged connections. + +--- + ## DuckDB ```yaml From 760ab33fd144ff2f146d75d57d0818a28739967c Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Thu, 26 Mar 2026 16:04:27 +0100 Subject: [PATCH 70/70] Fix testing the custom datasources --- anton/chat.py | 160 ++++++++++++++++-------------- tests/test_datasource.py | 203 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 290 insertions(+), 73 deletions(-) diff --git a/anton/chat.py b/anton/chat.py index 3c687cf..0ec523b 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -2835,6 +2835,78 @@ async def _handle_add_custom_datasource( return engine_def, credentials +async def _run_connection_test( + console: "Console", + scratchpads: "ScratchpadManager", + vault: "DataVault", + engine_def: "DatasourceEngine", + credentials: dict[str, str], + retry_fields: "list[DatasourceField]", +) -> bool: + """Inject flat DS_* vars, run engine_def.test_snippet, restore env. + + Returns True on success, False if the user declines retry after failure. + Mutates credentials in-place when the user re-enters secrets on retry. + """ + import os as _os + + while True: + console.print() + console.print("[anton.cyan](anton)[/] Got it. Testing connection…") + + vault.clear_ds_env() + for key, value in credentials.items(): + _os.environ[f"DS_{key.upper()}"] = value + _register_secret_vars(engine_def) # flat mode, for scrubbing during test + + try: + pad = await scratchpads.get_or_create("__datasource_test__") + await pad.reset() + if engine_def.pip: + await pad.install_packages([engine_def.pip]) + cell = await pad.execute(engine_def.test_snippet) + finally: + _restore_namespaced_env(vault) + + if cell.error or (cell.stdout.strip() != "ok" and cell.stderr.strip()): + error_text = cell.error or cell.stderr.strip() or cell.stdout.strip() + first_line = next( + (ln for ln in error_text.splitlines() if ln.strip()), error_text + ) + console.print() + console.print("[anton.warning](anton)[/] ✗ Connection failed.") + console.print() + console.print(f" Error: {first_line}") + console.print() + retry = ( + Prompt.ask( + "[anton.cyan](anton)[/] Would you like to re-enter your credentials? [y/n]", + console=console, + default="n", + ) + .strip() + .lower() + ) + if retry != "y": + return False + console.print() + for f in retry_fields: + if not f.secret: + continue + value = Prompt.ask( + f"[anton.cyan](anton)[/] {f.name}", + password=True, + console=console, + default="", + ) + if value: + credentials[f.name] = value + continue + + console.print("[anton.success] ✓ Connected successfully![/]") + return True + + async def _handle_connect_datasource( console: Console, scratchpads: ScratchpadManager, @@ -3066,12 +3138,12 @@ async def _handle_connect_datasource( return session engine_def: DatasourceEngine | None = None - _go_custom = False + custom_source = False if stripped_answer.isdigit() or (stripped_answer.lstrip("-").isdigit()): pick_num = int(stripped_answer) if pick_num == 0: - _go_custom = True + custom_source = True elif 1 <= pick_num <= len(all_engines): engine_def = all_engines[pick_num - 1] else: @@ -3082,7 +3154,7 @@ async def _handle_connect_datasource( console.print() return session - if engine_def is None and not _go_custom: + if engine_def is None and not custom_source: engine_def = registry.find_by_name(stripped_answer) # if exact match not found, try substring match against display and engine names if engine_def is None: @@ -3136,15 +3208,20 @@ async def _handle_connect_datasource( break if engine_def is None: - _go_custom = True + custom_source = True - if _go_custom: + if custom_source: result = await _handle_add_custom_datasource( console, stripped_answer if not stripped_answer.isdigit() else "", registry, session ) if result is None: return session engine_def, credentials = result + if engine_def.test_snippet: + if not await _run_connection_test( + console, scratchpads, vault, engine_def, credentials, engine_def.fields + ): + return session conn_name = uuid.uuid4().hex[:8] vault.save(engine_def.engine, conn_name, credentials) slug = f"{engine_def.engine}-{conn_name}" @@ -3170,6 +3247,7 @@ async def _handle_connect_datasource( ) return session + assert engine_def is not None # custom_source path always returns before this line active_fields = engine_def.fields if engine_def.auth_method == "choice" and engine_def.auth_methods: console.print() @@ -3290,74 +3368,10 @@ async def _handle_connect_datasource( return session if engine_def.test_snippet: - while True: - console.print() - console.print("[anton.cyan](anton)[/] Got it. Testing connection…") - - # Temporarily inject flat DS_* vars for test_snippet execution. - # conn_name is not yet known, so inject directly from credentials. - import os as _os - - vault.clear_ds_env() - for key, value in credentials.items(): - _os.environ[f"DS_{key.upper()}"] = value - _register_secret_vars(engine_def) # flat mode, for scrubbing during test - - try: - pad = await scratchpads.get_or_create("__datasource_test__") - await pad.reset() # fresh subprocess inherits current os.environ - - if engine_def.pip: - await pad.install_packages([engine_def.pip]) - - cell = await pad.execute(engine_def.test_snippet) - finally: - _restore_namespaced_env(vault) - - if cell.error or (cell.stdout.strip() != "ok" and cell.stderr.strip()): - error_text = cell.error or cell.stderr.strip() or cell.stdout.strip() - # Show first meaningful line of the error - first_line = next( - (ln for ln in error_text.splitlines() if ln.strip()), error_text - ) - console.print() - console.print("[anton.warning](anton)[/] ✗ Connection failed.") - console.print() - console.print(f" Error: {first_line}") - console.print() - - retry = ( - Prompt.ask( - "[anton.cyan](anton)[/] Would you like to re-enter your credentials? [y/n]", - console=console, - default="n", - ) - .strip() - .lower() - ) - - if retry != "y": - return session - - # Re-collect secret fields only - console.print() - for f in active_fields: - if not f.secret: - continue - value = Prompt.ask( - f"[anton.cyan](anton)[/] {f.name}", - password=True, - console=console, - default="", - ) - if value: - credentials[f.name] = value - # Try again with updated credentials - continue - - # Success - console.print("[anton.success] ✓ Connected successfully![/]") - break + if not await _run_connection_test( + console, scratchpads, vault, engine_def, credentials, active_fields + ): + return session conn_name = registry.derive_name(engine_def, credentials) if not conn_name: diff --git a/tests/test_datasource.py b/tests/test_datasource.py index 6caa3ea..3989cdf 100644 --- a/tests/test_datasource.py +++ b/tests/test_datasource.py @@ -1821,3 +1821,206 @@ async def test_incomplete_custom_datasource_not_saved( # Must return None — caller must not call vault.save() assert result is None + + +class TestCustomDatasourceConnectFlow: + """Tests for the custom_source path in _handle_connect_datasource: + test_snippet is run before saving, and failures prevent saving.""" + + # ── helpers (mirrors TestAddCustomDatasourceFlow) ──────────────────── + + def _make_llm_response( + self, + fields: list[dict], + display_name: str = "My API Service", + test_snippet: str = "", + ) -> str: + import json as _json + return _json.dumps({ + "display_name": display_name, + "pip": "", + "test_snippet": test_snippet, + "fields": fields, + }) + + def _make_registry(self, tmp_path): + reg = MagicMock() + reg.all_engines.return_value = [] + reg.find_by_name.return_value = None + reg.fuzzy_find.return_value = [] + reg.validate_file.return_value = {"my_api_service": MagicMock()} + reg.reload.return_value = None + reg.get.return_value = None # triggers inline fallback engine_def + return reg + + def _make_llm(self, json_text: str): + llm = AsyncMock() + response = MagicMock() + response.content = json_text + llm.plan = AsyncMock(return_value=response) + return llm + + def _mock_ds_path(self, mock_path_cls, tmp_path): + mock_path_cls.return_value.expanduser.return_value = tmp_path / "datasources.md" + + # ── tests ───────────────────────────────────────────────────────────── + + @pytest.mark.asyncio + async def test_custom_with_test_snippet_success( + self, vault_dir, make_session, make_cell, tmp_path + ): + """Custom datasource with test_snippet: test passes → connection saved.""" + session = make_session() + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + pad = AsyncMock() + pad.execute = AsyncMock(return_value=make_cell(stdout="ok")) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + session._llm = self._make_llm(self._make_llm_response( + [{"name": "api_key", "value": "", "secret": True, "required": True, "description": "API key"}], + test_snippet="print('ok')", + )) + + prompt_responses = iter([ + "0", # choose custom + "My API Service", # tool name + "I have an API key", # auth description + "my_secret_key", # api_key (secret prompt) + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=self._make_registry(tmp_path)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + patch("anton.chat.Path") as mock_path_cls, + ): + self._mock_ds_path(mock_path_cls, tmp_path) + result = await _handle_connect_datasource(console, session._scratchpads, session) + + conns = vault.list_connections() + assert len(conns) == 1 + saved = vault.load(conns[0]["engine"], conns[0]["name"]) + assert saved is not None + assert saved.get("api_key") == "my_secret_key" + assert result._history + assert result._history[-1]["role"] == "assistant" + pad.execute.assert_called_once() + + @pytest.mark.asyncio + async def test_custom_with_test_snippet_fail_no_retry( + self, vault_dir, make_session, make_cell, tmp_path + ): + """Custom datasource: test fails and user declines retry → not saved.""" + session = make_session() + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + pad = AsyncMock() + pad.execute = AsyncMock(return_value=make_cell(stdout="", stderr="connection refused")) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + session._llm = self._make_llm(self._make_llm_response( + [{"name": "api_key", "value": "", "secret": True, "required": True, "description": "API key"}], + test_snippet="print('ok')", + )) + + prompt_responses = iter([ + "0", + "My API Service", + "I have an API key", + "bad_key", # api_key + "n", # retry? + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=self._make_registry(tmp_path)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + patch("anton.chat.Path") as mock_path_cls, + ): + self._mock_ds_path(mock_path_cls, tmp_path) + result = await _handle_connect_datasource(console, session._scratchpads, session) + + assert vault.list_connections() == [] + assert not result._history + + @pytest.mark.asyncio + async def test_custom_with_test_snippet_fail_retry_success( + self, vault_dir, make_session, make_cell, tmp_path + ): + """Custom datasource: test fails, user retries with corrected creds → saved.""" + session = make_session() + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + pad = AsyncMock() + pad.execute = AsyncMock(side_effect=[ + make_cell(stdout="", stderr="invalid key"), + make_cell(stdout="ok"), + ]) + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + session._llm = self._make_llm(self._make_llm_response( + [{"name": "api_key", "value": "", "secret": True, "required": True, "description": "API key"}], + test_snippet="print('ok')", + )) + + prompt_responses = iter([ + "0", + "My API Service", + "I have an API key", + "bad_key", # api_key first attempt + "y", # retry? + "good_key", # api_key retry + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=self._make_registry(tmp_path)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + patch("anton.chat.Path") as mock_path_cls, + ): + self._mock_ds_path(mock_path_cls, tmp_path) + result = await _handle_connect_datasource(console, session._scratchpads, session) + + conns = vault.list_connections() + assert len(conns) == 1 + saved = vault.load(conns[0]["engine"], conns[0]["name"]) + assert saved is not None + assert saved.get("api_key") == "good_key" + assert result._history + + @pytest.mark.asyncio + async def test_custom_without_test_snippet_saves( + self, vault_dir, make_session, make_cell, tmp_path + ): + """Custom datasource without test_snippet: saves directly, no scratchpad call.""" + session = make_session() + console = MagicMock() + vault = DataVault(vault_dir=vault_dir) + + pad = AsyncMock() + session._scratchpads.get_or_create = AsyncMock(return_value=pad) + session._llm = self._make_llm(self._make_llm_response( + [{"name": "api_key", "value": "", "secret": True, "required": True, "description": "API key"}], + test_snippet="", + )) + + prompt_responses = iter([ + "0", + "My API Service", + "I have an API key", + "my_key", # api_key + ]) + + with ( + patch("anton.chat.DataVault", return_value=vault), + patch("anton.chat.DatasourceRegistry", return_value=self._make_registry(tmp_path)), + patch("rich.prompt.Prompt.ask", side_effect=lambda *a, **kw: next(prompt_responses)), + patch("anton.chat.Path") as mock_path_cls, + ): + self._mock_ds_path(mock_path_cls, tmp_path) + await _handle_connect_datasource(console, session._scratchpads, session) + + conns = vault.list_connections() + assert len(conns) == 1 + pad.execute.assert_not_called()