diff --git a/.github/workflows/ci-release.yml b/.github/workflows/ci-release.yml index b15931a3..fcaf8fee 100644 --- a/.github/workflows/ci-release.yml +++ b/.github/workflows/ci-release.yml @@ -22,6 +22,7 @@ jobs: test-python: name: Python Test & Lint runs-on: ubuntu-latest + timeout-minutes: 40 steps: - uses: actions/checkout@v4 @@ -47,7 +48,7 @@ jobs: run: uv run ruff check . - name: Run tests - run: uv run pytest -v --cov=src/kurt --cov-report=term-missing + run: uv run pytest -v -n auto --ignore=src/kurt/tools/e2e --ignore-glob="**/test_*e2e.py" # Run frontend tests test-frontend: diff --git a/pyproject.toml b/pyproject.toml index 3e432b13..6e50118a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -128,5 +128,6 @@ dev = [ "pytest>=8.4.2", "pytest-asyncio>=0.24.0", "pytest-cov>=4.1.0", + "pytest-xdist>=3.0.0", "ruff>=0.1.0", ] diff --git a/src/kurt/cli/doctor.py b/src/kurt/cli/doctor.py index 752f419f..c12be17d 100644 --- a/src/kurt/cli/doctor.py +++ b/src/kurt/cli/doctor.py @@ -10,13 +10,16 @@ 3. branch_sync: Git branch matches Dolt branch 4. no_uncommitted_dolt: Dolt status is clean 5. remotes_configured: Both Git and Dolt have 'origin' remote -6. sql_server: Dolt SQL server reachable (server mode required) -7. no_stale_locks: No .git/kurt-hook.lock older than 30s +6. no_stale_dolt_locks: No stale .dolt/noms/LOCK files +7. stale_server_info: No stale .dolt/kurt-server.json with dead PIDs +8. sql_server: Dolt SQL server reachable (server mode required) +9. no_stale_locks: No .git/kurt-hook.lock older than 30s SQL Runtime: Kurt uses server mode exclusively for SQL operations. The dolt sql-server is auto-started for local targets (localhost) if not running. Remote servers -must be started and accessible independently. +must be started and accessible independently. Each project gets its own +server on its own port to avoid conflicts. """ from __future__ import annotations @@ -411,7 +414,7 @@ def check_sql_server(dolt_path: Path) -> CheckResult: if server_running: # For local servers, verify it's the correct project's server if is_local: - info_file = dolt_path / "sql-server.info" + info_file = dolt_path / "kurt-server.json" if info_file.exists(): try: import json as json_mod @@ -455,6 +458,109 @@ def check_sql_server(dolt_path: Path) -> CheckResult: ) +def check_no_stale_dolt_locks(dolt_path: Path) -> CheckResult: + """Check for stale Dolt noms LOCK files. + + These can occur when dolt sql-server is killed ungracefully (e.g., pkill). + A LOCK file is considered stale only if no server is running for this project. + """ + lock_file = dolt_path / "noms" / "LOCK" + + if not lock_file.exists(): + return CheckResult( + name="no_stale_dolt_locks", + status=CheckStatus.PASS, + message="No Dolt lock files present", + ) + + # Check if a server is running for this project + info_file = dolt_path / "kurt-server.json" + server_running = False + if info_file.exists(): + try: + import json as json_mod + + info = json_mod.loads(info_file.read_text()) + pid = info.get("pid") + if pid: + try: + os.kill(pid, 0) # Check if process exists + server_running = True + except OSError: + pass + except Exception: + pass + + if server_running: + return CheckResult( + name="no_stale_dolt_locks", + status=CheckStatus.PASS, + message="Dolt lock file present (server running)", + ) + + # No server running but lock file exists - it's stale + return CheckResult( + name="no_stale_dolt_locks", + status=CheckStatus.FAIL, + message="Stale Dolt lock file (no server running)", + details="Run 'kurt repair' to remove stale lock", + ) + + +def check_stale_server_info(dolt_path: Path) -> CheckResult: + """Check for stale kurt-server.json files with dead PIDs. + + The kurt-server.json file tracks which project started a server. + If the PID is dead but the file exists, it's stale and can cause + connection issues. + """ + info_file = dolt_path / "kurt-server.json" + + if not info_file.exists(): + return CheckResult( + name="stale_server_info", + status=CheckStatus.PASS, + message="No server info file", + ) + + try: + import json as json_mod + + info = json_mod.loads(info_file.read_text()) + pid = info.get("pid") + + if pid is None: + return CheckResult( + name="stale_server_info", + status=CheckStatus.PASS, + message="Server info file present (no PID)", + ) + + # Check if process is still running + try: + os.kill(pid, 0) # Signal 0 = check if process exists + return CheckResult( + name="stale_server_info", + status=CheckStatus.PASS, + message=f"Server info valid (PID {pid} running)", + ) + except OSError: + # Process doesn't exist - stale info file + return CheckResult( + name="stale_server_info", + status=CheckStatus.WARN, + message=f"Stale server info (PID {pid} not running)", + details="Run 'kurt repair' to clean up", + ) + except Exception as e: + return CheckResult( + name="stale_server_info", + status=CheckStatus.WARN, + message="Could not check server info file", + details=str(e), + ) + + def check_no_stale_locks(git_path: Path) -> CheckResult: """Check for stale kurt-hook.lock files.""" lock_dir = git_path / ".git" / "kurt-hook.lock" @@ -463,7 +569,7 @@ def check_no_stale_locks(git_path: Path) -> CheckResult: return CheckResult( name="no_stale_locks", status=CheckStatus.PASS, - message="No lock files present", + message="No Git hook lock files present", ) # Check lock age @@ -532,6 +638,8 @@ def run_doctor(git_path: Path, dolt_path: Path) -> DoctorReport: checks.append(check_branch_sync(git_path, dolt_path)) checks.append(check_no_uncommitted_dolt(dolt_path)) checks.append(check_remotes_configured(git_path, dolt_path)) + checks.append(check_no_stale_dolt_locks(dolt_path)) + checks.append(check_stale_server_info(dolt_path)) checks.append(check_sql_server(dolt_path)) checks.append(check_no_stale_locks(git_path)) @@ -553,7 +661,9 @@ def run_doctor(git_path: Path, dolt_path: Path) -> DoctorReport: @click.command(name="doctor") -@click.option("--json", "as_json", is_flag=True, help="Output as JSON (deprecated: use global --json)") +@click.option( + "--json", "as_json", is_flag=True, help="Output as JSON (deprecated: use global --json)" +) @click.pass_context def doctor_cmd(ctx, as_json: bool): """Check project health and report issues. @@ -566,8 +676,10 @@ def doctor_cmd(ctx, as_json: bool): 3. branch_sync: Git branch matches Dolt branch 4. no_uncommitted_dolt: Dolt status is clean 5. remotes_configured: Both Git and Dolt have 'origin' remote - 6. sql_server: Dolt SQL server is reachable (server mode required) - 7. no_stale_locks: No .git/kurt-hook.lock older than 30s + 6. no_stale_dolt_locks: No stale .dolt/noms/LOCK files + 7. stale_server_info: No stale .dolt/kurt-server.json with dead PIDs + 8. sql_server: Dolt SQL server is reachable (server mode required) + 9. no_stale_locks: No .git/kurt-hook.lock older than 30s Exit codes: 0: All checks passed @@ -708,10 +820,26 @@ def get_repair_actions(report: DoctorReport) -> list[RepairAction]: actions.append( RepairAction( check_name="no_stale_locks", - description="Remove stale lock file", + description="Remove stale Git hook lock file", action="remove_lock", ) ) + elif check.name == "no_stale_dolt_locks" and check.status == CheckStatus.FAIL: + actions.append( + RepairAction( + check_name="no_stale_dolt_locks", + description="Remove stale Dolt lock file", + action="remove_dolt_lock", + ) + ) + elif check.name == "stale_server_info" and check.status == CheckStatus.WARN: + actions.append( + RepairAction( + check_name="stale_server_info", + description="Clean up stale server info file", + action="clean_server_info", + ) + ) elif check.name == "sql_server" and check.status == CheckStatus.FAIL: # Use consistent server config parsing host, _ = _parse_server_config() @@ -785,7 +913,7 @@ def do_commit_dolt(dolt_path: Path) -> bool: def do_remove_lock(git_path: Path) -> bool: - """Remove stale lock file.""" + """Remove stale Git hook lock file.""" import shutil lock_dir = git_path / ".git" / "kurt-hook.lock" @@ -798,6 +926,51 @@ def do_remove_lock(git_path: Path) -> bool: return True +def do_remove_dolt_lock(dolt_path: Path) -> bool: + """Remove stale Dolt noms LOCK file. + + This file can become stale when dolt sql-server is killed ungracefully. + """ + lock_file = dolt_path / "noms" / "LOCK" + if lock_file.exists(): + try: + lock_file.unlink() + return True + except Exception: + return False + return True + + +def do_clean_server_info(dolt_path: Path) -> bool: + """Clean up stale kurt-server.json file. + + Removes the file if the PID it references is no longer running. + """ + info_file = dolt_path / "kurt-server.json" + if not info_file.exists(): + return True + + try: + import json as json_mod + + info = json_mod.loads(info_file.read_text()) + pid = info.get("pid") + + # If no PID or PID is dead, remove the file + if pid is not None: + try: + os.kill(pid, 0) + # Process is running - don't remove + return False + except OSError: + pass # Process dead, continue to remove + + info_file.unlink() + return True + except Exception: + return False + + def do_start_server(dolt_path: Path) -> bool: """Start Dolt SQL server for local development. @@ -834,13 +1007,14 @@ def do_start_server(dolt_path: Path) -> bool: # Start the server try: - subprocess.Popen( + proc = subprocess.Popen( ["dolt", "sql-server", "--port", str(port), "--host", "127.0.0.1"], cwd=dolt_path.parent, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, start_new_session=True, ) + server_pid = proc.pid # Wait for server to be ready import time @@ -853,11 +1027,15 @@ def do_start_server(dolt_path: Path) -> bool: sock.close() if result == 0: # Write server info file - info_file = dolt_path / "sql-server.info" + info_file = dolt_path / "kurt-server.json" try: import json as json_mod - info = {"path": str(dolt_path.parent.resolve()), "port": port} + info = { + "path": str(dolt_path.parent.resolve()), + "port": port, + "pid": server_pid, + } info_file.write_text(json_mod.dumps(info)) except Exception: pass @@ -904,7 +1082,9 @@ def repair_cmd(dry_run: bool, yes: bool, check_name: str | None, force: bool): - hooks_installed=fail: Reinstall Git hooks - branch_sync=fail: Sync Dolt branch to match Git - no_uncommitted_dolt=warn: Commit pending Dolt changes - - no_stale_locks=fail: Remove stale lock files + - no_stale_locks=fail: Remove stale Git hook lock files + - no_stale_dolt_locks=fail: Remove stale Dolt noms LOCK files + - stale_server_info=warn: Clean up stale server info files - sql_server=fail: Start Dolt SQL server (local only) SQL Server Repair: @@ -918,6 +1098,7 @@ def repair_cmd(dry_run: bool, yes: bool, check_name: str | None, force: bool): kurt repair --dry-run # Preview repairs kurt repair --check=hooks_installed # Fix specific check kurt repair --check=sql_server # Start local SQL server + kurt repair --check=no_stale_dolt_locks # Remove stale Dolt locks kurt repair --yes # Skip confirmations """ try: @@ -982,6 +1163,10 @@ def repair_cmd(dry_run: bool, yes: bool, check_name: str | None, force: bool): success = do_commit_dolt(dolt_path) elif action.action == "remove_lock": success = do_remove_lock(git_path) + elif action.action == "remove_dolt_lock": + success = do_remove_dolt_lock(dolt_path) + elif action.action == "clean_server_info": + success = do_clean_server_info(dolt_path) elif action.action == "start_server": success = do_start_server(dolt_path) elif action.action == "notify_remote_server": diff --git a/src/kurt/cli/main.py b/src/kurt/cli/main.py index d56aaefd..b9e0a9f8 100644 --- a/src/kurt/cli/main.py +++ b/src/kurt/cli/main.py @@ -18,7 +18,6 @@ - Command aliases for LLM typo tolerance (e.g., doc -> docs) """ - import click from dotenv import load_dotenv @@ -53,8 +52,10 @@ def _auto_migrate_schema(): dolt_path = Path.cwd() / ".dolt" if dolt_path.exists(): from kurt.db.dolt import DoltDB, check_schema_exists, init_observability_schema + from kurt.observability.tracking import init_tracking db = DoltDB(Path.cwd()) + init_tracking(db) # Enable global tracking for track_event() calls schema_status = check_schema_exists(db) # Only initialize if any table is missing @@ -206,8 +207,9 @@ def main(ctx, json_output: bool, quiet: bool): ctx.obj["output"] = OutputContext(json_output, quiet) ctx.obj["json_output"] = json_output # Backwards compat for existing commands - # Skip auto-migrate for init command (no DB yet) - if ctx.invoked_subcommand in ["init", "help"]: + # Skip auto-migrate for commands that don't need DB or use Dolt CLI directly + # doctor/repair use Dolt CLI commands which conflict with the auto-started server + if ctx.invoked_subcommand in ["init", "help", "doctor", "repair"]: return # Skip if no project initialized diff --git a/src/kurt/cli/tests/test_doctor.py b/src/kurt/cli/tests/test_doctor.py index 8af06c73..43d4b337 100644 --- a/src/kurt/cli/tests/test_doctor.py +++ b/src/kurt/cli/tests/test_doctor.py @@ -260,7 +260,7 @@ def test_no_lock_file(self, tmp_path: Path): (tmp_path / ".git").mkdir() result = check_no_stale_locks(tmp_path) assert result.status == CheckStatus.PASS - assert "No lock" in result.message + assert "lock" in result.message.lower() def test_stale_lock_file(self, tmp_path: Path): """Test when stale lock file exists.""" @@ -558,9 +558,7 @@ def test_repair_specific_check(self, cli_runner: CliRunner, temp_git_repo: Path) summary={"passed": 0, "failed": 2, "warnings": 0}, exit_code=1, ) - result = cli_runner.invoke( - repair_cmd, ["--dry-run", "--check=hooks_installed"] - ) + result = cli_runner.invoke(repair_cmd, ["--dry-run", "--check=hooks_installed"]) # Should only show hooks repair assert "Reinstall Git hooks" in result.output diff --git a/src/kurt/conftest.py b/src/kurt/conftest.py index db64cf69..6add008a 100644 --- a/src/kurt/conftest.py +++ b/src/kurt/conftest.py @@ -125,18 +125,24 @@ def assert_json_output(result) -> dict: # ============================================================================ -@pytest.fixture -def tmp_database(tmp_path: Path, monkeypatch): +@pytest.fixture(scope="session") +def tmp_database(tmp_path_factory): """ Fixture for a temporary Dolt database with its own server. - Creates a fresh Dolt database for each test on a unique port. + Creates ONE shared Dolt database for the entire session. Sets DATABASE_URL environment variable. + + Using session scope dramatically speeds up tests by avoiding + repeated server startup/shutdown. All tests share one server. """ # Skip if dolt is not installed if not shutil.which("dolt"): pytest.skip("Dolt CLI not installed") + # Create temp directory for this session + tmp_path = tmp_path_factory.mktemp("dolt_session") + # Change to temp directory original_cwd = os.getcwd() os.chdir(tmp_path) @@ -165,10 +171,11 @@ def tmp_database(tmp_path: Path, monkeypatch): os.chdir(original_cwd) pytest.fail(f"Dolt server failed to start on port {port}") - # Set DATABASE_URL to connect to this test's Dolt server + # Set DATABASE_URL to connect to this session's Dolt server # Database name is the directory name (created by dolt init) database_name = tmp_path.name - monkeypatch.setenv("DATABASE_URL", f"mysql+pymysql://root@127.0.0.1:{port}/{database_name}") + database_url = f"mysql+pymysql://root@127.0.0.1:{port}/{database_name}" + os.environ["DATABASE_URL"] = database_url # Initialize the database from kurt.db import init_database @@ -188,14 +195,17 @@ def tmp_database(tmp_path: Path, monkeypatch): pass os.chdir(original_cwd) + # Restore original DATABASE_URL if it existed + if "DATABASE_URL" in os.environ: + del os.environ["DATABASE_URL"] -@pytest.fixture +@pytest.fixture(scope="session") def tmp_database_with_data(tmp_database: Path): """ Fixture with a temporary database pre-populated with sample data. - Use when tests need existing data. + Use when tests need existing data. Session-scoped for performance. """ from sqlalchemy import text diff --git a/src/kurt/db/connection.py b/src/kurt/db/connection.py index 396448bc..bd46df1e 100644 --- a/src/kurt/db/connection.py +++ b/src/kurt/db/connection.py @@ -258,6 +258,11 @@ def __init__( self._database = database or self.path.name self._pool_size = pool_size + # Check if this project already has a saved port from a previous run + saved_port = self._read_saved_port() + if saved_port is not None: + self._port = saved_port + # Connection pool (lazy init) - for raw query mode self._pool: ConnectionPool | None = None @@ -285,25 +290,47 @@ def _is_local_server_target(self) -> bool: """ return self._host in {"localhost", "127.0.0.1", "::1"} + def _read_saved_port(self) -> int | None: + """Read the saved port from this project's server info file. + + Returns the port if this project has a saved server info file, + None otherwise. + """ + info_file = self.path / ".dolt" / "kurt-server.json" + if not info_file.exists(): + return None + + try: + import json + + info = json.loads(info_file.read_text()) + # Only use saved port if it's for this project + if info.get("path") == str(self.path.resolve()): + return info.get("port") + except Exception: + pass + return None + + def _find_free_port(self) -> int: + """Find a free TCP port on localhost.""" + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) + return s.getsockname()[1] + def _get_pool(self) -> ConnectionPool: """Get or create connection pool for dolt sql-server. Auto-starts the server for local targets if not running. Remote servers must be running - we never try to start them. - For local servers, verifies the running server matches this project - to prevent connecting to a stale server from a different project. + For local servers, if the port is occupied by another project's server, + a new port is automatically selected. """ if self._auto_start and self._is_local_server_target(): - if not self._is_server_running(): - self._start_server() - elif not self._is_correct_server(): - # Server running but for wrong project - restart it - logger.warning( - f"Dolt server on port {self._port} is for a different project. Restarting..." - ) - self._stop_server() - self._start_server() + # _start_server handles port conflicts by finding a free port + self._start_server() if self._pool is None: self._pool = ConnectionPool( @@ -330,7 +357,9 @@ def _is_server_running(self) -> bool: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(1) - result = sock.connect_ex((self._host if self._host != "localhost" else "127.0.0.1", self._port)) + result = sock.connect_ex( + (self._host if self._host != "localhost" else "127.0.0.1", self._port) + ) sock.close() return result == 0 except Exception: @@ -349,10 +378,11 @@ def _is_correct_server(self) -> bool: # Can only verify local servers via info file if not self._is_local_server_target(): return False - info_file = self.path / ".dolt" / "sql-server.info" + info_file = self.path / ".dolt" / "kurt-server.json" if not info_file.exists(): - # No info file - can't verify, assume it's wrong (will restart) - return False + # No info file - can't verify ownership, but if server is running + # on our port, assume it's usable (e.g., test fixture started it) + return True try: import json @@ -373,7 +403,7 @@ def _write_server_info(self) -> None: if not self._is_local_server_target(): return - info_file = self.path / ".dolt" / "sql-server.info" + info_file = self.path / ".dolt" / "kurt-server.json" try: import json @@ -403,16 +433,19 @@ def _start_server(self) -> None: if self._is_server_running(): # Server is running - but is it OUR server or another project's? if self._is_correct_server(): - logger.debug(f"Dolt SQL server already running on port {self._port} for this project") + logger.debug( + f"Dolt SQL server already running on port {self._port} for this project" + ) return else: - # Wrong server running on our port - this is a conflict - # Log a warning but try to proceed (the database name in URL may differ) - logger.warning( - f"Dolt SQL server on port {self._port} belongs to a different project. " - f"Consider stopping it or using a different port." + # Wrong server running on our port - find a free port instead + old_port = self._port + self._port = self._find_free_port() + logger.info( + f"Port {old_port} in use by another project. " + f"Using port {self._port} instead." ) - return # Try to connect anyway - database name mismatch will fail gracefully + # Fall through to start server on new port if not shutil.which("dolt"): raise DoltConnectionError( @@ -471,7 +504,8 @@ def get_database_url(self) -> str: def _get_engine(self) -> "Engine": """Get or create SQLAlchemy engine.""" if self._engine is None: - if self._auto_start and self._is_local_server_target() and not self._is_server_running(): + if self._auto_start and self._is_local_server_target(): + # _start_server handles port conflicts by finding a free port self._start_server() self._engine = create_engine( @@ -530,7 +564,8 @@ def _make_async_url(self, url: str) -> str: def get_async_engine(self) -> AsyncEngine: """Get or create async SQLAlchemy engine.""" if self._async_engine is None: - if self._auto_start and self._is_local_server_target() and not self._is_server_running(): + if self._auto_start and self._is_local_server_target(): + # _start_server handles port conflicts by finding a free port self._start_server() async_url = self._make_async_url(self.get_database_url()) diff --git a/src/kurt/tools/tests/test_fetch_core_engines.py b/src/kurt/tools/tests/test_fetch_core_engines.py index 45b508c3..e8d89d5e 100644 --- a/src/kurt/tools/tests/test_fetch_core_engines.py +++ b/src/kurt/tools/tests/test_fetch_core_engines.py @@ -17,7 +17,6 @@ # Check if APIFY_API_KEY is available for integration tests HAS_APIFY_KEY = bool(os.environ.get("APIFY_API_KEY")) - class TestFetcherConfig: """Test FetcherConfig.""" diff --git a/src/kurt/tools/tests/test_map_engines.py b/src/kurt/tools/tests/test_map_engines.py index b1edf783..3f167a6c 100644 --- a/src/kurt/tools/tests/test_map_engines.py +++ b/src/kurt/tools/tests/test_map_engines.py @@ -20,7 +20,6 @@ reason="APIFY_API_KEY not set - skipping Apify integration tests", ) - class TestEngineRegistry: """Test EngineRegistry.""" diff --git a/src/kurt/workflows/toml/cli.py b/src/kurt/workflows/toml/cli.py index 4594c250..ba7edb57 100644 --- a/src/kurt/workflows/toml/cli.py +++ b/src/kurt/workflows/toml/cli.py @@ -72,7 +72,9 @@ def _get_dolt_db(): help="Parse and validate workflow without executing", ) @track_command -def run_cmd(workflow_path: Path, inputs: tuple[str, ...], background: bool, foreground: bool, dry_run: bool): +def run_cmd( + workflow_path: Path, inputs: tuple[str, ...], background: bool, foreground: bool, dry_run: bool +): """Run a workflow from a TOML or Markdown file. Supports both workflow types: @@ -97,6 +99,7 @@ def run_cmd(workflow_path: Path, inputs: tuple[str, ...], background: bool, fore if suffix == ".toml" and not dry_run: try: from kurt.workflows.agents.parser import parse_workflow as parse_agent_workflow + parsed = parse_agent_workflow(workflow_path) # If it has agent config but no steps, treat as agent workflow if parsed.agent is not None and not parsed.steps: @@ -140,14 +143,14 @@ def run_cmd(workflow_path: Path, inputs: tuple[str, ...], background: bool, fore else: console.print("[green]Workflow completed[/green]") console.print(f" Status: {result.get('status')}") - if result.get('turns'): + if result.get("turns"): console.print(f" Turns: {result.get('turns')}") - if result.get('tool_calls'): + if result.get("tool_calls"): console.print(f" Tool Calls: {result.get('tool_calls')}") - tokens = (result.get('tokens_in', 0) or 0) + (result.get('tokens_out', 0) or 0) + tokens = (result.get("tokens_in", 0) or 0) + (result.get("tokens_out", 0) or 0) if tokens: console.print(f" Tokens: {tokens:,}") - if result.get('duration_seconds'): + if result.get("duration_seconds"): console.print(f" Duration: {result.get('duration_seconds')}s") return @@ -243,20 +246,20 @@ def run_cmd(workflow_path: Path, inputs: tuple[str, ...], background: bool, fore # Note: Actual background execution would require a separate process # For now, we just create the record. Full implementation would use # subprocess or a task queue. - console.print( - "\n[yellow]Note: Background execution creates run record only.[/yellow]" - ) + console.print("\n[yellow]Note: Background execution creates run record only.[/yellow]") console.print("[dim]Use 'kurt logs {run_id}' to check progress.[/dim]") return # Foreground execution # Look for tools.py in the same directory as the workflow file tools_path = workflow_path.parent / "tools.py" + db = _get_dolt_db() result = asyncio.run( execute_workflow( workflow=workflow_def, inputs=merged_inputs, context=context, + db=db, tools_path=tools_path if tools_path.exists() else None, ) ) @@ -315,7 +318,9 @@ def status_cmd(run_id: str, output_json: bool, follow: bool): if run is None: if output_json: - print(json.dumps({"run_id": run_id, "status": "not_found", "error": "Workflow not found"})) + print( + json.dumps({"run_id": run_id, "status": "not_found", "error": "Workflow not found"}) + ) else: console.print(f"[red]Workflow not found: {run_id}[/red]") return @@ -324,7 +329,9 @@ def status_cmd(run_id: str, output_json: bool, follow: bool): step_logs = lifecycle.get_step_logs(run_id) # Calculate completed steps - completed_steps = sum(1 for s in step_logs if s.get("status") in ("completed", "failed", "canceled")) + completed_steps = sum( + 1 for s in step_logs if s.get("status") in ("completed", "failed", "canceled") + ) total_steps = len(step_logs) if output_json: @@ -438,7 +445,12 @@ def _follow_workflow(db, run_id: str): @click.argument("run_id") @click.option("--step", "step_filter", default=None, help="Filter by step name") @click.option("--substep", "substep_filter", default=None, help="Filter by substep name") -@click.option("--status", "status_filter", default=None, help="Filter by status (running|progress|completed|failed)") +@click.option( + "--status", + "status_filter", + default=None, + help="Filter by status (running|progress|completed|failed)", +) @click.option("--json", "output_json", is_flag=True, help="Output as JSON lines") @click.option("--tail", "-f", is_flag=True, help="Stream new events as they arrive (like tail -f)") @click.option("--limit", default=100, help="Maximum number of log entries to show") @@ -487,9 +499,7 @@ def logs_cmd( step_logs = _fetch_step_logs(db, run_id, step_filter) # Fetch step events - step_events = _fetch_step_events( - db, run_id, step_filter, substep_filter, status_filter, limit - ) + step_events = _fetch_step_events(db, run_id, step_filter, substep_filter, status_filter, limit) if output_json: _output_logs_json(step_logs, step_events) @@ -497,9 +507,7 @@ def logs_cmd( _output_logs_text(run, step_logs, step_events) -def _fetch_step_logs( - db, run_id: str, step_filter: str | None -) -> list[dict[str, Any]]: +def _fetch_step_logs(db, run_id: str, step_filter: str | None) -> list[dict[str, Any]]: """Fetch step logs from the database.""" sql = "SELECT * FROM step_logs WHERE run_id = ?" params: list[Any] = [run_id] @@ -545,9 +553,7 @@ def _fetch_step_events( return list(reversed(result.rows)) -def _output_logs_json( - step_logs: list[dict[str, Any]], step_events: list[dict[str, Any]] -): +def _output_logs_json(step_logs: list[dict[str, Any]], step_events: list[dict[str, Any]]): """Output logs as JSON lines.""" for event in step_events: output = { @@ -600,14 +606,18 @@ def _output_logs_text( if step.get("started_at") and step.get("completed_at"): try: started = datetime.fromisoformat(str(step["started_at"]).replace("Z", "+00:00")) - completed = datetime.fromisoformat(str(step["completed_at"]).replace("Z", "+00:00")) + completed = datetime.fromisoformat( + str(step["completed_at"]).replace("Z", "+00:00") + ) duration = completed - started duration_str = f"{duration.total_seconds():.1f}s" except (ValueError, TypeError): pass input_str = str(step.get("input_count")) if step.get("input_count") is not None else "-" - output_str = str(step.get("output_count")) if step.get("output_count") is not None else "-" + output_str = ( + str(step.get("output_count")) if step.get("output_count") is not None else "-" + ) error_str = str(step.get("error_count")) if step.get("error_count") else "-" table.add_row( @@ -739,9 +749,7 @@ def _tail_logs( _print_event(row) # Check if workflow terminated - run_result = db.query_one( - "SELECT status FROM workflow_runs WHERE id = ?", [run_id] - ) + run_result = db.query_one("SELECT status FROM workflow_runs WHERE id = ?", [run_id]) if run_result and run_result.get("status") in TERMINAL_STATUSES: # Fetch any final events final_result = db.query(sql, params) @@ -749,7 +757,9 @@ def _tail_logs( if row.get("id", 0) > cursor_id: if output_json: output = { - "timestamp": str(row.get("created_at")) if row.get("created_at") else None, + "timestamp": str(row.get("created_at")) + if row.get("created_at") + else None, "step": row.get("step_id"), "substep": row.get("substep"), "status": row.get("status"), @@ -912,6 +922,7 @@ def _build_dry_run_output( if tools_path and tools_path.exists(): try: import importlib.util + spec = importlib.util.spec_from_file_location("tools", tools_path) if spec and spec.loader: tools_module = importlib.util.module_from_spec(spec) @@ -1096,22 +1107,30 @@ def test_cmd( ) except FixtureNotFoundError as e: if output_json: - print(json.dumps({ - "success": False, - "error": str(e), - "step": e.step_name, - })) + print( + json.dumps( + { + "success": False, + "error": str(e), + "step": e.step_name, + } + ) + ) else: console.print(f"[red]Error: {e}[/red]") raise click.Abort() except FixtureLoadError as e: if output_json: - print(json.dumps({ - "success": False, - "error": str(e), - "step": e.step_name, - "path": str(e.path), - })) + print( + json.dumps( + { + "success": False, + "error": str(e), + "step": e.step_name, + "path": str(e.path), + } + ) + ) else: console.print(f"[red]Error loading fixture: {e}[/red]") raise click.Abort() @@ -1150,7 +1169,10 @@ def test_cmd( "fixture_records": len(fixture_set.get_output_data(name)), "would_execute": name in coverage.steps_without_fixtures, "tool": workflow_def.steps[name].type, - "config_valid": dry_run_output.get("steps", {}).get(name, {}).get("validation", {}).get("valid", False), + "config_valid": dry_run_output.get("steps", {}) + .get(name, {}) + .get("validation", {}) + .get("valid", False), } for name in step_names }, @@ -1184,9 +1206,16 @@ def test_cmd( step_def = workflow_def.steps[step_name] has_fixture = step_name in coverage.steps_with_fixtures records = len(fixture_set.get_output_data(step_name)) - config_valid = dry_run_output.get("steps", {}).get(step_name, {}).get("validation", {}).get("valid", False) + config_valid = ( + dry_run_output.get("steps", {}) + .get(step_name, {}) + .get("validation", {}) + .get("valid", False) + ) - fixture_display = "[green]Yes[/green]" if has_fixture else "[yellow]No (would execute)[/yellow]" + fixture_display = ( + "[green]Yes[/green]" if has_fixture else "[yellow]No (would execute)[/yellow]" + ) records_display = str(records) if has_fixture else "-" valid_display = "[green]Yes[/green]" if config_valid else "[red]No[/red]" diff --git a/src/kurt/workflows/toml/executor.py b/src/kurt/workflows/toml/executor.py index e15e3fc8..30b25ad5 100644 --- a/src/kurt/workflows/toml/executor.py +++ b/src/kurt/workflows/toml/executor.py @@ -26,6 +26,7 @@ from pathlib import Path from typing import Any, Callable, Literal +from kurt.db.dolt import DoltDB from kurt.observability.tracking import track_event from kurt.tools.core import ToolCanceledError, ToolContext, ToolError, ToolResult, execute_tool from kurt.tools.core.provider import get_provider_registry @@ -68,9 +69,7 @@ def _load_user_function(tools_path: Path, function_name: str) -> Callable: spec.loader.exec_module(module) if not hasattr(module, function_name): - raise AttributeError( - f"Function '{function_name}' not found in {tools_path}" - ) + raise AttributeError(f"Function '{function_name}' not found in {tools_path}") return getattr(module, function_name) @@ -263,6 +262,7 @@ def __init__( inputs: dict[str, Any], context: ToolContext | None = None, *, + db: DoltDB | None = None, continue_on_error: bool = False, run_id: str | None = None, tools_path: Path | str | None = None, @@ -274,6 +274,7 @@ def __init__( workflow: Parsed workflow definition. inputs: Input values for the workflow (merged with defaults). context: Tool execution context. + db: Optional DoltDB for event tracking. If None, uses global tracking DB. continue_on_error: If True, continue workflow on step failure. Failed step's dependents receive empty input. run_id: Optional run ID. If not provided, generates a UUID. @@ -283,6 +284,7 @@ def __init__( self.workflow = workflow self.inputs = inputs self.context = context or ToolContext() + self._db = db self.continue_on_error = continue_on_error self.run_id = run_id or str(uuid.uuid4()) self.tools_path = Path(tools_path) if tools_path else Path("tools.py") @@ -932,9 +934,7 @@ async def _handle_cancellation(self, pending_level: list[str]) -> None: if tasks_to_cancel: try: await asyncio.wait_for( - asyncio.gather( - *[t for _, t in tasks_to_cancel], return_exceptions=True - ), + asyncio.gather(*[t for _, t in tasks_to_cancel], return_exceptions=True), timeout=self.CANCEL_TIMEOUT, ) except asyncio.TimeoutError: @@ -960,9 +960,7 @@ async def _handle_cancellation(self, pending_level: list[str]) -> None: metadata={"canceled_steps": pending_level}, ) - def _create_result( - self, started_at: datetime, error: str | None = None - ) -> WorkflowResult: + def _create_result(self, started_at: datetime, error: str | None = None) -> WorkflowResult: """Create the final WorkflowResult.""" completed_at = datetime.now(timezone.utc) duration_ms = int((completed_at - started_at).total_seconds() * 1000) @@ -1021,6 +1019,7 @@ def _emit_event( total=total, message=message, metadata=metadata, + db=self._db, ) except Exception: # Event tracking should not break execution @@ -1032,6 +1031,7 @@ async def execute_workflow( inputs: dict[str, Any], context: ToolContext | None = None, *, + db: DoltDB | None = None, continue_on_error: bool = False, run_id: str | None = None, tools_path: Path | str | None = None, @@ -1050,6 +1050,7 @@ async def execute_workflow( workflow: Parsed workflow definition. inputs: Input values for the workflow. Must include all required inputs. context: Optional tool execution context. + db: Optional DoltDB for event tracking. If None, uses global tracking DB. continue_on_error: If True, continue workflow on step failure. Failed step's dependents receive empty input. run_id: Optional run ID. If not provided, generates a UUID. @@ -1080,6 +1081,7 @@ async def execute_workflow( workflow=workflow, inputs=inputs, context=context, + db=db, continue_on_error=continue_on_error, run_id=run_id, tools_path=tools_path,