diff --git a/tpcc/src/main.rs b/tpcc/src/main.rs index 3c7fd232..1c418cee 100644 --- a/tpcc/src/main.rs +++ b/tpcc/src/main.rs @@ -12,6 +12,8 @@ use kite_sql::errors::DatabaseError; use kite_sql::storage::Storage; use rand::prelude::ThreadRng; use rand::Rng; +use std::sync::{Arc, Mutex}; +use std::thread; use std::time::{Duration, Instant}; mod delivery; @@ -55,6 +57,7 @@ pub(crate) trait TpccTest { ) -> Result<(), TpccError>; } +#[derive(Clone)] struct TpccArgs { joins: bool, } @@ -68,25 +71,54 @@ struct Args { path: String, #[clap(long, default_value = "5")] max_retry: usize, - #[clap(long, default_value = "720")] + #[clap(long, default_value = "60")] measure_time: u64, #[clap(long, default_value = "1")] num_ware: usize, + #[clap(long, default_value = "2")] + threads: usize, } -// TODO: Support multi-threaded TPCC +// Multi-threaded TPCC implementation +// Each thread runs independent transaction loops with shared statistics collection +// +// Architecture Overview: +// 1. Data Loading Phase (Serial): Load initial data to avoid race conditions +// - Load::load_items(), Load::load_warehouses(), Load::load_custs(), Load::load_ord() run sequentially +// 2. Multi-threaded Execution Phase: Each thread executes transactions concurrently +// - Each thread has its own RNG, SeqGen with different distributions, and test instances +// - Threads share database connection (Arc), prepared statements, and statistics counters +// 3. Thread-local Statistics Collection: Reduce lock contention with local counters +// - Local success/late/failure arrays and RtHist for each thread +// - Periodic sync to shared counters every 100 transactions +// 4. Result Aggregation: Combine results from all threads for final report +// - Extract final statistics from shared Mutex-protected data +// - Calculate TPC-C compliance metrics and response time histograms +// +// Threading Strategy: +// - Each thread has independent: RNG, SeqGen with thread_id-based variations, transaction test instances +// - Shared across threads: Database connection (Arc), prepared statements (Arc), statistics (Arc) +// - Synchronization: Mutex for statistics updates, Arc for read-only data sharing +// - Error handling: Unified exponential backoff retry strategy for all transaction phases +// +// Example Usage: +// cargo run -- --threads 8 --measure-time 60 --num-ware 10 fn main() -> Result<(), TpccError> { let args = Args::parse(); let mut rng = rand::thread_rng(); - let database = DataBaseBuilder::path(&args.path).build()?; + let database = Arc::new(DataBaseBuilder::path(&args.path).build()?); + // Data loading phase - keep serial to avoid data races + println!("Loading data with {} warehouses...", args.num_ware); Load::load_items(&mut rng, &database)?; Load::load_warehouses(&mut rng, &database, args.num_ware)?; Load::load_custs(&mut rng, &database, args.num_ware)?; Load::load_ord(&mut rng, &database, args.num_ware)?; + println!("Data loading completed."); - let test_statements = vec![ + // Prepare statements for each transaction type - shared across threads + let test_statements = Arc::new(vec![ vec![ database.prepare("SELECT c.c_discount, c.c_last, c.c_credit, w.w_tax FROM customer AS c JOIN warehouse AS w ON c.c_w_id = w_id AND w.w_id = ?1 AND c.c_w_id = ?2 AND c.c_d_id = ?3 AND c.c_id = ?4")?, database.prepare("SELECT c_discount, c_last, c_credit FROM customer WHERE c_w_id = ?1 AND c_d_id = ?2 AND c_id = ?3")?, @@ -134,77 +166,282 @@ fn main() -> Result<(), TpccError> { database.prepare("SELECT DISTINCT ol_i_id FROM order_line WHERE ol_w_id = ?1 AND ol_d_id = ?2 AND ol_o_id < ?3 AND ol_o_id >= (?4 - 20)")?, database.prepare("SELECT count(*) FROM stock WHERE s_w_id = ?1 AND s_i_id = ?2 AND s_quantity < ?3")?, ], - ]; - - let mut rt_hist = RtHist::new(); - let mut success = [0usize; 5]; - let mut late = [0usize; 5]; - let mut failure = [0usize; 5]; - let tests = vec![ - Box::new(NewOrdTest) as Box>, - Box::new(PaymentTest), - Box::new(OrderStatTest), - Box::new(DeliveryTest), - Box::new(SlevTest), - ]; - let tpcc_args = TpccArgs { joins: args.joins }; + ]); + + // Shared statistics - thread-safe collections with reduced lock contention + let shared_rt_hist = Arc::new(Mutex::new(RtHist::new())); + let shared_success = Arc::new(Mutex::new([0usize; 5])); + let shared_late = Arc::new(Mutex::new([0usize; 5])); + let shared_failure = Arc::new(Mutex::new([0usize; 5])); + let shared_round_count = Arc::new(Mutex::new(0usize)); + // Test configurations - create function pointers instead of trait objects + let tpcc_args = TpccArgs { joins: args.joins }; let duration = Duration::new(args.measure_time, 0); - let mut round_count = 0; - let mut seq_gen = SeqGen::new(10, 10, 1, 1, 1); let tpcc_start = Instant::now(); - while tpcc_start.elapsed() < duration { - let i = seq_gen.get(); - let tpcc_test = &tests[i]; - let statement = &test_statements[i]; + println!( + "Starting multi-threaded TPCC with {} threads for {} seconds...", + args.threads, args.measure_time + ); + + // Create worker threads + let mut handles = Vec::new(); + for thread_id in 0..args.threads { + // Clone shared data for each thread + let database_clone = Arc::clone(&database); + let test_statements_clone = Arc::clone(&test_statements); + let tpcc_args_clone = tpcc_args.clone(); + let shared_rt_hist_clone = Arc::clone(&shared_rt_hist); + let shared_success_clone = Arc::clone(&shared_success); + let shared_late_clone = Arc::clone(&shared_late); + let shared_failure_clone = Arc::clone(&shared_failure); + let shared_round_count_clone = Arc::clone(&shared_round_count); + + let num_ware = args.num_ware; + let max_retry = args.max_retry; + + let handle = thread::spawn(move || -> Result<(), TpccError> { + // Each thread has its own RNG and sequence generator with different seeds + let mut thread_rng = rand::thread_rng(); + // Use thread_id to create different transaction distributions per thread + let mut seq_gen = SeqGen::new( + 10 + (thread_id % 3), // Vary NewOrder distribution + 10 + (thread_id % 2), // Vary Payment distribution + 1 + (thread_id % 2), // Vary OrderStatus distribution + 1 + (thread_id % 2), // Vary Delivery distribution + 1 + (thread_id % 2), // Vary StockLevel distribution + ); + + // Create test instances for this thread + let tests: Vec>> = vec![ + Box::new(NewOrdTest), + Box::new(PaymentTest), + Box::new(OrderStatTest), + Box::new(DeliveryTest), + Box::new(SlevTest), + ]; + + // Thread-local statistics to reduce lock contention + let mut local_success = [0usize; 5]; + let mut local_late = [0usize; 5]; + let mut local_failure = [0usize; 5]; + let mut local_rt_hist = RtHist::new(); + let mut local_round_count = 0; + + // Thread-local transaction loop + while tpcc_start.elapsed() < duration { + let i = seq_gen.get(); + let tpcc_test = &tests[i]; + let statement = &test_statements_clone[i]; + + let mut is_succeed = false; + for j in 0..max_retry { + let transaction_start = Instant::now(); + + // 1. Create transaction - unified retry strategy + let mut tx = match database_clone.new_transaction() { + Ok(tx) => tx, + Err(err) => { + eprintln!( + "[Thread {}][{}] Error creating transaction: {}", + thread_id, + tpcc_test.name(), + err + ); + + // Use unified exponential backoff strategy + std::thread::sleep(std::time::Duration::from_millis( + 10 * (1 << j), // Exponential backoff: 10ms, 20ms, 40ms, 80ms... + )); + continue; + } + }; + + // 2. Execute transaction - unified retry strategy + let tx_result = tpcc_test.do_transaction( + &mut thread_rng, + &mut tx, + num_ware, + &tpcc_args_clone, + &statement, + ); + + if let Err(err) = tx_result { + eprintln!( + "[Thread {}][{}] Error while doing transaction: {}", + thread_id, + tpcc_test.name(), + err + ); + + // Use unified exponential backoff strategy + std::thread::sleep(std::time::Duration::from_millis( + 10 * (1 << j), // Exponential backoff + )); + continue; + } + + // 3. Commit transaction - unified retry strategy + match tx.commit() { + Ok(_) => { + // Commit successful + let rt = transaction_start.elapsed(); + + // Update local statistics + local_rt_hist.hist_inc(i, rt); + is_succeed = true; + + if rt <= RT_LIMITS[i] { + local_success[i] += 1; + } else { + local_late[i] += 1; + } + + // Successfully committed, break retry loop + break; + } + Err(err) => { + eprintln!( + "[Thread {}][{}] Error committing transaction: {}", + thread_id, + tpcc_test.name(), + err + ); + + // Use unified exponential backoff strategy + std::thread::sleep(std::time::Duration::from_millis( + 10 * (1 << j), // Exponential backoff + )); + continue; + } + } + } + + // 4. Handle retry failure + if !is_succeed { + eprintln!( + "[Thread {}][{}] Transaction failed after {} retries, continuing with next transaction", + thread_id, + tpcc_test.name(), + max_retry + ); + + // Only record failure statistics once at final failure + local_failure[i] += 1; + continue; + } - let mut is_succeed = false; - for j in 0..args.max_retry + 1 { - let transaction_start = Instant::now(); - let mut tx = database.new_transaction()?; + // Periodic checkpoint reporting and statistics sync (thread-safe) + local_round_count += 1; + if local_round_count != 0 && local_round_count % 100 == 0 { + // Sync local statistics to shared statistics + { + let mut rt_hist = shared_rt_hist_clone.lock().unwrap(); + rt_hist.merge(&local_rt_hist); + } + { + let mut success = shared_success_clone.lock().unwrap(); + for k in 0..5 { + success[k] += local_success[k]; + local_success[k] = 0; + } + } + { + let mut late = shared_late_clone.lock().unwrap(); + for k in 0..5 { + late[k] += local_late[k]; + local_late[k] = 0; + } + } + { + let mut failure = shared_failure_clone.lock().unwrap(); + for k in 0..5 { + failure[k] += local_failure[k]; + local_failure[k] = 0; + } + } + { + let mut round_count = shared_round_count_clone.lock().unwrap(); + *round_count += local_round_count; + local_round_count = 0; + + let mut rt_hist = shared_rt_hist_clone.lock().unwrap(); + println!( + "[TPCC CheckPoint {} on round {}][Thread {}][{}]: 90th Percentile RT: {:.3}", + *round_count / 100, + *round_count, + thread_id, + tpcc_test.name(), + rt_hist.hist_ckp(i) + ); + } + } + } - if let Err(err) = - tpcc_test.do_transaction(&mut rng, &mut tx, args.num_ware, &tpcc_args, &statement) + // Final sync of remaining local statistics + { + let mut rt_hist = shared_rt_hist_clone.lock().unwrap(); + rt_hist.merge(&local_rt_hist); + } { - failure[i] += 1; - eprintln!( - "[{}] Error while doing transaction: {}", - tpcc_test.name(), - err - ); - } else { - let rt = transaction_start.elapsed(); - rt_hist.hist_inc(i, rt); - is_succeed = true; - - if rt <= RT_LIMITS[i] { - success[i] += 1; - } else { - late[i] += 1; + let mut success = shared_success_clone.lock().unwrap(); + for k in 0..5 { + success[k] += local_success[k]; } - tx.commit()?; - break; } - if j < args.max_retry { - println!("[{}] Retry for the {}th time", tpcc_test.name(), j + 1); + { + let mut late = shared_late_clone.lock().unwrap(); + for k in 0..5 { + late[k] += local_late[k]; + } + } + { + let mut failure = shared_failure_clone.lock().unwrap(); + for k in 0..5 { + failure[k] += local_failure[k]; + } + } + { + let mut round_count = shared_round_count_clone.lock().unwrap(); + *round_count += local_round_count; + } + + Ok(()) + }); + + handles.push(handle); + } + + // Wait for all threads to complete + for (thread_id, handle) in handles.into_iter().enumerate() { + match handle.join() { + Ok(result) => { + if let Err(e) = result { + eprintln!("Thread {} failed: {}", thread_id, e); + return Err(e); + } + } + Err(_) => { + eprintln!("Thread {} panicked", thread_id); + return Err(TpccError::MaxRetry); } } - if !is_succeed { - return Err(TpccError::MaxRetry); - } - if round_count != 0 && round_count % 100 == 0 { - println!( - "[TPCC CheckPoint {} on round {round_count}][{}]: 90th Percentile RT: {:.3}", - round_count / 100, - tpcc_test.name(), - rt_hist.hist_ckp(i) - ); - } - round_count += 1; } + let actual_tpcc_time = tpcc_start.elapsed(); + + // Extract final statistics from shared data + let success = *shared_success.lock().unwrap(); + let late = *shared_late.lock().unwrap(); + let failure = *shared_failure.lock().unwrap(); + let rt_hist = shared_rt_hist.lock().unwrap(); + println!("---------------------------------------------------"); + println!( + "Multi-threaded TPCC completed with {} threads", + args.threads + ); // Raw Results print_transaction(&success, &late, &failure, |name, success, late, failure| { println!("|{}| sc: {} lt: {} fl: {}", name, success, late, failure) diff --git a/tpcc/src/rt_hist.rs b/tpcc/src/rt_hist.rs index 145783d8..51d77c1d 100644 --- a/tpcc/src/rt_hist.rs +++ b/tpcc/src/rt_hist.rs @@ -45,6 +45,21 @@ impl RtHist { self.cur_hist[transaction][i] += 1; } + // Merge local histogram into this histogram + pub fn merge(&mut self, other: &RtHist) { + for transaction in 0..5 { + // Merge current histograms + for i in 0..(MAX_REC * REC_PER_SEC) { + self.cur_hist[transaction][i] += other.cur_hist[transaction][i]; + } + + // Update max response time + if other.cur_max_rt[transaction] > self.cur_max_rt[transaction] { + self.cur_max_rt[transaction] = other.cur_max_rt[transaction]; + } + } + } + // Checkpoint and add to the total histogram pub fn hist_ckp(&mut self, transaction: usize) -> f64 { let mut total = 0;