diff --git a/src/client.rs b/src/client.rs index 688721d1..68ff32a7 100644 --- a/src/client.rs +++ b/src/client.rs @@ -251,10 +251,13 @@ impl Client { Ok(result) } - /// Execute a `BULK INSERT` statement, efficiantly storing a large number of + /// Execute a `BULK INSERT` statement, efficiently storing a large number of /// rows to a specified table. Note: make sure the input row follows the same /// schema as the table, otherwise calling `send()` will return an error. /// + /// This is equivalent to calling `bulk_insert("table_name", &["*"])` to merge + /// all of a tables columns. + /// /// # Example /// /// ``` @@ -299,12 +302,67 @@ impl Client { pub async fn bulk_insert<'a>( &'a mut self, table: &'a str, + ) -> crate::Result> { + self.bulk_insert_columns(table, &["*"]).await + } + + /// Execute a `BULK INSERT` statement, efficiently storing a large number of + /// rows to a specified table. Note: make sure the input row follows the same + /// schema as the column list, otherwise calling `send()` will return an error. + /// + /// # Example + /// + /// ``` + /// # use tiberius::{Config, IntoRow}; + /// # use tokio_util::compat::TokioAsyncWriteCompatExt; + /// # use std::env; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or( + /// # "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(), + /// # ); + /// # let config = Config::from_ado_string(&c_str)?; + /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?; + /// # tcp.set_nodelay(true)?; + /// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?; + /// let create_table = r#" + /// CREATE TABLE ##bulk_test ( + /// id INT IDENTITY PRIMARY KEY, + /// foo INT NOT NULL, + /// bar FLOAT NOT NULL + /// ) + /// "#; + /// + /// client.simple_query(create_table).await?; + /// + /// // Start the bulk insert with the client. + /// let mut req = client.bulk_insert_columns("##bulk_test", &["foo", "bar"]).await?; + /// + /// for (i, j) in [(0i32, 0f64), (1i32, 1f64), (2i32, 2f64)] { + /// let row = (i, j).into_row(); + /// + /// // The request will handle flushing to the wire in an optimal way, + /// // balancing between memory usage and IO performance. + /// req.send(row).await?; + /// } + /// + /// // The request must be finalized. + /// let res = req.finalize().await?; + /// assert_eq!(3, res.total()); + /// # Ok(()) + /// # } + /// ``` + pub async fn bulk_insert_columns<'a>( + &'a mut self, + table: &'a str, + columns: &'a [&'a str], ) -> crate::Result> { // Start the bulk request self.connection.flush_stream().await?; // retrieve column metadata from server - let query = format!("SELECT TOP 0 * FROM {}", table); + let columns = columns.join(", "); + let query = format!("SELECT TOP 0 {columns} FROM {table}"); let req = BatchRequest::new(query, self.connection.context().transaction_descriptor()); @@ -371,7 +429,7 @@ impl Client { &'a mut self, proc_id: RpcProcId, mut rpc_params: Vec>, - params: impl Iterator>, + params: impl Iterator>, ) -> crate::Result<()> where 'a: 'b, diff --git a/tests/bulk.rs b/tests/bulk.rs index 33b90637..110f77a8 100644 --- a/tests/bulk.rs +++ b/tests/bulk.rs @@ -218,3 +218,101 @@ test_bulk_type!(datetime2_7( 100, vec![DateTime::from_timestamp(1658524194, 123456789); 100].into_iter() )); + +macro_rules! test_bulk_columns { + ($name:ident($total_generated:literal $(, $sql_type:literal)+ $(, ($cols:expr, $generator:expr ))+ $(,)?)) => { + paste::item! { + #[test_on_runtimes] + async fn [< bulk_load_optional_ $name >](mut conn: tiberius::Client) -> Result<()> + where + S: AsyncRead + AsyncWrite + Unpin + Send, + { + use tiberius::IntoRow; + + let table = format!("##{}", random_table().await); + let column_defs = &[$($sql_type,)+]; + + conn.execute( + &format!( + "CREATE TABLE {} (id INT IDENTITY PRIMARY KEY, {})", + table, + column_defs.join(", "), + ), + &[], + ) + .await?; + + let mut count = 0; + + $( + let mut req = conn.bulk_insert_columns(&table, $cols).await?; + for i in $generator { + let row = i.into_row(); + req.send(row).await?; + } + + let res = req.finalize().await?; + count += res.total(); + )+ + assert_eq!($total_generated, count); + + Ok(()) + } + + #[test_on_runtimes] + async fn [< bulk_load_required_ $name >](mut conn: tiberius::Client) -> Result<()> + where + S: AsyncRead + AsyncWrite + Unpin + Send, + { + use tiberius::IntoRow; + let table = format!("##{}", random_table().await); + let column_defs = &[$(format!("{} NOT NULL", $sql_type),)+]; + + conn.execute( + &format!( + "CREATE TABLE {} (id INT IDENTITY PRIMARY KEY, {})", + table, + column_defs.join(", "), + ), + &[], + ) + .await?; + + let mut count = 0; + + $( + let mut req = conn.bulk_insert_columns(&table, $cols).await?; + for i in $generator { + let row = i.into_row(); + req.send(row).await?; + } + + let res = req.finalize().await?; + count += res.total(); + )+ + assert_eq!($total_generated, count); + + Ok(()) + } + + } + }; +} + +test_bulk_columns!(ab_ba_default_columns( + 200, + "a INT", + "b FLOAT", + "c INT DEFAULT 0", + (&["a", "b"], vec![(1i32, 1f64); 100]), + (&["b", "a"], vec![(2f64, 2i32); 100]), +)); + +test_bulk_columns!(ab_ba_override_default_columns( + 200, + "a INT", + "b FLOAT", + "c INT DEFAULT 0", + (&["a", "b", "c"], vec![(1i32, 1f64, 10i32); 100]), + (&["b", "c", "a"], vec![(2f64, 20i32, 2i32); 100]), +));