Skip to content
106 changes: 89 additions & 17 deletions iroh/src/discovery/mdns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ pub struct MdnsDiscovery {
#[allow(dead_code)]
handle: AbortOnDropHandle<()>,
sender: mpsc::Sender<Message>,
advertise: bool,
/// When `local_addrs` changes, we re-publish our info.
local_addrs: Watchable<Option<NodeData>>,
}
Expand Down Expand Up @@ -127,39 +128,73 @@ impl Subscribers {

/// Builder for [`MdnsDiscovery`].
#[derive(Debug)]
pub struct MdnsDiscoveryBuilder;
pub struct MdnsDiscoveryBuilder {
advertise: bool,
}

impl MdnsDiscoveryBuilder {
/// Creates a new [`MdnsDiscoveryBuilder`] with default settings.
pub fn new() -> Self {
Self { advertise: true }
}

/// Sets whether this node should advertise its presence.
///
/// Default is true.
pub fn advertise(mut self, advertise: bool) -> Self {
self.advertise = advertise;
self
}

/// Builds an [`MdnsDiscovery`] instance with the configured settings.
pub fn build(self, node_id: NodeId) -> Result<MdnsDiscovery, IntoDiscoveryError> {
MdnsDiscovery::new(node_id, self.advertise)
}
}

impl Default for MdnsDiscoveryBuilder {
fn default() -> Self {
Self::new()
}
}

impl IntoDiscovery for MdnsDiscoveryBuilder {
fn into_discovery(
self,
context: &DiscoveryContext,
) -> Result<impl Discovery, IntoDiscoveryError> {
MdnsDiscovery::new(context.node_id())
self.build(context.node_id())
}
}

impl MdnsDiscovery {
/// Returns a [`MdnsDiscoveryBuilder`] that implements [`IntoDiscovery`].
pub fn builder() -> MdnsDiscoveryBuilder {
MdnsDiscoveryBuilder
MdnsDiscoveryBuilder::new()
}

/// Create a new [`MdnsDiscovery`] Service.
///
/// This starts a [`Discoverer`] that broadcasts your addresses and receives addresses from other nodes in your local network.
/// This starts a [`Discoverer`] that broadcasts your addresses (if advertise is set to true)
/// and receives addresses from other nodes in your local network.
///
/// # Errors
/// Returns an error if the network does not allow ipv4 OR ipv6.
///
/// # Panics
/// This relies on [`tokio::runtime::Handle::current`] and will panic if called outside of the context of a tokio runtime.
pub fn new(node_id: NodeId) -> Result<Self, IntoDiscoveryError> {
pub fn new(node_id: NodeId, advertise: bool) -> Result<Self, IntoDiscoveryError> {
debug!("Creating new MdnsDiscovery service");
let (send, mut recv) = mpsc::channel(64);
let task_sender = send.clone();
let rt = tokio::runtime::Handle::current();
let discovery =
MdnsDiscovery::spawn_discoverer(node_id, task_sender.clone(), BTreeSet::new(), &rt)?;
let discovery = MdnsDiscovery::spawn_discoverer(
node_id,
advertise,
task_sender.clone(),
BTreeSet::new(),
&rt,
)?;

let local_addrs: Watchable<Option<NodeData>> = Watchable::default();
let mut addrs_change = local_addrs.watch();
Expand Down Expand Up @@ -311,13 +346,15 @@ impl MdnsDiscovery {
let handle = task::spawn(discovery_fut.instrument(info_span!("swarm-discovery.actor")));
Ok(Self {
handle: AbortOnDropHandle::new(handle),
advertise,
sender: send,
local_addrs,
})
}

fn spawn_discoverer(
node_id: PublicKey,
advertise: bool,
sender: mpsc::Sender<Message>,
socketaddrs: BTreeSet<SocketAddr>,
rt: &tokio::runtime::Handle,
Expand All @@ -337,15 +374,17 @@ impl MdnsDiscovery {
sender.send(Message::Discovery(node_id, peer)).await.ok();
});
};
let addrs = MdnsDiscovery::socketaddrs_to_addrs(&socketaddrs);
let node_id_str = data_encoding::BASE32_NOPAD
.encode(node_id.as_bytes())
.to_ascii_lowercase();
let mut discoverer = Discoverer::new_interactive(N0_LOCAL_SWARM.to_string(), node_id_str)
.with_callback(callback)
.with_ip_class(IpClass::Auto);
for addr in addrs {
discoverer = discoverer.with_addrs(addr.0, addr.1);
if advertise {
let addrs = MdnsDiscovery::socketaddrs_to_addrs(&socketaddrs);
for addr in addrs {
discoverer = discoverer.with_addrs(addr.0, addr.1);
}
}
discoverer
.spawn(rt)
Expand Down Expand Up @@ -406,7 +445,9 @@ impl Discovery for MdnsDiscovery {
}

fn publish(&self, data: &NodeData) {
self.local_addrs.set(Some(data.clone())).ok();
if self.advertise {
self.local_addrs.set(Some(data.clone())).ok();
}
}

fn subscribe(&self) -> Option<BoxStream<DiscoveryItem>> {
Expand Down Expand Up @@ -440,8 +481,10 @@ mod tests {
#[tokio::test]
#[traced_test]
async fn mdns_publish_resolve() -> Result {
let (_, discovery_a) = make_discoverer()?;
let (node_id_b, discovery_b) = make_discoverer()?;
// Create discoverer A with advertise=false (only listens)
let (_, discovery_a) = make_discoverer(false)?;
// Create discoverer B with advertise=true (will broadcast)
let (node_id_b, discovery_b) = make_discoverer(true)?;

// make addr info for discoverer b
let user_data: UserData = "foobar".parse()?;
Expand Down Expand Up @@ -477,11 +520,11 @@ mod tests {
let mut node_ids = BTreeSet::new();
let mut discoverers = vec![];

let (_, discovery) = make_discoverer()?;
let (_, discovery) = make_discoverer(false)?;
let node_data = NodeData::new(None, BTreeSet::from(["0.0.0.0:11111".parse().unwrap()]));

for i in 0..num_nodes {
let (node_id, discovery) = make_discoverer()?;
let (node_id, discovery) = make_discoverer(true)?;
let user_data: UserData = format!("node{i}").parse()?;
let node_data = node_data.clone().with_user_data(Some(user_data.clone()));
node_ids.insert((node_id, Some(user_data)));
Expand Down Expand Up @@ -513,9 +556,38 @@ mod tests {
.context("timeout")?
}

fn make_discoverer() -> Result<(PublicKey, MdnsDiscovery)> {
#[tokio::test]
#[traced_test]
async fn non_advertising_node_not_discovered() -> Result {
let (_, discovery_a) = make_discoverer(false)?;
let (node_id_b, discovery_b) = make_discoverer(false)?;

let (node_id_c, discovery_c) = make_discoverer(true)?;
let node_data_c =
NodeData::new(None, BTreeSet::from(["0.0.0.0:22222".parse().unwrap()]));
discovery_c.publish(&node_data_c);

let node_data_b =
NodeData::new(None, BTreeSet::from(["0.0.0.0:11111".parse().unwrap()]));
discovery_b.publish(&node_data_b);

let mut stream_c = discovery_a.resolve(node_id_c).unwrap();
let result_c = tokio::time::timeout(Duration::from_secs(2), stream_c.next()).await;
assert!(result_c.is_ok(), "Advertising node should be discoverable");

let mut stream_b = discovery_a.resolve(node_id_b).unwrap();
let result_b = tokio::time::timeout(Duration::from_secs(2), stream_b.next()).await;
assert!(
result_b.is_err(),
"Expected timeout since node b isn't advertising"
);

Ok(())
}

fn make_discoverer(advertise: bool) -> Result<(PublicKey, MdnsDiscovery)> {
let node_id = SecretKey::generate(rand::thread_rng()).public();
Ok((node_id, MdnsDiscovery::new(node_id)?))
Ok((node_id, MdnsDiscovery::new(node_id, advertise)?))
}
}
}
Loading