diff --git a/cmd/resume.go b/cmd/resume.go index b57bccf..dff8153 100644 --- a/cmd/resume.go +++ b/cmd/resume.go @@ -133,6 +133,11 @@ func openNewSession(wt worktree.Worktree, t terminal.Terminal) error { if wt.Type != worktree.TypePRReview { initialPrompt = "" action = "Starting new session" + } else { + // Ensure /review-pr command is installed + if err := ensureClaudeCommand("review-pr"); err != nil { + ui.LogInfo(fmt.Sprintf("Warning: could not install /review-pr command: %v", err)) + } } if resumeNoITerm { diff --git a/cmd/review.go b/cmd/review.go index 0ea51b9..3d0b9ed 100644 --- a/cmd/review.go +++ b/cmd/review.go @@ -9,9 +9,8 @@ import ( "strconv" "strings" - ctxpkg "github.com/mgreau/zen/internal/context" "github.com/mgreau/zen/internal/github" - "github.com/mgreau/zen/internal/prcache" + "github.com/mgreau/zen/internal/review" "github.com/mgreau/zen/internal/terminal" "github.com/mgreau/zen/internal/ui" wt "github.com/mgreau/zen/internal/worktree" @@ -63,14 +62,6 @@ func init() { rootCmd.AddCommand(reviewCmd) } -// ReviewResult holds the output for --json mode. -type ReviewResult struct { - WorktreePath string `json:"worktree_path"` - PRNumber int `json:"pr_number"` - Title string `json:"title"` - Author string `json:"author"` -} - func runReview(cmd *cobra.Command, args []string) error { if len(args) != 1 { return cmd.Help() @@ -91,99 +82,48 @@ func runReview(cmd *cobra.Command, args []string) error { reviewRepo = detected } - // Validate repo exists in config + // Check if worktree already exists and resume basePath := cfg.RepoBasePath(reviewRepo) - if basePath == "" { - return fmt.Errorf("unknown repo %q — check ~/.zen/config.yaml", reviewRepo) - } - fullRepo := cfg.RepoFullName(reviewRepo) - - // Construct paths - originPath := filepath.Join(basePath, reviewRepo) - worktreeName := fmt.Sprintf("%s-pr-%d", reviewRepo, prNumber) - worktreePath := filepath.Join(basePath, worktreeName) - - // If worktree already exists, resume it - if _, err := os.Stat(worktreePath); err == nil { - ui.LogInfo(fmt.Sprintf("Worktree already exists, resuming PR #%d...", prNumber)) - // Pass model through to resume path - if reviewModel != "" { - resumeModel = reviewModel + if basePath != "" { + worktreeName := fmt.Sprintf("%s-pr-%d", reviewRepo, prNumber) + worktreePath := filepath.Join(basePath, worktreeName) + if _, err := os.Stat(worktreePath); err == nil { + ui.LogInfo(fmt.Sprintf("Worktree already exists, resuming PR #%d...", prNumber)) + if reviewModel != "" { + resumeModel = reviewModel + } + return openReviewTab(worktreePath, worktreeName) } - return openReviewTab(worktreePath, worktreeName) } - // Fetch PR details from GitHub - ui.LogInfo(fmt.Sprintf("Fetching PR #%d from %s...", prNumber, fullRepo)) - client, err := github.NewClient(ctx) - if err != nil { - return fmt.Errorf("creating GitHub client: %w", err) - } - details, err := client.GetPRDetails(ctx, fullRepo, prNumber) + // Create worktree using shared logic + result, err := review.CreateWorktree(ctx, cfg, reviewRepo, prNumber) if err != nil { - return fmt.Errorf("fetching PR details: %w", err) - } - - ui.LogInfo(fmt.Sprintf("PR #%d: %s (by %s)", prNumber, details.Title, details.Author)) - - // Create worktree under lock - branchName := fmt.Sprintf("pr-%d", prNumber) - - wt.GitMu.Lock() - - ui.LogInfo(fmt.Sprintf("Fetching pull/%d/head...", prNumber)) - fetchCmd := exec.Command("git", "fetch", "origin", fmt.Sprintf("+pull/%d/head:%s", prNumber, branchName)) - fetchCmd.Dir = originPath - if out, err := fetchCmd.CombinedOutput(); err != nil { - wt.GitMu.Unlock() - return fmt.Errorf("git fetch: %w: %s", err, string(out)) - } - - ui.LogInfo(fmt.Sprintf("Creating worktree %s...", worktreeName)) - wtCmd := exec.Command("git", "worktree", "add", worktreePath, branchName) - wtCmd.Dir = originPath - if out, err := wtCmd.CombinedOutput(); err != nil { - wt.GitMu.Unlock() - return fmt.Errorf("git worktree add: %w: %s", err, string(out)) - } - - // Clean stale index.lock - lockFile := filepath.Join(originPath, ".git", "worktrees", worktreeName, "index.lock") - os.Remove(lockFile) - - wt.GitMu.Unlock() - - // Inject PR context into CLAUDE.local.md - ui.LogInfo("Injecting PR context into CLAUDE.local.md...") - if err := ctxpkg.InjectPRContext(ctx, worktreePath, fullRepo, prNumber); err != nil { - ui.LogInfo(fmt.Sprintf("Warning: failed to inject context: %v", err)) + return err } - // Cache PR metadata - prcache.Set(reviewRepo, prNumber, details.Title, details.Author) - home := homeDir() - shortPath := ui.ShortenHome(worktreePath, home) + shortPath := ui.ShortenHome(result.WorktreePath, home) if jsonFlag { - printJSON(ReviewResult{ - WorktreePath: worktreePath, - PRNumber: prNumber, - Title: details.Title, - Author: details.Author, - }) + printJSON(result) return nil } fmt.Println() ui.LogSuccess(fmt.Sprintf("Created worktree: %s", shortPath)) - fmt.Printf(" PR: #%d — %s\n", prNumber, details.Title) - fmt.Printf(" Author: %s\n", details.Author) + fmt.Printf(" PR: #%d — %s\n", result.PRNumber, result.Title) + fmt.Printf(" Author: %s\n", result.Author) if reviewModel != "" { fmt.Printf(" Model: %s\n", ui.CyanText(reviewModel)) } + // Ensure /review-pr command is installed + if err := ensureClaudeCommand("review-pr"); err != nil { + ui.LogInfo(fmt.Sprintf("Warning: could not install /review-pr command: %v", err)) + } + if reviewNoITerm { fmt.Println() fmt.Println(ui.BoldText("Open manually:")) @@ -191,7 +131,7 @@ func runReview(cmd *cobra.Command, args []string) error { if reviewModel != "" { modelFlag = fmt.Sprintf(" --model %s", reviewModel) } - fmt.Printf(" cd %s && %s%s \"/review-pr\"\n", worktreePath, cfg.ClaudeBin, modelFlag) + fmt.Printf(" cd %s && %s%s \"/review-pr\"\n", result.WorktreePath, cfg.ClaudeBin, modelFlag) return nil } @@ -201,7 +141,7 @@ func runReview(cmd *cobra.Command, args []string) error { return err } - if err := term.OpenTabWithClaude(worktreePath, "/review-pr", cfg.ClaudeBin, reviewModel); err != nil { + if err := term.OpenTabWithClaude(result.WorktreePath, "/review-pr", cfg.ClaudeBin, reviewModel); err != nil { return fmt.Errorf("opening %s tab: %w", term.Name(), err) } diff --git a/cmd/setup.go b/cmd/setup.go index fa97a94..96f3bb4 100644 --- a/cmd/setup.go +++ b/cmd/setup.go @@ -182,6 +182,38 @@ func promptRequired(scanner *bufio.Scanner, label string) string { } } +// ensureClaudeCommand checks if a specific Claude command file exists and +// installs it silently from the embedded FS if missing. Returns true if the +// command was installed (or already existed). +func ensureClaudeCommand(name string) error { + home, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("resolving home directory: %w", err) + } + targetDir := filepath.Join(home, ".claude", "commands") + dst := filepath.Join(targetDir, name+".md") + + if _, err := os.Stat(dst); err == nil { + return nil // already exists + } + + srcData, err := fs.ReadFile(EmbeddedCommands, filepath.Join("commands", name+".md")) + if err != nil { + return fmt.Errorf("reading embedded %s.md: %w", name, err) + } + + if err := os.MkdirAll(targetDir, 0o755); err != nil { + return fmt.Errorf("creating %s: %w", targetDir, err) + } + + if err := os.WriteFile(dst, srcData, 0o644); err != nil { + return fmt.Errorf("writing %s: %w", dst, err) + } + + ui.LogInfo(fmt.Sprintf("Installed Claude command /%s", name)) + return nil +} + // installClaudeCommands prompts the user and installs embedded Claude Code // command files to ~/.claude/commands/. func installClaudeCommands(scanner *bufio.Scanner) (int, error) { diff --git a/cmd/work.go b/cmd/work.go index 8932833..d134061 100644 --- a/cmd/work.go +++ b/cmd/work.go @@ -171,9 +171,9 @@ func runWorkNew(cmd *cobra.Command, args []string) error { return fmt.Errorf("git worktree add: %w: %s", err, string(out)) } - // Clean stale index.lock + // Clean stale index.lock (only if holding process is dead) lockFile := filepath.Join(originPath, ".git", "worktrees", worktreeName, "index.lock") - os.Remove(lockFile) + wt.RemoveStaleLock(lockFile, worktreeName) wt.GitMu.Unlock() diff --git a/internal/github/client.go b/internal/github/client.go index 18f9840..7e0096b 100644 --- a/internal/github/client.go +++ b/internal/github/client.go @@ -5,11 +5,14 @@ import ( "fmt" "os/exec" "strings" + "time" gh "github.com/google/go-github/v75/github" "golang.org/x/oauth2" ) +const apiTimeout = 30 * time.Second + // Client wraps go-github with auth from `gh auth token`. type Client struct { gh *gh.Client @@ -17,23 +20,29 @@ type Client struct { // NewClient creates a GitHub client using the token from `gh auth token`. func NewClient(ctx context.Context) (*Client, error) { - token, err := ghAuthToken() + token, err := ghAuthToken(ctx) if err != nil { return nil, fmt.Errorf("getting GitHub token: %w", err) } ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token}) tc := oauth2.NewClient(ctx, ts) + tc.Timeout = apiTimeout client := gh.NewClient(tc) return &Client{gh: client}, nil } // ghAuthToken runs `gh auth token` and returns the token string. -func ghAuthToken() (string, error) { - cmd := exec.Command("gh", "auth", "token") +func ghAuthToken(ctx context.Context) (string, error) { + ctx, cancel := withTimeout(ctx) + defer cancel() + cmd := exec.CommandContext(ctx, "gh", "auth", "token") out, err := cmd.Output() if err != nil { + if ctx.Err() == context.DeadlineExceeded { + return "", fmt.Errorf("gh auth token timed out after %s", apiTimeout) + } return "", fmt.Errorf("gh auth token failed: %s (is gh CLI installed and authenticated?)", ghError(err)) } return strings.TrimSpace(string(out)), nil diff --git a/internal/github/queries.go b/internal/github/queries.go index e947236..09bd6fe 100644 --- a/internal/github/queries.go +++ b/internal/github/queries.go @@ -8,6 +8,15 @@ import ( "strings" ) +// withTimeout returns a context with apiTimeout applied, unless the caller +// already set a deadline. +func withTimeout(ctx context.Context) (context.Context, context.CancelFunc) { + if _, ok := ctx.Deadline(); ok { + return ctx, func() {} + } + return context.WithTimeout(ctx, apiTimeout) +} + // ghError extracts stderr from an exec.ExitError for better error messages. func ghError(err error) string { if ee, ok := err.(*exec.ExitError); ok && len(ee.Stderr) > 0 { @@ -50,9 +59,14 @@ type ApprovedPR struct { // GetCurrentUser returns the authenticated GitHub user's login. func GetCurrentUser(ctx context.Context) (string, error) { + ctx, cancel := withTimeout(ctx) + defer cancel() cmd := exec.CommandContext(ctx, "gh", "api", "user", "--jq", ".login") out, err := cmd.Output() if err != nil { + if ctx.Err() == context.DeadlineExceeded { + return "", fmt.Errorf("fetching current user timed out after %s", apiTimeout) + } return "", fmt.Errorf("fetching current user: %s", ghError(err)) } return strings.TrimSpace(string(out)), nil @@ -61,6 +75,8 @@ func GetCurrentUser(ctx context.Context) (string, error) { // GetReviewRequests fetches PRs where the user is a requested reviewer, // including re-reviews. Uses GraphQL via `gh api graphql`. func GetReviewRequests(ctx context.Context, repoFilter string) ([]ReviewRequest, error) { + ctx, cancel := withTimeout(ctx) + defer cancel() query := `query($q1: String!, $q2: String!) { requested: search(query: $q1, type: ISSUE, first: 50) { nodes { @@ -103,6 +119,9 @@ func GetReviewRequests(ctx context.Context, repoFilter string) ([]ReviewRequest, ) out, err := cmd.Output() if err != nil { + if ctx.Err() == context.DeadlineExceeded { + return nil, fmt.Errorf("review requests query timed out after %s", apiTimeout) + } return nil, fmt.Errorf("GraphQL query failed: %s", ghError(err)) } @@ -139,6 +158,8 @@ func GetReviewRequests(ctx context.Context, repoFilter string) ([]ReviewRequest, // GetApprovedUnmerged fetches the user's own PRs that are approved but not yet merged. func GetApprovedUnmerged(ctx context.Context, repoFilter string) ([]ApprovedPR, error) { + ctx, cancel := withTimeout(ctx) + defer cancel() query := `query($q: String!) { search(query: $q, type: ISSUE, first: 50) { nodes { @@ -168,6 +189,9 @@ func GetApprovedUnmerged(ctx context.Context, repoFilter string) ([]ApprovedPR, ) out, err := cmd.Output() if err != nil { + if ctx.Err() == context.DeadlineExceeded { + return nil, fmt.Errorf("approved PRs query timed out after %s", apiTimeout) + } return nil, fmt.Errorf("GraphQL query failed: %s", ghError(err)) } @@ -193,6 +217,8 @@ func GetApprovedUnmerged(ctx context.Context, repoFilter string) ([]ApprovedPR, // ListOpenPRs lists open PRs for a repository using `gh pr list`. func ListOpenPRs(ctx context.Context, fullRepo string, limit int) ([]ReviewRequest, error) { + ctx, cancel := withTimeout(ctx) + defer cancel() cmd := exec.CommandContext(ctx, "gh", "pr", "list", "-R", fullRepo, "--state", "open", @@ -201,6 +227,9 @@ func ListOpenPRs(ctx context.Context, fullRepo string, limit int) ([]ReviewReque ) out, err := cmd.Output() if err != nil { + if ctx.Err() == context.DeadlineExceeded { + return nil, fmt.Errorf("listing open PRs timed out after %s", apiTimeout) + } return nil, err } diff --git a/internal/github/queries_test.go b/internal/github/queries_test.go new file mode 100644 index 0000000..913f698 --- /dev/null +++ b/internal/github/queries_test.go @@ -0,0 +1,91 @@ +package github + +import ( + "context" + "strings" + "testing" + "time" +) + +func TestWithTimeout_addsDeadlineWhenNone(t *testing.T) { + ctx, cancel := withTimeout(context.Background()) + defer cancel() + + deadline, ok := ctx.Deadline() + if !ok { + t.Fatal("expected deadline to be set") + } + remaining := time.Until(deadline) + if remaining <= 0 || remaining > apiTimeout { + t.Fatalf("expected deadline within %s, got %s remaining", apiTimeout, remaining) + } +} + +func TestWithTimeout_preservesExistingDeadline(t *testing.T) { + existing := time.Now().Add(5 * time.Second) + parent, parentCancel := context.WithDeadline(context.Background(), existing) + defer parentCancel() + + ctx, cancel := withTimeout(parent) + defer cancel() + + deadline, ok := ctx.Deadline() + if !ok { + t.Fatal("expected deadline to be set") + } + if !deadline.Equal(existing) { + t.Fatalf("expected existing deadline %v, got %v", existing, deadline) + } +} + +func TestGetCurrentUser_timeoutError(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) + defer cancel() + + _, err := GetCurrentUser(ctx) + if err == nil { + t.Fatal("expected error from expired context") + } + if !strings.Contains(err.Error(), "timed out") { + t.Fatalf("expected timeout error message, got: %s", err) + } +} + +func TestGetReviewRequests_timeoutError(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) + defer cancel() + + _, err := GetReviewRequests(ctx, "") + if err == nil { + t.Fatal("expected error from expired context") + } + if !strings.Contains(err.Error(), "timed out") { + t.Fatalf("expected timeout error message, got: %s", err) + } +} + +func TestGetApprovedUnmerged_timeoutError(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) + defer cancel() + + _, err := GetApprovedUnmerged(ctx, "") + if err == nil { + t.Fatal("expected error from expired context") + } + if !strings.Contains(err.Error(), "timed out") { + t.Fatalf("expected timeout error message, got: %s", err) + } +} + +func TestListOpenPRs_timeoutError(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) + defer cancel() + + _, err := ListOpenPRs(ctx, "owner/repo", 10) + if err == nil { + t.Fatal("expected error from expired context") + } + if !strings.Contains(err.Error(), "timed out") { + t.Fatalf("expected timeout error message, got: %s", err) + } +} diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 5728a2f..6291ee4 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -106,4 +106,27 @@ func (s *Server) registerTools() { ), s.handleConfigRepos, ) + + s.server.AddTool( + mcpgo.NewTool("zen_review", + mcpgo.WithDescription("Create a worktree for a PR number (fetches branch, creates worktree, injects context)"), + mcpgo.WithNumber("pr_number", mcpgo.Description("Pull request number"), mcpgo.Required()), + mcpgo.WithString("repo", mcpgo.Description("Short repo name (auto-detected if omitted)")), + mcpgo.WithReadOnlyHintAnnotation(false), + mcpgo.WithDestructiveHintAnnotation(false), + mcpgo.WithOpenWorldHintAnnotation(true), + ), + s.handleReview, + ) + + s.server.AddTool( + mcpgo.NewTool("zen_review_resume", + mcpgo.WithDescription("Get resume info (worktree path and sessions) for an existing PR review worktree"), + mcpgo.WithNumber("pr_number", mcpgo.Description("Pull request number"), mcpgo.Required()), + mcpgo.WithReadOnlyHintAnnotation(true), + mcpgo.WithDestructiveHintAnnotation(false), + mcpgo.WithOpenWorldHintAnnotation(false), + ), + s.handleReviewResume, + ) } diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index 5c9b268..fb62cbe 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -152,6 +152,53 @@ func TestHandleAgentStatusNoSessions(t *testing.T) { } } +func TestHandleReviewMissingParams(t *testing.T) { + srv := New(testConfig()) + ctx := context.Background() + + // Missing required pr_number + result, err := srv.handleReview(ctx, makeRequest(nil)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.IsError { + t.Fatal("expected tool error for missing pr_number") + } +} + +func TestHandleReviewResumeMissingParams(t *testing.T) { + srv := New(testConfig()) + ctx := context.Background() + + // Missing required pr_number + result, err := srv.handleReviewResume(ctx, makeRequest(nil)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.IsError { + t.Fatal("expected tool error for missing pr_number") + } +} + +func TestHandleReviewResumeNoWorktree(t *testing.T) { + // Use paths that definitely don't have worktrees + cfg := &config.Config{ + Repos: map[string]config.RepoConfig{ + "fake": {FullName: "test/fake", BasePath: "/tmp/nonexistent-zen-test"}, + }, + } + srv := New(cfg) + ctx := context.Background() + + result, err := srv.handleReviewResume(ctx, makeRequest(map[string]any{"pr_number": 99999})) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.IsError { + t.Fatal("expected tool error for non-existent worktree") + } +} + func TestJsonResult(t *testing.T) { type testData struct { Name string `json:"name"` diff --git a/internal/mcp/tools.go b/internal/mcp/tools.go index a973dff..47f24dd 100644 --- a/internal/mcp/tools.go +++ b/internal/mcp/tools.go @@ -3,11 +3,13 @@ package coordmcp import ( "context" "encoding/json" + "fmt" "sort" "time" mcpgo "github.com/mark3labs/mcp-go/mcp" ghpkg "github.com/mgreau/zen/internal/github" + "github.com/mgreau/zen/internal/review" "github.com/mgreau/zen/internal/session" "github.com/mgreau/zen/internal/worktree" ) @@ -178,6 +180,66 @@ type repoEntry struct { BasePath string `json:"base_path"` } +// handleReview creates a worktree for a PR number. +func (s *Server) handleReview(ctx context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) { + prNumber, err := req.RequireInt("pr_number") + if err != nil { + return mcpgo.NewToolResultError(err.Error()), nil + } + + repoShort := req.GetString("repo", "") + if repoShort == "" { + detected, err := review.DetectRepo(ctx, s.cfg, prNumber) + if err != nil { + return mcpgo.NewToolResultError(err.Error()), nil + } + repoShort = detected + } + + result, err := review.CreateWorktree(ctx, s.cfg, repoShort, prNumber) + if err != nil { + return mcpgo.NewToolResultError(err.Error()), nil + } + + return jsonResult(result) +} + +// reviewResumeEntry holds the response for zen_review_resume. +type reviewResumeEntry struct { + WorktreePath string `json:"worktree_path"` + Name string `json:"name"` + Sessions []session.Session `json:"sessions"` +} + +// handleReviewResume gets resume info for an existing PR worktree. +func (s *Server) handleReviewResume(ctx context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) { + prNumber, err := req.RequireInt("pr_number") + if err != nil { + return mcpgo.NewToolResultError(err.Error()), nil + } + + wts, err := worktree.ListAll(s.cfg) + if err != nil { + return mcpgo.NewToolResultError("failed to list worktrees: " + err.Error()), nil + } + + for _, wt := range wts { + if wt.Type == worktree.TypePRReview && wt.PRNumber == prNumber { + sessions, _ := session.FindSessions(wt.Path) + if sessions == nil { + sessions = []session.Session{} + } + return jsonResult(reviewResumeEntry{ + WorktreePath: wt.Path, + Name: wt.Name, + Sessions: sessions, + }) + } + } + + return mcpgo.NewToolResultError(fmt.Sprintf("no PR review worktree for #%d", prNumber)), nil +} + // handleConfigRepos lists configured repositories. func (s *Server) handleConfigRepos(ctx context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) { var repos []repoEntry diff --git a/internal/reconciler/setup.go b/internal/reconciler/setup.go index ba437b1..6e621f3 100644 --- a/internal/reconciler/setup.go +++ b/internal/reconciler/setup.go @@ -132,9 +132,9 @@ func (r *SetupReconciler) ensureWorktree(originPath, worktreePath, worktreeName return fmt.Errorf("git worktree add: %w: %s", err, string(out)) } - // Clean stale lock immediately + // Clean stale index.lock (only if holding process is dead) lockFile := filepath.Join(originPath, ".git", "worktrees", worktreeName, "index.lock") - os.Remove(lockFile) + wt.RemoveStaleLock(lockFile, worktreeName) return nil } diff --git a/internal/review/review.go b/internal/review/review.go new file mode 100644 index 0000000..5849943 --- /dev/null +++ b/internal/review/review.go @@ -0,0 +1,165 @@ +// Package review provides shared logic for creating PR review worktrees. +// Both the CLI commands and the MCP server use this package. +package review + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + + ctxpkg "github.com/mgreau/zen/internal/context" + "github.com/mgreau/zen/internal/config" + "github.com/mgreau/zen/internal/github" + "github.com/mgreau/zen/internal/prcache" + "github.com/mgreau/zen/internal/ui" + wt "github.com/mgreau/zen/internal/worktree" +) + +// Result holds the output of a successful worktree creation. +type Result struct { + WorktreePath string `json:"worktree_path"` + PRNumber int `json:"pr_number"` + Title string `json:"title"` + Author string `json:"author"` +} + +// CreateWorktree creates a PR review worktree. It fetches the PR branch, +// creates the git worktree, injects CLAUDE.local.md context, and caches +// PR metadata. Returns the result or an error. +// +// If the worktree already exists, returns a Result with the existing path. +// The caller is responsible for detecting the repo if repoShort is empty. +func CreateWorktree(ctx context.Context, cfg *config.Config, repoShort string, prNumber int) (*Result, error) { + basePath := cfg.RepoBasePath(repoShort) + if basePath == "" { + return nil, fmt.Errorf("unknown repo %q — check ~/.zen/config.yaml", repoShort) + } + fullRepo := cfg.RepoFullName(repoShort) + + originPath := filepath.Join(basePath, repoShort) + worktreeName := fmt.Sprintf("%s-pr-%d", repoShort, prNumber) + worktreePath := filepath.Join(basePath, worktreeName) + + // If worktree already exists, return it + if _, err := os.Stat(worktreePath); err == nil { + // Try to get cached metadata + meta, ok := prcache.Get(repoShort, prNumber) + title, author := "", "" + if ok { + title = meta.Title + author = meta.Author + } + return &Result{ + WorktreePath: worktreePath, + PRNumber: prNumber, + Title: title, + Author: author, + }, nil + } + + // Fetch PR details from GitHub + ui.LogInfo(fmt.Sprintf("Fetching PR #%d from %s...", prNumber, fullRepo)) + client, err := github.NewClient(ctx) + if err != nil { + return nil, fmt.Errorf("creating GitHub client: %w", err) + } + details, err := client.GetPRDetails(ctx, fullRepo, prNumber) + if err != nil { + return nil, fmt.Errorf("fetching PR details: %w", err) + } + + ui.LogInfo(fmt.Sprintf("PR #%d: %s (by %s)", prNumber, details.Title, details.Author)) + + // Create worktree under lock + branchName := fmt.Sprintf("pr-%d", prNumber) + + wt.GitMu.Lock() + + ui.LogInfo(fmt.Sprintf("Fetching pull/%d/head...", prNumber)) + fetchCmd := exec.Command("git", "fetch", "origin", fmt.Sprintf("+pull/%d/head:%s", prNumber, branchName)) + fetchCmd.Dir = originPath + if out, err := fetchCmd.CombinedOutput(); err != nil { + wt.GitMu.Unlock() + return nil, fmt.Errorf("git fetch: %w: %s", err, string(out)) + } + + ui.LogInfo(fmt.Sprintf("Creating worktree %s...", worktreeName)) + wtCmd := exec.Command("git", "worktree", "add", worktreePath, branchName) + wtCmd.Dir = originPath + if out, err := wtCmd.CombinedOutput(); err != nil { + wt.GitMu.Unlock() + return nil, fmt.Errorf("git worktree add: %w: %s", err, string(out)) + } + + // Clean stale index.lock (only if holding process is dead) + lockFile := filepath.Join(originPath, ".git", "worktrees", worktreeName, "index.lock") + wt.RemoveStaleLock(lockFile, worktreeName) + + wt.GitMu.Unlock() + + // Inject PR context into CLAUDE.local.md + ui.LogInfo("Injecting PR context into CLAUDE.local.md...") + if err := ctxpkg.InjectPRContext(ctx, worktreePath, fullRepo, prNumber); err != nil { + ui.LogInfo(fmt.Sprintf("Warning: failed to inject context: %v", err)) + } + + // Cache PR metadata + prcache.Set(repoShort, prNumber, details.Title, details.Author) + + return &Result{ + WorktreePath: worktreePath, + PRNumber: prNumber, + Title: details.Title, + Author: details.Author, + }, nil +} + +// DetectRepo tries each configured repo to find which one contains the +// given PR number. Returns the repo short name or an error. +// Unlike the CLI version, this does not prompt interactively — it returns +// an error if ambiguous. +func DetectRepo(ctx context.Context, cfg *config.Config, prNumber int) (string, error) { + repos := cfg.RepoNames() + if len(repos) == 1 { + return repos[0], nil + } + + client, err := github.NewClient(ctx) + if err != nil { + return "", fmt.Errorf("creating GitHub client: %w", err) + } + + var matches []string + for _, repo := range repos { + fullRepo := cfg.RepoFullName(repo) + if _, err := client.GetPRDetails(ctx, fullRepo, prNumber); err == nil { + matches = append(matches, repo) + } + } + + switch len(matches) { + case 0: + return "", fmt.Errorf("PR #%d not found in any configured repo", prNumber) + case 1: + return matches[0], nil + default: + // Try reviewer heuristic + currentUser, _ := github.GetCurrentUser(ctx) + if currentUser != "" { + var reviewMatches []string + for _, repo := range matches { + fullRepo := cfg.RepoFullName(repo) + if ok, _ := client.IsRequestedReviewer(ctx, fullRepo, prNumber, currentUser); ok { + reviewMatches = append(reviewMatches, repo) + } + } + if len(reviewMatches) == 1 { + return reviewMatches[0], nil + } + } + return "", fmt.Errorf("PR #%d exists in multiple repos (%s) — specify with repo parameter", + prNumber, fmt.Sprintf("%v", matches)) + } +} diff --git a/internal/worktree/lock.go b/internal/worktree/lock.go index a94d852..2b5f568 100644 --- a/internal/worktree/lock.go +++ b/internal/worktree/lock.go @@ -40,12 +40,12 @@ func CleanStaleLocks(cfg *config.Config, repo string) { continue } lockFile := filepath.Join(worktreesDir, entry.Name(), "index.lock") - removeStaleLock(lockFile, entry.Name()) + RemoveStaleLock(lockFile, entry.Name()) } // Also check the main repo's own index.lock mainLock := filepath.Join(gitDir, "index.lock") - removeStaleLock(mainLock, repo) + RemoveStaleLock(mainLock, repo) } // CleanAllStaleLocks cleans stale locks across all known repos. @@ -55,7 +55,9 @@ func CleanAllStaleLocks(cfg *config.Config) { } } -func removeStaleLock(lockFile, name string) { +// RemoveStaleLock removes an index.lock file only if the holding process +// is no longer running. Safe to call if the file does not exist. +func RemoveStaleLock(lockFile, name string) { data, err := os.ReadFile(lockFile) if err != nil { return // file doesn't exist or can't be read