Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions conn/bind_std.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ type StdNetBind struct {

blackhole4 bool
blackhole6 bool

extraFns []ControlFn
}

func NewStdNetBind() Bind {
func NewStdNetBind(fns []ControlFn) Bind {
return &StdNetBind{
udpAddrPool: sync.Pool{
New: func() any {
Expand All @@ -70,6 +72,8 @@ func NewStdNetBind() Bind {
return &msgs
},
},

extraFns: fns,
}
}

Expand Down Expand Up @@ -119,8 +123,8 @@ func (e *StdNetEndpoint) DstToString() string {
return e.AddrPort.String()
}

func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
func listenNet(network string, port int, fns []ControlFn) (*net.UDPConn, int, error) {
conn, err := listenConfig(fns).ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
if err != nil {
return nil, 0, err
}
Expand Down Expand Up @@ -156,13 +160,13 @@ again:
var v4pc *ipv4.PacketConn
var v6pc *ipv6.PacketConn

v4conn, port, err = listenNet("udp4", port)
v4conn, port, err = listenNet("udp4", port, s.extraFns)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}

// Listen on the same port as we're using for ipv4.
v6conn, port, err = listenNet("udp6", port)
v6conn, port, err = listenNet("udp6", port, s.extraFns)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
v4conn.Close()
tries++
Expand Down
12 changes: 9 additions & 3 deletions conn/controlfns.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,29 @@ const socketBufferSize = 7 << 20
// controlFn is the callback function signature from net.ListenConfig.Control.
// It is used to apply platform specific configuration to the socket prior to
// bind.
type controlFn func(network, address string, c syscall.RawConn) error
type ControlFn func(network, address string, c syscall.RawConn) error

// controlFns is a list of functions that are called from the listen config
// that can apply socket options.
var controlFns = []controlFn{}
var controlFns = []ControlFn{}

// listenConfig returns a net.ListenConfig that applies the controlFns to the
// socket prior to bind. This is used to apply socket buffer sizing and packet
// information OOB configuration for sticky sockets.
func listenConfig() *net.ListenConfig {
func listenConfig(extraFns []ControlFn) *net.ListenConfig {
return &net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
for _, fn := range controlFns {
if err := fn(network, address, c); err != nil {
return err
}
}

for _, fn := range extraFns {
if err := fn(network, address, c); err != nil {
return err
}
}
return nil
},
}
Expand Down
2 changes: 1 addition & 1 deletion conn/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@

package conn

func NewDefaultBind() Bind { return NewStdNetBind() }
func NewDefaultBind() Bind { return NewStdNetBind(nil) }
4 changes: 2 additions & 2 deletions conn/sticky_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func Test_getSrcFromControl(t *testing.T) {

func Test_listenConfig(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
conn, err := listenConfig(nil).ListenPacket(context.Background(), "udp4", ":0")
if err != nil {
t.Fatal(err)
}
Expand All @@ -239,7 +239,7 @@ func Test_listenConfig(t *testing.T) {
}
})
t.Run("IPv6", func(t *testing.T) {
conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
conn, err := listenConfig(nil).ListenPacket(context.Background(), "udp6", ":0")
if err != nil {
t.Fatal(err)
}
Expand Down