diff --git a/candidate_base.go b/candidate_base.go index 0e38d4cc..af5df114 100644 --- a/candidate_base.go +++ b/candidate_base.go @@ -14,9 +14,11 @@ import ( "strings" "sync" "sync/atomic" + "syscall" "time" "github.com/pion/stun/v3" + "github.com/pion/transport/v3" ) type candidateBase struct { @@ -240,8 +242,12 @@ func (c *candidateBase) recvLoop(initializedCh <-chan struct{}) { return } + c.enableSocketOptions() + oob := make([]byte, 128) // buffer for out of band packet attributes + attr := transport.NewPacketAttributesWithLen(1) for { - n, srcAddr, err := c.conn.ReadFrom(buf) + attr.Reset() + n, srcAddr, err := c.readPacketWithAttributes(buf, oob, attr) if err != nil { if !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { agent.log.Warnf("Failed to read from candidate %s: %v", c, err) @@ -250,10 +256,80 @@ func (c *candidateBase) recvLoop(initializedCh <-chan struct{}) { return } - c.handleInboundPacket(buf[:n], srcAddr) + c.handleInboundPacket(buf[:n], attr.GetReadPacketAttributes(), srcAddr) } } +func (c *candidateBase) enableSocketOptions() { + if uc, ok := c.conn.(*net.UDPConn); ok { + raw, _ := uc.SyscallConn() + _ = raw.Control(func(fd uintptr) { + syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_RECVTOS, 1) // TOS/ECN + }) + } +} + +// Reads a packet including its out of band attributes like ECN +// if the underlying conn supports it. +func (c *candidateBase) readPacketWithAttributes( + buf []byte, oob []byte, attr *transport.PacketAttributes) (n int, srcAddr net.Addr, err error) { + var uc *net.UDPConn + var ok bool + + // in case the underlying socket is not udp socket (not a net.UDPConn) + if uc, ok = c.conn.(*net.UDPConn); !ok { + n, srcAddr, err = c.conn.ReadFrom(buf) + return + } + + return c.doReadPacketWithAttributes(buf, oob, attr, uc) +} + +// Reads a packet including its out of band attributes like ECN if possible. +func (c *candidateBase) doReadPacketWithAttributes( + buf []byte, oob []byte, attr *transport.PacketAttributes, uc *net.UDPConn) (n int, srcAddr net.Addr, err error) { + var oobn int + var flags int + var udpAddr *net.UDPAddr + n, oobn, flags, udpAddr, err = uc.ReadMsgUDP(buf, oob) + srcAddr = udpAddr + + if oobn <= 0 { + return + } + + _ = flags + + // Parse control messages for ECN/TOS + cms, err := syscall.ParseSocketControlMessage(oob[:oobn]) + if err != nil { + return + } + + for _, cm := range cms { + // IPv4 TOS + if cm.Header.Level == syscall.IPPROTO_IP && cm.Header.Type == syscall.IP_TOS { + if len(cm.Data) > 0 { + tos := cm.Data[0] + ecn := tos & 0x03 // ECN is the two least significant bits + attr.Buffer[0] = ecn + attr.BytesCopied = 1 + } + } + // IPv6 Traffic Class + if cm.Header.Level == syscall.IPPROTO_IPV6 && cm.Header.Type == syscall.IPV6_TCLASS { + if len(cm.Data) > 0 { + tos := cm.Data[0] + ecn := tos & 0x03 // ECN is the two least significant bits + attr.Buffer[0] = ecn + attr.BytesCopied = 1 + } + } + } + + return +} + func (c *candidateBase) validateSTUNTrafficCache(addr net.Addr) bool { if candidate, ok := c.remoteCandidateCaches[toAddrPort(addr)]; ok { candidate.seen(false) @@ -271,7 +347,7 @@ func (c *candidateBase) addRemoteCandidateCache(candidate Candidate, srcAddr net c.remoteCandidateCaches[toAddrPort(srcAddr)] = candidate } -func (c *candidateBase) handleInboundPacket(buf []byte, srcAddr net.Addr) { +func (c *candidateBase) handleInboundPacket(buf []byte, attr *transport.PacketAttributes, srcAddr net.Addr) { agent := c.agent() if stun.IsMessage(buf) { @@ -309,7 +385,7 @@ func (c *candidateBase) handleInboundPacket(buf []byte, srcAddr net.Addr) { } // Note: This will return packetio.ErrFull if the buffer ever manages to fill up. - n, err := agent.buf.Write(buf) + n, err := agent.buf.WriteWithAttributes(buf, attr) if err != nil { agent.log.Warnf("Failed to write packet: %s", err) diff --git a/transport.go b/transport.go index 31e12866..2047dd9e 100644 --- a/transport.go +++ b/transport.go @@ -10,6 +10,7 @@ import ( "time" "github.com/pion/stun/v3" + "github.com/pion/transport/v3" ) // Dial connects to the remote agent, acting as the controlling ice agent. @@ -79,6 +80,19 @@ func (c *Conn) Read(p []byte) (int, error) { return n, err } +// ReadWithAttributes implements the Conn ReadWithAttributes method. +func (c *Conn) ReadWithAttributes(b []byte, attr *transport.PacketAttributes) (n int, err error) { + err = c.agent.loop.Err() + if err != nil { + return 0, err + } + + n, err = c.agent.buf.ReadWithAttributes(b, attr) + c.bytesReceived.Add(uint64(n)) //nolint:gosec // G115 + + return n, err +} + // Write implements the Conn Write method. func (c *Conn) Write(packet []byte) (int, error) { err := c.agent.loop.Err()