diff --git a/engine/internal/cloning/base.go b/engine/internal/cloning/base.go index 952a9436..092d49d9 100644 --- a/engine/internal/cloning/base.go +++ b/engine/internal/cloning/base.go @@ -30,6 +30,7 @@ import ( "gitlab.com/postgres-ai/database-lab/v3/pkg/log" "gitlab.com/postgres-ai/database-lab/v3/pkg/models" "gitlab.com/postgres-ai/database-lab/v3/pkg/util" + "gitlab.com/postgres-ai/database-lab/v3/pkg/util/branching" "gitlab.com/postgres-ai/database-lab/v3/pkg/util/pglog" ) @@ -172,7 +173,11 @@ func (c *Base) CreateClone(cloneRequest *types.CloneCreateRequest) (*models.Clon } if cloneRequest.Branch == "" { - cloneRequest.Branch = snapshot.Branch + if cloneRequest.Snapshot != nil { + cloneRequest.Branch = snapshot.Branch + } else { + cloneRequest.Branch = branching.DefaultBranch + } } clone := &models.Clone{ diff --git a/engine/internal/cloning/base_test.go b/engine/internal/cloning/base_test.go index a63c28bc..57c990a3 100644 --- a/engine/internal/cloning/base_test.go +++ b/engine/internal/cloning/base_test.go @@ -12,7 +12,9 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "gitlab.com/postgres-ai/database-lab/v3/pkg/client/dblabapi/types" "gitlab.com/postgres-ai/database-lab/v3/pkg/models" + "gitlab.com/postgres-ai/database-lab/v3/pkg/util/branching" ) func TestBaseCloningSuite(t *testing.T) { @@ -133,3 +135,35 @@ func (s *BaseCloningSuite) TestLenClones() { lenClones = s.cloning.lenClones() assert.Equal(s.T(), 1, lenClones) } + +func TestDefaultBranchForCloneCreation(t *testing.T) { + testCases := []struct { + name string + inputBranch string + snapshotSpecified bool + snapshotBranch string + expectedBranch string + }{ + {name: "no branch no snapshot defaults to main", inputBranch: "", snapshotSpecified: false, snapshotBranch: "", expectedBranch: "main"}, + {name: "no branch with dev snapshot uses snapshot branch", inputBranch: "", snapshotSpecified: true, snapshotBranch: "dev", expectedBranch: "dev"}, + {name: "no branch with feature snapshot uses snapshot branch", inputBranch: "", snapshotSpecified: true, snapshotBranch: "feature", expectedBranch: "feature"}, + {name: "explicit dev branch preserved", inputBranch: "dev", snapshotSpecified: false, snapshotBranch: "", expectedBranch: "dev"}, + {name: "explicit main branch preserved", inputBranch: "main", snapshotSpecified: true, snapshotBranch: "dev", expectedBranch: "main"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + request := &types.CloneCreateRequest{Branch: tc.inputBranch} + + if request.Branch == "" { + if tc.snapshotSpecified { + request.Branch = tc.snapshotBranch + } else { + request.Branch = branching.DefaultBranch + } + } + + assert.Equal(t, tc.expectedBranch, request.Branch) + }) + } +} diff --git a/ui/packages/shared/pages/CreateClone/index.tsx b/ui/packages/shared/pages/CreateClone/index.tsx index 2629f5e0..c089a706 100644 --- a/ui/packages/shared/pages/CreateClone/index.tsx +++ b/ui/packages/shared/pages/CreateClone/index.tsx @@ -103,7 +103,8 @@ export const CreateClone = observer((props: Props) => { const branches = (await stores.main.getBranches(props.instanceId)) ?? [] - let initiallySelectedBranch = branches[0]?.name; + const mainBranch = branches.find((branch) => branch.name === 'main') + let initiallySelectedBranch = mainBranch?.name ?? branches[0]?.name; if (initialBranch && branches.find((branch) => branch.name === initialBranch)) { initiallySelectedBranch = initialBranch;