From 8de1a00e13fc2d88e3227d76db669ae944db2d74 Mon Sep 17 00:00:00 2001 From: Amirmohammad Ghasemi Date: Thu, 13 Nov 2025 11:31:22 -0500 Subject: [PATCH 1/2] ECN support --- session.go | 15 ++++++++------- session_srtcp.go | 16 +++++++++++++++- session_srtp.go | 16 +++++++++++++++- stream.go | 5 +++++ stream_srtcp.go | 41 +++++++++++++++++++++-------------------- stream_srtp.go | 41 +++++++++++++++++++++-------------------- stream_srtp_test.go | 2 +- 7 files changed, 86 insertions(+), 50 deletions(-) diff --git a/session.go b/session.go index dd0b365..7ff3388 100644 --- a/session.go +++ b/session.go @@ -6,11 +6,11 @@ package srtp import ( "errors" "io" - "net" "sync" "time" "github.com/pion/logging" + "github.com/pion/transport/v3" "github.com/pion/transport/v3/packetio" ) @@ -18,6 +18,7 @@ type streamSession interface { Close() error write([]byte) (int, error) decrypt([]byte) error + decryptWithAttributes(b []byte, attr *transport.PacketAttributes) error } type session struct { @@ -36,9 +37,9 @@ type session struct { readStreamsLock sync.Mutex log logging.LeveledLogger - bufferFactory func(packetType packetio.BufferPacketType, ssrc uint32) io.ReadWriteCloser + bufferFactory func(packetType packetio.BufferPacketType, ssrc uint32) *packetio.Buffer - nextConn net.Conn + nextConn transport.NetConnSocket } // Config is used to configure a session. @@ -48,7 +49,7 @@ type session struct { type Config struct { Keys SessionKeys Profile ProtectionProfile - BufferFactory func(packetType packetio.BufferPacketType, ssrc uint32) io.ReadWriteCloser + BufferFactory func(packetType packetio.BufferPacketType, ssrc uint32) *packetio.Buffer LoggerFactory logging.LoggerFactory AcceptStreamTimeout time.Time @@ -145,9 +146,9 @@ func (s *session) start( }() b := make([]byte, 8192) + attr := transport.NewPacketAttributes() for { - var i int - i, err = s.nextConn.Read(b) + n, err := s.nextConn.ReadWithAttributes(b, attr) if err != nil { if !errors.Is(err, io.EOF) { s.log.Error(err.Error()) @@ -156,7 +157,7 @@ func (s *session) start( return } - if err = child.decrypt(b[:i]); err != nil { + if err = child.decryptWithAttributes(b[:n], attr); err != nil { s.log.Info(err.Error()) } } diff --git a/session_srtcp.go b/session_srtcp.go index 753576f..6b1b902 100644 --- a/session_srtcp.go +++ b/session_srtcp.go @@ -9,6 +9,7 @@ import ( "github.com/pion/logging" "github.com/pion/rtcp" + "github.com/pion/transport/v3" ) const defaultSessionSRTCPReplayProtectionWindow = 64 @@ -24,6 +25,15 @@ type SessionSRTCP struct { // NewSessionSRTCP creates a SRTCP session using conn as the underlying transport. func NewSessionSRTCP(conn net.Conn, config *Config) (*SessionSRTCP, error) { //nolint:dupl + return NewSessionSRTCPWithNewSocket( + transport.NewNetConnToNetConnSocket(conn), + config, + ) +} + +// NewSessionSRTCPWithNewSocket creates a SRTCP session using conn as the underlying transport. +// The conn argument implements transport.NetConnSocket, with more capabilities than a net.Conn socket. +func NewSessionSRTCPWithNewSocket(conn transport.NetConnSocket, config *Config) (*SessionSRTCP, error) { //nolint:dupl if config == nil { return nil, errNoConfig } else if conn == nil { @@ -157,6 +167,10 @@ func destinationSSRC(pkts []rtcp.Packet) []uint32 { } func (s *SessionSRTCP) decrypt(buf []byte) error { + return s.decryptWithAttributes(buf, nil) +} + +func (s *SessionSRTCP) decryptWithAttributes(buf []byte, attr *transport.PacketAttributes) error { decrypted, err := s.remoteContext.DecryptRTCP(buf, buf, nil) if err != nil { return err @@ -183,7 +197,7 @@ func (s *SessionSRTCP) decrypt(buf []byte) error { return errFailedTypeAssertion } - _, err = readStream.write(decrypted) + _, err = readStream.writeWithAttributes(decrypted, attr) if err != nil { return err } diff --git a/session_srtp.go b/session_srtp.go index 73ff253..6c0ccc3 100644 --- a/session_srtp.go +++ b/session_srtp.go @@ -10,6 +10,7 @@ import ( "github.com/pion/logging" "github.com/pion/rtp" + "github.com/pion/transport/v3" ) const defaultSessionSRTPReplayProtectionWindow = 64 @@ -25,6 +26,15 @@ type SessionSRTP struct { // NewSessionSRTP creates a SRTP session using conn as the underlying transport. func NewSessionSRTP(conn net.Conn, config *Config) (*SessionSRTP, error) { //nolint:dupl + return NewSessionSRTPWithNewSocket( + transport.NewNetConnToNetConnSocket(conn), + config, + ) +} + +// NewSessionSRTPWithNewSocket creates a SRTP session using conn as the underlying transport. +// The conn argument implements transport.NetConnSocket, with more capabilities than a net.Conn socket. +func NewSessionSRTPWithNewSocket(conn transport.NetConnSocket, config *Config) (*SessionSRTP, error) { //nolint:dupl if config == nil { return nil, errNoConfig } else if conn == nil { @@ -178,6 +188,10 @@ func (s *SessionSRTP) setWriteDeadline(t time.Time) error { } func (s *SessionSRTP) decrypt(buf []byte) error { + return s.decryptWithAttributes(buf, nil) +} + +func (s *SessionSRTP) decryptWithAttributes(buf []byte, attr *transport.PacketAttributes) error { header := &rtp.Header{} headerLen, err := header.Unmarshal(buf) if err != nil { @@ -204,7 +218,7 @@ func (s *SessionSRTP) decrypt(buf []byte) error { return err } - _, err = readStream.write(decrypted) + _, err = readStream.writeWithAttributes(decrypted, attr) if err != nil { return err } diff --git a/stream.go b/stream.go index 5f9c58a..b2a18a9 100644 --- a/stream.go +++ b/stream.go @@ -3,9 +3,14 @@ package srtp +import "github.com/pion/transport/v3" + type readStream interface { init(child streamSession, ssrc uint32) error Read(buf []byte) (int, error) + + ReadWithAttributes(b []byte, attr *transport.PacketAttributes) (int, error) + GetSSRC() uint32 } diff --git a/stream_srtcp.go b/stream_srtcp.go index 8fe407f..cad689b 100644 --- a/stream_srtcp.go +++ b/stream_srtcp.go @@ -5,11 +5,11 @@ package srtp import ( "errors" - "io" "sync" "time" "github.com/pion/rtcp" + "github.com/pion/transport/v3" "github.com/pion/transport/v3/packetio" ) @@ -26,18 +26,7 @@ type ReadStreamSRTCP struct { ssrc uint32 isInited bool - buffer io.ReadWriteCloser -} - -func (r *ReadStreamSRTCP) write(buf []byte) (n int, err error) { - n, err = r.buffer.Write(buf) - - if errors.Is(err, packetio.ErrFull) { - // Silently drop data when the buffer is full. - return len(buf), nil - } - - return n, err + buffer *packetio.Buffer } // Used by getOrCreateReadStream. @@ -66,16 +55,17 @@ func (r *ReadStreamSRTCP) Read(buf []byte) (int, error) { return r.buffer.Read(buf) } +// ReadWithAttributes reads and decrypts full RTCP packet from the nextConn with additional packet attributes. +func (r *ReadStreamSRTCP) ReadWithAttributes(b []byte, attr *transport.PacketAttributes) (int, error) { + n, err := r.buffer.ReadWithAttributes(b, attr) + + return n, err +} + // SetReadDeadline sets the deadline for the Read operation. // Setting to zero means no deadline. func (r *ReadStreamSRTCP) SetReadDeadline(t time.Time) error { - if b, ok := r.buffer.(interface { - SetReadDeadline(time.Time) error - }); ok { - return b.SetReadDeadline(t) - } - - return nil + return r.buffer.SetReadDeadline(t) } // Close removes the ReadStream from the session and cleans up any associated state. @@ -160,3 +150,14 @@ func (w *WriteStreamSRTCP) Write(b []byte) (int, error) { func (w *WriteStreamSRTCP) SetWriteDeadline(t time.Time) error { return w.session.setWriteDeadline(t) } + +func (r *ReadStreamSRTCP) writeWithAttributes(b []byte, attr *transport.PacketAttributes) (int, error) { + n, err := r.buffer.WriteWithAttributes(b, attr) + + if errors.Is(err, packetio.ErrFull) { + // Silently drop data when the buffer is full. + return 0, nil + } + + return n, err +} diff --git a/stream_srtp.go b/stream_srtp.go index 1b34266..716f7fe 100644 --- a/stream_srtp.go +++ b/stream_srtp.go @@ -5,11 +5,11 @@ package srtp import ( "errors" - "io" "sync" "time" "github.com/pion/rtp" + "github.com/pion/transport/v3" "github.com/pion/transport/v3/packetio" ) @@ -26,7 +26,7 @@ type ReadStreamSRTP struct { ssrc uint32 isInited bool - buffer io.ReadWriteCloser + buffer *packetio.Buffer } // Used by getOrCreateReadStream. @@ -63,22 +63,18 @@ func (r *ReadStreamSRTP) init(child streamSession, ssrc uint32) error { return nil } -func (r *ReadStreamSRTP) write(buf []byte) (n int, err error) { - n, err = r.buffer.Write(buf) - - if errors.Is(err, packetio.ErrFull) { - // Silently drop data when the buffer is full. - return len(buf), nil - } - - return n, err -} - // Read reads and decrypts full RTP packet from the nextConn. func (r *ReadStreamSRTP) Read(buf []byte) (int, error) { return r.buffer.Read(buf) } +// Read reads and decrypts full RTP packet from the nextConn with additional packet attributes. +func (r *ReadStreamSRTP) ReadWithAttributes(buf []byte, attr *transport.PacketAttributes) (int, error) { + n, err := r.buffer.ReadWithAttributes(buf, attr) + + return n, err +} + // ReadRTP reads and decrypts full RTP packet and its header from the nextConn. func (r *ReadStreamSRTP) ReadRTP(buf []byte) (int, *rtp.Header, error) { n, err := r.Read(buf) @@ -99,13 +95,7 @@ func (r *ReadStreamSRTP) ReadRTP(buf []byte) (int, *rtp.Header, error) { // SetReadDeadline sets the deadline for the Read operation. // Setting to zero means no deadline. func (r *ReadStreamSRTP) SetReadDeadline(t time.Time) error { - if b, ok := r.buffer.(interface { - SetReadDeadline(time.Time) error - }); ok { - return b.SetReadDeadline(t) - } - - return nil + return r.buffer.SetReadDeadline(t) } // Close removes the ReadStream from the session and cleans up any associated state. @@ -157,3 +147,14 @@ func (w *WriteStreamSRTP) Write(b []byte) (int, error) { func (w *WriteStreamSRTP) SetWriteDeadline(t time.Time) error { return w.session.setWriteDeadline(t) } + +func (r *ReadStreamSRTP) writeWithAttributes(buff []byte, attr *transport.PacketAttributes) (n int, err error) { + n, err = r.buffer.WriteWithAttributes(buff, attr) + + if errors.Is(err, packetio.ErrFull) { + // Silently drop data when the buffer is full. + return len(buff), nil + } + + return n, err +} diff --git a/stream_srtp_test.go b/stream_srtp_test.go index 2137d90..dc82080 100644 --- a/stream_srtp_test.go +++ b/stream_srtp_test.go @@ -39,7 +39,7 @@ func TestBufferFactory(t *testing.T) { wg := sync.WaitGroup{} wg.Add(2) conn := newNoopConn() - bf := func(_ packetio.BufferPacketType, _ uint32) io.ReadWriteCloser { + bf := func(_ packetio.BufferPacketType, _ uint32) *packetio.Buffer { wg.Done() return packetio.NewBuffer() From d6af1e05ae092d27103dff2a4ca7903bdafd5164 Mon Sep 17 00:00:00 2001 From: Amirmohammad Ghasemi Date: Fri, 14 Nov 2025 14:26:29 -0500 Subject: [PATCH 2/2] Adapt with transport changes --- session.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/session.go b/session.go index 7ff3388..dc77ddd 100644 --- a/session.go +++ b/session.go @@ -146,7 +146,7 @@ func (s *session) start( }() b := make([]byte, 8192) - attr := transport.NewPacketAttributes() + attr := transport.NewPacketAttributesWithLen(transport.MaxAttributesLen) for { n, err := s.nextConn.ReadWithAttributes(b, attr) if err != nil { @@ -157,7 +157,7 @@ func (s *session) start( return } - if err = child.decryptWithAttributes(b[:n], attr); err != nil { + if err = child.decryptWithAttributes(b[:n], attr.GetReadPacketAttributes()); err != nil { s.log.Info(err.Error()) } }