diff --git a/sqlutil_test.go b/sqlutil_test.go index 7dea07a..b904392 100644 --- a/sqlutil_test.go +++ b/sqlutil_test.go @@ -204,11 +204,28 @@ func TestTransact(t *testing.T) { require.False(t, isTaskCompleted(t, ctx, tc.db, 2)) }) + t.Run("failure: rollback on cancel", func(t *testing.T) { + ctx := t.Context() + + txCtx, txCancel := context.WithCancel(ctx) + + err := sqlutil.Transact(txCtx, tc.db, func(ctx context.Context, tx *sql.Tx) error { + completeTask(t, ctx, tx, 1) + txCancel() + + return nil + }) + require.ErrorIs(t, err, context.Canceled) + + require.False(t, isTaskCompleted(t, ctx, tc.db, 1)) + require.False(t, isTaskCompleted(t, ctx, tc.db, 2)) + }) + t.Run("success", func(t *testing.T) { ctx := t.Context() - err := sqlutil.Transact(ctx, tc.db, func(txCtx context.Context, tx *sql.Tx) error { - completeTask(t, txCtx, tx, 1) + err := sqlutil.Transact(ctx, tc.db, func(ctx context.Context, tx *sql.Tx) error { + completeTask(t, ctx, tx, 1) return nil })