Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 62 additions & 14 deletions sqlutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,6 @@ func failMain(err error) int {
}

func TestTransact(t *testing.T) {
fixturePath, err := filepath.Abs("./testdata/fixture.sql")
require.NoError(t, err)

cleanerPath, err := filepath.Abs("./testdata/cleaner.sql")
require.NoError(t, err)

tcs := []struct {
name string
db *sql.DB
Expand All @@ -166,16 +160,15 @@ func TestTransact(t *testing.T) {
// should not use t.Context()
ctx := context.Background()

err = sqlutil.ExecFile(ctx, tc.db, cleanerPath)
require.NoError(t, err)
truncateTask(t, ctx, tc.db)

require.Zero(t, countAllTasks(t, ctx, tc.db))
})

ctx := t.Context()

err = sqlutil.ExecFile(ctx, tc.db, fixturePath)
require.NoError(t, err)
createTask(t, ctx, tc.db, 1, "task1.title")
createTask(t, ctx, tc.db, 2, "task2.title")

require.Equal(t, 2, countAllTasks(t, ctx, tc.db))

Expand All @@ -186,8 +179,8 @@ func TestTransact(t *testing.T) {
ctx := t.Context()

require.PanicsWithError(t, errPanic.Error(), func() {
sqlutil.Transact(ctx, tc.db, func(txCtx context.Context, tx *sql.Tx) error {
completeTask(t, txCtx, tx, 1)
sqlutil.Transact(ctx, tc.db, func(ctx context.Context, tx *sql.Tx) error {
completeTask(t, ctx, tx, 1)

panic(errPanic)
})
Expand All @@ -200,8 +193,8 @@ func TestTransact(t *testing.T) {
t.Run("failure: rollback on error", func(t *testing.T) {
ctx := t.Context()

err := sqlutil.Transact(ctx, tc.db, func(txCtx context.Context, tx *sql.Tx) error {
completeTask(t, txCtx, tx, 1)
err := sqlutil.Transact(ctx, tc.db, func(ctx context.Context, tx *sql.Tx) error {
completeTask(t, ctx, tx, 1)

return errSomethingWentWrong
})
Expand All @@ -228,6 +221,54 @@ func TestTransact(t *testing.T) {
}
}

func TestExecFile(t *testing.T) {
tcs := []struct {
name string
db *sql.DB
}{
{
"mysql",
mysqlDB,
},
{
"postgresql",
psqlDB,
},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
t.Run("failure: path must be absolute", func(t *testing.T) {
ctx := t.Context()

err := sqlutil.ExecFile(ctx, tc.db, "./testdata/fixture.sql")
require.ErrorContains(t, err, "path must be absolute")

require.Zero(t, countAllTasks(t, ctx, tc.db))
})

t.Run("success", func(t *testing.T) {
ctx := t.Context()

fPath, err := filepath.Abs("./testdata/fixture.sql")
require.NoError(t, err)

err = sqlutil.ExecFile(ctx, tc.db, fPath)
require.NoError(t, err)

require.Equal(t, 2, countAllTasks(t, ctx, tc.db))
})
})
}
}

func createTask(t *testing.T, ctx context.Context, dbtx DBTX, id int, title string) {
t.Helper()

_, err := dbtx.ExecContext(ctx, fmt.Sprintf(`INSERT INTO task (id, title) VALUES (%d, '%s')`, id, title))
require.NoError(t, err)
}

func countAllTasks(t *testing.T, ctx context.Context, dbtx DBTX) (cnt int) {
t.Helper()

Expand All @@ -252,3 +293,10 @@ func completeTask(t *testing.T, ctx context.Context, dbtx DBTX, taskID int) {
_, err := dbtx.ExecContext(ctx, fmt.Sprintf(`UPDATE task SET is_completed = true WHERE id = %d`, taskID))
require.NoError(t, err)
}

func truncateTask(t *testing.T, ctx context.Context, dbtx DBTX) {
t.Helper()

_, err := dbtx.ExecContext(ctx, `TRUNCATE task`)
require.NoError(t, err)
}
1 change: 0 additions & 1 deletion testdata/cleaner.sql

This file was deleted.

4 changes: 2 additions & 2 deletions testdata/fixture.sql
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
INSERT INTO task (id, title) VALUES (1, 'Ask questions');
INSERT INTO task (id, title) VALUES (2, 'Brainstorm ideas');
INSERT INTO task (id, title) VALUES (1, 'task1.title');
INSERT INTO task (id, title) VALUES (2, 'task2.title');
Loading