diff --git a/Cargo.lock b/Cargo.lock index 2a722feb..e1f75f9e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5641,6 +5641,7 @@ dependencies = [ "serde_json", "sha2", "tempfile", + "test-case", "thiserror 2.0.18", "tokio", "tokio-util", diff --git a/crates/dkg/Cargo.toml b/crates/dkg/Cargo.toml index 1d1a19b9..cb8f0a7a 100644 --- a/crates/dkg/Cargo.toml +++ b/crates/dkg/Cargo.toml @@ -13,6 +13,7 @@ thiserror.workspace = true libp2p.workspace = true futures.workspace = true tokio.workspace = true +tokio-util.workspace = true sha2.workspace = true tracing.workspace = true either.workspace = true @@ -35,6 +36,7 @@ pluto-build-proto.workspace = true [dev-dependencies] anyhow.workspace = true +test-case.workspace = true clap.workspace = true hex.workspace = true pluto-cluster = { workspace = true, features = ["test-cluster"] } diff --git a/crates/dkg/examples/bcast.rs b/crates/dkg/examples/bcast.rs index ce3f006f..03027336 100644 --- a/crates/dkg/examples/bcast.rs +++ b/crates/dkg/examples/bcast.rs @@ -299,14 +299,16 @@ async fn register_message(component: &Component, local_node_number: u32) -> bcas Ok(()) }), Box::new(move |peer_id, received_msg_id, msg| { - info!( - local_node = local_node_number, - sender = %peer_id, - msg_id = received_msg_id, - msg = ?msg, - "Received broadcast" - ); - Ok(()) + Box::pin(async move { + info!( + local_node = local_node_number, + sender = %peer_id, + msg_id = received_msg_id, + msg = ?msg, + "Received broadcast" + ); + Ok(()) + }) }), ) .await diff --git a/crates/dkg/src/bcast/behaviour.rs b/crates/dkg/src/bcast/behaviour.rs index 6f04dd1d..1e2d7798 100644 --- a/crates/dkg/src/bcast/behaviour.rs +++ b/crates/dkg/src/bcast/behaviour.rs @@ -585,15 +585,18 @@ mod tests { "timestamp", Box::new(|_peer_id, _msg| Ok(())), Box::new(move |peer_id, msg_id, msg| { - receipt_tx - .send(Receipt { - target: node_index, - source: peer_id, - msg_id: msg_id.to_string(), - seconds: msg.seconds, - }) - .map_err(|_| Error::ReceiptChannelClosed)?; - Ok(()) + let receipt_tx = receipt_tx.clone(); + Box::pin(async move { + receipt_tx + .send(Receipt { + target: node_index, + source: peer_id, + msg_id, + seconds: msg.seconds, + }) + .map_err(|_| Error::ReceiptChannelClosed)?; + Ok(()) + }) }), ) .await diff --git a/crates/dkg/src/bcast/component.rs b/crates/dkg/src/bcast/component.rs index 5423b992..c585dd2c 100644 --- a/crates/dkg/src/bcast/component.rs +++ b/crates/dkg/src/bcast/component.rs @@ -2,6 +2,7 @@ use std::{collections::HashMap, sync::Arc}; +use futures::future::BoxFuture; use libp2p::PeerId; use prost::{Message, Name}; use prost_types::Any; @@ -13,7 +14,12 @@ use super::error::{Error, Result}; pub type CheckFn = Box Result<()> + Send + Sync + 'static>; /// Typed message callback invoked for validated broadcast messages. -pub type CallbackFn = Box Result<()> + Send + Sync + 'static>; +/// +/// The returned future is awaited by the inbound message handler, allowing +/// the callback to perform async operations (e.g. waiting for state that +/// becomes available later). +pub type CallbackFn = + Box BoxFuture<'static, Result<()>> + Send + Sync + 'static>; pub(crate) type Registry = Arc>>>; @@ -33,7 +39,8 @@ pub(crate) trait RegisteredMessage: Send + Sync { fn check(&self, peer_id: PeerId, any: &Any) -> Result<()>; /// Dispatches the incoming wrapped protobuf message to the typed callback. - fn callback(&self, peer_id: PeerId, msg_id: &str, any: &Any) -> Result<()>; + fn callback(&self, peer_id: PeerId, msg_id: String, any: Any) + -> BoxFuture<'static, Result<()>>; } struct TypedRegistration { @@ -50,9 +57,16 @@ where (self.check)(peer_id, &message) } - fn callback(&self, peer_id: PeerId, msg_id: &str, any: &Any) -> Result<()> { - let message = any.to_msg::()?; - (self.callback)(peer_id, msg_id, message) + fn callback( + &self, + peer_id: PeerId, + msg_id: String, + any: Any, + ) -> BoxFuture<'static, Result<()>> { + match any.to_msg::() { + Ok(message) => (self.callback)(peer_id, msg_id, message), + Err(e) => Box::pin(async move { Err(e.into()) }), + } } } @@ -138,7 +152,7 @@ mod tests { .register_message::( "timestamp", Box::new(|_, _| Ok(())), - Box::new(|_, _, _| Ok(())), + Box::new(|_, _, _| Box::pin(async { Ok(()) })), ) .await .unwrap(); @@ -147,7 +161,7 @@ mod tests { .register_message::( "timestamp", Box::new(|_, _| Ok(())), - Box::new(|_, _, _| Ok(())), + Box::new(|_, _, _| Box::pin(async { Ok(()) })), ) .await .unwrap_err(); diff --git a/crates/dkg/src/bcast/error.rs b/crates/dkg/src/bcast/error.rs index e8e20468..60c58b35 100644 --- a/crates/dkg/src/bcast/error.rs +++ b/crates/dkg/src/bcast/error.rs @@ -46,6 +46,25 @@ impl Failure { } } +/// Peer IDs involved in an [`Error::InvalidSenderPeerIndex`] error. +#[derive(Debug)] +pub struct SenderPeerMismatch { + /// The peer ID of the actual sender. + pub sender: PeerId, + /// The peer ID expected at the claimed index. + pub expected: PeerId, +} + +impl fmt::Display for SenderPeerMismatch { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "sender peer ID ({}) does not match claimed peer index {}", + self.sender, self.expected + ) + } +} + /// User-facing reliable-broadcast error. #[derive(Debug, thiserror::Error)] pub enum Error { @@ -95,6 +114,15 @@ pub enum Error { #[error("invalid signature for peer {0}")] InvalidSignature(PeerId), + /// The peer index in the message is out of range or matches the local node. + #[error("invalid peer index: {0}")] + InvalidPeerIndex(PeerId), + + /// The sender's peer index in the message does not match the sender's + /// actual index. + #[error("{0}")] + InvalidSenderPeerIndex(Box), + /// The repeated hash for the same `(peer, msg_id)` differed. #[error("duplicate id with mismatching hash")] DuplicateMismatchingHash, @@ -119,6 +147,10 @@ pub enum Error { #[error("receipt channel closed")] ReceiptChannelClosed, + /// The operation was cancelled. + #[error("cancelled")] + Cancelled, + /// A required message body field was absent. #[error("missing protobuf field: {0}")] MissingField(&'static str), diff --git a/crates/dkg/src/bcast/handler.rs b/crates/dkg/src/bcast/handler.rs index d36cc7de..1f7d9769 100644 --- a/crates/dkg/src/bcast/handler.rs +++ b/crates/dkg/src/bcast/handler.rs @@ -441,7 +441,7 @@ async fn handle_inbound_msg( .ok_or_else(|| Error::UnknownMessageId(message.id.clone()))? }; - handler.callback(peer_id, &message.id, &any)?; + handler.callback(peer_id, message.id, any).await?; stream.close().await?; Ok(()) } diff --git a/crates/dkg/src/bcast/mod.rs b/crates/dkg/src/bcast/mod.rs index 3ab8116d..def5e76e 100644 --- a/crates/dkg/src/bcast/mod.rs +++ b/crates/dkg/src/bcast/mod.rs @@ -10,7 +10,7 @@ mod protocol; pub use behaviour::{Behaviour, Event}; pub use component::{CallbackFn, CheckFn, Component}; -pub use error::{Error, Failure, Result}; +pub use error::{Error, Failure, Result, SenderPeerMismatch}; /// The request-response protocol used to gather peer signatures. pub const SIG_PROTOCOL_NAME: StreamProtocol = StreamProtocol::new("/charon/dkg/bcast/1.0.0/sig"); diff --git a/crates/dkg/src/lib.rs b/crates/dkg/src/lib.rs index ce0e6cde..79e95f2c 100644 --- a/crates/dkg/src/lib.rs +++ b/crates/dkg/src/lib.rs @@ -17,5 +17,8 @@ pub mod disk; /// Main DKG protocol implementation. pub mod dkg; +/// Node signature exchange over the lock hash. +pub mod nodesigs; + /// Shares distributed to each node in the cluster. pub mod share; diff --git a/crates/dkg/src/nodesigs.rs b/crates/dkg/src/nodesigs.rs new file mode 100644 index 00000000..1d0c0adb --- /dev/null +++ b/crates/dkg/src/nodesigs.rs @@ -0,0 +1,615 @@ +//! Handles broadcasting of K1 signatures over the lock hash via the bcast +//! protocol. + +use std::{ + sync::{Arc, Mutex}, + time::Duration, +}; + +use k256::SecretKey; +use libp2p::PeerId; +use pluto_p2p::peer::Peer; +use tokio::sync::watch; +use tokio_util::sync::CancellationToken; + +use crate::{ + bcast::{self, Component}, + dkgpb::v1::nodesigs::MsgNodeSig, +}; + +/// The message ID used for node signature broadcasts. +const NODE_SIG_MSG_ID: &str = "/charon/dkg/node_sig"; + +/// Sentinel value used in place of a real signature when a peer has nothing to +/// sign. Filling the slot with this value unblocks `all_sigs` without +/// contributing a real signature to the result. +const NONE_DATA: [u8; 4] = [0xde, 0xad, 0xbe, 0xef]; + +/// Error returned by [`NodeSigBcast`] operations. +#[derive(Debug, thiserror::Error)] +pub enum Error { + /// Signing the lock hash with the local K1 key failed. + #[error("k1 lock hash signature: {0}")] + Sign(#[from] pluto_k1util::K1UtilError), + + /// Broadcasting or registering the broadcast message failed. + #[error("k1 lock hash signature broadcast: {0}")] + Broadcast(#[from] bcast::Error), + + /// The exchange was cancelled before all signatures were collected. + #[error("cancelled")] + Cancelled, + + /// The local node index cannot be represented as a u32. + #[error("node index {0} exceeds u32 range")] + NodeIndexOutOfRange(usize), +} + +/// Alias for `Result`. +pub type Result = std::result::Result; + +/// Handles broadcasting of K1 signatures over the lock hash via the bcast +/// protocol. +pub struct NodeSigBcast { + sigs: Arc>>>>, + bcast: Component, + node_idx: usize, + lock_hash_tx: watch::Sender>>, +} + +impl NodeSigBcast { + /// Returns a new instance, registering bcast handlers on `bcast_comp`. + pub async fn new( + peers: Vec, + node_idx: usize, + bcast_comp: Component, + token: CancellationToken, + ) -> Result { + let sigs = Arc::new(Mutex::new(vec![None::>; peers.len()])); + let (lock_hash_tx, lock_hash_rx) = watch::channel(None::>); + + let sigs_cb = Arc::clone(&sigs); + let peers = Arc::new(peers); + + bcast_comp + .register_message::( + NODE_SIG_MSG_ID, + Box::new(|_peer_id, _msg| Ok(())), + Box::new(move |peer_id, _msg_id, msg| { + let peers = Arc::clone(&peers); + let lock_hash_rx = lock_hash_rx.clone(); + let sigs = Arc::clone(&sigs_cb); + let token = token.clone(); + Box::pin(async move { + receive(peer_id, msg, node_idx, &peers, lock_hash_rx, &sigs, token).await + }) + }), + ) + .await?; + + Ok(Self { + sigs, + bcast: bcast_comp, + node_idx, + lock_hash_tx, + }) + } + + /// Exchanges K1 signatures over the lock hash with all peers. + /// + /// Signs `lock_hash` with `key`, broadcasts the signature to all peers, and + /// polls until every peer's signature has been received and verified. + /// Returns all collected signatures ordered by peer index. + pub async fn exchange( + self, + key: Option<&SecretKey>, + lock_hash: impl AsRef<[u8]>, + token: CancellationToken, + ) -> Result>> { + let (local_sig, lock_hash) = if let Some(k) = key { + let sig = pluto_k1util::sign(k, lock_hash.as_ref())?.to_vec(); + (sig, lock_hash.as_ref().to_vec()) + } else { + (NONE_DATA.to_vec(), NONE_DATA.to_vec()) + }; + + // Make the lock hash available to incoming callbacks before broadcasting. + // Only fails if all receivers are dropped, which cannot happen here. + let _ = self.lock_hash_tx.send(Some(lock_hash)); + + let peer_index = + u32::try_from(self.node_idx).map_err(|_| Error::NodeIndexOutOfRange(self.node_idx))?; + + let bcast_data = MsgNodeSig { + signature: local_sig.clone().into(), + peer_index, + }; + + tracing::debug!("Exchanging node signatures"); + + self.bcast.broadcast(NODE_SIG_MSG_ID, &bcast_data).await?; + + { + let mut sigs = self.sigs.lock().unwrap_or_else(|e| e.into_inner()); + sigs[self.node_idx] = Some(local_sig); + } + + let mut ticker = tokio::time::interval(Duration::from_millis(100)); + + loop { + tokio::select! { + () = token.cancelled() => return Err(Error::Cancelled), + _ = ticker.tick() => { + if let Some(all) = all_sigs(&self.sigs.lock().unwrap_or_else(|e| e.into_inner())) { + return Ok(all); + } + } + } + } + } +} + +/// Returns a copy of all signatures if every slot is filled, otherwise `None`. +fn all_sigs(sigs: &[Option>]) -> Option>> { + sigs.iter() + .filter(|slot| slot.as_deref() != Some(&NONE_DATA)) + .cloned() + .collect() +} + +/// Validates and stores an incoming node signature message. +/// +/// Waits for the lock hash to become available via the watch channel before +/// verifying the signature. Returns [`bcast::Error::Cancelled`] if `token` is +/// cancelled while waiting. +async fn receive( + peer_id: PeerId, + msg: MsgNodeSig, + node_idx: usize, + peers: &[Peer], + lock_hash_rx: watch::Receiver>>, + sigs: &Mutex>>>, + token: CancellationToken, +) -> bcast::Result<()> { + let peer_idx = usize::try_from(msg.peer_index).expect("peer_index out of usize range"); + + if peer_idx == node_idx || peer_idx >= peers.len() { + return Err(bcast::Error::InvalidPeerIndex(peer_id)); + } + + if peers[peer_idx].id != peer_id { + return Err(bcast::Error::InvalidSenderPeerIndex(Box::new( + bcast::SenderPeerMismatch { + sender: peer_id, + expected: peers[peer_idx].id, + }, + ))); + } + + if msg.signature.as_ref() == NONE_DATA { + sigs.lock().unwrap_or_else(|e| e.into_inner())[peer_idx] = Some(NONE_DATA.to_vec()); + return Ok(()); + } + + let pubkey = peers[peer_idx].public_key()?; + + let lock_hash = { + let mut rx = lock_hash_rx.clone(); + tokio::select! { + result = rx.wait_for(|v| v.is_some()) => { + let guard = result.map_err(|_| bcast::Error::MissingField("lock_hash"))?; + guard + .clone() + .ok_or(bcast::Error::MissingField("lock_hash"))? + } + () = token.cancelled() => return Err(bcast::Error::Cancelled), + } + }; + + if lock_hash.as_slice() == NONE_DATA { + sigs.lock().unwrap_or_else(|e| e.into_inner())[peer_idx] = Some(NONE_DATA.to_vec()); + return Ok(()); + } + + if !pluto_k1util::verify_65(&pubkey, &lock_hash, msg.signature.as_ref())? { + return Err(bcast::Error::InvalidSignature(peer_id)); + } + + sigs.lock().unwrap_or_else(|e| e.into_inner())[peer_idx] = Some(msg.signature.to_vec()); + + Ok(()) +} + +#[cfg(test)] +mod tests { + use std::{collections::HashSet, net::TcpListener}; + + use anyhow::Context as _; + use futures::StreamExt as _; + use libp2p::{Multiaddr, swarm::SwarmEvent}; + use pluto_p2p::{ + config::P2PConfig, + p2p::{Node, NodeType}, + p2p_context::P2PContext, + peer::{Peer, peer_id_from_key}, + }; + use pluto_testutil::random::generate_insecure_k1_key; + use test_case::test_case; + use tokio::{ + sync::{mpsc, oneshot, watch}, + task::JoinSet, + }; + + use crate::bcast::Behaviour; + + use super::*; + + fn make_peer(seed: u8, index: usize) -> (SecretKey, Peer) { + let key = generate_insecure_k1_key(seed); + let id = peer_id_from_key(key.public_key()).unwrap(); + let peer = Peer { + id, + addresses: vec![], + index, + name: format!("peer-{seed}"), + }; + (key, peer) + } + + #[test] + fn all_sigs_returns_none_when_slot_empty() { + assert!(all_sigs(&[None, Some(vec![1]), Some(vec![2])]).is_none()); + assert!(all_sigs(&[Some(vec![1]), None, Some(vec![2])]).is_none()); + } + + #[test] + fn all_sigs_returns_vec_when_all_filled() { + let result = all_sigs(&[Some(vec![1u8]), Some(vec![2u8])]).unwrap(); + assert_eq!(result, vec![vec![1u8], vec![2u8]]); + } + + #[test] + fn all_sigs_empty_input() { + assert_eq!(all_sigs(&[]), Some(vec![])); + } + + #[test] + fn all_sigs_filters_none_data() { + let none_data = NONE_DATA.to_vec(); + let real_sig = vec![1u8, 2, 3]; + let result = all_sigs(&[ + Some(none_data.clone()), + Some(real_sig.clone()), + Some(none_data), + ]) + .unwrap(); + assert_eq!(result, vec![real_sig]); + } + + #[test] + fn all_sigs_returns_none_when_slot_empty_with_none_data() { + let none_data = NONE_DATA.to_vec(); + assert!(all_sigs(&[None, Some(none_data)]).is_none()); + } + + // Ports TestSigsCallbacks from charon/dkg/nodesigs_internal_test.go. + // n=10 peers; peer_index 11 = n+1, 10 = n. + // sender_peer_idx is the index into `peers` used as the transport-layer PeerId. + #[test_case(0, 0, Some(vec![0u8; 32]), 65, "invalid peer index" ; "wrong_peer_index_equal_to_ours")] + #[test_case(0, 11, Some(vec![0u8; 32]), 65, "invalid peer index" ; "wrong_peer_index_more_than_operators")] + #[test_case(0, 10, Some(vec![0u8; 32]), 65, "invalid peer index" ; "wrong_peer_index_exactly_at_len")] + #[test_case(0, 1, Some(vec![0u8; 32]), 65, "does not match" ; "sender_peer_id_mismatch")] + #[test_case(1, 1, None, 65, "missing protobuf field: lock_hash" ; "missing_lock_hash")] + #[test_case(1, 1, Some(vec![42u8; 32]), 65, "The signature recovery id byte 42 is invalid" ; "signature_verification_failed")] + #[test_case(1, 1, Some(vec![42u8; 32]), 2, "The signature length is invalid: expected 65, actual 2" ; "malformed_signature")] + #[tokio::test] + async fn sigs_callbacks( + sender_peer_idx: usize, + peer_index: u32, + lock_hash: Option>, + sig_len: usize, + expected_msg: &str, + ) { + const N: usize = 10; + let peers: Vec = (0..N) + .map(|i| make_peer(u8::try_from(i).expect("The number fits into u8"), i).1) + .collect(); + let (_, rx) = watch::channel(lock_hash); + let sigs = Mutex::new(vec![None::>; N]); + + let msg = MsgNodeSig { + signature: vec![42u8; sig_len].into(), + peer_index, + }; + + let err = receive( + peers[sender_peer_idx].id, + msg, + 0, + &peers, + rx, + &sigs, + CancellationToken::new(), + ) + .await + .unwrap_err(); + assert!( + err.to_string().contains(expected_msg), + "expected '{expected_msg}' in '{err}'" + ); + } + + #[tokio::test] + async fn sigs_callbacks_ok() { + let (_, peer0) = make_peer(0, 0); + let (key1, peer1) = make_peer(1, 1); + let peers = vec![peer0, peer1.clone()]; + let lock_hash = vec![42u8; 32]; + let (_, rx) = watch::channel(Some(lock_hash.clone())); + let sigs = Mutex::new(vec![None::>; 2]); + + let sig = pluto_k1util::sign(&key1, &lock_hash).unwrap(); + let msg = MsgNodeSig { + signature: sig.to_vec().into(), + peer_index: 1, + }; + + receive( + peer1.id, + msg, + 0, + &peers, + rx, + &sigs, + CancellationToken::new(), + ) + .await + .unwrap(); + + let guard = sigs.lock().unwrap(); + assert_eq!(guard[1], Some(sig.to_vec())); + } + + #[tokio::test] + async fn receive_none_sig_stores_sentinel() { + let (_, peer0) = make_peer(0, 0); + let (_, peer1) = make_peer(1, 1); + let peers = vec![peer0, peer1.clone()]; + let (_, rx) = watch::channel(None::>); + let sigs = Mutex::new(vec![None::>; 2]); + + let msg = MsgNodeSig { + signature: NONE_DATA.to_vec().into(), + peer_index: 1, + }; + + receive( + peer1.id, + msg, + 0, + &peers, + rx, + &sigs, + CancellationToken::new(), + ) + .await + .unwrap(); + + let guard = sigs.lock().unwrap(); + assert_eq!(guard[1], Some(NONE_DATA.to_vec())); + } + + #[tokio::test] + async fn receive_none_lock_hash_stores_sentinel() { + let (_, peer0) = make_peer(0, 0); + let (key1, peer1) = make_peer(1, 1); + let peers = vec![peer0, peer1.clone()]; + let lock_hash = vec![42u8; 32]; + let sig = pluto_k1util::sign(&key1, &lock_hash).unwrap(); + let (_, rx) = watch::channel(Some(NONE_DATA.to_vec())); + let sigs = Mutex::new(vec![None::>; 2]); + + let msg = MsgNodeSig { + signature: sig.to_vec().into(), + peer_index: 1, + }; + + receive( + peer1.id, + msg, + 0, + &peers, + rx, + &sigs, + CancellationToken::new(), + ) + .await + .unwrap(); + + let guard = sigs.lock().unwrap(); + assert_eq!(guard[1], Some(NONE_DATA.to_vec())); + } + + struct TestNode { + node: Node, + addr: Multiaddr, + } + + struct RunningNode { + stop_tx: oneshot::Sender<()>, + join: tokio::task::JoinHandle>, + } + + fn available_tcp_port() -> anyhow::Result { + let listener = TcpListener::bind("127.0.0.1:0")?; + Ok(listener.local_addr()?.port()) + } + + async fn wait_for_all_connections( + conn_rx: &mut mpsc::UnboundedReceiver<(usize, PeerId)>, + n: usize, + ) -> anyhow::Result<()> { + let mut seen = vec![HashSet::::new(); n]; + tokio::time::timeout(Duration::from_secs(10), async { + loop { + if seen.iter().all(|peers| peers.len() == n.saturating_sub(1)) { + return Ok(()); + } + let (index, peer_id) = conn_rx.recv().await.context("connection channel closed")?; + seen[index].insert(peer_id); + } + }) + .await + .context("timed out waiting for connections")? + } + + async fn spawn_swarm_tasks( + mut nodes: Vec, + conn_tx: mpsc::UnboundedSender<(usize, PeerId)>, + ) -> anyhow::Result> { + for node in &mut nodes { + node.node.listen_on(node.addr.clone())?; + } + + let dial_targets: Vec> = (0..nodes.len()) + .map(|index| { + nodes + .iter() + .enumerate() + .filter(|(other, _)| *other > index) + .map(|(_, n)| n.addr.clone()) + .collect() + }) + .collect(); + + let mut running = Vec::with_capacity(nodes.len()); + for (index, (test_node, targets)) in nodes.into_iter().zip(dial_targets).enumerate() { + let mut node = test_node.node; + let conn_tx = conn_tx.clone(); + let (stop_tx, mut stop_rx) = oneshot::channel::<()>(); + + let join = tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(200)).await; + for target in targets { + node.dial(target)?; + } + loop { + tokio::select! { + _ = &mut stop_rx => break, + event = node.select_next_some() => { + if let SwarmEvent::ConnectionEstablished { peer_id, .. } = event { + let _ = conn_tx.send((index, peer_id)); + } + } + } + } + Ok(()) + }); + + running.push(RunningNode { stop_tx, join }); + } + + Ok(running) + } + + async fn shutdown_swarm_tasks(tasks: Vec) -> anyhow::Result<()> { + for task in tasks { + let _ = task.stop_tx.send(()); + task.join.await??; + } + Ok(()) + } + + // Ports `TestSigsExchange` from charon/dkg/nodesigs_internal_test.go. + #[tokio::test] + async fn test_sigs_exchange() -> anyhow::Result<()> { + const N: usize = 7; + + let keys: Vec = (0..N) + .map(|i| generate_insecure_k1_key(u8::try_from(i).expect("N fits in u8"))) + .collect(); + let peer_ids: Vec = keys + .iter() + .map(|k| peer_id_from_key(k.public_key())) + .collect::, _>>()?; + + let cluster_peers: Vec = peer_ids + .iter() + .enumerate() + .map(|(i, &id)| Peer { + id, + addresses: vec![], + index: i, + name: format!("peer-{i}"), + }) + .collect(); + + let ports = (0..N) + .map(|_| available_tcp_port()) + .collect::>>()?; + + let (conn_tx, mut conn_rx) = mpsc::unbounded_channel(); + let token = CancellationToken::new(); + + let mut test_nodes = Vec::with_capacity(N); + let mut nsig_list = Vec::with_capacity(N); + + for (index, key) in keys.iter().enumerate() { + let p2p_context = P2PContext::new(peer_ids.clone()); + let (behaviour, component) = + Behaviour::new(peer_ids.clone(), p2p_context.clone(), key.clone()); + let nsig = + NodeSigBcast::new(cluster_peers.clone(), index, component, token.clone()).await?; + nsig_list.push(nsig); + + let node = Node::new_server( + P2PConfig::default(), + key.clone(), + NodeType::TCP, + false, + peer_ids.clone(), + move |builder, _| builder.with_p2p_context(p2p_context).with_inner(behaviour), + )?; + + let addr: Multiaddr = format!("/ip4/127.0.0.1/tcp/{}", ports[index]).parse()?; + test_nodes.push(TestNode { node, addr }); + } + + let running = spawn_swarm_tasks(test_nodes, conn_tx).await?; + wait_for_all_connections(&mut conn_rx, N).await?; + + let lock_hash = [42u8; 32]; + let mut handles = JoinSet::new(); + + for (i, nsig) in nsig_list.into_iter().enumerate() { + let key = keys[i].clone(); + let token = token.clone(); + handles.spawn(async move { nsig.exchange(Some(&key), lock_hash, token).await }); + } + + let results = tokio::time::timeout(Duration::from_secs(45), async { + let mut results = Vec::with_capacity(N); + while let Some(res) = handles.join_next().await { + results.push(res??); + } + anyhow::Ok(results) + }) + .await + .context("exchange timed out")??; + + assert_eq!(results.len(), N); + let first = &results[0]; + assert_eq!(first.len(), N); + for sig in first { + assert!(!sig.is_empty()); + } + for result in &results[1..] { + assert_eq!(result, first, "all nodes must collect identical signatures"); + } + + token.cancel(); + shutdown_swarm_tasks(running).await?; + + Ok(()) + } +}