From 4e1ef119a6e4eb218c984fa5bd254f6015731a60 Mon Sep 17 00:00:00 2001 From: Zhou Fang Date: Fri, 10 Apr 2026 15:56:00 -0700 Subject: [PATCH] fix!: prune MPC message keyspaces by hashi epoch instead of chain epoch --- crates/hashi/src/db.rs | 341 +++++++++++++++----------------- crates/hashi/src/mpc/service.rs | 3 + 2 files changed, 159 insertions(+), 185 deletions(-) diff --git a/crates/hashi/src/db.rs b/crates/hashi/src/db.rs index 82dfbe830..d11776a67 100644 --- a/crates/hashi/src/db.rs +++ b/crates/hashi/src/db.rs @@ -77,7 +77,6 @@ impl Database { /// Store encryption key for the given epoch. /// /// No-op if a key already exists for this epoch (idempotent for restart safety). - /// Also cleans up old encryption keys (keeps only current and previous epoch). pub fn store_encryption_key( &self, epoch: u64, @@ -88,7 +87,6 @@ impl Database { let value = bcs::to_bytes(encryption_key).unwrap(); self.encryption_keys.insert(key, value)?; } - self.cleanup_old_encryption_keys(epoch)?; Ok(()) } @@ -126,26 +124,6 @@ impl Database { Ok(Some(EncryptionPrivateKey::from(scalar))) } - /// Clear encryption keys older than `current_epoch - 1` to limit exposure if the node is - /// compromised. - fn cleanup_old_encryption_keys(&self, current_epoch: u64) -> Result<()> { - let cutoff = current_epoch.saturating_sub(1); - let keys_to_delete: Vec<_> = self - .encryption_keys - .iter() - .filter_map(|guard| { - let key = guard.key().ok()?; - let epoch_bytes: [u8; 8] = key.as_ref().try_into().ok()?; - let epoch = u64::from_be_bytes(epoch_bytes); - if epoch < cutoff { Some(epoch) } else { None } - }) - .collect(); - for epoch in keys_to_delete { - self.encryption_keys.remove(epoch.to_be_bytes())?; - } - Ok(()) - } - pub fn store_dealer_message( &self, epoch: u64, @@ -154,8 +132,7 @@ impl Database { ) -> Result<()> { let key = [epoch.to_be_bytes().as_slice(), dealer.as_bytes()].concat(); let value = bcs::to_bytes(message).unwrap(); - self.dealer_messages.insert(key, value)?; - clean_up_old_epochs(&self.dealer_messages, epoch) + self.dealer_messages.insert(key, value) } pub fn get_dealer_message( @@ -193,8 +170,7 @@ impl Database { ) -> Result<()> { let key = [epoch.to_be_bytes().as_slice(), dealer.as_bytes()].concat(); let value = bcs::to_bytes(messages).unwrap(); - self.rotation_messages.insert(key, value)?; - clean_up_old_epochs(&self.rotation_messages, epoch) + self.rotation_messages.insert(key, value) } pub fn get_rotation_messages( @@ -238,8 +214,7 @@ impl Database { ] .concat(); let value = bcs::to_bytes(message).unwrap(); - self.nonce_messages.insert(key, value)?; - clean_up_old_epochs(&self.nonce_messages, epoch) + self.nonce_messages.insert(key, value) } pub fn get_nonce_message( @@ -305,6 +280,15 @@ impl Database { .concat(); self.nonce_messages.remove(key) } + + /// Prune all MPC keyspaces, deleting entries with `epoch < cutoff_epoch`. + pub fn prune_messages_below(&self, cutoff_epoch: u64) -> Result<()> { + prune_keyspace(&self.encryption_keys, cutoff_epoch)?; + prune_keyspace(&self.dealer_messages, cutoff_epoch)?; + prune_keyspace(&self.rotation_messages, cutoff_epoch)?; + prune_keyspace(&self.nonce_messages, cutoff_epoch)?; + Ok(()) + } } /// List all `(Address, T)` pairs from a keyspace where keys match the given prefix. @@ -341,21 +325,15 @@ fn list_messages_by_prefix( Ok(results) } -/// Delete entries from keyspace where key starts with epoch (big-endian u64) < cutoff. -/// Cutoff is `current_epoch - 1`, keeping current and previous epoch. -fn clean_up_old_epochs(keyspace: &Keyspace, current_epoch: u64) -> Result<()> { - let cutoff = current_epoch.saturating_sub(1); +/// Delete entries from `keyspace` whose leading big-endian u64 epoch is `< cutoff_epoch`. +fn prune_keyspace(keyspace: &Keyspace, cutoff_epoch: u64) -> Result<()> { let keys_to_delete: Vec<_> = keyspace .iter() .filter_map(|guard| { let key = guard.key().ok()?; let epoch_bytes: [u8; 8] = key.as_ref().get(..8)?.try_into().ok()?; let epoch = u64::from_be_bytes(epoch_bytes); - if epoch < cutoff { - Some(key.to_vec()) - } else { - None - } + (epoch < cutoff_epoch).then(|| key.to_vec()) }) .collect(); for key in keys_to_delete { @@ -446,45 +424,6 @@ mod tests { assert_eq!(private_key, db.get_encryption_key(100).unwrap().unwrap()); } - #[test] - fn test_automatic_cleanup_on_store() { - let tmpdir = tempfile::Builder::new().tempdir().unwrap(); - let db = Database::open(tmpdir.path()).unwrap(); - - let key1 = EncryptionPrivateKey::new(&mut rand::thread_rng()); - let key2 = EncryptionPrivateKey::new(&mut rand::thread_rng()); - let key3 = EncryptionPrivateKey::new(&mut rand::thread_rng()); - let key4 = EncryptionPrivateKey::new(&mut rand::thread_rng()); - let key5 = EncryptionPrivateKey::new(&mut rand::thread_rng()); - - // Store epoch 1 - cleanup(1) is no-op (epoch < 2) - db.store_encryption_key(1, &key1).unwrap(); - assert!(db.get_encryption_key(1).unwrap().is_some()); - - // Store epoch 2 - cleanup(2) cutoff=1, deletes nothing - db.store_encryption_key(2, &key2).unwrap(); - assert!(db.get_encryption_key(1).unwrap().is_some()); - assert!(db.get_encryption_key(2).unwrap().is_some()); - - // Store epoch 3 - cleanup(3) cutoff=2, deletes epoch 1 - db.store_encryption_key(3, &key3).unwrap(); - assert!(db.get_encryption_key(1).unwrap().is_none()); // deleted - assert!(db.get_encryption_key(2).unwrap().is_some()); - assert!(db.get_encryption_key(3).unwrap().is_some()); - - // Store epoch 4 - cleanup(4) cutoff=3, deletes epoch 2 - db.store_encryption_key(4, &key4).unwrap(); - assert!(db.get_encryption_key(2).unwrap().is_none()); // deleted - assert!(db.get_encryption_key(3).unwrap().is_some()); - assert!(db.get_encryption_key(4).unwrap().is_some()); - - // Store epoch 5 - cleanup(5) cutoff=4, deletes epoch 3 - db.store_encryption_key(5, &key5).unwrap(); - assert!(db.get_encryption_key(3).unwrap().is_none()); // deleted - assert_eq!(key4, db.get_encryption_key(4).unwrap().unwrap()); - assert_eq!(key5, db.get_encryption_key(5).unwrap().unwrap()); - } - #[test] fn test_latest_encryption_key_epoch() { let tmpdir = tempfile::Builder::new().tempdir().unwrap(); @@ -503,11 +442,10 @@ mod tests { db.store_encryption_key(8, &key2).unwrap(); assert_eq!(db.latest_encryption_key_epoch().unwrap(), Some(8)); - // After cleanup (store epoch 10 cleans up epoch 5), still returns latest + // Storing more keys keeps returning the latest let key3 = EncryptionPrivateKey::new(&mut rand::thread_rng()); db.store_encryption_key(10, &key3).unwrap(); assert_eq!(db.latest_encryption_key_epoch().unwrap(), Some(10)); - assert!(db.get_encryption_key(5).unwrap().is_none()); } #[test] @@ -552,39 +490,6 @@ mod tests { assert!(db.get_dealer_message(2, &dealer1).unwrap().is_some()); } - #[test] - fn test_dealer_messages_auto_cleanup() { - let tmpdir = tempfile::Builder::new().tempdir().unwrap(); - let db = Database::open(tmpdir.path()).unwrap(); - - let dealer = Address::new([1u8; 32]); - let message = create_test_message(); - - // Store in epoch 5 - cleanup happens for epochs < 4, but nothing exists yet - db.store_dealer_message(5, &dealer, &message).unwrap(); - assert!(db.get_dealer_message(5, &dealer).unwrap().is_some()); - - // Store in epoch 6 - cleanup happens for epochs < 5, so epoch 5 remains - db.store_dealer_message(6, &dealer, &message).unwrap(); - assert!(db.get_dealer_message(5, &dealer).unwrap().is_some()); - assert!(db.get_dealer_message(6, &dealer).unwrap().is_some()); - - // Store in epoch 7 - cleanup happens for epochs < 6, so epoch 5 is deleted - db.store_dealer_message(7, &dealer, &message).unwrap(); - assert!( - db.get_dealer_message(5, &dealer).unwrap().is_none(), - "epoch 5 should be cleaned up" - ); - assert!( - db.get_dealer_message(6, &dealer).unwrap().is_some(), - "epoch 6 should remain" - ); - assert!( - db.get_dealer_message(7, &dealer).unwrap().is_some(), - "epoch 7 should remain" - ); - } - #[test] fn test_list_all_dealer_messages() { let tmpdir = tempfile::Builder::new().tempdir().unwrap(); @@ -683,45 +588,6 @@ mod tests { assert_eq!(all.len(), 2); } - #[test] - fn test_rotation_messages_auto_cleanup() { - use std::collections::BTreeMap; - use std::num::NonZeroU16; - - let tmpdir = tempfile::Builder::new().tempdir().unwrap(); - let db = Database::open(tmpdir.path()).unwrap(); - - let dealer = Address::new([1u8; 32]); - let mut messages: BTreeMap = BTreeMap::new(); - messages.insert(NonZeroU16::new(1).unwrap(), create_test_message()); - - // Store in epoch 5 - cleanup happens for epochs < 4, but nothing exists yet - db.store_rotation_messages(5, &dealer, &messages).unwrap(); - assert_eq!(db.list_all_rotation_messages(5).unwrap().len(), 1); - - // Store in epoch 6 - cleanup happens for epochs < 5, so epoch 5 remains - db.store_rotation_messages(6, &dealer, &messages).unwrap(); - assert_eq!(db.list_all_rotation_messages(5).unwrap().len(), 1); - assert_eq!(db.list_all_rotation_messages(6).unwrap().len(), 1); - - // Store in epoch 7 - cleanup happens for epochs < 6, so epoch 5 is deleted - db.store_rotation_messages(7, &dealer, &messages).unwrap(); - assert!( - db.list_all_rotation_messages(5).unwrap().is_empty(), - "epoch 5 should be cleaned up" - ); - assert_eq!( - db.list_all_rotation_messages(6).unwrap().len(), - 1, - "epoch 6 should remain" - ); - assert_eq!( - db.list_all_rotation_messages(7).unwrap().len(), - 1, - "epoch 7 should remain" - ); - } - #[test] fn test_list_all_rotation_messages() { use std::collections::BTreeMap; @@ -867,41 +733,6 @@ mod tests { ); } - #[test] - fn test_nonce_messages_auto_cleanup() { - let tmpdir = tempfile::Builder::new().tempdir().unwrap(); - let db = Database::open(tmpdir.path()).unwrap(); - - let dealer = Address::new([1u8; 32]); - let message = create_test_nonce_message(); - - // Store in epoch 5 - db.store_nonce_message(5, 0, &dealer, &message).unwrap(); - assert_eq!(db.list_nonce_messages(5, 0).unwrap().len(), 1); - - // Store in epoch 6 - cleanup for epochs < 5, epoch 5 remains - db.store_nonce_message(6, 0, &dealer, &message).unwrap(); - assert_eq!(db.list_nonce_messages(5, 0).unwrap().len(), 1); - assert_eq!(db.list_nonce_messages(6, 0).unwrap().len(), 1); - - // Store in epoch 7 - cleanup for epochs < 6, epoch 5 is deleted - db.store_nonce_message(7, 0, &dealer, &message).unwrap(); - assert!( - db.list_nonce_messages(5, 0).unwrap().is_empty(), - "epoch 5 should be cleaned up" - ); - assert_eq!( - db.list_nonce_messages(6, 0).unwrap().len(), - 1, - "epoch 6 should remain" - ); - assert_eq!( - db.list_nonce_messages(7, 0).unwrap().len(), - 1, - "epoch 7 should remain" - ); - } - #[test] fn test_delete_dealer_message() { let tmpdir = tempfile::Builder::new().tempdir().unwrap(); @@ -998,4 +829,144 @@ mod tests { bcs::to_bytes(&message2).unwrap() ); } + + #[test] + fn test_store_does_not_prune() { + use std::collections::BTreeMap; + use std::num::NonZeroU16; + + let tmpdir = tempfile::Builder::new().tempdir().unwrap(); + let db = Database::open(tmpdir.path()).unwrap(); + + let dealer = Address::new([1u8; 32]); + let dealer_msg = create_test_message(); + let mut rotation_msgs: BTreeMap = BTreeMap::new(); + rotation_msgs.insert(NonZeroU16::new(1).unwrap(), create_test_message()); + let nonce_msg = create_test_nonce_message(); + let enc_key = EncryptionPrivateKey::new(&mut rand::thread_rng()); + + // Store at the "stuck" source epoch. + db.store_dealer_message(71, &dealer, &dealer_msg).unwrap(); + db.store_rotation_messages(71, &dealer, &rotation_msgs) + .unwrap(); + db.store_nonce_message(71, 0, &dealer, &nonce_msg).unwrap(); + db.store_encryption_key(71, &enc_key).unwrap(); + + // Chain advanced 16 epochs while hashi was stuck. Validator stores at the + // new target epoch. + db.store_dealer_message(87, &dealer, &dealer_msg).unwrap(); + db.store_rotation_messages(87, &dealer, &rotation_msgs) + .unwrap(); + db.store_nonce_message(87, 0, &dealer, &nonce_msg).unwrap(); + db.store_encryption_key(87, &enc_key).unwrap(); + + // The (epoch=71, *) entries must still be present. + assert!( + db.get_dealer_message(71, &dealer).unwrap().is_some(), + "dealer message at source epoch must survive a write at a much later epoch" + ); + assert!( + db.get_rotation_messages(71, &dealer).unwrap().is_some(), + "rotation messages at source epoch must survive a write at a much later epoch" + ); + assert!( + db.get_nonce_message(71, 0, &dealer).unwrap().is_some(), + "nonce message at source epoch must survive a write at a much later epoch" + ); + assert!( + db.get_encryption_key(71).unwrap().is_some(), + "encryption key at source epoch must survive a write at a much later epoch" + ); + } + + #[test] + fn test_prune_messages_below_basic() { + use std::collections::BTreeMap; + use std::num::NonZeroU16; + + let tmpdir = tempfile::Builder::new().tempdir().unwrap(); + let db = Database::open(tmpdir.path()).unwrap(); + + let dealer = Address::new([1u8; 32]); + let dealer_msg = create_test_message(); + let mut rotation_msgs: BTreeMap = BTreeMap::new(); + rotation_msgs.insert(NonZeroU16::new(1).unwrap(), create_test_message()); + let nonce_msg = create_test_nonce_message(); + let enc_key = EncryptionPrivateKey::new(&mut rand::thread_rng()); + + for epoch in 1..=10 { + db.store_dealer_message(epoch, &dealer, &dealer_msg) + .unwrap(); + db.store_rotation_messages(epoch, &dealer, &rotation_msgs) + .unwrap(); + db.store_nonce_message(epoch, 0, &dealer, &nonce_msg) + .unwrap(); + db.store_encryption_key(epoch, &enc_key).unwrap(); + } + + db.prune_messages_below(8).unwrap(); + + for epoch in 1..8 { + assert!( + db.get_dealer_message(epoch, &dealer).unwrap().is_none(), + "dealer message at epoch {epoch} should be pruned" + ); + assert!( + db.get_rotation_messages(epoch, &dealer).unwrap().is_none(), + "rotation messages at epoch {epoch} should be pruned" + ); + assert!( + db.get_nonce_message(epoch, 0, &dealer).unwrap().is_none(), + "nonce message at epoch {epoch} should be pruned" + ); + assert!( + db.get_encryption_key(epoch).unwrap().is_none(), + "encryption key at epoch {epoch} should be pruned" + ); + } + for epoch in 8..=10 { + assert!( + db.get_dealer_message(epoch, &dealer).unwrap().is_some(), + "dealer message at epoch {epoch} should be kept" + ); + assert!( + db.get_rotation_messages(epoch, &dealer).unwrap().is_some(), + "rotation messages at epoch {epoch} should be kept" + ); + assert!( + db.get_nonce_message(epoch, 0, &dealer).unwrap().is_some(), + "nonce message at epoch {epoch} should be kept" + ); + assert!( + db.get_encryption_key(epoch).unwrap().is_some(), + "encryption key at epoch {epoch} should be kept" + ); + } + } + + #[test] + fn test_prune_messages_below_zero_is_no_op() { + let tmpdir = tempfile::Builder::new().tempdir().unwrap(); + let db = Database::open(tmpdir.path()).unwrap(); + + let dealer = Address::new([1u8; 32]); + let message = create_test_message(); + for epoch in 5..=10 { + db.store_dealer_message(epoch, &dealer, &message).unwrap(); + } + + db.prune_messages_below(0).unwrap(); + + for epoch in 5..=10 { + assert!(db.get_dealer_message(epoch, &dealer).unwrap().is_some()); + } + } + + #[test] + fn test_prune_messages_below_empty_db() { + let tmpdir = tempfile::Builder::new().tempdir().unwrap(); + let db = Database::open(tmpdir.path()).unwrap(); + // Should be a no-op, not an error. + db.prune_messages_below(100).unwrap(); + } } diff --git a/crates/hashi/src/mpc/service.rs b/crates/hashi/src/mpc/service.rs index 308c88229..34d853fd1 100644 --- a/crates/hashi/src/mpc/service.rs +++ b/crates/hashi/src/mpc/service.rs @@ -605,6 +605,9 @@ impl MpcService { } } info!("end_reconfig complete for epoch {target_epoch}, running prepare_signing"); + if let Err(e) = self.inner.db.prune_messages_below(target_epoch) { + error!("Failed to prune old MPC messages below epoch {target_epoch}: {e}"); + } for attempt in 1..=MAX_PROTOCOL_ATTEMPTS { match self.prepare_signing(target_epoch, &output).await { Ok(()) => break,