Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 7 additions & 22 deletions MODERNIZATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -349,31 +349,16 @@ On closer analysis, this is actually correct. `buildRetVals` calls `handleMappin

---

## 24. Missing `tx.Rollback()` on Error Paths
## ~~24. Missing `tx.Rollback()` on Error Paths~~ (DONE)

**Files:** `cmd/sample/main.go`, `cmd/sample-ctx/main.go`, various test files
Fixed across all affected files. The `defer tx.Commit()` pattern has been replaced with:

The pattern used throughout is:
```go
tx, err := db.Begin()
if err != nil {
panic(err)
}
defer tx.Commit()
```
- `defer func() { _ = tx.Rollback() }()` as a safety-net cleanup at the top
- An explicit `tx.Commit()` call (with error handling) at the end of the successful path

`defer tx.Commit()` runs even when the function exits due to an error or panic, committing partial transactions. The correct pattern is:
```go
defer func() {
if err != nil {
tx.Rollback()
} else {
tx.Commit()
}
}()
```
For test code in `proteus_test.go` where all reads occur on the same `tx` (no need to persist data beyond the test), the deferred rollback alone is sufficient — this also provides better test isolation.

Note: `cmd/null/main.go` already uses the correct pattern.
**Files fixed:** `cmd/sample/main.go`, `cmd/sample-ctx/main.go`, `bench/bench_test.go`, `speed/speed.go`, `example_test.go`, `example2_test.go`, `mapper_test.go`, `proteus_test.go` (11 instances)

---

Expand Down Expand Up @@ -416,7 +401,7 @@ If `Build` returns an error, `productDao` will have nil function fields. Subsequ
- ~~#7 — `defer rows.Close()` (resource leak risk)~~ *(DONE)*
- ~~#8 — Remove `unsafe`~~ *(DONE)*
- ~~#15 — Makefile bug fix~~ *(DONE)*
- #24 — Missing `tx.Rollback()` in samples/tests
- ~~#24 — Missing `tx.Rollback()` in samples/tests~~ *(DONE)*

**Medium priority (idiomatic modernization):**
- ~~#1 — `interface{}` to `any`~~ *(DONE)*
Expand Down
10 changes: 7 additions & 3 deletions bench/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func populate(ctx context.Context, db *sql.DB) {
if err != nil {
panic(err)
}
defer tx.Commit()
defer func() { _ = tx.Rollback() }()

for i := 0; i < 10; i++ {
var cost *float64
Expand All @@ -44,6 +44,10 @@ func populate(ctx context.Context, db *sql.DB) {
panic(err)
}
}

if err := tx.Commit(); err != nil {
panic(err)
}
}

type closer func()
Expand All @@ -60,8 +64,8 @@ func setup() (*sql.Tx, closer) {
panic(err)
}
return tx, func() {
tx.Commit()
db.Close()
_ = tx.Commit()
_ = db.Close()
}
}

Expand Down
13 changes: 11 additions & 2 deletions cmd/sample-ctx/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func run(setupDb setupDb, productDAO ProductDAO) {
if err != nil {
panic(err)
}
defer tx.Commit()
defer func() { _ = tx.Rollback() }()

slog.Log(ctx, slog.LevelDebug, fmt.Sprintln(productDAO.FindByID(ctx, tx, 10)))
cost := new(float64)
Expand Down Expand Up @@ -97,6 +97,10 @@ func run(setupDb setupDb, productDAO ProductDAO) {

//using positional parameters instead of names
slog.Log(ctx, slog.LevelDebug, fmt.Sprintln((productDAO.FindByNameAndCostUnlabeled(ctx, tx, "Thingie", 56.23))))

if err := tx.Commit(); err != nil {
slog.Log(ctx, slog.LevelError, fmt.Sprintln(err))
}
}

func setupDbPostgres(ctx context.Context, productDAO ProductDAO) *sql.DB {
Expand All @@ -123,8 +127,9 @@ func populate(ctx context.Context, db *sql.DB, productDao ProductDAO) {
tx, err := db.Begin()
if err != nil {
slog.Log(ctx, slog.LevelError, fmt.Sprintln(err))
return
}
defer tx.Commit()
defer func() { _ = tx.Rollback() }()

for i := 0; i < 100; i++ {
var cost *float64
Expand All @@ -138,4 +143,8 @@ func populate(ctx context.Context, db *sql.DB, productDao ProductDAO) {
}
slog.Log(ctx, slog.LevelDebug, fmt.Sprintln(rowCount))
}

if err := tx.Commit(); err != nil {
slog.Log(ctx, slog.LevelError, fmt.Sprintln(err))
}
}
13 changes: 11 additions & 2 deletions cmd/sample/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func run(setupDb setupDb, productDAO ProductDAO) {
if err != nil {
panic(err)
}
defer tx.Commit()
defer func() { _ = tx.Rollback() }()

slog.Log(ctx, slog.LevelDebug, fmt.Sprintln(productDAO.FindByID(tx, 10)))
cost := new(float64)
Expand Down Expand Up @@ -96,6 +96,10 @@ func run(setupDb setupDb, productDAO ProductDAO) {

//using positional parameters instead of names
slog.Log(ctx, slog.LevelDebug, fmt.Sprintln((productDAO.FindByNameAndCostUnlabeled(tx, "Thingie", 56.23))))

if err := tx.Commit(); err != nil {
slog.Log(ctx, slog.LevelError, fmt.Sprintln(err))
}
}

func setupDbPostgres(ctx context.Context, productDAO ProductDAO) *sql.DB {
Expand All @@ -122,8 +126,9 @@ func populate(ctx context.Context, db *sql.DB, productDao ProductDAO) {
tx, err := db.Begin()
if err != nil {
slog.Log(ctx, slog.LevelError, fmt.Sprintln(err))
return
}
defer tx.Commit()
defer func() { _ = tx.Rollback() }()

for i := 0; i < 100; i++ {
var cost *float64
Expand All @@ -137,4 +142,8 @@ func populate(ctx context.Context, db *sql.DB, productDao ProductDAO) {
}
slog.Log(ctx, slog.LevelDebug, fmt.Sprintln(rowCount))
}

if err := tx.Commit(); err != nil {
slog.Log(ctx, slog.LevelError, fmt.Sprintln(err))
}
}
6 changes: 5 additions & 1 deletion example2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func Example_create() {
if err != nil {
log.Fatal(err)
}
defer tx.Commit()
defer func() { _ = tx.Rollback() }()

ctx := context.Background()

Expand All @@ -37,4 +37,8 @@ func Example_create() {
log.Fatal(err)
}
}

if err := tx.Commit(); err != nil {
log.Fatal(err)
}
}
4 changes: 3 additions & 1 deletion example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func Example_readUpdate() {
if err != nil {
panic(err)
}
defer tx.Commit()
defer func() { _ = tx.Rollback() }()

fmt.Println(productDao.FindById(ctx, tx, 10))
p := Product{10, "Thingie", 56.23}
Expand All @@ -61,4 +61,6 @@ func Example_readUpdate() {
}
fmt.Println(productDao.UpdateMap(ctx, tx, m))
fmt.Println(productDao.FindById(ctx, tx, 11))

_ = tx.Commit()
}
6 changes: 5 additions & 1 deletion mapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func setupDb(t *testing.T) *sql.DB {
slog.Error("err", "error", slog.AnyValue(err))
t.FailNow()
}
defer tx.Commit()
defer func() { _ = tx.Rollback() }()
stmt, err := tx.Prepare("insert into product(id, name, cost) values($1, $2, $3)")
if err != nil {
slog.Error("err", "error", slog.AnyValue(err))
Expand All @@ -73,6 +73,10 @@ func setupDb(t *testing.T) *sql.DB {
t.FailNow()
}
}
if err := tx.Commit(); err != nil {
slog.Error("err", "error", slog.AnyValue(err))
t.FailNow()
}
return db
}

Expand Down
22 changes: 11 additions & 11 deletions proteus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ func TestNilScanner(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer tx.Commit()
defer func() { _ = tx.Rollback() }()

_, err = tx.Exec(create)
if err != nil {
Expand Down Expand Up @@ -375,7 +375,7 @@ func TestNoParams(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer tx.Commit()
defer func() { _ = tx.Rollback() }()

_, err = tx.Exec(create)
if err != nil {
Expand Down Expand Up @@ -460,7 +460,7 @@ func TestUnnamedStructs(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer tx.Commit()
defer func() { _ = tx.Rollback() }()

_, err = tx.Exec(create)
if err != nil {
Expand Down Expand Up @@ -515,7 +515,7 @@ func TestEmbedded(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer tx.Commit()
defer func() { _ = tx.Rollback() }()

_, err = tx.Exec(create)
if err != nil {
Expand Down Expand Up @@ -584,7 +584,7 @@ func TestShouldBuildEmbeddedWithNullField(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer tx.Commit()
defer func() { _ = tx.Rollback() }()

_, err = tx.Exec(create)

Expand Down Expand Up @@ -685,7 +685,7 @@ func TestVariableMultipleUsage(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer tx.Commit()
defer func() { _ = tx.Rollback() }()

_, err = tx.Exec(create)
if err != nil {
Expand Down Expand Up @@ -787,7 +787,7 @@ func TestShouldBuildEmbedded(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer tx.Commit()
defer func() { _ = tx.Rollback() }()

_, err = tx.Exec(create)
if err != nil {
Expand Down Expand Up @@ -844,7 +844,7 @@ func TestShouldBinaryColumn(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer tx.Commit()
defer func() { _ = tx.Rollback() }()

_, err = tx.Exec(create)
if err != nil {
Expand Down Expand Up @@ -905,7 +905,7 @@ func TestShouldTimeColumn(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer tx.Commit()
defer func() { _ = tx.Rollback() }()

_, err = tx.Exec(create)
if err != nil {
Expand Down Expand Up @@ -972,7 +972,7 @@ func TestArray(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer tx.Commit()
defer func() { _ = tx.Rollback() }()

_, err = tx.Exec(create)
if err != nil {
Expand Down Expand Up @@ -1055,7 +1055,7 @@ func TestNested(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer tx.Commit()
defer func() { _ = tx.Rollback() }()

_, err = tx.Exec(create)
if err != nil {
Expand Down
7 changes: 6 additions & 1 deletion speed/speed.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func populate(ctx context.Context, db *sql.DB) {
slog.Error("error", "err", slog.AnyValue(err))
os.Exit(1)
}
defer tx.Commit()
defer func() { _ = tx.Rollback() }()

for i := 0; i < 100; i++ {
var cost *float64
Expand All @@ -162,4 +162,9 @@ func populate(ctx context.Context, db *sql.DB) {
}
slog.Debug("rowCount", "rowCount", rowCount)
}

if err := tx.Commit(); err != nil {
slog.Error("error", "err", slog.AnyValue(err))
os.Exit(1)
}
}
Loading