Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cmd/resume.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ func openNewSession(wt worktree.Worktree, t terminal.Terminal) error {
if wt.Type != worktree.TypePRReview {
initialPrompt = ""
action = "Starting new session"
} else {
// Ensure /review-pr command is installed
if err := ensureClaudeCommand("review-pr"); err != nil {
ui.LogInfo(fmt.Sprintf("Warning: could not install /review-pr command: %v", err))
}
}

if resumeNoITerm {
Expand Down
110 changes: 25 additions & 85 deletions cmd/review.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ import (
"strconv"
"strings"

ctxpkg "github.com/mgreau/zen/internal/context"
"github.com/mgreau/zen/internal/github"
"github.com/mgreau/zen/internal/prcache"
"github.com/mgreau/zen/internal/review"
"github.com/mgreau/zen/internal/terminal"
"github.com/mgreau/zen/internal/ui"
wt "github.com/mgreau/zen/internal/worktree"
Expand Down Expand Up @@ -63,14 +62,6 @@ func init() {
rootCmd.AddCommand(reviewCmd)
}

// ReviewResult holds the output for --json mode.
type ReviewResult struct {
WorktreePath string `json:"worktree_path"`
PRNumber int `json:"pr_number"`
Title string `json:"title"`
Author string `json:"author"`
}

func runReview(cmd *cobra.Command, args []string) error {
if len(args) != 1 {
return cmd.Help()
Expand All @@ -91,107 +82,56 @@ func runReview(cmd *cobra.Command, args []string) error {
reviewRepo = detected
}

// Validate repo exists in config
// Check if worktree already exists and resume
basePath := cfg.RepoBasePath(reviewRepo)
if basePath == "" {
return fmt.Errorf("unknown repo %q — check ~/.zen/config.yaml", reviewRepo)
}
fullRepo := cfg.RepoFullName(reviewRepo)

// Construct paths
originPath := filepath.Join(basePath, reviewRepo)
worktreeName := fmt.Sprintf("%s-pr-%d", reviewRepo, prNumber)
worktreePath := filepath.Join(basePath, worktreeName)

// If worktree already exists, resume it
if _, err := os.Stat(worktreePath); err == nil {
ui.LogInfo(fmt.Sprintf("Worktree already exists, resuming PR #%d...", prNumber))
// Pass model through to resume path
if reviewModel != "" {
resumeModel = reviewModel
if basePath != "" {
worktreeName := fmt.Sprintf("%s-pr-%d", reviewRepo, prNumber)
worktreePath := filepath.Join(basePath, worktreeName)
if _, err := os.Stat(worktreePath); err == nil {
ui.LogInfo(fmt.Sprintf("Worktree already exists, resuming PR #%d...", prNumber))
if reviewModel != "" {
resumeModel = reviewModel
}
return openReviewTab(worktreePath, worktreeName)
}
return openReviewTab(worktreePath, worktreeName)
}

// Fetch PR details from GitHub
ui.LogInfo(fmt.Sprintf("Fetching PR #%d from %s...", prNumber, fullRepo))
client, err := github.NewClient(ctx)
if err != nil {
return fmt.Errorf("creating GitHub client: %w", err)
}
details, err := client.GetPRDetails(ctx, fullRepo, prNumber)
// Create worktree using shared logic
result, err := review.CreateWorktree(ctx, cfg, reviewRepo, prNumber)
if err != nil {
return fmt.Errorf("fetching PR details: %w", err)
}

ui.LogInfo(fmt.Sprintf("PR #%d: %s (by %s)", prNumber, details.Title, details.Author))

// Create worktree under lock
branchName := fmt.Sprintf("pr-%d", prNumber)

wt.GitMu.Lock()

ui.LogInfo(fmt.Sprintf("Fetching pull/%d/head...", prNumber))
fetchCmd := exec.Command("git", "fetch", "origin", fmt.Sprintf("+pull/%d/head:%s", prNumber, branchName))
fetchCmd.Dir = originPath
if out, err := fetchCmd.CombinedOutput(); err != nil {
wt.GitMu.Unlock()
return fmt.Errorf("git fetch: %w: %s", err, string(out))
}

ui.LogInfo(fmt.Sprintf("Creating worktree %s...", worktreeName))
wtCmd := exec.Command("git", "worktree", "add", worktreePath, branchName)
wtCmd.Dir = originPath
if out, err := wtCmd.CombinedOutput(); err != nil {
wt.GitMu.Unlock()
return fmt.Errorf("git worktree add: %w: %s", err, string(out))
}

// Clean stale index.lock
lockFile := filepath.Join(originPath, ".git", "worktrees", worktreeName, "index.lock")
os.Remove(lockFile)

wt.GitMu.Unlock()

// Inject PR context into CLAUDE.local.md
ui.LogInfo("Injecting PR context into CLAUDE.local.md...")
if err := ctxpkg.InjectPRContext(ctx, worktreePath, fullRepo, prNumber); err != nil {
ui.LogInfo(fmt.Sprintf("Warning: failed to inject context: %v", err))
return err
}

// Cache PR metadata
prcache.Set(reviewRepo, prNumber, details.Title, details.Author)

home := homeDir()
shortPath := ui.ShortenHome(worktreePath, home)
shortPath := ui.ShortenHome(result.WorktreePath, home)

if jsonFlag {
printJSON(ReviewResult{
WorktreePath: worktreePath,
PRNumber: prNumber,
Title: details.Title,
Author: details.Author,
})
printJSON(result)
return nil
}

fmt.Println()
ui.LogSuccess(fmt.Sprintf("Created worktree: %s", shortPath))
fmt.Printf(" PR: #%d — %s\n", prNumber, details.Title)
fmt.Printf(" Author: %s\n", details.Author)
fmt.Printf(" PR: #%d — %s\n", result.PRNumber, result.Title)
fmt.Printf(" Author: %s\n", result.Author)

if reviewModel != "" {
fmt.Printf(" Model: %s\n", ui.CyanText(reviewModel))
}

// Ensure /review-pr command is installed
if err := ensureClaudeCommand("review-pr"); err != nil {
ui.LogInfo(fmt.Sprintf("Warning: could not install /review-pr command: %v", err))
}

if reviewNoITerm {
fmt.Println()
fmt.Println(ui.BoldText("Open manually:"))
modelFlag := ""
if reviewModel != "" {
modelFlag = fmt.Sprintf(" --model %s", reviewModel)
}
fmt.Printf(" cd %s && %s%s \"/review-pr\"\n", worktreePath, cfg.ClaudeBin, modelFlag)
fmt.Printf(" cd %s && %s%s \"/review-pr\"\n", result.WorktreePath, cfg.ClaudeBin, modelFlag)
return nil
}

Expand All @@ -201,7 +141,7 @@ func runReview(cmd *cobra.Command, args []string) error {
return err
}

if err := term.OpenTabWithClaude(worktreePath, "/review-pr", cfg.ClaudeBin, reviewModel); err != nil {
if err := term.OpenTabWithClaude(result.WorktreePath, "/review-pr", cfg.ClaudeBin, reviewModel); err != nil {
return fmt.Errorf("opening %s tab: %w", term.Name(), err)
}

Expand Down
32 changes: 32 additions & 0 deletions cmd/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,38 @@ func promptRequired(scanner *bufio.Scanner, label string) string {
}
}

// ensureClaudeCommand checks if a specific Claude command file exists and
// installs it silently from the embedded FS if missing. Returns true if the
// command was installed (or already existed).
func ensureClaudeCommand(name string) error {
home, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("resolving home directory: %w", err)
}
targetDir := filepath.Join(home, ".claude", "commands")
dst := filepath.Join(targetDir, name+".md")

if _, err := os.Stat(dst); err == nil {
return nil // already exists
}

srcData, err := fs.ReadFile(EmbeddedCommands, filepath.Join("commands", name+".md"))
if err != nil {
return fmt.Errorf("reading embedded %s.md: %w", name, err)
}

if err := os.MkdirAll(targetDir, 0o755); err != nil {
return fmt.Errorf("creating %s: %w", targetDir, err)
}

if err := os.WriteFile(dst, srcData, 0o644); err != nil {
return fmt.Errorf("writing %s: %w", dst, err)
}

ui.LogInfo(fmt.Sprintf("Installed Claude command /%s", name))
return nil
}

// installClaudeCommands prompts the user and installs embedded Claude Code
// command files to ~/.claude/commands/.
func installClaudeCommands(scanner *bufio.Scanner) (int, error) {
Expand Down
4 changes: 2 additions & 2 deletions cmd/work.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ func runWorkNew(cmd *cobra.Command, args []string) error {
return fmt.Errorf("git worktree add: %w: %s", err, string(out))
}

// Clean stale index.lock
// Clean stale index.lock (only if holding process is dead)
lockFile := filepath.Join(originPath, ".git", "worktrees", worktreeName, "index.lock")
os.Remove(lockFile)
wt.RemoveStaleLock(lockFile, worktreeName)

wt.GitMu.Unlock()

Expand Down
15 changes: 12 additions & 3 deletions internal/github/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,44 @@ import (
"fmt"
"os/exec"
"strings"
"time"

gh "github.com/google/go-github/v75/github"
"golang.org/x/oauth2"
)

const apiTimeout = 30 * time.Second

// Client wraps go-github with auth from `gh auth token`.
type Client struct {
gh *gh.Client
}

// NewClient creates a GitHub client using the token from `gh auth token`.
func NewClient(ctx context.Context) (*Client, error) {
token, err := ghAuthToken()
token, err := ghAuthToken(ctx)
if err != nil {
return nil, fmt.Errorf("getting GitHub token: %w", err)
}

ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token})
tc := oauth2.NewClient(ctx, ts)
tc.Timeout = apiTimeout
client := gh.NewClient(tc)

return &Client{gh: client}, nil
}

// ghAuthToken runs `gh auth token` and returns the token string.
func ghAuthToken() (string, error) {
cmd := exec.Command("gh", "auth", "token")
func ghAuthToken(ctx context.Context) (string, error) {
ctx, cancel := withTimeout(ctx)
defer cancel()
cmd := exec.CommandContext(ctx, "gh", "auth", "token")
out, err := cmd.Output()
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
return "", fmt.Errorf("gh auth token timed out after %s", apiTimeout)
}
return "", fmt.Errorf("gh auth token failed: %s (is gh CLI installed and authenticated?)", ghError(err))
}
return strings.TrimSpace(string(out)), nil
Expand Down
29 changes: 29 additions & 0 deletions internal/github/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@ import (
"strings"
)

// withTimeout returns a context with apiTimeout applied, unless the caller
// already set a deadline.
func withTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
if _, ok := ctx.Deadline(); ok {
return ctx, func() {}
}
return context.WithTimeout(ctx, apiTimeout)
}

// ghError extracts stderr from an exec.ExitError for better error messages.
func ghError(err error) string {
if ee, ok := err.(*exec.ExitError); ok && len(ee.Stderr) > 0 {
Expand Down Expand Up @@ -50,9 +59,14 @@ type ApprovedPR struct {

// GetCurrentUser returns the authenticated GitHub user's login.
func GetCurrentUser(ctx context.Context) (string, error) {
ctx, cancel := withTimeout(ctx)
defer cancel()
cmd := exec.CommandContext(ctx, "gh", "api", "user", "--jq", ".login")
out, err := cmd.Output()
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
return "", fmt.Errorf("fetching current user timed out after %s", apiTimeout)
}
return "", fmt.Errorf("fetching current user: %s", ghError(err))
}
return strings.TrimSpace(string(out)), nil
Expand All @@ -61,6 +75,8 @@ func GetCurrentUser(ctx context.Context) (string, error) {
// GetReviewRequests fetches PRs where the user is a requested reviewer,
// including re-reviews. Uses GraphQL via `gh api graphql`.
func GetReviewRequests(ctx context.Context, repoFilter string) ([]ReviewRequest, error) {
ctx, cancel := withTimeout(ctx)
defer cancel()
query := `query($q1: String!, $q2: String!) {
requested: search(query: $q1, type: ISSUE, first: 50) {
nodes {
Expand Down Expand Up @@ -103,6 +119,9 @@ func GetReviewRequests(ctx context.Context, repoFilter string) ([]ReviewRequest,
)
out, err := cmd.Output()
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
return nil, fmt.Errorf("review requests query timed out after %s", apiTimeout)
}
return nil, fmt.Errorf("GraphQL query failed: %s", ghError(err))
}

Expand Down Expand Up @@ -139,6 +158,8 @@ func GetReviewRequests(ctx context.Context, repoFilter string) ([]ReviewRequest,

// GetApprovedUnmerged fetches the user's own PRs that are approved but not yet merged.
func GetApprovedUnmerged(ctx context.Context, repoFilter string) ([]ApprovedPR, error) {
ctx, cancel := withTimeout(ctx)
defer cancel()
query := `query($q: String!) {
search(query: $q, type: ISSUE, first: 50) {
nodes {
Expand Down Expand Up @@ -168,6 +189,9 @@ func GetApprovedUnmerged(ctx context.Context, repoFilter string) ([]ApprovedPR,
)
out, err := cmd.Output()
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
return nil, fmt.Errorf("approved PRs query timed out after %s", apiTimeout)
}
return nil, fmt.Errorf("GraphQL query failed: %s", ghError(err))
}

Expand All @@ -193,6 +217,8 @@ func GetApprovedUnmerged(ctx context.Context, repoFilter string) ([]ApprovedPR,

// ListOpenPRs lists open PRs for a repository using `gh pr list`.
func ListOpenPRs(ctx context.Context, fullRepo string, limit int) ([]ReviewRequest, error) {
ctx, cancel := withTimeout(ctx)
defer cancel()
cmd := exec.CommandContext(ctx, "gh", "pr", "list",
"-R", fullRepo,
"--state", "open",
Expand All @@ -201,6 +227,9 @@ func ListOpenPRs(ctx context.Context, fullRepo string, limit int) ([]ReviewReque
)
out, err := cmd.Output()
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
return nil, fmt.Errorf("listing open PRs timed out after %s", apiTimeout)
}
return nil, err
}

Expand Down
Loading