From c34de745f0cd825f035e9dfa9c1814953b71753e Mon Sep 17 00:00:00 2001 From: m0t0k1ch1 Date: Sun, 24 Aug 2025 17:03:23 +0900 Subject: [PATCH] test: test rollback on cancel --- sqlutil_test.go | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/sqlutil_test.go b/sqlutil_test.go index 7dea07a..b904392 100644 --- a/sqlutil_test.go +++ b/sqlutil_test.go @@ -204,11 +204,28 @@ func TestTransact(t *testing.T) { require.False(t, isTaskCompleted(t, ctx, tc.db, 2)) }) + t.Run("failure: rollback on cancel", func(t *testing.T) { + ctx := t.Context() + + txCtx, txCancel := context.WithCancel(ctx) + + err := sqlutil.Transact(txCtx, tc.db, func(ctx context.Context, tx *sql.Tx) error { + completeTask(t, ctx, tx, 1) + txCancel() + + return nil + }) + require.ErrorIs(t, err, context.Canceled) + + require.False(t, isTaskCompleted(t, ctx, tc.db, 1)) + require.False(t, isTaskCompleted(t, ctx, tc.db, 2)) + }) + t.Run("success", func(t *testing.T) { ctx := t.Context() - err := sqlutil.Transact(ctx, tc.db, func(txCtx context.Context, tx *sql.Tx) error { - completeTask(t, txCtx, tx, 1) + err := sqlutil.Transact(ctx, tc.db, func(ctx context.Context, tx *sql.Tx) error { + completeTask(t, ctx, tx, 1) return nil })