diff --git a/src/config.rs b/src/config.rs index 81befc43..120619d7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -29,8 +29,9 @@ use crate::{ notification, request_response, UserProtocol, }, transport::{ - manager::limits::ConnectionLimitsConfig, tcp::config::Config as TcpConfig, - KEEP_ALIVE_TIMEOUT, MAX_PARALLEL_DIALS, + manager::{limits::ConnectionLimitsConfig, TransportHandle}, + tcp::config::Config as TcpConfig, + Transport, TransportEvent, KEEP_ALIVE_TIMEOUT, MAX_PARALLEL_DIALS, }, types::protocol::ProtocolName, PeerId, @@ -43,6 +44,7 @@ use crate::transport::webrtc::config::Config as WebRtcConfig; #[cfg(feature = "websocket")] use crate::transport::websocket::config::Config as WebSocketConfig; +use hickory_resolver::TokioResolver; use multiaddr::Multiaddr; use std::{collections::HashMap, sync::Arc, time::Duration}; @@ -83,6 +85,15 @@ pub struct ConfigBuilder { #[cfg(feature = "websocket")] websocket: Option, + /// List of custom transports. + custom_transports: Vec<( + &'static str, + fn( + TransportHandle, + Arc, + ) -> crate::Result<(Box>, Vec)>, + )>, + /// Keypair. keypair: Option, @@ -146,6 +157,7 @@ impl ConfigBuilder { webrtc: None, #[cfg(feature = "websocket")] websocket: None, + custom_transports: Vec::new(), keypair: None, ping: None, identify: None, @@ -191,6 +203,20 @@ impl ConfigBuilder { self } + /// Add a custom transport configuration, enabling the transport. + pub fn with_custom_transport( + mut self, + name: &'static str, + transport: fn( + TransportHandle, + Arc, + ) + -> crate::Result<(Box>, Vec)>, + ) -> Self { + self.custom_transports.push((name, transport)); + self + } + /// Add keypair. /// /// If no keypair is specified, litep2p creates a new keypair. @@ -305,6 +331,7 @@ impl ConfigBuilder { webrtc: self.webrtc.take(), #[cfg(feature = "websocket")] websocket: self.websocket.take(), + custom_transports: self.custom_transports, ping: self.ping.take(), identify: self.identify.take(), kademlia: self.kademlia.take(), @@ -339,6 +366,15 @@ pub struct Litep2pConfig { #[cfg(feature = "websocket")] pub(crate) websocket: Option, + /// Custom transports. + pub(crate) custom_transports: Vec<( + &'static str, + fn( + TransportHandle, + Arc, + ) -> crate::Result<(Box>, Vec)>, + )>, + /// Keypair. pub(crate) keypair: Keypair, diff --git a/src/lib.rs b/src/lib.rs index 77e8a8d4..f8e49736 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -372,7 +372,7 @@ impl Litep2p { if let Some(config) = litep2p_config.websocket.take() { let handle = transport_manager.transport_handle(Arc::clone(&litep2p_config.executor)); let (transport, transport_listen_addresses) = - ::new(handle, config, resolver)?; + ::new(handle, config, resolver.clone())?; for address in transport_listen_addresses { transport_manager.register_listen_address(address.clone()); @@ -383,6 +383,21 @@ impl Litep2p { .register_transport(SupportedTransport::WebSocket, Box::new(transport)); } + // enable custom transports + for (name, transport_factory) in litep2p_config.custom_transports { + let handle = transport_manager.transport_handle(Arc::clone(&litep2p_config.executor)); + + let (transport, transport_listen_addresses) = + transport_factory(handle, resolver.clone())?; + + for address in transport_listen_addresses { + transport_manager.register_listen_address(address.clone()); + listen_addresses.push(address.with(Protocol::P2p(*local_peer_id.as_ref()))); + } + + transport_manager.register_transport(SupportedTransport::Custom(name), transport); + } + // enable mdns if the config exists if let Some(config) = litep2p_config.mdns.take() { let mdns = Mdns::new(transport_handle, config, listen_addresses.clone()); diff --git a/src/transport/manager/mod.rs b/src/transport/manager/mod.rs index 5a44a1b4..27eb5ef0 100644 --- a/src/transport/manager/mod.rs +++ b/src/transport/manager/mod.rs @@ -162,6 +162,18 @@ impl TransportContext { ) { assert!(self.transports.insert(name, transport).is_none()); } + + /// Iterate through all transports + pub fn iter_mut( + &mut self, + ) -> impl Iterator< + Item = ( + &SupportedTransport, + &mut (dyn Transport + 'static), + ), + > { + self.transports.iter_mut().map(|(a, b)| (a, &mut **b)) + } } impl Stream for TransportContext { @@ -615,64 +627,6 @@ impl TransportManager { tracing::debug!(target: LOG_TARGET, address = ?address_record.address(), "dial address"); - let mut protocol_stack = address_record.as_ref().iter(); - match protocol_stack - .next() - .ok_or_else(|| Error::TransportNotSupported(address_record.address().clone()))? - { - Protocol::Ip4(_) | Protocol::Ip6(_) => {} - Protocol::Dns(_) | Protocol::Dns4(_) | Protocol::Dns6(_) => {} - transport => { - tracing::error!( - target: LOG_TARGET, - ?transport, - "invalid transport, expected `ip4`/`ip6`" - ); - return Err(Error::TransportNotSupported( - address_record.address().clone(), - )); - } - }; - - let supported_transport = match protocol_stack - .next() - .ok_or_else(|| Error::TransportNotSupported(address_record.address().clone()))? - { - Protocol::Tcp(_) => match protocol_stack.next() { - #[cfg(feature = "websocket")] - Some(Protocol::Ws(_)) | Some(Protocol::Wss(_)) => SupportedTransport::WebSocket, - Some(Protocol::P2p(_)) => SupportedTransport::Tcp, - _ => - return Err(Error::TransportNotSupported( - address_record.address().clone(), - )), - }, - #[cfg(feature = "quic")] - Protocol::Udp(_) => match protocol_stack - .next() - .ok_or_else(|| Error::TransportNotSupported(address_record.address().clone()))? - { - Protocol::QuicV1 => SupportedTransport::Quic, - _ => { - tracing::debug!(target: LOG_TARGET, address = ?address_record.address(), "expected `quic-v1`"); - return Err(Error::TransportNotSupported( - address_record.address().clone(), - )); - } - }, - protocol => { - tracing::error!( - target: LOG_TARGET, - ?protocol, - "invalid protocol" - ); - - return Err(Error::TransportNotSupported( - address_record.address().clone(), - )); - } - }; - // when constructing `AddressRecord`, `PeerId` was verified to be part of the address let remote_peer_id = PeerId::try_from_multiaddr(address_record.address()).expect("`PeerId` to exist"); @@ -699,12 +653,27 @@ impl TransportManager { }; } - self.transports - .get_mut(&supported_transport) - .ok_or(Error::TransportNotSupported( + let mut dailed = false; + + for (_, transport) in self.transports.iter_mut() { + if let Err(err) = transport.dial(connection_id, address_record.address().clone()) { + if let Error::AddressError(AddressError::InvalidProtocol) = err { + continue; + } + + return Err(err); + } + + dailed = true; + break; + } + + if !dailed { + return Err(Error::TransportNotSupported( address_record.address().clone(), - ))? - .dial(connection_id, address_record.address().clone())?; + )); + } + self.pending_connections.insert(connection_id, remote_peer_id); Ok(()) @@ -1687,7 +1656,7 @@ mod tests { } #[tokio::test] - async fn try_to_dial_over_disabled_transport() { + async fn try_to_dial_over_custom_transport() { let mut manager = TransportManagerBuilder::new().build(); let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); @@ -1700,10 +1669,7 @@ mod tests { Multihash::from_bytes(&PeerId::random().to_bytes()).unwrap(), )); - assert!(std::matches!( - manager.dial_address(address).await, - Err(Error::TransportNotSupported(_)) - )); + assert!(std::matches!(manager.dial_address(address).await, Ok(()))); } #[tokio::test] diff --git a/src/transport/manager/types.rs b/src/transport/manager/types.rs index 15eb2c50..98d03e96 100644 --- a/src/transport/manager/types.rs +++ b/src/transport/manager/types.rs @@ -37,6 +37,9 @@ pub enum SupportedTransport { /// WebSocket #[cfg(feature = "websocket")] WebSocket, + + /// Custom transport + Custom(&'static str), } /// Peer context. diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 79c582c0..a36d7c42 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -20,7 +20,7 @@ //! Transport protocol implementations provided by [`Litep2p`](`crate::Litep2p`). -use crate::{error::DialError, transport::manager::TransportHandle, types::ConnectionId, PeerId}; +use crate::{error::DialError, types::ConnectionId, PeerId}; use futures::Stream; use hickory_resolver::TokioResolver; @@ -42,7 +42,10 @@ pub(crate) mod dummy; pub(crate) mod manager; -pub use manager::limits::{ConnectionLimitsConfig, ConnectionLimitsError}; +pub use manager::{ + limits::{ConnectionLimitsConfig, ConnectionLimitsError}, + TransportHandle, +}; /// Timeout for opening a connection. pub(crate) const CONNECTION_OPEN_TIMEOUT: Duration = Duration::from_secs(10); @@ -119,7 +122,7 @@ impl Endpoint { /// Transport event. #[derive(Debug)] -pub(crate) enum TransportEvent { +pub enum TransportEvent { /// Fully negotiated connection established to remote peer. ConnectionEstablished { /// Peer ID. @@ -175,7 +178,7 @@ pub(crate) enum TransportEvent { }, } -pub(crate) trait TransportBuilder { +pub trait TransportBuilder { type Config: Debug; type Transport: Transport; @@ -189,7 +192,7 @@ pub(crate) trait TransportBuilder { Self: Sized; } -pub(crate) trait Transport: Stream + Unpin + Send { +pub trait Transport: Stream + Unpin + Send { /// Dial `address` and negotiate connection. fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()>; diff --git a/src/transport/tcp/mod.rs b/src/transport/tcp/mod.rs index 9d752433..4e999be1 100644 --- a/src/transport/tcp/mod.rs +++ b/src/transport/tcp/mod.rs @@ -94,7 +94,7 @@ enum RawConnectionResult { } /// TCP transport. -pub(crate) struct TcpTransport { +pub struct TcpTransport { /// Transport context. context: TransportHandle, diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 50026109..b88f8f32 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -18,7 +18,14 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use litep2p::{config::ConfigBuilder, transport::tcp::config::Config as TcpConfig}; +use std::sync::Arc; + +use hickory_resolver::TokioResolver; +use libp2p::Multiaddr; +use litep2p::{ + config::ConfigBuilder, + transport::{tcp::config::Config as TcpConfig, TransportEvent, TransportHandle}, +}; #[cfg(feature = "quic")] use litep2p::transport::quic::config::Config as QuicConfig; @@ -31,6 +38,18 @@ pub(crate) enum Transport { Quic(QuicConfig), #[cfg(feature = "websocket")] WebSocket(WebSocketConfig), + Custom( + ( + &'static str, + fn( + TransportHandle, + Arc, + ) -> litep2p::Result<( + Box>, + Vec, + )>, + ), + ), } pub(crate) fn add_transport(config: ConfigBuilder, transport: Transport) -> ConfigBuilder { @@ -40,5 +59,6 @@ pub(crate) fn add_transport(config: ConfigBuilder, transport: Transport) -> Conf Transport::Quic(transport) => config.with_quic(transport), #[cfg(feature = "websocket")] Transport::WebSocket(transport) => config.with_websocket(transport), + Transport::Custom((name, transport)) => config.with_custom_transport(name, transport), } } diff --git a/tests/connection/mod.rs b/tests/connection/mod.rs index ebe55169..05787d42 100644 --- a/tests/connection/mod.rs +++ b/tests/connection/mod.rs @@ -23,7 +23,10 @@ use litep2p::{ crypto::ed25519::Keypair, error::{DialError, Error, NegotiationError}, protocol::libp2p::ping::{Config as PingConfig, PingEvent}, - transport::tcp::config::Config as TcpConfig, + transport::{ + tcp::{config::Config as TcpConfig, TcpTransport}, + TransportBuilder, + }, Litep2p, Litep2pEvent, PeerId, }; @@ -88,6 +91,31 @@ async fn two_litep2ps_work_websocket() { .await; } +#[tokio::test] +async fn two_litep2ps_work_custom() { + let transport1 = Transport::Custom(("tcp1", |handle, resolver| { + let config = TcpConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0".parse().unwrap()], + ..Default::default() + }; + let (transport, transport_listen_addresses) = + ::new(handle, config, resolver)?; + Ok((Box::new(transport), transport_listen_addresses)) + })); + + let transport2 = Transport::Custom(("tcp2", |handle, resolver| { + let config = TcpConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0".parse().unwrap()], + ..Default::default() + }; + let (transport, transport_listen_addresses) = + ::new(handle, config, resolver)?; + Ok((Box::new(transport), transport_listen_addresses)) + })); + + two_litep2ps_work(transport1, transport2).await; +} + async fn two_litep2ps_work(transport1: Transport, transport2: Transport) { let _ = tracing_subscriber::fmt() .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) @@ -479,34 +507,6 @@ async fn attempt_to_dial_using_unsupported_transport_tcp() { )); } -#[cfg(feature = "quic")] -#[tokio::test] -async fn attempt_to_dial_using_unsupported_transport_quic() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (ping_config, _ping_event_stream) = PingConfig::default(); - let config = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_quic(Default::default()) - .with_libp2p_ping(ping_config) - .build(); - - let mut litep2p = Litep2p::new(config).unwrap(); - let address = Multiaddr::empty() - .with(Protocol::from(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p( - Multihash::from_bytes(&PeerId::random().to_bytes()).unwrap(), - )); - - assert!(std::matches!( - litep2p.dial_address(address.clone()).await, - Err(Error::TransportNotSupported(_)) - )); -} - #[tokio::test] async fn keep_alive_timeout_tcp() { keep_alive_timeout( @@ -1023,6 +1023,8 @@ async fn make_dummy_litep2p( Transport::Quic(config) => litep2p_config.with_quic(config), #[cfg(feature = "websocket")] Transport::WebSocket(config) => litep2p_config.with_websocket(config), + Transport::Custom((name, transport)) => + litep2p_config.with_custom_transport(name, transport), } .build(); diff --git a/tests/protocol/notification.rs b/tests/protocol/notification.rs index afb6d74d..4db42db1 100644 --- a/tests/protocol/notification.rs +++ b/tests/protocol/notification.rs @@ -821,6 +821,7 @@ async fn set_new_handshake(transport1: Transport, transport2: Transport) { Transport::Quic(config) => config1.with_quic(config), #[cfg(feature = "websocket")] Transport::WebSocket(config) => config1.with_websocket(config), + Transport::Custom((name, transport)) => config1.with_custom_transport(name, transport), } .build(); @@ -3707,6 +3708,10 @@ async fn dial_failure(transport1: Transport, transport2: Transport) { .with(Protocol::Ip6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))) .with(Protocol::Tcp(5)) .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))), + Transport::Custom(_) => Multiaddr::empty() + .with(Protocol::Ip6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))) + .with(Protocol::Tcp(5)) + .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))), }; let config2 = add_transport(config2, transport2).build(); diff --git a/tests/protocol/request_response.rs b/tests/protocol/request_response.rs index 208c0ed6..e80cbb00 100644 --- a/tests/protocol/request_response.rs +++ b/tests/protocol/request_response.rs @@ -2318,6 +2318,11 @@ async fn dial_failure(transport: Transport) { .with(Protocol::Tcp(5)) .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))) .with(Protocol::P2p(Multihash::from(peer))), + Transport::Custom(_) => Multiaddr::empty() + .with(Protocol::Ip6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))) + .with(Protocol::Tcp(5)) + .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))) + .with(Protocol::P2p(Multihash::from(peer))), }; let config = add_transport(litep2p_config, transport).build();