From 2d092497573c61cc17d010a9d629c9d1d85e3526 Mon Sep 17 00:00:00 2001 From: Andrey Voitenkov Date: Mon, 29 Jan 2024 19:19:28 +0200 Subject: [PATCH 01/21] Rpc exec impl/and basic docs. Ready for beta. (#1) * named proc/func rpc call skeleton * out param testing * more practial output params accessor * TVP type info skel draft * TVP type info metadata out * RpcValue introduced * TVP public binding, quick-n-dirty concept * TVP public binding, simplified * tvp debug version, pre-alfa * TVP working concept proof, limited set of types supported for pure debug purposes * TVP fixed len type mapping as per 2.2.5.5.5.3 * minor refactoring and dbg! cleanup * exec_query added with minor overall refactoring; filter-stream introduced * CommandStream introduced, FilterStream removed as no longer needed. Alfa-ready * public API optimized, query_result accessors added for CommandResult * bind_table goes behind "tds73" feature * Derive macro for TableValueRow introduced * tvp macro tidy up, basic docs added for Command and CommandResult * CommandStream few more accessors/converters for CommandItem, basic docs for CommandStream --- Cargo.toml | 5 +- src/client.rs | 55 ++++- src/command.rs | 265 ++++++++++++++++++++++++ src/lib.rs | 2 + src/result.rs | 101 +++++++++- src/tds/codec.rs | 2 + src/tds/codec/rpc_request.rs | 35 +++- src/tds/codec/type_info_tvp.rs | 228 +++++++++++++++++++++ src/tds/stream.rs | 2 + src/tds/stream/command.rs | 312 +++++++++++++++++++++++++++++ src/tds/stream/query.rs | 4 +- tvp-macro/Cargo.toml | 14 ++ tvp-macro/src/attr.rs | 45 +++++ tvp-macro/src/lib.rs | 44 ++++ tvp-macro/src/table_value_param.rs | 176 ++++++++++++++++ 15 files changed, 1275 insertions(+), 15 deletions(-) create mode 100644 src/command.rs create mode 100644 src/tds/codec/type_info_tvp.rs create mode 100644 src/tds/stream/command.rs create mode 100644 tvp-macro/Cargo.toml create mode 100644 tvp-macro/src/attr.rs create mode 100644 tvp-macro/src/lib.rs create mode 100644 tvp-macro/src/table_value_param.rs diff --git a/Cargo.toml b/Cargo.toml index 4f96e962..a3894bf7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ repository = "https://github.com/prisma/tiberius" version = "0.12.2" [workspace] -members = ["runtimes-macro"] +members = ["runtimes-macro", "tvp-macro"] [[test]] path = "tests/query.rs" @@ -162,6 +162,9 @@ version = "1" [dev-dependencies.runtimes-macro] path = "./runtimes-macro" +[dependencies.tvp-macro] +path = "./tvp-macro" + [dev-dependencies] names = "0.14" anyhow = "1" diff --git a/src/client.rs b/src/client.rs index 688721d1..ca7ba16f 100644 --- a/src/client.rs +++ b/src/client.rs @@ -14,6 +14,7 @@ pub use auth::*; pub use config::*; pub(crate) use connection::*; +use crate::tds::codec::{MetaDataColumn, RpcValue}; use crate::tds::stream::ReceivedToken; use crate::{ result::ExecuteResult, @@ -357,12 +358,12 @@ impl Client { RpcParam { name: Cow::Borrowed("stmt"), flags: BitFlags::empty(), - value: ColumnData::String(Some(query.into())), + value: RpcValue::Scalar(ColumnData::String(Some(query.into()))), }, RpcParam { name: Cow::Borrowed("params"), flags: BitFlags::empty(), - value: ColumnData::I32(Some(0)), + value: RpcValue::Scalar(ColumnData::I32(Some(0))), }, ] } @@ -388,12 +389,12 @@ impl Client { rpc_params.push(RpcParam { name: Cow::Owned(format!("@P{}", i + 1)), flags: BitFlags::empty(), - value: param, + value: RpcValue::Scalar(param), // for compat purposes, review later }); } if let Some(params) = rpc_params.iter_mut().find(|x| x.name == "params") { - params.value = ColumnData::String(Some(param_str.into())); + params.value = RpcValue::Scalar(ColumnData::String(Some(param_str.into()))); } let req = TokenRpcRequest::new( @@ -407,4 +408,50 @@ impl Client { Ok(()) } + + pub(crate) async fn rpc_run_command<'a, 'b>( + &'a mut self, + command_name: Cow<'b, str>, + rpc_params: Vec>, + ) -> crate::Result<()> + where + 'a: 'b, + { + let req = TokenRpcRequest::new( + command_name, + rpc_params, + self.connection.context().transaction_descriptor(), + ); + + let id = self.connection.context_mut().next_packet_id(); + self.connection.send(PacketHeader::rpc(id), req).await?; + + Ok(()) + } + + pub(crate) async fn query_run_for_metadata<'a, 'b>( + &'a mut self, + query: String, + ) -> crate::Result>>> { + self.connection.flush_stream().await?; + + let req = BatchRequest::new(query, self.connection.context().transaction_descriptor()); + + let id = self.connection.context_mut().next_packet_id(); + self.connection.send(PacketHeader::batch(id), req).await?; + + let token_stream = TokenStream::new(&mut self.connection).try_unfold(); + + let columns = token_stream + .try_fold(None, |mut columns, token| async move { + if let ReceivedToken::NewResultset(metadata) = token { + columns = Some(metadata.columns.clone()); + }; + + Ok(columns) + }) + .await?; + + Ok(columns) + } } diff --git a/src/command.rs b/src/command.rs new file mode 100644 index 00000000..454b85ac --- /dev/null +++ b/src/command.rs @@ -0,0 +1,265 @@ +use std::borrow::Cow; +pub use tvp_macro; + +use enumflags2::BitFlags; +use futures_util::io::{AsyncRead, AsyncWrite}; + +use crate::{ + tds::{ + codec::{RpcParam, RpcStatus::ByRefValue, RpcValue, TypeInfoTvp}, + stream::{CommandStream, TokenStream}, + }, + Client, ColumnData, IntoSql, +}; + +#[doc(inline)] +pub use tvp_macro::TableValueRow; + +/// Any structure that represents a row in a Table Value parameter must implement this trait +pub trait TableValueRow<'a> { + /// Bind row field values, called by `Command` instance before making the call to the server + fn bind_fields(&self, data_row: &mut SqlTableDataRow<'a>); // call data_row.add_field(val) for each field + /// Database type name that represents this TVP, like `dbo.MyType`. + fn get_db_type() -> &'static str; +} + +/// Implemented as generic for `IntoIterator` +pub trait TableValue<'a> { + fn into_sql(self) -> SqlTableData<'a>; +} + +impl<'a, R, C> TableValue<'a> for C +where + R: TableValueRow<'a> + 'a, + C: IntoIterator, +{ + fn into_sql(self) -> SqlTableData<'a> { + let mut data = Vec::new(); + for row in self.into_iter() { + let mut data_row = SqlTableDataRow::new(); + row.bind_fields(&mut data_row); + data.push(data_row); + } + + SqlTableData { + rows: data, + db_type: R::get_db_type(), + } + } +} + +/// Remote command (Stored Procedure of UDF) with bound parameters +#[derive(Debug)] +pub struct Command<'a> { + name: Cow<'a, str>, + params: Vec>, // TODO: might make sense to check if param names are unique, but server would recject repeating params anyway +} + +#[derive(Debug)] +struct CommandParam<'a> { + name: Cow<'a, str>, + out: bool, + data: CommandParamData<'a>, +} + +#[derive(Debug)] +enum CommandParamData<'a> { + Scalar(ColumnData<'a>), + Table(SqlTableData<'a>), +} + +#[derive(Debug)] +pub struct SqlTableData<'a> { + rows: Vec>, + db_type: &'static str, +} + +#[derive(Debug)] +/// TVP row binding public API +pub struct SqlTableDataRow<'a> { + col_data: Vec>, +} +impl<'a> SqlTableDataRow<'a> { + fn new() -> SqlTableDataRow<'a> { + SqlTableDataRow { + col_data: Vec::new(), + } + } + /// Adds TVP field value to the row. Must be called for each column. + /// The values are sent to the server in the same order as these calls. + pub fn add_field(&mut self, data: impl IntoSql<'a> + 'a) { + self.col_data.push(data.into_sql()); + } +} + +impl<'a> Command<'a> { + /// Constructs a new command object with given name. + pub fn new(proc_name: impl Into>) -> Self { + Self { + name: proc_name.into(), + params: Vec::new(), + } + } + + /// Binds scalar parameter with the given name to the command. + pub fn bind_param(&mut self, name: impl Into>, data: impl IntoSql<'a> + 'a) { + self.params.push(CommandParam { + name: name.into(), + out: false, + data: CommandParamData::Scalar(data.into_sql()), + }); + } + + /// Binds by-ref (OUT) scalar parameter to the command. + /// Returned value can be found by the same name in the returned values collection. + pub fn bind_out_param(&mut self, name: impl Into>, data: impl IntoSql<'a> + 'a) { + self.params.push(CommandParam { + name: name.into(), + out: true, + data: CommandParamData::Scalar(data.into_sql()), + }); + } + + /// Binds table-valued parameter to the command. + /// Provided argument must implement `TableValue` trait. + /// + /// Example + /// + /// ```no_run + /// # use tiberius::{numeric::Numeric, Client, Command, TableValueRow}; + /// # use tokio_util::compat::TokioAsyncWriteCompatExt; + /// # use std::env; + /// #[derive(TableValueRow)] + /// struct SomeGeoList { + /// eid: i32, + /// lat: Numeric, + /// lon: Numeric, + /// } + /// # async fn main() -> Result<(), Box> { + /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or( + /// # "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(), + /// # ); + /// # let config = Config::from_ado_string(&c_str)?; + /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?; + /// # tcp.set_nodelay(true)?; + /// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?; + /// + /// let r1 = SomeGeoList { + /// eid: 1, + /// lat: Numeric::new_with_scale(10, 6), + /// lon: Numeric::new_with_scale(14, 6), + /// }; + /// let r2 = SomeGeoList { + /// eid: 4, + /// lat: Numeric::new_with_scale(101, 6), + /// lon: Numeric::new_with_scale(142, 6), + /// }; + /// + /// let tbl = vec![r1, r2]; + /// + /// let mut cmd = Command::new("dbo.usp_TheGeoProcedure"); + /// + /// cmd.bind_table("@table", tbl); + /// # Ok(()) + /// # } + /// ``` + /// + #[cfg(feature = "tds73")] + pub fn bind_table(&mut self, name: impl Into>, data: impl TableValue<'a> + 'a) { + self.params.push(CommandParam { + name: name.into(), + out: false, + data: CommandParamData::Table(data.into_sql()), + }); + } + + /// Executes the `Command` in the SQL Server, returning `CommandStream` that + /// can be collected into `CommandResult` for convinience. + /// + /// Example + /// + /// ```no_run + /// # use tiberius::{numeric::Numeric, Client, Command}; + /// # use tokio_util::compat::TokioAsyncWriteCompatExt; + /// # use std::env; + /// # async fn main() -> Result<(), Box> { + /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or( + /// # "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(), + /// # ); + /// # let config = Config::from_ado_string(&c_str)?; + /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?; + /// # tcp.set_nodelay(true)?; + /// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?; + /// let mut cmd = Command::new("dbo.usp_SomeStoredProc"); + /// + /// cmd.bind_param("@foo", 34i32); + /// cmd.bind_out_param("@bar", "bar"); + /// let res = cmd.exec(&mut client).await?.into_command_result().await?; + /// + /// let rv: Option = res.try_return_value("@bar")?; + /// let rc = res.return_code(); + /// + /// println!("And we got bar: {:#?}, return_code: {}", rv, rc); + /// # Ok(()) + /// # } + /// ``` + /// + pub async fn exec<'b, S>(self, client: &'b mut Client) -> crate::Result> + where + S: AsyncRead + AsyncWrite + Unpin + Send, + { + let rpc_params = Command::build_rpc_params(self.params, client).await?; + + client.connection.flush_stream().await?; + client.rpc_run_command(self.name, rpc_params).await?; + + let ts = TokenStream::new(&mut client.connection); + let result = CommandStream::new(ts.try_unfold()); + + Ok(result) + } + + async fn build_rpc_params<'b, S>( + cmd_params: Vec>, + client: &'b mut Client, + ) -> crate::Result>> + where + S: AsyncRead + AsyncWrite + Unpin + Send, + { + let mut rpc_params = Vec::new(); + for p in cmd_params.into_iter() { + let rpc_val = match p.data { + CommandParamData::Scalar(col) => RpcValue::Scalar(col), + CommandParamData::Table(t) => { + let type_info_tvp = TypeInfoTvp::new( + t.db_type, + t.rows.into_iter().map(|r| r.col_data).collect(), + ); + // it might make sense to expose some API for the caller so they could cache metadata + let cols_metadata = client + .query_run_for_metadata(format!( + "DECLARE @P AS {};SELECT TOP 0 * FROM @P", + t.db_type + )) + .await?; + RpcValue::Table(if let Some(cm) = cols_metadata { + type_info_tvp.with_metadata(cm) + } else { + type_info_tvp + }) + } + }; + let rpc_param = RpcParam { + name: p.name, + flags: if p.out { + BitFlags::from_flag(ByRefValue) + } else { + BitFlags::empty() + }, + value: rpc_val, + }; + rpc_params.push(rpc_param); + } + Ok(rpc_params) + } +} diff --git a/src/lib.rs b/src/lib.rs index 882f5ad3..c2e2c034 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -257,6 +257,7 @@ pub(crate) extern crate bigdecimal_ as bigdecimal; mod macros; mod client; +mod command; mod from_sql; mod query; mod sql_read_bytes; @@ -270,6 +271,7 @@ mod tds; mod sql_browser; pub use client::{AuthMethod, Client, Config}; +pub use command::{Command, SqlTableDataRow, TableValueRow}; pub(crate) use error::Error; pub use from_sql::{FromSql, FromSqlOwned}; pub use query::Query; diff --git a/src/result.rs b/src/result.rs index 19ba6faf..a2edae35 100644 --- a/src/result.rs +++ b/src/result.rs @@ -1,7 +1,9 @@ pub use crate::tds::stream::{QueryItem, ResultMetadata}; use crate::{ client::Connection, - tds::stream::{ReceivedToken, TokenStream}, + error::Error, + tds::stream::{CommandReturnValue, ReceivedToken, TokenStream}, + FromSql, Row, }; use futures_util::io::{AsyncRead, AsyncWrite}; use futures_util::stream::TryStreamExt; @@ -113,3 +115,100 @@ impl IntoIterator for ExecuteResult { self.rows_affected.into_iter() } } + +/// A result from a command execution, listing the number of affected rows, +/// return code, values of the OUT params and any possible record sets returned +/// by the command. +/// +/// # Example +/// +/// ```no_run +/// # use tiberius::{numeric::Numeric, Client, Command}; +/// # use tokio_util::compat::TokioAsyncWriteCompatExt; +/// # use std::env; +/// # async fn main() -> Result<(), Box> { +/// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or( +/// # "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(), +/// # ); +/// # let config = Config::from_ado_string(&c_str)?; +/// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?; +/// # tcp.set_nodelay(true)?; +/// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?; +/// let mut cmd = Command::new("dbo.usp_SomeStoredProc"); +/// +/// cmd.bind_param("@foo", 34i32); +/// cmd.bind_out_param("@bar", "bar"); +/// let res = cmd.exec(&mut client).await?.into_command_result().await?; +/// +/// let rv: Option = res.try_return_value("@bar")?; +/// let rc = res.return_code(); +/// let ra = res.rows_affected(); +/// +/// println!("And we got bar: {:#?}, return_code: {}", rv, rc); +/// +/// let rs0 = res.to_query_result(0) +/// if let Some(rows) = rs0 { +/// printls!("First record set: {:#?}", rows); +/// } +/// # Ok(()) +/// # } +/// ``` +/// +#[derive(Debug)] +pub struct CommandResult { + pub(crate) rows_affected: Vec, + pub(crate) return_code: u32, + pub(crate) return_values: Vec, + pub(crate) query_results: Vec>, +} + +impl<'a> CommandResult { + /// A slice of numbers of rows affected in the same order as the given + /// queries. + pub fn rows_affected(&self) -> &[u64] { + self.rows_affected.as_slice() + } + + /// Return Code for the command. The server must return the code. + pub fn return_code(&self) -> u32 { + self.return_code + } + + /// Number of actually returned values (OUT parameters) available. + pub fn return_values_len(&self) -> usize { + self.return_values.len() + } + + /// Try to get returned value by the OUT parameter name. + /// If the value is NUL, `None` returned. + pub fn try_return_value(&'a self, name: &str) -> crate::Result> + where + T: FromSql<'a>, + { + let idx = self + .return_values + .iter() + .position(|p| p.name.eq(name)) + .ok_or_else(|| { + Error::Conversion(format!("Could not find return value {}", name).into()) + })?; + let col_data = self.return_values.get(idx).unwrap(); + + T::from_sql(&col_data.data) + } + + /// Get returned record set by its index (zero-based). Ruturns `None` if the index + /// is out of range. + pub fn to_query_result(&self, idx: usize) -> Option<&Vec> { + self.query_results.get(idx) + } +} + +impl IntoIterator for CommandResult { + type Item = Vec; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.query_results.into_iter() + } +} diff --git a/src/tds/codec.rs b/src/tds/codec.rs index 07f13310..ee0b8630 100644 --- a/src/tds/codec.rs +++ b/src/tds/codec.rs @@ -12,6 +12,7 @@ mod pre_login; mod rpc_request; mod token; mod type_info; +mod type_info_tvp; pub use batch_request::*; pub use bulk_load::*; @@ -28,6 +29,7 @@ pub use pre_login::*; pub use rpc_request::*; pub use token::*; pub use type_info::*; +pub use type_info_tvp::*; const HEADER_BYTES: usize = 8; const ALL_HEADERS_LEN_TX: usize = 22; diff --git a/src/tds/codec/rpc_request.rs b/src/tds/codec/rpc_request.rs index 368cbd3c..37b857de 100644 --- a/src/tds/codec/rpc_request.rs +++ b/src/tds/codec/rpc_request.rs @@ -1,3 +1,4 @@ +use super::TypeInfoTvp; use super::{AllHeaderTy, Encode, ALL_HEADERS_LEN_TX}; use crate::{tds::codec::ColumnData, BytesMutWithTypeInfo, Result}; use bytes::{BufMut, BytesMut}; @@ -46,11 +47,18 @@ impl<'a> TokenRpcRequest<'a> { } } +#[derive(Debug)] +pub enum RpcValue<'a> { + Scalar(ColumnData<'a>), + Table(TypeInfoTvp<'a>), // as per grammar, TYPE_INFO_TVP contains data rows (looks quite odd) +} + #[derive(Debug)] pub struct RpcParam<'a> { pub name: Cow<'a, str>, pub flags: BitFlags, - pub value: ColumnData<'a>, + // pub value: ColumnData<'a>, + pub value: RpcValue<'a>, } /// 2.2.6.6 RPC Request @@ -103,10 +111,18 @@ impl<'a> Encode for TokenRpcRequest<'a> { let val = (0xffff_u32) | ((*id as u16) as u32) << 16; dst.put_u32_le(val); } - RpcProcIdValue::Name(ref _name) => { - //let (left_bytes, _) = try!(write_varchar::(&mut cursor, name, 0)); - //assert_eq!(left_bytes, 0); - todo!() + RpcProcIdValue::Name(ref name) => { + let len_pos = dst.len(); + dst.put_u16_le(0u16); + let mut length = 0_u16; + + for chr in name.encode_utf16() { + dst.put_u16_le(chr); + length += 1; + } + let dst: &mut [u8] = dst.borrow_mut(); + let mut dst = &mut dst[len_pos..]; + dst.put_u16_le(length as u16); } } @@ -134,8 +150,13 @@ impl<'a> Encode for RpcParam<'a> { dst.put_u8(self.flags.bits()); - let mut dst_fi = BytesMutWithTypeInfo::new(dst); - self.value.encode(&mut dst_fi)?; + match self.value { + RpcValue::Scalar(value) => { + let mut dst_ti = BytesMutWithTypeInfo::new(dst); + value.encode(&mut dst_ti)?; + } + RpcValue::Table(value) => value.encode(dst)?, + } let dst: &mut [u8] = dst.borrow_mut(); dst[len_pos] = length; diff --git a/src/tds/codec/type_info_tvp.rs b/src/tds/codec/type_info_tvp.rs new file mode 100644 index 00000000..22e6c2fc --- /dev/null +++ b/src/tds/codec/type_info_tvp.rs @@ -0,0 +1,228 @@ +use std::borrow::BorrowMut; + +use asynchronous_codec::BytesMut; +use bytes::BufMut; + +use crate::ColumnData; + +use super::{BytesMutWithTypeInfo, Encode, FixedLenType, MetaDataColumn, TypeInfo, VarLenContext}; + +const TVPTYPE: u8 = 0xF3; + +#[derive(Debug)] +pub struct TypeInfoTvp<'a> { + scheema_name: &'a str, + db_type_name: &'a str, + columns: Option>>, + data: Vec>>, +} + +impl<'a> Encode for TypeInfoTvp<'a> { + fn encode(self, dst: &mut BytesMut) -> crate::Result<()> { + // TVPTYPE = %xF3 + // TVP_TYPE_INFO = TVPTYPE + // TVP_TYPENAME + // TVP_COLMETADATA + // [TVP_ORDER_UNIQUE] + // [TVP_COLUMN_ORDERING] + // TVP_END_TOKEN + + dst.put_u8(TVPTYPE); + put_b_varchar("", dst); // DB name + put_b_varchar(self.scheema_name, dst); + put_b_varchar(self.db_type_name, dst); + + if let Some(ref columns_metadata) = self.columns { + dst.put_u16_le(columns_metadata.len() as u16); + for col in columns_metadata { + // TvpColumnMetaData = UserType + // Flags + // TYPE_INFO + // ColName ; Column metadata instance + dst.put_u32_le(0_u32); + col.base.clone().encode(dst)?; // Arc would look better than this clone, but might be actually slower + put_b_varchar("", dst); // 2.2.5.5.5.1: ColName MUST be a zero-length string in the TVP. + // put_b_varchar(col.col_name, dst); + } + } else { + dst.put_u16_le(0xFFFF_u16); // TVP_NULL_TOKEN, server knows the type (never worked in practice) + } + + dst.put_u8(0_u8); // TVP_END_TOKEN + + for row in self.data.into_iter() { + dst.put_u8(0x01u8); // TVP_ROW_TOKEN = %x01 + for (i, col) in row.into_iter().enumerate() { + let mut dst_ti = BytesMutWithTypeInfo::new(dst); + if let Some(ref metadata) = self.columns { + dst_ti = dst_ti.with_type_info(&metadata[i].base.ty); + } + col.encode(&mut dst_ti)?; + } + } + // TVP_ROW_TOKEN = %x01 ; A row as defined by TVP_COLMETADATA follows + // TvpColumnData = TYPE_VARBYTE ; Actual value must match metadata for the column + // AllColumnData = *TvpColumnData ; Chunks of data, one per non-default column defined + // ; in TVP_COLMETADATA + // TVP_ROW = TVP_ROW_TOKEN + // AllColumnData + + dst.put_u8(0_u8); // TVP_END_TOKEN + + Ok(()) + } +} + +fn put_b_varchar>(s: T, dst: &mut BytesMut) { + let len_pos = dst.len(); + dst.put_u8(0u8); + let mut length = 0_u8; + + for chr in s.as_ref().encode_utf16() { + dst.put_u16_le(chr); + length += 1; + } + let dst: &mut [u8] = dst.borrow_mut(); + dst[len_pos] = length; +} + +impl<'a> TypeInfoTvp<'a> { + pub fn new(type_name: &'a str, rows: Vec>>) -> TypeInfoTvp<'a> { + let (scheema_name, db_type_name) = if let Some((s, t)) = type_name.split_once(".") { + (s, t) + } else { + ("", type_name.as_ref()) + }; + TypeInfoTvp { + scheema_name, + db_type_name, + columns: None, + data: rows, + } + } + + pub fn with_metadata(self, metadata: Vec>) -> TypeInfoTvp<'_> { + let mut metadata = metadata; + // 2.2.5.5.5.3 + for mdc in metadata.iter_mut() { + let ty_replace = + match mdc.base.ty { + TypeInfo::FixedLen(ref ty) => { + match ty { + FixedLenType::Int1 => Some(TypeInfo::VarLenSized(VarLenContext::new( + super::VarLenType::Intn, + 1, + None, + ))), + FixedLenType::Bit => Some(TypeInfo::VarLenSized(VarLenContext::new( + super::VarLenType::Bitn, + 1, + None, + ))), + FixedLenType::Int2 => Some(TypeInfo::VarLenSized(VarLenContext::new( + super::VarLenType::Intn, + 2, + None, + ))), + FixedLenType::Int4 => Some(TypeInfo::VarLenSized(VarLenContext::new( + super::VarLenType::Intn, + 4, + None, + ))), + FixedLenType::Datetime4 => Some(TypeInfo::VarLenSized( + VarLenContext::new(super::VarLenType::Datetimen, 4, None), + )), + FixedLenType::Float4 => Some(TypeInfo::VarLenSized( + VarLenContext::new(super::VarLenType::Floatn, 4, None), + )), + FixedLenType::Money => Some(TypeInfo::VarLenSized(VarLenContext::new( + super::VarLenType::Money, + 8, + None, + ))), + FixedLenType::Datetime => Some(TypeInfo::VarLenSized( + VarLenContext::new(super::VarLenType::Datetimen, 8, None), + )), + FixedLenType::Float8 => Some(TypeInfo::VarLenSized( + VarLenContext::new(super::VarLenType::Floatn, 8, None), + )), + FixedLenType::Money4 => Some(TypeInfo::VarLenSized( + VarLenContext::new(super::VarLenType::Money, 4, None), + )), + FixedLenType::Int8 => Some(TypeInfo::VarLenSized(VarLenContext::new( + super::VarLenType::Intn, + 8, + None, + ))), + _ => None, + } + } + // TypeInfo::VarLenSized(ref ctx) => match ctx.r#type() { + // super::VarLenType::Guid => todo!(), + // super::VarLenType::Intn => todo!(), + // super::VarLenType::Bitn => todo!(), + // super::VarLenType::Decimaln => todo!(), + // super::VarLenType::Numericn => todo!(), + // super::VarLenType::Floatn => todo!(), + // super::VarLenType::Money => todo!(), + // super::VarLenType::Datetimen => todo!(), + // super::VarLenType::Daten => todo!(), + // super::VarLenType::Timen => todo!(), + // super::VarLenType::Datetime2 => todo!(), + // super::VarLenType::DatetimeOffsetn => todo!(), + // super::VarLenType::BigVarBin => todo!(), + // super::VarLenType::BigVarChar => todo!(), + // super::VarLenType::BigBinary => todo!(), + // super::VarLenType::BigChar => todo!(), + // super::VarLenType::NVarchar => todo!(), + // super::VarLenType::NChar => todo!(), + // super::VarLenType::Xml => todo!(), + // super::VarLenType::Udt => todo!(), + // super::VarLenType::Text => todo!(), + // super::VarLenType::Image => todo!(), + // super::VarLenType::NText => todo!(), + // super::VarLenType::SSVariant => todo!(), + // }, + // TypeInfo::VarLenSizedPrecision { + // ty, + // size, + // precision, + // scale, + // } => match ty { + // super::VarLenType::Guid => todo!(), + // super::VarLenType::Intn => todo!(), + // super::VarLenType::Bitn => todo!(), + // super::VarLenType::Decimaln => todo!(), + // super::VarLenType::Numericn => todo!(), + // super::VarLenType::Floatn => todo!(), + // super::VarLenType::Money => todo!(), + // super::VarLenType::Datetimen => todo!(), + // super::VarLenType::Daten => todo!(), + // super::VarLenType::Timen => todo!(), + // super::VarLenType::Datetime2 => todo!(), + // super::VarLenType::DatetimeOffsetn => todo!(), + // super::VarLenType::BigVarBin => todo!(), + // super::VarLenType::BigVarChar => todo!(), + // super::VarLenType::BigBinary => todo!(), + // super::VarLenType::BigChar => todo!(), + // super::VarLenType::NVarchar => todo!(), + // super::VarLenType::NChar => todo!(), + // super::VarLenType::Xml => todo!(), + // super::VarLenType::Udt => todo!(), + // super::VarLenType::Text => todo!(), + // super::VarLenType::Image => todo!(), + // super::VarLenType::NText => todo!(), + // super::VarLenType::SSVariant => todo!(), + // }, + _ => None, + }; + if let Some(ty) = ty_replace { + mdc.base.ty = ty; + } + } + TypeInfoTvp { + columns: Some(metadata), + ..self + } + } +} diff --git a/src/tds/stream.rs b/src/tds/stream.rs index e6454876..d01e6b9b 100644 --- a/src/tds/stream.rs +++ b/src/tds/stream.rs @@ -1,5 +1,7 @@ +mod command; mod query; mod token; +pub use command::*; pub use query::*; pub use token::*; diff --git a/src/tds/stream/command.rs b/src/tds/stream/command.rs new file mode 100644 index 00000000..e90b0559 --- /dev/null +++ b/src/tds/stream/command.rs @@ -0,0 +1,312 @@ +use crate::tds::stream::ReceivedToken; +use crate::{row::ColumnType, Column, Row}; +use crate::{ColumnData, CommandResult, ResultMetadata}; +use futures_util::{ + ready, + stream::{BoxStream, Peekable, Stream, StreamExt, TryStreamExt}, +}; +use std::{ + fmt::Debug, + pin::Pin, + sync::Arc, + task::{self, Poll}, +}; + +/// A set of `Streams` of [`CommandItem`] values, which can be either result +/// metadata, row, return status, return value, etc. +/// +/// # Example +/// +/// ```no_run +/// # use futures::TryStreamExt; +/// # use tiberius::{numeric::Numeric, Client, Command}; +/// # use tokio::net::TcpStream; +/// # use tokio_util::compat::TokioAsyncWriteCompatExt; +/// # #[tokio::main] +/// # async fn main() -> Result<(), Box> { +/// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or( +/// # "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(), +/// # ); +/// # let config = Config::from_ado_string(&c_str)?; +/// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?; +/// # tcp.set_nodelay(true)?; +/// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?; +/// let mut cmd = Command::new("dbo.usp_SomeStoredProc"); +/// +/// cmd.bind_param("@foo", 34i32); +/// cmd.bind_param("@zoo", "the zoo string prm"); +/// cmd.bind_out_param("@bar", "bar"); +/// let stream = cmd.exec(&mut client).await?; +/// +/// while let Some(item) = stream.try_next().await? { +/// match item { +/// // our first item is the column data always +/// CommandItem::Metadata(meta) if meta.result_index() == 0 => { +/// // the first result column info can be handled here +/// } +/// // ... and from there on from 0..N rows +/// CommandItem::Row(row) if row.result_index() == 0 => { +/// let var = row.get(0); +/// } +/// // the second result set returns first another metadata item +/// CommandItem::Metadata(meta) => { +/// // .. handling +/// } +/// // ...and, again, we get rows from the second resultset +/// CommandItem::Row(row) => { +/// let var = row.get(0); +/// } +/// // check return status (mandatory, returned always) +/// CommandItem::ReturnStatus(rs) => { +/// // .... do something +/// } +/// // check return status (mandatory, returned always) +/// CommandItem::ReturnValue(rv) => { +/// // .... do something, like push to a collection +/// } +/// // get affected row count +/// CommandItem::RowsAffected(ra) => { +/// // .... do something, like push to a collection +/// } +/// } +/// } +/// # Ok(()) +/// # } +/// ``` +/// +pub struct CommandStream<'a> { + token_stream: Peekable>>, + columns: Option>>, + result_set_index: Option, +} + +impl<'a> Debug for CommandStream<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CommandStream") + .field( + "token_stream", + &"BoxStream<'a, crate::Result>", + ) + .finish() + } +} + +impl<'a> CommandStream<'a> { + pub(crate) fn new(token_stream: BoxStream<'a, crate::Result>) -> Self { + Self { + token_stream: token_stream.peekable(), + columns: None, + result_set_index: None, + } + } + + /// Collects all results from the command in the stream into memory in the order + /// of querying. + pub async fn into_command_result(mut self) -> crate::Result { + let mut results: Vec> = Vec::new(); + let mut result: Option> = None; + let mut return_status = 0; + let mut return_values = Vec::new(); + let mut rows_affected = Vec::new(); + + while let Some(item) = self.try_next().await? { + match (item, &mut result) { + (CommandItem::Row(row), None) => { + result = Some(vec![row]); + } + (CommandItem::Row(row), Some(ref mut result)) => result.push(row), + (CommandItem::Metadata(_), None) => { + result = Some(Vec::new()); + } + (CommandItem::Metadata(_), ref mut previous_result) => { + results.push(previous_result.take().unwrap()); + result = None; + } + (CommandItem::ReturnStatus(rs), _) => return_status = rs, + (CommandItem::ReturnValue(rv), _) => return_values.push(rv), + (CommandItem::RowsAffected(rows), _) => rows_affected.push(rows), + } + } + + if let Some(result) = result { + results.push(result); + } + + Ok(CommandResult { + return_code: return_status, + return_values, + query_results: results, + rows_affected, + }) + } + + /// Convert the stream into a stream of rows, skipping all other items. + pub fn into_row_stream(self) -> BoxStream<'a, crate::Result> { + let s = self.try_filter_map(|item| async { + match item { + CommandItem::Row(row) => Ok(Some(row)), + _ => Ok(None), + } + }); + + Box::pin(s) + } +} + +#[derive(Debug)] +pub struct CommandReturnValue { + pub(crate) name: String, + pub(crate) _ord: u16, // TODO: remove? do we need it? + pub(crate) data: ColumnData<'static>, +} + +/// Resulting data from a command. +#[derive(Debug)] +pub enum CommandItem { + /// A single row of data. + Row(Row), + /// Information of the upcoming row data. + Metadata(ResultMetadata), + /// Return Status from the server + ReturnStatus(u32), + /// Return Value, matching OUT parameter(s) + ReturnValue(CommandReturnValue), + /// Rows Affected, for one of the statements ran in server + RowsAffected(u64), +} + +impl CommandItem { + pub(crate) fn metadata(columns: Arc>, result_index: usize) -> Self { + Self::Metadata(ResultMetadata { + columns, + result_index, + }) + } + + /// Returns a reference to the metadata, if the item is of a correct variant. + pub fn as_metadata(&self) -> Option<&ResultMetadata> { + match self { + CommandItem::Metadata(ref metadata) => Some(metadata), + _ => None, + } + } + + /// Returns a reference to the row, if the item is of a correct variant. + pub fn as_row(&self) -> Option<&Row> { + match self { + CommandItem::Row(ref row) => Some(row), + _ => None, + } + } + + /// Returns the metadata, if the item is of a correct variant. + pub fn into_metadata(self) -> Option { + match self { + CommandItem::Metadata(metadata) => Some(metadata), + _ => None, + } + } + + /// Returns the row, if the item is of a correct variant. + pub fn into_row(self) -> Option { + match self { + CommandItem::Row(row) => Some(row), + _ => None, + } + } + + /// Returns the return status, if the item if on a correct variant. + pub fn as_return_status(&self) -> Option { + match self { + CommandItem::ReturnStatus(rs) => Some(*rs), + _ => None, + } + } + + /// Returns the return value, if the item if on a correct variant. + pub fn as_return_value(&self) -> Option<&CommandReturnValue> { + match self { + CommandItem::ReturnValue(rv) => Some(rv), + _ => None, + } + } + + /// Returns the return value, if the item if on a correct variant. + pub fn into_return_value(self) -> Option { + match self { + CommandItem::ReturnValue(rv) => Some(rv), + _ => None, + } + } +} + +impl<'a> Stream for CommandStream<'a> { + type Item = crate::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + let this = self.get_mut(); + + loop { + let token = match ready!(this.token_stream.poll_next_unpin(cx)) { + Some(res) => res?, + None => return Poll::Ready(None), + }; + + return match token { + ReceivedToken::NewResultset(meta) => { + let column_meta = meta + .columns + .iter() + .map(|x| Column { + name: x.col_name.to_string(), + column_type: ColumnType::from(&x.base.ty), + }) + .collect::>(); + + let column_meta = Arc::new(column_meta); + this.columns = Some(column_meta.clone()); + + this.result_set_index = this.result_set_index.map(|i| i + 1); + + let query_item = + CommandItem::metadata(column_meta, *this.result_set_index.get_or_insert(0)); + + return Poll::Ready(Some(Ok(query_item))); + } + ReceivedToken::Row(data) => { + let columns = this.columns.as_ref().unwrap().clone(); + let result_index = this.result_set_index.unwrap(); + + let row = Row { + columns, + data, + result_index, + }; + + Poll::Ready(Some(Ok(CommandItem::Row(row)))) + } + ReceivedToken::ReturnStatus(rs) => { + Poll::Ready(Some(Ok(CommandItem::ReturnStatus(rs)))) + } + ReceivedToken::ReturnValue(rv) => { + Poll::Ready(Some(Ok(CommandItem::ReturnValue(CommandReturnValue { + name: rv.param_name, + _ord: rv.param_ordinal, + data: rv.value, + })))) + } + ReceivedToken::DoneProc(done) if done.is_final() => continue, + ReceivedToken::DoneProc(done) => { + Poll::Ready(Some(Ok(CommandItem::RowsAffected(done.rows())))) + } + ReceivedToken::DoneInProc(done) => { + Poll::Ready(Some(Ok(CommandItem::RowsAffected(done.rows())))) + } + ReceivedToken::Done(done) => { + Poll::Ready(Some(Ok(CommandItem::RowsAffected(done.rows())))) + } + _ => continue, + }; + } + } +} diff --git a/src/tds/stream/query.rs b/src/tds/stream/query.rs index 0dc69474..59eef42a 100644 --- a/src/tds/stream/query.rs +++ b/src/tds/stream/query.rs @@ -280,8 +280,8 @@ impl<'a> QueryStream<'a> { /// Info about the following stream of rows. #[derive(Debug, Clone)] pub struct ResultMetadata { - columns: Arc>, - result_index: usize, + pub(crate) columns: Arc>, + pub(crate) result_index: usize, } impl ResultMetadata { diff --git a/tvp-macro/Cargo.toml b/tvp-macro/Cargo.toml new file mode 100644 index 00000000..cf67a90f --- /dev/null +++ b/tvp-macro/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "tvp-macro" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +syn = "^2" +quote = "^1" +proc-macro2 = "^1" + +[lib] +proc-macro = true diff --git a/tvp-macro/src/attr.rs b/tvp-macro/src/attr.rs new file mode 100644 index 00000000..300b6dac --- /dev/null +++ b/tvp-macro/src/attr.rs @@ -0,0 +1,45 @@ +pub(crate) struct FieldAttr { + pub colname: Option, +} + +impl FieldAttr { + pub(crate) fn parse(attrs: &[syn::Attribute]) -> Option { + let mut result = None; + for attr in attrs.iter() { + match attr.style { + syn::AttrStyle::Outer => {} + _ => continue, + } + let last_attr_path = attr + .path() + .segments + .last() + .expect("Expected at least one segment where #[segment[::segment*](..)]"); + if (*last_attr_path).ident != "colname" { + continue; + } + let kv = match attr.meta { + syn::Meta::NameValue(ref kv) => kv, + _ if attr.path().is_ident("colname") => { + panic!("Invalid #[colname] attribute, expected #[colname = \"SomeColName\"]") + } + _ => continue, + }; + if result.is_some() { + panic!("Expected at most one #[colname] attribute"); + } + if let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(ref s), + .. + }) = kv.value + { + result = Some(FieldAttr { + colname: Some(s.value()), + }); + } else { + panic!("Non-string literal value in #[colname] attribute"); + } + } + result + } +} diff --git a/tvp-macro/src/lib.rs b/tvp-macro/src/lib.rs new file mode 100644 index 00000000..4175a3f2 --- /dev/null +++ b/tvp-macro/src/lib.rs @@ -0,0 +1,44 @@ +//! This is a utility macro crate used to generate trivial trait implementations used in rust-to-SQL data exchange +extern crate proc_macro; + +#[macro_use] +extern crate quote; +#[macro_use] +extern crate syn; + +use proc_macro::TokenStream; + +macro_rules! sp_quote { + ($($t:tt)*) => (quote_spanned!(proc_macro2::Span::call_site() => $($t)*)) +} + +mod attr; +mod table_value_param; + +/// This macro generates a trivial implementation of the `TableValueRow` trait. +/// # Applications +/// Could be applied to structures that represent rows of a Table Value params. +/// # Example +/// ```rust,ignore +/// # use tiberius::*; +/// #[derive(TableValueRow)] +/// pub struct SomeGeoList { +/// #[colname = "SomeID"] +/// pub id: i32, +/// #[colname = "LastSyncIPGeoLat"] +/// pub lat: Numeric, +/// #[colname = "LastSyncIPGeoLong"] +/// pub lon: Numeric, +/// } +/// ``` +#[proc_macro_derive(TableValueRow, attributes(colname))] +pub fn table_value_param(input: TokenStream) -> TokenStream { + //println!("intput: {}", input.to_string()); + let ast: syn::DeriveInput = syn::parse(input).expect("Couldn't parse item"); + let result = match ast.data { + syn::Data::Enum(_) => panic!("n/a for enums, makes sense for structs only"), + syn::Data::Struct(ref s) => table_value_param::for_struct(&ast, &s.fields), + syn::Data::Union(_) => panic!("doesn't work with unions"), + }; + result.into() +} diff --git a/tvp-macro/src/table_value_param.rs b/tvp-macro/src/table_value_param.rs new file mode 100644 index 00000000..3b3aefc9 --- /dev/null +++ b/tvp-macro/src/table_value_param.rs @@ -0,0 +1,176 @@ +use proc_macro2::TokenStream; +use syn::punctuated::Punctuated; + +use crate::attr::FieldAttr; + +pub(crate) fn for_struct(ast: &syn::DeriveInput, fields: &syn::Fields) -> TokenStream { + match *fields { + syn::Fields::Named(ref fields) => { + table_value_param_impl(&ast, Some(&fields.named) /*, true*/) + } + _ => panic!("Only named fields are supported so far"), // syn::Fields::Unit => try_from_row_impl(&ast, None, false, variant), + // syn::Fields::Unnamed(ref fields) => { + // try_from_row_impl(&ast, Some(&fields.unnamed), false, variant) + // } + } +} + +fn table_value_param_impl( + ast: &syn::DeriveInput, + fields: Option<&Punctuated>, + // named: bool, +) -> TokenStream { + let name = &ast.ident; + let (lt_impl, lt_struct) = + { + let mut lifetimes: Vec<&syn::Ident> = Vec::new(); + for gp in ast.generics.params.iter() { + if let syn::GenericParam::Lifetime(ltp) = gp { + lifetimes.push(<p.lifetime.ident); + } + } + if lifetimes.len() > 1 { + panic!( + "Only one lifetime specifier is supported for the structure. Found multiple: {}", + lifetimes.iter().map(|lt| lt.to_string()).collect::>().join(", ")); + } + if lifetimes.is_empty() { + ( + sp_quote!(<'query>), + sp_quote!(), // "".parse::().unwrap(), + ) + } else { + let lt = lifetimes[0]; + let ts: proc_macro2::TokenStream = format!("< '{} >", lt).parse().unwrap(); + (ts.clone(), ts) + } + }; + // let unit = fields.is_none(); + let empty = Default::default(); + let fields: Vec<_> = fields + .unwrap_or(&empty) + .iter() + .map(|f| FieldExt::new(f)) + .collect(); + let col_names: Vec<_> = fields.iter().map(|f| f.get_col_name()).collect(); + let _col_names = sp_quote!( #(#col_names),* ); // this comes later, TVPs for queries would need col names, while SPs do not + let col_binds: Vec<_> = fields.iter().map(|f| f.as_bind()).collect(); + let col_binds = sp_quote!( #(#col_binds);*); + sp_quote! { + impl #lt_impl ::tiberius::TableValueRow #lt_impl for #name #lt_struct { + fn get_db_type() -> &'static str { + stringify!{ #name } + } + + fn bind_fields(&self, data_row: &mut tiberius::SqlTableDataRow #lt_impl) { + #col_binds; + } + } + } +} + +struct FieldExt { + attr: Option, + ident: syn::Ident, +} + +impl FieldExt { + pub fn new(field: &syn::Field) -> FieldExt { + if let Some(ident) = field.ident.clone() { + FieldExt { + attr: FieldAttr::parse(&field.attrs), + ident, + } + } else { + panic!("Field ident is required"); + } + } + pub(crate) fn get_col_name(&self) -> String { + if let Some(attr) = self.attr.as_ref() { + if let Some(colname) = attr.colname.as_ref() { + return colname.to_string(); + } + } + self.ident.to_string() + } + pub(crate) fn as_bind(&self) -> TokenStream { + let name = &self.ident; + sp_quote!(data_row.add_field(self.#name)) + } +} + +#[cfg(test)] +mod tests { + use super::for_struct; + + #[test] + fn basic_nolifetime() { + // just parse a struct + let ast: syn::DeriveInput = syn::parse_str( + r#" + pub struct SomeGeoList { + #[colname = "SomeID"] + pub id: i32, + #[colname = "LastSyncIPGeoLat"] + pub lat: Numeric, // decimal(9,6) + #[colname = "LastSyncIPGeoLong"] + pub lon: Numeric, // decimal(9,6) + } + "#, + ) + .unwrap(); + let result = match ast.data { + syn::Data::Enum(_) => panic!("n/a for enums, makes sense for structs only"), + syn::Data::Struct(ref s) => for_struct(&ast, &s.fields), + syn::Data::Union(_) => panic!("doesn't work with unions"), + }; + let etalon = sp_quote!( + impl<'query> ::tiberius::TableValueRow<'query> for SomeGeoList { + fn get_db_type() -> &'static str { + stringify! { SomeGeoList } + } + fn bind_fields(&self, data_row: &mut tiberius::SqlTableDataRow<'query>) { + data_row.add_field(self.id); + data_row.add_field(self.lat); + data_row.add_field(self.lon); + } + } + ); + + assert_eq!(result.to_string(), etalon.to_string()); + } + + #[test] + fn basic_lifetime() { + let ast: syn::DeriveInput = syn::parse_str( + r#" + pub struct AnotherGeoList<'e> { + #[colname = "SomeID"] + pub id: i32, + #[colname = "SomeStr"] + pub s: &'e str, + } + "#, + ) + .unwrap(); + let result = match ast.data { + syn::Data::Enum(_) => panic!("n/a for enums, makes sense for structs only"), + syn::Data::Struct(ref s) => for_struct(&ast, &s.fields), + syn::Data::Union(_) => panic!("doesn't work with unions"), + }; + + let etalon = sp_quote!( + impl<'e> ::tiberius::TableValueRow<'e> for AnotherGeoList<'e> { + fn get_db_type() -> &'static str { + stringify! { AnotherGeoList } + } + fn bind_fields(&self, data_row: &mut tiberius::SqlTableDataRow<'e>) { + data_row.add_field(self.id); + data_row.add_field(self.s); + } + } + ); + + assert_eq!(result.to_string(), etalon.to_string()); + } +} From 99c09e7bd51b858281ebfe819b11fa063322ad3e Mon Sep 17 00:00:00 2001 From: Andrey Voitenkov Date: Tue, 30 Jan 2024 15:45:51 +0200 Subject: [PATCH 02/21] ::tiberius root ref removed from tvp macro --- tvp-macro/src/table_value_param.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tvp-macro/src/table_value_param.rs b/tvp-macro/src/table_value_param.rs index 3b3aefc9..8fd16984 100644 --- a/tvp-macro/src/table_value_param.rs +++ b/tvp-macro/src/table_value_param.rs @@ -57,7 +57,7 @@ fn table_value_param_impl( let col_binds: Vec<_> = fields.iter().map(|f| f.as_bind()).collect(); let col_binds = sp_quote!( #(#col_binds);*); sp_quote! { - impl #lt_impl ::tiberius::TableValueRow #lt_impl for #name #lt_struct { + impl #lt_impl tiberius::TableValueRow #lt_impl for #name #lt_struct { fn get_db_type() -> &'static str { stringify!{ #name } } From c1e18698beeaa7bee863b9a7daf9d35d11822e8f Mon Sep 17 00:00:00 2001 From: Andrey Voitenkov Date: Tue, 30 Jan 2024 18:48:56 +0200 Subject: [PATCH 03/21] ::hygine is bask after testing --- tvp-macro/src/table_value_param.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tvp-macro/src/table_value_param.rs b/tvp-macro/src/table_value_param.rs index 8fd16984..1e1826c1 100644 --- a/tvp-macro/src/table_value_param.rs +++ b/tvp-macro/src/table_value_param.rs @@ -57,12 +57,12 @@ fn table_value_param_impl( let col_binds: Vec<_> = fields.iter().map(|f| f.as_bind()).collect(); let col_binds = sp_quote!( #(#col_binds);*); sp_quote! { - impl #lt_impl tiberius::TableValueRow #lt_impl for #name #lt_struct { + impl #lt_impl ::tiberius::TableValueRow #lt_impl for #name #lt_struct { fn get_db_type() -> &'static str { stringify!{ #name } } - fn bind_fields(&self, data_row: &mut tiberius::SqlTableDataRow #lt_impl) { + fn bind_fields(&self, data_row: &mut ::tiberius::SqlTableDataRow #lt_impl) { #col_binds; } } @@ -129,7 +129,7 @@ mod tests { fn get_db_type() -> &'static str { stringify! { SomeGeoList } } - fn bind_fields(&self, data_row: &mut tiberius::SqlTableDataRow<'query>) { + fn bind_fields(&self, data_row: &mut ::tiberius::SqlTableDataRow<'query>) { data_row.add_field(self.id); data_row.add_field(self.lat); data_row.add_field(self.lon); @@ -164,7 +164,7 @@ mod tests { fn get_db_type() -> &'static str { stringify! { AnotherGeoList } } - fn bind_fields(&self, data_row: &mut tiberius::SqlTableDataRow<'e>) { + fn bind_fields(&self, data_row: &mut ::tiberius::SqlTableDataRow<'e>) { data_row.add_field(self.id); data_row.add_field(self.s); } From d84eee72a21004aa9aad5fdacdc601cb878c13ff Mon Sep 17 00:00:00 2001 From: Andrey Voitenkov Date: Tue, 30 Jan 2024 19:08:34 +0200 Subject: [PATCH 04/21] back to relative path in macro --- tvp-macro/src/table_value_param.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tvp-macro/src/table_value_param.rs b/tvp-macro/src/table_value_param.rs index 1e1826c1..e1f39128 100644 --- a/tvp-macro/src/table_value_param.rs +++ b/tvp-macro/src/table_value_param.rs @@ -57,12 +57,12 @@ fn table_value_param_impl( let col_binds: Vec<_> = fields.iter().map(|f| f.as_bind()).collect(); let col_binds = sp_quote!( #(#col_binds);*); sp_quote! { - impl #lt_impl ::tiberius::TableValueRow #lt_impl for #name #lt_struct { + impl #lt_impl tiberius::TableValueRow #lt_impl for #name #lt_struct { fn get_db_type() -> &'static str { stringify!{ #name } } - fn bind_fields(&self, data_row: &mut ::tiberius::SqlTableDataRow #lt_impl) { + fn bind_fields(&self, data_row: &mut tiberius::SqlTableDataRow #lt_impl) { #col_binds; } } @@ -125,11 +125,11 @@ mod tests { syn::Data::Union(_) => panic!("doesn't work with unions"), }; let etalon = sp_quote!( - impl<'query> ::tiberius::TableValueRow<'query> for SomeGeoList { + impl<'query> tiberius::TableValueRow<'query> for SomeGeoList { fn get_db_type() -> &'static str { stringify! { SomeGeoList } } - fn bind_fields(&self, data_row: &mut ::tiberius::SqlTableDataRow<'query>) { + fn bind_fields(&self, data_row: &mut tiberius::SqlTableDataRow<'query>) { data_row.add_field(self.id); data_row.add_field(self.lat); data_row.add_field(self.lon); @@ -160,11 +160,11 @@ mod tests { }; let etalon = sp_quote!( - impl<'e> ::tiberius::TableValueRow<'e> for AnotherGeoList<'e> { + impl<'e> tiberius::TableValueRow<'e> for AnotherGeoList<'e> { fn get_db_type() -> &'static str { stringify! { AnotherGeoList } } - fn bind_fields(&self, data_row: &mut ::tiberius::SqlTableDataRow<'e>) { + fn bind_fields(&self, data_row: &mut tiberius::SqlTableDataRow<'e>) { data_row.add_field(self.id); data_row.add_field(self.s); } From 7cc2562ea10560da271908fa66d611049f39a011 Mon Sep 17 00:00:00 2001 From: Andrey Voitenkov Date: Tue, 30 Jan 2024 19:35:39 +0200 Subject: [PATCH 05/21] CommandStream export fixed --- src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lib.rs b/src/lib.rs index c2e2c034..3c203910 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -281,6 +281,7 @@ pub use sql_browser::SqlBrowser; pub use tds::{ codec::{BulkLoadRequest, ColumnData, ColumnFlag, IntoRow, TokenRow, TypeLength}, numeric, + stream::CommandStream, stream::QueryStream, time, xml, EncryptionLevel, }; From d7b0e0be618e6e4206a9d6bd68d2cca3cf5a3347 Mon Sep 17 00:00:00 2001 From: Andrey Voitenkov Date: Tue, 30 Jan 2024 19:40:48 +0200 Subject: [PATCH 06/21] CommandItem export fixed --- src/result.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/result.rs b/src/result.rs index a2edae35..f5ad1026 100644 --- a/src/result.rs +++ b/src/result.rs @@ -1,4 +1,4 @@ -pub use crate::tds::stream::{QueryItem, ResultMetadata}; +pub use crate::tds::stream::{CommandItem, QueryItem, ResultMetadata}; use crate::{ client::Connection, error::Error, From 396f44956f0e98547db7056dc4c100d8c90cf7be Mon Sep 17 00:00:00 2001 From: Andrii Voytenkov Date: Wed, 31 Jan 2024 09:09:44 +0200 Subject: [PATCH 07/21] very basic command test added --- Cargo.toml | 4 +++ tests/command.rs | 87 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) create mode 100644 tests/command.rs diff --git a/Cargo.toml b/Cargo.toml index a3894bf7..2e9534bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,10 @@ members = ["runtimes-macro", "tvp-macro"] path = "tests/query.rs" name = "query" +[[test]] +path = "tests/command.rs" +name = "command" + [[test]] path = "tests/named-instance-async.rs" name = "named-instance-async" diff --git a/tests/command.rs b/tests/command.rs new file mode 100644 index 00000000..742173f5 --- /dev/null +++ b/tests/command.rs @@ -0,0 +1,87 @@ +use futures_util::io::{AsyncRead, AsyncWrite}; +use futures_util::stream::TryStreamExt; +use names::{Generator, Name}; +use once_cell::sync::Lazy; +use std::cell::RefCell; +use std::env; +use std::sync::Once; + +use tiberius::FromSql; +use tiberius::{numeric::Numeric, xml::XmlData, ColumnType, Command, CommandItem, Result}; +use uuid::Uuid; + +use runtimes_macro::test_on_runtimes; + +// This is used in the testing macro :) +#[allow(dead_code)] +static LOGGER_SETUP: Once = Once::new(); + +static CONN_STR: Lazy = Lazy::new(|| { + env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or_else(|_| { + "server=tcp:localhost,1433;user=SA;password=;IntegratedSecurity=true;TrustServerCertificate=true".to_owned() + }) +}); + +thread_local! { + static NAMES: RefCell>> = + RefCell::new(None); +} + +async fn random_table() -> String { + NAMES.with(|maybe_generator| { + maybe_generator + .borrow_mut() + .get_or_insert_with(|| Generator::with_naming(Name::Plain)) + .next() + .unwrap() + .replace('-', "") + }) +} + +#[test_on_runtimes] +async fn basic_proc_exec(mut conn: tiberius::Client) -> Result<()> +where + S: AsyncRead + AsyncWrite + Unpin + Send, +{ + let table = random_table().await; + let proc = random_table().await; + + conn.simple_query(format!( + r#" + create table ##{} ( + id int identity(1,1), + other varchar(50), + ) + "#, + table + )) + .await?; + + conn.simple_query(format!( + r#" + create or alter procedure {} + @Param1 varchar(50) + as + insert into ##{} (other) + values (@Param1) + + return scope_identity() + "#, + proc, table, + )) + .await?; + + let mut ins_cmd = Command::new(&proc); + ins_cmd.bind_param("@Param1", "some text"); + + let result = ins_cmd.exec(&mut conn).await?.into_command_result().await?; + assert_eq!(1, result.return_code()); + + let mut ins_cmd = Command::new(&proc); + ins_cmd.bind_param("@Param1", "another text"); + + let result = ins_cmd.exec(&mut conn).await?.into_command_result().await?; + assert_eq!(2, result.return_code()); + + Ok(()) +} From 5fb8737656e1aa374a7ae7fab1b9561853e00c90 Mon Sep 17 00:00:00 2001 From: Andrii Voytenkov Date: Wed, 31 Jan 2024 09:46:14 +0200 Subject: [PATCH 08/21] bind_table feature dependency removed --- src/command.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/command.rs b/src/command.rs index 454b85ac..8e5abb0a 100644 --- a/src/command.rs +++ b/src/command.rs @@ -164,7 +164,6 @@ impl<'a> Command<'a> { /// # } /// ``` /// - #[cfg(feature = "tds73")] pub fn bind_table(&mut self, name: impl Into>, data: impl TableValue<'a> + 'a) { self.params.push(CommandParam { name: name.into(), From 8f981ffbd32a1ec276b6e3d2bdd72f162d33686f Mon Sep 17 00:00:00 2001 From: Andrii Voytenkov Date: Wed, 31 Jan 2024 10:15:08 +0200 Subject: [PATCH 09/21] hidden_glob_reexports warn fixed for time feature --- src/tds/time/time.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tds/time/time.rs b/src/tds/time/time.rs index 70d1c4d3..98906f0a 100644 --- a/src/tds/time/time.rs +++ b/src/tds/time/time.rs @@ -7,7 +7,7 @@ pub use time::*; use crate::tds::codec::ColumnData; -use std::time::Duration; +pub use std::time::Duration; #[inline] fn from_days(days: u64, start_year: i32) -> Date { From d821830016dcd6e8207dd03e4e737b23eef990e4 Mon Sep 17 00:00:00 2001 From: Andrii Voytenkov Date: Wed, 31 Jan 2024 10:21:36 +0200 Subject: [PATCH 10/21] chrono deprecated from_utc warning fixed --- src/tds/time/chrono.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tds/time/chrono.rs b/src/tds/time/chrono.rs index 4bdb59a5..b448f8d4 100644 --- a/src/tds/time/chrono.rs +++ b/src/tds/time/chrono.rs @@ -81,7 +81,7 @@ from_sql!( let offset = chrono::Duration::minutes(dto.offset as i64); let naive = NaiveDateTime::new(date, time).sub(offset); - chrono::DateTime::from_utc(naive, Utc) + chrono::DateTime::from_naive_utc_and_offset(naive, Utc) }); chrono::DateTime: ColumnData::DateTimeOffset(ref dto) => dto.map(|dto| { let date = from_days(dto.datetime2.date.days() as i64, 1); @@ -91,7 +91,7 @@ from_sql!( let offset = FixedOffset::east_opt((dto.offset as i32) * 60).unwrap(); let naive = NaiveDateTime::new(date, time); - chrono::DateTime::from_utc(naive, offset) + chrono::DateTime::from_naive_utc_and_offset(naive, offset) }) ); From 1a40f4932ccb4cb48deec2f0ebd79d8167a50705 Mon Sep 17 00:00:00 2001 From: Andrii Voytenkov Date: Wed, 31 Jan 2024 10:52:00 +0200 Subject: [PATCH 11/21] command test build cleanup --- tests/command.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/command.rs b/tests/command.rs index 742173f5..06df4e63 100644 --- a/tests/command.rs +++ b/tests/command.rs @@ -1,14 +1,12 @@ use futures_util::io::{AsyncRead, AsyncWrite}; -use futures_util::stream::TryStreamExt; +// use futures_util::stream::TryStreamExt; use names::{Generator, Name}; use once_cell::sync::Lazy; use std::cell::RefCell; use std::env; use std::sync::Once; -use tiberius::FromSql; -use tiberius::{numeric::Numeric, xml::XmlData, ColumnType, Command, CommandItem, Result}; -use uuid::Uuid; +use tiberius::{Command, Result}; use runtimes_macro::test_on_runtimes; From 83e50838a7860f2e84818efe9da00dec50dbc087 Mon Sep 17 00:00:00 2001 From: Andrii Voytenkov Date: Wed, 31 Jan 2024 11:23:43 +0200 Subject: [PATCH 12/21] chrono from_utc warning fixed in query tests --- tests/query.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/query.rs b/tests/query.rs index 527ab025..f2177ea2 100644 --- a/tests/query.rs +++ b/tests/query.rs @@ -2237,7 +2237,7 @@ where .unwrap() .and_hms_opt(16, 20, 0) .unwrap(); - let dt: DateTime = DateTime::from_utc(naive, Utc); + let dt: DateTime = DateTime::from_naive_utc_and_offset(naive, Utc); let row = conn .query("SELECT @P1", &[&dt]) @@ -2276,7 +2276,7 @@ where .unwrap(); let fixed = FixedOffset::east_opt(3600 * 3).unwrap(); - let dt: DateTime = DateTime::from_utc(naive, fixed); + let dt: DateTime = DateTime::from_naive_utc_and_offset(naive, fixed); let row = conn .query("SELECT @P1", &[&dt]) @@ -2314,7 +2314,7 @@ where .and_hms_opt(16, 20, 0) .unwrap(); let fixed = FixedOffset::east_opt(3600 * 3).unwrap(); - let dt: DateTime = DateTime::from_utc(naive, fixed); + let dt: DateTime = DateTime::from_naive_utc_and_offset(naive, fixed); let row = conn .query(format!("SELECT CAST('{}' AS datetimeoffset(7))", dt), &[]) From 8193f15b8c4dbbb1ac99c33498b7a6f21a862d41 Mon Sep 17 00:00:00 2001 From: Andrey Voitenkov Date: Wed, 31 Jan 2024 15:47:13 +0200 Subject: [PATCH 13/21] doctests tidy up, multiple fixes --- src/command.rs | 14 ++++++++------ src/result.rs | 12 ++++-------- src/tds/stream/command.rs | 13 +++++++------ 3 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/command.rs b/src/command.rs index 8e5abb0a..efe34a1a 100644 --- a/src/command.rs +++ b/src/command.rs @@ -126,15 +126,17 @@ impl<'a> Command<'a> { /// Example /// /// ```no_run - /// # use tiberius::{numeric::Numeric, Client, Command, TableValueRow}; - /// # use tokio_util::compat::TokioAsyncWriteCompatExt; /// # use std::env; + /// # use tiberius::Config; + /// # use tiberius::{numeric::Numeric, Command, TableValueRow}; + /// # use tokio_util::compat::TokioAsyncWriteCompatExt; /// #[derive(TableValueRow)] /// struct SomeGeoList { /// eid: i32, /// lat: Numeric, /// lon: Numeric, /// } + /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or( /// # "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(), @@ -142,7 +144,7 @@ impl<'a> Command<'a> { /// # let config = Config::from_ado_string(&c_str)?; /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?; /// # tcp.set_nodelay(true)?; - /// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?; + /// # let client = tiberius::Client::connect(config, tcp.compat_write()).await?; /// /// let r1 = SomeGeoList { /// eid: 1, @@ -178,9 +180,10 @@ impl<'a> Command<'a> { /// Example /// /// ```no_run - /// # use tiberius::{numeric::Numeric, Client, Command}; + /// # use tiberius::{Config, Command}; /// # use tokio_util::compat::TokioAsyncWriteCompatExt; /// # use std::env; + /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or( /// # "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(), @@ -195,10 +198,9 @@ impl<'a> Command<'a> { /// cmd.bind_out_param("@bar", "bar"); /// let res = cmd.exec(&mut client).await?.into_command_result().await?; /// - /// let rv: Option = res.try_return_value("@bar")?; + /// let rv: Option<&str> = res.try_return_value("@bar")?; /// let rc = res.return_code(); /// - /// println!("And we got bar: {:#?}, return_code: {}", rv, rc); /// # Ok(()) /// # } /// ``` diff --git a/src/result.rs b/src/result.rs index f5ad1026..93e3022c 100644 --- a/src/result.rs +++ b/src/result.rs @@ -123,9 +123,10 @@ impl IntoIterator for ExecuteResult { /// # Example /// /// ```no_run -/// # use tiberius::{numeric::Numeric, Client, Command}; +/// # use tiberius::{Config, Command}; /// # use tokio_util::compat::TokioAsyncWriteCompatExt; /// # use std::env; +/// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or( /// # "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(), @@ -140,16 +141,11 @@ impl IntoIterator for ExecuteResult { /// cmd.bind_out_param("@bar", "bar"); /// let res = cmd.exec(&mut client).await?.into_command_result().await?; /// -/// let rv: Option = res.try_return_value("@bar")?; +/// let rv: Option<&str> = res.try_return_value("@bar")?; /// let rc = res.return_code(); /// let ra = res.rows_affected(); /// -/// println!("And we got bar: {:#?}, return_code: {}", rv, rc); -/// -/// let rs0 = res.to_query_result(0) -/// if let Some(rows) = rs0 { -/// printls!("First record set: {:#?}", rows); -/// } +/// let rs0 = res.to_query_result(0); /// # Ok(()) /// # } /// ``` diff --git a/src/tds/stream/command.rs b/src/tds/stream/command.rs index e90b0559..d92a2ad8 100644 --- a/src/tds/stream/command.rs +++ b/src/tds/stream/command.rs @@ -18,9 +18,10 @@ use std::{ /// # Example /// /// ```no_run -/// # use futures::TryStreamExt; -/// # use tiberius::{numeric::Numeric, Client, Command}; -/// # use tokio::net::TcpStream; +/// # use std::env; +/// # use tiberius::Config; +/// # use tiberius::{Command, CommandItem}; +/// # use futures_util::TryStreamExt; /// # use tokio_util::compat::TokioAsyncWriteCompatExt; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { @@ -36,7 +37,7 @@ use std::{ /// cmd.bind_param("@foo", 34i32); /// cmd.bind_param("@zoo", "the zoo string prm"); /// cmd.bind_out_param("@bar", "bar"); -/// let stream = cmd.exec(&mut client).await?; +/// let mut stream = cmd.exec(&mut client).await?; /// /// while let Some(item) = stream.try_next().await? { /// match item { @@ -46,7 +47,7 @@ use std::{ /// } /// // ... and from there on from 0..N rows /// CommandItem::Row(row) if row.result_index() == 0 => { -/// let var = row.get(0); +/// let var: Option = row.get(0); /// } /// // the second result set returns first another metadata item /// CommandItem::Metadata(meta) => { @@ -54,7 +55,7 @@ use std::{ /// } /// // ...and, again, we get rows from the second resultset /// CommandItem::Row(row) => { -/// let var = row.get(0); +/// let var: Option = row.get(0); /// } /// // check return status (mandatory, returned always) /// CommandItem::ReturnStatus(rs) => { From e4c668613c1a57786b7ea29960d0878f824d524f Mon Sep 17 00:00:00 2001 From: Andrey Voitenkov Date: Wed, 31 Jan 2024 18:56:09 +0200 Subject: [PATCH 14/21] tvp test added --- src/command.rs | 2 +- tests/command.rs | 148 ++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 146 insertions(+), 4 deletions(-) diff --git a/src/command.rs b/src/command.rs index efe34a1a..13d433d8 100644 --- a/src/command.rs +++ b/src/command.rs @@ -71,7 +71,7 @@ enum CommandParamData<'a> { #[derive(Debug)] pub struct SqlTableData<'a> { rows: Vec>, - db_type: &'static str, + db_type: &'a str, } #[derive(Debug)] diff --git a/tests/command.rs b/tests/command.rs index 06df4e63..ed5c8363 100644 --- a/tests/command.rs +++ b/tests/command.rs @@ -6,7 +6,7 @@ use std::cell::RefCell; use std::env; use std::sync::Once; -use tiberius::{Command, Result}; +use tiberius::{numeric::Numeric, Command, Result, TableValueRow}; use runtimes_macro::test_on_runtimes; @@ -50,7 +50,7 @@ where id int identity(1,1), other varchar(50), ) - "#, + "#, table )) .await?; @@ -64,7 +64,7 @@ where values (@Param1) return scope_identity() - "#, + "#, proc, table, )) .await?; @@ -83,3 +83,145 @@ where Ok(()) } + +struct GeoTest { + id: i32, + lat: Numeric, + lon: Numeric, +} + +impl<'a> TableValueRow<'a> for GeoTest { + fn bind_fields(&self, data_row: &mut tiberius::SqlTableDataRow<'a>) { + data_row.add_field(self.id); + data_row.add_field(self.lat); + data_row.add_field(self.lon); + } + + fn get_db_type() -> &'static str { + "GeoTest" + } +} + +#[test_on_runtimes] +async fn tvp_proc_exec(mut conn: tiberius::Client) -> Result<()> +where + S: AsyncRead + AsyncWrite + Unpin + Send, +{ + let table = random_table().await; + let proc_tvp = random_table().await; + let proc_ins = random_table().await; + let proc_get = random_table().await; + + conn.simple_query(format!( + r#"if not exists(select * from sys.types where name = 'GeoTest') + create type dbo.[GeoTest] as table + ( + [ID] int not null, + [lat] decimal(9,6), + [lon] decimal(9,6) + ) + "# + )) + .await?; + + conn.simple_query(format!( + r#" + create table ##{} ( + id int not null, + lat decimal(9,6), + lon decimal(9,6), + n varchar(50) + ) + "#, + table + )) + .await?; + + conn.simple_query(format!( + r#" + create or alter procedure {} + @id int, + @geo dbo.[GeoTest] readonly + as + update t set + [lat] = g.[lat], + [lon] = g.[lon] + from + ##{} t + inner join @geo g on g.id = t.id + + "#, + proc_tvp, table, + )) + .await?; + + conn.simple_query(format!( + r#" + create or alter procedure {} + @id int, + @name varchar(50) + as + insert into ##{} (id, n) + values (@id, @name) + + "#, + proc_ins, table, + )) + .await?; + + conn.simple_query(format!( + r#" + create or alter procedure {} + @id int, + @count int out + as + set @count = (select count(*) from ##{}) + select * from ##{} + "#, + proc_get, table, table + )) + .await?; + + let mut ins_cmd = Command::new(&proc_ins); + ins_cmd.bind_param("@id", 23); + ins_cmd.bind_param("@name", "the twenty three"); + + let result = ins_cmd.exec(&mut conn).await?.into_command_result().await?; + assert_eq!(0, result.return_code()); + + let g1 = GeoTest { + id: 23, + lon: Numeric::new_with_scale(141, 6), + lat: Numeric::new_with_scale(192, 6), + }; + let g2 = GeoTest { + id: 78, + lon: Numeric::new_with_scale(1141, 6), + lat: Numeric::new_with_scale(8192, 6), + }; + let tbl = vec![g1, g2]; + let mut tvp_cmd = Command::new(&proc_tvp); + tvp_cmd.bind_param("@id", 23); + tvp_cmd.bind_table("@geo", tbl); + + let result = tvp_cmd.exec(&mut conn).await?.into_command_result().await?; + assert_eq!(0, result.return_code()); + + let count = 0; + let mut get_cmd = Command::new(&proc_get); + get_cmd.bind_param("@id", 23); + get_cmd.bind_out_param("@count", count); + + let result = get_cmd.exec(&mut conn).await?.into_command_result().await?; + assert_eq!(0, result.return_code()); + let count: i32 = result.try_return_value("@count")?.unwrap(); + assert_eq!(1, count); + + let rows = result.to_query_result(0).unwrap(); + let lat: Numeric = rows[0].get("lat").unwrap(); + let lon: Numeric = rows[0].get("lon").unwrap(); + assert_eq!(Numeric::new_with_scale(141, 6), lon); + assert_eq!(Numeric::new_with_scale(192, 6), lat); + + Ok(()) +} From 8309803e9604d1cc4fb33687bc990884be6631a9 Mon Sep 17 00:00:00 2001 From: Andrey Voitenkov Date: Wed, 31 Jan 2024 19:23:07 +0200 Subject: [PATCH 15/21] tvp test wrapped in transaction --- tests/command.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/command.rs b/tests/command.rs index ed5c8363..e3071cac 100644 --- a/tests/command.rs +++ b/tests/command.rs @@ -112,6 +112,8 @@ where let proc_ins = random_table().await; let proc_get = random_table().await; + conn.simple_query("BEGIN TRAN").await?; + conn.simple_query(format!( r#"if not exists(select * from sys.types where name = 'GeoTest') create type dbo.[GeoTest] as table @@ -223,5 +225,7 @@ where assert_eq!(Numeric::new_with_scale(141, 6), lon); assert_eq!(Numeric::new_with_scale(192, 6), lat); + conn.simple_query("COMMIT").await?; + Ok(()) } From 6d5aab1b6743f4c6be2aa1679ae21f1049c20a43 Mon Sep 17 00:00:00 2001 From: Andrey Voitenkov Date: Wed, 31 Jan 2024 20:02:56 +0200 Subject: [PATCH 16/21] smaller transaction for db type create/check --- tests/command.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/command.rs b/tests/command.rs index e3071cac..1d36f9bc 100644 --- a/tests/command.rs +++ b/tests/command.rs @@ -113,7 +113,6 @@ where let proc_get = random_table().await; conn.simple_query("BEGIN TRAN").await?; - conn.simple_query(format!( r#"if not exists(select * from sys.types where name = 'GeoTest') create type dbo.[GeoTest] as table @@ -125,6 +124,7 @@ where "# )) .await?; + conn.simple_query("COMMIT").await?; conn.simple_query(format!( r#" @@ -225,7 +225,5 @@ where assert_eq!(Numeric::new_with_scale(141, 6), lon); assert_eq!(Numeric::new_with_scale(192, 6), lat); - conn.simple_query("COMMIT").await?; - Ok(()) } From 5474daba1caf1906c98fe5cca3fcbe0aa84f8baf Mon Sep 17 00:00:00 2001 From: Andrey Voitenkov Date: Wed, 31 Jan 2024 20:15:31 +0200 Subject: [PATCH 17/21] tvp type name override in table bind call --- src/command.rs | 17 +++++++++++++++++ tests/command.rs | 16 ++++++++-------- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/command.rs b/src/command.rs index 13d433d8..04daf78d 100644 --- a/src/command.rs +++ b/src/command.rs @@ -174,6 +174,23 @@ impl<'a> Command<'a> { }); } + /// The same as `bind_table`, but with optional DB type name override + pub fn bind_table_with_dbtype( + &mut self, + db_type: &'a str, + name: impl Into>, + data: impl TableValue<'a> + 'a, + ) { + self.params.push(CommandParam { + name: name.into(), + out: false, + data: CommandParamData::Table(SqlTableData { + db_type, + ..data.into_sql() + }), + }); + } + /// Executes the `Command` in the SQL Server, returning `CommandStream` that /// can be collected into `CommandResult` for convinience. /// diff --git a/tests/command.rs b/tests/command.rs index 1d36f9bc..8c2bf115 100644 --- a/tests/command.rs +++ b/tests/command.rs @@ -111,20 +111,20 @@ where let proc_tvp = random_table().await; let proc_ins = random_table().await; let proc_get = random_table().await; + let db_type = random_table().await; - conn.simple_query("BEGIN TRAN").await?; conn.simple_query(format!( - r#"if not exists(select * from sys.types where name = 'GeoTest') - create type dbo.[GeoTest] as table + r#" + create type dbo.{} as table ( [ID] int not null, [lat] decimal(9,6), [lon] decimal(9,6) ) - "# + "#, + db_type )) .await?; - conn.simple_query("COMMIT").await?; conn.simple_query(format!( r#" @@ -143,7 +143,7 @@ where r#" create or alter procedure {} @id int, - @geo dbo.[GeoTest] readonly + @geo dbo.{} readonly as update t set [lat] = g.[lat], @@ -153,7 +153,7 @@ where inner join @geo g on g.id = t.id "#, - proc_tvp, table, + proc_tvp, db_type, table, )) .await?; @@ -204,7 +204,7 @@ where let tbl = vec![g1, g2]; let mut tvp_cmd = Command::new(&proc_tvp); tvp_cmd.bind_param("@id", 23); - tvp_cmd.bind_table("@geo", tbl); + tvp_cmd.bind_table_with_dbtype("@geo", db_type, tbl); let result = tvp_cmd.exec(&mut conn).await?.into_command_result().await?; assert_eq!(0, result.return_code()); From 330393a744b524b56dbc3f246ba9b6a8f81439d2 Mon Sep 17 00:00:00 2001 From: Andrey Voitenkov Date: Wed, 31 Jan 2024 20:24:58 +0200 Subject: [PATCH 18/21] tvp test fixed --- src/command.rs | 2 +- tests/command.rs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/command.rs b/src/command.rs index 04daf78d..60ba9abe 100644 --- a/src/command.rs +++ b/src/command.rs @@ -177,8 +177,8 @@ impl<'a> Command<'a> { /// The same as `bind_table`, but with optional DB type name override pub fn bind_table_with_dbtype( &mut self, - db_type: &'a str, name: impl Into>, + db_type: &'a str, data: impl TableValue<'a> + 'a, ) { self.params.push(CommandParam { diff --git a/tests/command.rs b/tests/command.rs index 8c2bf115..29ea0f18 100644 --- a/tests/command.rs +++ b/tests/command.rs @@ -115,7 +115,7 @@ where conn.simple_query(format!( r#" - create type dbo.{} as table + create type dbo.[{}] as table ( [ID] int not null, [lat] decimal(9,6), @@ -143,7 +143,7 @@ where r#" create or alter procedure {} @id int, - @geo dbo.{} readonly + @geo dbo.[{}] readonly as update t set [lat] = g.[lat], @@ -204,7 +204,7 @@ where let tbl = vec![g1, g2]; let mut tvp_cmd = Command::new(&proc_tvp); tvp_cmd.bind_param("@id", 23); - tvp_cmd.bind_table_with_dbtype("@geo", db_type, tbl); + tvp_cmd.bind_table_with_dbtype("@geo", &db_type, tbl); let result = tvp_cmd.exec(&mut conn).await?.into_command_result().await?; assert_eq!(0, result.return_code()); From 6f781e3c7fae60e0af752d5093a90d32002d9b44 Mon Sep 17 00:00:00 2001 From: Andrii Voytenkov Date: Sun, 25 Feb 2024 16:07:41 +0200 Subject: [PATCH 19/21] I32 Option None as VarLen (experimental) --- src/tds/codec/column_data.rs | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/tds/codec/column_data.rs b/src/tds/codec/column_data.rs index ada32781..2055dba0 100644 --- a/src/tds/codec/column_data.rs +++ b/src/tds/codec/column_data.rs @@ -216,10 +216,15 @@ impl<'a> Encode> for ColumnData<'a> { dst.put_u8(0); } } - (ColumnData::I32(Some(val)), None) => { - let header = [VarLenType::Intn as u8, 4, 4]; + (ColumnData::I32(opt), None) => { + let header = [VarLenType::Intn as u8, 4]; dst.extend_from_slice(&header); - dst.put_i32_le(val); + if let Some(val) = opt { + dst.put_u8(4); + dst.put_i32_le(val); + } else { + dst.put_u8(0); + } } (ColumnData::I64(Some(val)), Some(TypeInfo::FixedLen(FixedLenType::Int8))) => { dst.put_i64_le(val); From f4ddb2af0d2cd968fb9e9d833defdaf011070c21 Mon Sep 17 00:00:00 2001 From: Andrey Voitenkov Date: Sun, 25 Feb 2024 17:10:31 +0200 Subject: [PATCH 20/21] Guid Option None as VarLen --- src/tds/codec/column_data.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/tds/codec/column_data.rs b/src/tds/codec/column_data.rs index 2055dba0..9f888649 100644 --- a/src/tds/codec/column_data.rs +++ b/src/tds/codec/column_data.rs @@ -293,13 +293,17 @@ impl<'a> Encode> for ColumnData<'a> { dst.put_u8(0); } } - (ColumnData::Guid(Some(uuid)), None) => { - let header = [VarLenType::Guid as u8, 16, 16]; + (ColumnData::Guid(opt), None) => { + let header = [VarLenType::Guid as u8, 16]; dst.extend_from_slice(&header); - - let mut data = *uuid.as_bytes(); - super::guid::reorder_bytes(&mut data); - dst.extend_from_slice(&data); + if let Some(uuid) = opt { + dst.put_u8(16); + let mut data = *uuid.as_bytes(); + super::guid::reorder_bytes(&mut data); + dst.extend_from_slice(&data); + } else { + dst.put_u8(0); + } } (ColumnData::String(opt), Some(TypeInfo::VarLenSized(vlc))) if vlc.r#type() == VarLenType::BigChar From 5c577c286158216a50c5247899553ed3b05e3573 Mon Sep 17 00:00:00 2001 From: Andrey Voitenkov Date: Sun, 25 Feb 2024 18:26:14 +0200 Subject: [PATCH 21/21] All other Bit/Int/Float Option None as VarLen --- src/tds/codec/column_data.rs | 69 +++++++++++++++++++++++++----------- 1 file changed, 49 insertions(+), 20 deletions(-) diff --git a/src/tds/codec/column_data.rs b/src/tds/codec/column_data.rs index 9f888649..5a923818 100644 --- a/src/tds/codec/column_data.rs +++ b/src/tds/codec/column_data.rs @@ -158,13 +158,18 @@ impl<'a> Encode> for ColumnData<'a> { (ColumnData::Bit(Some(val)), Some(TypeInfo::FixedLen(FixedLenType::Bit))) => { dst.put_u8(val as u8); } - (ColumnData::Bit(Some(val)), None) => { + (ColumnData::Bit(opt), None) => { // if TypeInfo was not given, encode a TypeInfo // the first 1 is part of TYPE_INFO - // the second 1 is part of TYPE_VARBYTE - let header = [VarLenType::Bitn as u8, 1, 1]; + let header = [VarLenType::Bitn as u8, 1]; dst.extend_from_slice(&header); - dst.put_u8(val as u8); + if let Some(val) = opt { + // the second 1 is part of TYPE_VARBYTE + dst.put_u8(1); + dst.put_u8(val as u8); + } else { + dst.put_u8(0); + } } (ColumnData::U8(opt), Some(TypeInfo::VarLenSized(vlc))) if vlc.r#type() == VarLenType::Intn => @@ -179,10 +184,15 @@ impl<'a> Encode> for ColumnData<'a> { (ColumnData::U8(Some(val)), Some(TypeInfo::FixedLen(FixedLenType::Int1))) => { dst.put_u8(val); } - (ColumnData::U8(Some(val)), None) => { - let header = [VarLenType::Intn as u8, 1, 1]; + (ColumnData::U8(opt), None) => { + let header = [VarLenType::Intn as u8, 1]; dst.extend_from_slice(&header); - dst.put_u8(val); + if let Some(val) = opt { + dst.put_u8(1); + dst.put_u8(val); + } else { + dst.put_u8(0); + } } (ColumnData::I16(Some(val)), Some(TypeInfo::FixedLen(FixedLenType::Int2))) => { dst.put_i16_le(val); @@ -197,11 +207,15 @@ impl<'a> Encode> for ColumnData<'a> { dst.put_u8(0); } } - (ColumnData::I16(Some(val)), None) => { - let header = [VarLenType::Intn as u8, 2, 2]; + (ColumnData::I16(opt), None) => { + let header = [VarLenType::Intn as u8, 2]; dst.extend_from_slice(&header); - - dst.put_i16_le(val); + if let Some(val) = opt { + dst.put_u8(2); + dst.put_i16_le(val); + } else { + dst.put_u8(0); + } } (ColumnData::I32(Some(val)), Some(TypeInfo::FixedLen(FixedLenType::Int4))) => { dst.put_i32_le(val); @@ -239,10 +253,15 @@ impl<'a> Encode> for ColumnData<'a> { dst.put_u8(0); } } - (ColumnData::I64(Some(val)), None) => { - let header = [VarLenType::Intn as u8, 8, 8]; + (ColumnData::I64(opt), None) => { + let header = [VarLenType::Intn as u8, 8]; dst.extend_from_slice(&header); - dst.put_i64_le(val); + if let Some(val) = opt { + dst.put_u8(8); + dst.put_i64_le(val); + } else { + dst.put_u8(0); + } } (ColumnData::F32(Some(val)), Some(TypeInfo::FixedLen(FixedLenType::Float4))) => { dst.put_f32_le(val); @@ -257,10 +276,15 @@ impl<'a> Encode> for ColumnData<'a> { dst.put_u8(0); } } - (ColumnData::F32(Some(val)), None) => { - let header = [VarLenType::Floatn as u8, 4, 4]; + (ColumnData::F32(opt), None) => { + let header = [VarLenType::Floatn as u8, 4]; dst.extend_from_slice(&header); - dst.put_f32_le(val); + if let Some(val) = opt { + dst.put_u8(4); + dst.put_f32_le(val); + } else { + dst.put_u8(0); + } } (ColumnData::F64(Some(val)), Some(TypeInfo::FixedLen(FixedLenType::Float8))) => { dst.put_f64_le(val); @@ -275,10 +299,15 @@ impl<'a> Encode> for ColumnData<'a> { dst.put_u8(0); } } - (ColumnData::F64(Some(val)), None) => { - let header = [VarLenType::Floatn as u8, 8, 8]; + (ColumnData::F64(opt), None) => { + let header = [VarLenType::Floatn as u8, 8]; dst.extend_from_slice(&header); - dst.put_f64_le(val); + if let Some(val) = opt { + dst.put_u8(8); + dst.put_f64_le(val); + } else { + dst.put_u8(0); + } } (ColumnData::Guid(opt), Some(TypeInfo::VarLenSized(vlc))) if vlc.r#type() == VarLenType::Guid =>