From 40818b02f3e0fc78945320f7c9ab55e45ede8fd6 Mon Sep 17 00:00:00 2001 From: boks1971 Date: Mon, 29 Dec 2025 09:48:28 +0530 Subject: [PATCH] Split RTCP compound packet before forwarding When a compound RTCP packet is received, it is forwarded as is to the read streams. Downstream components have to check for the SSRC in the packets matching the track(s) it is handling. This leads to situations where downstream components could see duplicate RTCP reports. Happens when there is a downstream handler for all SSRCs. It receives the compound packet, unmarshals it and invokes the handlers. As it is fielding all SSRCs, it will get the same compound packet `n` times and invoke the handlers `n` times. This PR splits up the compound packet and forwards individual packets to avoid the end handlers from seeing duplicates. API wise, it is compatible as it still emits an encoded/marshaled packet. But, this does add a marshaling step. Add an unit test. Before this change, the test fails at the point where the test tries to read the CNAME packet and sees an incorrect type. --- session_srtcp.go | 53 +++++++++--------- session_srtcp_test.go | 123 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 152 insertions(+), 24 deletions(-) diff --git a/session_srtcp.go b/session_srtcp.go index 753576f..f9954f9 100644 --- a/session_srtcp.go +++ b/session_srtcp.go @@ -138,14 +138,11 @@ func (s *SessionSRTCP) setWriteDeadline(t time.Time) error { return s.session.nextConn.SetWriteDeadline(t) } -// create a list of Destination SSRCs -// that's a superset of all Destinations in the slice. -func destinationSSRC(pkts []rtcp.Packet) []uint32 { +// create a list of Destination SSRCs for the packet. +func destinationSSRC(pkt rtcp.Packet) []uint32 { ssrcSet := make(map[uint32]struct{}) - for _, p := range pkts { - for _, ssrc := range p.DestinationSSRC() { - ssrcSet[ssrc] = struct{}{} - } + for _, ssrc := range pkt.DestinationSSRC() { + ssrcSet[ssrc] = struct{}{} } out := make([]uint32, 0, len(ssrcSet)) @@ -156,36 +153,44 @@ func destinationSSRC(pkts []rtcp.Packet) []uint32 { return out } +//nolint:cyclop func (s *SessionSRTCP) decrypt(buf []byte) error { decrypted, err := s.remoteContext.DecryptRTCP(buf, buf, nil) if err != nil { return err } - pkt, err := rtcp.Unmarshal(decrypted) + pkts, err := rtcp.Unmarshal(decrypted) if err != nil { return err } - for _, ssrc := range destinationSSRC(pkt) { - r, isNew := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP) - if r == nil { - return nil // Session has been closed - } else if isNew { - if !s.session.acceptStreamTimeout.IsZero() { - _ = s.session.nextConn.SetReadDeadline(time.Time{}) - } - s.session.newStream <- r // Notify AcceptStream + for _, pkt := range pkts { + marshaled, err := pkt.Marshal() + if err != nil { + return err } - readStream, ok := r.(*ReadStreamSRTCP) - if !ok { - return errFailedTypeAssertion - } + for _, ssrc := range destinationSSRC(pkt) { + r, isNew := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP) + if r == nil { + return nil // Session has been closed + } else if isNew { + if !s.session.acceptStreamTimeout.IsZero() { + _ = s.session.nextConn.SetReadDeadline(time.Time{}) + } + s.session.newStream <- r // Notify AcceptStream + } - _, err = readStream.write(decrypted) - if err != nil { - return err + readStream, ok := r.(*ReadStreamSRTCP) + if !ok { + return errFailedTypeAssertion + } + + _, err = readStream.write(marshaled) + if err != nil { + return err + } } } diff --git a/session_srtcp_test.go b/session_srtcp_test.go index f297368..49cdc6f 100644 --- a/session_srtcp_test.go +++ b/session_srtcp_test.go @@ -245,6 +245,129 @@ func TestSessionSRTCPReplayProtection(t *testing.T) { expectedSSRC, receivedSSRC) } +func TestSessionSRTCPCompoundPacket(t *testing.T) { + lim := test.TimeOut(time.Second * 5) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + testSSRCSenderReport1SR := uint32(0x902f9e2e) + testSSRCSenderReport1RR := uint32(0xbc5e9a40) + testSSRCCNAME := uint32(1234) + testSSRCSenderReport2SR := uint32(0x12345678) + aSession, bSession := buildSessionSRTCPPair(t) + bReadStreamSR1SR, err := bSession.OpenReadStream(testSSRCSenderReport1SR) + assert.NoError(t, err) + bReadStreamSR1RR, err := bSession.OpenReadStream(testSSRCSenderReport1RR) + assert.NoError(t, err) + bReadStreamCNAME, err := bSession.OpenReadStream(testSSRCCNAME) + assert.NoError(t, err) + bReadStreamSR2SR, err := bSession.OpenReadStream(testSSRCSenderReport2SR) + assert.NoError(t, err) + + // Compound packet + // first packet - Sender Report with a Receiver Report + // seconde packet - Sender Report without a Receiver Report + cp := &rtcp.CompoundPacket{ + &rtcp.SenderReport{ + SSRC: testSSRCSenderReport1SR, + NTPTime: 0xda8bd1fcdddda05a, + RTPTime: 0xaaf4edd5, + PacketCount: 1, + OctetCount: 2, + Reports: []rtcp.ReceptionReport{{ + SSRC: testSSRCSenderReport1RR, + FractionLost: 0, + TotalLost: 0, + LastSequenceNumber: 0x46e1, + Jitter: 273, + LastSenderReport: 0x9f36432, + Delay: 150137, + }}, + ProfileExtensions: []byte{ + 0x81, 0xca, 0x0, 0x6, + 0x2b, 0x7e, 0xc0, 0xc5, + 0x1, 0x10, 0x4c, 0x63, + 0x49, 0x66, 0x7a, 0x58, + 0x6f, 0x6e, 0x44, 0x6f, + 0x72, 0x64, 0x53, 0x65, + 0x57, 0x36, 0x0, 0x0, + }, + }, + rtcp.NewCNAMESourceDescription(testSSRCCNAME, "cname"), // to make it a valid compound packet + &rtcp.SenderReport{ + SSRC: testSSRCSenderReport2SR, + NTPTime: 0xda8bd1fcdddda05a, + RTPTime: 0xaaf4edd5, + PacketCount: 1, + OctetCount: 2, + }, + } + + done := make(chan struct{}) + go func() { + readBuffer := make([]byte, 200) + + senderReport := &rtcp.SenderReport{} + n, _, rerr := bReadStreamSR1SR.ReadRTCP(readBuffer) + assert.NoError(t, rerr) + rerr = senderReport.Unmarshal(readBuffer[:n]) + assert.NoError(t, rerr) + assert.Equal(t, uint32(0x902f9e2e), senderReport.SSRC) + assert.Len(t, senderReport.Reports, 1) + assert.Equal(t, uint32(0xbc5e9a40), senderReport.Reports[0].SSRC) + assert.Len(t, senderReport.DestinationSSRC(), 2) + assert.ElementsMatch(t, []uint32{0x902f9e2e, 0xbc5e9a40}, senderReport.DestinationSSRC()) + + // should read via receiver report embedded in sender report + senderReport = &rtcp.SenderReport{} + n, _, rerr = bReadStreamSR1RR.ReadRTCP(readBuffer) + assert.NoError(t, rerr) + rerr = senderReport.Unmarshal(readBuffer[:n]) + assert.NoError(t, rerr) + assert.Equal(t, uint32(0x902f9e2e), senderReport.SSRC) + assert.Len(t, senderReport.Reports, 1) + assert.Equal(t, uint32(0xbc5e9a40), senderReport.Reports[0].SSRC) + assert.Len(t, senderReport.DestinationSSRC(), 2) + assert.ElementsMatch(t, []uint32{0x902f9e2e, 0xbc5e9a40}, senderReport.DestinationSSRC()) + + cname := &rtcp.SourceDescription{} + n, _, rerr = bReadStreamCNAME.ReadRTCP(readBuffer) + assert.NoError(t, rerr) + rerr = cname.Unmarshal(readBuffer[:n]) + assert.NoError(t, rerr) + assert.Len(t, cname.DestinationSSRC(), 1) + assert.Equal(t, uint32(1234), cname.DestinationSSRC()[0]) + + senderReport = &rtcp.SenderReport{} + n, _, rerr = bReadStreamSR2SR.ReadRTCP(readBuffer) + assert.NoError(t, rerr) + rerr = senderReport.Unmarshal(readBuffer[:n]) + assert.NoError(t, rerr) + assert.Equal(t, uint32(0x12345678), senderReport.SSRC) + assert.Len(t, senderReport.Reports, 0) + assert.Len(t, senderReport.DestinationSSRC(), 1) + assert.ElementsMatch(t, []uint32{0x12345678}, senderReport.DestinationSSRC()) + + close(done) + }() + + encrypted, err := encryptSRTCP(aSession.session.localContext, cp) + assert.NoError(t, err) + _, err = aSession.session.nextConn.Write(encrypted) + assert.NoError(t, err) + + <-done + + assert.NoError(t, aSession.Close()) + assert.NoError(t, bSession.Close()) + assert.NoError(t, bReadStreamSR1SR.Close()) + assert.NoError(t, bReadStreamSR1RR.Close()) + assert.NoError(t, bReadStreamCNAME.Close()) + assert.NoError(t, bReadStreamSR2SR.Close()) +} + // nolint: dupl func TestSessionSRTCPAcceptStreamTimeout(t *testing.T) { lim := test.TimeOut(time.Second * 5)