diff --git a/internal/db/dependencies.go b/internal/db/dependencies.go index 636e1348..6f18fe48 100644 --- a/internal/db/dependencies.go +++ b/internal/db/dependencies.go @@ -110,6 +110,7 @@ func (db *DB) GetBlockers(taskID int64) ([]*Task, error) { COALESCE(t.claude_pane_id, ''), COALESCE(t.shell_pane_id, ''), COALESCE(t.pr_url, ''), COALESCE(t.pr_number, 0), COALESCE(t.dangerous_mode, 0), COALESCE(t.pinned, 0), COALESCE(t.tags, ''), COALESCE(t.summary, ''), + t.parent_id, COALESCE(t.output, ''), t.created_at, t.updated_at, t.started_at, t.completed_at, t.last_distilled_at, t.last_accessed_at FROM tasks t @@ -134,6 +135,7 @@ func (db *DB) GetBlockedBy(taskID int64) ([]*Task, error) { COALESCE(t.claude_pane_id, ''), COALESCE(t.shell_pane_id, ''), COALESCE(t.pr_url, ''), COALESCE(t.pr_number, 0), COALESCE(t.dangerous_mode, 0), COALESCE(t.pinned, 0), COALESCE(t.tags, ''), COALESCE(t.summary, ''), + t.parent_id, COALESCE(t.output, ''), t.created_at, t.updated_at, t.started_at, t.completed_at, t.last_distilled_at, t.last_accessed_at FROM tasks t @@ -308,6 +310,7 @@ func scanTaskRows(rows *sql.Rows) ([]*Task, error) { &t.DaemonSession, &t.TmuxWindowID, &t.ClaudePaneID, &t.ShellPaneID, &t.PRURL, &t.PRNumber, &t.DangerousMode, &t.Pinned, &t.Tags, &t.Summary, + &t.ParentID, &t.Output, &t.CreatedAt, &t.UpdatedAt, &t.StartedAt, &t.CompletedAt, &t.LastDistilledAt, &t.LastAccessedAt, ) diff --git a/internal/db/orchestration.go b/internal/db/orchestration.go new file mode 100644 index 00000000..fb4a0823 --- /dev/null +++ b/internal/db/orchestration.go @@ -0,0 +1,212 @@ +package db + +import ( + "fmt" +) + +// WorkflowStatus represents the aggregate status of a parent task's subtasks. +type WorkflowStatus struct { + ParentID int64 `json:"parent_id"` + ParentTitle string `json:"parent_title"` + Total int `json:"total"` + Pending int `json:"pending"` // backlog + queued + Processing int `json:"processing"` // currently executing + Blocked int `json:"blocked"` // waiting for input + Done int `json:"done"` // completed + Archived int `json:"archived"` // archived + IsComplete bool `json:"is_complete"` // all subtasks are done or archived +} + +// GetSubtasks returns all child tasks of a parent task. +func (db *DB) GetSubtasks(parentID int64) ([]*Task, error) { + rows, err := db.Query(` + SELECT id, title, body, status, type, project, COALESCE(executor, 'claude'), + worktree_path, branch_name, port, claude_session_id, + COALESCE(daemon_session, ''), COALESCE(tmux_window_id, ''), + COALESCE(claude_pane_id, ''), COALESCE(shell_pane_id, ''), + COALESCE(pr_url, ''), COALESCE(pr_number, 0), + COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), + COALESCE(source_branch, ''), COALESCE(summary, ''), + parent_id, COALESCE(output, ''), + created_at, updated_at, started_at, completed_at, + last_distilled_at, last_accessed_at, + COALESCE(archive_ref, ''), COALESCE(archive_commit, ''), + COALESCE(archive_worktree_path, ''), COALESCE(archive_branch_name, '') + FROM tasks + WHERE parent_id = ? + ORDER BY id ASC + `, parentID) + if err != nil { + return nil, fmt.Errorf("get subtasks: %w", err) + } + defer rows.Close() + + var tasks []*Task + for rows.Next() { + t := &Task{} + err := rows.Scan( + &t.ID, &t.Title, &t.Body, &t.Status, &t.Type, &t.Project, &t.Executor, + &t.WorktreePath, &t.BranchName, &t.Port, &t.ClaudeSessionID, + &t.DaemonSession, &t.TmuxWindowID, &t.ClaudePaneID, &t.ShellPaneID, + &t.PRURL, &t.PRNumber, + &t.DangerousMode, &t.Pinned, &t.Tags, + &t.SourceBranch, &t.Summary, + &t.ParentID, &t.Output, + &t.CreatedAt, &t.UpdatedAt, &t.StartedAt, &t.CompletedAt, + &t.LastDistilledAt, &t.LastAccessedAt, + &t.ArchiveRef, &t.ArchiveCommit, &t.ArchiveWorktreePath, &t.ArchiveBranchName, + ) + if err != nil { + return nil, fmt.Errorf("scan subtask: %w", err) + } + tasks = append(tasks, t) + } + return tasks, nil +} + +// GetSubtaskCount returns the number of subtasks for a parent task. +func (db *DB) GetSubtaskCount(parentID int64) (int, error) { + var count int + err := db.QueryRow(`SELECT COUNT(*) FROM tasks WHERE parent_id = ?`, parentID).Scan(&count) + if err != nil { + return 0, fmt.Errorf("count subtasks: %w", err) + } + return count, nil +} + +// SetTaskOutput stores output/results for a task. +// This is used in orchestration to pass context between tasks. +func (db *DB) SetTaskOutput(taskID int64, output string) error { + _, err := db.Exec(` + UPDATE tasks SET output = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = ? + `, output, taskID) + if err != nil { + return fmt.Errorf("set task output: %w", err) + } + return nil +} + +// GetTaskOutput retrieves the output for a task. +func (db *DB) GetTaskOutput(taskID int64) (string, error) { + var output string + err := db.QueryRow(`SELECT COALESCE(output, '') FROM tasks WHERE id = ?`, taskID).Scan(&output) + if err != nil { + return "", fmt.Errorf("get task output: %w", err) + } + return output, nil +} + +// GetWorkflowStatus returns the aggregate status of all subtasks for a parent task. +func (db *DB) GetWorkflowStatus(parentID int64) (*WorkflowStatus, error) { + parent, err := db.GetTask(parentID) + if err != nil { + return nil, fmt.Errorf("get parent task: %w", err) + } + if parent == nil { + return nil, fmt.Errorf("parent task %d not found", parentID) + } + + status := &WorkflowStatus{ + ParentID: parentID, + ParentTitle: parent.Title, + } + + rows, err := db.Query(` + SELECT status, COUNT(*) as count + FROM tasks + WHERE parent_id = ? + GROUP BY status + `, parentID) + if err != nil { + return nil, fmt.Errorf("get workflow status: %w", err) + } + defer rows.Close() + + for rows.Next() { + var s string + var count int + if err := rows.Scan(&s, &count); err != nil { + return nil, fmt.Errorf("scan workflow status: %w", err) + } + + status.Total += count + switch s { + case StatusBacklog, StatusQueued: + status.Pending += count + case StatusProcessing: + status.Processing += count + case StatusBlocked: + status.Blocked += count + case StatusDone: + status.Done += count + case StatusArchived: + status.Archived += count + } + } + + // Workflow is complete when all subtasks are done or archived + status.IsComplete = status.Total > 0 && (status.Done+status.Archived) == status.Total + + return status, nil +} + +// CheckAndCompleteParent checks if all subtasks of a parent are done/archived, +// and if so, updates the parent task status to done. +// Returns true if the parent was completed. +func (db *DB) CheckAndCompleteParent(parentID int64) (bool, error) { + status, err := db.GetWorkflowStatus(parentID) + if err != nil { + return false, err + } + + if !status.IsComplete { + return false, nil + } + + // Get parent task to check its current status + parent, err := db.GetTask(parentID) + if err != nil { + return false, fmt.Errorf("get parent task: %w", err) + } + if parent == nil { + return false, nil + } + + // Only auto-complete if parent is still in a pending/processing state + if parent.Status == StatusDone || parent.Status == StatusArchived { + return false, nil + } + + // Collect outputs from all subtasks for the parent's summary + subtasks, err := db.GetSubtasks(parentID) + if err != nil { + return false, fmt.Errorf("get subtasks for summary: %w", err) + } + + var summaryParts []string + for _, st := range subtasks { + if st.Output != "" { + summaryParts = append(summaryParts, fmt.Sprintf("## Subtask #%d: %s\n%s", st.ID, st.Title, st.Output)) + } + } + + // Set workflow completion output on parent + if len(summaryParts) > 0 { + workflowOutput := fmt.Sprintf("Workflow completed: %d/%d subtasks done.\n\n", status.Done, status.Total) + for _, part := range summaryParts { + workflowOutput += part + "\n\n" + } + db.SetTaskOutput(parentID, workflowOutput) + } + + // Log completion + db.AppendTaskLog(parentID, "system", fmt.Sprintf("All %d subtask(s) completed. Workflow auto-completing.", status.Total)) + + // Mark parent as done + if err := db.UpdateTaskStatus(parentID, StatusDone); err != nil { + return false, fmt.Errorf("complete parent: %w", err) + } + + return true, nil +} diff --git a/internal/db/orchestration_test.go b/internal/db/orchestration_test.go new file mode 100644 index 00000000..697d0ea5 --- /dev/null +++ b/internal/db/orchestration_test.go @@ -0,0 +1,382 @@ +package db + +import ( + "os" + "testing" +) + +func setupOrchTestDB(t *testing.T) (*DB, func()) { + tmpFile, err := os.CreateTemp("", "test-orch-*.db") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + tmpFile.Close() + + db, err := Open(tmpFile.Name()) + if err != nil { + os.Remove(tmpFile.Name()) + t.Fatalf("Failed to open database: %v", err) + } + + cleanup := func() { + db.Close() + os.Remove(tmpFile.Name()) + } + + return db, cleanup +} + +func TestCreateTaskWithParent(t *testing.T) { + db, cleanup := setupOrchTestDB(t) + defer cleanup() + + // Create parent task + parent := &Task{Title: "Parent Task", Status: StatusProcessing} + if err := db.CreateTask(parent); err != nil { + t.Fatalf("Failed to create parent: %v", err) + } + + // Create subtask + subtask := &Task{ + Title: "Subtask 1", + Status: StatusBacklog, + ParentID: &parent.ID, + } + if err := db.CreateTask(subtask); err != nil { + t.Fatalf("Failed to create subtask: %v", err) + } + + // Verify subtask has parent_id set + fetched, err := db.GetTask(subtask.ID) + if err != nil { + t.Fatalf("Failed to get subtask: %v", err) + } + if fetched.ParentID == nil { + t.Fatal("Expected subtask to have parent_id set") + } + if *fetched.ParentID != parent.ID { + t.Errorf("Expected parent_id=%d, got %d", parent.ID, *fetched.ParentID) + } +} + +func TestGetSubtasks(t *testing.T) { + db, cleanup := setupOrchTestDB(t) + defer cleanup() + + // Create parent + parent := &Task{Title: "Parent Task", Status: StatusProcessing} + if err := db.CreateTask(parent); err != nil { + t.Fatalf("Failed to create parent: %v", err) + } + + // Create 3 subtasks + for i := 1; i <= 3; i++ { + st := &Task{ + Title: "Subtask", + Status: StatusBacklog, + ParentID: &parent.ID, + } + if err := db.CreateTask(st); err != nil { + t.Fatalf("Failed to create subtask %d: %v", i, err) + } + } + + // Get subtasks + subtasks, err := db.GetSubtasks(parent.ID) + if err != nil { + t.Fatalf("Failed to get subtasks: %v", err) + } + if len(subtasks) != 3 { + t.Errorf("Expected 3 subtasks, got %d", len(subtasks)) + } + + // Verify all have correct parent + for _, st := range subtasks { + if st.ParentID == nil || *st.ParentID != parent.ID { + t.Errorf("Subtask %d has wrong parent_id", st.ID) + } + } +} + +func TestGetSubtaskCount(t *testing.T) { + db, cleanup := setupOrchTestDB(t) + defer cleanup() + + parent := &Task{Title: "Parent", Status: StatusProcessing} + if err := db.CreateTask(parent); err != nil { + t.Fatalf("Failed to create parent: %v", err) + } + + // No subtasks initially + count, err := db.GetSubtaskCount(parent.ID) + if err != nil { + t.Fatalf("Failed to get subtask count: %v", err) + } + if count != 0 { + t.Errorf("Expected 0 subtasks, got %d", count) + } + + // Add 2 subtasks + for i := 0; i < 2; i++ { + st := &Task{Title: "Sub", Status: StatusBacklog, ParentID: &parent.ID} + if err := db.CreateTask(st); err != nil { + t.Fatalf("Failed to create subtask: %v", err) + } + } + + count, err = db.GetSubtaskCount(parent.ID) + if err != nil { + t.Fatalf("Failed to get subtask count: %v", err) + } + if count != 2 { + t.Errorf("Expected 2 subtasks, got %d", count) + } +} + +func TestTaskOutput(t *testing.T) { + db, cleanup := setupOrchTestDB(t) + defer cleanup() + + task := &Task{Title: "Task with output", Status: StatusProcessing} + if err := db.CreateTask(task); err != nil { + t.Fatalf("Failed to create task: %v", err) + } + + // Initially no output + output, err := db.GetTaskOutput(task.ID) + if err != nil { + t.Fatalf("Failed to get output: %v", err) + } + if output != "" { + t.Errorf("Expected empty output, got %q", output) + } + + // Set output + expectedOutput := "Results: everything passed" + if err := db.SetTaskOutput(task.ID, expectedOutput); err != nil { + t.Fatalf("Failed to set output: %v", err) + } + + // Get output + output, err = db.GetTaskOutput(task.ID) + if err != nil { + t.Fatalf("Failed to get output: %v", err) + } + if output != expectedOutput { + t.Errorf("Expected output %q, got %q", expectedOutput, output) + } +} + +func TestGetWorkflowStatus(t *testing.T) { + db, cleanup := setupOrchTestDB(t) + defer cleanup() + + parent := &Task{Title: "Workflow Parent", Status: StatusProcessing} + if err := db.CreateTask(parent); err != nil { + t.Fatalf("Failed to create parent: %v", err) + } + + // Create subtasks in various states + statuses := []string{StatusDone, StatusDone, StatusProcessing, StatusBacklog, StatusBlocked} + for _, s := range statuses { + st := &Task{Title: "Sub", Status: s, ParentID: &parent.ID} + if err := db.CreateTask(st); err != nil { + t.Fatalf("Failed to create subtask: %v", err) + } + } + + status, err := db.GetWorkflowStatus(parent.ID) + if err != nil { + t.Fatalf("Failed to get workflow status: %v", err) + } + + if status.Total != 5 { + t.Errorf("Expected total=5, got %d", status.Total) + } + if status.Done != 2 { + t.Errorf("Expected done=2, got %d", status.Done) + } + if status.Processing != 1 { + t.Errorf("Expected processing=1, got %d", status.Processing) + } + if status.Pending != 1 { + t.Errorf("Expected pending=1, got %d", status.Pending) + } + if status.Blocked != 1 { + t.Errorf("Expected blocked=1, got %d", status.Blocked) + } + if status.IsComplete { + t.Error("Expected workflow to not be complete") + } +} + +func TestGetWorkflowStatusAllDone(t *testing.T) { + db, cleanup := setupOrchTestDB(t) + defer cleanup() + + parent := &Task{Title: "Workflow Parent", Status: StatusProcessing} + if err := db.CreateTask(parent); err != nil { + t.Fatalf("Failed to create parent: %v", err) + } + + // All subtasks done + for i := 0; i < 3; i++ { + st := &Task{Title: "Sub", Status: StatusDone, ParentID: &parent.ID} + if err := db.CreateTask(st); err != nil { + t.Fatalf("Failed to create subtask: %v", err) + } + } + + status, err := db.GetWorkflowStatus(parent.ID) + if err != nil { + t.Fatalf("Failed to get workflow status: %v", err) + } + + if !status.IsComplete { + t.Error("Expected workflow to be complete") + } + if status.Total != 3 { + t.Errorf("Expected total=3, got %d", status.Total) + } + if status.Done != 3 { + t.Errorf("Expected done=3, got %d", status.Done) + } +} + +func TestCheckAndCompleteParent(t *testing.T) { + db, cleanup := setupOrchTestDB(t) + defer cleanup() + + parent := &Task{Title: "Workflow Parent", Status: StatusProcessing} + if err := db.CreateTask(parent); err != nil { + t.Fatalf("Failed to create parent: %v", err) + } + + // Create two subtasks, both done + for i := 0; i < 2; i++ { + st := &Task{Title: "Sub", Status: StatusDone, ParentID: &parent.ID} + if err := db.CreateTask(st); err != nil { + t.Fatalf("Failed to create subtask: %v", err) + } + // Set output on subtask + db.SetTaskOutput(st.ID, "output from subtask") + } + + // Check and complete parent + completed, err := db.CheckAndCompleteParent(parent.ID) + if err != nil { + t.Fatalf("Failed to check and complete parent: %v", err) + } + if !completed { + t.Error("Expected parent to be completed") + } + + // Verify parent is now done + updatedParent, err := db.GetTask(parent.ID) + if err != nil { + t.Fatalf("Failed to get parent: %v", err) + } + if updatedParent.Status != StatusDone { + t.Errorf("Expected parent status=done, got %s", updatedParent.Status) + } + + // Verify parent has aggregated output + if updatedParent.Output == "" { + t.Error("Expected parent to have aggregated output from subtasks") + } +} + +func TestCheckAndCompleteParentNotReady(t *testing.T) { + db, cleanup := setupOrchTestDB(t) + defer cleanup() + + parent := &Task{Title: "Workflow Parent", Status: StatusProcessing} + if err := db.CreateTask(parent); err != nil { + t.Fatalf("Failed to create parent: %v", err) + } + + // One done, one still processing + st1 := &Task{Title: "Sub 1", Status: StatusDone, ParentID: &parent.ID} + st2 := &Task{Title: "Sub 2", Status: StatusProcessing, ParentID: &parent.ID} + if err := db.CreateTask(st1); err != nil { + t.Fatalf("Failed to create subtask 1: %v", err) + } + if err := db.CreateTask(st2); err != nil { + t.Fatalf("Failed to create subtask 2: %v", err) + } + + // Should not complete + completed, err := db.CheckAndCompleteParent(parent.ID) + if err != nil { + t.Fatalf("Failed to check and complete parent: %v", err) + } + if completed { + t.Error("Expected parent not to be completed (subtask still processing)") + } + + // Parent should still be processing + updatedParent, err := db.GetTask(parent.ID) + if err != nil { + t.Fatalf("Failed to get parent: %v", err) + } + if updatedParent.Status != StatusProcessing { + t.Errorf("Expected parent status=processing, got %s", updatedParent.Status) + } +} + +func TestSubtaskCompletionTriggersParentCompletion(t *testing.T) { + db, cleanup := setupOrchTestDB(t) + defer cleanup() + + parent := &Task{Title: "Workflow Parent", Status: StatusProcessing} + if err := db.CreateTask(parent); err != nil { + t.Fatalf("Failed to create parent: %v", err) + } + + // Create two subtasks + st1 := &Task{Title: "Sub 1", Status: StatusProcessing, ParentID: &parent.ID} + st2 := &Task{Title: "Sub 2", Status: StatusDone, ParentID: &parent.ID} + if err := db.CreateTask(st1); err != nil { + t.Fatalf("Failed to create subtask 1: %v", err) + } + if err := db.CreateTask(st2); err != nil { + t.Fatalf("Failed to create subtask 2: %v", err) + } + + // Complete the last remaining subtask via UpdateTaskStatus + // This should trigger auto-completion of the parent + if err := db.UpdateTaskStatus(st1.ID, StatusDone); err != nil { + t.Fatalf("Failed to update subtask status: %v", err) + } + + // Verify parent is now done + updatedParent, err := db.GetTask(parent.ID) + if err != nil { + t.Fatalf("Failed to get parent: %v", err) + } + if updatedParent.Status != StatusDone { + t.Errorf("Expected parent to be auto-completed, got status=%s", updatedParent.Status) + } +} + +func TestGetWorkflowStatusNoSubtasks(t *testing.T) { + db, cleanup := setupOrchTestDB(t) + defer cleanup() + + parent := &Task{Title: "No subtasks", Status: StatusProcessing} + if err := db.CreateTask(parent); err != nil { + t.Fatalf("Failed to create parent: %v", err) + } + + status, err := db.GetWorkflowStatus(parent.ID) + if err != nil { + t.Fatalf("Failed to get workflow status: %v", err) + } + + if status.Total != 0 { + t.Errorf("Expected total=0, got %d", status.Total) + } + if status.IsComplete { + t.Error("Expected workflow with no subtasks to not be marked complete") + } +} diff --git a/internal/db/sqlite.go b/internal/db/sqlite.go index cde69a72..9110cc6f 100644 --- a/internal/db/sqlite.go +++ b/internal/db/sqlite.go @@ -251,6 +251,9 @@ func (db *DB) migrate() error { `ALTER TABLE tasks ADD COLUMN archive_branch_name TEXT DEFAULT ''`, // Original branch name before archiving // Source branch for checking out existing branches in worktrees (e.g., for QA deployments) `ALTER TABLE tasks ADD COLUMN source_branch TEXT DEFAULT ''`, // Existing branch to checkout instead of creating new branch + // Orchestration columns for parent-child task relationships + `ALTER TABLE tasks ADD COLUMN parent_id INTEGER REFERENCES tasks(id) ON DELETE SET NULL`, // Parent task ID for subtask orchestration + `ALTER TABLE tasks ADD COLUMN output TEXT DEFAULT ''`, // Task output/results for downstream tasks in workflows } for _, m := range alterMigrations { @@ -258,6 +261,9 @@ func (db *DB) migrate() error { db.Exec(m) } + // Post-ALTER indexes (must run after columns exist) + db.Exec(`CREATE INDEX IF NOT EXISTS idx_tasks_parent_id ON tasks(parent_id)`) + // Note: SQLite doesn't support ALTER COLUMN DEFAULT directly // The default value change for project column will be handled in the application layer // New tasks will get 'personal' as default through the form and executor logic diff --git a/internal/db/tasks.go b/internal/db/tasks.go index cf64c097..ab030409 100644 --- a/internal/db/tasks.go +++ b/internal/db/tasks.go @@ -33,6 +33,8 @@ type Task struct { Tags string // Comma-separated tags for categorization (e.g., "customer-support,email,influence-kit") SourceBranch string // Existing branch to checkout for worktree (e.g., "fix/ui-overflow") instead of creating new branch Summary string // Distilled summary of what was accomplished (for search and context) + ParentID *int64 // Parent task ID for orchestration (nil = top-level task) + Output string // Task output/results for downstream tasks in orchestration workflows CreatedAt LocalTime UpdatedAt LocalTime StartedAt *LocalTime @@ -128,9 +130,9 @@ func (db *DB) CreateTask(t *Task) error { t.Project = project.Name result, err := db.Exec(` - INSERT INTO tasks (title, body, status, type, project, executor, pinned, tags, source_branch) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - `, t.Title, t.Body, t.Status, t.Type, t.Project, t.Executor, t.Pinned, t.Tags, t.SourceBranch) + INSERT INTO tasks (title, body, status, type, project, executor, pinned, tags, source_branch, parent_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + `, t.Title, t.Body, t.Status, t.Type, t.Project, t.Executor, t.Pinned, t.Tags, t.SourceBranch, t.ParentID) if err != nil { return fmt.Errorf("insert task: %w", err) } @@ -176,6 +178,7 @@ func (db *DB) GetTask(id int64) (*Task, error) { COALESCE(pr_url, ''), COALESCE(pr_number, 0), COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), COALESCE(source_branch, ''), COALESCE(summary, ''), + parent_id, COALESCE(output, ''), created_at, updated_at, started_at, completed_at, last_distilled_at, last_accessed_at, COALESCE(archive_ref, ''), COALESCE(archive_commit, ''), @@ -188,6 +191,7 @@ func (db *DB) GetTask(id int64) (*Task, error) { &t.PRURL, &t.PRNumber, &t.DangerousMode, &t.Pinned, &t.Tags, &t.SourceBranch, &t.Summary, + &t.ParentID, &t.Output, &t.CreatedAt, &t.UpdatedAt, &t.StartedAt, &t.CompletedAt, &t.LastDistilledAt, &t.LastAccessedAt, &t.ArchiveRef, &t.ArchiveCommit, &t.ArchiveWorktreePath, &t.ArchiveBranchName, @@ -221,6 +225,7 @@ func (db *DB) ListTasks(opts ListTasksOptions) ([]*Task, error) { COALESCE(pr_url, ''), COALESCE(pr_number, 0), COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), COALESCE(source_branch, ''), COALESCE(summary, ''), + parent_id, COALESCE(output, ''), created_at, updated_at, started_at, completed_at, last_distilled_at, last_accessed_at, COALESCE(archive_ref, ''), COALESCE(archive_commit, ''), @@ -281,6 +286,7 @@ func (db *DB) ListTasks(opts ListTasksOptions) ([]*Task, error) { &t.PRURL, &t.PRNumber, &t.DangerousMode, &t.Pinned, &t.Tags, &t.SourceBranch, &t.Summary, + &t.ParentID, &t.Output, &t.CreatedAt, &t.UpdatedAt, &t.StartedAt, &t.CompletedAt, &t.LastDistilledAt, &t.LastAccessedAt, &t.ArchiveRef, &t.ArchiveCommit, &t.ArchiveWorktreePath, &t.ArchiveBranchName, @@ -306,6 +312,7 @@ func (db *DB) GetMostRecentlyCreatedTask() (*Task, error) { COALESCE(pr_url, ''), COALESCE(pr_number, 0), COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), COALESCE(source_branch, ''), COALESCE(summary, ''), + parent_id, COALESCE(output, ''), created_at, updated_at, started_at, completed_at, last_distilled_at, last_accessed_at, COALESCE(archive_ref, ''), COALESCE(archive_commit, ''), @@ -320,6 +327,7 @@ func (db *DB) GetMostRecentlyCreatedTask() (*Task, error) { &t.PRURL, &t.PRNumber, &t.DangerousMode, &t.Pinned, &t.Tags, &t.SourceBranch, &t.Summary, + &t.ParentID, &t.Output, &t.CreatedAt, &t.UpdatedAt, &t.StartedAt, &t.CompletedAt, &t.LastDistilledAt, &t.LastAccessedAt, &t.ArchiveRef, &t.ArchiveCommit, &t.ArchiveWorktreePath, &t.ArchiveBranchName, @@ -349,6 +357,7 @@ func (db *DB) SearchTasks(query string, limit int) ([]*Task, error) { COALESCE(pr_url, ''), COALESCE(pr_number, 0), COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), COALESCE(source_branch, ''), COALESCE(summary, ''), + parent_id, COALESCE(output, ''), created_at, updated_at, started_at, completed_at, last_distilled_at, last_accessed_at, COALESCE(archive_ref, ''), COALESCE(archive_commit, ''), @@ -382,6 +391,7 @@ func (db *DB) SearchTasks(query string, limit int) ([]*Task, error) { &t.PRURL, &t.PRNumber, &t.DangerousMode, &t.Pinned, &t.Tags, &t.SourceBranch, &t.Summary, + &t.ParentID, &t.Output, &t.CreatedAt, &t.UpdatedAt, &t.StartedAt, &t.CompletedAt, &t.LastDistilledAt, &t.LastAccessedAt, &t.ArchiveRef, &t.ArchiveCommit, &t.ArchiveWorktreePath, &t.ArchiveBranchName, @@ -458,6 +468,12 @@ func (db *DB) UpdateTaskStatus(id int64, status string) error { // Process dependent tasks when a blocker is completed if status == StatusDone || status == StatusArchived { db.ProcessCompletedBlocker(id) + + // Check if this task is a subtask and all siblings are done + // If so, auto-complete the parent workflow + if oldTask != nil && oldTask.ParentID != nil { + db.CheckAndCompleteParent(*oldTask.ParentID) + } } return nil @@ -738,6 +754,7 @@ func (db *DB) GetNextQueuedTask() (*Task, error) { COALESCE(pr_url, ''), COALESCE(pr_number, 0), COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), COALESCE(source_branch, ''), COALESCE(summary, ''), + parent_id, COALESCE(output, ''), created_at, updated_at, started_at, completed_at, last_distilled_at, last_accessed_at, COALESCE(archive_ref, ''), COALESCE(archive_commit, ''), @@ -753,6 +770,7 @@ func (db *DB) GetNextQueuedTask() (*Task, error) { &t.PRURL, &t.PRNumber, &t.DangerousMode, &t.Pinned, &t.Tags, &t.SourceBranch, &t.Summary, + &t.ParentID, &t.Output, &t.CreatedAt, &t.UpdatedAt, &t.StartedAt, &t.CompletedAt, &t.LastDistilledAt, &t.LastAccessedAt, &t.ArchiveRef, &t.ArchiveCommit, &t.ArchiveWorktreePath, &t.ArchiveBranchName, @@ -776,6 +794,7 @@ func (db *DB) GetQueuedTasks() ([]*Task, error) { COALESCE(pr_url, ''), COALESCE(pr_number, 0), COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), COALESCE(source_branch, ''), COALESCE(summary, ''), + parent_id, COALESCE(output, ''), created_at, updated_at, started_at, completed_at, last_distilled_at, last_accessed_at, COALESCE(archive_ref, ''), COALESCE(archive_commit, ''), @@ -799,6 +818,7 @@ func (db *DB) GetQueuedTasks() ([]*Task, error) { &t.PRURL, &t.PRNumber, &t.DangerousMode, &t.Pinned, &t.Tags, &t.SourceBranch, &t.Summary, + &t.ParentID, &t.Output, &t.CreatedAt, &t.UpdatedAt, &t.StartedAt, &t.CompletedAt, &t.LastDistilledAt, &t.LastAccessedAt, &t.ArchiveRef, &t.ArchiveCommit, &t.ArchiveWorktreePath, &t.ArchiveBranchName, diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 3da51ebc..96d5c4b1 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -297,6 +297,74 @@ func (s *Server) handleRequest(req *jsonRPCRequest) { "required": []string{"action"}, }, }, + // Orchestration tools for long-running task workflows + { + Name: "taskyou_create_subtask", + Description: "Create a subtask of the current task for orchestrating large workflows. Subtasks execute independently and the parent task auto-completes when all subtasks finish. Use this to decompose complex work into parallel or sequential steps. Subtasks can be chained with dependencies using 'depends_on' to create sequential workflows.", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "title": map[string]interface{}{ + "type": "string", + "description": "Title of the subtask", + }, + "body": map[string]interface{}{ + "type": "string", + "description": "Detailed description and instructions for the subtask", + }, + "type": map[string]interface{}{ + "type": "string", + "description": "Task type (code, writing, thinking). Defaults to parent's type.", + }, + "status": map[string]interface{}{ + "type": "string", + "description": "Initial status: 'backlog' (default) or 'queued' (starts immediately)", + }, + "depends_on": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{"type": "integer"}, + "description": "List of subtask IDs that must complete before this one starts. Creates blocking dependencies with auto-queue.", + }, + }, + "required": []string{"title"}, + }, + }, + { + Name: "taskyou_get_workflow_status", + Description: "Get the overall progress of the current task's workflow (subtask completion status). Shows how many subtasks are pending, processing, blocked, and done. Use this to monitor orchestration progress.", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + }, + }, + { + Name: "taskyou_set_task_output", + Description: "Store output/results from this task for use by downstream tasks or the parent workflow. Use this to pass context, data, or results to other tasks in an orchestration workflow.", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "output": map[string]interface{}{ + "type": "string", + "description": "The output/results to store. This will be available to downstream tasks and the parent task.", + }, + }, + "required": []string{"output"}, + }, + }, + { + Name: "taskyou_get_task_output", + Description: "Read the stored output from another task. Use this to get results from upstream tasks in an orchestration workflow. Only works for tasks in the same project.", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "task_id": map[string]interface{}{ + "type": "integer", + "description": "The ID of the task whose output to read", + }, + }, + "required": []string{"task_id"}, + }, + }, }, }) @@ -742,6 +810,218 @@ This saves future tasks from re-exploring the codebase.`}, }, }) + case "taskyou_create_subtask": + title, _ := params.Arguments["title"].(string) + if title == "" { + s.sendError(id, -32602, "title is required") + return + } + body, _ := params.Arguments["body"].(string) + taskType, _ := params.Arguments["type"].(string) + status, _ := params.Arguments["status"].(string) + + // Get current task for defaults + currentTask, err := s.db.GetTask(s.taskID) + if err != nil || currentTask == nil { + s.sendError(id, -32603, "Failed to get current task") + return + } + + // Default type to parent's type + if taskType == "" { + taskType = currentTask.Type + } + + if status == "" { + status = db.StatusBacklog + } + + parentID := s.taskID + newTask := &db.Task{ + Title: title, + Body: body, + Project: currentTask.Project, + Type: taskType, + Status: status, + ParentID: &parentID, + } + + if err := s.db.CreateTask(newTask); err != nil { + s.sendError(id, -32603, fmt.Sprintf("Failed to create subtask: %v", err)) + return + } + + // Handle dependencies (depends_on) + if deps, ok := params.Arguments["depends_on"].([]interface{}); ok && len(deps) > 0 { + for _, dep := range deps { + if depID, ok := dep.(float64); ok { + if err := s.db.AddDependency(int64(depID), newTask.ID, true); err != nil { + // Log but don't fail - the subtask was already created + s.db.AppendTaskLog(s.taskID, "system", fmt.Sprintf("Warning: failed to add dependency %d -> #%d: %v", int64(depID), newTask.ID, err)) + } + } + } + } + + s.db.AppendTaskLog(s.taskID, "system", fmt.Sprintf("Created subtask #%d: %s", newTask.ID, newTask.Title)) + + s.sendResult(id, toolCallResult{ + Content: []contentBlock{ + {Type: "text", Text: fmt.Sprintf("Created subtask #%d: %s (status: %s, parent: #%d)", newTask.ID, newTask.Title, status, s.taskID)}, + }, + }) + + case "taskyou_get_workflow_status": + // Get workflow status for the current task (as parent) + status, err := s.db.GetWorkflowStatus(s.taskID) + if err != nil { + s.sendError(id, -32603, fmt.Sprintf("Failed to get workflow status: %v", err)) + return + } + + if status.Total == 0 { + // Check if current task is itself a subtask and report parent's workflow status + currentTask, err := s.db.GetTask(s.taskID) + if err == nil && currentTask != nil && currentTask.ParentID != nil { + status, err = s.db.GetWorkflowStatus(*currentTask.ParentID) + if err != nil { + s.sendError(id, -32603, fmt.Sprintf("Failed to get parent workflow status: %v", err)) + return + } + } + } + + if status.Total == 0 { + s.sendResult(id, toolCallResult{ + Content: []contentBlock{ + {Type: "text", Text: "No subtasks found. Use taskyou_create_subtask to decompose this task into a workflow."}, + }, + }) + return + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("## Workflow Status: %s\n\n", status.ParentTitle)) + sb.WriteString(fmt.Sprintf("**Total subtasks:** %d\n", status.Total)) + sb.WriteString(fmt.Sprintf("**Done:** %d\n", status.Done)) + sb.WriteString(fmt.Sprintf("**Processing:** %d\n", status.Processing)) + sb.WriteString(fmt.Sprintf("**Pending:** %d\n", status.Pending)) + sb.WriteString(fmt.Sprintf("**Blocked:** %d\n", status.Blocked)) + if status.Archived > 0 { + sb.WriteString(fmt.Sprintf("**Archived:** %d\n", status.Archived)) + } + + progress := 0 + if status.Total > 0 { + progress = (status.Done + status.Archived) * 100 / status.Total + } + sb.WriteString(fmt.Sprintf("\n**Progress:** %d%%\n", progress)) + + if status.IsComplete { + sb.WriteString("\n**All subtasks are complete!**\n") + } + + // List individual subtasks + subtasks, err := s.db.GetSubtasks(status.ParentID) + if err == nil && len(subtasks) > 0 { + sb.WriteString("\n### Subtasks:\n") + for _, st := range subtasks { + statusIcon := "⏳" + switch st.Status { + case db.StatusDone: + statusIcon = "✅" + case db.StatusProcessing: + statusIcon = "🔄" + case db.StatusBlocked: + statusIcon = "🚫" + case db.StatusQueued: + statusIcon = "📋" + case db.StatusArchived: + statusIcon = "📦" + } + sb.WriteString(fmt.Sprintf("- %s #%d: %s (%s)\n", statusIcon, st.ID, st.Title, st.Status)) + } + } + + s.sendResult(id, toolCallResult{ + Content: []contentBlock{ + {Type: "text", Text: sb.String()}, + }, + }) + + case "taskyou_set_task_output": + output, _ := params.Arguments["output"].(string) + if output == "" { + s.sendError(id, -32602, "output is required") + return + } + + if err := s.db.SetTaskOutput(s.taskID, output); err != nil { + s.sendError(id, -32603, fmt.Sprintf("Failed to set task output: %v", err)) + return + } + + s.db.AppendTaskLog(s.taskID, "system", fmt.Sprintf("Task output stored (%d bytes)", len(output))) + + s.sendResult(id, toolCallResult{ + Content: []contentBlock{ + {Type: "text", Text: fmt.Sprintf("Output stored for task #%d (%d bytes). This will be available to downstream tasks and the parent workflow.", s.taskID, len(output))}, + }, + }) + + case "taskyou_get_task_output": + taskIDFloat, ok := params.Arguments["task_id"].(float64) + if !ok { + s.sendError(id, -32602, "task_id is required") + return + } + targetTaskID := int64(taskIDFloat) + + // Get current task for project access control + currentTask, err := s.db.GetTask(s.taskID) + if err != nil || currentTask == nil { + s.sendError(id, -32603, "Failed to get current task") + return + } + + // Get target task + targetTask, err := s.db.GetTask(targetTaskID) + if err != nil { + s.sendError(id, -32603, fmt.Sprintf("Failed to get task: %v", err)) + return + } + if targetTask == nil { + s.sendError(id, -32602, fmt.Sprintf("Task #%d not found", targetTaskID)) + return + } + + // Enforce project isolation + if targetTask.Project != currentTask.Project { + s.sendError(id, -32602, fmt.Sprintf("Task #%d is in a different project and cannot be accessed", targetTaskID)) + return + } + + output, err := s.db.GetTaskOutput(targetTaskID) + if err != nil { + s.sendError(id, -32603, fmt.Sprintf("Failed to get task output: %v", err)) + return + } + + if output == "" { + s.sendResult(id, toolCallResult{ + Content: []contentBlock{ + {Type: "text", Text: fmt.Sprintf("Task #%d (%s) has no stored output.", targetTaskID, targetTask.Title)}, + }, + }) + return + } + + s.sendResult(id, toolCallResult{ + Content: []contentBlock{ + {Type: "text", Text: fmt.Sprintf("## Output from Task #%d: %s\n\n%s", targetTaskID, targetTask.Title, output)}, + }, + }) + default: s.sendError(id, -32602, fmt.Sprintf("Unknown tool: %s", params.Name)) } diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index ea842c46..37eb4766 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -1338,6 +1338,244 @@ func TestSpotlightSync(t *testing.T) { } } +// callTool is a helper to make a tool call and return the response. +func callTool(t *testing.T, database *db.DB, taskID int64, toolName string, args map[string]interface{}) jsonRPCResponse { + t.Helper() + request := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]interface{}{ + "name": toolName, + "arguments": args, + }, + } + reqBytes, _ := json.Marshal(request) + reqBytes = append(reqBytes, '\n') + + server, output := testServer(database, taskID, string(reqBytes)) + server.Run() + + var resp jsonRPCResponse + if err := json.Unmarshal(output.Bytes(), &resp); err != nil { + t.Fatalf("failed to parse response: %v", err) + } + return resp +} + +func TestOrchestrationCreateSubtask(t *testing.T) { + database := testDB(t) + task := createTestTask(t, database) + + resp := callTool(t, database, task.ID, "taskyou_create_subtask", map[string]interface{}{ + "title": "Build the frontend", + "body": "Create the React components", + "status": "backlog", + }) + + if resp.Error != nil { + t.Fatalf("unexpected error: %s", resp.Error.Message) + } + + // Verify subtask was created with correct parent + subtasks, err := database.GetSubtasks(task.ID) + if err != nil { + t.Fatalf("failed to get subtasks: %v", err) + } + if len(subtasks) != 1 { + t.Fatalf("expected 1 subtask, got %d", len(subtasks)) + } + if subtasks[0].Title != "Build the frontend" { + t.Errorf("expected title 'Build the frontend', got %q", subtasks[0].Title) + } + if subtasks[0].ParentID == nil || *subtasks[0].ParentID != task.ID { + t.Error("subtask should have parent_id pointing to the current task") + } + if subtasks[0].Project != task.Project { + t.Errorf("subtask should inherit project %q, got %q", task.Project, subtasks[0].Project) + } +} + +func TestOrchestrationCreateSubtaskWithDependencies(t *testing.T) { + database := testDB(t) + task := createTestTask(t, database) + + // Create first subtask + resp1 := callTool(t, database, task.ID, "taskyou_create_subtask", map[string]interface{}{ + "title": "Step 1: Design", + "status": "queued", + }) + if resp1.Error != nil { + t.Fatalf("unexpected error creating subtask 1: %s", resp1.Error.Message) + } + + // Get the first subtask's ID + subtasks, _ := database.GetSubtasks(task.ID) + if len(subtasks) != 1 { + t.Fatalf("expected 1 subtask, got %d", len(subtasks)) + } + sub1ID := subtasks[0].ID + + // Create second subtask that depends on the first + resp2 := callTool(t, database, task.ID, "taskyou_create_subtask", map[string]interface{}{ + "title": "Step 2: Implement", + "status": "backlog", + "depends_on": []interface{}{float64(sub1ID)}, + }) + if resp2.Error != nil { + t.Fatalf("unexpected error creating subtask 2: %s", resp2.Error.Message) + } + + // Verify dependency was created + subtasks, _ = database.GetSubtasks(task.ID) + if len(subtasks) != 2 { + t.Fatalf("expected 2 subtasks, got %d", len(subtasks)) + } + + blockers, err := database.GetBlockers(subtasks[1].ID) + if err != nil { + t.Fatalf("failed to get blockers: %v", err) + } + if len(blockers) != 1 { + t.Errorf("expected 1 blocker, got %d", len(blockers)) + } +} + +func TestOrchestrationWorkflowStatus(t *testing.T) { + database := testDB(t) + task := createTestTask(t, database) + + // Create some subtasks + parentID := task.ID + sub1 := &db.Task{Title: "Sub 1", Status: db.StatusDone, Project: task.Project, ParentID: &parentID} + sub2 := &db.Task{Title: "Sub 2", Status: db.StatusProcessing, Project: task.Project, ParentID: &parentID} + database.CreateTask(sub1) + database.CreateTask(sub2) + + resp := callTool(t, database, task.ID, "taskyou_get_workflow_status", map[string]interface{}{}) + if resp.Error != nil { + t.Fatalf("unexpected error: %s", resp.Error.Message) + } + + // Check response contains status info + result, ok := resp.Result.(map[string]interface{}) + if !ok { + t.Fatal("expected result to be a map") + } + content, ok := result["content"].([]interface{}) + if !ok || len(content) == 0 { + t.Fatal("expected content array") + } + block := content[0].(map[string]interface{}) + text := block["text"].(string) + + if !strings.Contains(text, "Total subtasks") { + t.Error("expected workflow status to contain total subtasks") + } + if !strings.Contains(text, "Done") { + t.Error("expected workflow status to contain done count") + } +} + +func TestOrchestrationSetAndGetTaskOutput(t *testing.T) { + database := testDB(t) + task := createTestTask(t, database) + + // Set output + resp := callTool(t, database, task.ID, "taskyou_set_task_output", map[string]interface{}{ + "output": "API endpoints: /users, /posts, /comments", + }) + if resp.Error != nil { + t.Fatalf("unexpected error setting output: %s", resp.Error.Message) + } + + // Create another task in the same project to read the output + task2 := &db.Task{Title: "Reader task", Status: db.StatusProcessing, Project: task.Project} + database.CreateTask(task2) + + // Get output from the first task + resp = callTool(t, database, task2.ID, "taskyou_get_task_output", map[string]interface{}{ + "task_id": float64(task.ID), + }) + if resp.Error != nil { + t.Fatalf("unexpected error getting output: %s", resp.Error.Message) + } + + result, ok := resp.Result.(map[string]interface{}) + if !ok { + t.Fatal("expected result to be a map") + } + content := result["content"].([]interface{}) + block := content[0].(map[string]interface{}) + text := block["text"].(string) + + if !strings.Contains(text, "API endpoints") { + t.Error("expected output to contain the stored text") + } +} + +func TestOrchestrationGetTaskOutputProjectIsolation(t *testing.T) { + database := testDB(t) + task := createTestTask(t, database) + + // Create another project and task + database.CreateProject(&db.Project{Name: "other-project", Path: "/tmp/other"}) + otherTask := &db.Task{Title: "Other task", Status: db.StatusProcessing, Project: "other-project"} + database.CreateTask(otherTask) + database.SetTaskOutput(otherTask.ID, "secret data") + + // Try to read the other project's task output - should fail + resp := callTool(t, database, task.ID, "taskyou_get_task_output", map[string]interface{}{ + "task_id": float64(otherTask.ID), + }) + if resp.Error == nil { + t.Error("expected error when accessing task from different project") + } +} + +func TestOrchestrationToolsInToolsList(t *testing.T) { + database := testDB(t) + task := createTestTask(t, database) + + request := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list", + } + reqBytes, _ := json.Marshal(request) + reqBytes = append(reqBytes, '\n') + + server, output := testServer(database, task.ID, string(reqBytes)) + server.Run() + + var resp jsonRPCResponse + json.Unmarshal(output.Bytes(), &resp) + + result := resp.Result.(map[string]interface{}) + tools := result["tools"].([]interface{}) + + expectedTools := map[string]bool{ + "taskyou_create_subtask": false, + "taskyou_get_workflow_status": false, + "taskyou_set_task_output": false, + "taskyou_get_task_output": false, + } + + for _, toolI := range tools { + tool := toolI.(map[string]interface{}) + name := tool["name"].(string) + if _, ok := expectedTools[name]; ok { + expectedTools[name] = true + } + } + + for name, found := range expectedTools { + if !found { + t.Errorf("orchestration tool %q not found in tools list", name) + } + } +} + // runGit is a helper to run git commands in tests func runGit(t *testing.T, dir string, args ...string) { t.Helper()