diff --git a/.cargo/audit.toml b/.cargo/audit.toml index 159a1505..0a8cdcda 100644 --- a/.cargo/audit.toml +++ b/.cargo/audit.toml @@ -1,2 +1,8 @@ [advisories] -ignore = ["RUSTSEC-2023-0071", "RUSTSEC-2026-0049"] +ignore = [ + "RUSTSEC-2023-0071", + "RUSTSEC-2026-0049", + # TODO: remove once electrum-client/bdk stops pulling rustls-webpki 0.101.7 + "RUSTSEC-2026-0098", + "RUSTSEC-2026-0099", +] diff --git a/.gitignore b/.gitignore index f3af7694..ca3fc711 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ /target .bria +.bats-e2e .e2e-logs - .direnv diff --git a/.sqlx/query-561f73d1c3ae19876eef43a68dad50bedd904b0e8934a065bad8ef339b3d41eb.json b/.sqlx/query-561f73d1c3ae19876eef43a68dad50bedd904b0e8934a065bad8ef339b3d41eb.json new file mode 100644 index 00000000..37cdcf73 --- /dev/null +++ b/.sqlx/query-561f73d1c3ae19876eef43a68dad50bedd904b0e8934a065bad8ef339b3d41eb.json @@ -0,0 +1,22 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT script FROM bdk_script_pubkeys\n WHERE keychain_id = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "script", + "type_info": "Bytea" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false + ] + }, + "hash": "561f73d1c3ae19876eef43a68dad50bedd904b0e8934a065bad8ef339b3d41eb" +} diff --git a/.sqlx/query-6c30dc4e7c409d797b957058f3f358ceeb651c8f971f92951b5e2d1e113b1286.json b/.sqlx/query-6c30dc4e7c409d797b957058f3f358ceeb651c8f971f92951b5e2d1e113b1286.json deleted file mode 100644 index 9021ebaf..00000000 --- a/.sqlx/query-6c30dc4e7c409d797b957058f3f358ceeb651c8f971f92951b5e2d1e113b1286.json +++ /dev/null @@ -1,44 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "SELECT script, keychain_kind as \"keychain_kind: BdkKeychainKind\", path FROM bdk_script_pubkeys\n WHERE keychain_id = $1", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "script", - "type_info": "Bytea" - }, - { - "ordinal": 1, - "name": "keychain_kind: BdkKeychainKind", - "type_info": { - "Custom": { - "name": "bdkkeychainkind", - "kind": { - "Enum": [ - "external", - "internal" - ] - } - } - } - }, - { - "ordinal": 2, - "name": "path", - "type_info": "Int4" - } - ], - "parameters": { - "Left": [ - "Uuid" - ] - }, - "nullable": [ - false, - false, - false - ] - }, - "hash": "6c30dc4e7c409d797b957058f3f358ceeb651c8f971f92951b5e2d1e113b1286" -} diff --git a/.sqlx/query-79eff690e77e488dccc3c066b265e7718128987e44d97f408059cca30fc5124c.json b/.sqlx/query-90eecb0b6744da1d056aae9b32706840136afd1c44dbe358dd1d414f32d01cb4.json similarity index 56% rename from .sqlx/query-79eff690e77e488dccc3c066b265e7718128987e44d97f408059cca30fc5124c.json rename to .sqlx/query-90eecb0b6744da1d056aae9b32706840136afd1c44dbe358dd1d414f32d01cb4.json index e2572075..02f7a0d2 100644 --- a/.sqlx/query-79eff690e77e488dccc3c066b265e7718128987e44d97f408059cca30fc5124c.json +++ b/.sqlx/query-90eecb0b6744da1d056aae9b32706840136afd1c44dbe358dd1d414f32d01cb4.json @@ -1,17 +1,18 @@ { "db_name": "PostgreSQL", - "query": "SELECT script, keychain_kind as \"keychain_kind: BdkKeychainKind\" FROM bdk_script_pubkeys\n WHERE keychain_id = $1", + "query": "SELECT script FROM bdk_script_pubkeys\n WHERE keychain_id = $1 AND keychain_kind = $2", "describe": { "columns": [ { "ordinal": 0, "name": "script", "type_info": "Bytea" - }, - { - "ordinal": 1, - "name": "keychain_kind: BdkKeychainKind", - "type_info": { + } + ], + "parameters": { + "Left": [ + "Uuid", + { "Custom": { "name": "bdkkeychainkind", "kind": { @@ -22,17 +23,11 @@ } } } - } - ], - "parameters": { - "Left": [ - "Uuid" ] }, "nullable": [ - false, false ] }, - "hash": "79eff690e77e488dccc3c066b265e7718128987e44d97f408059cca30fc5124c" + "hash": "90eecb0b6744da1d056aae9b32706840136afd1c44dbe358dd1d414f32d01cb4" } diff --git a/.sqlx/query-dfa815aeb090f27d4fd45326c706b23f3f84c9e843addc1b05160fd0e28fdf2c.json b/.sqlx/query-dfa815aeb090f27d4fd45326c706b23f3f84c9e843addc1b05160fd0e28fdf2c.json new file mode 100644 index 00000000..2e2c29dd --- /dev/null +++ b/.sqlx/query-dfa815aeb090f27d4fd45326c706b23f3f84c9e843addc1b05160fd0e28fdf2c.json @@ -0,0 +1,52 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT tx_id, sent, height,\n (details_json->>'received')::BIGINT AS \"received?\",\n (details_json->>'fee')::BIGINT AS \"fee?\",\n (details_json->'confirmation_time'->>'timestamp')::BIGINT AS \"confirmation_timestamp?\"\n FROM bdk_transactions\n WHERE keychain_id = $1 AND deleted_at IS NULL", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "tx_id", + "type_info": "Varchar" + }, + { + "ordinal": 1, + "name": "sent", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "height", + "type_info": "Int4" + }, + { + "ordinal": 3, + "name": "received?", + "type_info": "Int8" + }, + { + "ordinal": 4, + "name": "fee?", + "type_info": "Int8" + }, + { + "ordinal": 5, + "name": "confirmation_timestamp?", + "type_info": "Int8" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false, + false, + true, + null, + null, + null + ] + }, + "hash": "dfa815aeb090f27d4fd45326c706b23f3f84c9e843addc1b05160fd0e28fdf2c" +} diff --git a/src/bdk/pg/mod.rs b/src/bdk/pg/mod.rs index b1e5ca45..310fcde1 100644 --- a/src/bdk/pg/mod.rs +++ b/src/bdk/pg/mod.rs @@ -21,60 +21,272 @@ use index::Indexes; use script_pubkeys::ScriptPubkeys; use std::{ collections::HashMap, - sync::{Arc, Mutex}, + sync::atomic::{AtomicBool, Ordering}, + sync::{Arc, Mutex, MutexGuard}, }; pub(super) use sync_times::SyncTimes; pub use transactions::*; pub use utxos::*; -pub struct SqlxWalletDb { +type ScriptPubkeyCache = HashMap; +type TransactionCache = HashMap; + +#[derive(Copy, Clone, Eq, PartialEq)] +enum TxLookupMode { + Any, + RequireRaw, +} + +#[derive(Clone)] +struct WalletDbContext { rt: Handle, pool: PgPool, keychain_id: KeychainId, - utxos: Option>, - cached_spks: Arc>>, - addresses: HashMap, - cached_txs: Arc>>, - txs: HashMap, } -impl SqlxWalletDb { - pub fn new(pool: PgPool, keychain_id: KeychainId) -> Self { +impl WalletDbContext { + fn new(pool: PgPool, keychain_id: KeychainId) -> Self { Self { rt: Handle::current(), - keychain_id, pool, - utxos: None, - addresses: HashMap::new(), - cached_spks: Arc::new(Mutex::new(HashMap::new())), - txs: HashMap::new(), - cached_txs: Arc::new(Mutex::new(HashMap::new())), + keychain_id, } } +} - fn load_all_txs(&self) -> Result<(), bdk::Error> { - let mut txs = self.cached_txs.lock().expect("poisoned txs cache lock"); - if txs.is_empty() { - let loaded = self.rt.block_on(async { - let txs = Transactions::new(self.keychain_id, self.pool.clone()); - txs.load_all().await - })?; - *txs = loaded; +#[derive(Default)] +struct WalletBatchState { + utxos: Vec, + addresses: ScriptPubkeyCache, + txs: TransactionCache, +} + +#[derive(Clone)] +struct WalletCache { + script_pubkeys: Arc>, + transactions: Arc>, + // Process-local hint: true means this instance has already hydrated raw tx details + // from the DB at least once. It is intentionally not synchronized across processes. + raw_txs_fully_loaded: Arc, +} + +impl WalletCache { + fn new() -> Self { + Self { + script_pubkeys: Arc::new(Mutex::new(HashMap::new())), + transactions: Arc::new(Mutex::new(HashMap::new())), + raw_txs_fully_loaded: Arc::new(AtomicBool::new(false)), } + } + + fn lock_script_pubkeys(&self) -> Result, bdk::Error> { + self.script_pubkeys + .lock() + .map_err(|_| bdk::Error::Generic("script pubkeys cache lock poisoned".to_string())) + } + + fn lock_transactions(&self) -> Result, bdk::Error> { + self.transactions + .lock() + .map_err(|_| bdk::Error::Generic("transactions cache lock poisoned".to_string())) + } + + fn get_script_pubkey_path( + &self, + script: &Script, + ) -> Result, bdk::Error> { + let cache = self.lock_script_pubkeys()?; + Ok(cache.get(script).copied()) + } + + fn insert_script_pubkey( + &self, + script: ScriptBuf, + path: (KeychainKind, u32), + ) -> Result<(), bdk::Error> { + let mut cache = self.lock_script_pubkeys()?; + cache.insert(script, path); + Ok(()) + } + + fn extend_script_pubkeys(&self, entries: I) -> Result<(), bdk::Error> + where + I: IntoIterator, + { + let mut cache = self.lock_script_pubkeys()?; + cache.extend(entries); + Ok(()) + } + + fn get_tx(&self, txid: &Txid) -> Result, bdk::Error> { + let cache = self.lock_transactions()?; + Ok(cache.get(txid).cloned()) + } + + fn insert_tx(&self, txid: Txid, tx: TransactionDetails) -> Result<(), bdk::Error> { + let mut cache = self.lock_transactions()?; + cache.insert(txid, tx); + Ok(()) + } + + fn extend_txs(&self, entries: I) -> Result<(), bdk::Error> + where + I: IntoIterator, + { + let mut cache = self.lock_transactions()?; + cache.extend(entries); + Ok(()) + } + + fn all_txs(&self) -> Result, bdk::Error> { + let cache = self.lock_transactions()?; + Ok(cache.values().cloned().collect()) + } + + fn raw_txs_fully_loaded(&self) -> bool { + self.raw_txs_fully_loaded.load(Ordering::Acquire) + } + + fn set_raw_txs_fully_loaded(&self) { + self.raw_txs_fully_loaded.store(true, Ordering::Release); + } + + fn remove_tx(&self, txid: &Txid) -> Result<(), bdk::Error> { + let mut cache = self.lock_transactions()?; + cache.remove(txid); Ok(()) } +} + +pub struct SqlxWalletDb { + ctx: WalletDbContext, + cache: WalletCache, + batch: WalletBatchState, +} + +impl SqlxWalletDb { + fn unsupported_operation(operation: &str) -> bdk::Error { + bdk::Error::Generic(format!("{operation} is not supported by SqlxWalletDb")) + } + + pub fn new(pool: PgPool, keychain_id: KeychainId) -> Self { + Self { + ctx: WalletDbContext::new(pool, keychain_id), + cache: WalletCache::new(), + batch: WalletBatchState::default(), + } + } + + fn script_pubkeys_repo(&self) -> ScriptPubkeys { + ScriptPubkeys::new(self.ctx.keychain_id, self.ctx.pool.clone()) + } + + fn utxos_repo(&self) -> Utxos { + Utxos::new(self.ctx.keychain_id, self.ctx.pool.clone()) + } + + fn transactions_repo(&self) -> Transactions { + Transactions::new(self.ctx.keychain_id, self.ctx.pool.clone()) + } + + fn indexes_repo(&self) -> Indexes { + Indexes::new(self.ctx.keychain_id, self.ctx.pool.clone()) + } + + fn sync_times_repo(&self) -> SyncTimes { + SyncTimes::new(self.ctx.keychain_id, self.ctx.pool.clone()) + } + + fn descriptor_checksums_repo(&self) -> DescriptorChecksums { + DescriptorChecksums::new(self.ctx.keychain_id, self.ctx.pool.clone()) + } + + fn lookup_script_pubkey_path( + &self, + script: &Script, + ) -> Result, bdk::Error> { + if let Some(path) = self.batch.addresses.get(script) { + return Ok(Some(*path)); + } + + if let Some(path) = self.cache.get_script_pubkey_path(script)? { + return Ok(Some(path)); + } + + let script_pubkey = script.to_owned(); + let found = self + .ctx + .rt + .block_on(async { self.script_pubkeys_repo().find_path(&script_pubkey).await })?; + + if let Some((kind, path)) = found { + let value = (KeychainKind::from(kind), path); + self.cache.insert_script_pubkey(script_pubkey, value)?; + return Ok(Some(value)); + } + + Ok(None) + } + + fn lookup_tx_with_mode( + &self, + txid: &Txid, + mode: TxLookupMode, + ) -> Result, bdk::Error> { + if let Some(tx) = self.batch.txs.get(txid) { + if mode == TxLookupMode::Any || tx.transaction.is_some() { + return Ok(Some(tx.clone())); + } + + return Ok(None); + } + + if let Some(tx) = self.cache.get_tx(txid)? { + if mode == TxLookupMode::Any || tx.transaction.is_some() { + return Ok(Some(tx)); + } + + if self.cache.raw_txs_fully_loaded() { + return Ok(None); + } + } + + let found = self + .ctx + .rt + .block_on(async { self.transactions_repo().find_by_id(txid).await })?; + + // DB rows represent persisted TransactionDetails; this store does not persist a + // "summary-only" transaction format. A DB hit is therefore valid for both lookup + // modes (`Any` and `RequireRaw`). + + if let Some(tx) = &found { + self.cache.insert_tx(tx.txid, tx.clone())?; + } + + Ok(found) + } fn lookup_tx(&self, txid: &Txid) -> Result, bdk::Error> { - if let Some(tx) = self.txs.get(txid) { - return Ok(Some(tx.clone())); + self.lookup_tx_with_mode(txid, TxLookupMode::Any) + } + + fn overlay_batch_txs( + mut txs: HashMap, + batch_txs: &HashMap, + include_raw: bool, + ) -> HashMap { + if include_raw { + txs.extend(batch_txs.iter().map(|(id, tx)| (*id, tx.clone()))); + } else { + txs.extend(batch_txs.iter().map(|(id, tx)| { + let mut tx = tx.clone(); + tx.transaction = None; + (*id, tx) + })); } - self.load_all_txs()?; - Ok(self - .cached_txs - .lock() - .expect("poisoned txs cache lock") - .get(txid) - .cloned()) + + txs } } @@ -85,39 +297,38 @@ impl BatchOperations for SqlxWalletDb { keychain: KeychainKind, path: u32, ) -> Result<(), bdk::Error> { - self.addresses.insert(script.into(), (keychain, path)); + self.batch.addresses.insert(script.into(), (keychain, path)); Ok(()) } fn set_utxo(&mut self, utxo: &LocalUtxo) -> Result<(), bdk::Error> { - if self.utxos.is_none() { - self.utxos = Some(Vec::new()); - } - self.utxos.as_mut().unwrap().push(utxo.clone()); + self.batch.utxos.push(utxo.clone()); Ok(()) } fn set_raw_tx(&mut self, _: &Transaction) -> Result<(), bdk::Error> { - unimplemented!() + Err(Self::unsupported_operation("set_raw_tx")) } fn set_tx(&mut self, tx: &TransactionDetails) -> Result<(), bdk::Error> { - self.txs.insert(tx.txid, tx.clone()); + self.batch.txs.insert(tx.txid, tx.clone()); Ok(()) } fn set_last_index(&mut self, kind: KeychainKind, idx: u32) -> Result<(), bdk::Error> { - self.rt.block_on(async { - let indexes = Indexes::new(self.keychain_id, self.pool.clone()); - indexes.persist_last_index(kind, idx).await - }) + // NOTE: This write is intentionally immediate because BDK may call it outside of + // `commit_batch` flow. + self.ctx + .rt + .block_on(async { self.indexes_repo().persist_last_index(kind, idx).await }) } fn set_sync_time(&mut self, time: SyncTime) -> Result<(), bdk::Error> { - self.rt.block_on(async { - let sync_times = SyncTimes::new(self.keychain_id, self.pool.clone()); - sync_times.persist(time).await - }) + // NOTE: This write is intentionally immediate because BDK may call it outside of + // `commit_batch` flow. + self.ctx + .rt + .block_on(async { self.sync_times_repo().persist(time).await }) } fn del_script_pubkey_from_path( @@ -125,23 +336,21 @@ impl BatchOperations for SqlxWalletDb { _: KeychainKind, _: u32, ) -> Result, bdk::Error> { - unimplemented!() + Err(Self::unsupported_operation("del_script_pubkey_from_path")) } fn del_path_from_script_pubkey( &mut self, _: &Script, ) -> Result, bdk::Error> { - unimplemented!() + Err(Self::unsupported_operation("del_path_from_script_pubkey")) } fn del_utxo(&mut self, outpoint: &OutPoint) -> Result, bdk::Error> { - self.rt.block_on(async { - Utxos::new(self.keychain_id, self.pool.clone()) - .delete(outpoint) - .await - }) + self.ctx + .rt + .block_on(async { self.utxos_repo().delete(outpoint).await }) } fn del_raw_tx(&mut self, _: &Txid) -> Result, bdk::Error> { - unimplemented!() + Err(Self::unsupported_operation("del_raw_tx")) } fn del_tx( @@ -149,16 +358,23 @@ impl BatchOperations for SqlxWalletDb { tx_id: &Txid, _include_raw: bool, ) -> Result, bdk::Error> { - self.rt.block_on(async { - let txs = Transactions::new(self.keychain_id, self.pool.clone()); - txs.delete(tx_id).await - }) + let deleted = self + .ctx + .rt + .block_on(async { self.transactions_repo().delete(tx_id).await })?; + + if deleted.is_some() { + self.batch.txs.remove(tx_id); + self.cache.remove_tx(tx_id)?; + } + + Ok(deleted) } fn del_last_index(&mut self, _: KeychainKind) -> Result, bdk::Error> { - unimplemented!() + Err(Self::unsupported_operation("del_last_index")) } fn del_sync_time(&mut self) -> Result, bdk::Error> { - unimplemented!() + Err(Self::unsupported_operation("del_sync_time")) } } @@ -171,8 +387,8 @@ impl Database for SqlxWalletDb { where B: AsRef<[u8]>, { - self.rt.block_on(async { - let checksums = DescriptorChecksums::new(self.keychain_id, self.pool.clone()); + self.ctx.rt.block_on(async { + let checksums = self.descriptor_checksums_repo(); checksums .check_or_persist_descriptor_checksum(keychain, script_bytes.as_ref()) .await?; @@ -184,31 +400,45 @@ impl Database for SqlxWalletDb { &self, keychain: Option, ) -> Result, bdk::Error> { - self.rt.block_on(async { - let script_pubkeys = ScriptPubkeys::new(self.keychain_id, self.pool.clone()); - let scripts = script_pubkeys.list_scripts(keychain).await?; - Ok(scripts) - }) + self.ctx + .rt + .block_on(async { self.script_pubkeys_repo().list_scripts(keychain).await }) } fn iter_utxos(&self) -> Result, bdk::Error> { - self.rt.block_on(async { - Utxos::new(self.keychain_id, self.pool.clone()) - .list_local_utxos() - .await - }) + self.ctx + .rt + .block_on(async { self.utxos_repo().list_local_utxos().await }) } fn iter_raw_txs(&self) -> Result, bdk::Error> { - unimplemented!() + Err(Self::unsupported_operation("iter_raw_txs")) } - fn iter_txs(&self, _: bool) -> Result, bdk::Error> { - self.load_all_txs()?; - Ok(self - .cached_txs - .lock() - .expect("poisoned txs cache lock") - .values() - .cloned() + fn iter_txs(&self, include_raw: bool) -> Result, bdk::Error> { + let txs = if include_raw { + if self.cache.raw_txs_fully_loaded() { + self.cache + .all_txs()? + .into_iter() + .map(|tx| (tx.txid, tx)) + .collect() + } else { + let loaded = self + .ctx + .rt + .block_on(async { self.transactions_repo().load_all().await })?; + self.cache + .extend_txs(loaded.iter().map(|(txid, tx)| (*txid, tx.clone())))?; + self.cache.set_raw_txs_fully_loaded(); + loaded + } + } else { + self.ctx + .rt + .block_on(async { self.transactions_repo().load_all_summaries().await })? + }; + + Ok(Self::overlay_batch_txs(txs, &self.batch.txs, include_raw) + .into_values() .collect()) } @@ -217,67 +447,53 @@ impl Database for SqlxWalletDb { keychain: KeychainKind, path: u32, ) -> Result, bdk::Error> { - self.rt.block_on(async { - let script_pubkeys = ScriptPubkeys::new(self.keychain_id, self.pool.clone()); - script_pubkeys.find_script(keychain, path).await - }) + self.ctx + .rt + .block_on(async { self.script_pubkeys_repo().find_script(keychain, path).await }) } fn get_path_from_script_pubkey( &self, script: &Script, ) -> Result, bdk::Error> { - let mut cache = self.cached_spks.lock().expect("poisoned spk cache lock"); - if cache.is_empty() { - let loaded = self.rt.block_on(async { - let script_pubkeys = ScriptPubkeys::new(self.keychain_id, self.pool.clone()); - script_pubkeys.load_all().await - })?; - *cache = loaded; - } - - if let Some(res) = cache.get(script) { - Ok(Some(*res)) - } else if let Some(res) = self.addresses.get(script) { - Ok(Some(*res)) - } else { - Ok(None) - } + self.lookup_script_pubkey_path(script) } fn get_utxo(&self, outpoint: &OutPoint) -> Result, bdk::Error> { - self.rt.block_on(async { - Utxos::new(self.keychain_id, self.pool.clone()) - .find(outpoint) - .await - }) + self.ctx + .rt + .block_on(async { self.utxos_repo().find(outpoint).await }) } fn get_raw_tx(&self, tx_id: &Txid) -> Result, bdk::Error> { - self.lookup_tx(tx_id) + self.lookup_tx_with_mode(tx_id, TxLookupMode::RequireRaw) .map(|tx| tx.and_then(|tx| tx.transaction)) } fn get_tx( &self, tx_id: &Txid, - _include_raw: bool, + include_raw: bool, ) -> Result, bdk::Error> { - self.lookup_tx(tx_id) + self.lookup_tx(tx_id).map(|tx| { + tx.map(|mut tx| { + if !include_raw { + tx.transaction = None; + } + tx + }) + }) } fn get_last_index(&self, kind: KeychainKind) -> Result, bdk::Error> { - self.rt.block_on(async { - let last_indexes = Indexes::new(self.keychain_id, self.pool.clone()); - last_indexes.get_latest(kind).await - }) + self.ctx + .rt + .block_on(async { self.indexes_repo().get_latest(kind).await }) } fn get_sync_time(&self) -> Result, bdk::Error> { - self.rt.block_on(async { - let sync_times = SyncTimes::new(self.keychain_id, self.pool.clone()); - sync_times.get().await - }) + self.ctx + .rt + .block_on(async { self.sync_times_repo().get().await }) } fn increment_last_index(&mut self, keychain: KeychainKind) -> Result { - self.rt.block_on(async { - let indexes = Indexes::new(self.keychain_id, self.pool.clone()); - indexes.increment(keychain).await - }) + self.ctx + .rt + .block_on(async { self.indexes_repo().increment(keychain).await }) } } @@ -285,52 +501,173 @@ impl BatchDatabase for SqlxWalletDb { type Batch = Self; fn begin_batch(&self) -> ::Batch { - let mut res = SqlxWalletDb::new(self.pool.clone(), self.keychain_id); - res.cached_spks = Arc::clone(&self.cached_spks); - res.cached_txs = Arc::clone(&self.cached_txs); - res + SqlxWalletDb { + ctx: self.ctx.clone(), + cache: self.cache.clone(), + batch: WalletBatchState::default(), + } } fn commit_batch( &mut self, mut batch: ::Batch, ) -> Result<(), bdk::Error> { - self.cached_spks - .lock() - .expect("poisoned spk cache lock") - .extend( - batch - .addresses - .iter() - .map(|(s, (k, p))| (s.clone(), (*k, *p))), - ); - - self.cached_txs - .lock() - .expect("poisoned txs cache lock") - .extend(batch.txs.iter().map(|(id, tx)| (*id, tx.clone()))); - - self.rt.block_on(async move { - if !batch.addresses.is_empty() { - let addresses: Vec<_> = batch - .addresses - .drain() - .map(|(s, (k, p))| (BdkKeychainKind::from(k), p, s)) - .collect(); - let repo = ScriptPubkeys::new(batch.keychain_id, batch.pool.clone()); - repo.persist_all(addresses).await?; + // Atomic scope here is limited to staged script pubkeys, utxos, and transactions. + // `set_last_index` / `set_sync_time` remain immediate writes by design. + let (addresses_for_cache, addresses_for_db): (Vec<_>, Vec<_>) = batch + .batch + .addresses + .drain() + .map(|(script, (keychain, path))| { + let cache_entry = (script.clone(), (keychain, path)); + let db_entry = (BdkKeychainKind::from(keychain), path, script); + (cache_entry, db_entry) + }) + .unzip(); + + let (txs_for_cache, txs_for_db): (Vec<_>, Vec<_>) = batch + .batch + .txs + .drain() + .map(|(txid, tx)| ((txid, tx.clone()), tx)) + .unzip(); + + let utxos_for_db = std::mem::take(&mut batch.batch.utxos); + let keychain_id = batch.ctx.keychain_id; + let pool = batch.ctx.pool.clone(); + + batch.ctx.rt.block_on(async move { + let mut tx = pool + .begin() + .await + .map_err(|e| bdk::Error::Generic(e.to_string()))?; + + if !addresses_for_db.is_empty() { + ScriptPubkeys::new(keychain_id, pool.clone()) + .persist_all_in_tx(&mut tx, addresses_for_db) + .await?; } - if let Some(utxos) = batch.utxos.take() { - let repo = Utxos::new(batch.keychain_id, batch.pool.clone()); - repo.persist_all(utxos).await?; + if !utxos_for_db.is_empty() { + Utxos::new(keychain_id, pool.clone()) + .persist_all_in_tx(&mut tx, utxos_for_db) + .await?; } - if !batch.txs.is_empty() { - let txs = batch.txs.drain().map(|(_, tx)| tx).collect(); - let repo = Transactions::new(batch.keychain_id, batch.pool.clone()); - repo.persist_all(txs).await?; + + if !txs_for_db.is_empty() { + Transactions::new(keychain_id, pool) + .persist_all_in_tx(&mut tx, txs_for_db) + .await?; } + + tx.commit() + .await + .map_err(|e| bdk::Error::Generic(e.to_string()))?; + Ok::<_, bdk::Error>(()) - }) + })?; + + self.cache.extend_script_pubkeys(addresses_for_cache)?; + self.cache.extend_txs(txs_for_cache)?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bdk::bitcoin::hashes::Hash; + + fn tx_details(txid: Txid) -> TransactionDetails { + TransactionDetails { + transaction: None, + txid, + received: 0, + sent: 0, + fee: None, + confirmation_time: None, + } + } + + #[test] + fn wallet_cache_can_insert_get_and_remove_transactions() { + let cache = WalletCache::new(); + let txid = Txid::all_zeros(); + let details = tx_details(txid); + + cache + .insert_tx(txid, details.clone()) + .expect("insert should succeed"); + let loaded = cache.get_tx(&txid).expect("get should succeed"); + assert_eq!(loaded, Some(details)); + + cache.remove_tx(&txid).expect("remove should succeed"); + let loaded = cache.get_tx(&txid).expect("get should succeed"); + assert_eq!(loaded, None); + } + + #[test] + fn wallet_cache_can_insert_and_get_script_pubkey_paths() { + let cache = WalletCache::new(); + let script = ScriptBuf::new(); + let path = (KeychainKind::External, 42); + + cache + .insert_script_pubkey(script.clone(), path) + .expect("insert should succeed"); + + let loaded = cache + .get_script_pubkey_path(script.as_script()) + .expect("get should succeed"); + assert_eq!(loaded, Some(path)); + } + + #[test] + fn wallet_cache_raw_txs_loaded_flag_defaults_false_and_can_be_set() { + let cache = WalletCache::new(); + assert!(!cache.raw_txs_fully_loaded()); + + cache.set_raw_txs_fully_loaded(); + assert!(cache.raw_txs_fully_loaded()); + } + + #[test] + fn overlay_batch_txs_strips_raw_when_include_raw_is_false() { + let txid = Txid::all_zeros(); + let mut base = HashMap::new(); + base.insert(txid, tx_details(txid)); + + let raw_tx = bdk::bitcoin::Transaction { + version: 2, + lock_time: bdk::bitcoin::absolute::LockTime::ZERO, + input: Vec::new(), + output: Vec::new(), + }; + + let mut batch = HashMap::new(); + let mut batch_tx = tx_details(txid); + batch_tx.transaction = Some(raw_tx); + batch.insert(txid, batch_tx); + + let merged = SqlxWalletDb::overlay_batch_txs(base, &batch, false); + assert!(merged + .get(&txid) + .expect("merged tx should exist") + .transaction + .is_none()); + } + + #[test] + fn wallet_cache_all_txs_returns_cached_values() { + let cache = WalletCache::new(); + let txid = Txid::all_zeros(); + + cache + .insert_tx(txid, tx_details(txid)) + .expect("insert should succeed"); + + let txs = cache.all_txs().expect("all_txs should succeed"); + assert_eq!(txs.len(), 1); + assert_eq!(txs[0].txid, txid); } } diff --git a/src/bdk/pg/script_pubkeys.rs b/src/bdk/pg/script_pubkeys.rs index 18f11085..e02eeca3 100644 --- a/src/bdk/pg/script_pubkeys.rs +++ b/src/bdk/pg/script_pubkeys.rs @@ -1,5 +1,4 @@ -use sqlx::{PgPool, Postgres, QueryBuilder}; -use std::collections::HashMap; +use sqlx::{PgPool, Postgres, QueryBuilder, Transaction}; use tracing::instrument; use uuid::Uuid; @@ -17,6 +16,8 @@ impl ScriptPubkeys { } #[instrument(name = "bdk.script_pubkeys.persist_all", skip_all)] + // Retained for non-transactional call sites and focused tests. + #[allow(dead_code)] pub async fn persist_all( &self, keys: Vec<(BdkKeychainKind, u32, ScriptBuf)>, @@ -39,12 +40,45 @@ impl ScriptPubkeys { }); query_builder.push("ON CONFLICT DO NOTHING"); - let query = query_builder.build(); - query + query_builder + .build() .execute(&self.pool) .await .map_err(|e| bdk::Error::Generic(e.to_string()))?; } + + Ok(()) + } + + pub async fn persist_all_in_tx( + &self, + tx: &mut Transaction<'_, Postgres>, + keys: Vec<(BdkKeychainKind, u32, ScriptBuf)>, + ) -> Result<(), bdk::Error> { + const BATCH_SIZE: usize = 5000; + let chunks = keys.chunks(BATCH_SIZE); + for chunk in chunks { + let mut query_builder: QueryBuilder = QueryBuilder::new( + r#"INSERT INTO bdk_script_pubkeys + (keychain_id, keychain_kind, path, script, script_hex, script_fmt)"#, + ); + + query_builder.push_values(chunk, |mut builder, (keychain, path, script)| { + builder.push_bind(self.keychain_id); + builder.push_bind(keychain); + builder.push_bind(*path as i32); + builder.push_bind(script.as_bytes()); + builder.push_bind(format!("{script:02x}")); + builder.push_bind(format!("{script:?}")); + }); + query_builder.push("ON CONFLICT DO NOTHING"); + + query_builder + .build() + .execute(tx.as_mut()) + .await + .map_err(|e| bdk::Error::Generic(e.to_string()))?; + } Ok(()) } @@ -55,64 +89,36 @@ impl ScriptPubkeys { path: u32, ) -> Result, bdk::Error> { let kind = keychain.into(); - let rows = sqlx::query!( + let row = sqlx::query!( r#"SELECT script FROM bdk_script_pubkeys WHERE keychain_id = $1 AND keychain_kind = $2 AND path = $3"#, Uuid::from(self.keychain_id), kind as BdkKeychainKind, path as i32, ) - .fetch_all(&self.pool) + .fetch_optional(&self.pool) .await .map_err(|e| bdk::Error::Generic(e.to_string()))?; - Ok(rows - .into_iter() - .next() - .map(|row| ScriptBuf::from(row.script))) - } - #[instrument(name = "bdk.script_pubkeys.load_all", skip_all)] - pub async fn load_all( - &self, - ) -> Result, bdk::Error> { - let rows = sqlx::query!( - r#"SELECT script, keychain_kind as "keychain_kind: BdkKeychainKind", path FROM bdk_script_pubkeys - WHERE keychain_id = $1"#, - Uuid::from(self.keychain_id), - ) - .fetch_all(&self.pool) - .await - .map_err(|e| bdk::Error::Generic(e.to_string()))?; - let mut ret = HashMap::new(); - for row in rows { - ret.insert( - ScriptBuf::from(row.script), - (bdk::KeychainKind::from(row.keychain_kind), row.path as u32), - ); - } - Ok(ret) + Ok(row.map(|row| ScriptBuf::from(row.script))) } - #[allow(dead_code)] #[instrument(name = "bdk.script_pubkeys.find_path", skip_all)] pub async fn find_path( &self, script: &ScriptBuf, ) -> Result, bdk::Error> { - let rows = sqlx::query!( + let row = sqlx::query!( r#"SELECT keychain_kind as "keychain_kind: BdkKeychainKind", path FROM bdk_script_pubkeys WHERE keychain_id = $1 AND script_hex = ENCODE($2, 'hex')"#, Uuid::from(self.keychain_id), script.as_bytes(), ) - .fetch_all(&self.pool) + .fetch_optional(&self.pool) .await .map_err(|e| bdk::Error::Generic(e.to_string()))?; - if let Some(row) = rows.into_iter().next() { - Ok(Some((row.keychain_kind, row.path as u32))) - } else { - Ok(None) - } + + Ok(row.map(|row| (row.keychain_kind, row.path as u32))) } #[instrument(name = "bdk.script_pubkeys.list_scripts", skip_all)] @@ -120,28 +126,28 @@ impl ScriptPubkeys { &self, keychain: Option>, ) -> Result, bdk::Error> { - let kind = keychain.map(|k| k.into()); - let rows = sqlx::query!( - r#"SELECT script, keychain_kind as "keychain_kind: BdkKeychainKind" FROM bdk_script_pubkeys - WHERE keychain_id = $1"#, - Uuid::from(self.keychain_id), - ) - .fetch_all(&self.pool) - .await - .map_err(|e| bdk::Error::Generic(e.to_string()))?; - Ok(rows - .into_iter() - .filter_map(|row| { - if let Some(kind) = kind { - if kind == row.keychain_kind { - Some(ScriptBuf::from(row.script)) - } else { - None - } - } else { - Some(ScriptBuf::from(row.script)) - } - }) - .collect()) + let keychain_id = Uuid::from(self.keychain_id); + let scripts = if let Some(kind) = keychain.map(Into::into) { + sqlx::query_scalar!( + r#"SELECT script FROM bdk_script_pubkeys + WHERE keychain_id = $1 AND keychain_kind = $2"#, + keychain_id, + kind as BdkKeychainKind, + ) + .fetch_all(&self.pool) + .await + .map_err(|e| bdk::Error::Generic(e.to_string()))? + } else { + sqlx::query_scalar!( + r#"SELECT script FROM bdk_script_pubkeys + WHERE keychain_id = $1"#, + keychain_id, + ) + .fetch_all(&self.pool) + .await + .map_err(|e| bdk::Error::Generic(e.to_string()))? + }; + + Ok(scripts.into_iter().map(ScriptBuf::from).collect()) } } diff --git a/src/bdk/pg/transactions.rs b/src/bdk/pg/transactions.rs index 34324a8f..64ca305a 100644 --- a/src/bdk/pg/transactions.rs +++ b/src/bdk/pg/transactions.rs @@ -1,11 +1,13 @@ -use bdk::{bitcoin::Txid, LocalUtxo, TransactionDetails}; -use sqlx::{PgPool, Postgres, QueryBuilder, Transaction}; +use bdk::{bitcoin::Txid, BlockTime, LocalUtxo, TransactionDetails}; +use sqlx::{PgPool, Postgres, QueryBuilder, Transaction as SqlxTransaction}; use tracing::instrument; use std::collections::HashMap; use crate::{bdk::error::BdkError, primitives::*}; +type SerializedTransactionRow = (String, serde_json::Value, i64, Option); + #[derive(Debug)] pub struct UnsyncedTransaction { pub tx_id: bitcoin::Txid, @@ -31,29 +33,54 @@ pub struct Transactions { } impl Transactions { + fn serialize_batch( + batch: &[TransactionDetails], + ) -> Result, bdk::Error> { + batch + .iter() + .map(|tx| { + Ok::<_, bdk::Error>(( + tx.txid.to_string(), + serde_json::to_value(tx).map_err(|e| { + bdk::Error::Generic(format!("failed to serialize tx details: {e}")) + })?, + tx.sent as i64, + tx.confirmation_time.as_ref().map(|t| t.height as i32), + )) + }) + .collect() + } + pub fn new(keychain_id: KeychainId, pool: PgPool) -> Self { Self { keychain_id, pool } } #[instrument(name = "bdk.transactions.persist", skip_all)] + // Retained for non-transactional call sites and focused tests. + #[allow(dead_code)] pub async fn persist_all(&self, txs: Vec) -> Result<(), bdk::Error> { const BATCH_SIZE: usize = 2000; let batches = txs.chunks(BATCH_SIZE); for batch in batches { + let serialized_batch = Self::serialize_batch(batch)?; + let mut query_builder: QueryBuilder = QueryBuilder::new( r#" INSERT INTO bdk_transactions (keychain_id, tx_id, details_json, sent, height)"#, ); - query_builder.push_values(batch, |mut builder, tx| { - builder.push_bind(self.keychain_id as KeychainId); - builder.push_bind(tx.txid.to_string()); - builder.push_bind(serde_json::to_value(tx).unwrap()); - builder.push_bind(tx.sent as i64); - builder.push_bind(tx.confirmation_time.as_ref().map(|t| t.height as i32)); - }); + query_builder.push_values( + serialized_batch, + |mut builder, (tx_id, details_json, sent, height)| { + builder.push_bind(self.keychain_id as KeychainId); + builder.push_bind(tx_id); + builder.push_bind(details_json); + builder.push_bind(sent); + builder.push_bind(height); + }, + ); query_builder.push( "ON CONFLICT (keychain_id, tx_id) DO UPDATE \ @@ -68,8 +95,8 @@ impl Transactions { OR bdk_transactions.deleted_at IS NOT NULL", ); - let query = query_builder.build(); - query + query_builder + .build() .execute(&self.pool) .await .map_err(|e| bdk::Error::Generic(e.to_string()))?; @@ -78,6 +105,57 @@ impl Transactions { Ok(()) } + pub async fn persist_all_in_tx( + &self, + tx: &mut SqlxTransaction<'_, Postgres>, + txs: Vec, + ) -> Result<(), bdk::Error> { + const BATCH_SIZE: usize = 2000; + let batches = txs.chunks(BATCH_SIZE); + + for batch in batches { + let serialized_batch = Self::serialize_batch(batch)?; + + let mut query_builder: QueryBuilder = QueryBuilder::new( + r#" + INSERT INTO bdk_transactions + (keychain_id, tx_id, details_json, sent, height)"#, + ); + + query_builder.push_values( + serialized_batch, + |mut builder, (tx_id, details_json, sent, height)| { + builder.push_bind(self.keychain_id as KeychainId); + builder.push_bind(tx_id); + builder.push_bind(details_json); + builder.push_bind(sent); + builder.push_bind(height); + }, + ); + + query_builder.push( + "ON CONFLICT (keychain_id, tx_id) DO UPDATE \ + SET details_json = EXCLUDED.details_json,\ + sent = EXCLUDED.sent,\ + height = EXCLUDED.height,\ + modified_at = NOW(),\ + deleted_at = NULL \ + WHERE bdk_transactions.details_json IS DISTINCT FROM EXCLUDED.details_json \ + OR bdk_transactions.sent IS DISTINCT FROM EXCLUDED.sent \ + OR bdk_transactions.height IS DISTINCT FROM EXCLUDED.height \ + OR bdk_transactions.deleted_at IS NOT NULL", + ); + + query_builder + .build() + .execute(tx.as_mut()) + .await + .map_err(|e| bdk::Error::Generic(e.to_string()))?; + } + + Ok(()) + } + #[instrument(name = "bdk.transactions.delete", skip_all)] pub async fn delete(&self, tx_id: &Txid) -> Result, bdk::Error> { let tx = sqlx::query!( @@ -92,12 +170,13 @@ impl Transactions { .await .map_err(|e| bdk::Error::Generic(e.to_string()))?; - Ok(tx.map(|tx| { - serde_json::from_value(tx.details_json).expect("could not deserialize tx details") - })) + tx.map(|tx| { + serde_json::from_value(tx.details_json) + .map_err(|e| bdk::Error::Generic(format!("could not deserialize tx details: {e}"))) + }) + .transpose() } - #[allow(dead_code)] #[instrument(name = "bdk.transactions.find_by_id", skip_all)] pub async fn find_by_id(&self, tx_id: &Txid) -> Result, bdk::Error> { let tx = sqlx::query!( @@ -109,7 +188,11 @@ impl Transactions { .fetch_optional(&self.pool) .await .map_err(|e| bdk::Error::Generic(e.to_string()))?; - Ok(tx.map(|tx| serde_json::from_value(tx.details_json).unwrap())) + tx.map(|tx| { + serde_json::from_value(tx.details_json) + .map_err(|e| bdk::Error::Generic(format!("could not deserialize tx details: {e}"))) + }) + .transpose() } #[instrument(name = "bdk.transactions.load_all", skip(self), fields(n_rows))] @@ -123,13 +206,86 @@ impl Transactions { .await .map_err(|e| bdk::Error::Generic(e.to_string()))?; tracing::Span::current().record("n_rows", txs.len()); - Ok(txs - .into_iter() + txs.into_iter() .map(|tx| { - let tx = serde_json::from_value::(tx.details_json).unwrap(); - (tx.txid, tx) + serde_json::from_value::(tx.details_json) + .map(|tx| (tx.txid, tx)) + .map_err(|e| { + bdk::Error::Generic(format!("could not deserialize tx details: {e}")) + }) }) - .collect()) + .collect() + } + + #[instrument( + name = "bdk.transactions.load_all_summaries", + skip(self), + fields(n_rows) + )] + pub async fn load_all_summaries( + &self, + ) -> Result, bdk::Error> { + let rows = sqlx::query!( + r#" + SELECT tx_id, sent, height, + (details_json->>'received')::BIGINT AS "received?", + (details_json->>'fee')::BIGINT AS "fee?", + (details_json->'confirmation_time'->>'timestamp')::BIGINT AS "confirmation_timestamp?" + FROM bdk_transactions + WHERE keychain_id = $1 AND deleted_at IS NULL"#, + self.keychain_id as KeychainId, + ) + .fetch_all(&self.pool) + .await + .map_err(|e| bdk::Error::Generic(e.to_string()))?; + + tracing::Span::current().record("n_rows", rows.len()); + + fn to_u64(value: i64, field: &str) -> Result { + if value < 0 { + return Err(bdk::Error::Generic(format!( + "negative {field} value in bdk_transactions" + ))); + } + Ok(value as u64) + } + + fn to_u32(value: i32, field: &str) -> Result { + if value < 0 { + return Err(bdk::Error::Generic(format!( + "negative {field} value in bdk_transactions" + ))); + } + Ok(value as u32) + } + + rows.into_iter() + .map(|row| { + let txid = row + .tx_id + .parse::() + .map_err(|e| bdk::Error::Generic(format!("invalid tx_id in db: {e}")))?; + + let confirmation_time = match (row.height, row.confirmation_timestamp) { + (Some(height), Some(timestamp)) => Some(BlockTime { + height: to_u32(height, "height")?, + timestamp: to_u64(timestamp, "confirmation timestamp")?, + }), + _ => None, + }; + + let details = TransactionDetails { + txid, + transaction: None, + received: to_u64(row.received.unwrap_or_default(), "received")?, + sent: to_u64(row.sent, "sent")?, + fee: row.fee.map(|f| to_u64(f, "fee")).transpose()?, + confirmation_time, + }; + + Ok((txid, details)) + }) + .collect() } #[instrument(name = "bdk.transactions.find_unsynced_tx", skip(self), fields(n_rows))] @@ -184,11 +340,22 @@ impl Transactions { inputs.push((utxo, row.path as u32)); } if tx_id.is_none() { - tx_id = Some(row.tx_id.parse().expect("couldn't parse tx_id")); + tx_id = Some(row.tx_id.parse().map_err(|e| { + bdk::Error::Generic(format!("invalid tx id from bdk_transactions: {e}")) + })?); let details: TransactionDetails = serde_json::from_value(row.details_json)?; total_utxo_in_sats = Satoshis::from(details.sent); - fee_sats = Satoshis::from(details.fee.expect("Fee")); - vsize = details.transaction.expect("transaction").vsize() as u64; + fee_sats = Satoshis::from(details.fee.ok_or_else(|| { + bdk::Error::Generic("missing fee in unsynced transaction details".to_string()) + })?); + vsize = details + .transaction + .ok_or_else(|| { + bdk::Error::Generic( + "missing raw transaction in unsynced transaction details".to_string(), + ) + })? + .vsize() as u64; confirmation_time = details.confirmation_time; } } @@ -206,7 +373,7 @@ impl Transactions { #[instrument(name = "bdk.transactions.find_confirmed_spend_tx", skip(self, tx))] pub async fn find_confirmed_spend_tx( &self, - tx: &mut Transaction<'_, Postgres>, + tx: &mut SqlxTransaction<'_, Postgres>, min_height: u32, ) -> Result, BdkError> { let rows = sqlx::query!(r#" @@ -257,19 +424,29 @@ impl Transactions { inputs.push(utxo); } if tx_id.is_none() { - tx_id = Some(row.tx_id.parse().expect("couldn't parse tx_id")); + tx_id = Some(row.tx_id.parse().map_err(|e| { + bdk::Error::Generic(format!("invalid tx id from bdk_transactions: {e}")) + })?); let details: TransactionDetails = serde_json::from_value(row.details_json)?; confirmation_time = details.confirmation_time; } } - Ok(tx_id.map(|tx_id| ConfirmedSpendTransaction { - tx_id, - confirmation_time: confirmation_time - .expect("query should always return confirmation_time"), - inputs, - outputs, - })) + if let Some(tx_id) = tx_id { + let confirmation_time = confirmation_time.ok_or_else(|| { + bdk::Error::Generic( + "missing confirmation_time in confirmed spend transaction details".to_string(), + ) + })?; + Ok(Some(ConfirmedSpendTransaction { + tx_id, + confirmation_time, + inputs, + outputs, + })) + } else { + Ok(None) + } } #[instrument(name = "bdk.transactions.mark_as_synced", skip(self))] @@ -288,7 +465,7 @@ impl Transactions { #[instrument(name = "bdk.transactions.mark_confirmed", skip(self))] pub async fn mark_confirmed( &self, - tx: &mut Transaction<'_, Postgres>, + tx: &mut SqlxTransaction<'_, Postgres>, tx_id: bitcoin::Txid, ) -> Result<(), BdkError> { sqlx::query!( @@ -308,7 +485,7 @@ impl Transactions { )] pub async fn delete_transaction_if_no_more_utxos_exist( &self, - tx: &mut Transaction<'_, Postgres>, + tx: &mut SqlxTransaction<'_, Postgres>, outpoint: bitcoin::OutPoint, ) -> Result<(), BdkError> { sqlx::query!( diff --git a/src/bdk/pg/utxos.rs b/src/bdk/pg/utxos.rs index 60628a19..cc0a25c8 100644 --- a/src/bdk/pg/utxos.rs +++ b/src/bdk/pg/utxos.rs @@ -1,10 +1,12 @@ use bdk::{bitcoin::blockdata::transaction::OutPoint, LocalUtxo, TransactionDetails}; -use sqlx::{PgPool, Postgres, QueryBuilder, Transaction}; +use sqlx::{PgPool, Postgres, QueryBuilder, Transaction as SqlxTransaction}; use tracing::instrument; use uuid::Uuid; use crate::{bdk::error::BdkError, primitives::*}; +type SerializedUtxoRow = (String, i32, serde_json::Value, bool); + pub struct ConfirmedIncomeUtxo { pub outpoint: bitcoin::OutPoint, pub spent: bool, @@ -17,28 +19,51 @@ pub struct Utxos { } impl Utxos { + fn serialize_batch(batch: &[LocalUtxo]) -> Result, bdk::Error> { + batch + .iter() + .map(|utxo| { + Ok::<_, bdk::Error>(( + utxo.outpoint.txid.to_string(), + utxo.outpoint.vout as i32, + serde_json::to_value(utxo).map_err(|e| { + bdk::Error::Generic(format!("failed to serialize utxo: {e}")) + })?, + utxo.is_spent, + )) + }) + .collect() + } + pub fn new(keychain_id: KeychainId, pool: PgPool) -> Self { Self { keychain_id, pool } } #[instrument(name = "bdk.utxos.persist_all", skip_all)] + // Retained for non-transactional call sites and focused tests. + #[allow(dead_code)] pub async fn persist_all(&self, utxos: Vec) -> Result<(), bdk::Error> { const BATCH_SIZE: usize = 2000; let batches = utxos.chunks(BATCH_SIZE); for batch in batches { + let serialized_batch = Self::serialize_batch(batch)?; + let mut query_builder: QueryBuilder = QueryBuilder::new( r#"INSERT INTO bdk_utxos (keychain_id, tx_id, vout, utxo_json, is_spent)"#, ); - query_builder.push_values(batch, |mut builder, utxo| { - builder.push_bind(Uuid::from(self.keychain_id)); - builder.push_bind(utxo.outpoint.txid.to_string()); - builder.push_bind(utxo.outpoint.vout as i32); - builder.push_bind(serde_json::to_value(utxo).unwrap()); - builder.push_bind(utxo.is_spent); - }); + query_builder.push_values( + serialized_batch, + |mut builder, (tx_id, vout, utxo_json, is_spent)| { + builder.push_bind(Uuid::from(self.keychain_id)); + builder.push_bind(tx_id); + builder.push_bind(vout); + builder.push_bind(utxo_json); + builder.push_bind(is_spent); + }, + ); query_builder.push( "ON CONFLICT (keychain_id, tx_id, vout) DO UPDATE \ @@ -51,8 +76,8 @@ impl Utxos { OR bdk_utxos.deleted_at IS NOT NULL", ); - let query = query_builder.build(); - query + query_builder + .build() .execute(&self.pool) .await .map_err(|e| bdk::Error::Generic(e.to_string()))?; @@ -61,6 +86,54 @@ impl Utxos { Ok(()) } + pub async fn persist_all_in_tx( + &self, + tx: &mut SqlxTransaction<'_, Postgres>, + utxos: Vec, + ) -> Result<(), bdk::Error> { + const BATCH_SIZE: usize = 2000; + let batches = utxos.chunks(BATCH_SIZE); + + for batch in batches { + let serialized_batch = Self::serialize_batch(batch)?; + + let mut query_builder: QueryBuilder = QueryBuilder::new( + r#"INSERT INTO bdk_utxos + (keychain_id, tx_id, vout, utxo_json, is_spent)"#, + ); + + query_builder.push_values( + serialized_batch, + |mut builder, (tx_id, vout, utxo_json, is_spent)| { + builder.push_bind(Uuid::from(self.keychain_id)); + builder.push_bind(tx_id); + builder.push_bind(vout); + builder.push_bind(utxo_json); + builder.push_bind(is_spent); + }, + ); + + query_builder.push( + "ON CONFLICT (keychain_id, tx_id, vout) DO UPDATE \ + SET utxo_json = EXCLUDED.utxo_json,\ + is_spent = EXCLUDED.is_spent,\ + modified_at = NOW(),\ + deleted_at = NULL \ + WHERE bdk_utxos.utxo_json IS DISTINCT FROM EXCLUDED.utxo_json \ + OR bdk_utxos.is_spent IS DISTINCT FROM EXCLUDED.is_spent \ + OR bdk_utxos.deleted_at IS NOT NULL", + ); + + query_builder + .build() + .execute(tx.as_mut()) + .await + .map_err(|e| bdk::Error::Generic(e.to_string()))?; + } + + Ok(()) + } + #[instrument(name = "bdk.utxos.delete", skip_all)] pub async fn delete( &self, @@ -78,9 +151,11 @@ impl Utxos { .await .map_err(|e| bdk::Error::Generic(e.to_string()))?; - Ok(row.map(|row| { - serde_json::from_value::(row.utxo_json).expect("Could not deserialize utxo") - })) + row.map(|row| { + serde_json::from_value::(row.utxo_json) + .map_err(|e| bdk::Error::Generic(format!("could not deserialize utxo: {e}"))) + }) + .transpose() } #[instrument(name = "bdk.utxos.undelete", skip_all)] @@ -116,9 +191,11 @@ impl Utxos { .await .map_err(|e| bdk::Error::Generic(e.to_string()))?; - Ok(utxo.map(|utxo| { - serde_json::from_value(utxo.utxo_json).expect("Could not deserialize utxo") - })) + utxo.map(|utxo| { + serde_json::from_value(utxo.utxo_json) + .map_err(|e| bdk::Error::Generic(format!("could not deserialize utxo: {e}"))) + }) + .transpose() } #[instrument(name = "bdk.utxos.list_local_utxos", skip_all)] @@ -130,16 +207,19 @@ impl Utxos { .fetch_all(&self.pool) .await .map_err(|e| bdk::Error::Generic(e.to_string()))?; - Ok(utxos + utxos .into_iter() - .map(|utxo| serde_json::from_value(utxo.utxo_json).expect("Could not deserialize utxo")) - .collect()) + .map(|utxo| { + serde_json::from_value(utxo.utxo_json) + .map_err(|e| bdk::Error::Generic(format!("could not deserialize utxo: {e}"))) + }) + .collect() } #[instrument(name = "bdk.utxos.mark_as_synced", skip(self, tx))] pub async fn mark_as_synced( &self, - tx: &mut Transaction<'_, Postgres>, + tx: &mut SqlxTransaction<'_, Postgres>, utxo: &LocalUtxo, ) -> Result<(), BdkError> { sqlx::query!( @@ -157,7 +237,7 @@ impl Utxos { #[instrument(name = "bdk.utxos.mark_confirmed", skip(self, tx))] pub async fn mark_confirmed( &self, - tx: &mut Transaction<'_, Postgres>, + tx: &mut SqlxTransaction<'_, Postgres>, utxo: &LocalUtxo, ) -> Result<(), BdkError> { sqlx::query!( @@ -175,7 +255,7 @@ impl Utxos { #[instrument(name = "bdk.utxos.find_confirmed_income_utxo", skip(self, tx))] pub async fn find_confirmed_income_utxo( &self, - tx: &mut Transaction<'_, Postgres>, + tx: &mut SqlxTransaction<'_, Postgres>, min_height: u32, ) -> Result, BdkError> { let row = sqlx::query!( @@ -206,25 +286,28 @@ impl Utxos { .fetch_optional(&mut **tx) .await?; - Ok(row.map(|row| { - let local_utxo = serde_json::from_value::(row.utxo_json) - .expect("Could not deserialize utxo"); - let tx_details = serde_json::from_value::(row.details_json) - .expect("Could not deserialize tx details"); - ConfirmedIncomeUtxo { + if let Some(row) = row { + let local_utxo = serde_json::from_value::(row.utxo_json)?; + let tx_details = serde_json::from_value::(row.details_json)?; + let confirmation_time = tx_details.confirmation_time.ok_or_else(|| { + bdk::Error::Generic( + "missing confirmation_time in confirmed income transaction details".to_string(), + ) + })?; + Ok(Some(ConfirmedIncomeUtxo { outpoint: local_utxo.outpoint, spent: local_utxo.is_spent, - confirmation_time: tx_details - .confirmation_time - .expect("query should always return confirmation_time"), - } - })) + confirmation_time, + })) + } else { + Ok(None) + } } #[instrument(name = "bdk.utxos.find_and_remove_soft_deleted_utxo", skip_all)] pub async fn find_and_remove_soft_deleted_utxo( &self, - tx: &mut Transaction<'_, Postgres>, + tx: &mut SqlxTransaction<'_, Postgres>, ) -> Result, BdkError> { let row = sqlx::query!( r#"DELETE FROM bdk_utxos @@ -238,11 +321,12 @@ impl Utxos { ) .fetch_optional(&mut **tx) .await?; - Ok(row.map(|row| { - let local_utxo = serde_json::from_value::(row.utxo_json) - .expect("Could not deserialize the utxo"); + if let Some(row) = row { + let local_utxo = serde_json::from_value::(row.utxo_json)?; let keychain_id = KeychainId::from(row.keychain_id); - (local_utxo.outpoint, keychain_id) - })) + Ok(Some((local_utxo.outpoint, keychain_id))) + } else { + Ok(None) + } } }