Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 29 additions & 24 deletions session_srtcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,11 @@ func (s *SessionSRTCP) setWriteDeadline(t time.Time) error {
return s.session.nextConn.SetWriteDeadline(t)
}

// create a list of Destination SSRCs
// that's a superset of all Destinations in the slice.
func destinationSSRC(pkts []rtcp.Packet) []uint32 {
// create a list of Destination SSRCs for the packet.
func destinationSSRC(pkt rtcp.Packet) []uint32 {
ssrcSet := make(map[uint32]struct{})
for _, p := range pkts {
for _, ssrc := range p.DestinationSSRC() {
ssrcSet[ssrc] = struct{}{}
}
for _, ssrc := range pkt.DestinationSSRC() {
ssrcSet[ssrc] = struct{}{}
}

out := make([]uint32, 0, len(ssrcSet))
Expand All @@ -156,36 +153,44 @@ func destinationSSRC(pkts []rtcp.Packet) []uint32 {
return out
}

//nolint:cyclop
func (s *SessionSRTCP) decrypt(buf []byte) error {
decrypted, err := s.remoteContext.DecryptRTCP(buf, buf, nil)
if err != nil {
return err
}

pkt, err := rtcp.Unmarshal(decrypted)
pkts, err := rtcp.Unmarshal(decrypted)
if err != nil {
return err
}

for _, ssrc := range destinationSSRC(pkt) {
r, isNew := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP)
if r == nil {
return nil // Session has been closed
} else if isNew {
if !s.session.acceptStreamTimeout.IsZero() {
_ = s.session.nextConn.SetReadDeadline(time.Time{})
}
s.session.newStream <- r // Notify AcceptStream
for _, pkt := range pkts {
marshaled, err := pkt.Marshal()
if err != nil {
return err
}

readStream, ok := r.(*ReadStreamSRTCP)
if !ok {
return errFailedTypeAssertion
}
for _, ssrc := range destinationSSRC(pkt) {
r, isNew := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP)
if r == nil {
return nil // Session has been closed
} else if isNew {
if !s.session.acceptStreamTimeout.IsZero() {
_ = s.session.nextConn.SetReadDeadline(time.Time{})
}
s.session.newStream <- r // Notify AcceptStream
}

_, err = readStream.write(decrypted)
if err != nil {
return err
readStream, ok := r.(*ReadStreamSRTCP)
if !ok {
return errFailedTypeAssertion
}

_, err = readStream.write(marshaled)
if err != nil {
return err
}
}
}

Expand Down
123 changes: 123 additions & 0 deletions session_srtcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,129 @@ func TestSessionSRTCPReplayProtection(t *testing.T) {
expectedSSRC, receivedSSRC)
}

func TestSessionSRTCPCompoundPacket(t *testing.T) {
lim := test.TimeOut(time.Second * 5)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

testSSRCSenderReport1SR := uint32(0x902f9e2e)
testSSRCSenderReport1RR := uint32(0xbc5e9a40)
testSSRCCNAME := uint32(1234)
testSSRCSenderReport2SR := uint32(0x12345678)
aSession, bSession := buildSessionSRTCPPair(t)
bReadStreamSR1SR, err := bSession.OpenReadStream(testSSRCSenderReport1SR)
assert.NoError(t, err)
bReadStreamSR1RR, err := bSession.OpenReadStream(testSSRCSenderReport1RR)
assert.NoError(t, err)
bReadStreamCNAME, err := bSession.OpenReadStream(testSSRCCNAME)
assert.NoError(t, err)
bReadStreamSR2SR, err := bSession.OpenReadStream(testSSRCSenderReport2SR)
assert.NoError(t, err)

// Compound packet
// first packet - Sender Report with a Receiver Report
// seconde packet - Sender Report without a Receiver Report
cp := &rtcp.CompoundPacket{
&rtcp.SenderReport{
SSRC: testSSRCSenderReport1SR,
NTPTime: 0xda8bd1fcdddda05a,
RTPTime: 0xaaf4edd5,
PacketCount: 1,
OctetCount: 2,
Reports: []rtcp.ReceptionReport{{
SSRC: testSSRCSenderReport1RR,
FractionLost: 0,
TotalLost: 0,
LastSequenceNumber: 0x46e1,
Jitter: 273,
LastSenderReport: 0x9f36432,
Delay: 150137,
}},
ProfileExtensions: []byte{
0x81, 0xca, 0x0, 0x6,
0x2b, 0x7e, 0xc0, 0xc5,
0x1, 0x10, 0x4c, 0x63,
0x49, 0x66, 0x7a, 0x58,
0x6f, 0x6e, 0x44, 0x6f,
0x72, 0x64, 0x53, 0x65,
0x57, 0x36, 0x0, 0x0,
},
},
rtcp.NewCNAMESourceDescription(testSSRCCNAME, "cname"), // to make it a valid compound packet
&rtcp.SenderReport{
SSRC: testSSRCSenderReport2SR,
NTPTime: 0xda8bd1fcdddda05a,
RTPTime: 0xaaf4edd5,
PacketCount: 1,
OctetCount: 2,
},
}

done := make(chan struct{})
go func() {
readBuffer := make([]byte, 200)

senderReport := &rtcp.SenderReport{}
n, _, rerr := bReadStreamSR1SR.ReadRTCP(readBuffer)
assert.NoError(t, rerr)
rerr = senderReport.Unmarshal(readBuffer[:n])
assert.NoError(t, rerr)
assert.Equal(t, uint32(0x902f9e2e), senderReport.SSRC)
assert.Len(t, senderReport.Reports, 1)
assert.Equal(t, uint32(0xbc5e9a40), senderReport.Reports[0].SSRC)
assert.Len(t, senderReport.DestinationSSRC(), 2)
assert.ElementsMatch(t, []uint32{0x902f9e2e, 0xbc5e9a40}, senderReport.DestinationSSRC())

// should read via receiver report embedded in sender report
senderReport = &rtcp.SenderReport{}
n, _, rerr = bReadStreamSR1RR.ReadRTCP(readBuffer)
assert.NoError(t, rerr)
rerr = senderReport.Unmarshal(readBuffer[:n])
assert.NoError(t, rerr)
assert.Equal(t, uint32(0x902f9e2e), senderReport.SSRC)
assert.Len(t, senderReport.Reports, 1)
assert.Equal(t, uint32(0xbc5e9a40), senderReport.Reports[0].SSRC)
assert.Len(t, senderReport.DestinationSSRC(), 2)
assert.ElementsMatch(t, []uint32{0x902f9e2e, 0xbc5e9a40}, senderReport.DestinationSSRC())

cname := &rtcp.SourceDescription{}
n, _, rerr = bReadStreamCNAME.ReadRTCP(readBuffer)
assert.NoError(t, rerr)
rerr = cname.Unmarshal(readBuffer[:n])
assert.NoError(t, rerr)
assert.Len(t, cname.DestinationSSRC(), 1)
assert.Equal(t, uint32(1234), cname.DestinationSSRC()[0])

senderReport = &rtcp.SenderReport{}
n, _, rerr = bReadStreamSR2SR.ReadRTCP(readBuffer)
assert.NoError(t, rerr)
rerr = senderReport.Unmarshal(readBuffer[:n])
assert.NoError(t, rerr)
assert.Equal(t, uint32(0x12345678), senderReport.SSRC)
assert.Len(t, senderReport.Reports, 0)
assert.Len(t, senderReport.DestinationSSRC(), 1)
assert.ElementsMatch(t, []uint32{0x12345678}, senderReport.DestinationSSRC())

close(done)
}()

encrypted, err := encryptSRTCP(aSession.session.localContext, cp)
assert.NoError(t, err)
_, err = aSession.session.nextConn.Write(encrypted)
assert.NoError(t, err)

<-done

assert.NoError(t, aSession.Close())
assert.NoError(t, bSession.Close())
assert.NoError(t, bReadStreamSR1SR.Close())
assert.NoError(t, bReadStreamSR1RR.Close())
assert.NoError(t, bReadStreamCNAME.Close())
assert.NoError(t, bReadStreamSR2SR.Close())
}

// nolint: dupl
func TestSessionSRTCPAcceptStreamTimeout(t *testing.T) {
lim := test.TimeOut(time.Second * 5)
Expand Down
Loading