diff --git a/session_srtcp.go b/session_srtcp.go index 753576f..f9954f9 100644 --- a/session_srtcp.go +++ b/session_srtcp.go @@ -138,14 +138,11 @@ func (s *SessionSRTCP) setWriteDeadline(t time.Time) error { return s.session.nextConn.SetWriteDeadline(t) } -// create a list of Destination SSRCs -// that's a superset of all Destinations in the slice. -func destinationSSRC(pkts []rtcp.Packet) []uint32 { +// create a list of Destination SSRCs for the packet. +func destinationSSRC(pkt rtcp.Packet) []uint32 { ssrcSet := make(map[uint32]struct{}) - for _, p := range pkts { - for _, ssrc := range p.DestinationSSRC() { - ssrcSet[ssrc] = struct{}{} - } + for _, ssrc := range pkt.DestinationSSRC() { + ssrcSet[ssrc] = struct{}{} } out := make([]uint32, 0, len(ssrcSet)) @@ -156,36 +153,44 @@ func destinationSSRC(pkts []rtcp.Packet) []uint32 { return out } +//nolint:cyclop func (s *SessionSRTCP) decrypt(buf []byte) error { decrypted, err := s.remoteContext.DecryptRTCP(buf, buf, nil) if err != nil { return err } - pkt, err := rtcp.Unmarshal(decrypted) + pkts, err := rtcp.Unmarshal(decrypted) if err != nil { return err } - for _, ssrc := range destinationSSRC(pkt) { - r, isNew := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP) - if r == nil { - return nil // Session has been closed - } else if isNew { - if !s.session.acceptStreamTimeout.IsZero() { - _ = s.session.nextConn.SetReadDeadline(time.Time{}) - } - s.session.newStream <- r // Notify AcceptStream + for _, pkt := range pkts { + marshaled, err := pkt.Marshal() + if err != nil { + return err } - readStream, ok := r.(*ReadStreamSRTCP) - if !ok { - return errFailedTypeAssertion - } + for _, ssrc := range destinationSSRC(pkt) { + r, isNew := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP) + if r == nil { + return nil // Session has been closed + } else if isNew { + if !s.session.acceptStreamTimeout.IsZero() { + _ = s.session.nextConn.SetReadDeadline(time.Time{}) + } + s.session.newStream <- r // Notify AcceptStream + } - _, err = readStream.write(decrypted) - if err != nil { - return err + readStream, ok := r.(*ReadStreamSRTCP) + if !ok { + return errFailedTypeAssertion + } + + _, err = readStream.write(marshaled) + if err != nil { + return err + } } } diff --git a/session_srtcp_test.go b/session_srtcp_test.go index f297368..49cdc6f 100644 --- a/session_srtcp_test.go +++ b/session_srtcp_test.go @@ -245,6 +245,129 @@ func TestSessionSRTCPReplayProtection(t *testing.T) { expectedSSRC, receivedSSRC) } +func TestSessionSRTCPCompoundPacket(t *testing.T) { + lim := test.TimeOut(time.Second * 5) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + testSSRCSenderReport1SR := uint32(0x902f9e2e) + testSSRCSenderReport1RR := uint32(0xbc5e9a40) + testSSRCCNAME := uint32(1234) + testSSRCSenderReport2SR := uint32(0x12345678) + aSession, bSession := buildSessionSRTCPPair(t) + bReadStreamSR1SR, err := bSession.OpenReadStream(testSSRCSenderReport1SR) + assert.NoError(t, err) + bReadStreamSR1RR, err := bSession.OpenReadStream(testSSRCSenderReport1RR) + assert.NoError(t, err) + bReadStreamCNAME, err := bSession.OpenReadStream(testSSRCCNAME) + assert.NoError(t, err) + bReadStreamSR2SR, err := bSession.OpenReadStream(testSSRCSenderReport2SR) + assert.NoError(t, err) + + // Compound packet + // first packet - Sender Report with a Receiver Report + // seconde packet - Sender Report without a Receiver Report + cp := &rtcp.CompoundPacket{ + &rtcp.SenderReport{ + SSRC: testSSRCSenderReport1SR, + NTPTime: 0xda8bd1fcdddda05a, + RTPTime: 0xaaf4edd5, + PacketCount: 1, + OctetCount: 2, + Reports: []rtcp.ReceptionReport{{ + SSRC: testSSRCSenderReport1RR, + FractionLost: 0, + TotalLost: 0, + LastSequenceNumber: 0x46e1, + Jitter: 273, + LastSenderReport: 0x9f36432, + Delay: 150137, + }}, + ProfileExtensions: []byte{ + 0x81, 0xca, 0x0, 0x6, + 0x2b, 0x7e, 0xc0, 0xc5, + 0x1, 0x10, 0x4c, 0x63, + 0x49, 0x66, 0x7a, 0x58, + 0x6f, 0x6e, 0x44, 0x6f, + 0x72, 0x64, 0x53, 0x65, + 0x57, 0x36, 0x0, 0x0, + }, + }, + rtcp.NewCNAMESourceDescription(testSSRCCNAME, "cname"), // to make it a valid compound packet + &rtcp.SenderReport{ + SSRC: testSSRCSenderReport2SR, + NTPTime: 0xda8bd1fcdddda05a, + RTPTime: 0xaaf4edd5, + PacketCount: 1, + OctetCount: 2, + }, + } + + done := make(chan struct{}) + go func() { + readBuffer := make([]byte, 200) + + senderReport := &rtcp.SenderReport{} + n, _, rerr := bReadStreamSR1SR.ReadRTCP(readBuffer) + assert.NoError(t, rerr) + rerr = senderReport.Unmarshal(readBuffer[:n]) + assert.NoError(t, rerr) + assert.Equal(t, uint32(0x902f9e2e), senderReport.SSRC) + assert.Len(t, senderReport.Reports, 1) + assert.Equal(t, uint32(0xbc5e9a40), senderReport.Reports[0].SSRC) + assert.Len(t, senderReport.DestinationSSRC(), 2) + assert.ElementsMatch(t, []uint32{0x902f9e2e, 0xbc5e9a40}, senderReport.DestinationSSRC()) + + // should read via receiver report embedded in sender report + senderReport = &rtcp.SenderReport{} + n, _, rerr = bReadStreamSR1RR.ReadRTCP(readBuffer) + assert.NoError(t, rerr) + rerr = senderReport.Unmarshal(readBuffer[:n]) + assert.NoError(t, rerr) + assert.Equal(t, uint32(0x902f9e2e), senderReport.SSRC) + assert.Len(t, senderReport.Reports, 1) + assert.Equal(t, uint32(0xbc5e9a40), senderReport.Reports[0].SSRC) + assert.Len(t, senderReport.DestinationSSRC(), 2) + assert.ElementsMatch(t, []uint32{0x902f9e2e, 0xbc5e9a40}, senderReport.DestinationSSRC()) + + cname := &rtcp.SourceDescription{} + n, _, rerr = bReadStreamCNAME.ReadRTCP(readBuffer) + assert.NoError(t, rerr) + rerr = cname.Unmarshal(readBuffer[:n]) + assert.NoError(t, rerr) + assert.Len(t, cname.DestinationSSRC(), 1) + assert.Equal(t, uint32(1234), cname.DestinationSSRC()[0]) + + senderReport = &rtcp.SenderReport{} + n, _, rerr = bReadStreamSR2SR.ReadRTCP(readBuffer) + assert.NoError(t, rerr) + rerr = senderReport.Unmarshal(readBuffer[:n]) + assert.NoError(t, rerr) + assert.Equal(t, uint32(0x12345678), senderReport.SSRC) + assert.Len(t, senderReport.Reports, 0) + assert.Len(t, senderReport.DestinationSSRC(), 1) + assert.ElementsMatch(t, []uint32{0x12345678}, senderReport.DestinationSSRC()) + + close(done) + }() + + encrypted, err := encryptSRTCP(aSession.session.localContext, cp) + assert.NoError(t, err) + _, err = aSession.session.nextConn.Write(encrypted) + assert.NoError(t, err) + + <-done + + assert.NoError(t, aSession.Close()) + assert.NoError(t, bSession.Close()) + assert.NoError(t, bReadStreamSR1SR.Close()) + assert.NoError(t, bReadStreamSR1RR.Close()) + assert.NoError(t, bReadStreamCNAME.Close()) + assert.NoError(t, bReadStreamSR2SR.Close()) +} + // nolint: dupl func TestSessionSRTCPAcceptStreamTimeout(t *testing.T) { lim := test.TimeOut(time.Second * 5)