diff --git a/session_srtp.go b/session_srtp.go index 73ff253..5489210 100644 --- a/session_srtp.go +++ b/session_srtp.go @@ -139,8 +139,19 @@ var bufferpool = sync.Pool{ // nolint:gochecknoglobals } func (s *SessionSRTP) writeRTP(header *rtp.Header, payload []byte) (int, error) { + encrypted, err := s.EncryptRTP(header, payload) + if err != nil { + return 0, err + } + + return s.session.nextConn.Write(encrypted) +} + +// EncryptRTP encrypts an RTP packet and returns the encrypted bytes. +// This allows the caller to handle writing to any destination. +func (s *SessionSRTP) EncryptRTP(header *rtp.Header, payload []byte) ([]byte, error) { if _, ok := <-s.session.started; ok { - return 0, errStartedChannelUsedIncorrectly + return nil, errStartedChannelUsedIncorrectly } // encryptRTP will either return our buffer, or, if it is too @@ -159,7 +170,7 @@ func (s *SessionSRTP) writeRTP(header *rtp.Header, payload []byte) (int, error) } _, err := rtp.MarshalPacketTo(buf, header, payload) // nolint:staticcheck if err != nil { - return 0, err + return nil, err } s.session.localContextMutex.Lock() @@ -167,10 +178,10 @@ func (s *SessionSRTP) writeRTP(header *rtp.Header, payload []byte) (int, error) s.session.localContextMutex.Unlock() if err != nil { - return 0, err + return nil, err } - return s.session.nextConn.Write(encrypted) + return encrypted, nil } func (s *SessionSRTP) setWriteDeadline(t time.Time) error { diff --git a/session_srtp_test.go b/session_srtp_test.go index 67a80ed..1117406 100644 --- a/session_srtp_test.go +++ b/session_srtp_test.go @@ -411,3 +411,120 @@ func TestSessionSRTPPacketWithPadding(t *testing.T) { assert.NoError(t, aSession.Close()) assert.NoError(t, bSession.Close()) } + +func TestSessionSRTPEncryptRTP(t *testing.T) { + lim := test.TimeOut(time.Second * 5) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + const ( + testSSRC = 5000 + rtpHeaderSize = 12 + ) + testPayload := []byte{0x00, 0x01, 0x03, 0x04} + readBuffer := make([]byte, rtpHeaderSize+len(testPayload)) + + // Create pipes + aPipe, bPipe := net.Pipe() + config := &Config{ + Profile: ProtectionProfileAes128CmHmacSha1_80, + Keys: SessionKeys{ + []byte{0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, 0x41, 0x39}, + []byte{0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6}, + []byte{0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, 0x41, 0x39}, + []byte{0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6}, + }, + } + + aSession, err := NewSessionSRTP(aPipe, config) + assert.NoError(t, err) + + bSession, err := NewSessionSRTP(bPipe, config) + assert.NoError(t, err) + + // Test EncryptRTP returns encrypted bytes + header := &rtp.Header{SSRC: testSSRC} + encrypted, err := aSession.EncryptRTP(header, append([]byte{}, testPayload...)) + assert.NoError(t, err) + assert.Greater(t, len(encrypted), 0, "EncryptRTP should return encrypted bytes") + + // Write the encrypted packet manually + _, err = aPipe.Write(encrypted) + assert.NoError(t, err) + + // Verify the packet can be decrypted by the receiver + bReadStream, ssrc, err := bSession.AcceptStream() + assert.NoError(t, err) + assert.Equal(t, uint32(testSSRC), ssrc) + + _, err = bReadStream.Read(readBuffer) + assert.NoError(t, err) + assert.Equal(t, testPayload, readBuffer[rtpHeaderSize:]) + + assert.NoError(t, aSession.Close()) + assert.NoError(t, bSession.Close()) +} + +func TestSessionSRTPEncryptRTPSharedContext(t *testing.T) { + // This test verifies that EncryptRTP and writeRTP share the same + // encryption context (ROC state). Sending packets via both methods + // should result in consecutive sequence numbers being decryptable. + lim := test.TimeOut(time.Second * 5) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + const ( + testSSRC = 5000 + rtpHeaderSize = 12 + ) + testPayload := []byte{0x00, 0x01, 0x03, 0x04} + + aPipe, bPipe := net.Pipe() + config := &Config{ + Profile: ProtectionProfileAes128CmHmacSha1_80, + Keys: SessionKeys{ + []byte{0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, 0x41, 0x39}, + []byte{0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6}, + []byte{0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, 0x41, 0x39}, + []byte{0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6}, + }, + } + + aSession, err := NewSessionSRTP(aPipe, config) + assert.NoError(t, err) + + bSession, err := NewSessionSRTP(bPipe, config) + assert.NoError(t, err) + + bReadStream, err := bSession.OpenReadStream(testSSRC) + assert.NoError(t, err) + + aWriteStream, err := aSession.OpenWriteStream() + assert.NoError(t, err) + + // Send packets alternating between WriteRTP and EncryptRTP + // All should be decryptable if they share the same context + for seqNum := uint16(0); seqNum < 10; seqNum++ { + header := &rtp.Header{SSRC: testSSRC, SequenceNumber: seqNum} + if seqNum%2 == 0 { + _, err = aWriteStream.WriteRTP(header, testPayload) + } else { + encrypted, encErr := aSession.EncryptRTP(header, testPayload) + assert.NoError(t, encErr) + _, err = aPipe.Write(encrypted) + } + assert.NoError(t, err) + + readBuffer := make([]byte, rtpHeaderSize+len(testPayload)) + _, err = bReadStream.Read(readBuffer) + assert.NoError(t, err, "Failed to read packet %d", seqNum) + assert.Equal(t, testPayload, readBuffer[rtpHeaderSize:], "Payload mismatch for packet %d", seqNum) + } + + assert.NoError(t, aSession.Close()) + assert.NoError(t, bSession.Close()) +}