diff --git a/application_defined.go b/application_defined.go index ca5f844..35a9ae0 100644 --- a/application_defined.go +++ b/application_defined.go @@ -85,7 +85,7 @@ func (a *ApplicationDefined) Unmarshal(rawPacket []byte) error { return err } if len(rawPacket) < 12 { - return errPacketTooShort + return errPacketTooShortFor(a) } if int(header.Length+1)*4 != len(rawPacket) { diff --git a/errors.go b/errors.go index 05e6847..3dd1334 100644 --- a/errors.go +++ b/errors.go @@ -3,7 +3,11 @@ package rtcp -import "errors" +import ( + "errors" + "fmt" + "reflect" +) var ( errWrongMarshalSize = errors.New("rtcp: wrong marshal size") @@ -39,3 +43,53 @@ var ( errAppDefinedDataTooLarge = errors.New("rtcp: application defined data is too large") errAppDefinedInvalidName = errors.New("rtcp: application defined name must be 4 ASCII chars") ) + +type packetTooShortError struct { + packet string +} + +func (e packetTooShortError) Error() string { + if e.packet == "" { + return errPacketTooShort.Error() + } + + return fmt.Sprintf("%s (%s)", errPacketTooShort.Error(), e.packet) +} + +func (e packetTooShortError) Unwrap() error { + return errPacketTooShort +} + +func errPacketTooShortFor(packet any) error { + name := packetTypeName(packet) + if name == "" { + return errPacketTooShort + } + + return packetTooShortError{packet: name} +} + +func packetTypeName(packet any) string { + if packet == nil { + return "" + } + + if name, ok := packet.(string); ok { + return name + } + + typ := reflect.TypeOf(packet) + if typ == nil { + return "" + } + + for typ.Kind() == reflect.Pointer { + typ = typ.Elem() + } + + if typ.Name() != "" { + return typ.Name() + } + + return typ.String() +} diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..20b18ab --- /dev/null +++ b/errors_test.go @@ -0,0 +1,103 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestErrPacketTooShortForPacket(t *testing.T) { + t.Parallel() + + err := errPacketTooShortFor(&PictureLossIndication{}) + + assert.ErrorIs(t, err, errPacketTooShort) + assert.Contains(t, err.Error(), "PictureLossIndication") +} + +func TestErrPacketTooShortForString(t *testing.T) { + t.Parallel() + + err := errPacketTooShortFor("CustomPacket") + + assert.ErrorIs(t, err, errPacketTooShort) + assert.Contains(t, err.Error(), "CustomPacket") +} + +func TestErrPacketTooShortForNil(t *testing.T) { + t.Parallel() + + err := errPacketTooShortFor(nil) + + assert.True(t, errors.Is(err, errPacketTooShort)) + assert.Equal(t, errPacketTooShort, err) +} + +func TestPacketNameFromHeader(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + header Header + want string + }{ + { + name: "sender report", + header: Header{ + Type: TypeSenderReport, + }, + want: "SenderReport", + }, + { + name: "transport feedback", + header: Header{ + Type: TypeTransportSpecificFeedback, + Count: FormatTLN, + }, + want: "TransportLayerNack", + }, + { + name: "transport feedback fallback", + header: Header{ + Type: TypeTransportSpecificFeedback, + Count: 99, + }, + want: "TransportSpecificFeedback(FMT=99)", + }, + { + name: "payload specific fallback", + header: Header{ + Type: TypePayloadSpecificFeedback, + Count: 5, + }, + want: "PayloadSpecificFeedback(FMT=5)", + }, + { + name: "payload specific known", + header: Header{ + Type: TypePayloadSpecificFeedback, + Count: FormatPLI, + }, + want: "PictureLossIndication", + }, + { + name: "unknown type", + header: Header{ + Type: 199, + }, + want: "PacketType(199)", + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, test.want, packetNameFromHeader(test.header)) + }) + } +} diff --git a/full_intra_request.go b/full_intra_request.go index 0d93c5c..ed03d3d 100644 --- a/full_intra_request.go +++ b/full_intra_request.go @@ -51,7 +51,7 @@ func (p FullIntraRequest) Marshal() ([]byte, error) { // Unmarshal decodes the TransportLayerNack. func (p *FullIntraRequest) Unmarshal(rawPacket []byte) error { if len(rawPacket) < (headerLength + ssrcLength) { - return errPacketTooShort + return errPacketTooShortFor(p) } var header Header @@ -60,7 +60,7 @@ func (p *FullIntraRequest) Unmarshal(rawPacket []byte) error { } if len(rawPacket) < (headerLength + int(4*header.Length)) { - return errPacketTooShort + return errPacketTooShortFor(p) } if header.Type != TypePayloadSpecificFeedback || header.Count != FormatFIR { diff --git a/goodbye.go b/goodbye.go index c647743..8bdfbd9 100644 --- a/goodbye.go +++ b/goodbye.go @@ -90,14 +90,14 @@ func (g *Goodbye) Unmarshal(rawPacket []byte) error { } if getPadding(len(rawPacket)) != 0 { - return errPacketTooShort + return errPacketTooShortFor(g) } g.Sources = make([]uint32, header.Count) reasonOffset := int(headerLength + header.Count*ssrcLength) if reasonOffset > len(rawPacket) { - return errPacketTooShort + return errPacketTooShortFor(g) } for i := 0; i < int(header.Count); i++ { @@ -111,7 +111,7 @@ func (g *Goodbye) Unmarshal(rawPacket []byte) error { reasonEnd := reasonOffset + 1 + reasonLen if reasonEnd > len(rawPacket) { - return errPacketTooShort + return errPacketTooShortFor(g) } g.Reason = string(rawPacket[reasonOffset+1 : reasonEnd]) diff --git a/header.go b/header.go index f392e14..5a950bf 100644 --- a/header.go +++ b/header.go @@ -123,7 +123,7 @@ func (h Header) Marshal() ([]byte, error) { // Unmarshal decodes the Header from binary. func (h *Header) Unmarshal(rawPacket []byte) error { if len(rawPacket) < headerLength { - return errPacketTooShort + return errPacketTooShortFor(h) } /* diff --git a/packet.go b/packet.go index 637d9cd..e45c8df 100644 --- a/packet.go +++ b/packet.go @@ -3,6 +3,31 @@ package rtcp +import "fmt" + +var packetTypeNames = map[PacketType]string{ + TypeSenderReport: "SenderReport", + TypeReceiverReport: "ReceiverReport", + TypeSourceDescription: "SourceDescription", + TypeGoodbye: "Goodbye", + TypeExtendedReport: "ExtendedReport", + TypeApplicationDefined: "ApplicationDefined", +} + +var transportSpecificFeedbackNames = map[uint8]string{ + FormatTLN: "TransportLayerNack", + FormatRRR: "RapidResynchronizationRequest", + FormatTCC: "TransportLayerCC", + FormatCCFB: "CCFeedbackReport", +} + +var payloadSpecificFeedbackNames = map[uint8]string{ + FormatPLI: "PictureLossIndication", + FormatSLI: "SliceLossIndication", + FormatREMB: "ReceiverEstimatedMaximumBitrate", + FormatFIR: "FullIntraRequest", +} + // Packet represents an RTCP packet, a protocol used for out-of-band statistics // and control information for an RTP session. type Packet interface { @@ -70,7 +95,7 @@ func unmarshal(rawData []byte) (packet Packet, bytesprocessed int, err error) { bytesprocessed = int(header.Length+1) * 4 if bytesprocessed > len(rawData) { - return nil, 0, errPacketTooShort + return nil, 0, errPacketTooShortFor(packetNameFromHeader(header)) } inPacket := rawData[:bytesprocessed] @@ -129,3 +154,35 @@ func unmarshal(rawData []byte) (packet Packet, bytesprocessed int, err error) { return packet, bytesprocessed, err } + +func packetNameFromHeader(header Header) string { + if header.Type == TypeTransportSpecificFeedback { + return transportSpecificFeedbackName(header.Count) + } + + if header.Type == TypePayloadSpecificFeedback { + return payloadSpecificFeedbackName(header.Count) + } + + if name, ok := packetTypeNames[header.Type]; ok { + return name + } + + return fmt.Sprintf("PacketType(%d)", header.Type) +} + +func transportSpecificFeedbackName(count uint8) string { + if name, ok := transportSpecificFeedbackNames[count]; ok { + return name + } + + return fmt.Sprintf("TransportSpecificFeedback(FMT=%d)", count) +} + +func payloadSpecificFeedbackName(count uint8) string { + if name, ok := payloadSpecificFeedbackNames[count]; ok { + return name + } + + return fmt.Sprintf("PayloadSpecificFeedback(FMT=%d)", count) +} diff --git a/packet_test.go b/packet_test.go index b3fae06..cc8579f 100644 --- a/packet_test.go +++ b/packet_test.go @@ -136,4 +136,5 @@ func TestInvalidHeaderLength(t *testing.T) { _, err := Unmarshal(invalidPacket) assert.ErrorIs(t, err, errPacketTooShort) + assert.Contains(t, err.Error(), "ReceiverReport") } diff --git a/picture_loss_indication.go b/picture_loss_indication.go index 17379cd..87392a9 100644 --- a/picture_loss_indication.go +++ b/picture_loss_indication.go @@ -53,7 +53,7 @@ func (p PictureLossIndication) Marshal() ([]byte, error) { // Unmarshal decodes the PictureLossIndication from binary. func (p *PictureLossIndication) Unmarshal(rawPacket []byte) error { if len(rawPacket) < (headerLength + (ssrcLength * 2)) { - return errPacketTooShort + return errPacketTooShortFor(p) } var h Header diff --git a/rapid_resynchronization_request.go b/rapid_resynchronization_request.go index d422033..b1dcc79 100644 --- a/rapid_resynchronization_request.go +++ b/rapid_resynchronization_request.go @@ -54,7 +54,7 @@ func (p RapidResynchronizationRequest) Marshal() ([]byte, error) { // Unmarshal decodes the RapidResynchronizationRequest from binary. func (p *RapidResynchronizationRequest) Unmarshal(rawPacket []byte) error { if len(rawPacket) < (headerLength + (ssrcLength * 2)) { - return errPacketTooShort + return errPacketTooShortFor(p) } var h Header diff --git a/raw_packet.go b/raw_packet.go index 71ac152..4956443 100644 --- a/raw_packet.go +++ b/raw_packet.go @@ -17,7 +17,7 @@ func (r RawPacket) Marshal() ([]byte, error) { // Unmarshal decodes the packet from binary. func (r *RawPacket) Unmarshal(b []byte) error { if len(b) < (headerLength) { - return errPacketTooShort + return errPacketTooShortFor(r) } *r = b diff --git a/receiver_estimated_maximum_bitrate.go b/receiver_estimated_maximum_bitrate.go index cb6cdae..982964c 100644 --- a/receiver_estimated_maximum_bitrate.go +++ b/receiver_estimated_maximum_bitrate.go @@ -71,7 +71,7 @@ func (p ReceiverEstimatedMaximumBitrate) MarshalTo(buf []byte) (n int, err error size := p.MarshalSize() if len(buf) < size { - return 0, errPacketTooShort + return 0, errPacketTooShortFor(p) } buf[0] = 143 // v=2, p=0, fmt=15 @@ -158,7 +158,7 @@ func (p *ReceiverEstimatedMaximumBitrate) Unmarshal(buf []byte) (err error) { // 20 bytes is the size of the packet with no SSRCs if len(buf) < 20 { - return errPacketTooShort + return errPacketTooShortFor(p) } // version must be 2 @@ -195,7 +195,7 @@ func (p *ReceiverEstimatedMaximumBitrate) Unmarshal(buf []byte) (err error) { // Make sure the buffer is large enough. if len(buf) < size { - return errPacketTooShort + return errPacketTooShortFor(p) } // The sender SSRC is 32-bits diff --git a/receiver_report.go b/receiver_report.go index 84c682d..1e72d69 100644 --- a/receiver_report.go +++ b/receiver_report.go @@ -127,7 +127,7 @@ func (r *ReceiverReport) Unmarshal(rawPacket []byte) error { */ if len(rawPacket) < (headerLength + ssrcLength) { - return errPacketTooShort + return errPacketTooShortFor(r) } var header Header diff --git a/reception_report.go b/reception_report.go index f2b4548..ae91b41 100644 --- a/reception_report.go +++ b/reception_report.go @@ -93,7 +93,7 @@ func (r ReceptionReport) Marshal() ([]byte, error) { // Unmarshal decodes the ReceptionReport from binary. func (r *ReceptionReport) Unmarshal(rawPacket []byte) error { if len(rawPacket) < receptionReportLength { - return errPacketTooShort + return errPacketTooShortFor(r) } /* diff --git a/rfc8888.go b/rfc8888.go index ce5dfb1..b40262a 100644 --- a/rfc8888.go +++ b/rfc8888.go @@ -179,7 +179,7 @@ func (b CCFeedbackReport) String() string { // Unmarshal decodes the Congestion Control Feedback Report from binary. func (b *CCFeedbackReport) Unmarshal(rawPacket []byte) error { if len(rawPacket) < headerLength+ssrcLength+reportTimestampLength { - return errPacketTooShort + return errPacketTooShortFor(b) } var h Header diff --git a/sender_report.go b/sender_report.go index 0ad8699..142761f 100644 --- a/sender_report.go +++ b/sender_report.go @@ -171,7 +171,7 @@ func (r *SenderReport) Unmarshal(rawPacket []byte) error { */ if len(rawPacket) < (headerLength + srHeaderLength) { - return errPacketTooShort + return errPacketTooShortFor(r) } var header Header @@ -195,7 +195,7 @@ func (r *SenderReport) Unmarshal(rawPacket []byte) error { for i := 0; i < int(header.Count); i++ { rrEnd := offset + receptionReportLength if rrEnd > len(packetBody) { - return errPacketTooShort + return errPacketTooShortFor(r) } rrBody := packetBody[offset : offset+receptionReportLength] offset = rrEnd diff --git a/slice_loss_indication.go b/slice_loss_indication.go index 43e2d80..79bba69 100644 --- a/slice_loss_indication.go +++ b/slice_loss_indication.go @@ -64,7 +64,7 @@ func (p SliceLossIndication) Marshal() ([]byte, error) { // Unmarshal decodes the SliceLossIndication from binary. func (p *SliceLossIndication) Unmarshal(rawPacket []byte) error { if len(rawPacket) < (headerLength + ssrcLength) { - return errPacketTooShort + return errPacketTooShortFor(p) } var header Header @@ -73,7 +73,7 @@ func (p *SliceLossIndication) Unmarshal(rawPacket []byte) error { } if len(rawPacket) < (headerLength + int(4*header.Length)) { - return errPacketTooShort + return errPacketTooShortFor(p) } if header.Type != TypeTransportSpecificFeedback || header.Count != FormatSLI { diff --git a/source_description.go b/source_description.go index 6d6f2a0..201ad85 100644 --- a/source_description.go +++ b/source_description.go @@ -241,7 +241,7 @@ func (s *SourceDescriptionChunk) Unmarshal(rawPacket []byte) error { */ if len(rawPacket) < (sdesSourceLen + sdesTypeLen) { - return errPacketTooShort + return errPacketTooShortFor(s) } s.Source = binary.BigEndian.Uint32(rawPacket) @@ -259,7 +259,7 @@ func (s *SourceDescriptionChunk) Unmarshal(rawPacket []byte) error { i += it.Len() } - return errPacketTooShort + return errPacketTooShortFor(s) } func (s SourceDescriptionChunk) len() int { @@ -338,14 +338,14 @@ func (s *SourceDescriptionItem) Unmarshal(rawPacket []byte) error { */ if len(rawPacket) < (sdesTypeLen + sdesOctetCountLen) { - return errPacketTooShort + return errPacketTooShortFor(s) } s.Type = SDESType(rawPacket[sdesTypeOffset]) octetCount := int(rawPacket[sdesOctetCountOffset]) if sdesTextOffset+octetCount > len(rawPacket) { - return errPacketTooShort + return errPacketTooShortFor(s) } txtBytes := rawPacket[sdesTextOffset : sdesTextOffset+octetCount] diff --git a/transport_layer_cc.go b/transport_layer_cc.go index de464cb..8b93ef9 100644 --- a/transport_layer_cc.go +++ b/transport_layer_cc.go @@ -461,7 +461,7 @@ func (t TransportLayerCC) Marshal() ([]byte, error) { //nolint:gocognit,cyclop func (t *TransportLayerCC) Unmarshal(rawPacket []byte) error { if len(rawPacket) < (headerLength + ssrcLength) { - return errPacketTooShort + return errPacketTooShortFor(t) } if err := t.Header.Unmarshal(rawPacket); err != nil { @@ -473,11 +473,11 @@ func (t *TransportLayerCC) Unmarshal(rawPacket []byte) error { totalLength := 4 * (t.Header.Length + 1) if totalLength < headerLength+packetChunkOffset { - return errPacketTooShort + return errPacketTooShortFor(t) } if len(rawPacket) < int(totalLength) { - return errPacketTooShort + return errPacketTooShortFor(t) } if t.Header.Type != TypeTransportSpecificFeedback || t.Header.Count != FormatTCC { @@ -495,7 +495,7 @@ func (t *TransportLayerCC) Unmarshal(rawPacket []byte) error { var processedPacketNum uint16 for processedPacketNum < t.PacketStatusCount { if packetStatusPos+packetStatusChunkLength >= totalLength { - return errPacketTooShort + return errPacketTooShortFor(t) } typ := getNBitsFromByte(rawPacket[packetStatusPos : packetStatusPos+1][0], 0, 1) var iPacketStatus PacketStatusChunk @@ -548,7 +548,7 @@ func (t *TransportLayerCC) Unmarshal(rawPacket []byte) error { for _, delta := range t.RecvDeltas { if delta.Type == TypeTCCPacketReceivedSmallDelta { if recvDeltasPos+1 > totalLength { - return errPacketTooShort + return errPacketTooShortFor(t) } err := delta.Unmarshal(rawPacket[recvDeltasPos : recvDeltasPos+1]) if err != nil { @@ -558,7 +558,7 @@ func (t *TransportLayerCC) Unmarshal(rawPacket []byte) error { } if delta.Type == TypeTCCPacketReceivedLargeDelta { if recvDeltasPos+2 > totalLength { - return errPacketTooShort + return errPacketTooShortFor(t) } err := delta.Unmarshal(rawPacket[recvDeltasPos : recvDeltasPos+2]) if err != nil { diff --git a/transport_layer_nack.go b/transport_layer_nack.go index c0e3b8f..19d52eb 100644 --- a/transport_layer_nack.go +++ b/transport_layer_nack.go @@ -123,7 +123,7 @@ func (p TransportLayerNack) Marshal() ([]byte, error) { // Unmarshal decodes the TransportLayerNack from binary. func (p *TransportLayerNack) Unmarshal(rawPacket []byte) error { if len(rawPacket) < (headerLength + ssrcLength) { - return errPacketTooShort + return errPacketTooShortFor(p) } var header Header @@ -132,7 +132,7 @@ func (p *TransportLayerNack) Unmarshal(rawPacket []byte) error { } if len(rawPacket) < (headerLength + int(4*header.Length)) { - return errPacketTooShort + return errPacketTooShortFor(p) } if header.Type != TypeTransportSpecificFeedback || header.Count != FormatTLN {