diff --git a/api/api.go b/api/api.go index d37f8f0..af99706 100644 --- a/api/api.go +++ b/api/api.go @@ -80,7 +80,7 @@ func Init(ctx context.Context, cfg *utils.AppConfig, assets embed.FS, info *AppI registerHealthRoutes(app, humaAPI) registerProjectRoutes(v1) - registerWorktreeRoutes(v1) + registerWorktreeRoutes(v1, cfg) registerBranchRoutes(v1) registerTaskRoutes(v1) registerNotePadRoutes(v1) diff --git a/api/system.go b/api/system.go index b2e9c78..67b0be0 100644 --- a/api/system.go +++ b/api/system.go @@ -4,6 +4,8 @@ import ( "context" "errors" "net/http" + "path/filepath" + "strings" "github.com/danielgtaylor/huma/v2" @@ -159,11 +161,12 @@ func registerSystemRoutes(group *huma.Group, cfg *utils.AppConfig, terminalManag huma.Post(group, "/system/ai-assistant-status/update", func(ctx context.Context, input *struct { Body utils.AIAssistantStatusConfig `json:"body"` }) (*h.MessageResponse, error) { - // 更新内存中的配置 - cfg.Terminal.AIAssistantStatus = input.Body - - // 写回配置文件 - utils.WriteConfig(cfg) + // 原子更新:在锁内完成修改+写盘 + if err := utils.UpdateConfig(cfg, func(c *utils.AppConfig) { + c.Terminal.AIAssistantStatus = input.Body + }); err != nil { + return nil, huma.Error500InternalServerError("failed to save configuration") + } // 热重载:更新所有现有终端的配置 if terminalManager != nil { @@ -194,8 +197,12 @@ func registerSystemRoutes(group *huma.Group, cfg *utils.AppConfig, terminalManag huma.Post(group, "/system/developer-config/update", func(ctx context.Context, input *struct { Body utils.DeveloperConfig `json:"body"` }) (*h.MessageResponse, error) { - cfg.Developer = input.Body - utils.WriteConfig(cfg) + // 原子更新:在锁内完成修改+写盘 + if err := utils.UpdateConfig(cfg, func(c *utils.AppConfig) { + c.Developer = input.Body + }); err != nil { + return nil, huma.Error500InternalServerError("failed to save configuration") + } if terminalManager != nil { terminalManager.UpdateScrollbackEnabled(input.Body.EnableTerminalScrollback) @@ -231,25 +238,29 @@ func registerSystemRoutes(group *huma.Group, cfg *utils.AppConfig, terminalManag Shell string `json:"shell" doc:"Shell命令,空值表示使用自动选择"` } `json:"body"` }) (*h.MessageResponse, error) { - // Validate the shell command if provided + // 验证 Shell 命令有效性 if err := utils.ValidateShellCommand(input.Body.Shell); err != nil { return nil, huma.Error400BadRequest("Invalid shell command: " + err.Error()) } - // Update config based on current platform - switch utils.GetAvailableShells(cfg.Terminal.Shell).Platform { - case "windows": - cfg.Terminal.Shell.Windows = input.Body.Shell - case "darwin": - cfg.Terminal.Shell.Darwin = input.Body.Shell - default: - cfg.Terminal.Shell.Linux = input.Body.Shell + // 获取当前平台以便更新对应配置 + platform := utils.GetAvailableShells(cfg.Terminal.Shell).Platform + + // 原子更新:在锁内完成修改+写盘 + if err := utils.UpdateConfig(cfg, func(c *utils.AppConfig) { + switch platform { + case "windows": + c.Terminal.Shell.Windows = input.Body.Shell + case "darwin": + c.Terminal.Shell.Darwin = input.Body.Shell + default: + c.Terminal.Shell.Linux = input.Body.Shell + } + }); err != nil { + return nil, huma.Error500InternalServerError("failed to save configuration") } - // Persist to config file - utils.WriteConfig(cfg) - - // Hot-reload: update terminal manager's shell config for new sessions + // 热重载:更新终端管理器的 Shell 配置,新会话生效 if terminalManager != nil { terminalManager.UpdateShellConfig(cfg.Terminal.Shell) } @@ -295,6 +306,50 @@ func registerSystemRoutes(group *huma.Group, cfg *utils.AppConfig, terminalManag op.Description = "检查指定的Shell命令是否有效可用" op.Tags = []string{systemTag} }) + + huma.Get(group, "/system/worktree-settings", func(ctx context.Context, input *struct{}) (*h.ItemResponse[utils.WorktreeConfig], error) { + resp := h.NewItemResponse(cfg.Worktree) + resp.Status = http.StatusOK + return resp, nil + }, func(op *huma.Operation) { + op.OperationID = "system-worktree-settings-get" + op.Summary = "获取 Worktree 全局设置" + op.Tags = []string{systemTag} + }) + + huma.Post(group, "/system/worktree-settings/update", func(ctx context.Context, input *struct { + Body utils.WorktreeConfig `json:"body"` + }) (*h.ItemResponse[utils.WorktreeConfig], error) { + globalBaseDir := strings.TrimSpace(input.Body.GlobalBaseDir) + pattern := strings.TrimSpace(input.Body.GlobalDirNamePattern) + if globalBaseDir != "" && !filepath.IsAbs(globalBaseDir) { + return nil, huma.Error400BadRequest("globalBaseDir must be an absolute path") + } + if pattern == "" { + return nil, huma.Error400BadRequest("globalDirNamePattern is required") + } + + // 安全检查:全局基础目录不能是敏感系统目录 + if globalBaseDir != "" && utils.IsSensitiveSystemDir(globalBaseDir) { + return nil, huma.Error400BadRequest("globalBaseDir cannot be a system directory") + } + + // 原子更新:在锁内完成修改+写盘 + if err := utils.UpdateConfig(cfg, func(c *utils.AppConfig) { + c.Worktree.GlobalBaseDir = globalBaseDir + c.Worktree.GlobalDirNamePattern = pattern + }); err != nil { + return nil, huma.Error500InternalServerError("failed to save configuration") + } + + resp := h.NewItemResponse(cfg.Worktree) + resp.Status = http.StatusOK + return resp, nil + }, func(op *huma.Operation) { + op.OperationID = "system-worktree-settings-update" + op.Summary = "更新 Worktree 全局设置" + op.Tags = []string{systemTag} + }) } func mapSystemError(err error) error { diff --git a/api/worktree.go b/api/worktree.go index 6ba41c4..e5186ab 100644 --- a/api/worktree.go +++ b/api/worktree.go @@ -18,9 +18,11 @@ const worktreeTag = "worktree-工作树" type createWorktreeInput struct { Body struct { - BranchName string `json:"branchName" doc:"分支名称" required:"true"` - BaseBranch string `json:"baseBranch" doc:"基础分支" default:""` - CreateBranch bool `json:"createBranch" doc:"是否创建新分支" default:"true"` + BranchName string `json:"branchName" doc:"分支名称" required:"true"` + BaseBranch string `json:"baseBranch" doc:"基础分支" default:""` + CreateBranch bool `json:"createBranch" doc:"是否创建新分支" default:"true"` + Location string `json:"location,omitempty" doc:"创建位置(project/global),为空表示使用项目默认"` + GlobalBaseDirOverride string `json:"globalBaseDirOverride,omitempty" doc:"全局 Worktree 基础目录(仅本次创建,优先级高于全局配置)"` } `json:"body"` } @@ -30,7 +32,7 @@ type commitWorktreeInput struct { } `json:"body"` } -func registerWorktreeRoutes(group *huma.Group) { +func registerWorktreeRoutes(group *huma.Group, cfg *utils.AppConfig) { worktreeSvc := service.NewWorktreeService() huma.Post(group, "/projects/{projectId}/worktrees/create", func( @@ -44,8 +46,14 @@ func registerWorktreeRoutes(group *huma.Group) { ctx, input.ProjectID, input.Body.BranchName, - input.Body.BaseBranch, - input.Body.CreateBranch, + service.CreateWorktreeOptions{ + BaseBranch: input.Body.BaseBranch, + CreateBranch: input.Body.CreateBranch, + Location: input.Body.Location, + GlobalBaseDirOverride: input.Body.GlobalBaseDirOverride, + GlobalBaseDir: cfg.Worktree.GlobalBaseDir, + GlobalDirNamePattern: cfg.Worktree.GlobalDirNamePattern, + }, ) if err != nil { return nil, mapWorktreeError(err) diff --git a/model/db_gen.go b/model/db_gen.go index 42db34c..4e5c094 100644 --- a/model/db_gen.go +++ b/model/db_gen.go @@ -54,6 +54,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { if q.projectUpdateStmt, err = db.PrepareContext(ctx, projectUpdate); err != nil { return nil, fmt.Errorf("error preparing query ProjectUpdate: %w", err) } + if q.projectUpdateWorktreeBasePathStmt, err = db.PrepareContext(ctx, projectUpdateWorktreeBasePath); err != nil { + return nil, fmt.Errorf("error preparing query ProjectUpdateWorktreeBasePath: %w", err) + } if q.projectUpdatePriorityStmt, err = db.PrepareContext(ctx, projectUpdatePriority); err != nil { return nil, fmt.Errorf("error preparing query ProjectUpdatePriority: %w", err) } @@ -160,6 +163,11 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing projectUpdateStmt: %w", cerr) } } + if q.projectUpdateWorktreeBasePathStmt != nil { + if cerr := q.projectUpdateWorktreeBasePathStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing projectUpdateWorktreeBasePathStmt: %w", cerr) + } + } if q.projectUpdatePriorityStmt != nil { if cerr := q.projectUpdatePriorityStmt.Close(); cerr != nil { err = fmt.Errorf("error closing projectUpdatePriorityStmt: %w", cerr) @@ -282,67 +290,69 @@ func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, ar } type Queries struct { - db DBTX - tx *sql.Tx - accessTokenCreateStmt *sql.Stmt - accessTokenDeleteAllByUserIdStmt *sql.Stmt - accessTokenGetByIdStmt *sql.Stmt - accessTokenRefreshStmt *sql.Stmt - getOneStmt *sql.Stmt - projectCreateStmt *sql.Stmt - projectGetByIDStmt *sql.Stmt - projectListStmt *sql.Stmt - projectSoftDeleteStmt *sql.Stmt - projectUpdateStmt *sql.Stmt - projectUpdatePriorityStmt *sql.Stmt - taskCountByWorktreeStmt *sql.Stmt - userCreateStmt *sql.Stmt - userDeleteStmt *sql.Stmt - userDisableStmt *sql.Stmt - userGetByIdStmt *sql.Stmt - userGetByUsernameStmt *sql.Stmt - userListStmt *sql.Stmt - userListCountStmt *sql.Stmt - userUpdateInfoStmt *sql.Stmt - userUpdatePasswordStmt *sql.Stmt - worktreeCreateStmt *sql.Stmt - worktreeGetByIDStmt *sql.Stmt - worktreeListByProjectStmt *sql.Stmt - worktreeSoftDeleteStmt *sql.Stmt - worktreeUpdateMetadataStmt *sql.Stmt - worktreeUpdateStatusStmt *sql.Stmt + db DBTX + tx *sql.Tx + accessTokenCreateStmt *sql.Stmt + accessTokenDeleteAllByUserIdStmt *sql.Stmt + accessTokenGetByIdStmt *sql.Stmt + accessTokenRefreshStmt *sql.Stmt + getOneStmt *sql.Stmt + projectCreateStmt *sql.Stmt + projectGetByIDStmt *sql.Stmt + projectListStmt *sql.Stmt + projectSoftDeleteStmt *sql.Stmt + projectUpdateStmt *sql.Stmt + projectUpdateWorktreeBasePathStmt *sql.Stmt + projectUpdatePriorityStmt *sql.Stmt + taskCountByWorktreeStmt *sql.Stmt + userCreateStmt *sql.Stmt + userDeleteStmt *sql.Stmt + userDisableStmt *sql.Stmt + userGetByIdStmt *sql.Stmt + userGetByUsernameStmt *sql.Stmt + userListStmt *sql.Stmt + userListCountStmt *sql.Stmt + userUpdateInfoStmt *sql.Stmt + userUpdatePasswordStmt *sql.Stmt + worktreeCreateStmt *sql.Stmt + worktreeGetByIDStmt *sql.Stmt + worktreeListByProjectStmt *sql.Stmt + worktreeSoftDeleteStmt *sql.Stmt + worktreeUpdateMetadataStmt *sql.Stmt + worktreeUpdateStatusStmt *sql.Stmt } func (q *Queries) WithTx(tx *sql.Tx) *Queries { return &Queries{ - db: tx, - tx: tx, - accessTokenCreateStmt: q.accessTokenCreateStmt, - accessTokenDeleteAllByUserIdStmt: q.accessTokenDeleteAllByUserIdStmt, - accessTokenGetByIdStmt: q.accessTokenGetByIdStmt, - accessTokenRefreshStmt: q.accessTokenRefreshStmt, - getOneStmt: q.getOneStmt, - projectCreateStmt: q.projectCreateStmt, - projectGetByIDStmt: q.projectGetByIDStmt, - projectListStmt: q.projectListStmt, - projectSoftDeleteStmt: q.projectSoftDeleteStmt, - projectUpdateStmt: q.projectUpdateStmt, - projectUpdatePriorityStmt: q.projectUpdatePriorityStmt, - taskCountByWorktreeStmt: q.taskCountByWorktreeStmt, - userCreateStmt: q.userCreateStmt, - userDeleteStmt: q.userDeleteStmt, - userDisableStmt: q.userDisableStmt, - userGetByIdStmt: q.userGetByIdStmt, - userGetByUsernameStmt: q.userGetByUsernameStmt, - userListStmt: q.userListStmt, - userListCountStmt: q.userListCountStmt, - userUpdateInfoStmt: q.userUpdateInfoStmt, - userUpdatePasswordStmt: q.userUpdatePasswordStmt, - worktreeCreateStmt: q.worktreeCreateStmt, - worktreeGetByIDStmt: q.worktreeGetByIDStmt, - worktreeListByProjectStmt: q.worktreeListByProjectStmt, - worktreeSoftDeleteStmt: q.worktreeSoftDeleteStmt, - worktreeUpdateMetadataStmt: q.worktreeUpdateMetadataStmt, - worktreeUpdateStatusStmt: q.worktreeUpdateStatusStmt, + db: tx, + tx: tx, + accessTokenCreateStmt: q.accessTokenCreateStmt, + accessTokenDeleteAllByUserIdStmt: q.accessTokenDeleteAllByUserIdStmt, + accessTokenGetByIdStmt: q.accessTokenGetByIdStmt, + accessTokenRefreshStmt: q.accessTokenRefreshStmt, + getOneStmt: q.getOneStmt, + projectCreateStmt: q.projectCreateStmt, + projectGetByIDStmt: q.projectGetByIDStmt, + projectListStmt: q.projectListStmt, + projectSoftDeleteStmt: q.projectSoftDeleteStmt, + projectUpdateStmt: q.projectUpdateStmt, + projectUpdateWorktreeBasePathStmt: q.projectUpdateWorktreeBasePathStmt, + projectUpdatePriorityStmt: q.projectUpdatePriorityStmt, + taskCountByWorktreeStmt: q.taskCountByWorktreeStmt, + userCreateStmt: q.userCreateStmt, + userDeleteStmt: q.userDeleteStmt, + userDisableStmt: q.userDisableStmt, + userGetByIdStmt: q.userGetByIdStmt, + userGetByUsernameStmt: q.userGetByUsernameStmt, + userListStmt: q.userListStmt, + userListCountStmt: q.userListCountStmt, + userUpdateInfoStmt: q.userUpdateInfoStmt, + userUpdatePasswordStmt: q.userUpdatePasswordStmt, + worktreeCreateStmt: q.worktreeCreateStmt, + worktreeGetByIDStmt: q.worktreeGetByIDStmt, + worktreeListByProjectStmt: q.worktreeListByProjectStmt, + worktreeSoftDeleteStmt: q.worktreeSoftDeleteStmt, + worktreeUpdateMetadataStmt: q.worktreeUpdateMetadataStmt, + worktreeUpdateStatusStmt: q.worktreeUpdateStatusStmt, } } diff --git a/model/project.sql_gen.go b/model/project.sql_gen.go index 43ec5a0..fa33631 100644 --- a/model/project.sql_gen.go +++ b/model/project.sql_gen.go @@ -230,6 +230,43 @@ func (q *Queries) ProjectUpdate(ctx context.Context, arg *ProjectUpdateParams) ( return &i, err } +const projectUpdateWorktreeBasePath = `-- name: ProjectUpdateWorktreeBasePath :one +UPDATE projects +SET + updated_at = ?1, + worktree_base_path = CAST(?2 AS TEXT) +WHERE id = ?3 + AND deleted_at IS NULL +RETURNING id, created_at, updated_at, deleted_at, name, path, description, default_branch, worktree_base_path, remote_url, last_sync_at, hide_path, priority +` + +type ProjectUpdateWorktreeBasePathParams struct { + UpdatedAt time.Time `db:"updated_at" json:"updatedAt"` + WorktreeBasePath *string `db:"worktree_base_path" json:"worktreeBasePath"` + Id string `db:"id" json:"id"` +} + +func (q *Queries) ProjectUpdateWorktreeBasePath(ctx context.Context, arg *ProjectUpdateWorktreeBasePathParams) (*Project, error) { + row := q.queryRow(ctx, q.projectUpdateWorktreeBasePathStmt, projectUpdateWorktreeBasePath, arg.UpdatedAt, arg.WorktreeBasePath, arg.Id) + var i Project + err := row.Scan( + &i.Id, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + &i.Name, + &i.Path, + &i.Description, + &i.DefaultBranch, + &i.WorktreeBasePath, + &i.RemoteUrl, + &i.LastSyncAt, + &i.HidePath, + &i.Priority, + ) + return &i, err +} + const projectUpdatePriority = `-- name: ProjectUpdatePriority :one UPDATE projects SET diff --git a/model/queries/project.sql b/model/queries/project.sql index fc8825e..50cf0b7 100644 --- a/model/queries/project.sql +++ b/model/queries/project.sql @@ -50,6 +50,15 @@ WHERE id = @id AND deleted_at IS NULL RETURNING *; +-- name: ProjectUpdateWorktreeBasePath :one +UPDATE projects +SET + updated_at = @updated_at, + worktree_base_path = CAST(@worktree_base_path AS TEXT) +WHERE id = @id + AND deleted_at IS NULL +RETURNING *; + -- name: ProjectSoftDelete :execrows UPDATE projects SET diff --git a/service/branch_service.go b/service/branch_service.go index c80725f..5aafa55 100644 --- a/service/branch_service.go +++ b/service/branch_service.go @@ -129,7 +129,10 @@ func (s *BranchService) CreateBranch(ctx context.Context, projectID, name, base if createWorktree { worktreeService := NewWorktreeService() - if _, err := worktreeService.CreateWorktree(ctx, projectID, branchName, baseBranch, false); err != nil { + if _, err := worktreeService.CreateWorktree(ctx, projectID, branchName, CreateWorktreeOptions{ + BaseBranch: baseBranch, + CreateBranch: false, + }); err != nil { logger.Error("create worktree for branch failed", zap.Error(err), zap.String("projectId", projectID), diff --git a/service/worktree_service.go b/service/worktree_service.go index cb63d9a..de05693 100644 --- a/service/worktree_service.go +++ b/service/worktree_service.go @@ -17,19 +17,29 @@ import ( "go.uber.org/zap" ) -// WorktreeService coordinates CRUD operations between git worktrees and the database. +// WorktreeService 协调 git worktree 与数据库之间的 CRUD 操作。 type WorktreeService struct { asyncStatusRefresh bool } -// NewWorktreeService builds a WorktreeService with async status refresh enabled. +// CreateWorktreeOptions 创建 Worktree 时的选项参数。 +type CreateWorktreeOptions struct { + BaseBranch string // 基础分支(新建分支时的起始点) + CreateBranch bool // 是否创建新分支 + Location string // 创建位置:"project"(项目目录)或 "global"(全局目录) + GlobalBaseDirOverride string // 全局目录覆盖(仅本次生效,不持久化) + GlobalBaseDir string // 全局 Worktree 基础目录(来自配置) + GlobalDirNamePattern string // 全局目录命名模式(如 {projectName}-{branch}) +} + +// NewWorktreeService 创建一个启用异步状态刷新的 WorktreeService 实例。 func NewWorktreeService() *WorktreeService { return &WorktreeService{ asyncStatusRefresh: true, } } -// AsyncRefresh toggles async status refresh behaviour (useful for tests). +// AsyncRefresh 切换异步状态刷新行为(用于测试)。 func (s *WorktreeService) AsyncRefresh(enabled bool) { if s == nil { return @@ -37,13 +47,12 @@ func (s *WorktreeService) AsyncRefresh(enabled bool) { s.asyncStatusRefresh = enabled } -// CreateWorktree provisions a new git worktree and persists its metadata. +// CreateWorktree 创建一个新的 git worktree 并持久化其元数据。 func (s *WorktreeService) CreateWorktree( ctx context.Context, projectID string, branchName string, - baseBranch string, - createBranch bool, + opts CreateWorktreeOptions, ) (*model.Worktree, error) { if ctx == nil { ctx = context.Background() @@ -74,8 +83,8 @@ func (s *WorktreeService) CreateWorktree( } targetBranch := strings.TrimSpace(branchName) - if createBranch { - refBranch := strings.TrimSpace(baseBranch) + if opts.CreateBranch { + refBranch := strings.TrimSpace(opts.BaseBranch) if refBranch == "" { if project.DefaultBranch != nil && *project.DefaultBranch != "" { refBranch = *project.DefaultBranch @@ -88,7 +97,7 @@ func (s *WorktreeService) CreateWorktree( } } - worktreePath, err := s.resolveWorktreePath(project, targetBranch) + worktreePath, baseDirToPersist, persistRequested, err := s.resolveWorktreePath(project, targetBranch, opts) if err != nil { return nil, err } @@ -123,6 +132,29 @@ func (s *WorktreeService) CreateWorktree( return nil, err } + if persistRequested { + updatedAt := time.Now() + var basePathParam *string + if strings.TrimSpace(baseDirToPersist) != "" { + cleaned := filepath.Clean(baseDirToPersist) + basePathParam = &cleaned + } + + if _, err := q.ProjectUpdateWorktreeBasePath(ctx, &model.ProjectUpdateWorktreeBasePathParams{ + UpdatedAt: updatedAt, + WorktreeBasePath: basePathParam, + Id: projectID, + }); err != nil { + _ = gitRepo.RemoveWorktree(worktreePath, true) + _, _ = q.WorktreeSoftDelete(ctx, &model.WorktreeSoftDeleteParams{ + DeletedAt: &updatedAt, + UpdatedAt: updatedAt, + Id: worktree.Id, + }) + return nil, err + } + } + // 同步刷新状态,确保返回的 worktree 包含最新的 git 状态信息 refreshed, err := s.RefreshWorktreeStatus(ctx, worktree.Id) if err != nil { @@ -137,7 +169,7 @@ func (s *WorktreeService) CreateWorktree( return refreshed, nil } -// ListWorktrees returns worktrees for a project ordered by main flag then creation. +// ListWorktrees 返回项目的所有 worktree,按主 worktree 标志和创建时间排序。 func (s *WorktreeService) ListWorktrees(ctx context.Context, projectID string) ([]*model.Worktree, error) { if ctx == nil { ctx = context.Background() @@ -151,7 +183,7 @@ func (s *WorktreeService) ListWorktrees(ctx context.Context, projectID string) ( return q.WorktreeListByProject(ctx, projectID) } -// GetWorktree fetches a worktree by identifier. +// GetWorktree 根据 ID 获取 worktree 记录。 func (s *WorktreeService) GetWorktree(ctx context.Context, id string) (*model.Worktree, error) { if ctx == nil { ctx = context.Background() @@ -172,7 +204,7 @@ func (s *WorktreeService) GetWorktree(ctx context.Context, id string) (*model.Wo return wt, nil } -// DeleteWorktree removes a worktree from git and the database. +// DeleteWorktree 从 git 和数据库中删除 worktree。 func (s *WorktreeService) DeleteWorktree(ctx context.Context, id string, force, deleteBranch bool) error { if ctx == nil { ctx = context.Background() @@ -208,16 +240,16 @@ func (s *WorktreeService) DeleteWorktree(ctx context.Context, id string, force, return err } - // Check if project path exists + // 检查项目路径是否存在 var gitRepo *git.GitRepo if _, err := os.Stat(project.Path); os.IsNotExist(err) { utils.Logger().Warn("project path does not exist, skipping git operations", zap.String("projectPath", project.Path), zap.String("worktreeId", id), ) - // Skip git operations and proceed to database cleanup + // 跳过 git 操作,继续进行数据库清理 } else { - // Project exists, try git operations + // 项目存在,尝试 git 操作 gitRepo, err = git.DetectRepository(project.Path) if err != nil { utils.Logger().Warn("failed to detect git repository, skipping git removal", @@ -226,17 +258,17 @@ func (s *WorktreeService) DeleteWorktree(ctx context.Context, id string, force, zap.String("worktreeId", id), ) } else { - // Try to remove the worktree from git + // 尝试从 git 中移除 worktree if err := gitRepo.RemoveWorktree(worktree.Path, force); err != nil { - // If the worktree path doesn't exist anymore, we can still proceed - // Check if the error is because the worktree doesn't exist + // 如果 worktree 路径已不存在,可以继续处理 + // 检查错误是否因为 worktree 不存在 if _, statErr := os.Stat(worktree.Path); os.IsNotExist(statErr) { utils.Logger().Warn("worktree path does not exist, skipping git removal", zap.String("path", worktree.Path), zap.String("worktreeId", id), ) } else { - // For other errors, return them + // 其他错误则返回 return err } } @@ -262,7 +294,7 @@ func (s *WorktreeService) DeleteWorktree(ctx context.Context, id string, force, return err } -// RefreshWorktreeStatus updates cached status fields for a worktree and returns the refreshed record. +// RefreshWorktreeStatus 更新 worktree 的缓存状态字段并返回刷新后的记录。 func (s *WorktreeService) RefreshWorktreeStatus(ctx context.Context, id string) (*model.Worktree, error) { if ctx == nil { ctx = context.Background() @@ -328,7 +360,7 @@ func (s *WorktreeService) RefreshWorktreeStatus(ctx context.Context, id string) return updated, nil } -// RefreshAllWorktrees refreshes status for every worktree belonging to a project. +// RefreshAllWorktrees 刷新项目下所有 worktree 的状态。 func (s *WorktreeService) RefreshAllWorktrees(ctx context.Context, projectID string) (updated, failed int, err error) { if ctx == nil { ctx = context.Background() @@ -354,7 +386,7 @@ func (s *WorktreeService) RefreshAllWorktrees(ctx context.Context, projectID str return updated, failed, nil } -// RefreshWorktreeCommitInfo refreshes commit/status metadata for all worktrees and returns the updated list. +// RefreshWorktreeCommitInfo 刷新所有 worktree 的提交/状态元数据并返回更新后的列表。 func (s *WorktreeService) RefreshWorktreeCommitInfo(ctx context.Context, projectID string) ([]*model.Worktree, error) { if ctx == nil { ctx = context.Background() @@ -365,7 +397,7 @@ func (s *WorktreeService) RefreshWorktreeCommitInfo(ctx context.Context, project return s.ListWorktrees(ctx, projectID) } -// SyncWorktrees ensures git worktrees and the database remain aligned. +// SyncWorktrees 确保 git worktree 与数据库保持同步。 func (s *WorktreeService) SyncWorktrees(ctx context.Context, projectID string) error { if ctx == nil { ctx = context.Background() @@ -486,7 +518,7 @@ func (s *WorktreeService) SyncWorktrees(ctx context.Context, projectID string) e return nil } -// CommitWorktree stages all changes within the worktree and creates a commit with the provided message. +// CommitWorktree 暂存 worktree 中的所有更改并使用指定消息创建提交。 func (s *WorktreeService) CommitWorktree(ctx context.Context, id, message string) (*model.Worktree, error) { if ctx == nil { ctx = context.Background() @@ -545,25 +577,123 @@ func (s *WorktreeService) CommitWorktree(ctx context.Context, id, message string return updated, nil } -func (s *WorktreeService) resolveWorktreePath(project *model.Project, branchName string) (string, error) { - basePath := "" - if project.WorktreeBasePath != nil && strings.TrimSpace(*project.WorktreeBasePath) != "" { - basePath = *project.WorktreeBasePath - } else { - basePath = filepath.Join(project.Path, ".worktrees") +// resolveWorktreePath 根据选项解析 worktree 的完整路径。 +// 返回值: +// - worktreePath: 最终的 worktree 目录路径 +// - baseDirToPersist: 需要持久化到项目的基础目录(仅当使用全局配置时) +// - persistRequested: 是否需要持久化基础目录到项目 +// - err: 错误信息 +func (s *WorktreeService) resolveWorktreePath(project *model.Project, branchName string, opts CreateWorktreeOptions) (worktreePath string, baseDirToPersist string, persistRequested bool, err error) { + if project == nil { + return "", "", false, fmt.Errorf("project is required") + } + + location := strings.TrimSpace(opts.Location) + if location != "" && location != "project" && location != "global" { + return "", "", false, fmt.Errorf("invalid location: %s", location) + } + + pattern := strings.TrimSpace(opts.GlobalDirNamePattern) + if pattern == "" { + pattern = "{projectName}-{branch}" + } + + baseDir := "" + globalMode := false + persistRequested = location != "" + + switch location { + case "project": + baseDir = filepath.Join(project.Path, ".worktrees") + globalMode = false + baseDirToPersist = "" + case "global": + // 优先检查覆盖参数(仅本次生效,不持久化) + overrideDir := strings.TrimSpace(opts.GlobalBaseDirOverride) + configDir := strings.TrimSpace(opts.GlobalBaseDir) + + if overrideDir != "" { + baseDir = overrideDir + // 覆盖参数仅本次生效,不持久化到项目 + baseDirToPersist = "" + persistRequested = false + } else if configDir != "" { + baseDir = configDir + // 使用全局配置,持久化到项目以便后续使用 + baseDirToPersist = filepath.Clean(configDir) + } else { + return "", "", false, fmt.Errorf("global base dir is not configured") + } + + if !filepath.IsAbs(baseDir) { + return "", "", false, fmt.Errorf("global base dir must be an absolute path") + } + // 安全检查:全局基础目录不能是敏感系统目录 + if utils.IsSensitiveSystemDir(baseDir) { + return "", "", false, fmt.Errorf("global base dir cannot be a system directory") + } + globalMode = true + default: + if project.WorktreeBasePath != nil && strings.TrimSpace(*project.WorktreeBasePath) != "" { + baseDir = strings.TrimSpace(*project.WorktreeBasePath) + } else { + baseDir = filepath.Join(project.Path, ".worktrees") + } + + if !filepath.IsAbs(baseDir) { + baseDir = filepath.Join(project.Path, baseDir) + } + + // 安全检查:确保 baseDir 不会通过 ".." 逃逸出项目目录 + absBase := filepath.Clean(baseDir) + absProject := filepath.Clean(project.Path) + rel, relErr := filepath.Rel(absProject, absBase) + if relErr == nil && strings.HasPrefix(rel, "..") { + // baseDir 逃逸出项目目录 - 仅当是绝对路径时允许 + // 对于包含 ".." 的相对路径,拒绝作为安全风险 + if project.WorktreeBasePath != nil && !filepath.IsAbs(*project.WorktreeBasePath) { + return "", "", false, fmt.Errorf("worktree base path escapes project directory") + } + } + + globalMode = isGlobalWorktreeBaseDir(project.Path, baseDir) + baseDirToPersist = "" } - if !filepath.IsAbs(basePath) { - basePath = filepath.Join(project.Path, basePath) + + if err := os.MkdirAll(baseDir, 0o755); err != nil { + return "", "", false, err } - if err := os.MkdirAll(basePath, 0o755); err != nil { - return "", err + + dirName := "" + if globalMode { + dirName, err = expandWorktreeDirNamePattern(pattern, project, branchName) + if err != nil { + return "", "", false, err + } + } else { + dirName = sanitizeBranchName(branchName) } - dirName := sanitizeBranchName(branchName) - return filepath.Join(basePath, dirName), nil + // 最终安全校验:确保解析后的路径在 baseDir 内 + finalPath := filepath.Join(baseDir, dirName) + cleanFinal := filepath.Clean(finalPath) + cleanBase := filepath.Clean(baseDir) + if !strings.HasPrefix(cleanFinal, cleanBase+string(filepath.Separator)) && cleanFinal != cleanBase { + return "", "", false, fmt.Errorf("worktree path escapes base directory") + } + + return finalPath, baseDirToPersist, persistRequested, nil } +// sanitizeBranchName 将分支名称转换为安全的目录名称。 +// 替换路径分隔符和特殊字符,防止路径遍历攻击。 func sanitizeBranchName(branch string) string { + clean := strings.TrimSpace(branch) + // 拒绝可能导致路径遍历的危险目录名 + if clean == "" || clean == "." || clean == ".." { + return "_invalid_branch_" + } + replacer := strings.NewReplacer( "/", "__", "\\", "__", @@ -574,5 +704,75 @@ func sanitizeBranchName(branch string) string { ">", "_", "|", "_", ) - return replacer.Replace(strings.TrimSpace(branch)) + result := replacer.Replace(clean) + + // 二次校验:如果结果仍包含 ".." 则拒绝 + if strings.Contains(result, "..") { + return "_invalid_branch_" + } + return result +} + +// isGlobalWorktreeBaseDir 判断 worktree 基础目录是否在项目目录外(即全局模式)。 +func isGlobalWorktreeBaseDir(projectPath, baseDir string) bool { + projectAbs, err := filepath.Abs(projectPath) + if err != nil { + return false + } + baseAbs, err := filepath.Abs(baseDir) + if err != nil { + return false + } + rel, err := filepath.Rel(projectAbs, baseAbs) + if err != nil { + return false + } + if rel == "." { + return false + } + return strings.HasPrefix(rel, "..") +} + +// sanitizePathSegment 清理路径片段中的特殊字符。 +func sanitizePathSegment(input string) string { + trimmed := strings.TrimSpace(input) + replacer := strings.NewReplacer( + "/", "_", + "\\", "_", + ":", "_", + "*", "_", + "?", "_", + "<", "_", + ">", "_", + "|", "_", + ) + return replacer.Replace(trimmed) +} + +// expandWorktreeDirNamePattern 展开 worktree 目录名模式。 +// 支持的变量:{projectName}、{projectId}、{branch} +func expandWorktreeDirNamePattern(pattern string, project *model.Project, branchName string) (string, error) { + rawProjectName := "" + if project != nil { + rawProjectName = project.Name + } + + // 使用固定顺序替换以避免非确定性行为 + expanded := pattern + expanded = strings.ReplaceAll(expanded, "{projectName}", sanitizePathSegment(rawProjectName)) + expanded = strings.ReplaceAll(expanded, "{projectId}", sanitizePathSegment(project.Id)) + expanded = strings.ReplaceAll(expanded, "{branch}", sanitizeBranchName(branchName)) + + expanded = strings.TrimSpace(expanded) + if expanded == "" { + return "", fmt.Errorf("worktree dir name is empty after pattern expansion") + } + if strings.Contains(expanded, "..") { + return "", fmt.Errorf("invalid worktree dir name: %s", expanded) + } + if strings.ContainsAny(expanded, "/\\") { + return "", fmt.Errorf("invalid worktree dir name: %s", expanded) + } + + return expanded, nil } diff --git a/service/worktree_service_test.go b/service/worktree_service_test.go index 91b87bb..8634bd3 100644 --- a/service/worktree_service_test.go +++ b/service/worktree_service_test.go @@ -33,7 +33,10 @@ func TestWorktreeServiceCreateAndRefresh(t *testing.T) { svc.AsyncRefresh(false) ctx := context.Background() - worktree, err := svc.CreateWorktree(ctx, project.Id, "feature/testing", "main", true) + worktree, err := svc.CreateWorktree(ctx, project.Id, "feature/testing", CreateWorktreeOptions{ + BaseBranch: "main", + CreateBranch: true, + }) if err != nil { t.Fatalf("CreateWorktree returned error: %v", err) } @@ -87,7 +90,10 @@ func TestWorktreeServiceDeleteAndSync(t *testing.T) { svc.AsyncRefresh(false) ctx := context.Background() - worktree, err := svc.CreateWorktree(ctx, project.Id, "feature/delete", "main", true) + worktree, err := svc.CreateWorktree(ctx, project.Id, "feature/delete", CreateWorktreeOptions{ + BaseBranch: "main", + CreateBranch: true, + }) if err != nil { t.Fatalf("CreateWorktree returned error: %v", err) } @@ -140,7 +146,10 @@ func TestWorktreeServiceRefreshAll(t *testing.T) { svc.AsyncRefresh(false) ctx := context.Background() - if _, err := svc.CreateWorktree(ctx, project.Id, "feature/all", "main", true); err != nil { + if _, err := svc.CreateWorktree(ctx, project.Id, "feature/all", CreateWorktreeOptions{ + BaseBranch: "main", + CreateBranch: true, + }); err != nil { t.Fatalf("CreateWorktree returned error: %v", err) } @@ -180,7 +189,10 @@ func TestWorktreeServiceCommit(t *testing.T) { svc.AsyncRefresh(false) ctx := context.Background() - worktree, err := svc.CreateWorktree(ctx, project.Id, "feature/commit", "main", true) + worktree, err := svc.CreateWorktree(ctx, project.Id, "feature/commit", CreateWorktreeOptions{ + BaseBranch: "main", + CreateBranch: true, + }) if err != nil { t.Fatalf("CreateWorktree returned error: %v", err) } @@ -203,6 +215,65 @@ func TestWorktreeServiceCommit(t *testing.T) { } } +func TestWorktreeServiceCreateWorktree_PersistWorktreeBasePath(t *testing.T) { + cleanup := initTestDB(t) + defer cleanup() + + repoPath := createProjectTestRepo(t) + projectService := &model.ProjectService{} + project, err := projectService.CreateProject(context.Background(), model.CreateProjectParams{ + Name: "Persist Project", + Path: repoPath, + }) + if err != nil { + t.Fatalf("create project failed: %v", err) + } + + q, err := model.ResolveQueries(nil) + if err != nil { + t.Fatalf("resolve queries failed: %v", err) + } + + svc := NewWorktreeService() + svc.AsyncRefresh(false) + ctx := context.Background() + + globalBaseDir := t.TempDir() + if _, err := svc.CreateWorktree(ctx, project.Id, "feature/global", CreateWorktreeOptions{ + BaseBranch: "main", + CreateBranch: true, + Location: "global", + GlobalBaseDir: globalBaseDir, + GlobalDirNamePattern: "{projectName}-{branch}", + }); err != nil { + t.Fatalf("CreateWorktree(global) returned error: %v", err) + } + + updated, err := q.ProjectGetByID(ctx, project.Id) + if err != nil { + t.Fatalf("reload project failed: %v", err) + } + if updated.WorktreeBasePath == nil || filepath.Clean(*updated.WorktreeBasePath) != filepath.Clean(globalBaseDir) { + t.Fatalf("expected worktreeBasePath to be persisted to %s, got %v", filepath.Clean(globalBaseDir), updated.WorktreeBasePath) + } + + if _, err := svc.CreateWorktree(ctx, project.Id, "feature/project", CreateWorktreeOptions{ + BaseBranch: "main", + CreateBranch: true, + Location: "project", + }); err != nil { + t.Fatalf("CreateWorktree(project) returned error: %v", err) + } + + cleared, err := q.ProjectGetByID(ctx, project.Id) + if err != nil { + t.Fatalf("reload project failed: %v", err) + } + if cleared.WorktreeBasePath != nil && strings.TrimSpace(*cleared.WorktreeBasePath) != "" { + t.Fatalf("expected worktreeBasePath to be cleared, got %v", cleared.WorktreeBasePath) + } +} + func initTestDB(t *testing.T) func() { t.Helper() dsn := "file:" + t.Name() + "?mode=memory&cache=shared" @@ -258,3 +329,68 @@ func runGitCommand(t *testing.T, dir string, args ...string) { t.Fatalf("git %s failed: %v\n%s", strings.Join(args, " "), err, output) } } + +// TestSanitizeBranchName_Security tests that sanitizeBranchName correctly handles +// potentially dangerous branch names that could lead to path traversal. +func TestSanitizeBranchName_Security(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"empty", "", "_invalid_branch_"}, + {"dot", ".", "_invalid_branch_"}, + {"dotdot", "..", "_invalid_branch_"}, + {"normal branch", "feature/test", "feature__test"}, + {"with backslash", "feature\\test", "feature__test"}, + {"dotdot in name", "feature..test", "_invalid_branch_"}, + {"trailing dots", "feature..", "_invalid_branch_"}, + {"leading dots", "..feature", "_invalid_branch_"}, + {"triple dot", "...", "_invalid_branch_"}, + {"valid dotfile", ".gitignore", ".gitignore"}, + {"spaces only", " ", "_invalid_branch_"}, + {"mixed special chars", "feat:test*?<>|", "feat_test_____"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := sanitizeBranchName(tt.input) + if result != tt.expected { + t.Errorf("sanitizeBranchName(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +// TestExpandWorktreeDirNamePattern_Security tests that pattern expansion +// correctly rejects potentially dangerous patterns. +func TestExpandWorktreeDirNamePattern_Security(t *testing.T) { + project := &model.Project{ + Id: "test-id", + Name: "Test Project", + } + + tests := []struct { + name string + pattern string + branchName string + shouldError bool + }{ + {"normal pattern", "{projectName}-{branch}", "feature/test", false}, + {"dotdot branch", "{projectName}-{branch}", "..", false}, // sanitizeBranchName will handle this + {"path separator in result", "{projectName}/{branch}", "test", true}, + {"backslash in result", "{projectName}\\{branch}", "test", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := expandWorktreeDirNamePattern(tt.pattern, project, tt.branchName) + if tt.shouldError && err == nil { + t.Errorf("expected error for pattern %q with branch %q, got nil", tt.pattern, tt.branchName) + } + if !tt.shouldError && err != nil { + t.Errorf("unexpected error for pattern %q with branch %q: %v", tt.pattern, tt.branchName, err) + } + }) + } +} diff --git a/static/index.html b/static/index.html index a2ca349..74140d2 100644 --- a/static/index.html +++ b/static/index.html @@ -7,8 +7,8 @@