Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 128 additions & 64 deletions conn/bind_std.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"errors"
"net"
"net/netip"
"runtime"
"strconv"
"sync"
"syscall"
Expand All @@ -22,16 +23,21 @@ var (
_ Bind = (*StdNetBind)(nil)
)

// StdNetBind implements Bind for all platforms except Windows.
// StdNetBind implements Bind for all platforms. While Windows has its own Bind
// (see bind_windows.go), it may fall back to StdNetBind.
// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
// methods for sending and receiving multiple datagrams per-syscall. See the
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
type StdNetBind struct {
mu sync.Mutex // protects following fields
ipv4 *net.UDPConn
ipv6 *net.UDPConn
blackhole4 bool
blackhole6 bool
ipv4PC *ipv4.PacketConn
ipv6PC *ipv6.PacketConn
udpAddrPool sync.Pool
mu sync.Mutex // protects following fields
ipv4 *net.UDPConn
ipv6 *net.UDPConn
blackhole4 bool
blackhole6 bool
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
ipv6PC *ipv6.PacketConn // will be nil on non-Linux

udpAddrPool sync.Pool // following fields are not guarded by mu
ipv4MsgsPool sync.Pool
ipv6MsgsPool sync.Pool
}
Expand Down Expand Up @@ -154,6 +160,8 @@ func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
again:
port := int(uport)
var v4conn, v6conn *net.UDPConn
var v4pc *ipv4.PacketConn
var v6pc *ipv6.PacketConn

v4conn, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
Expand All @@ -173,63 +181,92 @@ again:
}
var fns []ReceiveFunc
if v4conn != nil {
fns = append(fns, s.receiveIPv4)
if runtime.GOOS == "linux" {
v4pc = ipv4.NewPacketConn(v4conn)
s.ipv4PC = v4pc
}
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
s.ipv4 = v4conn
}
if v6conn != nil {
fns = append(fns, s.receiveIPv6)
if runtime.GOOS == "linux" {
v6pc = ipv6.NewPacketConn(v6conn)
s.ipv6PC = v6pc
}
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
s.ipv6 = v6conn
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}

s.ipv4PC = ipv4.NewPacketConn(s.ipv4)
s.ipv6PC = ipv6.NewPacketConn(s.ipv6)

return fns, uint16(port), nil
}

func (s *StdNetBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
defer s.ipv4MsgsPool.Put(msgs)
for i := range buffs {
(*msgs)[i].Buffers[0] = buffs[i]
}
numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := asEndpoint(addrPort)
getSrcFromControl(msg.OOB, ep)
eps[i] = ep
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
defer s.ipv4MsgsPool.Put(msgs)
for i := range buffs {
(*msgs)[i].Buffers[0] = buffs[i]
}
var numMsgs int
if runtime.GOOS == "linux" {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := asEndpoint(addrPort)
getSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
return numMsgs, nil
}

func (s *StdNetBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
defer s.ipv6MsgsPool.Put(msgs)
for i := range buffs {
(*msgs)[i].Buffers[0] = buffs[i]
}
numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := asEndpoint(addrPort)
getSrcFromControl(msg.OOB, ep)
eps[i] = ep
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
msgs := s.ipv4MsgsPool.Get().(*[]ipv6.Message)
defer s.ipv4MsgsPool.Put(msgs)
for i := range buffs {
(*msgs)[i].Buffers[0] = buffs[i]
}
var numMsgs int
if runtime.GOOS == "linux" {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := asEndpoint(addrPort)
getSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
return numMsgs, nil
}

// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
Expand All @@ -246,10 +283,12 @@ func (s *StdNetBind) Close() error {
if s.ipv4 != nil {
err1 = s.ipv4.Close()
s.ipv4 = nil
s.ipv4PC = nil
}
if s.ipv6 != nil {
err2 = s.ipv6.Close()
s.ipv6 = nil
s.ipv6PC = nil
}
s.blackhole4 = false
s.blackhole6 = false
Expand All @@ -263,11 +302,18 @@ func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
s.mu.Lock()
blackhole := s.blackhole4
conn := s.ipv4
var (
pc4 *ipv4.PacketConn
pc6 *ipv6.PacketConn
)
is6 := false
if endpoint.DstIP().Is6() {
blackhole = s.blackhole6
conn = s.ipv6
pc6 = s.ipv6PC
is6 = true
} else {
pc4 = s.ipv4PC
}
s.mu.Unlock()

Expand All @@ -278,13 +324,13 @@ func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
return syscall.EAFNOSUPPORT
}
if is6 {
return s.send6(s.ipv6PC, endpoint, buffs)
return s.send6(conn, pc6, endpoint, buffs)
} else {
return s.send4(s.ipv4PC, endpoint, buffs)
return s.send4(conn, pc4, endpoint, buffs)
}
}

func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error {
func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as4 := ep.DstIP().As4()
copy(ua.IP, as4[:])
Expand All @@ -301,19 +347,28 @@ func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) e
err error
start int
)
for {
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
if err != nil || n == len((*msgs)[start:len(buffs)]) {
break
if runtime.GOOS == "linux" {
for {
n, err = pc.WriteBatch((*msgs)[start:len(buffs)], 0)
if err != nil || n == len((*msgs)[start:len(buffs)]) {
break
}
start += n
}
} else {
for i, buff := range buffs {
_, _, err = conn.WriteMsgUDP(buff, (*msgs)[i].OOB, ua)
if err != nil {
break
}
}
start += n
}
s.udpAddrPool.Put(ua)
s.ipv4MsgsPool.Put(msgs)
return err
}

func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error {
func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as16 := ep.DstIP().As16()
copy(ua.IP, as16[:])
Expand All @@ -330,12 +385,21 @@ func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) e
err error
start int
)
for {
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
if err != nil || n == len((*msgs)[start:len(buffs)]) {
break
if runtime.GOOS == "linux" {
for {
n, err = pc.WriteBatch((*msgs)[start:len(buffs)], 0)
if err != nil || n == len((*msgs)[start:len(buffs)]) {
break
}
start += n
}
} else {
for i, buff := range buffs {
_, _, err = conn.WriteMsgUDP(buff, (*msgs)[i].OOB, ua)
if err != nil {
break
}
}
start += n
}
s.udpAddrPool.Put(ua)
s.ipv6MsgsPool.Put(msgs)
Expand Down
22 changes: 22 additions & 0 deletions conn/bind_std_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package conn

import "testing"

func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
bind := NewStdNetBind().(*StdNetBind)
fns, _, err := bind.Open(0)
if err != nil {
t.Fatal(err)
}
bind.Close()
buffs := make([][]byte, 1)
buffs[0] = make([]byte, 1)
sizes := make([]int, 1)
eps := make([]Endpoint, 1)
for _, fn := range fns {
// The ReceiveFuncs must not access conn-related fields on StdNetBind
// unguarded. Close() nils the conn-related fields resulting in a panic
// if they violate the mutex.
fn(buffs, sizes, eps)
}
}