diff --git a/chain-signatures/node/src/backlog/mod.rs b/chain-signatures/node/src/backlog/mod.rs index 0f400fb5..e5fd2e40 100644 --- a/chain-signatures/node/src/backlog/mod.rs +++ b/chain-signatures/node/src/backlog/mod.rs @@ -9,7 +9,7 @@ use crate::storage::checkpoint_storage::CheckpointStorage; use anyhow::Context; use mpc_primitives::{PendingTx, SignId}; -use std::collections::{hash_map, HashMap}; +use std::collections::{hash_map, HashMap, HashSet}; use std::hash::{Hash, Hasher}; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -185,6 +185,7 @@ struct HistoricalCheckpoint { pub struct Backlog { storage: CheckpointStorage, requests: Arc>>, + recovered_requests: Arc>>>, execution_watchers: Arc>>, /// Historical checkpoints kept for 30 minutes, indexed by chain historical_checkpoints: Arc>>>, @@ -197,12 +198,6 @@ pub enum RecoveryRequeueMode { AfterCatchup, } -#[derive(Debug, Default)] -pub struct RecoveredChainRequests { - pub pending: HashMap, - pub requeue_mode: RecoveryRequeueMode, -} - impl Default for Backlog { fn default() -> Self { Self::new() @@ -218,6 +213,7 @@ impl Backlog { Self { storage, requests: Arc::new(RwLock::new(HashMap::new())), + recovered_requests: Arc::new(RwLock::new(HashMap::new())), execution_watchers: Arc::new(RwLock::new(HashMap::new())), historical_checkpoints: Arc::new(RwLock::new(HashMap::new())), } @@ -235,6 +231,7 @@ impl Backlog { }; self.observe_backlog_size(chain, len); + self.unmark_recovered_request(chain, &id).await; prev } @@ -247,6 +244,7 @@ impl Backlog { }; self.observe_backlog_size(chain, len); + self.unmark_recovered_request(chain, id).await; removed } @@ -279,6 +277,60 @@ impl Backlog { .set(len as i64); } + async fn unmark_recovered_request(&self, chain: Chain, id: &SignId) { + let mut recovered_requests = self.recovered_requests.write().await; + let Some(recovered) = recovered_requests.get_mut(&chain) else { + return; + }; + + recovered.remove(id); + if recovered.is_empty() { + recovered_requests.remove(&chain); + } + } + + async fn set_recovered_requests(&self, chain: Chain, sign_ids: HashSet) { + let mut recovered_requests = self.recovered_requests.write().await; + match recovered_requests.entry(chain) { + hash_map::Entry::Vacant(entry) => { + entry.insert(sign_ids); + } + hash_map::Entry::Occupied(entry) => { + tracing::error!( + %chain, + new_requests_len = sign_ids.len(), + old_requests_len = entry.get().len(), + "attempting to set recovered requests but it already has an entry", + ); + } + } + } + + /// Removes recovered requests for a chain and returns a list of them filtered + /// to only those that should be enqueued for processing. + pub async fn take_requeueable_requests(&self, chain: Chain) -> Vec { + let recovered_sign_ids = { + let mut recovered_requests = self.recovered_requests.write().await; + let Some(recovered) = recovered_requests.remove(&chain) else { + return Vec::new(); + }; + recovered + }; + + let requests = self.requests.read().await; + let Some(pending) = requests.get(&chain) else { + return Vec::new(); + }; + + recovered_sign_ids + .into_iter() + .filter_map(|sign_id| pending.get(&sign_id)) + .filter(|entry| entry.status() == SignStatus::AwaitingResponse) + .filter(|entry| entry.execution_tx().is_none()) + .map(|entry| entry.request.clone()) + .collect() + } + /// Returns all sign-respond transactions with a specific status pub async fn get_by_status( &self, @@ -571,7 +623,7 @@ impl Backlog { node_client: &NodeClient, threshold: usize, chains: &[Chain], - ) -> HashMap { + ) -> HashMap { tracing::info!("attempting to recover from latest checkpoints via node selection"); // Load local checkpoints first @@ -635,27 +687,20 @@ impl Backlog { recovered_modes.insert(chain, requeue_mode); } - // Snapshot pending requests for the requested chains + // Mark the following sign_ids as recovered to requeue them after catchup. + // If they're removed before catchup completes, they're unmarked from recovery + // and will not be requeued let requests = self.requests.read().await; - let mut recovered = HashMap::new(); for &chain in chains { if let Some(pending) = requests.get(&chain) { - let requeue_mode = recovered_modes.get(&chain).copied().unwrap_or_default(); - recovered.insert( - chain, - RecoveredChainRequests { - pending: pending - .requests - .iter() - .map(|(id, entry)| (*id, entry.clone())) - .collect(), - requeue_mode, - }, - ); + let sign_ids: HashSet<_> = pending.requests.keys().copied().collect(); + if !sign_ids.is_empty() { + self.set_recovered_requests(chain, sign_ids).await; + } } } - recovered + recovered_modes } } diff --git a/chain-signatures/node/src/stream/mod.rs b/chain-signatures/node/src/stream/mod.rs index b549964c..4836ee27 100644 --- a/chain-signatures/node/src/stream/mod.rs +++ b/chain-signatures/node/src/stream/mod.rs @@ -111,7 +111,7 @@ pub async fn run_stream( tracing::info!(%chain, "starting indexer loop"); - let mut recovered = recover_backlog( + let requeue_mode = recover_backlog( &backlog, &mut contract_watcher, &mut mesh_state, @@ -151,15 +151,8 @@ pub async fn run_stream( } } ChainEvent::CatchupCompleted => { - if recovered.requeue_mode == crate::backlog::RecoveryRequeueMode::AfterCatchup { - requeue_recovered_sign_requests( - &backlog, - chain, - sign_tx.clone(), - &recovered.pending, - ) - .await; - recovered.pending.clear(); + if requeue_mode == crate::backlog::RecoveryRequeueMode::AfterCatchup { + requeue_recovered_sign_requests(&backlog, chain, sign_tx.clone()).await; } } ChainEvent::Block(block) => { @@ -746,4 +739,132 @@ mod tests { } assert!(backlog.get(Chain::Ethereum, &sign_id).await.is_none()); } + + #[tokio::test] + async fn test_stream_does_not_requeue_replaced_ethereum_recovery_entry_after_catchup() { + let storage = CheckpointStorage::in_memory(); + let seeded_backlog = Backlog::persisted(storage.clone()); + let sign_id = SignId::new([100u8; 32]); + let args = SignArgs { + entropy: [5u8; 32], + epsilon: Scalar::from(1u64), + payload: Scalar::from(2u64), + path: "test".to_string(), + key_version: 1, + }; + let recovered_timestamp = current_unix_timestamp(); + let replayed_timestamp = recovered_timestamp.saturating_add(1); + + seeded_backlog + .insert(IndexedSignRequest::sign( + sign_id, + args.clone(), + Chain::Ethereum, + recovered_timestamp, + )) + .await; + seeded_backlog + .set_processed_block(Chain::Ethereum, 100) + .await; + seeded_backlog.checkpoint(Chain::Ethereum).await; + + struct EthereumLocalStream { + events: Vec>, + } + + impl ChainStream for EthereumLocalStream { + const CHAIN: Chain = Chain::Ethereum; + + async fn next_event(&mut self) -> Option { + if self.events.is_empty() { + return None; + } + self.events.remove(0) + } + } + + let replacement = + IndexedSignRequest::sign(sign_id, args.clone(), Chain::Ethereum, replayed_timestamp); + let client = EthereumLocalStream { + events: vec![ + Some(ChainEvent::SignRequest(replacement)), + Some(ChainEvent::CatchupCompleted), + None, + ], + }; + + let backlog = Backlog::persisted(storage); + let (sign_tx, mut sign_rx) = mpsc::channel(8); + + let (contract_watcher, _tx) = ContractStateWatcher::with_running( + &"test.near".parse::().unwrap(), + k256::ProjectivePoint::GENERATOR.to_affine(), + 2, + Default::default(), + ); + + let mut servers = Vec::new(); + for _ in 0..2 { + let mut server = Server::new_async().await; + let mut body = Vec::new(); + ciborium::ser::into_writer( + &std::collections::HashMap::::new(), + &mut body, + ) + .unwrap(); + server + .mock("GET", "/checkpoint") + .with_status(200) + .with_body(body) + .create_async() + .await; + servers.push(server); + } + + let mut mesh_state = MeshState::default(); + for (index, server) in servers.iter().enumerate() { + let mut info = ParticipantInfo::new(index as u32); + info.url = server.url(); + mesh_state.update( + cait_sith::protocol::Participant::from(index as u32), + NodeStatus::Active, + info, + ); + } + let (_mesh_state_tx, mesh_state_rx) = tokio::sync::watch::channel(mesh_state); + let node_client = NodeClient::new(&Default::default()); + + run_stream( + client, + sign_tx, + backlog.clone(), + contract_watcher, + mesh_state_rx, + node_client, + ) + .await; + + let first = timeout(Duration::from_secs(1), sign_rx.recv()) + .await + .unwrap() + .unwrap(); + match first { + Sign::Request(req) => { + assert_eq!(req.id, sign_id); + assert_eq!(req.unix_timestamp_indexed, replayed_timestamp); + } + other => panic!("expected replayed sign request, got {other:?}"), + } + + match timeout(Duration::from_millis(100), sign_rx.recv()).await { + Err(_) | Ok(None) => {} + Ok(Some(msg)) => panic!("unexpected extra sign message after catchup: {msg:?}"), + } + + let entry = backlog + .get(Chain::Ethereum, &sign_id) + .await + .expect("replayed entry should remain in backlog"); + assert_eq!(entry.request.unix_timestamp_indexed, replayed_timestamp); + } } diff --git a/chain-signatures/node/src/stream/ops.rs b/chain-signatures/node/src/stream/ops.rs index f00b580c..5f5de66c 100644 --- a/chain-signatures/node/src/stream/ops.rs +++ b/chain-signatures/node/src/stream/ops.rs @@ -1,4 +1,4 @@ -use crate::backlog::{Backlog, BacklogEntry, RecoveredChainRequests, RecoveryRequeueMode}; +use crate::backlog::{Backlog, RecoveryRequeueMode}; use crate::indexer_hydration::{ HydrationRespondBidirectionalEvent, HydrationSignBidirectionalRequestedEvent, HydrationSignatureRespondedEvent, @@ -288,52 +288,34 @@ pub(crate) async fn recover_backlog( node_client: &NodeClient, source_chain: Chain, sign_tx: mpsc::Sender, -) -> RecoveredChainRequests { +) -> RecoveryRequeueMode { // Recover backlog before doing anything. // Wait for threshold to be available let threshold = contract_watcher.wait_threshold().await; if threshold == 0 { - return RecoveredChainRequests::default(); + return RecoveryRequeueMode::default(); } wait_threshold_active(mesh_state, threshold).await; let mesh_state = mesh_state.borrow().clone(); - let mut recovered = backlog + let mut requeue_modes = backlog .recover(&mesh_state, node_client, threshold, &[source_chain]) .await; - let recovered = recovered.remove(&source_chain).unwrap_or_default(); - - if recovered.requeue_mode == RecoveryRequeueMode::Immediate { - requeue_recovered_sign_requests(backlog, source_chain, sign_tx, &recovered.pending).await; + let requeue_mode = requeue_modes.remove(&source_chain).unwrap_or_default(); + if requeue_mode == RecoveryRequeueMode::Immediate { + requeue_recovered_sign_requests(backlog, source_chain, sign_tx).await; } - - recovered + requeue_mode } pub(crate) async fn requeue_recovered_sign_requests( backlog: &Backlog, source_chain: Chain, sign_tx: mpsc::Sender, - pending: &std::collections::HashMap, ) { - for &sign_id in pending.keys() { - let Some(entry) = backlog.get(source_chain, &sign_id).await else { - continue; - }; - - if entry.status() != SignStatus::AwaitingResponse { - continue; - } - - // This is a bidirectional execution watcher, so let's skip it and have - // the stream/indexer itself enqueue watching. - if entry.execution_tx().is_some() { - continue; - } - - let sign_request = entry.request; - + for sign_request in backlog.take_requeueable_requests(source_chain).await { + let sign_id = sign_request.id; if let Err(err) = sign_tx.send(Sign::Request(sign_request)).await { tracing::error!( ?err,