Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 68 additions & 23 deletions chain-signatures/node/src/backlog/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::storage::checkpoint_storage::CheckpointStorage;

use anyhow::Context;
use mpc_primitives::{PendingTx, SignId};
use std::collections::{hash_map, HashMap};
use std::collections::{hash_map, HashMap, HashSet};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::time::{Duration, Instant};
Expand Down Expand Up @@ -185,6 +185,7 @@ struct HistoricalCheckpoint {
pub struct Backlog {
storage: CheckpointStorage,
requests: Arc<RwLock<HashMap<Chain, PendingRequests>>>,
recovered_requests: Arc<RwLock<HashMap<Chain, HashSet<SignId>>>>,
execution_watchers: Arc<RwLock<HashMap<Chain, ExecutionWatchers>>>,
/// Historical checkpoints kept for 30 minutes, indexed by chain
historical_checkpoints: Arc<RwLock<HashMap<Chain, Vec<HistoricalCheckpoint>>>>,
Expand All @@ -197,12 +198,6 @@ pub enum RecoveryRequeueMode {
AfterCatchup,
}

#[derive(Debug, Default)]
pub struct RecoveredChainRequests {
pub pending: HashMap<SignId, BacklogEntry>,
pub requeue_mode: RecoveryRequeueMode,
}

impl Default for Backlog {
fn default() -> Self {
Self::new()
Expand All @@ -218,6 +213,7 @@ impl Backlog {
Self {
storage,
requests: Arc::new(RwLock::new(HashMap::new())),
recovered_requests: Arc::new(RwLock::new(HashMap::new())),
execution_watchers: Arc::new(RwLock::new(HashMap::new())),
historical_checkpoints: Arc::new(RwLock::new(HashMap::new())),
}
Expand All @@ -235,6 +231,7 @@ impl Backlog {
};

self.observe_backlog_size(chain, len);
self.unmark_recovered_request(chain, &id).await;
prev
}

Expand All @@ -247,6 +244,7 @@ impl Backlog {
};

self.observe_backlog_size(chain, len);
self.unmark_recovered_request(chain, id).await;
removed
}

Expand Down Expand Up @@ -279,6 +277,60 @@ impl Backlog {
.set(len as i64);
}

async fn unmark_recovered_request(&self, chain: Chain, id: &SignId) {
let mut recovered_requests = self.recovered_requests.write().await;
let Some(recovered) = recovered_requests.get_mut(&chain) else {
return;
};

recovered.remove(id);
if recovered.is_empty() {
recovered_requests.remove(&chain);
}
}

async fn set_recovered_requests(&self, chain: Chain, sign_ids: HashSet<SignId>) {
let mut recovered_requests = self.recovered_requests.write().await;
match recovered_requests.entry(chain) {
hash_map::Entry::Vacant(entry) => {
entry.insert(sign_ids);
}
hash_map::Entry::Occupied(entry) => {
tracing::error!(
%chain,
new_requests_len = sign_ids.len(),
old_requests_len = entry.get().len(),
"attempting to set recovered requests but it already has an entry",
);
}
}
}

/// Removes recovered requests for a chain and returns a list of them filtered
/// to only those that should be enqueued for processing.
pub async fn take_requeueable_requests(&self, chain: Chain) -> Vec<IndexedSignRequest> {
let recovered_sign_ids = {
let mut recovered_requests = self.recovered_requests.write().await;
let Some(recovered) = recovered_requests.remove(&chain) else {
return Vec::new();
};
recovered
};

let requests = self.requests.read().await;
let Some(pending) = requests.get(&chain) else {
return Vec::new();
};

recovered_sign_ids
.into_iter()
.filter_map(|sign_id| pending.get(&sign_id))
.filter(|entry| entry.status() == SignStatus::AwaitingResponse)
.filter(|entry| entry.execution_tx().is_none())
.map(|entry| entry.request.clone())
.collect()
}

/// Returns all sign-respond transactions with a specific status
pub async fn get_by_status(
&self,
Expand Down Expand Up @@ -571,7 +623,7 @@ impl Backlog {
node_client: &NodeClient,
threshold: usize,
chains: &[Chain],
) -> HashMap<Chain, RecoveredChainRequests> {
) -> HashMap<Chain, RecoveryRequeueMode> {
tracing::info!("attempting to recover from latest checkpoints via node selection");

// Load local checkpoints first
Expand Down Expand Up @@ -635,27 +687,20 @@ impl Backlog {
recovered_modes.insert(chain, requeue_mode);
}

// Snapshot pending requests for the requested chains
// Mark the following sign_ids as recovered to requeue them after catchup.
// If they're removed before catchup completes, they're unmarked from recovery
// and will not be requeued
let requests = self.requests.read().await;
let mut recovered = HashMap::new();
for &chain in chains {
if let Some(pending) = requests.get(&chain) {
let requeue_mode = recovered_modes.get(&chain).copied().unwrap_or_default();
recovered.insert(
chain,
RecoveredChainRequests {
pending: pending
.requests
.iter()
.map(|(id, entry)| (*id, entry.clone()))
.collect(),
requeue_mode,
},
);
let sign_ids: HashSet<_> = pending.requests.keys().copied().collect();
if !sign_ids.is_empty() {
self.set_recovered_requests(chain, sign_ids).await;
}
Comment on lines 693 to +699
Copy link

Copilot AI Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backlog::recover holds self.requests.read().await across an .await when calling mark_recovered_requests(...). Awaiting while holding a Tokio RwLock guard can cause unnecessary contention (blocking writers like insert/remove/advance) and can contribute to deadlock scenarios if other code later introduces different lock ordering. Consider collecting the recovered sign_ids for each chain while holding the read lock, then dropping the guard before awaiting (or acquiring recovered_requests once and updating it without further awaits).

Copilot uses AI. Check for mistakes.
}
}

recovered
recovered_modes
}
}

Expand Down
141 changes: 131 additions & 10 deletions chain-signatures/node/src/stream/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ pub async fn run_stream<S: ChainStream>(

tracing::info!(%chain, "starting indexer loop");

let mut recovered = recover_backlog(
let requeue_mode = recover_backlog(
&backlog,
&mut contract_watcher,
&mut mesh_state,
Expand Down Expand Up @@ -151,15 +151,8 @@ pub async fn run_stream<S: ChainStream>(
}
}
ChainEvent::CatchupCompleted => {
if recovered.requeue_mode == crate::backlog::RecoveryRequeueMode::AfterCatchup {
requeue_recovered_sign_requests(
&backlog,
chain,
sign_tx.clone(),
&recovered.pending,
)
.await;
recovered.pending.clear();
if requeue_mode == crate::backlog::RecoveryRequeueMode::AfterCatchup {
requeue_recovered_sign_requests(&backlog, chain, sign_tx.clone()).await;
}
}
ChainEvent::Block(block) => {
Expand Down Expand Up @@ -746,4 +739,132 @@ mod tests {
}
assert!(backlog.get(Chain::Ethereum, &sign_id).await.is_none());
}

#[tokio::test]
async fn test_stream_does_not_requeue_replaced_ethereum_recovery_entry_after_catchup() {
let storage = CheckpointStorage::in_memory();
let seeded_backlog = Backlog::persisted(storage.clone());
let sign_id = SignId::new([100u8; 32]);
let args = SignArgs {
entropy: [5u8; 32],
epsilon: Scalar::from(1u64),
payload: Scalar::from(2u64),
path: "test".to_string(),
key_version: 1,
};
let recovered_timestamp = current_unix_timestamp();
let replayed_timestamp = recovered_timestamp.saturating_add(1);

seeded_backlog
.insert(IndexedSignRequest::sign(
sign_id,
args.clone(),
Chain::Ethereum,
recovered_timestamp,
))
.await;
seeded_backlog
.set_processed_block(Chain::Ethereum, 100)
.await;
seeded_backlog.checkpoint(Chain::Ethereum).await;

struct EthereumLocalStream {
events: Vec<Option<ChainEvent>>,
}

impl ChainStream for EthereumLocalStream {
const CHAIN: Chain = Chain::Ethereum;

async fn next_event(&mut self) -> Option<ChainEvent> {
if self.events.is_empty() {
return None;
}
self.events.remove(0)
}
}

let replacement =
IndexedSignRequest::sign(sign_id, args.clone(), Chain::Ethereum, replayed_timestamp);
let client = EthereumLocalStream {
events: vec![
Some(ChainEvent::SignRequest(replacement)),
Some(ChainEvent::CatchupCompleted),
None,
],
};

let backlog = Backlog::persisted(storage);
let (sign_tx, mut sign_rx) = mpsc::channel(8);

let (contract_watcher, _tx) = ContractStateWatcher::with_running(
&"test.near".parse::<AccountId>().unwrap(),
k256::ProjectivePoint::GENERATOR.to_affine(),
2,
Default::default(),
);

let mut servers = Vec::new();
for _ in 0..2 {
let mut server = Server::new_async().await;
let mut body = Vec::new();
ciborium::ser::into_writer(
&std::collections::HashMap::<Chain, crate::backlog::Checkpoint>::new(),
&mut body,
)
.unwrap();
server
.mock("GET", "/checkpoint")
.with_status(200)
.with_body(body)
.create_async()
.await;
servers.push(server);
}

let mut mesh_state = MeshState::default();
for (index, server) in servers.iter().enumerate() {
let mut info = ParticipantInfo::new(index as u32);
info.url = server.url();
mesh_state.update(
cait_sith::protocol::Participant::from(index as u32),
NodeStatus::Active,
info,
);
}
let (_mesh_state_tx, mesh_state_rx) = tokio::sync::watch::channel(mesh_state);
let node_client = NodeClient::new(&Default::default());

run_stream(
client,
sign_tx,
backlog.clone(),
contract_watcher,
mesh_state_rx,
node_client,
)
.await;

let first = timeout(Duration::from_secs(1), sign_rx.recv())
.await
.unwrap()
.unwrap();
match first {
Sign::Request(req) => {
assert_eq!(req.id, sign_id);
assert_eq!(req.unix_timestamp_indexed, replayed_timestamp);
}
other => panic!("expected replayed sign request, got {other:?}"),
}

match timeout(Duration::from_millis(100), sign_rx.recv()).await {
Err(_) | Ok(None) => {}
Ok(Some(msg)) => panic!("unexpected extra sign message after catchup: {msg:?}"),
}

let entry = backlog
.get(Chain::Ethereum, &sign_id)
.await
.expect("replayed entry should remain in backlog");
assert_eq!(entry.request.unix_timestamp_indexed, replayed_timestamp);
}
}
38 changes: 10 additions & 28 deletions chain-signatures/node/src/stream/ops.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::backlog::{Backlog, BacklogEntry, RecoveredChainRequests, RecoveryRequeueMode};
use crate::backlog::{Backlog, RecoveryRequeueMode};
use crate::indexer_hydration::{
HydrationRespondBidirectionalEvent, HydrationSignBidirectionalRequestedEvent,
HydrationSignatureRespondedEvent,
Expand Down Expand Up @@ -288,52 +288,34 @@ pub(crate) async fn recover_backlog(
node_client: &NodeClient,
source_chain: Chain,
sign_tx: mpsc::Sender<Sign>,
) -> RecoveredChainRequests {
) -> RecoveryRequeueMode {
// Recover backlog before doing anything.
// Wait for threshold to be available
let threshold = contract_watcher.wait_threshold().await;
if threshold == 0 {
return RecoveredChainRequests::default();
return RecoveryRequeueMode::default();
}
wait_threshold_active(mesh_state, threshold).await;

let mesh_state = mesh_state.borrow().clone();
let mut recovered = backlog
let mut requeue_modes = backlog
.recover(&mesh_state, node_client, threshold, &[source_chain])
.await;

let recovered = recovered.remove(&source_chain).unwrap_or_default();

if recovered.requeue_mode == RecoveryRequeueMode::Immediate {
requeue_recovered_sign_requests(backlog, source_chain, sign_tx, &recovered.pending).await;
let requeue_mode = requeue_modes.remove(&source_chain).unwrap_or_default();
if requeue_mode == RecoveryRequeueMode::Immediate {
requeue_recovered_sign_requests(backlog, source_chain, sign_tx).await;
}

recovered
requeue_mode
}

pub(crate) async fn requeue_recovered_sign_requests(
backlog: &Backlog,
source_chain: Chain,
sign_tx: mpsc::Sender<Sign>,
pending: &std::collections::HashMap<SignId, BacklogEntry>,
) {
for &sign_id in pending.keys() {
let Some(entry) = backlog.get(source_chain, &sign_id).await else {
continue;
};

if entry.status() != SignStatus::AwaitingResponse {
continue;
}

// This is a bidirectional execution watcher, so let's skip it and have
// the stream/indexer itself enqueue watching.
if entry.execution_tx().is_some() {
continue;
}

let sign_request = entry.request;

for sign_request in backlog.take_requeueable_requests(source_chain).await {
let sign_id = sign_request.id;
if let Err(err) = sign_tx.send(Sign::Request(sign_request)).await {
tracing::error!(
?err,
Expand Down
Loading