diff --git a/candidate.go b/candidate.go index 89082f98..811c17a0 100644 --- a/candidate.go +++ b/candidate.go @@ -92,4 +92,5 @@ type Candidate interface { seen(outbound bool) start(a *Agent, conn net.PacketConn, initializedCh <-chan struct{}) writeTo(raw []byte, dst Candidate) (int, error) + writeBatchTo(rawPackets [][]byte, dst Candidate) (int, error) } diff --git a/candidate_base.go b/candidate_base.go index 40aeab5f..440c6e05 100644 --- a/candidate_base.go +++ b/candidate_base.go @@ -17,6 +17,8 @@ import ( "time" "github.com/pion/stun/v3" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" ) type candidateBase struct { @@ -35,6 +37,8 @@ type candidateBase struct { lastSent atomic.Int64 lastReceived atomic.Int64 conn net.PacketConn + ipv4Conn *ipv4.PacketConn + ipv6Conn *ipv6.PacketConn currAgent *Agent closeCh chan struct{} @@ -227,6 +231,12 @@ func (c *candidateBase) start(a *Agent, conn net.PacketConn, initializedCh <-cha c.closeCh = make(chan struct{}) c.closedCh = make(chan struct{}) + if c.networkType.IsIPv6() { + c.ipv6Conn = ipv6.NewPacketConn(conn) + } else { + c.ipv4Conn = ipv4.NewPacketConn(conn) + } + go c.recvLoop(initializedCh) } @@ -391,6 +401,59 @@ func (c *candidateBase) writeTo(raw []byte, dst Candidate) (int, error) { return n, nil } +func (c *candidateBase) writeBatchTo(rawPackets [][]byte, dst Candidate) (int, error) { + if len(rawPackets) == 0 { + return 0, nil + } + + dstAddr := dst.addr() + + // Build messages for batch write. + messages := make([]ipv4.Message, len(rawPackets)) + for i, raw := range rawPackets { + messages[i] = ipv4.Message{ + Buffers: [][]byte{raw}, + Addr: dstAddr, + } + } + + // WriteBatch uses sendmmsg on Linux for improved performance. + // On other platforms it writes one message at a time, so we loop. + totalWritten := 0 + for totalWritten < len(messages) { + var n int + var err error + + if c.ipv6Conn != nil { + n, err = c.ipv6Conn.WriteBatch(messages[totalWritten:], 0) + } else { + n, err = c.ipv4Conn.WriteBatch(messages[totalWritten:], 0) + } + + if err != nil { + // If the connection is closed, we should return the error. + if errors.Is(err, io.ErrClosedPipe) { + return totalWritten, err + } + c.agent().log.Infof("Failed to send batch packets: %v", err) + + return totalWritten, nil + } + + if n == 0 { + break + } + + totalWritten += n + } + + if totalWritten > 0 { + c.seen(true) + } + + return totalWritten, nil +} + // TypePreference returns the type preference for this candidate. func (c *candidateBase) TypePreference() uint16 { pref := c.Type().Preference() diff --git a/candidate_test.go b/candidate_test.go index c76a3db4..bdf54dbe 100644 --- a/candidate_test.go +++ b/candidate_test.go @@ -12,6 +12,8 @@ import ( "github.com/pion/logging" "github.com/stretchr/testify/require" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" ) const localhostIPStr = "127.0.0.1" @@ -620,6 +622,99 @@ func TestCandidateWriteTo(t *testing.T) { require.Error(t, err, "writing to closed conn") } +func TestCandidateWriteBatchTo(t *testing.T) { + testCases := []struct { + name string + network string + ip net.IP + networkType NetworkType + setupConn func(*candidateBase, *net.UDPConn) + }{ + { + name: "UDP_IPv4", + network: "udp4", + ip: net.IP{127, 0, 0, 1}, + networkType: NetworkTypeUDP4, + setupConn: func(c *candidateBase, conn *net.UDPConn) { + c.ipv4Conn = ipv4.NewPacketConn(conn) + }, + }, + { + name: "UDP_IPv6", + network: "udp6", + ip: net.IPv6loopback, + networkType: NetworkTypeUDP6, + setupConn: func(c *candidateBase, conn *net.UDPConn) { + c.ipv6Conn = ipv6.NewPacketConn(conn) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + receiverConn, err := net.ListenUDP(tc.network, &net.UDPAddr{IP: tc.ip, Port: 0}) + if err != nil { + t.Skipf("%s not available on this system", tc.name) + } + defer func() { _ = receiverConn.Close() }() + + senderConn, err := net.ListenUDP(tc.network, &net.UDPAddr{IP: tc.ip, Port: 0}) + require.NoError(t, err, "error creating test UDP sender") + defer func() { _ = senderConn.Close() }() + + loggerFactory := logging.NewDefaultLoggerFactory() + + c1 := &candidateBase{ + conn: senderConn, + networkType: tc.networkType, + currAgent: &Agent{ + log: loggerFactory.NewLogger("agent"), + }, + } + tc.setupConn(c1, senderConn) + + c2 := &candidateBase{ + resolvedAddr: receiverConn.LocalAddr(), + } + + // Test with empty batch. + n, err := c1.writeBatchTo([][]byte{}, c2) + require.NoError(t, err, "writing empty batch should not error") + require.Equal(t, 0, n, "writing empty batch should return 0") + + // Test with single packet. + n, err = c1.writeBatchTo([][]byte{[]byte("test1")}, c2) + require.NoError(t, err, "writing single packet batch") + require.Equal(t, 1, n, "should have written 1 message") + + // Read the packet on receiver side. + buf := make([]byte, 1024) + require.NoError(t, receiverConn.SetReadDeadline(time.Now().Add(time.Second))) + nr, _, err := receiverConn.ReadFromUDP(buf) + require.NoError(t, err, "reading packet") + require.Equal(t, "test1", string(buf[:nr])) + + // Test with multiple packets. + packets := [][]byte{ + []byte("packet1"), + []byte("packet2"), + []byte("packet3"), + } + n, err = c1.writeBatchTo(packets, c2) + require.NoError(t, err, "writing multiple packets batch") + require.Equal(t, len(packets), n, "should have written all messages") + + // Read all packets. + for i := 0; i < len(packets); i++ { + require.NoError(t, receiverConn.SetReadDeadline(time.Now().Add(time.Second))) + nr, _, err = receiverConn.ReadFromUDP(buf) + require.NoError(t, err, "reading packet %d", i) + require.Equal(t, string(packets[i]), string(buf[:nr])) + } + }) + } +} + func TestMarshalUnmarshalCandidateWithZoneID(t *testing.T) { candidateWithZoneID := mustCandidateHost(t, &CandidateHostConfig{ Network: NetworkTypeUDP6.String(), diff --git a/candidatepair.go b/candidatepair.go index 31b8bf19..93d88a23 100644 --- a/candidatepair.go +++ b/candidatepair.go @@ -127,10 +127,19 @@ func (p *CandidatePair) priority() uint64 { return (1<<32-1)*localMin(g, d) + 2*localMax(g, d) + cmp(g, d) } +// Write sends a single packet on the candidate pair. +// Returns the number of bytes written. func (p *CandidatePair) Write(b []byte) (int, error) { return p.Local.writeTo(b, p.Remote) } +// WriteBatch sends multiple packets on the candidate pair. +// On Linux, this uses sendmmsg for improved performance. +// Returns the number of packets successfully written. +func (p *CandidatePair) WriteBatch(packets [][]byte) (int, error) { + return p.Local.writeBatchTo(packets, p.Remote) +} + func (a *Agent) sendSTUN(msg *stun.Message, local, remote Candidate) { _, err := local.writeTo(msg.Raw, remote) if err != nil {