diff --git a/src/codec/identity.rs b/src/codec/identity.rs index 92ea7916..5451b5b2 100644 --- a/src/codec/identity.rs +++ b/src/codec/identity.rs @@ -100,9 +100,11 @@ mod tests { fn decoding_smaller_payloads() { let mut codec = Identity::new(100); let bytes = vec![3u8; 64]; + let copy = bytes.clone(); let mut bytes = BytesMut::from(&bytes[..]); - let decoded = codec.decode(&mut bytes); + // The smaller payload will not be decoded as the identity code needs 100 bytes. + assert!(codec.decode(&mut bytes).unwrap().is_none()); } #[test] diff --git a/src/config.rs b/src/config.rs index c2956021..8cd9ff98 100644 --- a/src/config.rs +++ b/src/config.rs @@ -29,7 +29,7 @@ use crate::{ notification, request_response, UserProtocol, }, transport::{ - manager::limits::ConnectionLimitsConfig, tcp::config::Config as TcpConfig, + manager::limits::ConnectionMiddleware, tcp::config::Config as TcpConfig, KEEP_ALIVE_TIMEOUT, MAX_PARALLEL_DIALS, }, types::protocol::ProtocolName, @@ -119,11 +119,11 @@ pub struct ConfigBuilder { /// Maximum number of parallel dial attempts. max_parallel_dials: usize, - /// Connection limits config. - connection_limits: ConnectionLimitsConfig, - /// Close the connection if no substreams are open within this time frame. keep_alive_timeout: Duration, + + /// Connection middleware. + connection_middleware: Option>, } impl Default for ConfigBuilder { @@ -155,8 +155,8 @@ impl ConfigBuilder { notification_protocols: HashMap::new(), request_response_protocols: HashMap::new(), known_addresses: Vec::new(), - connection_limits: ConnectionLimitsConfig::default(), keep_alive_timeout: KEEP_ALIVE_TIMEOUT, + connection_middleware: None, } } @@ -266,9 +266,9 @@ impl ConfigBuilder { self } - /// Set connection limits configuration. - pub fn with_connection_limits(mut self, config: ConnectionLimitsConfig) -> Self { - self.connection_limits = config; + /// Set connection middleware. + pub fn with_connection_middleware(mut self, middleware: Box) -> Self { + self.connection_middleware = Some(middleware); self } @@ -305,8 +305,8 @@ impl ConfigBuilder { notification_protocols: self.notification_protocols, request_response_protocols: self.request_response_protocols, known_addresses: self.known_addresses, - connection_limits: self.connection_limits, keep_alive_timeout: self.keep_alive_timeout, + connection_middleware: self.connection_middleware.take(), } } } @@ -364,9 +364,9 @@ pub struct Litep2pConfig { /// Known addresses. pub(crate) known_addresses: Vec<(PeerId, Vec)>, - /// Connection limits config. - pub(crate) connection_limits: ConnectionLimitsConfig, - /// Close the connection if no substreams are open within this time frame. pub(crate) keep_alive_timeout: Duration, + + /// Connection middleware. + pub(crate) connection_middleware: Option>, } diff --git a/src/lib.rs b/src/lib.rs index 66e03289..c1f42d99 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -163,7 +163,7 @@ impl Litep2p { supported_transports, bandwidth_sink.clone(), litep2p_config.max_parallel_dials, - litep2p_config.connection_limits, + litep2p_config.connection_middleware, ); // add known addresses to `TransportManager`, if any exist diff --git a/src/protocol/libp2p/kademlia/mod.rs b/src/protocol/libp2p/kademlia/mod.rs index aa5c20f4..e993aeac 100644 --- a/src/protocol/libp2p/kademlia/mod.rs +++ b/src/protocol/libp2p/kademlia/mod.rs @@ -1230,10 +1230,7 @@ mod tests { use crate::{ codec::ProtocolCodec, crypto::ed25519::Keypair, - transport::{ - manager::{limits::ConnectionLimitsConfig, TransportManager}, - KEEP_ALIVE_TIMEOUT, - }, + transport::{manager::TransportManager, ConnectionLimits, KEEP_ALIVE_TIMEOUT}, types::protocol::ProtocolName, BandwidthSink, }; @@ -1251,7 +1248,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let peer = PeerId::random(); diff --git a/src/protocol/mdns.rs b/src/protocol/mdns.rs index 7ea8f30d..53538fae 100644 --- a/src/protocol/mdns.rs +++ b/src/protocol/mdns.rs @@ -336,7 +336,7 @@ mod tests { use super::*; use crate::{ crypto::ed25519::Keypair, - transport::manager::{limits::ConnectionLimitsConfig, TransportManager}, + transport::{manager::TransportManager, ConnectionLimits}, BandwidthSink, }; use futures::StreamExt; @@ -354,7 +354,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let mdns1 = Mdns::new( @@ -377,7 +377,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let mdns2 = Mdns::new( diff --git a/src/protocol/notification/tests/mod.rs b/src/protocol/notification/tests/mod.rs index 4aa48aa4..5b40dd3b 100644 --- a/src/protocol/notification/tests/mod.rs +++ b/src/protocol/notification/tests/mod.rs @@ -29,10 +29,7 @@ use crate::{ }, InnerTransportEvent, ProtocolCommand, TransportService, }, - transport::{ - manager::{limits::ConnectionLimitsConfig, TransportManager}, - KEEP_ALIVE_TIMEOUT, - }, + transport::{manager::TransportManager, ConnectionLimits, KEEP_ALIVE_TIMEOUT}, types::protocol::ProtocolName, BandwidthSink, PeerId, }; @@ -56,7 +53,7 @@ fn make_notification_protocol() -> ( HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let peer = PeerId::random(); diff --git a/src/protocol/request_response/tests.rs b/src/protocol/request_response/tests.rs index 32cc65e7..522269bb 100644 --- a/src/protocol/request_response/tests.rs +++ b/src/protocol/request_response/tests.rs @@ -29,10 +29,7 @@ use crate::{ InnerTransportEvent, SubstreamError, TransportService, }, substream::Substream, - transport::{ - manager::{limits::ConnectionLimitsConfig, TransportManager}, - KEEP_ALIVE_TIMEOUT, - }, + transport::{manager::TransportManager, ConnectionLimits, KEEP_ALIVE_TIMEOUT}, types::{RequestId, SubstreamId}, BandwidthSink, Error, PeerId, ProtocolName, }; @@ -54,7 +51,7 @@ fn protocol() -> ( HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let peer = PeerId::random(); diff --git a/src/transport/manager/limits.rs b/src/transport/manager/limits.rs index 0af49eb1..bcb48c95 100644 --- a/src/transport/manager/limits.rs +++ b/src/transport/manager/limits.rs @@ -20,9 +20,100 @@ //! Limits for the transport manager. -use crate::types::ConnectionId; +use crate::{transport::Endpoint, types::ConnectionId, PeerId}; + +use std::{collections::HashSet, net::SocketAddr}; + +/// A middleware trait for managing connections. +/// +/// This middleware allows developers to implement custom connection policies, +/// enabling a wide range of use cases by exposing hooks into the connection lifecycle. +/// +/// It interacts with the transport manager at two stages: +/// +/// ## 1. Before Negotiation +/// +/// At this stage, the connection has not yet been negotiated. In the context of litep2p, +/// "negotiation" refers to the handshake and setup of `crypto/noise` (encryption and peer ID +/// validation) and `yamux` (multiplexing). +/// +/// The node is either attempting to establish an outbound connection or accept an inbound one. +/// +/// - Returning an error here will prevent the negotiation from proceeding, saving resources. +/// +/// - [`Self::outbound_capacity`] is called to determine the number of outbound +/// connections that can be established. The peerID is provided to further provide connection +/// details. +/// +/// - [`Self::check_inbound`] is called to evaluate whether an inbound connection can be accepted. +/// The peer ID is not yet known, but the socket address is provided to identify the connection. +/// +/// ## 2. After Negotiation +/// +/// At this point, the connection has been successfully negotiated and the peer ID is known. +/// +/// - [`Self::can_accept_connection`] is invoked to determine if the fully negotiated connection +/// should be accepted. The peer ID, endpoint, and connection ID are provided. Implementations +/// should check internal limits but **must not** store the connection ID or endpoint here, as the +/// transport manager might still reject the connection later. +/// +/// - If the connection is accepted, [`Self::on_connection_established`] is called with the same +/// peer ID and endpoint. At this point, implementations should begin tracking the connection ID. +/// +/// - When a connection is closed, [`Self::on_connection_closed`] is called. Implementations must +/// clean up any resources associated with the connection ID to prevent memory leaks. +pub trait ConnectionMiddleware: Send { + /// Determines the number of outbound connections permitted to be established. + /// + /// This method is called before the node attempts to dial a remote peer. + /// + /// Returns the number of allowed outbound connections. + /// - If there is no limit, returns `Ok(usize::MAX)`. + /// - If the node cannot accept any more outbound connections, returns an error. + fn outbound_capacity(&mut self, peer: PeerId) -> crate::Result; + + /// Checks whether a new inbound connection can be accepted before processing it. + /// + /// At this point, no protocol negotiation has occurred and the peer identity is + /// unknown. The connection ID provided is the one that will be used for the + /// connection. + fn check_inbound( + &mut self, + connection_id: ConnectionId, + address: SocketAddr, + ) -> crate::Result<()>; + + /// Verifies if a new connection (inbound or outbound) can be established. + /// + /// Returns an error if connection limits or policy constraints prevent + /// establishing the connection. + /// + /// # Note + /// + /// This method is called before the connection is established. However, + /// the transport manager can decide to reject the connection even if this + /// method returns `Ok(())`. Therefore, the API makes no guarantees of + /// further calling [`Self::on_connection_established`]. + /// + /// Implementations should inspect the provided parameters. To avoid leaking + /// memory, the implementation should not store the connection ID or endpoint + /// at this point in time. + fn can_accept_connection(&mut self, peer: PeerId, endpoint: &Endpoint) -> crate::Result<()>; + + /// Registers a connection as established. + /// + /// This method will be called after a successful check using [`Self::can_accept_connection`]. + /// The peer ID and endpoint are provided to identify the connection and are identical + /// to the ones used in [`Self::can_accept_connection`]. + fn on_connection_established(&mut self, peer: PeerId, endpoint: &Endpoint); -use std::collections::HashSet; + /// Deregisters a connection when it is closed. + /// + /// This method will be called after a [`Self::on_connection_established`] call. + /// The connection ID corresponds the endpoint provided in the + /// [`Self::on_connection_established`] method. + fn on_connection_closed(&mut self, peer: PeerId, connection_id: ConnectionId); +} /// Configuration for the connection limits. #[derive(Debug, Clone, Default)] @@ -56,7 +147,10 @@ pub enum ConnectionLimitsError { MaxOutgoingConnectionsExceeded, } -/// Connection limits. +/// General connection limits. +/// +/// This is a type of connection middleware that places limits on the number +/// of incoming and outgoing connections. #[derive(Debug, Clone)] pub struct ConnectionLimits { /// Configuration for the connection limits. @@ -80,19 +174,13 @@ impl ConnectionLimits { outgoing_connections: HashSet::with_capacity(max_outgoing_connections), } } +} - /// Called when dialing an address. - /// - /// Returns the number of outgoing connections permitted to be established. - /// It is guaranteed that at least one connection can be established if the method returns `Ok`. - /// The number of available outgoing connections can influence the maximum parallel dials to a - /// single address. - /// - /// If the maximum number of outgoing connections is not set, `Ok(usize::MAX)` is returned. - pub fn on_dial_address(&mut self) -> Result { +impl ConnectionMiddleware for ConnectionLimits { + fn outbound_capacity(&mut self, _peer: PeerId) -> crate::Result { if let Some(max_outgoing_connections) = self.config.max_outgoing_connections { if self.outgoing_connections.len() >= max_outgoing_connections { - return Err(ConnectionLimitsError::MaxOutgoingConnectionsExceeded); + return Err(ConnectionLimitsError::MaxOutgoingConnectionsExceeded.into()); } return Ok(max_outgoing_connections - self.outgoing_connections.len()); @@ -101,62 +189,48 @@ impl ConnectionLimits { Ok(usize::MAX) } - /// Called before accepting a new incoming connection. - pub fn on_incoming(&mut self) -> Result<(), ConnectionLimitsError> { + fn check_inbound( + &mut self, + _connection_id: ConnectionId, + _address: SocketAddr, + ) -> crate::Result<()> { if let Some(max_incoming_connections) = self.config.max_incoming_connections { if self.incoming_connections.len() >= max_incoming_connections { - return Err(ConnectionLimitsError::MaxIncomingConnectionsExceeded); + return Err(ConnectionLimitsError::MaxIncomingConnectionsExceeded.into()); } } Ok(()) } - /// Called when a new connection is established. - /// - /// Returns an error if the connection cannot be accepted due to connection limits. - pub fn can_accept_connection( - &mut self, - is_listener: bool, - ) -> Result<(), ConnectionLimitsError> { + fn can_accept_connection(&mut self, _peer: PeerId, endpoint: &Endpoint) -> crate::Result<()> { // Check connection limits. - if is_listener { + if endpoint.is_listener() { if let Some(max_incoming_connections) = self.config.max_incoming_connections { if self.incoming_connections.len() >= max_incoming_connections { - return Err(ConnectionLimitsError::MaxIncomingConnectionsExceeded); + return Err(ConnectionLimitsError::MaxIncomingConnectionsExceeded.into()); } } } else if let Some(max_outgoing_connections) = self.config.max_outgoing_connections { if self.outgoing_connections.len() >= max_outgoing_connections { - return Err(ConnectionLimitsError::MaxOutgoingConnectionsExceeded); + return Err(ConnectionLimitsError::MaxOutgoingConnectionsExceeded.into()); } } Ok(()) } - /// Accept an established connection. - /// - /// # Note - /// - /// This method should be called after the `Self::can_accept_connection` method - /// to ensure that the connection can be accepted. - pub fn accept_established_connection( - &mut self, - connection_id: ConnectionId, - is_listener: bool, - ) { - if is_listener { + fn on_connection_established(&mut self, _peer: PeerId, endpoint: &Endpoint) { + if endpoint.is_listener() { if self.config.max_incoming_connections.is_some() { - self.incoming_connections.insert(connection_id); + self.incoming_connections.insert(endpoint.connection_id()); } } else if self.config.max_outgoing_connections.is_some() { - self.outgoing_connections.insert(connection_id); + self.outgoing_connections.insert(endpoint.connection_id()); } } - /// Called when a connection is closed. - pub fn on_connection_closed(&mut self, connection_id: ConnectionId) { + fn on_connection_closed(&mut self, _peer: PeerId, connection_id: ConnectionId) { self.incoming_connections.remove(&connection_id); self.outgoing_connections.remove(&connection_id); } @@ -181,46 +255,59 @@ mod tests { let connection_id_in_3 = ConnectionId::random(); // Establish incoming connection. - assert!(limits.can_accept_connection(true).is_ok()); - limits.accept_established_connection(connection_id_in_1, true); + let endpoint = Endpoint::Listener { + address: multiaddr::Multiaddr::empty(), + connection_id: connection_id_in_1, + }; + assert!(limits.can_accept_connection(PeerId::random(), &endpoint).is_ok()); + limits.on_connection_established(PeerId::random(), &endpoint); assert_eq!(limits.incoming_connections.len(), 1); - assert!(limits.can_accept_connection(true).is_ok()); - limits.accept_established_connection(connection_id_in_2, true); + let endpoint = Endpoint::Listener { + address: multiaddr::Multiaddr::empty(), + connection_id: connection_id_in_2, + }; + assert!(limits.can_accept_connection(PeerId::random(), &endpoint).is_ok()); + limits.on_connection_established(PeerId::random(), &endpoint); assert_eq!(limits.incoming_connections.len(), 2); - assert!(limits.can_accept_connection(true).is_ok()); - limits.accept_established_connection(connection_id_in_3, true); + let endpoint = Endpoint::Listener { + address: multiaddr::Multiaddr::empty(), + connection_id: connection_id_in_3, + }; + assert!(limits.can_accept_connection(PeerId::random(), &endpoint).is_ok()); + limits.on_connection_established(PeerId::random(), &endpoint); assert_eq!(limits.incoming_connections.len(), 3); - assert_eq!( - limits.can_accept_connection(true).unwrap_err(), - ConnectionLimitsError::MaxIncomingConnectionsExceeded - ); + assert!(limits.can_accept_connection(PeerId::random(), &endpoint).is_err()); assert_eq!(limits.incoming_connections.len(), 3); // Establish outgoing connection. - assert!(limits.can_accept_connection(false).is_ok()); - limits.accept_established_connection(connection_id_out_1, false); + let endpoint = Endpoint::Dialer { + address: multiaddr::Multiaddr::empty(), + connection_id: connection_id_out_1, + }; + assert!(limits.can_accept_connection(PeerId::random(), &endpoint).is_ok()); + limits.on_connection_established(PeerId::random(), &endpoint); assert_eq!(limits.incoming_connections.len(), 3); assert_eq!(limits.outgoing_connections.len(), 1); - assert!(limits.can_accept_connection(false).is_ok()); - limits.accept_established_connection(connection_id_out_2, false); + let endpoint = Endpoint::Dialer { + address: multiaddr::Multiaddr::empty(), + connection_id: connection_id_out_2, + }; + assert!(limits.can_accept_connection(PeerId::random(), &endpoint).is_ok()); + limits.on_connection_established(PeerId::random(), &endpoint); assert_eq!(limits.incoming_connections.len(), 3); assert_eq!(limits.outgoing_connections.len(), 2); + assert!(limits.can_accept_connection(PeerId::random(), &endpoint).is_err()); - assert_eq!( - limits.can_accept_connection(false).unwrap_err(), - ConnectionLimitsError::MaxOutgoingConnectionsExceeded - ); - - // Close connections with peer a. - limits.on_connection_closed(connection_id_in_1); + // Close connections with 1. + limits.on_connection_closed(PeerId::random(), connection_id_in_1); assert_eq!(limits.incoming_connections.len(), 2); assert_eq!(limits.outgoing_connections.len(), 2); - limits.on_connection_closed(connection_id_out_1); + limits.on_connection_closed(PeerId::random(), connection_id_out_1); assert_eq!(limits.incoming_connections.len(), 2); assert_eq!(limits.outgoing_connections.len(), 1); } diff --git a/src/transport/manager/mod.rs b/src/transport/manager/mod.rs index dec079a7..20144699 100644 --- a/src/transport/manager/mod.rs +++ b/src/transport/manager/mod.rs @@ -41,6 +41,7 @@ use crate::{ use address::{scores, AddressStore}; use futures::{Stream, StreamExt}; use indexmap::IndexMap; +use limits::ConnectionMiddleware; use multiaddr::{Multiaddr, Protocol}; use multihash::Multihash; use parking_lot::RwLock; @@ -48,6 +49,7 @@ use tokio::sync::mpsc::{channel, Receiver, Sender}; use std::{ collections::{HashMap, HashSet}, + net::SocketAddr, pin::Pin, sync::{ atomic::{AtomicUsize, Ordering}, @@ -247,11 +249,11 @@ pub struct TransportManager { /// Pending connections. pending_connections: HashMap, - /// Connection limits. - connection_limits: limits::ConnectionLimits, - /// Opening connections errors. opening_errors: HashMap>, + + /// Connection middleware. + connection_middleware: Option>, } impl TransportManager { @@ -262,7 +264,7 @@ impl TransportManager { supported_transports: HashSet, bandwidth_sink: BandwidthSink, max_parallel_dials: usize, - connection_limits_config: limits::ConnectionLimitsConfig, + connection_middleware: Option>, ) -> (Self, TransportManagerHandle) { let local_peer_id = PeerId::from_public_key(&keypair.public().into()); let peers = Arc::new(RwLock::new(HashMap::new())); @@ -298,8 +300,8 @@ impl TransportManager { pending_connections: HashMap::new(), next_substream_id: Arc::new(AtomicUsize::new(0usize)), next_connection_id: Arc::new(AtomicUsize::new(0usize)), - connection_limits: limits::ConnectionLimits::new(connection_limits_config), opening_errors: HashMap::new(), + connection_middleware, }, handle, ) @@ -441,7 +443,12 @@ impl TransportManager { /// Returns an error if the peer is unknown or the peer is already connected. pub async fn dial(&mut self, peer: PeerId) -> crate::Result<()> { // Don't alter the peer state if there's no capacity to dial. - let available_capacity = self.connection_limits.on_dial_address()?; + let available_capacity = if let Some(middleware) = &mut self.connection_middleware { + middleware.outbound_capacity(peer)? + } else { + usize::MAX + }; + // The available capacity is the maximum number of connections that can be established, // so we limit the number of parallel dials to the minimum of these values. let limit = available_capacity.min(self.max_parallel_dials); @@ -514,7 +521,12 @@ impl TransportManager { /// /// Returns an error if address it not valid. pub async fn dial_address(&mut self, address: Multiaddr) -> crate::Result<()> { - self.connection_limits.on_dial_address()?; + if let Some(middleware) = &mut self.connection_middleware { + let peer = PeerId::try_from_multiaddr(&address) + .ok_or(Error::AddressError(AddressError::PeerIdMissing))?; + + middleware.outbound_capacity(peer)?; + } let address_record = AddressRecord::from_multiaddr(address) .ok_or(Error::AddressError(AddressError::PeerIdMissing))?; @@ -682,8 +694,15 @@ impl TransportManager { Ok(()) } - fn on_pending_incoming_connection(&mut self) -> crate::Result<()> { - self.connection_limits.on_incoming()?; + fn on_pending_incoming_connection( + &mut self, + connection_id: ConnectionId, + address: SocketAddr, + ) -> crate::Result<()> { + if let Some(middleware) = &mut self.connection_middleware { + middleware.check_inbound(connection_id, address)?; + } + Ok(()) } @@ -695,7 +714,9 @@ impl TransportManager { ) -> Option { tracing::trace!(target: LOG_TARGET, ?peer, ?connection_id, "connection closed"); - self.connection_limits.on_connection_closed(connection_id); + if let Some(middleware) = &mut self.connection_middleware { + middleware.on_connection_closed(peer, connection_id); + } let mut peers = self.peers.write(); let context = peers.entry(peer).or_insert_with(|| PeerContext::default()); @@ -772,15 +793,17 @@ impl TransportManager { }; // Reject the connection if exceeded limits. - if let Err(error) = self.connection_limits.can_accept_connection(endpoint.is_listener()) { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?endpoint, - ?error, - "connection limit exceeded, rejecting connection", - ); - return Ok(ConnectionEstablishedResult::Reject); + if let Some(middleware) = &mut self.connection_middleware { + if let Err(error) = middleware.can_accept_connection(peer, endpoint) { + tracing::debug!( + target: LOG_TARGET, + ?peer, + ?endpoint, + ?error, + "connection middleware rejected connection", + ); + return Ok(ConnectionEstablishedResult::Reject); + } } let mut peers = self.peers.write(); @@ -801,8 +824,9 @@ impl TransportManager { ); if connection_accepted { - self.connection_limits - .accept_established_connection(endpoint.connection_id(), endpoint.is_listener()); + if let Some(middleware) = &mut self.connection_middleware { + middleware.on_connection_established(peer, endpoint); + } // Cancel all pending dials if the connection was established. if let PeerState::Opening { @@ -1280,8 +1304,8 @@ impl TransportManager { } } }, - TransportEvent::PendingInboundConnection { connection_id } => { - if self.on_pending_incoming_connection().is_ok() { + TransportEvent::PendingInboundConnection { connection_id, address } => { + if self.on_pending_incoming_connection(connection_id, address).is_ok() { tracing::trace!( target: LOG_TARGET, ?connection_id, @@ -1318,9 +1342,10 @@ impl TransportManager { #[cfg(test)] mod tests { use crate::transport::manager::{address::AddressStore, peer_state::SecondaryOrDialing}; - use limits::ConnectionLimitsConfig; + use limits::{ConnectionLimits, ConnectionLimitsConfig}; use multihash::Multihash; + use std::net::IpAddr; use super::*; use crate::{ @@ -1428,6 +1453,7 @@ mod tests { tx_ws .send(TransportEvent::PendingInboundConnection { connection_id: ConnectionId::from(1), + address: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), }) .await .expect("chanel to be open"); @@ -1446,6 +1472,7 @@ mod tests { tx_tcp .send(TransportEvent::PendingInboundConnection { connection_id: ConnectionId::from(2), + address: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), }) .await .expect("chanel to be open"); @@ -1464,12 +1491,14 @@ mod tests { tx_ws .send(TransportEvent::PendingInboundConnection { connection_id: ConnectionId::from(3), + address: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), }) .await .expect("chanel to be open"); tx_tcp .send(TransportEvent::PendingInboundConnection { connection_id: ConnectionId::from(4), + address: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), }) .await .expect("chanel to be open"); @@ -1505,7 +1534,7 @@ mod tests { HashSet::new(), sink, 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); manager.register_protocol( @@ -1532,7 +1561,7 @@ mod tests { HashSet::new(), sink, 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); manager.register_protocol( @@ -1562,7 +1591,7 @@ mod tests { HashSet::new(), sink, 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); manager.register_protocol( @@ -1595,7 +1624,7 @@ mod tests { HashSet::new(), sink, 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -1612,7 +1641,7 @@ mod tests { HashSet::new(), sink, 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); assert!(manager.dial(local_peer_id).await.is_err()); @@ -1625,7 +1654,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -1655,7 +1684,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let peer = PeerId::random(); let dial_address = Multiaddr::empty() @@ -1717,7 +1746,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -1748,7 +1777,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -1793,7 +1822,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -1812,7 +1841,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -1845,7 +1874,7 @@ mod tests { transports, BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); // ipv6 @@ -1907,7 +1936,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -1974,7 +2003,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -2061,7 +2090,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -2146,7 +2175,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -2255,7 +2284,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -2351,7 +2380,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -2460,7 +2489,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -2564,7 +2593,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -2708,7 +2737,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); manager.on_dial_failure(ConnectionId::random()).unwrap(); @@ -2727,7 +2756,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); manager.on_connection_closed(PeerId::random(), ConnectionId::random()).unwrap(); } @@ -2745,7 +2774,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); manager .on_connection_opened( @@ -2769,7 +2798,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let connection_id = ConnectionId::random(); let peer = PeerId::random(); @@ -2793,7 +2822,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let connection_id = ConnectionId::random(); let peer = PeerId::random(); @@ -2820,7 +2849,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); manager @@ -2841,7 +2870,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let connection_id = ConnectionId::random(); let peer = PeerId::random(); @@ -2861,7 +2890,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); assert!(manager.next().await.is_none()); @@ -2874,7 +2903,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let peer = { @@ -2922,7 +2951,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let peer = { @@ -2985,7 +3014,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let peer = { @@ -3028,7 +3057,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); // transport doesn't start with ip/dns @@ -3094,7 +3123,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); async fn call_manager(manager: &mut TransportManager, address: Multiaddr) { @@ -3148,7 +3177,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let peer = PeerId::random(); let dial_address = Multiaddr::empty() @@ -3234,7 +3263,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let peer = PeerId::random(); let dial_address = Multiaddr::empty() @@ -3322,9 +3351,11 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default() - .max_incoming_connections(Some(3)) - .max_outgoing_connections(Some(2)), + Some(Box::new(ConnectionLimits::new( + ConnectionLimitsConfig::default() + .max_incoming_connections(Some(3)) + .max_outgoing_connections(Some(2)), + ))), ); // The connection limit is agnostic of the underlying transports. manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -3398,9 +3429,11 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default() - .max_incoming_connections(Some(3)) - .max_outgoing_connections(Some(2)), + Some(Box::new(ConnectionLimits::new( + ConnectionLimitsConfig::default() + .max_incoming_connections(Some(3)) + .max_outgoing_connections(Some(2)), + ))), ); // The connection limit is agnostic of the underlying transports. manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -3487,7 +3520,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -3540,7 +3573,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -3692,7 +3725,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let peer = PeerId::random(); let dial_address = Multiaddr::empty() @@ -3778,7 +3811,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + Some(Box::new(ConnectionLimits::new(Default::default()))), ); let peer = PeerId::random(); let connection_id = ConnectionId::from(0); diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 1bf9b7d9..d71cb258 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -25,7 +25,7 @@ use crate::{error::DialError, transport::manager::TransportHandle, types::Connec use futures::Stream; use multiaddr::Multiaddr; -use std::{fmt::Debug, time::Duration}; +use std::{fmt::Debug, net::SocketAddr, time::Duration}; pub(crate) mod common; #[cfg(feature = "quic")] @@ -39,7 +39,9 @@ pub mod websocket; pub(crate) mod dummy; pub(crate) mod manager; -pub use manager::limits::{ConnectionLimitsConfig, ConnectionLimitsError}; +pub use manager::limits::{ + ConnectionLimits, ConnectionLimitsConfig, ConnectionLimitsError, ConnectionMiddleware, +}; /// Timeout for opening a connection. pub(crate) const CONNECTION_OPEN_TIMEOUT: Duration = Duration::from_secs(10); @@ -129,6 +131,9 @@ pub(crate) enum TransportEvent { PendingInboundConnection { /// Connection ID. connection_id: ConnectionId, + + /// The socket address which initiated the connection. + address: SocketAddr, }, /// Connection opened to remote but not yet negotiated. diff --git a/src/transport/quic/mod.rs b/src/transport/quic/mod.rs index 0cf5e255..32c75c2f 100644 --- a/src/transport/quic/mod.rs +++ b/src/transport/quic/mod.rs @@ -501,6 +501,7 @@ impl Stream for QuicTransport { return Poll::Ready(Some(TransportEvent::PendingInboundConnection { connection_id, + address: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), })); } @@ -659,7 +660,7 @@ mod tests { let event = transport1.next().await.unwrap(); match event { - TransportEvent::PendingInboundConnection { connection_id } => { + TransportEvent::PendingInboundConnection { connection_id, .. } => { transport1.accept_pending(connection_id).unwrap(); } _ => panic!("unexpected event"), diff --git a/src/transport/tcp/mod.rs b/src/transport/tcp/mod.rs index 748e138d..27d105f3 100644 --- a/src/transport/tcp/mod.rs +++ b/src/transport/tcp/mod.rs @@ -579,6 +579,7 @@ impl Stream for TcpTransport { Poll::Ready(Some(TransportEvent::PendingInboundConnection { connection_id, + address, })) } }; @@ -684,9 +685,7 @@ mod tests { codec::ProtocolCodec, crypto::ed25519::Keypair, executor::DefaultExecutor, - transport::manager::{ - limits::ConnectionLimitsConfig, ProtocolContext, SupportedTransport, TransportManager, - }, + transport::manager::{ProtocolContext, SupportedTransport, TransportManager}, types::protocol::ProtocolName, BandwidthSink, PeerId, }; @@ -769,7 +768,7 @@ mod tests { let event = transport1.next().await.unwrap(); match event { - TransportEvent::PendingInboundConnection { connection_id } => { + TransportEvent::PendingInboundConnection { connection_id, .. } => { transport1.accept_pending(connection_id).unwrap(); } _ => panic!("unexpected event"), @@ -863,7 +862,7 @@ mod tests { // Reject connection. let event = transport1.next().await.unwrap(); match event { - TransportEvent::PendingInboundConnection { connection_id } => { + TransportEvent::PendingInboundConnection { connection_id, .. } => { transport1.reject_pending(connection_id).unwrap(); } _ => panic!("unexpected event"), @@ -979,7 +978,7 @@ mod tests { HashSet::new(), BandwidthSink::new(), 8usize, - ConnectionLimitsConfig::default(), + None, ); let handle = manager.transport_handle(Arc::new(DefaultExecutor {})); manager.register_transport( diff --git a/src/transport/websocket/mod.rs b/src/transport/websocket/mod.rs index 2435f639..d65a0106 100644 --- a/src/transport/websocket/mod.rs +++ b/src/transport/websocket/mod.rs @@ -625,6 +625,7 @@ impl Stream for WebSocketTransport { Poll::Ready(Some(TransportEvent::PendingInboundConnection { connection_id, + address, })) } };