@@ -25,9 +25,10 @@ use crate::{
2525 codec:: unsigned_varint:: UnsignedVarint ,
2626 error:: { self , Error } ,
2727 multistream_select:: {
28+ drain_trailing_protocols,
2829 protocol:: {
2930 webrtc_encode_multistream_message, HeaderLine , Message , MessageIO , Protocol ,
30- ProtocolError ,
31+ ProtocolError , PROTO_MULTISTREAM_1_0 ,
3132 } ,
3233 Negotiated , NegotiationError ,
3334 } ,
@@ -349,30 +350,14 @@ pub enum ListenerSelectResult {
349350/// response and the negotiated protocol. If parsing fails or no match is found, return an error.
350351pub fn webrtc_listener_negotiate < ' a > (
351352 supported_protocols : & ' a mut impl Iterator < Item = & ' a ProtocolName > ,
352- payload : Bytes ,
353+ mut payload : Bytes ,
353354) -> crate :: Result < ListenerSelectResult > {
354- let payload = if payload. len ( ) > 2 && payload[ 0 ..payload. len ( ) - 2 ] != b"\n \n " [ ..] {
355- let mut buf = BytesMut :: from ( payload) ;
356- buf. extend_from_slice ( b"\n " ) ;
357- buf. freeze ( )
358- } else {
359- payload
360- } ;
361-
362- let Message :: Protocols ( protocols) = Message :: decode ( payload) . map_err ( |_| Error :: InvalidData ) ?
363- else {
364- return Err ( Error :: NegotiationError (
365- error:: NegotiationError :: MultistreamSelectError ( NegotiationError :: Failed ) ,
366- ) ) ;
367- } ;
355+ let protocols = drain_trailing_protocols ( payload) ?;
356+ let mut protocol_iter = protocols. into_iter ( ) ;
368357
369358 // skip the multistream-select header because it's not part of user protocols but verify it's
370359 // present
371- let mut protocol_iter = protocols. into_iter ( ) ;
372- let header =
373- Protocol :: try_from ( & b"/multistream/1.0.0" [ ..] ) . expect ( "valid multitstream-select header" ) ;
374-
375- if protocol_iter. next ( ) != Some ( header) {
360+ if protocol_iter. next ( ) != Some ( PROTO_MULTISTREAM_1_0 ) {
376361 return Err ( Error :: NegotiationError (
377362 error:: NegotiationError :: MultistreamSelectError ( NegotiationError :: Failed ) ,
378363 ) ) ;
@@ -410,6 +395,8 @@ pub fn webrtc_listener_negotiate<'a>(
410395#[ cfg( test) ]
411396mod tests {
412397 use super :: * ;
398+ use crate :: error;
399+ use bytes:: BufMut ;
413400
414401 #[ test]
415402 fn webrtc_listener_negotiate_works ( ) {
@@ -445,6 +432,21 @@ mod tests {
445432 ProtocolName :: from ( "/13371338/proto/3" ) ,
446433 ProtocolName :: from ( "/13371338/proto/4" ) ,
447434 ] ;
435+ // The invalid message is really two multistream-select messages inside one `WebRtcMessage`:
436+ // 1. the multistream-select header
437+ // 2. an "ls response" message (that does not contain another header)
438+ //
439+ // This is invalid for two reasons:
440+ // 1. It is malformed. Either the header is followed by one or more `Message::Protocol`
441+ // instances or the header is part of the "ls response".
442+ // 2. This sequence of messages is not spec compliant. A listener receives one of the
443+ // following on an inbound substream:
444+ // - a multistream-select header followed by a `Message::Protocol` instance
445+ // - a multistream-select header followed by an "ls" message (<length prefix><ls><\n>)
446+ //
447+ // `webrtc_listener_negotiate()` should reject this invalid message. The error can either be
448+ // `InvalidData` because the message is malformed or `StateMismatch` because the message is
449+ // not expected at this point in the protocol.
448450 let message = webrtc_encode_multistream_message ( std:: iter:: once ( Message :: Protocols ( vec ! [
449451 Protocol :: try_from( & b"/13371338/proto/1" [ ..] ) . unwrap( ) ,
450452 Protocol :: try_from( & b"/sup/proto/1" [ ..] ) . unwrap( ) ,
@@ -453,7 +455,13 @@ mod tests {
453455 . freeze ( ) ;
454456
455457 match webrtc_listener_negotiate ( & mut local_protocols. iter ( ) , message) {
456- Err ( error) => assert ! ( std:: matches!( error, Error :: InvalidData ) ) ,
458+ Err ( error) => assert ! ( std:: matches!(
459+ error,
460+ // something has gone off the rails here...
461+ Error :: NegotiationError ( error:: NegotiationError :: ParseError (
462+ error:: ParseError :: InvalidData
463+ ) ) ,
464+ ) ) ,
457465 _ => panic ! ( "invalid event" ) ,
458466 }
459467 }
@@ -474,7 +482,12 @@ mod tests {
474482 message. encode ( & mut bytes) . map_err ( |_| Error :: InvalidData ) . unwrap ( ) ;
475483
476484 match webrtc_listener_negotiate ( & mut local_protocols. iter ( ) , bytes. freeze ( ) ) {
477- Err ( error) => assert ! ( std:: matches!( error, Error :: InvalidData ) ) ,
485+ Err ( error) => assert ! ( std:: matches!(
486+ error,
487+ Error :: NegotiationError ( error:: NegotiationError :: ParseError (
488+ error:: ParseError :: InvalidData
489+ ) ) ,
490+ ) ) ,
478491 event => panic ! ( "invalid event: {event:?}" ) ,
479492 }
480493 }
@@ -491,14 +504,23 @@ mod tests {
491504
492505 // header line missing
493506 let mut bytes = BytesMut :: with_capacity ( 256 ) ;
494- let message = Message :: Protocols ( vec ! [
495- Protocol :: try_from( & b"/13371338/proto/1" [ ..] ) . unwrap( ) ,
496- Protocol :: try_from( & b"/sup/proto/1" [ ..] ) . unwrap( ) ,
497- ] ) ;
498- message. encode ( & mut bytes) . map_err ( |_| Error :: InvalidData ) . unwrap ( ) ;
507+ vec ! [ & b"/13371338/proto/1" [ ..] , & b"/sup/proto/1" [ ..] ]
508+ . into_iter ( )
509+ . for_each ( |proto| {
510+ bytes. put_u8 ( ( proto. len ( ) + 1 ) as u8 ) ;
511+
512+ Message :: Protocol ( Protocol :: try_from ( proto) . unwrap ( ) )
513+ . encode ( & mut bytes)
514+ . unwrap ( ) ;
515+ } ) ;
499516
500517 match webrtc_listener_negotiate ( & mut local_protocols. iter ( ) , bytes. freeze ( ) ) {
501- Err ( error) => assert ! ( std:: matches!( error, Error :: InvalidData ) ) ,
518+ Err ( error) => assert ! ( std:: matches!(
519+ error,
520+ Error :: NegotiationError ( error:: NegotiationError :: MultistreamSelectError (
521+ NegotiationError :: Failed
522+ ) )
523+ ) ) ,
502524 event => panic ! ( "invalid event: {event:?}" ) ,
503525 }
504526 }
0 commit comments