Skip to content

Commit c0a44c0

Browse files
committed
avoid Message::Protocols in webrtc_listener_negotiate()
1 parent 90f0c86 commit c0a44c0

File tree

4 files changed

+123
-87
lines changed

4 files changed

+123
-87
lines changed

src/multistream_select/dialer_select.rs

Lines changed: 5 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use crate::{
2424
codec::unsigned_varint::UnsignedVarint,
2525
error::{self, Error, ParseError, SubstreamError},
2626
multistream_select::{
27+
drain_trailing_protocols,
2728
protocol::{
2829
webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol,
2930
ProtocolError, PROTO_MULTISTREAM_1_0,
@@ -418,6 +419,10 @@ impl WebRtcDialerState {
418419
}
419420
}
420421
(HandshakeState::WaitingProtocol, Some(protocol)) => {
422+
if protocol == PROTO_MULTISTREAM_1_0 {
423+
return Err(crate::error::NegotiationError::StateMismatch);
424+
}
425+
421426
if self.protocol.as_bytes() == protocol.as_ref() {
422427
return Ok(HandshakeResult::Succeeded(self.protocol.clone()));
423428
}
@@ -440,68 +445,12 @@ impl WebRtcDialerState {
440445
}
441446
}
442447

443-
fn drain_trailing_protocols(
444-
mut remaining: Bytes,
445-
) -> Result<Vec<Protocol>, error::NegotiationError> {
446-
let mut protocols = vec![];
447-
448-
loop {
449-
if remaining.is_empty() {
450-
break;
451-
}
452-
453-
let (len, tail) = unsigned_varint::decode::usize(&remaining).map_err(|error| {
454-
tracing::debug!(
455-
target: LOG_TARGET,
456-
?error,
457-
message = ?remaining,
458-
"Failed to decode length-prefix in multistream message");
459-
error::NegotiationError::ParseError(ParseError::InvalidData)
460-
})?;
461-
462-
if len > tail.len() {
463-
tracing::debug!(
464-
target: LOG_TARGET,
465-
message = ?tail,
466-
length_prefix = len,
467-
actual_length = tail.len(),
468-
"Truncated multistream message",
469-
);
470-
471-
return Err(error::NegotiationError::ParseError(ParseError::InvalidData));
472-
}
473-
474-
let len_size = remaining.len() - tail.len();
475-
let payload = remaining.slice(len_size..len_size + len);
476-
477-
match Message::decode(payload) {
478-
Ok(Message::Protocol(protocol)) => protocols.push(protocol),
479-
Err(error) => {
480-
tracing::debug!(
481-
target: LOG_TARGET,
482-
?error,
483-
message = ?tail[..len],
484-
"Failed to decode multistream message",
485-
);
486-
return Err(error::NegotiationError::ParseError(ParseError::InvalidData));
487-
}
488-
_ => return Err(error::NegotiationError::StateMismatch),
489-
}
490-
491-
remaining = remaining.slice(len_size + len..);
492-
}
493-
494-
Ok(protocols)
495-
}
496-
497448
#[cfg(test)]
498449
mod tests {
499450
use super::*;
500451
use crate::multistream_select::{listener_select_proto, protocol::MSG_MULTISTREAM_1_0};
501452
use bytes::BufMut;
502453
use std::time::Duration;
503-
use tokio::net::{TcpListener, TcpStream};
504-
505454
#[tokio::test]
506455
async fn select_proto_basic() {
507456
async fn run(version: Version) {

src/multistream_select/listener_select.rs

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ use crate::{
2525
codec::unsigned_varint::UnsignedVarint,
2626
error::{self, Error},
2727
multistream_select::{
28+
drain_trailing_protocols,
2829
protocol::{
2930
webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol,
30-
ProtocolError,
31+
ProtocolError, PROTO_MULTISTREAM_1_0,
3132
},
3233
Negotiated, NegotiationError,
3334
},
@@ -349,30 +350,14 @@ pub enum ListenerSelectResult {
349350
/// response and the negotiated protocol. If parsing fails or no match is found, return an error.
350351
pub fn webrtc_listener_negotiate<'a>(
351352
supported_protocols: &'a mut impl Iterator<Item = &'a ProtocolName>,
352-
payload: Bytes,
353+
mut payload: Bytes,
353354
) -> crate::Result<ListenerSelectResult> {
354-
let payload = if payload.len() > 2 && payload[0..payload.len() - 2] != b"\n\n"[..] {
355-
let mut buf = BytesMut::from(payload);
356-
buf.extend_from_slice(b"\n");
357-
buf.freeze()
358-
} else {
359-
payload
360-
};
361-
362-
let Message::Protocols(protocols) = Message::decode(payload).map_err(|_| Error::InvalidData)?
363-
else {
364-
return Err(Error::NegotiationError(
365-
error::NegotiationError::MultistreamSelectError(NegotiationError::Failed),
366-
));
367-
};
355+
let protocols = drain_trailing_protocols(payload)?;
356+
let mut protocol_iter = protocols.into_iter();
368357

369358
// skip the multistream-select header because it's not part of user protocols but verify it's
370359
// present
371-
let mut protocol_iter = protocols.into_iter();
372-
let header =
373-
Protocol::try_from(&b"/multistream/1.0.0"[..]).expect("valid multitstream-select header");
374-
375-
if protocol_iter.next() != Some(header) {
360+
if protocol_iter.next() != Some(PROTO_MULTISTREAM_1_0) {
376361
return Err(Error::NegotiationError(
377362
error::NegotiationError::MultistreamSelectError(NegotiationError::Failed),
378363
));
@@ -410,6 +395,8 @@ pub fn webrtc_listener_negotiate<'a>(
410395
#[cfg(test)]
411396
mod tests {
412397
use super::*;
398+
use crate::error;
399+
use bytes::BufMut;
413400

414401
#[test]
415402
fn webrtc_listener_negotiate_works() {
@@ -445,6 +432,21 @@ mod tests {
445432
ProtocolName::from("/13371338/proto/3"),
446433
ProtocolName::from("/13371338/proto/4"),
447434
];
435+
// The invalid message is really two multistream-select messages inside one `WebRtcMessage`:
436+
// 1. the multistream-select header
437+
// 2. an "ls response" message (that does not contain another header)
438+
//
439+
// This is invalid for two reasons:
440+
// 1. It is malformed. Either the header is followed by one or more `Message::Protocol`
441+
// instances or the header is part of the "ls response".
442+
// 2. This sequence of messages is not spec compliant. A listener receives one of the
443+
// following on an inbound substream:
444+
// - a multistream-select header followed by a `Message::Protocol` instance
445+
// - a multistream-select header followed by an "ls" message (<length prefix><ls><\n>)
446+
//
447+
// `webrtc_listener_negotiate()` should reject this invalid message. The error can either be
448+
// `InvalidData` because the message is malformed or `StateMismatch` because the message is
449+
// not expected at this point in the protocol.
448450
let message = webrtc_encode_multistream_message(std::iter::once(Message::Protocols(vec![
449451
Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
450452
Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(),
@@ -453,7 +455,13 @@ mod tests {
453455
.freeze();
454456

455457
match webrtc_listener_negotiate(&mut local_protocols.iter(), message) {
456-
Err(error) => assert!(std::matches!(error, Error::InvalidData)),
458+
Err(error) => assert!(std::matches!(
459+
error,
460+
// something has gone off the rails here...
461+
Error::NegotiationError(error::NegotiationError::ParseError(
462+
error::ParseError::InvalidData
463+
)),
464+
)),
457465
_ => panic!("invalid event"),
458466
}
459467
}
@@ -474,7 +482,12 @@ mod tests {
474482
message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();
475483

476484
match webrtc_listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
477-
Err(error) => assert!(std::matches!(error, Error::InvalidData)),
485+
Err(error) => assert!(std::matches!(
486+
error,
487+
Error::NegotiationError(error::NegotiationError::ParseError(
488+
error::ParseError::InvalidData
489+
)),
490+
)),
478491
event => panic!("invalid event: {event:?}"),
479492
}
480493
}
@@ -491,14 +504,23 @@ mod tests {
491504

492505
// header line missing
493506
let mut bytes = BytesMut::with_capacity(256);
494-
let message = Message::Protocols(vec![
495-
Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
496-
Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(),
497-
]);
498-
message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();
507+
vec![&b"/13371338/proto/1"[..], &b"/sup/proto/1"[..]]
508+
.into_iter()
509+
.for_each(|proto| {
510+
bytes.put_u8((proto.len() + 1) as u8);
511+
512+
Message::Protocol(Protocol::try_from(proto).unwrap())
513+
.encode(&mut bytes)
514+
.unwrap();
515+
});
499516

500517
match webrtc_listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
501-
Err(error) => assert!(std::matches!(error, Error::InvalidData)),
518+
Err(error) => assert!(std::matches!(
519+
error,
520+
Error::NegotiationError(error::NegotiationError::MultistreamSelectError(
521+
NegotiationError::Failed
522+
))
523+
)),
502524
event => panic!("invalid event: {event:?}"),
503525
}
504526
}

src/multistream_select/mod.rs

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,21 @@ mod listener_select;
7575
mod negotiated;
7676
mod protocol;
7777

78+
use crate::error::{self, ParseError};
7879
pub use crate::multistream_select::{
7980
dialer_select::{dialer_select_proto, DialerSelectFuture, HandshakeResult, WebRtcDialerState},
8081
listener_select::{
8182
listener_select_proto, webrtc_listener_negotiate, ListenerSelectFuture,
8283
ListenerSelectResult,
8384
},
8485
negotiated::{Negotiated, NegotiatedComplete, NegotiationError},
85-
protocol::{HeaderLine, Message, Protocol, ProtocolError},
86+
protocol::{HeaderLine, Message, Protocol, ProtocolError, PROTO_MULTISTREAM_1_0},
8687
};
8788

89+
use bytes::Bytes;
90+
91+
const LOG_TARGET: &str = "litep2p::multistream-select";
92+
8893
/// Supported multistream-select versions.
8994
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
9095
pub enum Version {
@@ -132,3 +137,63 @@ impl Default for Version {
132137
Version::V1
133138
}
134139
}
140+
141+
// This function is only used in the WebRTC transport. It expects one or more multistream-select
142+
// messages in `remaining` and returns a list of protocols that were decoded from them.
143+
fn drain_trailing_protocols(
144+
mut remaining: Bytes,
145+
) -> Result<Vec<Protocol>, error::NegotiationError> {
146+
let mut protocols = vec![];
147+
148+
loop {
149+
if remaining.is_empty() {
150+
break;
151+
}
152+
153+
let (len, tail) = unsigned_varint::decode::usize(&remaining).map_err(|error| {
154+
tracing::debug!(
155+
target: LOG_TARGET,
156+
?error,
157+
message = ?remaining,
158+
"Failed to decode length-prefix in multistream message");
159+
error::NegotiationError::ParseError(ParseError::InvalidData)
160+
})?;
161+
162+
if len > tail.len() {
163+
tracing::debug!(
164+
target: LOG_TARGET,
165+
message = ?tail,
166+
length_prefix = len,
167+
actual_length = tail.len(),
168+
"Truncated multistream message",
169+
);
170+
171+
return Err(error::NegotiationError::ParseError(ParseError::InvalidData));
172+
}
173+
174+
let len_size = remaining.len() - tail.len();
175+
let payload = remaining.slice(len_size..len_size + len);
176+
let res = Message::decode(payload);
177+
178+
match res {
179+
Ok(Message::Header(HeaderLine::V1)) => protocols.push(PROTO_MULTISTREAM_1_0),
180+
Ok(Message::Protocol(protocol)) => protocols.push(protocol),
181+
Ok(Message::Protocols(_)) =>
182+
return Err(error::NegotiationError::ParseError(ParseError::InvalidData)),
183+
Err(error) => {
184+
tracing::debug!(
185+
target: LOG_TARGET,
186+
?error,
187+
message = ?tail[..len],
188+
"Failed to decode multistream message",
189+
);
190+
return Err(error::NegotiationError::ParseError(ParseError::InvalidData));
191+
}
192+
_ => return Err(error::NegotiationError::StateMismatch),
193+
}
194+
195+
remaining = remaining.slice(len_size + len..);
196+
}
197+
198+
Ok(protocols)
199+
}

src/multistream_select/protocol.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ const MSG_PROTOCOL_NA: &[u8] = b"na\n";
5555
/// The encoded form of a multistream-select 'ls' message.
5656
const MSG_LS: &[u8] = b"ls\n";
5757
/// A Protocol instance for the `/multistream/1.0.0` header line.
58-
pub const PROTO_MULTISTREAM_1_0: Protocol = Protocol(Bytes::from_static(MSG_MULTISTREAM_1_0));
58+
pub const PROTO_MULTISTREAM_1_0: Protocol = Protocol(Bytes::from_static(b"/multistream/1.0.0"));
5959
/// Logging target.
6060
const LOG_TARGET: &str = "litep2p::multistream-select";
6161

0 commit comments

Comments
 (0)