diff --git a/application_defined.go b/application_defined.go index 77a1193..f501557 100644 --- a/application_defined.go +++ b/application_defined.go @@ -120,3 +120,9 @@ func (a *ApplicationDefined) MarshalSize() int { } return 12 + dataLength + paddingSize } + +// Release returns the packet to its pool and resets it +func (p *ApplicationDefined) Release() { + *p = ApplicationDefined{} // Reset the packet + applicationDefinedPool.Put(p) +} diff --git a/compound_packet.go b/compound_packet.go index a621c61..497f56e 100644 --- a/compound_packet.go +++ b/compound_packet.go @@ -159,3 +159,11 @@ func (c CompoundPacket) String() string { out = strings.TrimSuffix(strings.ReplaceAll(out, "\n", "\n\t"), "\t") return out } + +// Release returns the packet to its pool and resets it +func (p *CompoundPacket) Release() { + // CompoundPacket is a slice of pointers, so we need to release each one + for _, packet := range *p { + packet.Release() + } +} diff --git a/extended_report.go b/extended_report.go index 0f9ee2c..c864b6f 100644 --- a/extended_report.go +++ b/extended_report.go @@ -657,3 +657,9 @@ func (x *ExtendedReport) DestinationSSRC() []uint32 { func (x *ExtendedReport) String() string { return stringify(x) } + +// Release returns the packet to its pool and resets it +func (p *ExtendedReport) Release() { + *p = ExtendedReport{} // Reset the packet + extendedReportPool.Put(p) +} diff --git a/full_intra_request.go b/full_intra_request.go index 7c67c50..36d68ef 100644 --- a/full_intra_request.go +++ b/full_intra_request.go @@ -114,3 +114,9 @@ func (p *FullIntraRequest) DestinationSSRC() []uint32 { } return ssrcs } + +// Release returns the packet to its pool and resets it +func (p *FullIntraRequest) Release() { + *p = FullIntraRequest{} // Reset the packet + fullIntraRequestPool.Put(p) +} diff --git a/goodbye.go b/goodbye.go index f87731c..39e8d56 100644 --- a/goodbye.go +++ b/goodbye.go @@ -161,3 +161,9 @@ func (g Goodbye) String() string { return out } + +// Release returns the packet to its pool and resets it +func (p *Goodbye) Release() { + *p = Goodbye{} // Reset the packet + goodbyePool.Put(p) +} diff --git a/packet.go b/packet.go index fc3c9a3..ea43ffb 100644 --- a/packet.go +++ b/packet.go @@ -3,6 +3,11 @@ package rtcp +import ( + "bytes" + "sync" +) + // Packet represents an RTCP packet, a protocol used for out-of-band statistics and control information for an RTP session type Packet interface { // DestinationSSRC returns an array of SSRC values that this packet refers to. @@ -11,8 +16,30 @@ type Packet interface { Marshal() ([]byte, error) Unmarshal(rawPacket []byte) error MarshalSize() int + + // Release returns the packet to its pool + Release() } +//nolint:gochecknoglobals +var ( + senderReportPool = sync.Pool{New: func() interface{} { return new(SenderReport) }} + receiverReportPool = sync.Pool{New: func() interface{} { return new(ReceiverReport) }} + sourceDescriptionPool = sync.Pool{New: func() interface{} { return new(SourceDescription) }} + goodbyePool = sync.Pool{New: func() interface{} { return new(Goodbye) }} + transportLayerNackPool = sync.Pool{New: func() interface{} { return new(TransportLayerNack) }} + rapidResynchronizationRequestPool = sync.Pool{New: func() interface{} { return new(RapidResynchronizationRequest) }} + transportLayerCCPool = sync.Pool{New: func() interface{} { return new(TransportLayerCC) }} + ccFeedbackReportPool = sync.Pool{New: func() interface{} { return new(CCFeedbackReport) }} + pictureLossIndicationPool = sync.Pool{New: func() interface{} { return new(PictureLossIndication) }} + sliceLossIndicationPool = sync.Pool{New: func() interface{} { return new(SliceLossIndication) }} + receiverEstimatedMaximumBitratePool = sync.Pool{New: func() interface{} { return new(ReceiverEstimatedMaximumBitrate) }} + fullIntraRequestPool = sync.Pool{New: func() interface{} { return new(FullIntraRequest) }} + extendedReportPool = sync.Pool{New: func() interface{} { return new(ExtendedReport) }} + applicationDefinedPool = sync.Pool{New: func() interface{} { return new(ApplicationDefined) }} + rawPacketPool = sync.Pool{New: func() interface{} { return new(RawPacket) }} +) + // Unmarshal takes an entire udp datagram (which may consist of multiple RTCP packets) and // returns the unmarshaled packets it contains. // @@ -20,10 +47,17 @@ type Packet interface { // will be returned. Otherwise, the underlying type of the returned packet will be // CompoundPacket. func Unmarshal(rawData []byte) ([]Packet, error) { - var packets []Packet + // Preallocate a slice with a reasonable initial capacity + estimatedPackets := len(rawData) / 100 // Estimate based on average packet size + packets := make([]Packet, 0, estimatedPackets) + for len(rawData) != 0 { p, processed, err := unmarshal(rawData) if err != nil { + // Release already allocated packets in case of error + for _, packet := range packets { + packet.Release() + } return nil, err } @@ -43,15 +77,16 @@ func Unmarshal(rawData []byte) ([]Packet, error) { // Marshal takes an array of Packets and serializes them to a single buffer func Marshal(packets []Packet) ([]byte, error) { - out := make([]byte, 0) + var buf bytes.Buffer for _, p := range packets { data, err := p.Marshal() if err != nil { return nil, err } - out = append(out, data...) + buf.Write(data) + p.Release() } - return out, nil + return buf.Bytes(), nil } // unmarshal is a factory which pulls the first RTCP packet from a bytestream, @@ -72,53 +107,53 @@ func unmarshal(rawData []byte) (packet Packet, bytesprocessed int, err error) { switch h.Type { case TypeSenderReport: - packet = new(SenderReport) + packet = senderReportPool.Get().(*SenderReport) case TypeReceiverReport: - packet = new(ReceiverReport) + packet = receiverReportPool.Get().(*ReceiverReport) case TypeSourceDescription: - packet = new(SourceDescription) + packet = sourceDescriptionPool.Get().(*SourceDescription) case TypeGoodbye: - packet = new(Goodbye) + packet = goodbyePool.Get().(*Goodbye) case TypeTransportSpecificFeedback: switch h.Count { case FormatTLN: - packet = new(TransportLayerNack) + packet = transportLayerNackPool.Get().(*TransportLayerNack) case FormatRRR: - packet = new(RapidResynchronizationRequest) + packet = rapidResynchronizationRequestPool.Get().(*RapidResynchronizationRequest) case FormatTCC: - packet = new(TransportLayerCC) + packet = transportLayerCCPool.Get().(*TransportLayerCC) case FormatCCFB: - packet = new(CCFeedbackReport) + packet = ccFeedbackReportPool.Get().(*CCFeedbackReport) default: - packet = new(RawPacket) + packet = rawPacketPool.Get().(*RawPacket) } case TypePayloadSpecificFeedback: switch h.Count { case FormatPLI: - packet = new(PictureLossIndication) + packet = pictureLossIndicationPool.Get().(*PictureLossIndication) case FormatSLI: - packet = new(SliceLossIndication) + packet = sliceLossIndicationPool.Get().(*SliceLossIndication) case FormatREMB: - packet = new(ReceiverEstimatedMaximumBitrate) + packet = receiverEstimatedMaximumBitratePool.Get().(*ReceiverEstimatedMaximumBitrate) case FormatFIR: - packet = new(FullIntraRequest) + packet = fullIntraRequestPool.Get().(*FullIntraRequest) default: - packet = new(RawPacket) + packet = rawPacketPool.Get().(*RawPacket) } case TypeExtendedReport: - packet = new(ExtendedReport) + packet = extendedReportPool.Get().(*ExtendedReport) case TypeApplicationDefined: - packet = new(ApplicationDefined) + packet = applicationDefinedPool.Get().(*ApplicationDefined) default: - packet = new(RawPacket) + packet = rawPacketPool.Get().(*RawPacket) } err = packet.Unmarshal(inPacket) diff --git a/packet_test.go b/packet_test.go index 6eb0a83..5aa2de8 100644 --- a/packet_test.go +++ b/packet_test.go @@ -83,6 +83,21 @@ func realPacket() []byte { } } +func BenchmarkUnmarshal(b *testing.B) { + packetData := realPacket() + for i := 0; i < b.N; i++ { + pkts, err := Unmarshal(packetData) + if err != nil { + b.Fatalf("Error unmarshalling packets: %s", err) + } + + for _, pkt := range pkts { + pkt.Release() + } + + } +} + func TestUnmarshal(t *testing.T) { packet, err := Unmarshal(realPacket()) if err != nil { @@ -144,3 +159,111 @@ func TestInvalidHeaderLength(t *testing.T) { t.Fatalf("Unmarshal(nil) err = %v, want %v", got, want) } } + +func TestPacketPool(t *testing.T) { + t.Run("SenderReport", func(t *testing.T) { + sr := senderReportPool.Get() + p, ok := sr.(*SenderReport) + assert.True(t, ok) + + p.Release() + }) + + t.Run("ReceiverReport", func(t *testing.T) { + rr := receiverReportPool.Get() + p, ok := rr.(*ReceiverReport) + assert.True(t, ok) + p.Release() + }) + + t.Run("SourceDescription", func(t *testing.T) { + sd := sourceDescriptionPool.Get() + p, ok := sd.(*SourceDescription) + assert.True(t, ok) + p.Release() + }) + + t.Run("Goodbye", func(t *testing.T) { + gb := goodbyePool.Get() + p, ok := gb.(*Goodbye) + assert.True(t, ok) + p.Release() + }) + + t.Run("TransportLayerNack", func(t *testing.T) { + tln := transportLayerNackPool.Get() + p, ok := tln.(*TransportLayerNack) + assert.True(t, ok) + p.Release() + }) + + t.Run("RapidResynchronizationRequest", func(t *testing.T) { + rrr := rapidResynchronizationRequestPool.Get() + p, ok := rrr.(*RapidResynchronizationRequest) + assert.True(t, ok) + p.Release() + }) + + t.Run("TransportLayerCC", func(t *testing.T) { + tcc := transportLayerCCPool.Get() + p, ok := tcc.(*TransportLayerCC) + assert.True(t, ok) + p.Release() + }) + + t.Run("CCFeedbackReport", func(t *testing.T) { + ccfb := ccFeedbackReportPool.Get() + p, ok := ccfb.(*CCFeedbackReport) + assert.True(t, ok) + p.Release() + }) + + t.Run("PictureLossIndication", func(t *testing.T) { + pli := pictureLossIndicationPool.Get() + p, ok := pli.(*PictureLossIndication) + assert.True(t, ok) + p.Release() + }) + + t.Run("SliceLossIndication", func(t *testing.T) { + sli := sliceLossIndicationPool.Get() + p, ok := sli.(*SliceLossIndication) + assert.True(t, ok) + p.Release() + }) + + t.Run("ReceiverEstimatedMaximumBitrate", func(t *testing.T) { + remb := receiverEstimatedMaximumBitratePool.Get() + p, ok := remb.(*ReceiverEstimatedMaximumBitrate) + assert.True(t, ok) + p.Release() + }) + + t.Run("FullIntraRequest", func(t *testing.T) { + fir := fullIntraRequestPool.Get() + p, ok := fir.(*FullIntraRequest) + assert.True(t, ok) + p.Release() + }) + + t.Run("ExtendedReport", func(t *testing.T) { + er := extendedReportPool.Get() + p, ok := er.(*ExtendedReport) + assert.True(t, ok) + p.Release() + }) + + t.Run("ApplicationDefined", func(t *testing.T) { + ad := applicationDefinedPool.Get() + p, ok := ad.(*ApplicationDefined) + assert.True(t, ok) + p.Release() + }) + + t.Run("RawPacket", func(t *testing.T) { + rp := rawPacketPool.Get() + p, ok := rp.(*RawPacket) + assert.True(t, ok) + p.Release() + }) +} diff --git a/picture_loss_indication.go b/picture_loss_indication.go index 56a7de2..3370fc6 100644 --- a/picture_loss_indication.go +++ b/picture_loss_indication.go @@ -91,3 +91,9 @@ func (p *PictureLossIndication) String() string { func (p *PictureLossIndication) DestinationSSRC() []uint32 { return []uint32{p.MediaSSRC} } + +// Release returns the packet to its pool and resets it +func (p *PictureLossIndication) Release() { + *p = PictureLossIndication{} // Reset the packet + pictureLossIndicationPool.Put(p) +} diff --git a/rapid_resynchronization_request.go b/rapid_resynchronization_request.go index dc67d49..92d78e3 100644 --- a/rapid_resynchronization_request.go +++ b/rapid_resynchronization_request.go @@ -92,3 +92,9 @@ func (p *RapidResynchronizationRequest) DestinationSSRC() []uint32 { func (p *RapidResynchronizationRequest) String() string { return fmt.Sprintf("RapidResynchronizationRequest %x %x", p.SenderSSRC, p.MediaSSRC) } + +// Release returns the packet to its pool and resets it +func (p *RapidResynchronizationRequest) Release() { + *p = RapidResynchronizationRequest{} // Reset the packet + rapidResynchronizationRequestPool.Put(p) +} diff --git a/raw_packet.go b/raw_packet.go index eafb034..2f58f9b 100644 --- a/raw_packet.go +++ b/raw_packet.go @@ -48,3 +48,9 @@ func (r RawPacket) String() string { func (r RawPacket) MarshalSize() int { return len(r) } + +// Release returns the packet to its pool and resets it +func (p *RawPacket) Release() { + *p = RawPacket{} // Reset the packet + rawPacketPool.Put(p) +} diff --git a/receiver_estimated_maximum_bitrate.go b/receiver_estimated_maximum_bitrate.go index 7be57e6..6ee6b61 100644 --- a/receiver_estimated_maximum_bitrate.go +++ b/receiver_estimated_maximum_bitrate.go @@ -283,3 +283,9 @@ func (p *ReceiverEstimatedMaximumBitrate) String() string { func (p *ReceiverEstimatedMaximumBitrate) DestinationSSRC() []uint32 { return p.SSRCs } + +// Release returns the packet to its pool and resets it +func (p *ReceiverEstimatedMaximumBitrate) Release() { + *p = ReceiverEstimatedMaximumBitrate{} // Reset the packet + receiverEstimatedMaximumBitratePool.Put(p) +} diff --git a/receiver_report.go b/receiver_report.go index e917702..30f7141 100644 --- a/receiver_report.go +++ b/receiver_report.go @@ -193,3 +193,9 @@ func (r ReceiverReport) String() string { out += fmt.Sprintf("\tProfile Extension Data: %v\n", r.ProfileExtensions) return out } + +// Release returns the packet to its pool and resets it +func (p *ReceiverReport) Release() { + *p = ReceiverReport{} // Reset the packet + receiverReportPool.Put(p) +} diff --git a/rfc8888.go b/rfc8888.go index 544c6e3..70f25f3 100644 --- a/rfc8888.go +++ b/rfc8888.go @@ -186,6 +186,12 @@ func (b *CCFeedbackReport) Unmarshal(rawPacket []byte) error { return nil } +// Release returns the packet to its pool and resets it +func (p *CCFeedbackReport) Release() { + *p = CCFeedbackReport{} // Reset the packet + ccFeedbackReportPool.Put(p) +} + const ( ssrcOffset = 0 beginSequenceOffset = 4 diff --git a/sender_report.go b/sender_report.go index aaee0ee..85caffd 100644 --- a/sender_report.go +++ b/sender_report.go @@ -260,3 +260,9 @@ func (r SenderReport) String() string { out += fmt.Sprintf("\tProfile Extension Data: %v\n", r.ProfileExtensions) return out } + +// Release returns the packet to its pool and resets it +func (p *SenderReport) Release() { + *p = SenderReport{} // Reset the packet + senderReportPool.Put(p) +} diff --git a/slice_loss_indication.go b/slice_loss_indication.go index 014fcb7..5eafc7d 100644 --- a/slice_loss_indication.go +++ b/slice_loss_indication.go @@ -115,3 +115,9 @@ func (p *SliceLossIndication) String() string { func (p *SliceLossIndication) DestinationSSRC() []uint32 { return []uint32{p.MediaSSRC} } + +// Release returns the packet to its pool and resets it +func (p *SliceLossIndication) Release() { + *p = SliceLossIndication{} // Reset the packet + sliceLossIndicationPool.Put(p) +} diff --git a/source_description.go b/source_description.go index fc29d8e..d962b30 100644 --- a/source_description.go +++ b/source_description.go @@ -366,3 +366,9 @@ func (s *SourceDescription) String() string { } return out } + +// Release returns the packet to its pool and resets it +func (p *SourceDescription) Release() { + *p = SourceDescription{} // Reset the packet + sourceDescriptionPool.Put(p) +} diff --git a/transport_layer_cc.go b/transport_layer_cc.go index 84ead7d..a4fc978 100644 --- a/transport_layer_cc.go +++ b/transport_layer_cc.go @@ -558,6 +558,12 @@ func (t TransportLayerCC) DestinationSSRC() []uint32 { return []uint32{t.MediaSSRC} } +// Release returns the packet to its pool and resets it +func (p *TransportLayerCC) Release() { + *p = TransportLayerCC{} // Reset the packet + transportLayerCCPool.Put(p) +} + func localMin(x, y uint16) uint16 { if x < y { return x diff --git a/transport_layer_nack.go b/transport_layer_nack.go index 802b915..9254e72 100644 --- a/transport_layer_nack.go +++ b/transport_layer_nack.go @@ -179,3 +179,9 @@ func (p TransportLayerNack) String() string { func (p *TransportLayerNack) DestinationSSRC() []uint32 { return []uint32{p.MediaSSRC} } + +// Release returns the packet to its pool and resets it +func (p *TransportLayerNack) Release() { + *p = TransportLayerNack{} // Reset the packet + transportLayerNackPool.Put(p) +}