Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions session_srtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,19 @@ var bufferpool = sync.Pool{ // nolint:gochecknoglobals
}

func (s *SessionSRTP) writeRTP(header *rtp.Header, payload []byte) (int, error) {
encrypted, err := s.EncryptRTP(header, payload)
if err != nil {
return 0, err
}

return s.session.nextConn.Write(encrypted)
}

// EncryptRTP encrypts an RTP packet and returns the encrypted bytes.
// This allows the caller to handle writing to any destination.
func (s *SessionSRTP) EncryptRTP(header *rtp.Header, payload []byte) ([]byte, error) {
if _, ok := <-s.session.started; ok {
return 0, errStartedChannelUsedIncorrectly
return nil, errStartedChannelUsedIncorrectly
}

// encryptRTP will either return our buffer, or, if it is too
Expand All @@ -159,18 +170,18 @@ func (s *SessionSRTP) writeRTP(header *rtp.Header, payload []byte) (int, error)
}
_, err := rtp.MarshalPacketTo(buf, header, payload) // nolint:staticcheck
if err != nil {
return 0, err
return nil, err
}

s.session.localContextMutex.Lock()
encrypted, err := s.localContext.encryptRTP(buf, header, headerLen, buf[:marshalSize])
s.session.localContextMutex.Unlock()

if err != nil {
return 0, err
return nil, err
}

return s.session.nextConn.Write(encrypted)
return encrypted, nil
}

func (s *SessionSRTP) setWriteDeadline(t time.Time) error {
Expand Down
117 changes: 117 additions & 0 deletions session_srtp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,120 @@ func TestSessionSRTPPacketWithPadding(t *testing.T) {
assert.NoError(t, aSession.Close())
assert.NoError(t, bSession.Close())
}

func TestSessionSRTPEncryptRTP(t *testing.T) {
lim := test.TimeOut(time.Second * 5)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

const (
testSSRC = 5000
rtpHeaderSize = 12
)
testPayload := []byte{0x00, 0x01, 0x03, 0x04}
readBuffer := make([]byte, rtpHeaderSize+len(testPayload))

// Create pipes
aPipe, bPipe := net.Pipe()
config := &Config{
Profile: ProtectionProfileAes128CmHmacSha1_80,
Keys: SessionKeys{
[]byte{0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, 0x41, 0x39},
[]byte{0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6},
[]byte{0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, 0x41, 0x39},
[]byte{0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6},
},
}

aSession, err := NewSessionSRTP(aPipe, config)
assert.NoError(t, err)

bSession, err := NewSessionSRTP(bPipe, config)
assert.NoError(t, err)

// Test EncryptRTP returns encrypted bytes
header := &rtp.Header{SSRC: testSSRC}
encrypted, err := aSession.EncryptRTP(header, append([]byte{}, testPayload...))
assert.NoError(t, err)
assert.Greater(t, len(encrypted), 0, "EncryptRTP should return encrypted bytes")

// Write the encrypted packet manually
_, err = aPipe.Write(encrypted)
assert.NoError(t, err)

// Verify the packet can be decrypted by the receiver
bReadStream, ssrc, err := bSession.AcceptStream()
assert.NoError(t, err)
assert.Equal(t, uint32(testSSRC), ssrc)

_, err = bReadStream.Read(readBuffer)
assert.NoError(t, err)
assert.Equal(t, testPayload, readBuffer[rtpHeaderSize:])

assert.NoError(t, aSession.Close())
assert.NoError(t, bSession.Close())
}

func TestSessionSRTPEncryptRTPSharedContext(t *testing.T) {
// This test verifies that EncryptRTP and writeRTP share the same
// encryption context (ROC state). Sending packets via both methods
// should result in consecutive sequence numbers being decryptable.
lim := test.TimeOut(time.Second * 5)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

const (
testSSRC = 5000
rtpHeaderSize = 12
)
testPayload := []byte{0x00, 0x01, 0x03, 0x04}

aPipe, bPipe := net.Pipe()
config := &Config{
Profile: ProtectionProfileAes128CmHmacSha1_80,
Keys: SessionKeys{
[]byte{0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, 0x41, 0x39},
[]byte{0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6},
[]byte{0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, 0x41, 0x39},
[]byte{0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6},
},
}

aSession, err := NewSessionSRTP(aPipe, config)
assert.NoError(t, err)

bSession, err := NewSessionSRTP(bPipe, config)
assert.NoError(t, err)

bReadStream, err := bSession.OpenReadStream(testSSRC)
assert.NoError(t, err)

aWriteStream, err := aSession.OpenWriteStream()
assert.NoError(t, err)

// Send packets alternating between WriteRTP and EncryptRTP
// All should be decryptable if they share the same context
for seqNum := uint16(0); seqNum < 10; seqNum++ {
header := &rtp.Header{SSRC: testSSRC, SequenceNumber: seqNum}
if seqNum%2 == 0 {
_, err = aWriteStream.WriteRTP(header, testPayload)
} else {
encrypted, encErr := aSession.EncryptRTP(header, testPayload)
assert.NoError(t, encErr)
_, err = aPipe.Write(encrypted)
}
assert.NoError(t, err)

readBuffer := make([]byte, rtpHeaderSize+len(testPayload))
_, err = bReadStream.Read(readBuffer)
assert.NoError(t, err, "Failed to read packet %d", seqNum)
assert.Equal(t, testPayload, readBuffer[rtpHeaderSize:], "Payload mismatch for packet %d", seqNum)
}

assert.NoError(t, aSession.Close())
assert.NoError(t, bSession.Close())
}
Loading