Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 73 additions & 46 deletions crates/hashi/src/mpc/mpc_except_signing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,6 @@ const EXPECT_THRESHOLD_VALIDATED: &str = "Threshold already validated";
const EXPECT_THRESHOLD_MET: &str = "Already checked earlier that threshold is met";
const EXPECT_SERIALIZATION_SUCCESS: &str = "Serialization should always succeed";

// DKG protocol
// 1) A dealer sends out a message to all parties containing the encrypted shares and the public keys of the nonces.
// 2) Each party verifies the message and returns a signature. Once sufficient valid signatures are received from the parties, the dealer sends a certificate to Sui (TOB).
// 3) Once sufficient valid certificates are received, a party completes the protocol locally by aggregating the shares from the dealers.
pub struct MpcManager {
// Immutable during the epoch
pub party_id: PartyId,
Expand All @@ -98,7 +94,8 @@ pub struct MpcManager {
previous_output: Option<MpcOutput>,
pub batch_size_per_weight: u16,

// Mutable during the epoch
// TODO: Rename these fields so it is clear at the call site which are
// backed by persistent store and which live only in memory.
pub dealer_outputs: HashMap<DealerOutputsKey, avss::PartialOutput>,
pub dkg_messages: HashMap<Address, avss::Message>,
pub rotation_messages: HashMap<Address, RotationMessages>,
Expand Down Expand Up @@ -258,11 +255,10 @@ impl MpcManager {
if let Some(response) = self.message_responses.get(&sender) {
return Ok(response.clone());
}
return Err(MpcError::InvalidMessage {
sender,
reason: "Message previously received but no valid response was produced"
.to_string(),
});
tracing::info!(
"handle_send_messages_request: existing message from {sender:?} but no \
cached response (e.g. post-restart), re-processing"
);
}
let signature = match &request.messages {
Messages::Dkg(msg) => {
Expand Down Expand Up @@ -546,7 +542,11 @@ impl MpcManager {
}
}
}
// Optimization: a node that fell back to the new-member path has empty
// key shares and cannot generate valid rotation messages.
let has_previous_shares = !previous.key_shares.shares.is_empty();
if is_member_of_previous_committee
&& has_previous_shares
&& {
let certified = ordered_broadcast_channel.certified_dealers().await;
let mgr = mpc_manager.read().unwrap();
Expand All @@ -559,6 +559,10 @@ impl MpcManager {
let certified_share_count: usize = certified
.iter()
.filter_map(|d| {
let messages = mgr.rotation_messages.get(d)?;
if messages.is_empty() {
return None;
}
let party_id = prev_committee.index_of(d)? as u16;
prev_nodes.share_ids_of(party_id).ok()
})
Expand Down Expand Up @@ -876,14 +880,21 @@ impl MpcManager {
.expect("certificate verified above")
};
let epoch = mpc_manager.read().unwrap().mpc_config.epoch;
Self::recover_shares_via_complaint(
let recovered = Self::recover_shares_via_complaint(
mpc_manager,
&dealer,
signers,
p2p_channel,
epoch,
)
.await?;
{
let mut mgr = mpc_manager.write().unwrap();
mgr.dealer_outputs
.insert(DealerOutputsKey::Dkg(dealer), recovered);
mgr.complaints_to_process
.remove(&ComplaintsToProcessKey::Dkg(dealer));
}
}
let dealer_weight = {
let mgr = mpc_manager.read().unwrap();
Expand Down Expand Up @@ -1102,7 +1113,7 @@ impl MpcManager {
.expect("certificate verified above")
};
let epoch = mpc_manager.read().unwrap().mpc_config.epoch;
Self::recover_rotation_shares_via_complaints(
let recovered = Self::recover_rotation_shares_via_complaints(
mpc_manager,
&dealer,
previous,
Expand All @@ -1111,6 +1122,15 @@ impl MpcManager {
epoch,
)
.await?;
{
let mut mgr = mpc_manager.write().unwrap();
for (share_index, output) in recovered {
mgr.dealer_outputs
.insert(DealerOutputsKey::Rotation(share_index), output);
mgr.complaints_to_process
.remove(&ComplaintsToProcessKey::Rotation(dealer, share_index));
}
}
// Only add indices that have outputs (avoids adding indices for
// dealers with empty rotation messages, e.g. a node that rejoined
// with no shares from the new-member fallback).
Expand Down Expand Up @@ -1381,6 +1401,8 @@ impl MpcManager {
Ok(())
}

// TODO: Change return type to `MpcResult<()>` and propagate disk errors
// (mirroring `store_dkg_message` and `store_rotation_messages`).
fn store_nonce_message(&mut self, dealer: Address, nonce: &NonceMessage) {
self.nonce_messages.insert(dealer, nonce.clone());
if let Err(e) = self.public_messages_store.store_nonce_message(
Expand Down Expand Up @@ -2035,7 +2057,7 @@ impl MpcManager {
signers: Vec<Address>,
p2p_channel: &impl P2PChannel,
epoch: u64,
) -> MpcResult<()> {
) -> MpcResult<avss::PartialOutput> {
let (complaint_request, receiver, message) = {
let mgr = mpc_manager.read().unwrap();
let complaint = mgr
Expand Down Expand Up @@ -2098,12 +2120,7 @@ impl MpcManager {
};
match result {
Ok(partial_output) => {
let mut mgr = mpc_manager.write().unwrap();
mgr.dealer_outputs
.insert(DealerOutputsKey::Dkg(*dealer), partial_output);
mgr.complaints_to_process
.remove(&ComplaintsToProcessKey::Dkg(*dealer));
return Ok(());
return Ok(partial_output);
}
Err(FastCryptoError::InputTooShort(_)) => {
continue;
Expand Down Expand Up @@ -2225,15 +2242,15 @@ impl MpcManager {
signers: Vec<Address>,
p2p_channel: &impl P2PChannel,
epoch: u64,
) -> MpcResult<()> {
) -> MpcResult<HashMap<ShareIndex, avss::PartialOutput>> {
let (request, recovery_contexts) = {
let mgr = mpc_manager.read().unwrap();
let Some(RotationComplainContext {
request,
recovery_contexts,
}) = mgr.prepare_rotation_complain_request(dealer, previous_dkg_output, epoch)?
else {
return Ok(());
return Ok(HashMap::new());
};
tracing::info!(
"Rotation complaint detected for dealer {:?}, recovering via Complain RPC",
Expand All @@ -2252,6 +2269,7 @@ impl MpcManager {
Vec<complaint::ComplaintResponse<avss::SharesForNode>>,
> = HashMap::new();
let mut pending_shares: HashSet<ShareIndex> = HashSet::new();
let mut recovered_outputs: HashMap<ShareIndex, avss::PartialOutput> = HashMap::new();
for &share_index in recovery_contexts.keys() {
all_responses.insert(share_index, Vec::new());
pending_shares.insert(share_index);
Expand Down Expand Up @@ -2287,15 +2305,7 @@ impl MpcManager {
};
match result {
Ok(partial_output) => {
let mut mgr = mpc_manager.write().unwrap();
mgr.dealer_outputs.insert(
DealerOutputsKey::Rotation(share_index),
partial_output,
);
mgr.complaints_to_process.remove(
&ComplaintsToProcessKey::Rotation(*dealer, share_index),
);
drop(mgr);
recovered_outputs.insert(share_index, partial_output);
pending_shares.remove(&share_index);
}
Err(FastCryptoError::InputTooShort(_)) => {
Expand Down Expand Up @@ -2327,7 +2337,7 @@ impl MpcManager {
dealer, pending_shares
)));
}
Ok(())
Ok(recovered_outputs)
}

fn load_stored_messages(&mut self) -> MpcResult<()> {
Expand Down Expand Up @@ -2594,16 +2604,21 @@ impl MpcManager {
pub fn reconstruct_previous_output(
&self,
certificates: &[CertificateV1],
complaint_cache: &HashMap<DealerOutputsKey, avss::PartialOutput>,
) -> MpcResult<ReconstructionOutcome> {
match certificates.first() {
Some(CertificateV1::Dkg(_)) | None => {
self.reconstruct_from_dkg_certificates(certificates)
self.reconstruct_from_dkg_certificates(certificates, complaint_cache)
}
Some(CertificateV1::Rotation(_)) => {
let previous_threshold = self.previous_threshold.ok_or_else(|| {
MpcError::InvalidConfig("Key rotation requires previous threshold".into())
})?;
self.reconstruct_from_rotation_certificates(certificates, previous_threshold)
self.reconstruct_from_rotation_certificates(
certificates,
previous_threshold,
complaint_cache,
)
}
Some(CertificateV1::NonceGeneration { .. }) => {
unreachable!(
Expand All @@ -2616,6 +2631,7 @@ impl MpcManager {
fn reconstruct_from_dkg_certificates(
&self,
certificates: &[CertificateV1],
complaint_cache: &HashMap<DealerOutputsKey, avss::PartialOutput>,
) -> MpcResult<ReconstructionOutcome> {
let previous_committee = self.previous_committee.clone().ok_or_else(|| {
MpcError::InvalidConfig("DKG reconstruction requires previous committee".into())
Expand Down Expand Up @@ -2672,11 +2688,7 @@ impl MpcManager {
let session_id = source_session_id
.dealer_session_id(&dealer_address)
.to_vec();
// Check for previously recovered output (from complaint recovery on a prior attempt).
if let Some(output) = self
.dealer_outputs
.get(&DealerOutputsKey::Dkg(dealer_address))
{
if let Some(output) = complaint_cache.get(&DealerOutputsKey::Dkg(dealer_address)) {
outputs.insert(dealer_party_id, output.clone());
let dealer_weight = previous_nodes
.weight_of(dealer_party_id)
Expand Down Expand Up @@ -2746,6 +2758,7 @@ impl MpcManager {
&self,
certificates: &[CertificateV1],
previous_threshold: u16,
complaint_cache: &HashMap<DealerOutputsKey, avss::PartialOutput>,
) -> MpcResult<ReconstructionOutcome> {
let previous_nodes = self.previous_nodes.clone().ok_or_else(|| {
MpcError::InvalidConfig("Rotation reconstruction requires previous nodes".into())
Expand Down Expand Up @@ -2793,13 +2806,10 @@ impl MpcManager {
)));
}
for (share_index, message) in rotation_msgs {
// Check for previously recovered output (from complaint recovery on a prior attempt).
if let Some(output) = self
.dealer_outputs
.get(&DealerOutputsKey::Rotation(share_index))
if let Some(output) = complaint_cache.get(&DealerOutputsKey::Rotation(share_index))
{
tracing::info!(
"reconstruct_from_rotation_certificates: cache hit for \
"reconstruct_from_rotation_certificates: complaint cache hit for \
dealer {:?} share_index={share_index}",
dealer_address,
);
Expand Down Expand Up @@ -3020,12 +3030,14 @@ impl MpcManager {
previous_certificates: &[CertificateV1],
p2p_channel: &impl P2PChannel,
) -> MpcResult<MpcOutput> {
let mut complaint_cache: HashMap<DealerOutputsKey, avss::PartialOutput> = HashMap::new();
loop {
let mgr = Arc::clone(mpc_manager);
let certs = previous_certificates.to_vec();
let cache_snapshot = complaint_cache.clone();
match spawn_blocking(move || {
let mgr = mgr.read().unwrap();
mgr.reconstruct_previous_output(&certs)
mgr.reconstruct_previous_output(&certs, &cache_snapshot)
})
.await?
{
Expand Down Expand Up @@ -3085,14 +3097,21 @@ impl MpcManager {
match protocol_type {
ProtocolTypeIndicator::Dkg => {
let source_epoch = mpc_manager.read().unwrap().source_epoch;
Self::recover_shares_via_complaint(
let recovered = Self::recover_shares_via_complaint(
mpc_manager,
&dealer_address,
signers,
p2p_channel,
source_epoch,
)
.await?;
complaint_cache
.insert(DealerOutputsKey::Dkg(dealer_address), recovered);
mpc_manager
.write()
.unwrap()
.complaints_to_process
.remove(&ComplaintsToProcessKey::Dkg(dealer_address));
}
ProtocolTypeIndicator::KeyRotation => {
let (previous_output, source_epoch) = {
Expand All @@ -3104,7 +3123,7 @@ impl MpcManager {
mgr.source_epoch,
)
};
Self::recover_rotation_shares_via_complaints(
let recovered = Self::recover_rotation_shares_via_complaints(
mpc_manager,
&dealer_address,
&previous_output,
Expand All @@ -3113,6 +3132,14 @@ impl MpcManager {
source_epoch,
)
.await?;
let mut mgr = mpc_manager.write().unwrap();
for (share_index, output) in recovered {
complaint_cache
.insert(DealerOutputsKey::Rotation(share_index), output);
mgr.complaints_to_process.remove(
&ComplaintsToProcessKey::Rotation(dealer_address, share_index),
);
}
}
ProtocolTypeIndicator::NonceGeneration => {}
}
Expand Down
Loading
Loading