Skip to content

Commit 1d2e39d

Browse files
authored
Merge pull request #253 from lsunsi/reproduce-rollback-after-commit
Fix rollback being attempted on no transaction because commit already rolled it back
2 parents 53c52a4 + 4cdaf87 commit 1d2e39d

File tree

3 files changed

+146
-5
lines changed

3 files changed

+146
-5
lines changed

src/pg/mod.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -387,9 +387,11 @@ fn update_transaction_manager_status<T>(
387387
if let Err(diesel::result::Error::DatabaseError(DatabaseErrorKind::SerializationFailure, _)) =
388388
query_result
389389
{
390-
transaction_manager
391-
.status
392-
.set_requires_rollback_maybe_up_to_top_level(true)
390+
if !transaction_manager.is_commit {
391+
transaction_manager
392+
.status
393+
.set_requires_rollback_maybe_up_to_top_level(true);
394+
}
393395
}
394396
query_result
395397
}

src/transaction_manager.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ pub struct AnsiTransactionManager {
146146
// See https://github.com/weiznich/diesel_async/issues/198 for
147147
// details
148148
pub(crate) is_broken: Arc<AtomicBool>,
149+
// this boolean flag tracks whether we are currently in this process
150+
// of trying to commit the transaction. this is useful because if we
151+
// are and we get a serialization failure, we might not want to attempt
152+
// a rollback up the chain.
153+
pub(crate) is_commit: bool,
149154
}
150155

151156
impl AnsiTransactionManager {
@@ -355,9 +360,18 @@ where
355360
conn.instrumentation()
356361
.on_connection_event(InstrumentationEvent::commit_transaction(depth));
357362

358-
let is_broken = conn.transaction_state().is_broken.clone();
363+
let is_broken = {
364+
let transaction_state = conn.transaction_state();
365+
transaction_state.is_commit = true;
366+
transaction_state.is_broken.clone()
367+
};
368+
369+
let res =
370+
Self::critical_transaction_block(&is_broken, conn.batch_execute(&commit_sql)).await;
371+
372+
conn.transaction_state().is_commit = false;
359373

360-
match Self::critical_transaction_block(&is_broken, conn.batch_execute(&commit_sql)).await {
374+
match res {
361375
Ok(()) => {
362376
match Self::get_transaction_state(conn)?
363377
.change_transaction_depth(TransactionDepthChange::DecreaseDepth)
@@ -392,6 +406,9 @@ where
392406
});
393407
}
394408
}
409+
} else {
410+
Self::get_transaction_state(conn)?
411+
.change_transaction_depth(TransactionDepthChange::DecreaseDepth)?;
395412
}
396413
Err(commit_error)
397414
}

tests/transactions.rs

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,125 @@ async fn concurrent_serializable_transactions_behave_correctly() {
104104
res.unwrap_err()
105105
);
106106
}
107+
108+
#[cfg(feature = "postgres")]
109+
#[tokio::test]
110+
async fn commit_with_serialization_failure_already_ends_transaction() {
111+
use diesel::prelude::*;
112+
use diesel_async::{AsyncConnection, RunQueryDsl};
113+
use std::sync::Arc;
114+
use tokio::sync::Barrier;
115+
116+
table! {
117+
users4 {
118+
id -> Integer,
119+
}
120+
}
121+
122+
// create an async connection
123+
let mut conn = super::connection_without_transaction().await;
124+
125+
struct A(Vec<&'static str>);
126+
impl diesel::connection::Instrumentation for A {
127+
fn on_connection_event(&mut self, event: diesel::connection::InstrumentationEvent<'_>) {
128+
if let diesel::connection::InstrumentationEvent::StartQuery { query, .. } = event {
129+
let q = query.to_string();
130+
let q = q.split_once(' ').map(|(a, _)| a).unwrap_or(&q);
131+
132+
if matches!(q, "BEGIN" | "COMMIT" | "ROLLBACK") {
133+
assert_eq!(q, self.0.pop().unwrap());
134+
}
135+
}
136+
}
137+
}
138+
conn.set_instrumentation(A(vec!["COMMIT", "BEGIN", "COMMIT", "BEGIN"]));
139+
140+
let mut conn1 = super::connection_without_transaction().await;
141+
142+
diesel::sql_query("CREATE TABLE IF NOT EXISTS users4 (id int);")
143+
.execute(&mut conn)
144+
.await
145+
.unwrap();
146+
147+
let barrier_1 = Arc::new(Barrier::new(2));
148+
let barrier_2 = Arc::new(Barrier::new(2));
149+
let barrier_1_for_tx1 = barrier_1.clone();
150+
let barrier_1_for_tx2 = barrier_1.clone();
151+
let barrier_2_for_tx1 = barrier_2.clone();
152+
let barrier_2_for_tx2 = barrier_2.clone();
153+
154+
let mut tx = conn.build_transaction().serializable().read_write();
155+
156+
let res = tx.run(|conn| {
157+
Box::pin(async {
158+
users4::table.select(users4::id).load::<i32>(conn).await?;
159+
160+
barrier_1_for_tx1.wait().await;
161+
diesel::insert_into(users4::table)
162+
.values(users4::id.eq(1))
163+
.execute(conn)
164+
.await?;
165+
barrier_2_for_tx1.wait().await;
166+
167+
Ok::<_, diesel::result::Error>(())
168+
})
169+
});
170+
171+
let mut tx1 = conn1.build_transaction().serializable().read_write();
172+
173+
let res1 = async {
174+
let res = tx1
175+
.run(|conn| {
176+
Box::pin(async {
177+
users4::table.select(users4::id).load::<i32>(conn).await?;
178+
179+
barrier_1_for_tx2.wait().await;
180+
diesel::insert_into(users4::table)
181+
.values(users4::id.eq(1))
182+
.execute(conn)
183+
.await?;
184+
185+
Ok::<_, diesel::result::Error>(())
186+
})
187+
})
188+
.await;
189+
barrier_2_for_tx2.wait().await;
190+
res
191+
};
192+
193+
let (res, res1) = tokio::join!(res, res1);
194+
let _ = diesel::sql_query("DROP TABLE users4")
195+
.execute(&mut conn1)
196+
.await;
197+
198+
assert!(
199+
res1.is_ok(),
200+
"Expected the second transaction to be succussfull, but got an error: {:?}",
201+
res1.unwrap_err()
202+
);
203+
204+
assert!(res.is_err(), "Expected the first transaction to fail");
205+
let err = res.unwrap_err();
206+
assert!(
207+
matches!(
208+
&err,
209+
diesel::result::Error::DatabaseError(
210+
diesel::result::DatabaseErrorKind::SerializationFailure,
211+
_
212+
)
213+
),
214+
"Expected an serialization failure but got another error: {err:?}"
215+
);
216+
217+
let mut tx = conn.build_transaction();
218+
219+
let res = tx
220+
.run(|_| Box::pin(async { Ok::<_, diesel::result::Error>(()) }))
221+
.await;
222+
223+
assert!(
224+
res.is_ok(),
225+
"Expect transaction to run fine but got an error: {:?}",
226+
res.unwrap_err()
227+
);
228+
}

0 commit comments

Comments
 (0)