diff --git a/MODERNIZATION.md b/MODERNIZATION.md index 728f61e..f2b00b2 100644 --- a/MODERNIZATION.md +++ b/MODERNIZATION.md @@ -349,31 +349,16 @@ On closer analysis, this is actually correct. `buildRetVals` calls `handleMappin --- -## 24. Missing `tx.Rollback()` on Error Paths +## ~~24. Missing `tx.Rollback()` on Error Paths~~ (DONE) -**Files:** `cmd/sample/main.go`, `cmd/sample-ctx/main.go`, various test files +Fixed across all affected files. The `defer tx.Commit()` pattern has been replaced with: -The pattern used throughout is: -```go -tx, err := db.Begin() -if err != nil { - panic(err) -} -defer tx.Commit() -``` +- `defer func() { _ = tx.Rollback() }()` as a safety-net cleanup at the top +- An explicit `tx.Commit()` call (with error handling) at the end of the successful path -`defer tx.Commit()` runs even when the function exits due to an error or panic, committing partial transactions. The correct pattern is: -```go -defer func() { - if err != nil { - tx.Rollback() - } else { - tx.Commit() - } -}() -``` +For test code in `proteus_test.go` where all reads occur on the same `tx` (no need to persist data beyond the test), the deferred rollback alone is sufficient — this also provides better test isolation. -Note: `cmd/null/main.go` already uses the correct pattern. +**Files fixed:** `cmd/sample/main.go`, `cmd/sample-ctx/main.go`, `bench/bench_test.go`, `speed/speed.go`, `example_test.go`, `example2_test.go`, `mapper_test.go`, `proteus_test.go` (11 instances) --- @@ -416,7 +401,7 @@ If `Build` returns an error, `productDao` will have nil function fields. Subsequ - ~~#7 — `defer rows.Close()` (resource leak risk)~~ *(DONE)* - ~~#8 — Remove `unsafe`~~ *(DONE)* - ~~#15 — Makefile bug fix~~ *(DONE)* -- #24 — Missing `tx.Rollback()` in samples/tests +- ~~#24 — Missing `tx.Rollback()` in samples/tests~~ *(DONE)* **Medium priority (idiomatic modernization):** - ~~#1 — `interface{}` to `any`~~ *(DONE)* diff --git a/bench/bench_test.go b/bench/bench_test.go index a2374a6..7f747e3 100644 --- a/bench/bench_test.go +++ b/bench/bench_test.go @@ -31,7 +31,7 @@ func populate(ctx context.Context, db *sql.DB) { if err != nil { panic(err) } - defer tx.Commit() + defer func() { _ = tx.Rollback() }() for i := 0; i < 10; i++ { var cost *float64 @@ -44,6 +44,10 @@ func populate(ctx context.Context, db *sql.DB) { panic(err) } } + + if err := tx.Commit(); err != nil { + panic(err) + } } type closer func() @@ -60,8 +64,8 @@ func setup() (*sql.Tx, closer) { panic(err) } return tx, func() { - tx.Commit() - db.Close() + _ = tx.Commit() + _ = db.Close() } } diff --git a/cmd/sample-ctx/main.go b/cmd/sample-ctx/main.go index 08e2312..637a3ef 100644 --- a/cmd/sample-ctx/main.go +++ b/cmd/sample-ctx/main.go @@ -65,7 +65,7 @@ func run(setupDb setupDb, productDAO ProductDAO) { if err != nil { panic(err) } - defer tx.Commit() + defer func() { _ = tx.Rollback() }() slog.Log(ctx, slog.LevelDebug, fmt.Sprintln(productDAO.FindByID(ctx, tx, 10))) cost := new(float64) @@ -97,6 +97,10 @@ func run(setupDb setupDb, productDAO ProductDAO) { //using positional parameters instead of names slog.Log(ctx, slog.LevelDebug, fmt.Sprintln((productDAO.FindByNameAndCostUnlabeled(ctx, tx, "Thingie", 56.23)))) + + if err := tx.Commit(); err != nil { + slog.Log(ctx, slog.LevelError, fmt.Sprintln(err)) + } } func setupDbPostgres(ctx context.Context, productDAO ProductDAO) *sql.DB { @@ -123,8 +127,9 @@ func populate(ctx context.Context, db *sql.DB, productDao ProductDAO) { tx, err := db.Begin() if err != nil { slog.Log(ctx, slog.LevelError, fmt.Sprintln(err)) + return } - defer tx.Commit() + defer func() { _ = tx.Rollback() }() for i := 0; i < 100; i++ { var cost *float64 @@ -138,4 +143,8 @@ func populate(ctx context.Context, db *sql.DB, productDao ProductDAO) { } slog.Log(ctx, slog.LevelDebug, fmt.Sprintln(rowCount)) } + + if err := tx.Commit(); err != nil { + slog.Log(ctx, slog.LevelError, fmt.Sprintln(err)) + } } diff --git a/cmd/sample/main.go b/cmd/sample/main.go index 2a99102..b99311c 100644 --- a/cmd/sample/main.go +++ b/cmd/sample/main.go @@ -64,7 +64,7 @@ func run(setupDb setupDb, productDAO ProductDAO) { if err != nil { panic(err) } - defer tx.Commit() + defer func() { _ = tx.Rollback() }() slog.Log(ctx, slog.LevelDebug, fmt.Sprintln(productDAO.FindByID(tx, 10))) cost := new(float64) @@ -96,6 +96,10 @@ func run(setupDb setupDb, productDAO ProductDAO) { //using positional parameters instead of names slog.Log(ctx, slog.LevelDebug, fmt.Sprintln((productDAO.FindByNameAndCostUnlabeled(tx, "Thingie", 56.23)))) + + if err := tx.Commit(); err != nil { + slog.Log(ctx, slog.LevelError, fmt.Sprintln(err)) + } } func setupDbPostgres(ctx context.Context, productDAO ProductDAO) *sql.DB { @@ -122,8 +126,9 @@ func populate(ctx context.Context, db *sql.DB, productDao ProductDAO) { tx, err := db.Begin() if err != nil { slog.Log(ctx, slog.LevelError, fmt.Sprintln(err)) + return } - defer tx.Commit() + defer func() { _ = tx.Rollback() }() for i := 0; i < 100; i++ { var cost *float64 @@ -137,4 +142,8 @@ func populate(ctx context.Context, db *sql.DB, productDao ProductDAO) { } slog.Log(ctx, slog.LevelDebug, fmt.Sprintln(rowCount)) } + + if err := tx.Commit(); err != nil { + slog.Log(ctx, slog.LevelError, fmt.Sprintln(err)) + } } diff --git a/example2_test.go b/example2_test.go index c1e8d2f..dd8deb2 100644 --- a/example2_test.go +++ b/example2_test.go @@ -27,7 +27,7 @@ func Example_create() { if err != nil { log.Fatal(err) } - defer tx.Commit() + defer func() { _ = tx.Rollback() }() ctx := context.Background() @@ -37,4 +37,8 @@ func Example_create() { log.Fatal(err) } } + + if err := tx.Commit(); err != nil { + log.Fatal(err) + } } diff --git a/example_test.go b/example_test.go index 1122d7b..381c1cc 100644 --- a/example_test.go +++ b/example_test.go @@ -40,7 +40,7 @@ func Example_readUpdate() { if err != nil { panic(err) } - defer tx.Commit() + defer func() { _ = tx.Rollback() }() fmt.Println(productDao.FindById(ctx, tx, 10)) p := Product{10, "Thingie", 56.23} @@ -61,4 +61,6 @@ func Example_readUpdate() { } fmt.Println(productDao.UpdateMap(ctx, tx, m)) fmt.Println(productDao.FindById(ctx, tx, 11)) + + _ = tx.Commit() } diff --git a/mapper_test.go b/mapper_test.go index f9067b5..eae8232 100644 --- a/mapper_test.go +++ b/mapper_test.go @@ -54,7 +54,7 @@ func setupDb(t *testing.T) *sql.DB { slog.Error("err", "error", slog.AnyValue(err)) t.FailNow() } - defer tx.Commit() + defer func() { _ = tx.Rollback() }() stmt, err := tx.Prepare("insert into product(id, name, cost) values($1, $2, $3)") if err != nil { slog.Error("err", "error", slog.AnyValue(err)) @@ -73,6 +73,10 @@ func setupDb(t *testing.T) *sql.DB { t.FailNow() } } + if err := tx.Commit(); err != nil { + slog.Error("err", "error", slog.AnyValue(err)) + t.FailNow() + } return db } diff --git a/proteus_test.go b/proteus_test.go index ff31e9d..aa73f81 100644 --- a/proteus_test.go +++ b/proteus_test.go @@ -289,7 +289,7 @@ func TestNilScanner(t *testing.T) { if err != nil { t.Fatal(err) } - defer tx.Commit() + defer func() { _ = tx.Rollback() }() _, err = tx.Exec(create) if err != nil { @@ -375,7 +375,7 @@ func TestNoParams(t *testing.T) { if err != nil { t.Fatal(err) } - defer tx.Commit() + defer func() { _ = tx.Rollback() }() _, err = tx.Exec(create) if err != nil { @@ -460,7 +460,7 @@ func TestUnnamedStructs(t *testing.T) { if err != nil { t.Fatal(err) } - defer tx.Commit() + defer func() { _ = tx.Rollback() }() _, err = tx.Exec(create) if err != nil { @@ -515,7 +515,7 @@ func TestEmbedded(t *testing.T) { if err != nil { t.Fatal(err) } - defer tx.Commit() + defer func() { _ = tx.Rollback() }() _, err = tx.Exec(create) if err != nil { @@ -584,7 +584,7 @@ func TestShouldBuildEmbeddedWithNullField(t *testing.T) { if err != nil { t.Fatal(err) } - defer tx.Commit() + defer func() { _ = tx.Rollback() }() _, err = tx.Exec(create) @@ -685,7 +685,7 @@ func TestVariableMultipleUsage(t *testing.T) { if err != nil { t.Fatal(err) } - defer tx.Commit() + defer func() { _ = tx.Rollback() }() _, err = tx.Exec(create) if err != nil { @@ -787,7 +787,7 @@ func TestShouldBuildEmbedded(t *testing.T) { if err != nil { t.Fatal(err) } - defer tx.Commit() + defer func() { _ = tx.Rollback() }() _, err = tx.Exec(create) if err != nil { @@ -844,7 +844,7 @@ func TestShouldBinaryColumn(t *testing.T) { if err != nil { t.Fatal(err) } - defer tx.Commit() + defer func() { _ = tx.Rollback() }() _, err = tx.Exec(create) if err != nil { @@ -905,7 +905,7 @@ func TestShouldTimeColumn(t *testing.T) { if err != nil { t.Fatal(err) } - defer tx.Commit() + defer func() { _ = tx.Rollback() }() _, err = tx.Exec(create) if err != nil { @@ -972,7 +972,7 @@ func TestArray(t *testing.T) { if err != nil { t.Fatal(err) } - defer tx.Commit() + defer func() { _ = tx.Rollback() }() _, err = tx.Exec(create) if err != nil { @@ -1055,7 +1055,7 @@ func TestNested(t *testing.T) { if err != nil { t.Fatal(err) } - defer tx.Commit() + defer func() { _ = tx.Rollback() }() _, err = tx.Exec(create) if err != nil { diff --git a/speed/speed.go b/speed/speed.go index 9144e34..146f217 100644 --- a/speed/speed.go +++ b/speed/speed.go @@ -147,7 +147,7 @@ func populate(ctx context.Context, db *sql.DB) { slog.Error("error", "err", slog.AnyValue(err)) os.Exit(1) } - defer tx.Commit() + defer func() { _ = tx.Rollback() }() for i := 0; i < 100; i++ { var cost *float64 @@ -162,4 +162,9 @@ func populate(ctx context.Context, db *sql.DB) { } slog.Debug("rowCount", "rowCount", rowCount) } + + if err := tx.Commit(); err != nil { + slog.Error("error", "err", slog.AnyValue(err)) + os.Exit(1) + } }