diff --git a/conn/bind_std.go b/conn/bind_std.go index f5c88160e..d6e0a21be 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -46,9 +46,11 @@ type StdNetBind struct { blackhole4 bool blackhole6 bool + + extraFns []ControlFn } -func NewStdNetBind() Bind { +func NewStdNetBind(fns []ControlFn) Bind { return &StdNetBind{ udpAddrPool: sync.Pool{ New: func() any { @@ -70,6 +72,8 @@ func NewStdNetBind() Bind { return &msgs }, }, + + extraFns: fns, } } @@ -119,8 +123,8 @@ func (e *StdNetEndpoint) DstToString() string { return e.AddrPort.String() } -func listenNet(network string, port int) (*net.UDPConn, int, error) { - conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) +func listenNet(network string, port int, fns []ControlFn) (*net.UDPConn, int, error) { + conn, err := listenConfig(fns).ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) if err != nil { return nil, 0, err } @@ -156,13 +160,13 @@ again: var v4pc *ipv4.PacketConn var v6pc *ipv6.PacketConn - v4conn, port, err = listenNet("udp4", port) + v4conn, port, err = listenNet("udp4", port, s.extraFns) if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { return nil, 0, err } // Listen on the same port as we're using for ipv4. - v6conn, port, err = listenNet("udp6", port) + v6conn, port, err = listenNet("udp6", port, s.extraFns) if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { v4conn.Close() tries++ diff --git a/conn/controlfns.go b/conn/controlfns.go index 27421bd26..61b7233f4 100644 --- a/conn/controlfns.go +++ b/conn/controlfns.go @@ -20,16 +20,16 @@ const socketBufferSize = 7 << 20 // controlFn is the callback function signature from net.ListenConfig.Control. // It is used to apply platform specific configuration to the socket prior to // bind. -type controlFn func(network, address string, c syscall.RawConn) error +type ControlFn func(network, address string, c syscall.RawConn) error // controlFns is a list of functions that are called from the listen config // that can apply socket options. -var controlFns = []controlFn{} +var controlFns = []ControlFn{} // listenConfig returns a net.ListenConfig that applies the controlFns to the // socket prior to bind. This is used to apply socket buffer sizing and packet // information OOB configuration for sticky sockets. -func listenConfig() *net.ListenConfig { +func listenConfig(extraFns []ControlFn) *net.ListenConfig { return &net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { for _, fn := range controlFns { @@ -37,6 +37,12 @@ func listenConfig() *net.ListenConfig { return err } } + + for _, fn := range extraFns { + if err := fn(network, address, c); err != nil { + return err + } + } return nil }, } diff --git a/conn/default.go b/conn/default.go index 2ce157956..6fdcabd04 100644 --- a/conn/default.go +++ b/conn/default.go @@ -7,4 +7,4 @@ package conn -func NewDefaultBind() Bind { return NewStdNetBind() } +func NewDefaultBind() Bind { return NewStdNetBind(nil) } diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go index 1b1ee6833..9a5f6e372 100644 --- a/conn/sticky_linux_test.go +++ b/conn/sticky_linux_test.go @@ -213,7 +213,7 @@ func Test_getSrcFromControl(t *testing.T) { func Test_listenConfig(t *testing.T) { t.Run("IPv4", func(t *testing.T) { - conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0") + conn, err := listenConfig(nil).ListenPacket(context.Background(), "udp4", ":0") if err != nil { t.Fatal(err) } @@ -239,7 +239,7 @@ func Test_listenConfig(t *testing.T) { } }) t.Run("IPv6", func(t *testing.T) { - conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0") + conn, err := listenConfig(nil).ListenPacket(context.Background(), "udp6", ":0") if err != nil { t.Fatal(err) }