diff --git a/src/protocol/libp2p/ping/config.rs b/src/protocol/libp2p/ping/config.rs index 085f2542..a78f3681 100644 --- a/src/protocol/libp2p/ping/config.rs +++ b/src/protocol/libp2p/ping/config.rs @@ -18,6 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use std::time::Duration; use crate::{ codec::ProtocolCodec, protocol::libp2p::ping::PingEvent, types::protocol::ProtocolName, DEFAULT_CHANNEL_SIZE, @@ -36,6 +37,8 @@ const PING_PAYLOAD_SIZE: usize = 32; /// Maximum PING failures. const MAX_FAILURES: usize = 3; +pub const PING_INTERVAL: Duration = Duration::from_secs(15); + /// Ping configuration. pub struct Config { /// Protocol name. @@ -49,6 +52,8 @@ pub struct Config { /// TX channel for sending events to the user protocol. pub(crate) tx_event: Sender, + + pub(crate) ping_interval: Duration, } impl Config { @@ -61,6 +66,7 @@ impl Config { ( Self { tx_event, + ping_interval: PING_INTERVAL, max_failures: MAX_FAILURES, protocol: ProtocolName::from(PROTOCOL_NAME), codec: ProtocolCodec::Identity(PING_PAYLOAD_SIZE), @@ -80,6 +86,7 @@ pub struct ConfigBuilder { /// Maximum failures before the peer is considered unreachable. max_failures: usize, + ping_interval: Duration, } impl Default for ConfigBuilder { @@ -92,6 +99,7 @@ impl ConfigBuilder { /// Create new default [`Config`] which can be modified by the user. pub fn new() -> Self { Self { + ping_interval: PING_INTERVAL, max_failures: MAX_FAILURES, protocol: ProtocolName::from(PROTOCOL_NAME), codec: ProtocolCodec::Identity(PING_PAYLOAD_SIZE), @@ -104,6 +112,11 @@ impl ConfigBuilder { self } + pub fn with_ping_interval(mut self, ping_interval: Duration) -> Self { + self.ping_interval = ping_interval; + self + } + /// Build [`Config`]. pub fn build(self) -> (Config, Box + Send + Unpin>) { let (tx_event, rx_event) = channel(DEFAULT_CHANNEL_SIZE); @@ -111,6 +124,7 @@ impl ConfigBuilder { ( Config { tx_event, + ping_interval: self.ping_interval, max_failures: self.max_failures, protocol: self.protocol, codec: self.codec, diff --git a/src/protocol/libp2p/ping/mod.rs b/src/protocol/libp2p/ping/mod.rs index fa16069f..a0b5e8ba 100644 --- a/src/protocol/libp2p/ping/mod.rs +++ b/src/protocol/libp2p/ping/mod.rs @@ -24,26 +24,24 @@ use crate::{ error::{Error, SubstreamError}, protocol::{Direction, TransportEvent, TransportService}, substream::Substream, - types::SubstreamId, PeerId, }; -use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; -use tokio::sync::mpsc::Sender; - +use futures::StreamExt; use std::{ - collections::HashSet, + collections::HashMap, time::{Duration, Instant}, }; +use tokio::sync::mpsc; pub use config::{Config, ConfigBuilder}; - mod config; // TODO: https://github.com/paritytech/litep2p/issues/132 let the user handle max failures /// Log target for the file. const LOG_TARGET: &str = "litep2p::ipfs::ping"; +const PING_TIMEOUT: Duration = Duration::from_secs(10); /// Events emitted by the ping protocol. #[derive(Debug)] @@ -56,158 +54,234 @@ pub enum PingEvent { /// Measured ping time with the peer. ping: Duration, }, + Failure { + peer: PeerId, + }, +} + +enum PingCommand { + SendPing, +} + +enum PingResult { + Success(Duration), + Failure, +} + +enum PeerState { + Pending, + Active { + command_tx: mpsc::Sender, + failures: usize, + }, } /// Ping protocol. pub(crate) struct Ping { /// Maximum failures before the peer is considered unreachable. - _max_failures: usize, + max_failures: usize, // Connection service. service: TransportService, /// TX channel for sending events to the user protocol. - tx: Sender, + tx: mpsc::Sender, /// Connected peers. - peers: HashSet, + peers: HashMap, - /// Pending outbound substreams. - pending_outbound: FuturesUnordered>>, + ping_interval: Duration, - /// Pending inbound substreams. - pending_inbound: FuturesUnordered>>, + result_rx: mpsc::Receiver<(PeerId, PingResult)>, + result_tx: mpsc::Sender<(PeerId, PingResult)>, } impl Ping { /// Create new [`Ping`] protocol. pub fn new(service: TransportService, config: Config) -> Self { + let (result_tx, result_rx) = mpsc::channel(256); Self { service, tx: config.tx_event, - peers: HashSet::new(), - pending_outbound: FuturesUnordered::new(), - pending_inbound: FuturesUnordered::new(), - _max_failures: config.max_failures, + peers: HashMap::new(), + ping_interval: config.ping_interval, + max_failures: config.max_failures, + result_rx, + result_tx, } } /// Connection established to remote peer. - fn on_connection_established(&mut self, peer: PeerId) -> crate::Result<()> { - tracing::trace!(target: LOG_TARGET, ?peer, "connection established"); - - self.service.open_substream(peer)?; - self.peers.insert(peer); + fn on_connection_established(&mut self, peer: PeerId) { + tracing::debug!(target: LOG_TARGET, ?peer, "connection established, opening ping substream"); - Ok(()) + match self.service.open_substream(peer) { + Ok(_) => { + self.peers.insert(peer, PeerState::Pending); + } + Err(error) => { + tracing::warn!(target: LOG_TARGET, ?peer, ?error, "failed to open ping substream"); + } + } } /// Connection closed to remote peer. fn on_connection_closed(&mut self, peer: PeerId) { - tracing::trace!(target: LOG_TARGET, ?peer, "connection closed"); - + tracing::debug!(target: LOG_TARGET, ?peer, "connection closed"); self.peers.remove(&peer); } /// Handle outbound substream. - fn on_outbound_substream( - &mut self, - peer: PeerId, - substream_id: SubstreamId, - mut substream: Substream, - ) { - tracing::trace!(target: LOG_TARGET, ?peer, "handle outbound substream"); - - self.pending_outbound.push(Box::pin(async move { - let future = async move { - // TODO: https://github.com/paritytech/litep2p/issues/134 generate random payload and verify it - substream.send_framed(vec![0u8; 32].into()).await?; - let now = Instant::now(); - let _ = substream.next().await.ok_or(Error::SubstreamError( - SubstreamError::ReadFailure(Some(substream_id)), - ))?; - let _ = substream.close().await; - - Ok(now.elapsed()) - }; - - match tokio::time::timeout(Duration::from_secs(10), future).await { - Err(_) => Err(Error::Timeout), - Ok(Err(error)) => Err(error), - Ok(Ok(elapsed)) => Ok((peer, elapsed)), - } - })); + fn on_outbound_substream(&mut self, peer: PeerId, substream: Substream) { + tracing::trace!(target: LOG_TARGET, ?peer, "outbound ping substream opened"); + + if let Some(PeerState::Pending) = self.peers.get(&peer) { + let (command_tx, command_rx) = mpsc::channel(1); + let result_tx = self.result_tx.clone(); + + tokio::spawn(handle_ping_substream( + peer, + substream, + command_rx, + result_tx, + )); + + self.peers.insert( + peer, + PeerState::Active { + command_tx, + failures: 0, + }, + ); + } else { + tracing::warn!(target: LOG_TARGET, ?peer, "ping substream opened for non-pending peer"); + } } /// Substream opened to remote peer. - fn on_inbound_substream(&mut self, peer: PeerId, mut substream: Substream) { - tracing::trace!(target: LOG_TARGET, ?peer, "handle inbound substream"); - - self.pending_inbound.push(Box::pin(async move { - let future = async move { - let payload = substream - .next() - .await - .ok_or(Error::SubstreamError(SubstreamError::ReadFailure(None)))??; - substream.send_framed(payload.freeze()).await?; - let _ = substream.next().await.map(|_| ()); - - Ok(()) - }; - - match tokio::time::timeout(Duration::from_secs(10), future).await { - Err(_) => Err(Error::Timeout), - Ok(Err(error)) => Err(error), - Ok(Ok(())) => Ok(()), + fn on_inbound_substream(&mut self, peer: PeerId, substream: Substream) { + tracing::trace!(target: LOG_TARGET, ?peer, "handling inbound ping substream"); + tokio::spawn(handle_inbound_ping(substream)); + } + + async fn on_ping_result(&mut self, peer: PeerId, result: PingResult) { + match self.peers.get_mut(&peer) { + Some(PeerState::Active { failures, .. }) => match result { + PingResult::Success(duration) => { + *failures = 0; + let _ = self.tx.send(PingEvent::Ping { peer, ping: duration }).await; + } + PingResult::Failure => { + *failures += 1; + tracing::debug!(target: LOG_TARGET, ?peer, failures, "ping failure"); + + if *failures >= self.max_failures { + tracing::warn!( + target: LOG_TARGET, + ?peer, + "maximum ping failures reached, closing connection" + ); + let _ = self.tx.send(PingEvent::Failure { peer }).await; + if let Err(e) = self.service.force_close(peer) { + tracing::error!(target: LOG_TARGET, ?peer, ?e, "failed to force close connection"); + } + self.peers.remove(&peer); + } + } + }, + _ => { + tracing::trace!(target: LOG_TARGET, ?peer, "ping result for inactive peer"); } - })); + } } /// Start [`Ping`] event loop. pub async fn run(mut self) { tracing::debug!(target: LOG_TARGET, "starting ping event loop"); + let mut interval = tokio::time::interval(self.ping_interval); loop { tokio::select! { event = self.service.next() => match event { Some(TransportEvent::ConnectionEstablished { peer, .. }) => { - let _ = self.on_connection_established(peer); + self.on_connection_established(peer); } Some(TransportEvent::ConnectionClosed { peer }) => { self.on_connection_closed(peer); } - Some(TransportEvent::SubstreamOpened { - peer, - substream, - direction, - .. - }) => match direction { - Direction::Inbound => { - self.on_inbound_substream(peer, substream); - } - Direction::Outbound(substream_id) => { - self.on_outbound_substream(peer, substream_id, substream); - } - }, + Some(TransportEvent::SubstreamOpened { peer, substream, direction, .. }) => match direction { + Direction::Outbound(_) => self.on_outbound_substream(peer, substream), + Direction::Inbound => self.on_inbound_substream(peer, substream), + } Some(_) => {} - None => return, + None => { + tracing::debug!(target: LOG_TARGET, "transport service shut down"); + return; + } + }, + _ = interval.tick() => { + for (peer, state) in self.peers.iter() { + if let PeerState::Active { command_tx, .. } = state { + if let Err(e) = command_tx.try_send(PingCommand::SendPing) { + tracing::trace!(target: LOG_TARGET, ?peer, ?e, "failed to send ping command"); + } + } + } }, - _event = self.pending_inbound.next(), if !self.pending_inbound.is_empty() => {} - event = self.pending_outbound.next(), if !self.pending_outbound.is_empty() => { - match event { - Some(Ok((peer, elapsed))) => { - let _ = self - .tx - .send(PingEvent::Ping { - peer, - ping: elapsed, - }) - .await; + Some((peer, result)) = self.result_rx.recv() => { + self.on_ping_result(peer, result).await; + } + } + } + } +} + +async fn handle_ping_substream( + peer: PeerId, + mut substream: Substream, + mut command_rx: mpsc::Receiver, + result_tx: mpsc::Sender<(PeerId, PingResult)>, +) { + loop { + match command_rx.recv().await { + Some(PingCommand::SendPing) => { + // TODO: https://github.com/paritytech/litep2p/issues/134 generate random payload and verify it + let payload = vec![0u8; 32]; + let future = async { + substream.send_framed(payload.into()).await?; + let now = Instant::now(); + let _ = substream + .next() + .await + .ok_or(Error::SubstreamError(SubstreamError::ConnectionClosed))??; + Ok::<_, Error>(now.elapsed()) + }; + + match tokio::time::timeout(PING_TIMEOUT, future).await { + Ok(Ok(duration)) => { + if result_tx.send((peer, PingResult::Success(duration))).await.is_err() { + break; + } + } + _ => { + if result_tx.send((peer, PingResult::Failure)).await.is_err() { + break; } - event => tracing::debug!(target: LOG_TARGET, "failed to handle ping for an outbound peer: {event:?}"), } } } + None => { + tracing::trace!(target: LOG_TARGET, ?peer, "ping command channel closed, shutting down task"); + break; + } } } } + +async fn handle_inbound_ping(mut substream: Substream) { + while let Some(Ok(payload)) = substream.next().await { + if substream.send_framed(payload.freeze()).await.is_err() { + break; + } + } +} \ No newline at end of file diff --git a/src/protocol/transport_service.rs b/src/protocol/transport_service.rs index b729e931..e2827182 100644 --- a/src/protocol/transport_service.rs +++ b/src/protocol/transport_service.rs @@ -287,6 +287,8 @@ pub struct TransportService { /// Close the connection if no substreams are open within this time frame. keep_alive_tracker: KeepAliveTracker, + + counts_towards_keep_alive: bool, } impl TransportService { @@ -298,6 +300,7 @@ impl TransportService { next_substream_id: Arc, transport_handle: TransportManagerHandle, keep_alive_timeout: Duration, + counts_towards_keep_alive: bool, ) -> (Self, Sender) { let (tx, rx) = channel(DEFAULT_CHANNEL_SIZE); @@ -313,6 +316,7 @@ impl TransportService { next_substream_id, connections: HashMap::new(), keep_alive_tracker, + counts_towards_keep_alive }, tx, ) @@ -507,8 +511,11 @@ impl TransportService { "open substream", ); - self.keep_alive_tracker.substream_activity(peer, connection_id); - connection.try_upgrade(); + if self.counts_towards_keep_alive { + self.keep_alive_tracker.substream_activity(peer, connection_id); + connection.try_upgrade(); + } + connection .open_substream( @@ -592,7 +599,7 @@ impl Stream for TransportService { substream, connection_id, }) => { - if protocol == self.protocol { + if protocol == self.protocol && self.counts_towards_keep_alive { self.keep_alive_tracker.substream_activity(peer, connection_id); if let Some(context) = self.connections.get_mut(&peer) { context.try_upgrade(&connection_id); diff --git a/src/transport/manager/mod.rs b/src/transport/manager/mod.rs index f44a07a6..ad2c1753 100644 --- a/src/transport/manager/mod.rs +++ b/src/transport/manager/mod.rs @@ -348,6 +348,7 @@ impl TransportManager { self.next_substream_id.clone(), self.transport_manager_handle.clone(), keep_alive_timeout, + true ); self.protocols.insert(