From bd7e9d35d114da345d0aa845e61686df8cfc5e1e Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Tue, 28 Mar 2023 13:40:31 -0700 Subject: [PATCH 01/39] all: rename module (#7) Signed-off-by: Jordan Whited --- conn/bind_windows.go | 2 +- conn/bindtest/bindtest.go | 2 +- device/bind_test.go | 2 +- device/device.go | 8 ++++---- device/device_test.go | 8 ++++---- device/keypair.go | 2 +- device/noise-protocol.go | 2 +- device/noise_test.go | 4 ++-- device/peer.go | 2 +- device/queueconstants_android.go | 2 +- device/queueconstants_default.go | 2 +- device/receive.go | 2 +- device/send.go | 2 +- device/sticky_default.go | 4 ++-- device/sticky_linux.go | 4 ++-- device/tun.go | 2 +- device/uapi.go | 2 +- go.mod | 2 +- ipc/namedpipe/namedpipe_test.go | 2 +- ipc/uapi_linux.go | 2 +- ipc/uapi_windows.go | 2 +- main.go | 8 ++++---- main_windows.go | 8 ++++---- tun/netstack/examples/http_client.go | 6 +++--- tun/netstack/examples/http_server.go | 6 +++--- tun/netstack/examples/ping_client.go | 6 +++--- tun/netstack/tun.go | 2 +- tun/tcp_offload_linux.go | 2 +- tun/tcp_offload_linux_test.go | 2 +- tun/tun_linux.go | 4 ++-- tun/tuntest/tuntest.go | 2 +- 31 files changed, 53 insertions(+), 53 deletions(-) diff --git a/conn/bind_windows.go b/conn/bind_windows.go index d5095e004..9638b3096 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -17,7 +17,7 @@ import ( "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/conn/winrio" + "github.com/tailscale/wireguard-go/conn/winrio" ) const ( diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go index 74e7addd2..836d983ce 100644 --- a/conn/bindtest/bindtest.go +++ b/conn/bindtest/bindtest.go @@ -12,7 +12,7 @@ import ( "net/netip" "os" - "golang.zx2c4.com/wireguard/conn" + "github.com/tailscale/wireguard-go/conn" ) type ChannelBind struct { diff --git a/device/bind_test.go b/device/bind_test.go index 302a52180..d64ca09b4 100644 --- a/device/bind_test.go +++ b/device/bind_test.go @@ -8,7 +8,7 @@ package device import ( "errors" - "golang.zx2c4.com/wireguard/conn" + "github.com/tailscale/wireguard-go/conn" ) type DummyDatagram struct { diff --git a/device/device.go b/device/device.go index 1af9fe017..7482d9b33 100644 --- a/device/device.go +++ b/device/device.go @@ -11,10 +11,10 @@ import ( "sync/atomic" "time" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/ratelimiter" - "golang.zx2c4.com/wireguard/rwcancel" - "golang.zx2c4.com/wireguard/tun" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/ratelimiter" + "github.com/tailscale/wireguard-go/rwcancel" + "github.com/tailscale/wireguard-go/tun" ) type Device struct { diff --git a/device/device_test.go b/device/device_test.go index fff172bb8..4088b9fab 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -20,10 +20,10 @@ import ( "testing" "time" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/conn/bindtest" - "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/tun/tuntest" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/conn/bindtest" + "github.com/tailscale/wireguard-go/tun" + "github.com/tailscale/wireguard-go/tun/tuntest" ) // uapiCfg returns a string that contains cfg formatted use with IpcSet. diff --git a/device/keypair.go b/device/keypair.go index e3540d7d4..2689ee2a5 100644 --- a/device/keypair.go +++ b/device/keypair.go @@ -11,7 +11,7 @@ import ( "sync/atomic" "time" - "golang.zx2c4.com/wireguard/replay" + "github.com/tailscale/wireguard-go/replay" ) /* Due to limitations in Go and /x/crypto there is currently diff --git a/device/noise-protocol.go b/device/noise-protocol.go index e8f6145e5..9f2ba509a 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -15,7 +15,7 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/poly1305" - "golang.zx2c4.com/wireguard/tai64n" + "github.com/tailscale/wireguard-go/tai64n" ) type handshakeState int diff --git a/device/noise_test.go b/device/noise_test.go index 2dd53241d..7d6af1df0 100644 --- a/device/noise_test.go +++ b/device/noise_test.go @@ -10,8 +10,8 @@ import ( "encoding/binary" "testing" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/tun/tuntest" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/tun/tuntest" ) func TestCurveWrappers(t *testing.T) { diff --git a/device/peer.go b/device/peer.go index 0ac48962c..c7163ac58 100644 --- a/device/peer.go +++ b/device/peer.go @@ -12,7 +12,7 @@ import ( "sync/atomic" "time" - "golang.zx2c4.com/wireguard/conn" + "github.com/tailscale/wireguard-go/conn" ) type Peer struct { diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go index 3d80eadb0..bab9625c4 100644 --- a/device/queueconstants_android.go +++ b/device/queueconstants_android.go @@ -5,7 +5,7 @@ package device -import "golang.zx2c4.com/wireguard/conn" +import "github.com/tailscale/wireguard-go/conn" /* Reduce memory consumption for Android */ diff --git a/device/queueconstants_default.go b/device/queueconstants_default.go index ea763d01c..9749cb789 100644 --- a/device/queueconstants_default.go +++ b/device/queueconstants_default.go @@ -7,7 +7,7 @@ package device -import "golang.zx2c4.com/wireguard/conn" +import "github.com/tailscale/wireguard-go/conn" const ( QueueStagedSize = conn.IdealBatchSize diff --git a/device/receive.go b/device/receive.go index e24d29f5b..c0cf74791 100644 --- a/device/receive.go +++ b/device/receive.go @@ -13,10 +13,10 @@ import ( "sync" "time" + "github.com/tailscale/wireguard-go/conn" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" - "golang.zx2c4.com/wireguard/conn" ) type QueueHandshakeElement struct { diff --git a/device/send.go b/device/send.go index d22bf264e..a95a46fb6 100644 --- a/device/send.go +++ b/device/send.go @@ -14,10 +14,10 @@ import ( "sync" "time" + "github.com/tailscale/wireguard-go/tun" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" - "golang.zx2c4.com/wireguard/tun" ) /* Outbound flow diff --git a/device/sticky_default.go b/device/sticky_default.go index 10382565e..732f84ce7 100644 --- a/device/sticky_default.go +++ b/device/sticky_default.go @@ -3,8 +3,8 @@ package device import ( - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/rwcancel" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/rwcancel" ) func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { diff --git a/device/sticky_linux.go b/device/sticky_linux.go index f9230f8c3..7a519c1f8 100644 --- a/device/sticky_linux.go +++ b/device/sticky_linux.go @@ -20,8 +20,8 @@ import ( "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/rwcancel" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/rwcancel" ) func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { diff --git a/device/tun.go b/device/tun.go index 2a2ace920..960ecca40 100644 --- a/device/tun.go +++ b/device/tun.go @@ -8,7 +8,7 @@ package device import ( "fmt" - "golang.zx2c4.com/wireguard/tun" + "github.com/tailscale/wireguard-go/tun" ) const DefaultMTU = 1420 diff --git a/device/uapi.go b/device/uapi.go index 617dcd333..2a91a9361 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -18,7 +18,7 @@ import ( "sync" "time" - "golang.zx2c4.com/wireguard/ipc" + "github.com/tailscale/wireguard-go/ipc" ) type IPCError struct { diff --git a/go.mod b/go.mod index c04e1bb61..0d60c9a46 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module golang.zx2c4.com/wireguard +module github.com/tailscale/wireguard-go go 1.20 diff --git a/ipc/namedpipe/namedpipe_test.go b/ipc/namedpipe/namedpipe_test.go index 998453be0..de7d0f63c 100644 --- a/ipc/namedpipe/namedpipe_test.go +++ b/ipc/namedpipe/namedpipe_test.go @@ -20,8 +20,8 @@ import ( "testing" "time" + "github.com/tailscale/wireguard-go/ipc/namedpipe" "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/ipc/namedpipe" ) func randomPipePath() string { diff --git a/ipc/uapi_linux.go b/ipc/uapi_linux.go index 1562a1834..bfdf1bffd 100644 --- a/ipc/uapi_linux.go +++ b/ipc/uapi_linux.go @@ -9,8 +9,8 @@ import ( "net" "os" + "github.com/tailscale/wireguard-go/rwcancel" "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/rwcancel" ) type UAPIListener struct { diff --git a/ipc/uapi_windows.go b/ipc/uapi_windows.go index aa023c92a..bc30ae0dd 100644 --- a/ipc/uapi_windows.go +++ b/ipc/uapi_windows.go @@ -8,8 +8,8 @@ package ipc import ( "net" + "github.com/tailscale/wireguard-go/ipc/namedpipe" "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/ipc/namedpipe" ) // TODO: replace these with actual standard windows error numbers from the win package diff --git a/main.go b/main.go index e01611694..55000e9f9 100644 --- a/main.go +++ b/main.go @@ -14,11 +14,11 @@ import ( "runtime" "strconv" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/ipc" + "github.com/tailscale/wireguard-go/tun" "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/ipc" - "golang.zx2c4.com/wireguard/tun" ) const ( diff --git a/main_windows.go b/main_windows.go index a4dc46f2c..689a9a79a 100644 --- a/main_windows.go +++ b/main_windows.go @@ -12,11 +12,11 @@ import ( "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/ipc" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/ipc" - "golang.zx2c4.com/wireguard/tun" + "github.com/tailscale/wireguard-go/tun" ) const ( diff --git a/tun/netstack/examples/http_client.go b/tun/netstack/examples/http_client.go index ccd32ede3..81f4d3180 100644 --- a/tun/netstack/examples/http_client.go +++ b/tun/netstack/examples/http_client.go @@ -13,9 +13,9 @@ import ( "net/http" "net/netip" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun/netstack" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun/netstack" ) func main() { diff --git a/tun/netstack/examples/http_server.go b/tun/netstack/examples/http_server.go index f5b7a8ff8..30f454454 100644 --- a/tun/netstack/examples/http_server.go +++ b/tun/netstack/examples/http_server.go @@ -14,9 +14,9 @@ import ( "net/http" "net/netip" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun/netstack" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun/netstack" ) func main() { diff --git a/tun/netstack/examples/ping_client.go b/tun/netstack/examples/ping_client.go index 2eef0fbc2..fc991b156 100644 --- a/tun/netstack/examples/ping_client.go +++ b/tun/netstack/examples/ping_client.go @@ -17,9 +17,9 @@ import ( "golang.org/x/net/icmp" "golang.org/x/net/ipv4" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun/netstack" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun/netstack" ) func main() { diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index 596cfcd8a..24028427c 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -22,7 +22,7 @@ import ( "syscall" "time" - "golang.zx2c4.com/wireguard/tun" + "github.com/tailscale/wireguard-go/tun" "golang.org/x/net/dns/dnsmessage" "gvisor.dev/gvisor/pkg/bufferv2" diff --git a/tun/tcp_offload_linux.go b/tun/tcp_offload_linux.go index 39a7180c5..d64010d01 100644 --- a/tun/tcp_offload_linux.go +++ b/tun/tcp_offload_linux.go @@ -12,8 +12,8 @@ import ( "io" "unsafe" + "github.com/tailscale/wireguard-go/conn" "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/conn" ) const tcpFlagsOffset = 13 diff --git a/tun/tcp_offload_linux_test.go b/tun/tcp_offload_linux_test.go index 9160e18cd..e828642e8 100644 --- a/tun/tcp_offload_linux_test.go +++ b/tun/tcp_offload_linux_test.go @@ -9,8 +9,8 @@ import ( "net/netip" "testing" + "github.com/tailscale/wireguard-go/conn" "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/conn" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 12cd49f74..eb5051ed8 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -17,9 +17,9 @@ import ( "time" "unsafe" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/rwcancel" "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/rwcancel" ) const ( diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go index d07e8601a..e7507c26c 100644 --- a/tun/tuntest/tuntest.go +++ b/tun/tuntest/tuntest.go @@ -11,7 +11,7 @@ import ( "net/netip" "os" - "golang.zx2c4.com/wireguard/tun" + "github.com/tailscale/wireguard-go/tun" ) func Ping(dst, src netip.Addr) []byte { From e26adb828d950319d0d0f17178035597e70904a7 Mon Sep 17 00:00:00 2001 From: James Tucker Date: Wed, 27 Sep 2023 16:15:09 -0700 Subject: [PATCH 02/39] go.mod,tun/netstack: bump gvisor Signed-off-by: James Tucker --- go.mod | 10 +++++----- go.sum | 20 ++++++++++---------- tun/netstack/tun.go | 14 +++++++------- tun/tcp_offload_linux_test.go | 8 ++++---- 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/go.mod b/go.mod index 0d60c9a46..9c9b02a66 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,14 @@ module github.com/tailscale/wireguard-go go 1.20 require ( - golang.org/x/crypto v0.6.0 - golang.org/x/net v0.7.0 - golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 + golang.org/x/crypto v0.13.0 + golang.org/x/net v0.15.0 + golang.org/x/sys v0.12.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 - gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 + gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 ) require ( github.com/google/btree v1.0.1 // indirect - golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect + golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect ) diff --git a/go.sum b/go.sum index cfeaee623..6bcecea3f 100644 --- a/go.sum +++ b/go.sum @@ -1,14 +1,14 @@ github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= -golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= -golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= -golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= -golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 h1:260HNjMTPDya+jq5AM1zZLgG9pv9GASPAGiEEJUbRg4= -golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44= +golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= -gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 h1:Wobr37noukisGxpKo5jAsLREcpj61RxrWYzD8uwveOY= -gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0/go.mod h1:Dn5idtptoW1dIos9U6A2rpebLs/MtTwFacjKb8jLdQA= +gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= +gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index 24028427c..d8e70bb03 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -25,7 +25,7 @@ import ( "github.com/tailscale/wireguard-go/tun" "golang.org/x/net/dns/dnsmessage" - "gvisor.dev/gvisor/pkg/bufferv2" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -43,7 +43,7 @@ type netTun struct { ep *channel.Endpoint stack *stack.Stack events chan tun.Event - incomingPacket chan *bufferv2.View + incomingPacket chan *buffer.View mtu int dnsServers []netip.Addr hasV4, hasV6 bool @@ -61,7 +61,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, ep: channel.New(1024, uint32(mtu), ""), stack: stack.New(opts), events: make(chan tun.Event, 10), - incomingPacket: make(chan *bufferv2.View), + incomingPacket: make(chan *buffer.View), dnsServers: dnsServers, mtu: mtu, } @@ -84,7 +84,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, } protoAddr := tcpip.ProtocolAddress{ Protocol: protoNumber, - AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(), + AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), } tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) if tcpipErr != nil { @@ -140,7 +140,7 @@ func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { continue } - pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)}) + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) switch packet[0] >> 4 { case 4: tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) @@ -198,7 +198,7 @@ func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.Networ } return tcpip.FullAddress{ NIC: 1, - Addr: tcpip.Address(endpoint.Addr().AsSlice()), + Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), Port: endpoint.Port(), }, protoNumber } @@ -453,7 +453,7 @@ func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { return 0, nil, fmt.Errorf("ping read: %s", tcpipErr) } - remoteAddr, _ := netip.AddrFromSlice([]byte(res.RemoteAddr.Addr)) + remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice()) return res.Count, &PingAddr{remoteAddr}, nil } diff --git a/tun/tcp_offload_linux_test.go b/tun/tcp_offload_linux_test.go index e828642e8..41fba7064 100644 --- a/tun/tcp_offload_linux_test.go +++ b/tun/tcp_offload_linux_test.go @@ -35,8 +35,8 @@ func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header. srcAs4 := srcIPPort.Addr().As4() dstAs4 := dstIPPort.Addr().As4() ipFields := &header.IPv4Fields{ - SrcAddr: tcpip.Address(srcAs4[:]), - DstAddr: tcpip.Address(dstAs4[:]), + SrcAddr: tcpip.AddrFromSlice(srcAs4[:]), + DstAddr: tcpip.AddrFromSlice(dstAs4[:]), Protocol: unix.IPPROTO_TCP, TTL: 64, TotalLength: uint16(totalLen), @@ -72,8 +72,8 @@ func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header. srcAs16 := srcIPPort.Addr().As16() dstAs16 := dstIPPort.Addr().As16() ipFields := &header.IPv6Fields{ - SrcAddr: tcpip.Address(srcAs16[:]), - DstAddr: tcpip.Address(dstAs16[:]), + SrcAddr: tcpip.AddrFromSlice(srcAs16[:]), + DstAddr: tcpip.AddrFromSlice(dstAs16[:]), TransportProtocol: unix.IPPROTO_TCP, HopLimit: 64, PayloadLength: uint16(segmentSize + 20), From e06231b8611133d249b8a1d5eaf1588c27800f05 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Tue, 4 Apr 2023 13:04:30 -0700 Subject: [PATCH 03/39] conn, device: use UDP GSO and GRO on Linux StdNetBind probes for UDP GSO and GRO support at runtime. UDP GSO is dependent on checksum offload support on the egress netdev. UDP GSO will be disabled in the event sendmmsg() returns EIO, which is a strong signal that the egress netdev does not support checksum offload. The iperf3 results below demonstrate the effect of this commit between two Linux computers with i5-12400 CPUs. There is roughly ~13us of round trip latency between them. The first result is from commit 052af4a without UDP GSO or GRO. Starting Test: protocol: TCP, 1 streams, 131072 byte blocks [ ID] Interval Transfer Bitrate Retr Cwnd [ 5] 0.00-10.00 sec 9.85 GBytes 8.46 Gbits/sec 1139 3.01 MBytes - - - - - - - - - - - - - - - - - - - - - - - - - Test Complete. Summary Results: [ ID] Interval Transfer Bitrate Retr [ 5] 0.00-10.00 sec 9.85 GBytes 8.46 Gbits/sec 1139 sender [ 5] 0.00-10.04 sec 9.85 GBytes 8.42 Gbits/sec receiver The second result is with UDP GSO and GRO. Starting Test: protocol: TCP, 1 streams, 131072 byte blocks [ ID] Interval Transfer Bitrate Retr Cwnd [ 5] 0.00-10.00 sec 12.3 GBytes 10.6 Gbits/sec 232 3.15 MBytes - - - - - - - - - - - - - - - - - - - - - - - - - Test Complete. Summary Results: [ ID] Interval Transfer Bitrate Retr [ 5] 0.00-10.00 sec 12.3 GBytes 10.6 Gbits/sec 232 sender [ 5] 0.00-10.04 sec 12.3 GBytes 10.6 Gbits/sec receiver Reviewed-by: Adrian Dewhurst Signed-off-by: Jordan Whited --- conn/bind_std.go | 398 ++++++++++++------ conn/bind_std_test.go | 230 +++++++++- .../{sticky_default.go => control_default.go} | 22 +- conn/{sticky_linux.go => control_linux.go} | 51 ++- ...ky_linux_test.go => control_linux_test.go} | 4 +- conn/controlfns_linux.go | 8 + conn/errors_default.go | 12 + conn/errors_linux.go | 26 ++ conn/features_default.go | 15 + conn/features_linux.go | 42 ++ device/send.go | 8 + 11 files changed, 674 insertions(+), 142 deletions(-) rename conn/{sticky_default.go => control_default.go} (54%) rename conn/{sticky_linux.go => control_linux.go} (65%) rename conn/{sticky_linux_test.go => control_linux_test.go} (98%) create mode 100644 conn/errors_default.go create mode 100644 conn/errors_linux.go create mode 100644 conn/features_default.go create mode 100644 conn/features_linux.go diff --git a/conn/bind_std.go b/conn/bind_std.go index c701ef872..cc5cf2311 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -8,6 +8,7 @@ package conn import ( "context" "errors" + "fmt" "net" "net/netip" "runtime" @@ -20,7 +21,8 @@ import ( ) var ( - _ Bind = (*StdNetBind)(nil) + _ Bind = (*StdNetBind)(nil) + _ Endpoint = (*StdNetEndpoint)(nil) ) // StdNetBind implements Bind for all platforms. While Windows has its own Bind @@ -29,16 +31,19 @@ var ( // methods for sending and receiving multiple datagrams per-syscall. See the // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. type StdNetBind struct { - mu sync.Mutex // protects all fields except as specified - ipv4 *net.UDPConn - ipv6 *net.UDPConn - ipv4PC *ipv4.PacketConn // will be nil on non-Linux - ipv6PC *ipv6.PacketConn // will be nil on non-Linux - - // these three fields are not guarded by mu - udpAddrPool sync.Pool - ipv4MsgsPool sync.Pool - ipv6MsgsPool sync.Pool + mu sync.Mutex // protects all fields except as specified + ipv4 *net.UDPConn + ipv6 *net.UDPConn + ipv4PC *ipv4.PacketConn // will be nil on non-Linux + ipv6PC *ipv6.PacketConn // will be nil on non-Linux + ipv4TxOffload bool + ipv4RxOffload bool + ipv6TxOffload bool + ipv6RxOffload bool + + // these two fields are not guarded by mu + udpAddrPool sync.Pool + msgsPool sync.Pool blackhole4 bool blackhole6 bool @@ -54,23 +59,12 @@ func NewStdNetBind() Bind { }, }, - ipv4MsgsPool: sync.Pool{ - New: func() any { - msgs := make([]ipv4.Message, IdealBatchSize) - for i := range msgs { - msgs[i].Buffers = make(net.Buffers, 1) - msgs[i].OOB = make([]byte, srcControlSize) - } - return &msgs - }, - }, - - ipv6MsgsPool: sync.Pool{ + msgsPool: sync.Pool{ New: func() any { msgs := make([]ipv6.Message, IdealBatchSize) for i := range msgs { msgs[i].Buffers = make(net.Buffers, 1) - msgs[i].OOB = make([]byte, srcControlSize) + msgs[i].OOB = make([]byte, controlSize) } return &msgs }, @@ -179,19 +173,21 @@ again: } var fns []ReceiveFunc if v4conn != nil { + s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn) if runtime.GOOS == "linux" { v4pc = ipv4.NewPacketConn(v4conn) s.ipv4PC = v4pc } - fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn)) + fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload)) s.ipv4 = v4conn } if v6conn != nil { + s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn) if runtime.GOOS == "linux" { v6pc = ipv6.NewPacketConn(v6conn) s.ipv6PC = v6pc } - fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn)) + fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload)) s.ipv6 = v6conn } if len(fns) == 0 { @@ -201,69 +197,93 @@ again: return fns, uint16(port), nil } -func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc { - return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) - defer s.ipv4MsgsPool.Put(msgs) - for i := range bufs { - (*msgs)[i].Buffers[0] = bufs[i] - } - var numMsgs int - if runtime.GOOS == "linux" { - numMsgs, err = pc.ReadBatch(*msgs, 0) +func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) { + for i := range *msgs { + (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB} + } + s.msgsPool.Put(msgs) +} + +func (s *StdNetBind) getMessages() *[]ipv6.Message { + return s.msgsPool.Get().(*[]ipv6.Message) +} + +var ( + // If compilation fails here these are no longer the same underlying type. + _ ipv6.Message = ipv4.Message{} +) + +type batchReader interface { + ReadBatch([]ipv6.Message, int) (int, error) +} + +type batchWriter interface { + WriteBatch([]ipv6.Message, int) (int, error) +} + +func (s *StdNetBind) receiveIP( + br batchReader, + conn *net.UDPConn, + rxOffload bool, + bufs [][]byte, + sizes []int, + eps []Endpoint, +) (n int, err error) { + msgs := s.getMessages() + for i := range bufs { + (*msgs)[i].Buffers[0] = bufs[i] + (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] + } + defer s.putMessages(msgs) + var numMsgs int + if runtime.GOOS == "linux" { + if rxOffload { + readAt := len(*msgs) - 2 + numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) + if err != nil { + return 0, err + } + numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize) if err != nil { return 0, err } } else { - msg := &(*msgs)[0] - msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + numMsgs, err = br.ReadBatch(*msgs, 0) if err != nil { return 0, err } - numMsgs = 1 } - for i := 0; i < numMsgs; i++ { - msg := &(*msgs)[i] - sizes[i] = msg.N - addrPort := msg.Addr.(*net.UDPAddr).AddrPort() - ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation - getSrcFromControl(msg.OOB[:msg.NN], ep) - eps[i] = ep + } else { + msg := &(*msgs)[0] + msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + if err != nil { + return 0, err + } + numMsgs = 1 + } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + if sizes[i] == 0 { + continue } - return numMsgs, nil + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation + getSrcFromControl(msg.OOB[:msg.NN], ep) + eps[i] = ep } + return numMsgs, nil } -func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc { +func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) - defer s.ipv6MsgsPool.Put(msgs) - for i := range bufs { - (*msgs)[i].Buffers[0] = bufs[i] - } - var numMsgs int - if runtime.GOOS == "linux" { - numMsgs, err = pc.ReadBatch(*msgs, 0) - if err != nil { - return 0, err - } - } else { - msg := &(*msgs)[0] - msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) - if err != nil { - return 0, err - } - numMsgs = 1 - } - for i := 0; i < numMsgs; i++ { - msg := &(*msgs)[i] - sizes[i] = msg.N - addrPort := msg.Addr.(*net.UDPAddr).AddrPort() - ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation - getSrcFromControl(msg.OOB[:msg.NN], ep) - eps[i] = ep - } - return numMsgs, nil + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + } +} + +func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) } } @@ -293,28 +313,42 @@ func (s *StdNetBind) Close() error { } s.blackhole4 = false s.blackhole6 = false + s.ipv4TxOffload = false + s.ipv4RxOffload = false + s.ipv6TxOffload = false + s.ipv6RxOffload = false if err1 != nil { return err1 } return err2 } +type ErrUDPGSODisabled struct { + onLaddr string + RetryErr error +} + +func (e ErrUDPGSODisabled) Error() string { + return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr) +} + +func (e ErrUDPGSODisabled) Unwrap() error { + return e.RetryErr +} + func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { s.mu.Lock() blackhole := s.blackhole4 conn := s.ipv4 - var ( - pc4 *ipv4.PacketConn - pc6 *ipv6.PacketConn - ) + offload := s.ipv4TxOffload + br := batchWriter(s.ipv4PC) is6 := false if endpoint.DstIP().Is6() { blackhole = s.blackhole6 conn = s.ipv6 - pc6 = s.ipv6PC + br = s.ipv6PC is6 = true - } else { - pc4 = s.ipv4PC + offload = s.ipv6TxOffload } s.mu.Unlock() @@ -324,25 +358,56 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { if conn == nil { return syscall.EAFNOSUPPORT } + + msgs := s.getMessages() + defer s.putMessages(msgs) + ua := s.udpAddrPool.Get().(*net.UDPAddr) + defer s.udpAddrPool.Put(ua) if is6 { - return s.send6(conn, pc6, endpoint, bufs) + as16 := endpoint.DstIP().As16() + copy(ua.IP, as16[:]) + ua.IP = ua.IP[:16] } else { - return s.send4(conn, pc4, endpoint, bufs) + as4 := endpoint.DstIP().As4() + copy(ua.IP, as4[:]) + ua.IP = ua.IP[:4] } + ua.Port = int(endpoint.(*StdNetEndpoint).Port()) + var ( + retried bool + err error + ) +retry: + if offload { + n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize) + err = s.send(conn, br, (*msgs)[:n]) + if err != nil && offload && errShouldDisableUDPGSO(err) { + offload = false + s.mu.Lock() + if is6 { + s.ipv6TxOffload = false + } else { + s.ipv4TxOffload = false + } + s.mu.Unlock() + retried = true + goto retry + } + } else { + for i := range bufs { + (*msgs)[i].Addr = ua + (*msgs)[i].Buffers[0] = bufs[i] + setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint)) + } + err = s.send(conn, br, (*msgs)[:len(bufs)]) + } + if retried { + return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err} + } + return err } -func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error { - ua := s.udpAddrPool.Get().(*net.UDPAddr) - as4 := ep.DstIP().As4() - copy(ua.IP, as4[:]) - ua.IP = ua.IP[:4] - ua.Port = int(ep.(*StdNetEndpoint).Port()) - msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) - for i, buf := range bufs { - (*msgs)[i].Buffers[0] = buf - (*msgs)[i].Addr = ua - setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) - } +func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error { var ( n int err error @@ -350,59 +415,128 @@ func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, ) if runtime.GOOS == "linux" { for { - n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) - if err != nil || n == len((*msgs)[start:len(bufs)]) { + n, err = pc.WriteBatch(msgs[start:], 0) + if err != nil || n == len(msgs[start:]) { break } start += n } } else { - for i, buf := range bufs { - _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) + for _, msg := range msgs { + _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr)) if err != nil { break } } } - s.udpAddrPool.Put(ua) - s.ipv4MsgsPool.Put(msgs) return err } -func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error { - ua := s.udpAddrPool.Get().(*net.UDPAddr) - as16 := ep.DstIP().As16() - copy(ua.IP, as16[:]) - ua.IP = ua.IP[:16] - ua.Port = int(ep.(*StdNetEndpoint).Port()) - msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) - for i, buf := range bufs { - (*msgs)[i].Buffers[0] = buf - (*msgs)[i].Addr = ua - setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) - } +const ( + // Exceeding these values results in EMSGSIZE. They account for layer3 and + // layer4 headers. IPv6 does not need to account for itself as the payload + // length field is self excluding. + maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 + maxIPv6PayloadLen = 1<<16 - 1 - 8 + + // This is a hard limit imposed by the kernel. + udpSegmentMaxDatagrams = 64 +) + +type setGSOFunc func(control *[]byte, gsoSize uint16) + +func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int { var ( - n int - err error - start int + base = -1 // index of msg we are currently coalescing into + gsoSize int // segmentation size of msgs[base] + dgramCnt int // number of dgrams coalesced into msgs[base] + endBatch bool // tracking flag to start a new batch on next iteration of bufs ) - if runtime.GOOS == "linux" { - for { - n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) - if err != nil || n == len((*msgs)[start:len(bufs)]) { - break + maxPayloadLen := maxIPv4PayloadLen + if ep.DstIP().Is6() { + maxPayloadLen = maxIPv6PayloadLen + } + for i, buf := range bufs { + if i > 0 { + msgLen := len(buf) + baseLenBefore := len(msgs[base].Buffers[0]) + freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore + if msgLen+baseLenBefore <= maxPayloadLen && + msgLen <= gsoSize && + msgLen <= freeBaseCap && + dgramCnt < udpSegmentMaxDatagrams && + !endBatch { + msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...) + if i == len(bufs)-1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) + } + dgramCnt++ + if msgLen < gsoSize { + // A smaller than gsoSize packet on the tail is legal, but + // it must end the batch. + endBatch = true + } + continue } - start += n } - } else { - for i, buf := range bufs { - _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) - if err != nil { - break + if dgramCnt > 1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) + } + // Reset prior to incrementing base since we are preparing to start a + // new potential batch. + endBatch = false + base++ + gsoSize = len(buf) + setSrcControl(&msgs[base].OOB, ep) + msgs[base].Buffers[0] = buf + msgs[base].Addr = addr + dgramCnt = 1 + } + return base + 1 +} + +type getGSOFunc func(control []byte) (int, error) + +func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) { + for i := firstMsgAt; i < len(msgs); i++ { + msg := &msgs[i] + if msg.N == 0 { + return n, err + } + var ( + gsoSize int + start int + end = msg.N + numToSplit = 1 + ) + gsoSize, err = getGSO(msg.OOB[:msg.NN]) + if err != nil { + return n, err + } + if gsoSize > 0 { + numToSplit = (msg.N + gsoSize - 1) / gsoSize + end = gsoSize + } + for j := 0; j < numToSplit; j++ { + if n > i { + return n, errors.New("splitting coalesced packet resulted in overflow") } + copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) + msgs[n].N = copied + msgs[n].Addr = msg.Addr + start = end + end += gsoSize + if end > msg.N { + end = msg.N + } + n++ + } + if i != n-1 { + // It is legal for bytes to move within msg.Buffers[0] as a result + // of splitting, so we only zero the source msg len when it is not + // the destination of the last split operation above. + msg.N = 0 } } - s.udpAddrPool.Put(ua) - s.ipv6MsgsPool.Put(msgs) - return err + return n, nil } diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go index 1e4677654..34a3c9acf 100644 --- a/conn/bind_std_test.go +++ b/conn/bind_std_test.go @@ -1,6 +1,12 @@ package conn -import "testing" +import ( + "encoding/binary" + "net" + "testing" + + "golang.org/x/net/ipv6" +) func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { bind := NewStdNetBind().(*StdNetBind) @@ -20,3 +26,225 @@ func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { fn(bufs, sizes, eps) } } + +func mockSetGSOSize(control *[]byte, gsoSize uint16) { + *control = (*control)[:cap(*control)] + binary.LittleEndian.PutUint16(*control, gsoSize) +} + +func Test_coalesceMessages(t *testing.T) { + cases := []struct { + name string + buffs [][]byte + wantLens []int + wantGSO []int + }{ + { + name: "one message no coalesce", + buffs: [][]byte{ + make([]byte, 1, 1), + }, + wantLens: []int{1}, + wantGSO: []int{0}, + }, + { + name: "two messages equal len coalesce", + buffs: [][]byte{ + make([]byte, 1, 2), + make([]byte, 1, 1), + }, + wantLens: []int{2}, + wantGSO: []int{1}, + }, + { + name: "two messages unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + }, + wantLens: []int{3}, + wantGSO: []int{2}, + }, + { + name: "three messages second unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + make([]byte, 2, 2), + }, + wantLens: []int{3, 2}, + wantGSO: []int{2, 0}, + }, + { + name: "three messages limited cap coalesce", + buffs: [][]byte{ + make([]byte, 2, 4), + make([]byte, 2, 2), + make([]byte, 2, 2), + }, + wantLens: []int{4, 2}, + wantGSO: []int{2, 0}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + addr := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1").To4(), + Port: 1, + } + msgs := make([]ipv6.Message, len(tt.buffs)) + for i := range msgs { + msgs[i].Buffers = make([][]byte, 1) + msgs[i].OOB = make([]byte, 0, 2) + } + got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize) + if got != len(tt.wantLens) { + t.Fatalf("got len %d want: %d", got, len(tt.wantLens)) + } + for i := 0; i < got; i++ { + if msgs[i].Addr != addr { + t.Errorf("msgs[%d].Addr != passed addr", i) + } + gotLen := len(msgs[i].Buffers[0]) + if gotLen != tt.wantLens[i] { + t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i]) + } + gotGSO, err := mockGetGSOSize(msgs[i].OOB) + if err != nil { + t.Fatalf("msgs[%d] getGSOSize err: %v", i, err) + } + if gotGSO != tt.wantGSO[i] { + t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i]) + } + } + }) + } +} + +func mockGetGSOSize(control []byte) (int, error) { + if len(control) < 2 { + return 0, nil + } + return int(binary.LittleEndian.Uint16(control)), nil +} + +func Test_splitCoalescedMessages(t *testing.T) { + newMsg := func(n, gso int) ipv6.Message { + msg := ipv6.Message{ + Buffers: [][]byte{make([]byte, 1<<16-1)}, + N: n, + OOB: make([]byte, 2), + } + binary.LittleEndian.PutUint16(msg.OOB, uint16(gso)) + if gso > 0 { + msg.NN = 2 + } + return msg + } + + cases := []struct { + name string + msgs []ipv6.Message + firstMsgAt int + wantNumEval int + wantMsgLens []int + wantErr bool + }{ + { + name: "second last split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(3, 1), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 3, + wantMsgLens: []int{1, 1, 1, 0}, + wantErr: false, + }, + { + name: "second last no split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 1, + wantMsgLens: []int{1, 0, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last no split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(1, 0), + }, + firstMsgAt: 2, + wantNumEval: 2, + wantMsgLens: []int{1, 1, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(3, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(2, 1), + newMsg(2, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last no split last split overflow", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(4, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: true, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize) + if err != nil && !tt.wantErr { + t.Fatalf("err: %v", err) + } + if got != tt.wantNumEval { + t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval) + } + for i, msg := range tt.msgs { + if msg.N != tt.wantMsgLens[i] { + t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i]) + } + } + }) + } +} diff --git a/conn/sticky_default.go b/conn/control_default.go similarity index 54% rename from conn/sticky_default.go rename to conn/control_default.go index 1fa8a0c4b..a8bc06a95 100644 --- a/conn/sticky_default.go +++ b/conn/control_default.go @@ -1,4 +1,4 @@ -//go:build !linux || android +//go:build !(linux && !android) /* SPDX-License-Identifier: MIT * @@ -21,8 +21,9 @@ func (e *StdNetEndpoint) SrcToString() string { return "" } -// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but -// use alternatively named flags and need ports and require testing. +// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets +// ({get,set}srcControl feature set, but use alternatively named flags and need +// ports and require testing. // getSrcFromControl parses the control for PKTINFO and if found updates ep with // the source information found. @@ -34,8 +35,17 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) { func setSrcControl(control *[]byte, ep *StdNetEndpoint) { } -// srcControlSize returns the recommended buffer size for pooling sticky control -// data. -const srcControlSize = 0 +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. +func setGSOSize(control *[]byte, gsoSize uint16) { +} + +// controlSize returns the recommended buffer size for pooling sticky and UDP +// offloading control data. +const controlSize = 0 const StdNetSupportsStickySockets = false diff --git a/conn/sticky_linux.go b/conn/control_linux.go similarity index 65% rename from conn/sticky_linux.go rename to conn/control_linux.go index a30ccc715..f32f26a69 100644 --- a/conn/sticky_linux.go +++ b/conn/control_linux.go @@ -8,6 +8,7 @@ package conn import ( + "fmt" "net/netip" "unsafe" @@ -105,6 +106,54 @@ func setSrcControl(control *[]byte, ep *StdNetEndpoint) { *control = append(*control, ep.src...) } -var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) +const ( + sizeOfGSOData = 2 +) + +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + var ( + hdr unix.Cmsghdr + data []byte + rem = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) + if err != nil { + return 0, fmt.Errorf("error parsing socket control message: %w", err) + } + if hdr.Level == socketOptionLevelUDP && hdr.Type == socketOptionUDPGRO && len(data) >= sizeOfGSOData { + var gso uint16 + copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData]) + return int(gso), nil + } + } + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing +// data in control untouched. +func setGSOSize(control *[]byte, gsoSize uint16) { + existingLen := len(*control) + avail := cap(*control) - existingLen + space := unix.CmsgSpace(sizeOfGSOData) + if avail < space { + return + } + *control = (*control)[:cap(*control)] + gsoControl := (*control)[existingLen:] + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0])) + hdr.Level = socketOptionLevelUDP + hdr.Type = socketOptionUDPSegment + hdr.SetLen(unix.CmsgLen(sizeOfGSOData)) + copy((gsoControl)[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData)) + *control = (*control)[:existingLen+space] +} + +// controlSize returns the recommended buffer size for pooling sticky and UDP +// offloading control data. +var controlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) + unix.CmsgSpace(sizeOfGSOData) const StdNetSupportsStickySockets = true diff --git a/conn/sticky_linux_test.go b/conn/control_linux_test.go similarity index 98% rename from conn/sticky_linux_test.go rename to conn/control_linux_test.go index 679213a82..3ca7d3723 100644 --- a/conn/sticky_linux_test.go +++ b/conn/control_linux_test.go @@ -60,7 +60,7 @@ func Test_setSrcControl(t *testing.T) { } setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5) - control := make([]byte, srcControlSize) + control := make([]byte, controlSize) setSrcControl(&control, ep) @@ -89,7 +89,7 @@ func Test_setSrcControl(t *testing.T) { } setSrc(ep, netip.MustParseAddr("::1"), 5) - control := make([]byte, srcControlSize) + control := make([]byte, controlSize) setSrcControl(&control, ep) diff --git a/conn/controlfns_linux.go b/conn/controlfns_linux.go index a2396fe89..752fbca4c 100644 --- a/conn/controlfns_linux.go +++ b/conn/controlfns_linux.go @@ -57,5 +57,13 @@ func init() { } return err }, + + // Attempt to enable UDP_GRO + func(network, address string, c syscall.RawConn) error { + c.Control(func(fd uintptr) { + _ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, socketOptionUDPGRO, 1) + }) + return nil + }, ) } diff --git a/conn/errors_default.go b/conn/errors_default.go new file mode 100644 index 000000000..f1e5b90e5 --- /dev/null +++ b/conn/errors_default.go @@ -0,0 +1,12 @@ +//go:build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +func errShouldDisableUDPGSO(err error) bool { + return false +} diff --git a/conn/errors_linux.go b/conn/errors_linux.go new file mode 100644 index 000000000..8e61000f8 --- /dev/null +++ b/conn/errors_linux.go @@ -0,0 +1,26 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "errors" + "os" + + "golang.org/x/sys/unix" +) + +func errShouldDisableUDPGSO(err error) bool { + var serr *os.SyscallError + if errors.As(err, &serr) { + // EIO is returned by udp_send_skb() if the device driver does not have + // tx checksumming enabled, which is a hard requirement of UDP_SEGMENT. + // See: + // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 + // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 + return serr.Err == unix.EIO + } + return false +} diff --git a/conn/features_default.go b/conn/features_default.go new file mode 100644 index 000000000..d53ff5f7b --- /dev/null +++ b/conn/features_default.go @@ -0,0 +1,15 @@ +//go:build !linux +// +build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import "net" + +func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { + return +} diff --git a/conn/features_linux.go b/conn/features_linux.go new file mode 100644 index 000000000..513202e53 --- /dev/null +++ b/conn/features_linux.go @@ -0,0 +1,42 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "net" + + "golang.org/x/sys/unix" +) + +const ( + // TODO: upstream to x/sys/unix + socketOptionLevelUDP = 17 + socketOptionUDPSegment = 103 + socketOptionUDPGRO = 104 +) + +func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { + rc, err := conn.SyscallConn() + if err != nil { + return + } + err = rc.Control(func(fd uintptr) { + _, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, socketOptionUDPSegment) + if errSyscall != nil { + return + } + txOffload = true + opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, socketOptionUDPGRO) + if errSyscall != nil { + return + } + rxOffload = opt == 1 + }) + if err != nil { + return false, false + } + return txOffload, rxOffload +} diff --git a/device/send.go b/device/send.go index a95a46fb6..ea349979e 100644 --- a/device/send.go +++ b/device/send.go @@ -14,6 +14,7 @@ import ( "sync" "time" + "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/tun" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" @@ -525,6 +526,13 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { device.PutOutboundElement(elem) } device.PutOutboundElementsSlice(elems) + if err != nil { + var errGSO conn.ErrUDPGSODisabled + if errors.As(err, &errGSO) { + device.log.Verbosef(err.Error()) + err = errGSO.RetryErr + } + } if err != nil { device.log.Errorf("%v - Failed to send data packets: %v", peer, err) continue From d831fef379ddf6876304385621dd61995e7e32f5 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Tue, 4 Apr 2023 13:06:08 -0700 Subject: [PATCH 04/39] device: distribute crypto work as slice of elements After reducing UDP stack traversal overhead via GSO and GRO, runtime.chanrecv() began to account for a high percentage (20% in one environment) of perf samples during a throughput benchmark. The individual packet channel ops with the crypto goroutines was the primary contributor to this overhead. Updating these channels to pass vectors, which the device package already handles at its ends, reduced this overhead substantially, and improved throughput. The iperf3 results below demonstrate the effect of this commit between two Linux computers with i5-12400 CPUs. There is roughly ~13us of round trip latency between them. The first result is with UDP GSO and GRO, and with single element channels. Starting Test: protocol: TCP, 1 streams, 131072 byte blocks [ ID] Interval Transfer Bitrate Retr Cwnd [ 5] 0.00-10.00 sec 12.3 GBytes 10.6 Gbits/sec 232 3.15 MBytes - - - - - - - - - - - - - - - - - - - - - - - - - Test Complete. Summary Results: [ ID] Interval Transfer Bitrate Retr [ 5] 0.00-10.00 sec 12.3 GBytes 10.6 Gbits/sec 232 sender [ 5] 0.00-10.04 sec 12.3 GBytes 10.6 Gbits/sec receiver The second result is with channels updated to pass a slice of elements. Starting Test: protocol: TCP, 1 streams, 131072 byte blocks [ ID] Interval Transfer Bitrate Retr Cwnd [ 5] 0.00-10.00 sec 13.2 GBytes 11.3 Gbits/sec 182 3.15 MBytes - - - - - - - - - - - - - - - - - - - - - - - - - Test Complete. Summary Results: [ ID] Interval Transfer Bitrate Retr [ 5] 0.00-10.00 sec 13.2 GBytes 11.3 Gbits/sec 182 sender [ 5] 0.00-10.04 sec 13.2 GBytes 11.3 Gbits/sec receiver Reviewed-by: Adrian Dewhurst Signed-off-by: Jordan Whited --- device/channels.go | 8 +++---- device/receive.go | 44 +++++++++++++++++------------------ device/send.go | 58 +++++++++++++++++++++++----------------------- 3 files changed, 55 insertions(+), 55 deletions(-) diff --git a/device/channels.go b/device/channels.go index 039d8dfd0..40ee5c9a5 100644 --- a/device/channels.go +++ b/device/channels.go @@ -19,13 +19,13 @@ import ( // call wg.Done to remove the initial reference. // When the refcount hits 0, the queue's channel is closed. type outboundQueue struct { - c chan *QueueOutboundElement + c chan *[]*QueueOutboundElement wg sync.WaitGroup } func newOutboundQueue() *outboundQueue { q := &outboundQueue{ - c: make(chan *QueueOutboundElement, QueueOutboundSize), + c: make(chan *[]*QueueOutboundElement, QueueOutboundSize), } q.wg.Add(1) go func() { @@ -37,13 +37,13 @@ func newOutboundQueue() *outboundQueue { // A inboundQueue is similar to an outboundQueue; see those docs. type inboundQueue struct { - c chan *QueueInboundElement + c chan *[]*QueueInboundElement wg sync.WaitGroup } func newInboundQueue() *inboundQueue { q := &inboundQueue{ - c: make(chan *QueueInboundElement, QueueInboundSize), + c: make(chan *[]*QueueInboundElement, QueueInboundSize), } q.wg.Add(1) go func() { diff --git a/device/receive.go b/device/receive.go index c0cf74791..744bf182a 100644 --- a/device/receive.go +++ b/device/receive.go @@ -220,9 +220,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive for peer, elems := range elemsByPeer { if peer.isRunning.Load() { peer.queue.inbound.c <- elems - for _, elem := range *elems { - device.queue.decryption.c <- elem - } + device.queue.decryption.c <- elems } else { for _, elem := range *elems { device.PutMessageBuffer(elem.buffer) @@ -241,26 +239,28 @@ func (device *Device) RoutineDecryption(id int) { defer device.log.Verbosef("Routine: decryption worker %d - stopped", id) device.log.Verbosef("Routine: decryption worker %d - started", id) - for elem := range device.queue.decryption.c { - // split message into fields - counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] - content := elem.packet[MessageTransportOffsetContent:] - - // decrypt and release to consumer - var err error - elem.counter = binary.LittleEndian.Uint64(counter) - // copy counter to nonce - binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) - elem.packet, err = elem.keypair.receive.Open( - content[:0], - nonce[:], - content, - nil, - ) - if err != nil { - elem.packet = nil + for elems := range device.queue.decryption.c { + for _, elem := range *elems { + // split message into fields + counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] + content := elem.packet[MessageTransportOffsetContent:] + + // decrypt and release to consumer + var err error + elem.counter = binary.LittleEndian.Uint64(counter) + // copy counter to nonce + binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) + elem.packet, err = elem.keypair.receive.Open( + content[:0], + nonce[:], + content, + nil, + ) + if err != nil { + elem.packet = nil + } + elem.Unlock() } - elem.Unlock() } } diff --git a/device/send.go b/device/send.go index ea349979e..7adfacff1 100644 --- a/device/send.go +++ b/device/send.go @@ -385,9 +385,7 @@ top: // add to parallel and sequential queue if peer.isRunning.Load() { peer.queue.outbound.c <- elems - for _, elem := range *elems { - peer.device.queue.encryption.c <- elem - } + peer.device.queue.encryption.c <- elems } else { for _, elem := range *elems { peer.device.PutMessageBuffer(elem.buffer) @@ -447,32 +445,34 @@ func (device *Device) RoutineEncryption(id int) { defer device.log.Verbosef("Routine: encryption worker %d - stopped", id) device.log.Verbosef("Routine: encryption worker %d - started", id) - for elem := range device.queue.encryption.c { - // populate header fields - header := elem.buffer[:MessageTransportHeaderSize] - - fieldType := header[0:4] - fieldReceiver := header[4:8] - fieldNonce := header[8:16] - - binary.LittleEndian.PutUint32(fieldType, MessageTransportType) - binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) - binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) - - // pad content to multiple of 16 - paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) - elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) - - // encrypt content and release to consumer - - binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) - elem.packet = elem.keypair.send.Seal( - header, - nonce[:], - elem.packet, - nil, - ) - elem.Unlock() + for elems := range device.queue.encryption.c { + for _, elem := range *elems { + // populate header fields + header := elem.buffer[:MessageTransportHeaderSize] + + fieldType := header[0:4] + fieldReceiver := header[4:8] + fieldNonce := header[8:16] + + binary.LittleEndian.PutUint32(fieldType, MessageTransportType) + binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) + binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) + + // pad content to multiple of 16 + paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) + elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) + + // encrypt content and release to consumer + + binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) + elem.packet = elem.keypair.send.Seal( + header, + nonce[:], + elem.packet, + nil, + ) + elem.Unlock() + } } } From 915962ded2318b41ddc7d2664bad3ffb12a8998f Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Tue, 4 Apr 2023 13:07:11 -0700 Subject: [PATCH 05/39] tun: unwind summing loop in checksumNoFold() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit $ benchstat old.txt new.txt goos: linux goarch: amd64 pkg: golang.zx2c4.com/wireguard/tun cpu: 12th Gen Intel(R) Core(TM) i5-12400 │ old.txt │ new.txt │ │ sec/op │ sec/op vs base │ Checksum/64-12 10.670n ± 2% 4.769n ± 0% -55.30% (p=0.000 n=10) Checksum/128-12 19.665n ± 2% 8.032n ± 0% -59.16% (p=0.000 n=10) Checksum/256-12 37.68n ± 1% 16.06n ± 0% -57.37% (p=0.000 n=10) Checksum/512-12 76.61n ± 3% 32.13n ± 0% -58.06% (p=0.000 n=10) Checksum/1024-12 160.55n ± 4% 64.25n ± 0% -59.98% (p=0.000 n=10) Checksum/1500-12 231.05n ± 7% 94.12n ± 0% -59.26% (p=0.000 n=10) Checksum/2048-12 309.5n ± 3% 128.5n ± 0% -58.48% (p=0.000 n=10) Checksum/4096-12 603.8n ± 4% 257.2n ± 0% -57.41% (p=0.000 n=10) Checksum/8192-12 1185.0n ± 3% 515.5n ± 0% -56.50% (p=0.000 n=10) Checksum/9000-12 1328.5n ± 5% 564.8n ± 0% -57.49% (p=0.000 n=10) Checksum/9001-12 1340.5n ± 3% 564.8n ± 0% -57.87% (p=0.000 n=10) geomean 185.3n 77.99n -57.92% Reviewed-by: Adrian Dewhurst Signed-off-by: Jordan Whited --- tun/checksum.go | 100 +++++++++++++++++++++++++++++++++++++------ tun/checksum_test.go | 35 +++++++++++++++ 2 files changed, 123 insertions(+), 12 deletions(-) create mode 100644 tun/checksum_test.go diff --git a/tun/checksum.go b/tun/checksum.go index f4f847164..29a8fc8fc 100644 --- a/tun/checksum.go +++ b/tun/checksum.go @@ -3,23 +3,99 @@ package tun import "encoding/binary" // TODO: Explore SIMD and/or other assembly optimizations. +// TODO: Test native endian loads. See RFC 1071 section 2 part B. func checksumNoFold(b []byte, initial uint64) uint64 { ac := initial - i := 0 - n := len(b) - for n >= 4 { - ac += uint64(binary.BigEndian.Uint32(b[i : i+4])) - n -= 4 - i += 4 + + for len(b) >= 128 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + ac += uint64(binary.BigEndian.Uint32(b[64:68])) + ac += uint64(binary.BigEndian.Uint32(b[68:72])) + ac += uint64(binary.BigEndian.Uint32(b[72:76])) + ac += uint64(binary.BigEndian.Uint32(b[76:80])) + ac += uint64(binary.BigEndian.Uint32(b[80:84])) + ac += uint64(binary.BigEndian.Uint32(b[84:88])) + ac += uint64(binary.BigEndian.Uint32(b[88:92])) + ac += uint64(binary.BigEndian.Uint32(b[92:96])) + ac += uint64(binary.BigEndian.Uint32(b[96:100])) + ac += uint64(binary.BigEndian.Uint32(b[100:104])) + ac += uint64(binary.BigEndian.Uint32(b[104:108])) + ac += uint64(binary.BigEndian.Uint32(b[108:112])) + ac += uint64(binary.BigEndian.Uint32(b[112:116])) + ac += uint64(binary.BigEndian.Uint32(b[116:120])) + ac += uint64(binary.BigEndian.Uint32(b[120:124])) + ac += uint64(binary.BigEndian.Uint32(b[124:128])) + b = b[128:] + } + if len(b) >= 64 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + b = b[64:] + } + if len(b) >= 32 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + b = b[32:] + } + if len(b) >= 16 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + b = b[16:] } - for n >= 2 { - ac += uint64(binary.BigEndian.Uint16(b[i : i+2])) - n -= 2 - i += 2 + if len(b) >= 8 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + b = b[8:] } - if n == 1 { - ac += uint64(b[i]) << 8 + if len(b) >= 4 { + ac += uint64(binary.BigEndian.Uint32(b)) + b = b[4:] } + if len(b) >= 2 { + ac += uint64(binary.BigEndian.Uint16(b)) + b = b[2:] + } + if len(b) == 1 { + ac += uint64(b[0]) << 8 + } + return ac } diff --git a/tun/checksum_test.go b/tun/checksum_test.go new file mode 100644 index 000000000..c1ccff531 --- /dev/null +++ b/tun/checksum_test.go @@ -0,0 +1,35 @@ +package tun + +import ( + "fmt" + "math/rand" + "testing" +) + +func BenchmarkChecksum(b *testing.B) { + lengths := []int{ + 64, + 128, + 256, + 512, + 1024, + 1500, + 2048, + 4096, + 8192, + 9000, + 9001, + } + + for _, length := range lengths { + b.Run(fmt.Sprintf("%d", length), func(b *testing.B) { + buf := make([]byte, length) + rng := rand.New(rand.NewSource(1)) + rng.Read(buf) + b.ResetTimer() + for i := 0; i < b.N; i++ { + checksum(buf, 0) + } + }) + } +} From ceb9a09d035f373d6ca9617b5e25195c100017fb Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Tue, 4 Apr 2023 13:07:35 -0700 Subject: [PATCH 06/39] tun: reduce redundant checksumming in tcpGRO() IPv4 header and pseudo header checksums were being computed on every merge operation. Additionally, virtioNetHdr was being written at the same time. This delays those operations until after all coalescing has occurred. Reviewed-by: Adrian Dewhurst Signed-off-by: Jordan Whited --- tun/tcp_offload_linux.go | 162 ++++++++++++++++++++++++--------------- 1 file changed, 99 insertions(+), 63 deletions(-) diff --git a/tun/tcp_offload_linux.go b/tun/tcp_offload_linux.go index d64010d01..67288237f 100644 --- a/tun/tcp_offload_linux.go +++ b/tun/tcp_offload_linux.go @@ -269,11 +269,11 @@ func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool { type coalesceResult int const ( - coalesceInsufficientCap coalesceResult = 0 - coalescePSHEnding coalesceResult = 1 - coalesceItemInvalidCSum coalesceResult = 2 - coalescePktInvalidCSum coalesceResult = 3 - coalesceSuccess coalesceResult = 4 + coalesceInsufficientCap coalesceResult = iota + coalescePSHEnding + coalesceItemInvalidCSum + coalescePktInvalidCSum + coalesceSuccess ) // coalesceTCPPackets attempts to coalesce pkt with the packet described by @@ -339,42 +339,6 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize if gsoSize > item.gsoSize { item.gsoSize = gsoSize } - hdr := virtioNetHdr{ - flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb - hdrLen: uint16(headersLen), - gsoSize: uint16(item.gsoSize), - csumStart: uint16(item.iphLen), - csumOffset: 16, - } - - // Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the - // (IPv4) header checksum. - if isV6 { - hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 - binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len - } else { - hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 - pktHead[10], pktHead[11] = 0, 0 // clear checksum field - binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length - iphCSum := ^checksum(pktHead[:item.iphLen], 0) // compute checksum - binary.BigEndian.PutUint16(pktHead[10:], iphCSum) // set checksum field - } - hdr.encode(bufs[item.bufsIndex][bufsOffset-virtioNetHdrLen:]) - - // Calculate the pseudo header checksum and place it at the TCP checksum - // offset. Downstream checksum offloading will combine this with computation - // of the tcp header and payload checksum. - addrLen := 4 - addrOffset := ipv4SrcAddrOffset - if isV6 { - addrLen = 16 - addrOffset = ipv6SrcAddrOffset - } - srcAddrAt := bufsOffset + addrOffset - srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] - dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] - psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen))) - binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) item.numMerged++ return coalesceSuccess @@ -390,43 +354,52 @@ const ( maxUint16 = 1<<16 - 1 ) +type tcpGROResult int + +const ( + tcpGROResultNoop tcpGROResult = iota + tcpGROResultTableInsert + tcpGROResultCoalesced +) + // tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with -// existing packets tracked in table. It will return false when pktI is not -// coalesced, otherwise true. This indicates to the caller if bufs[pktI] -// should be written to the Device. -func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) { +// existing packets tracked in table. It returns a tcpGROResultNoop when no +// action was taken, tcpGROResultTableInsert when the evaluated packet was +// inserted into table, and tcpGROResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) tcpGROResult { pkt := bufs[pktI][offset:] if len(pkt) > maxUint16 { // A valid IPv4 or IPv6 packet will never exceed this. - return false + return tcpGROResultNoop } iphLen := int((pkt[0] & 0x0F) * 4) if isV6 { iphLen = 40 ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) if ipv6HPayloadLen != len(pkt)-iphLen { - return false + return tcpGROResultNoop } } else { totalLen := int(binary.BigEndian.Uint16(pkt[2:])) if totalLen != len(pkt) { - return false + return tcpGROResultNoop } } if len(pkt) < iphLen { - return false + return tcpGROResultNoop } tcphLen := int((pkt[iphLen+12] >> 4) * 4) if tcphLen < 20 || tcphLen > 60 { - return false + return tcpGROResultNoop } if len(pkt) < iphLen+tcphLen { - return false + return tcpGROResultNoop } if !isV6 { if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { // no GRO support for fragmented segments for now - return false + return tcpGROResultNoop } } tcpFlags := pkt[iphLen+tcpFlagsOffset] @@ -434,14 +407,14 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) // not a candidate if any non-ACK flags (except PSH+ACK) are set if tcpFlags != tcpFlagACK { if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { - return false + return tcpGROResultNoop } pshSet = true } gsoSize := uint16(len(pkt) - tcphLen - iphLen) // not a candidate if payload len is 0 if gsoSize < 1 { - return false + return tcpGROResultNoop } seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) srcAddrOffset := ipv4SrcAddrOffset @@ -452,7 +425,7 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) } items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) if !existing { - return false + return tcpGROResultNoop } for i := len(items) - 1; i >= 0; i-- { // In the best case of packets arriving in order iterating in reverse is @@ -470,20 +443,20 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) switch result { case coalesceSuccess: table.updateAt(item, i) - return true + return tcpGROResultCoalesced case coalesceItemInvalidCSum: // delete the item with an invalid csum table.deleteAt(item.key, i) case coalescePktInvalidCSum: // no point in inserting an item that we can't coalesce - return false + return tcpGROResultNoop default: } } } // failed to coalesce with any other packets; store the item in the flow table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) - return false + return tcpGROResultTableInsert } func isTCP4NoIPOptions(b []byte) bool { @@ -515,6 +488,64 @@ func isTCP6NoEH(b []byte) bool { return true } +// applyCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 bool) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + item.tcphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 16, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + if isV6 { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Calculate the pseudo header checksum and place it at the TCP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the tcp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + // handleGRO evaluates bufs for GRO, and writes the indices of the resulting // packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be // empty (but non-nil), and are passed in to save allocs as the caller may reset @@ -524,23 +555,28 @@ func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toW if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { return errors.New("invalid offset") } - var coalesced bool + var result tcpGROResult switch { case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce - coalesced = tcpGRO(bufs, offset, i, tcp4Table, false) + result = tcpGRO(bufs, offset, i, tcp4Table, false) case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce - coalesced = tcpGRO(bufs, offset, i, tcp6Table, true) + result = tcpGRO(bufs, offset, i, tcp6Table, true) } - if !coalesced { + switch result { + case tcpGROResultNoop: hdr := virtioNetHdr{} err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) if err != nil { return err } + fallthrough + case tcpGROResultTableInsert: *toWrite = append(*toWrite, i) } } - return nil + err4 := applyCoalesceAccounting(bufs, offset, tcp4Table, false) + err6 := applyCoalesceAccounting(bufs, offset, tcp6Table, true) + return errors.Join(err4, err6) } // tcpTSO splits packets from in into outBuffs, writing the size of each From cc7b29b8c60406713a50faf0175d6a0481a182f6 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Wed, 5 Apr 2023 15:40:40 -0700 Subject: [PATCH 07/39] device: move Queue{In,Out}boundElement Mutex to container type Queue{In,Out}boundElement locking can contribute to significant overhead via sync.Mutex.lockSlow() in some environments. These types are passed throughout the device package as elements in a slice, so move the per-element Mutex to a container around the slice. Signed-off-by: Jordan Whited --- device/channels.go | 32 +++++++-------- device/device.go | 10 ++--- device/peer.go | 8 ++-- device/pools.go | 44 +++++++++++---------- device/receive.go | 43 ++++++++++---------- device/send.go | 99 ++++++++++++++++++++++++---------------------- 6 files changed, 123 insertions(+), 113 deletions(-) diff --git a/device/channels.go b/device/channels.go index 40ee5c9a5..e526f6bb1 100644 --- a/device/channels.go +++ b/device/channels.go @@ -19,13 +19,13 @@ import ( // call wg.Done to remove the initial reference. // When the refcount hits 0, the queue's channel is closed. type outboundQueue struct { - c chan *[]*QueueOutboundElement + c chan *QueueOutboundElementsContainer wg sync.WaitGroup } func newOutboundQueue() *outboundQueue { q := &outboundQueue{ - c: make(chan *[]*QueueOutboundElement, QueueOutboundSize), + c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), } q.wg.Add(1) go func() { @@ -37,13 +37,13 @@ func newOutboundQueue() *outboundQueue { // A inboundQueue is similar to an outboundQueue; see those docs. type inboundQueue struct { - c chan *[]*QueueInboundElement + c chan *QueueInboundElementsContainer wg sync.WaitGroup } func newInboundQueue() *inboundQueue { q := &inboundQueue{ - c: make(chan *[]*QueueInboundElement, QueueInboundSize), + c: make(chan *QueueInboundElementsContainer, QueueInboundSize), } q.wg.Add(1) go func() { @@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue { } type autodrainingInboundQueue struct { - c chan *[]*QueueInboundElement + c chan *QueueInboundElementsContainer } // newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd. @@ -81,7 +81,7 @@ type autodrainingInboundQueue struct { // some other means, such as sending a sentinel nil values. func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { q := &autodrainingInboundQueue{ - c: make(chan *[]*QueueInboundElement, QueueInboundSize), + c: make(chan *QueueInboundElementsContainer, QueueInboundSize), } runtime.SetFinalizer(q, device.flushInboundQueue) return q @@ -90,13 +90,13 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { for { select { - case elems := <-q.c: - for _, elem := range *elems { - elem.Lock() + case elemsContainer := <-q.c: + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) } - device.PutInboundElementsSlice(elems) + device.PutInboundElementsContainer(elemsContainer) default: return } @@ -104,7 +104,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { } type autodrainingOutboundQueue struct { - c chan *[]*QueueOutboundElement + c chan *QueueOutboundElementsContainer } // newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd. @@ -114,7 +114,7 @@ type autodrainingOutboundQueue struct { // All sends to the channel must be best-effort, because there may be no receivers. func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { q := &autodrainingOutboundQueue{ - c: make(chan *[]*QueueOutboundElement, QueueOutboundSize), + c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), } runtime.SetFinalizer(q, device.flushOutboundQueue) return q @@ -123,13 +123,13 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { for { select { - case elems := <-q.c: - for _, elem := range *elems { - elem.Lock() + case elemsContainer := <-q.c: + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } - device.PutOutboundElementsSlice(elems) + device.PutOutboundElementsContainer(elemsContainer) default: return } diff --git a/device/device.go b/device/device.go index 7482d9b33..5c666acc9 100644 --- a/device/device.go +++ b/device/device.go @@ -68,11 +68,11 @@ type Device struct { cookieChecker CookieChecker pool struct { - outboundElementsSlice *WaitPool - inboundElementsSlice *WaitPool - messageBuffers *WaitPool - inboundElements *WaitPool - outboundElements *WaitPool + inboundElementsContainer *WaitPool + outboundElementsContainer *WaitPool + messageBuffers *WaitPool + inboundElements *WaitPool + outboundElements *WaitPool } queue struct { diff --git a/device/peer.go b/device/peer.go index c7163ac58..22757d443 100644 --- a/device/peer.go +++ b/device/peer.go @@ -45,9 +45,9 @@ type Peer struct { } queue struct { - staged chan *[]*QueueOutboundElement // staged packets before a handshake is available - outbound *autodrainingOutboundQueue // sequential ordering of udp transmission - inbound *autodrainingInboundQueue // sequential ordering of tun writing + staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available + outbound *autodrainingOutboundQueue // sequential ordering of udp transmission + inbound *autodrainingInboundQueue // sequential ordering of tun writing } cookieGenerator CookieGenerator @@ -81,7 +81,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { peer.device = device peer.queue.outbound = newAutodrainingOutboundQueue(device) peer.queue.inbound = newAutodrainingInboundQueue(device) - peer.queue.staged = make(chan *[]*QueueOutboundElement, QueueStagedSize) + peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize) // map public key _, ok := device.peers.keyMap[pk] diff --git a/device/pools.go b/device/pools.go index 02a5d6acb..94f3dc7e6 100644 --- a/device/pools.go +++ b/device/pools.go @@ -46,13 +46,13 @@ func (p *WaitPool) Put(x any) { } func (device *Device) PopulatePools() { - device.pool.outboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any { - s := make([]*QueueOutboundElement, 0, device.BatchSize()) - return &s - }) - device.pool.inboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any { + device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { s := make([]*QueueInboundElement, 0, device.BatchSize()) - return &s + return &QueueInboundElementsContainer{elems: s} + }) + device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { + s := make([]*QueueOutboundElement, 0, device.BatchSize()) + return &QueueOutboundElementsContainer{elems: s} }) device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { return new([MaxMessageSize]byte) @@ -65,28 +65,32 @@ func (device *Device) PopulatePools() { }) } -func (device *Device) GetOutboundElementsSlice() *[]*QueueOutboundElement { - return device.pool.outboundElementsSlice.Get().(*[]*QueueOutboundElement) +func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer { + c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer) + c.Mutex = sync.Mutex{} + return c } -func (device *Device) PutOutboundElementsSlice(s *[]*QueueOutboundElement) { - for i := range *s { - (*s)[i] = nil +func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) { + for i := range c.elems { + c.elems[i] = nil } - *s = (*s)[:0] - device.pool.outboundElementsSlice.Put(s) + c.elems = c.elems[:0] + device.pool.inboundElementsContainer.Put(c) } -func (device *Device) GetInboundElementsSlice() *[]*QueueInboundElement { - return device.pool.inboundElementsSlice.Get().(*[]*QueueInboundElement) +func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer { + c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer) + c.Mutex = sync.Mutex{} + return c } -func (device *Device) PutInboundElementsSlice(s *[]*QueueInboundElement) { - for i := range *s { - (*s)[i] = nil +func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) { + for i := range c.elems { + c.elems[i] = nil } - *s = (*s)[:0] - device.pool.inboundElementsSlice.Put(s) + c.elems = c.elems[:0] + device.pool.outboundElementsContainer.Put(c) } func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { diff --git a/device/receive.go b/device/receive.go index 744bf182a..da663e9ee 100644 --- a/device/receive.go +++ b/device/receive.go @@ -27,7 +27,6 @@ type QueueHandshakeElement struct { } type QueueInboundElement struct { - sync.Mutex buffer *[MaxMessageSize]byte packet []byte counter uint64 @@ -35,6 +34,11 @@ type QueueInboundElement struct { endpoint conn.Endpoint } +type QueueInboundElementsContainer struct { + sync.Mutex + elems []*QueueInboundElement +} + // clearPointers clears elem fields that contain pointers. // This makes the garbage collector's life easier and // avoids accidentally keeping other objects around unnecessarily. @@ -87,7 +91,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive count int endpoints = make([]conn.Endpoint, maxBatchSize) deathSpiral int - elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize) + elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize) ) for i := range bufsArrs { @@ -170,15 +174,14 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive elem.keypair = keypair elem.endpoint = endpoints[i] elem.counter = 0 - elem.Mutex = sync.Mutex{} - elem.Lock() elemsForPeer, ok := elemsByPeer[peer] if !ok { - elemsForPeer = device.GetInboundElementsSlice() + elemsForPeer = device.GetInboundElementsContainer() + elemsForPeer.Lock() elemsByPeer[peer] = elemsForPeer } - *elemsForPeer = append(*elemsForPeer, elem) + elemsForPeer.elems = append(elemsForPeer.elems, elem) bufsArrs[i] = device.GetMessageBuffer() bufs[i] = bufsArrs[i][:] continue @@ -217,16 +220,16 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive default: } } - for peer, elems := range elemsByPeer { + for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { - peer.queue.inbound.c <- elems - device.queue.decryption.c <- elems + peer.queue.inbound.c <- elemsContainer + device.queue.decryption.c <- elemsContainer } else { - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) } - device.PutInboundElementsSlice(elems) + device.PutInboundElementsContainer(elemsContainer) } delete(elemsByPeer, peer) } @@ -239,8 +242,8 @@ func (device *Device) RoutineDecryption(id int) { defer device.log.Verbosef("Routine: decryption worker %d - stopped", id) device.log.Verbosef("Routine: decryption worker %d - started", id) - for elems := range device.queue.decryption.c { - for _, elem := range *elems { + for elemsContainer := range device.queue.decryption.c { + for _, elem := range elemsContainer.elems { // split message into fields counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] content := elem.packet[MessageTransportOffsetContent:] @@ -259,8 +262,8 @@ func (device *Device) RoutineDecryption(id int) { if err != nil { elem.packet = nil } - elem.Unlock() } + elemsContainer.Unlock() } } @@ -437,12 +440,12 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { bufs := make([][]byte, 0, maxBatchSize) - for elems := range peer.queue.inbound.c { - if elems == nil { + for elemsContainer := range peer.queue.inbound.c { + if elemsContainer == nil { return } - for _, elem := range *elems { - elem.Lock() + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { if elem.packet == nil { // decryption failed continue @@ -515,11 +518,11 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { device.log.Errorf("Failed to write packets to TUN device: %v", err) } } - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) } bufs = bufs[:0] - device.PutInboundElementsSlice(elems) + device.PutInboundElementsContainer(elemsContainer) } } diff --git a/device/send.go b/device/send.go index 7adfacff1..95a5cbe49 100644 --- a/device/send.go +++ b/device/send.go @@ -46,7 +46,6 @@ import ( */ type QueueOutboundElement struct { - sync.Mutex buffer *[MaxMessageSize]byte // slice holding the packet data packet []byte // slice of "buffer" (always!) nonce uint64 // nonce for encryption @@ -54,10 +53,14 @@ type QueueOutboundElement struct { peer *Peer // related peer } +type QueueOutboundElementsContainer struct { + sync.Mutex + elems []*QueueOutboundElement +} + func (device *Device) NewOutboundElement() *QueueOutboundElement { elem := device.GetOutboundElement() elem.buffer = device.GetMessageBuffer() - elem.Mutex = sync.Mutex{} elem.nonce = 0 // keypair and peer were cleared (if necessary) by clearPointers. return elem @@ -79,15 +82,15 @@ func (elem *QueueOutboundElement) clearPointers() { func (peer *Peer) SendKeepalive() { if len(peer.queue.staged) == 0 && peer.isRunning.Load() { elem := peer.device.NewOutboundElement() - elems := peer.device.GetOutboundElementsSlice() - *elems = append(*elems, elem) + elemsContainer := peer.device.GetOutboundElementsContainer() + elemsContainer.elems = append(elemsContainer.elems, elem) select { - case peer.queue.staged <- elems: + case peer.queue.staged <- elemsContainer: peer.device.log.Verbosef("%v - Sending keepalive packet", peer) default: peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) - peer.device.PutOutboundElementsSlice(elems) + peer.device.PutOutboundElementsContainer(elemsContainer) } } peer.SendStagedPackets() @@ -219,7 +222,7 @@ func (device *Device) RoutineReadFromTUN() { readErr error elems = make([]*QueueOutboundElement, batchSize) bufs = make([][]byte, batchSize) - elemsByPeer = make(map[*Peer]*[]*QueueOutboundElement, batchSize) + elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize) count = 0 sizes = make([]int, batchSize) offset = MessageTransportHeaderSize @@ -276,10 +279,10 @@ func (device *Device) RoutineReadFromTUN() { } elemsForPeer, ok := elemsByPeer[peer] if !ok { - elemsForPeer = device.GetOutboundElementsSlice() + elemsForPeer = device.GetOutboundElementsContainer() elemsByPeer[peer] = elemsForPeer } - *elemsForPeer = append(*elemsForPeer, elem) + elemsForPeer.elems = append(elemsForPeer.elems, elem) elems[i] = device.NewOutboundElement() bufs[i] = elems[i].buffer[:] } @@ -289,11 +292,11 @@ func (device *Device) RoutineReadFromTUN() { peer.StagePackets(elemsForPeer) peer.SendStagedPackets() } else { - for _, elem := range *elemsForPeer { + for _, elem := range elemsForPeer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } - device.PutOutboundElementsSlice(elemsForPeer) + device.PutOutboundElementsContainer(elemsForPeer) } delete(elemsByPeer, peer) } @@ -317,7 +320,7 @@ func (device *Device) RoutineReadFromTUN() { } } -func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) { +func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) { for { select { case peer.queue.staged <- elems: @@ -326,11 +329,11 @@ func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) { } select { case tooOld := <-peer.queue.staged: - for _, elem := range *tooOld { + for _, elem := range tooOld.elems { peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } - peer.device.PutOutboundElementsSlice(tooOld) + peer.device.PutOutboundElementsContainer(tooOld) default: } } @@ -349,52 +352,52 @@ top: } for { - var elemsOOO *[]*QueueOutboundElement + var elemsContainerOOO *QueueOutboundElementsContainer select { - case elems := <-peer.queue.staged: + case elemsContainer := <-peer.queue.staged: i := 0 - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { elem.peer = peer elem.nonce = keypair.sendNonce.Add(1) - 1 if elem.nonce >= RejectAfterMessages { keypair.sendNonce.Store(RejectAfterMessages) - if elemsOOO == nil { - elemsOOO = peer.device.GetOutboundElementsSlice() + if elemsContainerOOO == nil { + elemsContainerOOO = peer.device.GetOutboundElementsContainer() } - *elemsOOO = append(*elemsOOO, elem) + elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem) continue } else { - (*elems)[i] = elem + elemsContainer.elems[i] = elem i++ } elem.keypair = keypair - elem.Lock() } - *elems = (*elems)[:i] + elemsContainer.Lock() + elemsContainer.elems = elemsContainer.elems[:i] - if elemsOOO != nil { - peer.StagePackets(elemsOOO) // XXX: Out of order, but we can't front-load go chans + if elemsContainerOOO != nil { + peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans } - if len(*elems) == 0 { - peer.device.PutOutboundElementsSlice(elems) + if len(elemsContainer.elems) == 0 { + peer.device.PutOutboundElementsContainer(elemsContainer) goto top } // add to parallel and sequential queue if peer.isRunning.Load() { - peer.queue.outbound.c <- elems - peer.device.queue.encryption.c <- elems + peer.queue.outbound.c <- elemsContainer + peer.device.queue.encryption.c <- elemsContainer } else { - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } - peer.device.PutOutboundElementsSlice(elems) + peer.device.PutOutboundElementsContainer(elemsContainer) } - if elemsOOO != nil { + if elemsContainerOOO != nil { goto top } default: @@ -406,12 +409,12 @@ top: func (peer *Peer) FlushStagedPackets() { for { select { - case elems := <-peer.queue.staged: - for _, elem := range *elems { + case elemsContainer := <-peer.queue.staged: + for _, elem := range elemsContainer.elems { peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } - peer.device.PutOutboundElementsSlice(elems) + peer.device.PutOutboundElementsContainer(elemsContainer) default: return } @@ -433,7 +436,7 @@ func calculatePaddingSize(packetSize, mtu int) int { return paddedSize - lastUnit } -/* Encrypts the elements in the queue +/* Encrypts the elems in the queue * and marks them for sequential consumption (by releasing the mutex) * * Obs. One instance per core @@ -445,8 +448,8 @@ func (device *Device) RoutineEncryption(id int) { defer device.log.Verbosef("Routine: encryption worker %d - stopped", id) device.log.Verbosef("Routine: encryption worker %d - started", id) - for elems := range device.queue.encryption.c { - for _, elem := range *elems { + for elemsContainer := range device.queue.encryption.c { + for _, elem := range elemsContainer.elems { // populate header fields header := elem.buffer[:MessageTransportHeaderSize] @@ -471,8 +474,8 @@ func (device *Device) RoutineEncryption(id int) { elem.packet, nil, ) - elem.Unlock() } + elemsContainer.Unlock() } } @@ -486,28 +489,28 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { bufs := make([][]byte, 0, maxBatchSize) - for elems := range peer.queue.outbound.c { + for elemsContainer := range peer.queue.outbound.c { bufs = bufs[:0] - if elems == nil { + if elemsContainer == nil { return } if !peer.isRunning.Load() { - // peer has been stopped; return re-usable elems to the shared pool. + // peer has been stopped; return re-usable elemsContainer to the shared pool. // This is an optimization only. It is possible for the peer to be stopped // immediately after this check, in which case, elem will get processed. // The timers and SendBuffers code are resilient to a few stragglers. // TODO: rework peer shutdown order to ensure // that we never accidentally keep timers alive longer than necessary. - for _, elem := range *elems { - elem.Lock() + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } continue } dataSent := false - for _, elem := range *elems { - elem.Lock() + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { if len(elem.packet) != MessageKeepaliveSize { dataSent = true } @@ -521,11 +524,11 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { if dataSent { peer.timersDataSent() } - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } - device.PutOutboundElementsSlice(elems) + device.PutOutboundElementsContainer(elemsContainer) if err != nil { var errGSO conn.ErrUDPGSODisabled if errors.As(err, &errGSO) { From ec6f23b33e03f304681b0a440f318636d002a380 Mon Sep 17 00:00:00 2001 From: Adrian Dewhurst Date: Thu, 18 May 2023 10:09:53 -0400 Subject: [PATCH 08/39] tun: checksum tests and benchmarks Signed-off-by: Adrian Dewhurst --- tun/checksum_generic_test.go | 9 + tun/checksum_test.go | 577 ++++++++++++++++++++++++++++++++++- 2 files changed, 579 insertions(+), 7 deletions(-) create mode 100644 tun/checksum_generic_test.go diff --git a/tun/checksum_generic_test.go b/tun/checksum_generic_test.go new file mode 100644 index 000000000..a0c945740 --- /dev/null +++ b/tun/checksum_generic_test.go @@ -0,0 +1,9 @@ +package tun + +var archChecksumFuncs = []archChecksumDetails{ + { + name: "generic", + available: true, + f: checksum, + }, +} diff --git a/tun/checksum_test.go b/tun/checksum_test.go index c1ccff531..b3a358378 100644 --- a/tun/checksum_test.go +++ b/tun/checksum_test.go @@ -2,33 +2,596 @@ package tun import ( "fmt" + "math" "math/rand" + "net/netip" + "sort" + "syscall" "testing" + "unsafe" + + "gvisor.dev/gvisor/pkg/tcpip" + gvisorChecksum "gvisor.dev/gvisor/pkg/tcpip/checksum" + "gvisor.dev/gvisor/pkg/tcpip/header" ) +type archChecksumDetails struct { + name string + available bool + f func([]byte, uint64) uint16 +} + +func deterministicRandomBytes(seed int64, length int) []byte { + rng := rand.New(rand.NewSource(seed)) + buf := make([]byte, length) + n, err := rng.Read(buf) + if err != nil { + panic(err) + } + if n != length { + panic("incomplete random buffer") + } + return buf +} + +func getPageAlignedRandomBytes(seed int64, length int) []byte { + alignment := syscall.Getpagesize() + buf := deterministicRandomBytes(seed, length+(alignment-1)) + bufPtr := uintptr(unsafe.Pointer(&buf[0])) + alignedBufPtr := (bufPtr + uintptr(alignment-1)) & ^uintptr(alignment-1) + alignedStart := int(alignedBufPtr - bufPtr) + return buf[alignedStart:] +} + +func TestChecksum(t *testing.T) { + alignedBuf := getPageAlignedRandomBytes(10, 8192) + allOnes := make([]byte, 65535) + for i := range allOnes { + allOnes[i] = 0xff + } + allFE := make([]byte, 65535) + for i := range allFE { + allFE[i] = 0xfe + } + + tests := []struct { + name string + data []byte + initial uint16 + want uint16 + }{ + { + name: "empty", + data: []byte{}, + initial: 0, + want: 0, + }, + { + name: "max initial", + data: []byte{}, + initial: math.MaxUint16, + want: 0xffff, + }, + { + name: "odd length", + data: []byte{0x01, 0x02, 0x01}, + initial: 0, + want: 0x0202, + }, + { + name: "tiny", + data: []byte{0x01, 0x02, 0x01, 0x02, 0x01, 0x02}, + initial: 0, + want: 0x0306, + }, + { + name: "initial", + data: []byte{0x01, 0x02, 0x01, 0x02, 0x01, 0x02}, + initial: 0x1000, + want: 0x1306, + }, + // cleanup0 through cleanup15 is 1024 (handled by large SIMD loops) + + // 32 (handled by small SIMD loops) + n, where n ranges from 0 to 15 + // to cover all of the leftover byte sizes that are possible after small + // SIMD loops that handle 16 bytes. + { + name: "cleanup0", + data: deterministicRandomBytes(1, 1056), + initial: 0, + want: 0x11ec, + }, + { + name: "cleanup1", + data: deterministicRandomBytes(1, 1057), + initial: 0, + want: 0xc5ec, + }, + { + name: "cleanup2", + data: deterministicRandomBytes(1, 1058), + initial: 0, + want: 0xc6ad, + }, + { + name: "cleanup3", + data: deterministicRandomBytes(1, 1059), + initial: 0, + want: 0x86ae, + }, + { + name: "cleanup4", + data: deterministicRandomBytes(1, 1060), + initial: 0, + want: 0x878e, + }, + { + name: "cleanup5", + data: deterministicRandomBytes(1, 1061), + initial: 0, + want: 0xdb8e, + }, + { + name: "cleanup6", + data: deterministicRandomBytes(1, 1062), + initial: 0, + want: 0xdbd5, + }, + { + name: "cleanup7", + data: deterministicRandomBytes(1, 1063), + initial: 0, + want: 0xcfd6, + }, + { + name: "cleanup8", + data: deterministicRandomBytes(1, 1064), + initial: 0, + want: 0xd090, + }, + { + name: "cleanup9", + data: deterministicRandomBytes(1, 1065), + initial: 0, + want: 0x0791, + }, + { + name: "cleanup10", + data: deterministicRandomBytes(1, 1066), + initial: 0, + want: 0x079f, + }, + { + name: "cleanup11", + data: deterministicRandomBytes(1, 1067), + initial: 0, + want: 0xba9f, + }, + { + name: "cleanup12", + data: deterministicRandomBytes(1, 1068), + initial: 0, + want: 0xbb0c, + }, + { + name: "cleanup13", + data: deterministicRandomBytes(1, 1069), + initial: 0, + want: 0x770d, + }, + { + name: "cleanup14", + data: deterministicRandomBytes(1, 1070), + initial: 0, + want: 0x780a, + }, + { + name: "cleanup15", + data: deterministicRandomBytes(1, 1071), + initial: 0, + want: 0x640b, + }, + // small1 through small15 covers small sizes that are not large enough + // to do overlapped reads. + { + name: "small1", + data: deterministicRandomBytes(2, 1), + initial: 0x1122, + want: 0x4022, + }, + { + name: "small2", + data: deterministicRandomBytes(2, 2), + initial: 0x1122, + want: 0x40a4, + }, + { + name: "small3", + data: deterministicRandomBytes(2, 3), + initial: 0x1122, + want: 0xc2a4, + }, + { + name: "small4", + data: deterministicRandomBytes(2, 4), + initial: 0x1122, + want: 0xc36f, + }, + { + name: "small5", + data: deterministicRandomBytes(2, 5), + initial: 0x1122, + want: 0xa570, + }, + { + name: "small6", + data: deterministicRandomBytes(2, 6), + initial: 0x1122, + want: 0xa669, + }, + { + name: "small7", + data: deterministicRandomBytes(2, 7), + initial: 0x1122, + want: 0x0f6a, + }, + { + name: "small8", + data: deterministicRandomBytes(2, 8), + initial: 0x1122, + want: 0x0fd9, + }, + { + name: "small9", + data: deterministicRandomBytes(2, 9), + initial: 0x1122, + want: 0x40d9, + }, + { + name: "small10", + data: deterministicRandomBytes(2, 10), + initial: 0x1122, + want: 0x411d, + }, + { + name: "small11", + data: deterministicRandomBytes(2, 11), + initial: 0x1122, + want: 0x011e, + }, + { + name: "small12", + data: deterministicRandomBytes(2, 12), + initial: 0x1122, + want: 0x01c8, + }, + { + name: "small13", + data: deterministicRandomBytes(2, 13), + initial: 0x1122, + want: 0x4dc8, + }, + { + name: "small14", + data: deterministicRandomBytes(2, 14), + initial: 0x1122, + want: 0x4eb5, + }, + { + name: "small15", + data: deterministicRandomBytes(2, 15), + initial: 0x1122, + want: 0xa4b5, + }, + // other small-ish sizes + { + name: "small16", + data: deterministicRandomBytes(1, 16), + initial: 0, + want: 0x02fa, + }, + { + name: "small32", + data: deterministicRandomBytes(1, 32), + initial: 0, + want: 0x03ee, + }, + { + name: "small64", + data: deterministicRandomBytes(1, 64), + initial: 0, + want: 0x3f85, + }, + { + name: "medium", + data: deterministicRandomBytes(1, 1400), + initial: 0, + want: 0xbea5, + }, + { + name: "big", + data: deterministicRandomBytes(2, 65000), + initial: 0, + want: 0x3ba7, + }, + { + name: "big-initial", + data: deterministicRandomBytes(2, 65000), + initial: 0x1234, + want: 0x4ddb, + }, + { + // big-small-loop is intended to exercise a few iterations of a big + // initial loop of 128 bytes or larger + a smaller loop of 16 bytes + // + some leftover + name: "big-small-loop", + data: deterministicRandomBytes(3, 1094), + initial: 0x9999, + want: 0xe65b, + }, + { + name: "page-aligned", + data: alignedBuf[:4096], + initial: 0, + want: 0x963b, + }, + { + name: "32-aligned", + data: alignedBuf[32:4128], + initial: 0, + want: 0x30c4, + }, + { + name: "16-aligned", + data: alignedBuf[16:4112], + initial: 0, + want: 0xaeff, + }, + { + name: "8-aligned", + data: alignedBuf[8:4104], + initial: 0, + want: 0x6c3b, + }, + { + name: "4-aligned", + data: alignedBuf[4:4100], + initial: 0, + want: 0x2e4a, + }, + { + name: "2-aligned", + data: alignedBuf[2:4098], + initial: 0, + want: 0xc702, + }, + { + name: "unaligned", + data: alignedBuf[1:4097], + initial: 0, + want: 0x3bc7, + }, + { + name: "unalignedAndOdd", + data: alignedBuf[1:4096], + initial: 0, + want: 0x3b13, + }, + { + name: "fe1282", + data: allFE[:1282], + initial: 0, + want: 0x7c7c, + }, + { + name: "fe", + data: allFE, + initial: 0, + want: 0x7e81, + }, + { + name: "maximum", + data: allOnes, + initial: 0, + want: 0xff00, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for _, fd := range archChecksumFuncs { + t.Run(fd.name, func(t *testing.T) { + if !fd.available { + t.Skip("can not run on this system") + } + if got := fd.f(tt.data, uint64(tt.initial)); got != tt.want { + t.Errorf("%s checksum = %04x, want %04x", fd.name, got, tt.want) + } + }) + } + t.Run("reference", func(t *testing.T) { + if got := gvisorChecksum.Checksum(tt.data, tt.initial); got != tt.want { + t.Errorf("reference checksum = %04x, want %04x", got, tt.want) + } + }) + }) + } +} + +func TestPseudoHeaderChecksumNoFold(t *testing.T) { + tests := []struct { + name string + protocol uint8 + srcAddr []byte + dstAddr []byte + totalLen uint16 + want uint16 + }{ + { + name: "ipv4", + protocol: syscall.IPPROTO_TCP, + srcAddr: netip.MustParseAddr("192.168.1.1").AsSlice(), + dstAddr: netip.MustParseAddr("192.168.1.2").AsSlice(), + totalLen: 1492, + want: 0x892e, + }, + { + name: "ipv6", + protocol: syscall.IPPROTO_TCP, + srcAddr: netip.MustParseAddr("2001:db8:3333:4444:5555:6666:7777:8888").AsSlice(), + dstAddr: netip.MustParseAddr("2001:db8:aaaa:bbbb:cccc:dddd:eeee:ffff").AsSlice(), + totalLen: 1492, + want: 0x947f, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotNoFold := pseudoHeaderChecksumNoFold(tt.protocol, tt.srcAddr, tt.dstAddr, tt.totalLen) + got := checksum([]byte{}, gotNoFold) + if got != tt.want { + t.Errorf("pseudoHeaderChecksumNoFold() = %x, folds to %04x, want %04x", gotNoFold, got, tt.want) + } + + got = header.PseudoHeaderChecksum( + tcpip.TransportProtocolNumber(tt.protocol), + tcpip.AddrFromSlice(tt.srcAddr), + tcpip.AddrFromSlice(tt.dstAddr), + tt.totalLen) + if got != tt.want { + t.Errorf("header.PseudoHeaderChecksum() = %04x, want %04x", got, tt.want) + } + }) + } +} + +func FuzzChecksum(f *testing.F) { + buf := getPageAlignedRandomBytes(1234, 65536) + + f.Add([]byte{}, uint16(0)) + f.Add([]byte{}, uint16(0x1234)) + f.Add([]byte{}, uint16(0)) + f.Add(buf[:15], uint16(0x1234)) + f.Add(buf[:256], uint16(0x1234)) + f.Add(buf[:1280], uint16(0x1234)) + f.Add(buf[:1288], uint16(0x1234)) + f.Add(buf[1:1050], uint16(0x1234)) + + f.Fuzz(func(t *testing.T, data []byte, initial uint16) { + want := gvisorChecksum.Checksum(data, initial) + + for _, fd := range archChecksumFuncs { + t.Run(fd.name, func(t *testing.T) { + if !fd.available { + t.Skip("can not run on this system") + } + if got := fd.f(data, uint64(initial)); got != want { + t.Errorf("%s checksum = %04x, want %04x", fd.name, got, want) + } + }) + } + }) +} + +var result uint16 +var result64 uint64 + func BenchmarkChecksum(b *testing.B) { + offsets := []int{ // offsets from page alignment + 0, + 1, + 2, + 4, + 8, + 16, + } lengths := []int{ + 0, + 7, + 15, + 16, + 31, 64, + 90, + 95, 128, 256, 512, 1024, + 1240, 1500, 2048, 4096, 8192, 9000, 9001, + 16384, + 65536, + } + if !sort.IntsAreSorted(offsets) { + b.Fatal("offsets are not sorted") + } + largestLength := lengths[len(lengths)-1] + if !sort.IntsAreSorted(lengths) { + b.Fatal("lengths are not sorted") } + largestOffset := lengths[len(offsets)-1] + alignedBuf := getPageAlignedRandomBytes(1, largestOffset+largestLength) + var r uint16 + for _, offset := range offsets { + name := fmt.Sprintf("%vAligned", offset) + if offset == 0 { + name = "pageAligned" + } + offsetBuf := alignedBuf[offset:] + b.Run(name, func(b *testing.B) { + for _, length := range lengths { + b.Run(fmt.Sprintf("%d", length), func(b *testing.B) { + for _, fd := range archChecksumFuncs { + b.Run(fd.name, func(b *testing.B) { + if !fd.available { + b.Skip("can not run on this system") + } + b.SetBytes(int64(length)) + for i := 0; i < b.N; i++ { + r += fd.f(offsetBuf[:length], 0) + } + }) + } + }) + } + }) + } + result = r +} - for _, length := range lengths { - b.Run(fmt.Sprintf("%d", length), func(b *testing.B) { - buf := make([]byte, length) - rng := rand.New(rand.NewSource(1)) - rng.Read(buf) - b.ResetTimer() +func BenchmarkPseudoHeaderChecksum(b *testing.B) { + tests := []struct { + name string + protocol uint8 + srcAddr []byte + dstAddr []byte + totalLen uint16 + want uint16 + }{ + { + name: "ipv4", + protocol: syscall.IPPROTO_TCP, + srcAddr: []byte{192, 168, 1, 1}, + dstAddr: []byte{192, 168, 1, 2}, + totalLen: 1492, + want: 0x892e, + }, + { + name: "ipv6", + protocol: syscall.IPPROTO_TCP, + srcAddr: netip.MustParseAddr("2001:db8:3333:4444:5555:6666:7777:8888").AsSlice(), + dstAddr: netip.MustParseAddr("2001:db8:aaaa:bbbb:cccc:dddd:eeee:ffff").AsSlice(), + totalLen: 1492, + want: 0x892e, + }, + } + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { for i := 0; i < b.N; i++ { - checksum(buf, 0) + result64 += pseudoHeaderChecksumNoFold(tt.protocol, tt.srcAddr, tt.dstAddr, tt.totalLen) } }) } From 88b11b4a0dde05ecd823e896c198014b1652e7c9 Mon Sep 17 00:00:00 2001 From: Adrian Dewhurst Date: Thu, 30 Mar 2023 21:48:43 -0400 Subject: [PATCH 09/39] tun: AMD64 optimized checksum This adds AMD64 assembly implementations of IP checksum computation, one for baseline AMD64 and the other for v3 AMD64 (AVX2 and BMI2). All performance numbers reported are from a Ryzen 7 4750U but similar improvements are expected for a wide range of processors. The generic IP checksum implementation has also been further improved to be significantly faster using bits.AddUint64 (for a 64KiB buffer the throughput improves from 15,000MiB/s to 27,600MiB/s; similar gains are also reported on ARM64 but I do not have specific numbers). The baseline AMD64 implementation for a 64KiB buffer reports 32,700MiB/s and the AVX2 implementation is slightly over 107,000MiB/s. Unfortunately, for very small sizes (e.g. the expected size for an IPv4 header) setting up SIMD computation involves some overhead that makes computing a checksum for small buffers slower than a non-SIMD implementation. Even more unfortunately, testing for this at runtimen in Go and calling a func optimized for small buffers mitigates most of the improvement due to call overhead. The break even point is around 256 byte buffers; IPv4 headers are no more than 60 bytes including extensions. IPv6 headers do not have a checksum but are a fixed size of 40 bytes. As a result, the generated assembly code uses an alternate approach for buffers of less than 256 bytes. Additionally, buffers of less than 32 bytes need to be handled specially because the strategy for reading buffers that are not a multiple of 8 bytes fails when the buffer is too small. As suggested by additional benchmarking, pseudo header computation has been rewritten to be faster (benchmark time reduced by 1/2 to 1/4). Updates tailscale/corp#9755 Signed-off-by: Adrian Dewhurst --- tun/checksum.go | 764 ++++++++++++++++++++++++---- tun/checksum_amd64.go | 20 + tun/checksum_amd64_test.go | 45 ++ tun/checksum_generated_amd64.go | 18 + tun/checksum_generated_amd64.s | 851 ++++++++++++++++++++++++++++++++ tun/checksum_generic.go | 15 + tun/checksum_generic_test.go | 21 +- tun/checksum_test.go | 56 ++- tun/generate_amd64.go | 579 ++++++++++++++++++++++ tun/tcp_offload_linux.go | 12 +- 10 files changed, 2266 insertions(+), 115 deletions(-) create mode 100644 tun/checksum_amd64.go create mode 100644 tun/checksum_amd64_test.go create mode 100644 tun/checksum_generated_amd64.go create mode 100644 tun/checksum_generated_amd64.s create mode 100644 tun/checksum_generic.go create mode 100644 tun/generate_amd64.go diff --git a/tun/checksum.go b/tun/checksum.go index 29a8fc8fc..ee3f35960 100644 --- a/tun/checksum.go +++ b/tun/checksum.go @@ -1,118 +1,710 @@ package tun -import "encoding/binary" +import ( + "encoding/binary" + "math/bits" + "strconv" -// TODO: Explore SIMD and/or other assembly optimizations. -// TODO: Test native endian loads. See RFC 1071 section 2 part B. -func checksumNoFold(b []byte, initial uint64) uint64 { - ac := initial + "golang.org/x/sys/cpu" +) + +// checksumGeneric64 is a reference implementation of checksum using 64 bit +// arithmetic for use in testing or when an architecture-specific implementation +// is not available. +func checksumGeneric64(b []byte, initial uint16) uint16 { + var ac uint64 + var carry uint64 + + if cpu.IsBigEndian { + ac = uint64(initial) + } else { + ac = uint64(bits.ReverseBytes16(initial)) + } for len(b) >= 128 { - ac += uint64(binary.BigEndian.Uint32(b[:4])) - ac += uint64(binary.BigEndian.Uint32(b[4:8])) - ac += uint64(binary.BigEndian.Uint32(b[8:12])) - ac += uint64(binary.BigEndian.Uint32(b[12:16])) - ac += uint64(binary.BigEndian.Uint32(b[16:20])) - ac += uint64(binary.BigEndian.Uint32(b[20:24])) - ac += uint64(binary.BigEndian.Uint32(b[24:28])) - ac += uint64(binary.BigEndian.Uint32(b[28:32])) - ac += uint64(binary.BigEndian.Uint32(b[32:36])) - ac += uint64(binary.BigEndian.Uint32(b[36:40])) - ac += uint64(binary.BigEndian.Uint32(b[40:44])) - ac += uint64(binary.BigEndian.Uint32(b[44:48])) - ac += uint64(binary.BigEndian.Uint32(b[48:52])) - ac += uint64(binary.BigEndian.Uint32(b[52:56])) - ac += uint64(binary.BigEndian.Uint32(b[56:60])) - ac += uint64(binary.BigEndian.Uint32(b[60:64])) - ac += uint64(binary.BigEndian.Uint32(b[64:68])) - ac += uint64(binary.BigEndian.Uint32(b[68:72])) - ac += uint64(binary.BigEndian.Uint32(b[72:76])) - ac += uint64(binary.BigEndian.Uint32(b[76:80])) - ac += uint64(binary.BigEndian.Uint32(b[80:84])) - ac += uint64(binary.BigEndian.Uint32(b[84:88])) - ac += uint64(binary.BigEndian.Uint32(b[88:92])) - ac += uint64(binary.BigEndian.Uint32(b[92:96])) - ac += uint64(binary.BigEndian.Uint32(b[96:100])) - ac += uint64(binary.BigEndian.Uint32(b[100:104])) - ac += uint64(binary.BigEndian.Uint32(b[104:108])) - ac += uint64(binary.BigEndian.Uint32(b[108:112])) - ac += uint64(binary.BigEndian.Uint32(b[112:116])) - ac += uint64(binary.BigEndian.Uint32(b[116:120])) - ac += uint64(binary.BigEndian.Uint32(b[120:124])) - ac += uint64(binary.BigEndian.Uint32(b[124:128])) + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[64:72]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[72:80]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[80:88]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[88:96]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[96:104]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[104:112]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[112:120]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[120:128]), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[64:72]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[72:80]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[80:88]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[88:96]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[96:104]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[104:112]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[112:120]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[120:128]), carry) + } b = b[128:] } if len(b) >= 64 { - ac += uint64(binary.BigEndian.Uint32(b[:4])) - ac += uint64(binary.BigEndian.Uint32(b[4:8])) - ac += uint64(binary.BigEndian.Uint32(b[8:12])) - ac += uint64(binary.BigEndian.Uint32(b[12:16])) - ac += uint64(binary.BigEndian.Uint32(b[16:20])) - ac += uint64(binary.BigEndian.Uint32(b[20:24])) - ac += uint64(binary.BigEndian.Uint32(b[24:28])) - ac += uint64(binary.BigEndian.Uint32(b[28:32])) - ac += uint64(binary.BigEndian.Uint32(b[32:36])) - ac += uint64(binary.BigEndian.Uint32(b[36:40])) - ac += uint64(binary.BigEndian.Uint32(b[40:44])) - ac += uint64(binary.BigEndian.Uint32(b[44:48])) - ac += uint64(binary.BigEndian.Uint32(b[48:52])) - ac += uint64(binary.BigEndian.Uint32(b[52:56])) - ac += uint64(binary.BigEndian.Uint32(b[56:60])) - ac += uint64(binary.BigEndian.Uint32(b[60:64])) + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry) + } + b = b[64:] + } + if len(b) >= 32 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry) + } + b = b[32:] + } + if len(b) >= 16 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry) + } + b = b[16:] + } + if len(b) >= 8 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b), carry) + } + b = b[8:] + } + if len(b) >= 4 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint32(b)), carry) + } else { + ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint32(b)), carry) + } + b = b[4:] + } + if len(b) >= 2 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint16(b)), carry) + } else { + ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint16(b)), carry) + } + b = b[2:] + } + if len(b) >= 1 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, uint64(b[0])<<8, carry) + } else { + ac, carry = bits.Add64(ac, uint64(b[0]), carry) + } + } + + folded := ipChecksumFold64(ac, carry) + if !cpu.IsBigEndian { + folded = bits.ReverseBytes16(folded) + } + return folded +} + +// checksumGeneric32 is a reference implementation of checksum using 32 bit +// arithmetic for use in testing or when an architecture-specific implementation +// is not available. +func checksumGeneric32(b []byte, initial uint16) uint16 { + var ac uint32 + var carry uint32 + + if cpu.IsBigEndian { + ac = uint32(initial) + } else { + ac = uint32(bits.ReverseBytes16(initial)) + } + + for len(b) >= 64 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:8]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[32:36]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[36:40]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[40:44]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[44:48]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[48:52]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[52:56]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[56:60]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[60:64]), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:8]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[32:36]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[36:40]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[40:44]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[44:48]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[48:52]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[52:56]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[56:60]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[60:64]), carry) + } + b = b[64:] + } + if len(b) >= 32 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry) + } + b = b[32:] + } + if len(b) >= 16 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry) + } + b = b[16:] + } + if len(b) >= 8 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry) + } + b = b[8:] + } + if len(b) >= 4 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b), carry) + } + b = b[4:] + } + if len(b) >= 2 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, uint32(binary.BigEndian.Uint16(b)), carry) + } else { + ac, carry = bits.Add32(ac, uint32(binary.LittleEndian.Uint16(b)), carry) + } + b = b[2:] + } + if len(b) >= 1 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, uint32(b[0])<<8, carry) + } else { + ac, carry = bits.Add32(ac, uint32(b[0]), carry) + } + } + + folded := ipChecksumFold32(ac, carry) + if !cpu.IsBigEndian { + folded = bits.ReverseBytes16(folded) + } + return folded +} + +// checksumGeneric32Alternate is an alternate reference implementation of +// checksum using 32 bit arithmetic for use in testing or when an +// architecture-specific implementation is not available. +func checksumGeneric32Alternate(b []byte, initial uint16) uint16 { + var ac uint32 + + if cpu.IsBigEndian { + ac = uint32(initial) + } else { + ac = uint32(bits.ReverseBytes16(initial)) + } + + for len(b) >= 64 { + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + ac += uint32(binary.BigEndian.Uint16(b[4:6])) + ac += uint32(binary.BigEndian.Uint16(b[6:8])) + ac += uint32(binary.BigEndian.Uint16(b[8:10])) + ac += uint32(binary.BigEndian.Uint16(b[10:12])) + ac += uint32(binary.BigEndian.Uint16(b[12:14])) + ac += uint32(binary.BigEndian.Uint16(b[14:16])) + ac += uint32(binary.BigEndian.Uint16(b[16:18])) + ac += uint32(binary.BigEndian.Uint16(b[18:20])) + ac += uint32(binary.BigEndian.Uint16(b[20:22])) + ac += uint32(binary.BigEndian.Uint16(b[22:24])) + ac += uint32(binary.BigEndian.Uint16(b[24:26])) + ac += uint32(binary.BigEndian.Uint16(b[26:28])) + ac += uint32(binary.BigEndian.Uint16(b[28:30])) + ac += uint32(binary.BigEndian.Uint16(b[30:32])) + ac += uint32(binary.BigEndian.Uint16(b[32:34])) + ac += uint32(binary.BigEndian.Uint16(b[34:36])) + ac += uint32(binary.BigEndian.Uint16(b[36:38])) + ac += uint32(binary.BigEndian.Uint16(b[38:40])) + ac += uint32(binary.BigEndian.Uint16(b[40:42])) + ac += uint32(binary.BigEndian.Uint16(b[42:44])) + ac += uint32(binary.BigEndian.Uint16(b[44:46])) + ac += uint32(binary.BigEndian.Uint16(b[46:48])) + ac += uint32(binary.BigEndian.Uint16(b[48:50])) + ac += uint32(binary.BigEndian.Uint16(b[50:52])) + ac += uint32(binary.BigEndian.Uint16(b[52:54])) + ac += uint32(binary.BigEndian.Uint16(b[54:56])) + ac += uint32(binary.BigEndian.Uint16(b[56:58])) + ac += uint32(binary.BigEndian.Uint16(b[58:60])) + ac += uint32(binary.BigEndian.Uint16(b[60:62])) + ac += uint32(binary.BigEndian.Uint16(b[62:64])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + ac += uint32(binary.LittleEndian.Uint16(b[4:6])) + ac += uint32(binary.LittleEndian.Uint16(b[6:8])) + ac += uint32(binary.LittleEndian.Uint16(b[8:10])) + ac += uint32(binary.LittleEndian.Uint16(b[10:12])) + ac += uint32(binary.LittleEndian.Uint16(b[12:14])) + ac += uint32(binary.LittleEndian.Uint16(b[14:16])) + ac += uint32(binary.LittleEndian.Uint16(b[16:18])) + ac += uint32(binary.LittleEndian.Uint16(b[18:20])) + ac += uint32(binary.LittleEndian.Uint16(b[20:22])) + ac += uint32(binary.LittleEndian.Uint16(b[22:24])) + ac += uint32(binary.LittleEndian.Uint16(b[24:26])) + ac += uint32(binary.LittleEndian.Uint16(b[26:28])) + ac += uint32(binary.LittleEndian.Uint16(b[28:30])) + ac += uint32(binary.LittleEndian.Uint16(b[30:32])) + ac += uint32(binary.LittleEndian.Uint16(b[32:34])) + ac += uint32(binary.LittleEndian.Uint16(b[34:36])) + ac += uint32(binary.LittleEndian.Uint16(b[36:38])) + ac += uint32(binary.LittleEndian.Uint16(b[38:40])) + ac += uint32(binary.LittleEndian.Uint16(b[40:42])) + ac += uint32(binary.LittleEndian.Uint16(b[42:44])) + ac += uint32(binary.LittleEndian.Uint16(b[44:46])) + ac += uint32(binary.LittleEndian.Uint16(b[46:48])) + ac += uint32(binary.LittleEndian.Uint16(b[48:50])) + ac += uint32(binary.LittleEndian.Uint16(b[50:52])) + ac += uint32(binary.LittleEndian.Uint16(b[52:54])) + ac += uint32(binary.LittleEndian.Uint16(b[54:56])) + ac += uint32(binary.LittleEndian.Uint16(b[56:58])) + ac += uint32(binary.LittleEndian.Uint16(b[58:60])) + ac += uint32(binary.LittleEndian.Uint16(b[60:62])) + ac += uint32(binary.LittleEndian.Uint16(b[62:64])) + } b = b[64:] } if len(b) >= 32 { - ac += uint64(binary.BigEndian.Uint32(b[:4])) - ac += uint64(binary.BigEndian.Uint32(b[4:8])) - ac += uint64(binary.BigEndian.Uint32(b[8:12])) - ac += uint64(binary.BigEndian.Uint32(b[12:16])) - ac += uint64(binary.BigEndian.Uint32(b[16:20])) - ac += uint64(binary.BigEndian.Uint32(b[20:24])) - ac += uint64(binary.BigEndian.Uint32(b[24:28])) - ac += uint64(binary.BigEndian.Uint32(b[28:32])) + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + ac += uint32(binary.BigEndian.Uint16(b[4:6])) + ac += uint32(binary.BigEndian.Uint16(b[6:8])) + ac += uint32(binary.BigEndian.Uint16(b[8:10])) + ac += uint32(binary.BigEndian.Uint16(b[10:12])) + ac += uint32(binary.BigEndian.Uint16(b[12:14])) + ac += uint32(binary.BigEndian.Uint16(b[14:16])) + ac += uint32(binary.BigEndian.Uint16(b[16:18])) + ac += uint32(binary.BigEndian.Uint16(b[18:20])) + ac += uint32(binary.BigEndian.Uint16(b[20:22])) + ac += uint32(binary.BigEndian.Uint16(b[22:24])) + ac += uint32(binary.BigEndian.Uint16(b[24:26])) + ac += uint32(binary.BigEndian.Uint16(b[26:28])) + ac += uint32(binary.BigEndian.Uint16(b[28:30])) + ac += uint32(binary.BigEndian.Uint16(b[30:32])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + ac += uint32(binary.LittleEndian.Uint16(b[4:6])) + ac += uint32(binary.LittleEndian.Uint16(b[6:8])) + ac += uint32(binary.LittleEndian.Uint16(b[8:10])) + ac += uint32(binary.LittleEndian.Uint16(b[10:12])) + ac += uint32(binary.LittleEndian.Uint16(b[12:14])) + ac += uint32(binary.LittleEndian.Uint16(b[14:16])) + ac += uint32(binary.LittleEndian.Uint16(b[16:18])) + ac += uint32(binary.LittleEndian.Uint16(b[18:20])) + ac += uint32(binary.LittleEndian.Uint16(b[20:22])) + ac += uint32(binary.LittleEndian.Uint16(b[22:24])) + ac += uint32(binary.LittleEndian.Uint16(b[24:26])) + ac += uint32(binary.LittleEndian.Uint16(b[26:28])) + ac += uint32(binary.LittleEndian.Uint16(b[28:30])) + ac += uint32(binary.LittleEndian.Uint16(b[30:32])) + } b = b[32:] } if len(b) >= 16 { - ac += uint64(binary.BigEndian.Uint32(b[:4])) - ac += uint64(binary.BigEndian.Uint32(b[4:8])) - ac += uint64(binary.BigEndian.Uint32(b[8:12])) - ac += uint64(binary.BigEndian.Uint32(b[12:16])) + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + ac += uint32(binary.BigEndian.Uint16(b[4:6])) + ac += uint32(binary.BigEndian.Uint16(b[6:8])) + ac += uint32(binary.BigEndian.Uint16(b[8:10])) + ac += uint32(binary.BigEndian.Uint16(b[10:12])) + ac += uint32(binary.BigEndian.Uint16(b[12:14])) + ac += uint32(binary.BigEndian.Uint16(b[14:16])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + ac += uint32(binary.LittleEndian.Uint16(b[4:6])) + ac += uint32(binary.LittleEndian.Uint16(b[6:8])) + ac += uint32(binary.LittleEndian.Uint16(b[8:10])) + ac += uint32(binary.LittleEndian.Uint16(b[10:12])) + ac += uint32(binary.LittleEndian.Uint16(b[12:14])) + ac += uint32(binary.LittleEndian.Uint16(b[14:16])) + } b = b[16:] } if len(b) >= 8 { - ac += uint64(binary.BigEndian.Uint32(b[:4])) - ac += uint64(binary.BigEndian.Uint32(b[4:8])) + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + ac += uint32(binary.BigEndian.Uint16(b[4:6])) + ac += uint32(binary.BigEndian.Uint16(b[6:8])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + ac += uint32(binary.LittleEndian.Uint16(b[4:6])) + ac += uint32(binary.LittleEndian.Uint16(b[6:8])) + } b = b[8:] } if len(b) >= 4 { - ac += uint64(binary.BigEndian.Uint32(b)) + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + } b = b[4:] } if len(b) >= 2 { - ac += uint64(binary.BigEndian.Uint16(b)) + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b)) + } else { + ac += uint32(binary.LittleEndian.Uint16(b)) + } b = b[2:] } - if len(b) == 1 { - ac += uint64(b[0]) << 8 + if len(b) >= 1 { + if cpu.IsBigEndian { + ac += uint32(b[0]) << 8 + } else { + ac += uint32(b[0]) + } } - return ac + folded := ipChecksumFold32(ac, 0) + if !cpu.IsBigEndian { + folded = bits.ReverseBytes16(folded) + } + return folded } -func checksum(b []byte, initial uint64) uint16 { - ac := checksumNoFold(b, initial) - ac = (ac >> 16) + (ac & 0xffff) - ac = (ac >> 16) + (ac & 0xffff) - ac = (ac >> 16) + (ac & 0xffff) - ac = (ac >> 16) + (ac & 0xffff) - return uint16(ac) +// checksumGeneric64Alternate is an alternate reference implementation of +// checksum using 64 bit arithmetic for use in testing or when an +// architecture-specific implementation is not available. +func checksumGeneric64Alternate(b []byte, initial uint16) uint16 { + var ac uint64 + + if cpu.IsBigEndian { + ac = uint64(initial) + } else { + ac = uint64(bits.ReverseBytes16(initial)) + } + + for len(b) >= 64 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + } else { + ac += uint64(binary.LittleEndian.Uint32(b[:4])) + ac += uint64(binary.LittleEndian.Uint32(b[4:8])) + ac += uint64(binary.LittleEndian.Uint32(b[8:12])) + ac += uint64(binary.LittleEndian.Uint32(b[12:16])) + ac += uint64(binary.LittleEndian.Uint32(b[16:20])) + ac += uint64(binary.LittleEndian.Uint32(b[20:24])) + ac += uint64(binary.LittleEndian.Uint32(b[24:28])) + ac += uint64(binary.LittleEndian.Uint32(b[28:32])) + ac += uint64(binary.LittleEndian.Uint32(b[32:36])) + ac += uint64(binary.LittleEndian.Uint32(b[36:40])) + ac += uint64(binary.LittleEndian.Uint32(b[40:44])) + ac += uint64(binary.LittleEndian.Uint32(b[44:48])) + ac += uint64(binary.LittleEndian.Uint32(b[48:52])) + ac += uint64(binary.LittleEndian.Uint32(b[52:56])) + ac += uint64(binary.LittleEndian.Uint32(b[56:60])) + ac += uint64(binary.LittleEndian.Uint32(b[60:64])) + } + b = b[64:] + } + if len(b) >= 32 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + } else { + ac += uint64(binary.LittleEndian.Uint32(b[:4])) + ac += uint64(binary.LittleEndian.Uint32(b[4:8])) + ac += uint64(binary.LittleEndian.Uint32(b[8:12])) + ac += uint64(binary.LittleEndian.Uint32(b[12:16])) + ac += uint64(binary.LittleEndian.Uint32(b[16:20])) + ac += uint64(binary.LittleEndian.Uint32(b[20:24])) + ac += uint64(binary.LittleEndian.Uint32(b[24:28])) + ac += uint64(binary.LittleEndian.Uint32(b[28:32])) + } + b = b[32:] + } + if len(b) >= 16 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + } else { + ac += uint64(binary.LittleEndian.Uint32(b[:4])) + ac += uint64(binary.LittleEndian.Uint32(b[4:8])) + ac += uint64(binary.LittleEndian.Uint32(b[8:12])) + ac += uint64(binary.LittleEndian.Uint32(b[12:16])) + } + b = b[16:] + } + if len(b) >= 8 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + } else { + ac += uint64(binary.LittleEndian.Uint32(b[:4])) + ac += uint64(binary.LittleEndian.Uint32(b[4:8])) + } + b = b[8:] + } + if len(b) >= 4 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b)) + } else { + ac += uint64(binary.LittleEndian.Uint32(b)) + } + b = b[4:] + } + if len(b) >= 2 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint16(b)) + } else { + ac += uint64(binary.LittleEndian.Uint16(b)) + } + b = b[2:] + } + if len(b) >= 1 { + if cpu.IsBigEndian { + ac += uint64(b[0]) << 8 + } else { + ac += uint64(b[0]) + } + } + + folded := ipChecksumFold64(ac, 0) + if !cpu.IsBigEndian { + folded = bits.ReverseBytes16(folded) + } + return folded +} + +func ipChecksumFold64(unfolded uint64, initialCarry uint64) uint16 { + sum, carry := bits.Add32(uint32(unfolded>>32), uint32(unfolded&0xffff_ffff), uint32(initialCarry)) + // if carry != 0, sum <= 0xffff_fffe, otherwise sum <= 0xffff_ffff + // therefore (sum >> 16) + (sum & 0xffff) + carry <= 0x1_fffe; so there is + // no need to save the carry flag + sum = (sum >> 16) + (sum & 0xffff) + carry + // sum <= 0x1_fffe therefore this is the last fold needed: + // if (sum >> 16) > 0 then + // (sum >> 16) == 1 && (sum & 0xffff) <= 0xfffe and therefore + // the addition will not overflow + // otherwise (sum >> 16) == 0 and sum will be unchanged + sum = (sum >> 16) + (sum & 0xffff) + return uint16(sum) +} + +func ipChecksumFold32(unfolded uint32, initialCarry uint32) uint16 { + sum := (unfolded >> 16) + (unfolded & 0xffff) + initialCarry + // sum <= 0x1_ffff: + // 0xffff + 0xffff = 0x1_fffe + // initialCarry is 0 or 1, for a combined maximum of 0x1_ffff + sum = (sum >> 16) + (sum & 0xffff) + // sum <= 0x1_0000 therefore this is the last fold needed: + // if (sum >> 16) > 0 then + // (sum >> 16) == 1 && (sum & 0xffff) == 0 and therefore + // the addition will not overflow + // otherwise (sum >> 16) == 0 and sum will be unchanged + sum = (sum >> 16) + (sum & 0xffff) + return uint16(sum) +} + +func addrPartialChecksum64(addr []byte, initial, carryIn uint64) (sum, carry uint64) { + sum, carry = initial, carryIn + switch len(addr) { + case 4: // IPv4 + if cpu.IsBigEndian { + sum, carry = bits.Add64(sum, uint64(binary.BigEndian.Uint32(addr)), carry) + } else { + sum, carry = bits.Add64(sum, uint64(binary.LittleEndian.Uint32(addr)), carry) + } + case 16: // IPv6 + if cpu.IsBigEndian { + sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr), carry) + sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr[8:]), carry) + } else { + sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr), carry) + sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr[8:]), carry) + } + default: + panic("bad addr length") + } + return sum, carry +} + +func addrPartialChecksum32(addr []byte, initial, carryIn uint32) (sum, carry uint32) { + sum, carry = initial, carryIn + switch len(addr) { + case 4: // IPv4 + if cpu.IsBigEndian { + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry) + } else { + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry) + } + case 16: // IPv6 + if cpu.IsBigEndian { + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry) + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[4:8]), carry) + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[8:12]), carry) + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[12:16]), carry) + } else { + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry) + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[4:8]), carry) + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[8:12]), carry) + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[12:16]), carry) + } + default: + panic("bad addr length") + } + return sum, carry +} + +func pseudoHeaderChecksum64(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 { + var sum uint64 + if cpu.IsBigEndian { + sum = uint64(totalLen) + uint64(protocol) + } else { + sum = uint64(bits.ReverseBytes16(totalLen)) + uint64(protocol)<<8 + } + sum, carry := addrPartialChecksum64(srcAddr, sum, 0) + sum, carry = addrPartialChecksum64(dstAddr, sum, carry) + + foldedSum := ipChecksumFold64(sum, carry) + if !cpu.IsBigEndian { + foldedSum = bits.ReverseBytes16(foldedSum) + } + return foldedSum } -func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 { - sum := checksumNoFold(srcAddr, 0) - sum = checksumNoFold(dstAddr, sum) - sum = checksumNoFold([]byte{0, protocol}, sum) - tmp := make([]byte, 2) - binary.BigEndian.PutUint16(tmp, totalLen) - return checksumNoFold(tmp, sum) +func pseudoHeaderChecksum32(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 { + var sum uint32 + if cpu.IsBigEndian { + sum = uint32(totalLen) + uint32(protocol) + } else { + sum = uint32(bits.ReverseBytes16(totalLen)) + uint32(protocol)<<8 + } + sum, carry := addrPartialChecksum32(srcAddr, sum, 0) + sum, carry = addrPartialChecksum32(dstAddr, sum, carry) + + foldedSum := ipChecksumFold32(sum, carry) + if !cpu.IsBigEndian { + foldedSum = bits.ReverseBytes16(foldedSum) + } + return foldedSum +} + +func pseudoHeaderChecksum(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 { + if strconv.IntSize < 64 { + return pseudoHeaderChecksum32(protocol, srcAddr, dstAddr, totalLen) + } + return pseudoHeaderChecksum64(protocol, srcAddr, dstAddr, totalLen) } diff --git a/tun/checksum_amd64.go b/tun/checksum_amd64.go new file mode 100644 index 000000000..5e87693b6 --- /dev/null +++ b/tun/checksum_amd64.go @@ -0,0 +1,20 @@ +package tun + +import "golang.org/x/sys/cpu" + +// checksum computes an IP checksum starting with the provided initial value. +// The length of data should be at least 128 bytes for best performance. Smaller +// buffers will still compute a correct result. For best performance with +// smaller buffers, use shortChecksum(). +var checksum = checksumAMD64 + +func init() { + if cpu.X86.HasAVX && cpu.X86.HasAVX2 && cpu.X86.HasBMI2 { + checksum = checksumAVX2 + return + } + if cpu.X86.HasSSE2 { + checksum = checksumSSE2 + return + } +} diff --git a/tun/checksum_amd64_test.go b/tun/checksum_amd64_test.go new file mode 100644 index 000000000..7a0b68140 --- /dev/null +++ b/tun/checksum_amd64_test.go @@ -0,0 +1,45 @@ +//go:build amd64 + +package tun + +import ( + "golang.org/x/sys/cpu" +) + +var archChecksumFuncs = []archChecksumDetails{ + { + name: "generic32", + available: true, + f: checksumGeneric32, + }, + { + name: "generic64", + available: true, + f: checksumGeneric64, + }, + { + name: "generic32Alternate", + available: true, + f: checksumGeneric32Alternate, + }, + { + name: "generic64Alternate", + available: true, + f: checksumGeneric64Alternate, + }, + { + name: "AMD64", + available: true, + f: checksumAMD64, + }, + { + name: "SSE2", + available: cpu.X86.HasSSE2, + f: checksumSSE2, + }, + { + name: "AVX2", + available: cpu.X86.HasAVX && cpu.X86.HasAVX2 && cpu.X86.HasBMI2, + f: checksumAVX2, + }, +} diff --git a/tun/checksum_generated_amd64.go b/tun/checksum_generated_amd64.go new file mode 100644 index 000000000..b4a29419b --- /dev/null +++ b/tun/checksum_generated_amd64.go @@ -0,0 +1,18 @@ +// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT. + +package tun + +// checksumAVX2 computes an IP checksum using amd64 v3 instructions (AVX2, BMI2) +// +//go:noescape +func checksumAVX2(b []byte, initial uint16) uint16 + +// checksumSSE2 computes an IP checksum using amd64 baseline instructions (SSE2) +// +//go:noescape +func checksumSSE2(b []byte, initial uint16) uint16 + +// checksumAMD64 computes an IP checksum using amd64 baseline instructions +// +//go:noescape +func checksumAMD64(b []byte, initial uint16) uint16 diff --git a/tun/checksum_generated_amd64.s b/tun/checksum_generated_amd64.s new file mode 100644 index 000000000..5f2e4c525 --- /dev/null +++ b/tun/checksum_generated_amd64.s @@ -0,0 +1,851 @@ +// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT. + +#include "textflag.h" + +DATA xmmLoadMasks<>+0(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" +DATA xmmLoadMasks<>+16(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff" +DATA xmmLoadMasks<>+32(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff" +DATA xmmLoadMasks<>+48(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff" +DATA xmmLoadMasks<>+64(SB)/16, $"\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +DATA xmmLoadMasks<>+80(SB)/16, $"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +DATA xmmLoadMasks<>+96(SB)/16, $"\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +GLOBL xmmLoadMasks<>(SB), RODATA|NOPTR, $112 + +// func checksumAVX2(b []byte, initial uint16) uint16 +// Requires: AVX, AVX2, BMI2 +TEXT ·checksumAVX2(SB), NOSPLIT|NOFRAME, $0-34 + MOVWQZX initial+24(FP), AX + XCHGB AH, AL + MOVQ b_base+0(FP), DX + MOVQ b_len+8(FP), BX + + // handle odd length buffers; they are difficult to handle in general + TESTQ $0x00000001, BX + JZ lengthIsEven + MOVBQZX -1(DX)(BX*1), CX + DECQ BX + ADDQ CX, AX + +lengthIsEven: + // handle tiny buffers (<=31 bytes) specially + CMPQ BX, $0x1f + JGT bufferIsNotTiny + XORQ CX, CX + XORQ SI, SI + XORQ DI, DI + + // shift twice to start because length is guaranteed to be even + // n = n >> 2; CF = originalN & 2 + SHRQ $0x02, BX + JNC handleTiny4 + + // tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:] + MOVWQZX (DX), CX + ADDQ $0x02, DX + +handleTiny4: + // n = n >> 1; CF = originalN & 4 + SHRQ $0x01, BX + JNC handleTiny8 + + // tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:] + MOVLQZX (DX), SI + ADDQ $0x04, DX + +handleTiny8: + // n = n >> 1; CF = originalN & 8 + SHRQ $0x01, BX + JNC handleTiny16 + + // tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:] + MOVQ (DX), DI + ADDQ $0x08, DX + +handleTiny16: + // n = n >> 1; CF = originalN & 16 + // n == 0 now, otherwise we would have branched after comparing with tinyBufferSize + SHRQ $0x01, BX + JNC handleTinyFinish + ADDQ (DX), AX + ADCQ 8(DX), AX + +handleTinyFinish: + // CF should be included from the previous add, so we use ADCQ. + // If we arrived via the JNC above, then CF=0 due to the branch condition, + // so ADCQ will still produce the correct result. + ADCQ CX, AX + ADCQ SI, AX + ADCQ DI, AX + JMP foldAndReturn + +bufferIsNotTiny: + // skip all SIMD for small buffers + CMPQ BX, $0x00000100 + JGE startSIMD + + // Accumulate carries in this register. It is never expected to overflow. + XORQ SI, SI + + // We will perform an overlapped read for buffers with length not a multiple of 8. + // Overlapped in this context means some memory will be read twice, but a shift will + // eliminate the duplicated data. This extra read is performed at the end of the buffer to + // preserve any alignment that may exist for the start of the buffer. + MOVQ BX, CX + SHRQ $0x03, BX + ANDQ $0x07, CX + JZ handleRemaining8 + LEAQ (DX)(BX*8), DI + MOVQ -8(DI)(CX*1), DI + + // Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8) + SHLQ $0x03, CX + NEGQ CX + ADDQ $0x40, CX + SHRQ CL, DI + ADDQ DI, AX + ADCQ $0x00, SI + +handleRemaining8: + SHRQ $0x01, BX + JNC handleRemaining16 + ADDQ (DX), AX + ADCQ $0x00, SI + ADDQ $0x08, DX + +handleRemaining16: + SHRQ $0x01, BX + JNC handleRemaining32 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ $0x00, SI + ADDQ $0x10, DX + +handleRemaining32: + SHRQ $0x01, BX + JNC handleRemaining64 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ $0x00, SI + ADDQ $0x20, DX + +handleRemaining64: + SHRQ $0x01, BX + JNC handleRemaining128 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ $0x00, SI + ADDQ $0x40, DX + +handleRemaining128: + SHRQ $0x01, BX + JNC handleRemainingComplete + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ 64(DX), AX + ADCQ 72(DX), AX + ADCQ 80(DX), AX + ADCQ 88(DX), AX + ADCQ 96(DX), AX + ADCQ 104(DX), AX + ADCQ 112(DX), AX + ADCQ 120(DX), AX + ADCQ $0x00, SI + ADDQ $0x80, DX + +handleRemainingComplete: + ADDQ SI, AX + JMP foldAndReturn + +startSIMD: + VPXOR Y0, Y0, Y0 + VPXOR Y1, Y1, Y1 + VPXOR Y2, Y2, Y2 + VPXOR Y3, Y3, Y3 + MOVQ BX, CX + + // Update number of bytes remaining after the loop completes + ANDQ $0xff, BX + + // Number of 256 byte iterations + SHRQ $0x08, CX + JZ smallLoop + +bigLoop: + VPMOVZXWD (DX), Y4 + VPADDD Y4, Y0, Y0 + VPMOVZXWD 16(DX), Y4 + VPADDD Y4, Y1, Y1 + VPMOVZXWD 32(DX), Y4 + VPADDD Y4, Y2, Y2 + VPMOVZXWD 48(DX), Y4 + VPADDD Y4, Y3, Y3 + VPMOVZXWD 64(DX), Y4 + VPADDD Y4, Y0, Y0 + VPMOVZXWD 80(DX), Y4 + VPADDD Y4, Y1, Y1 + VPMOVZXWD 96(DX), Y4 + VPADDD Y4, Y2, Y2 + VPMOVZXWD 112(DX), Y4 + VPADDD Y4, Y3, Y3 + VPMOVZXWD 128(DX), Y4 + VPADDD Y4, Y0, Y0 + VPMOVZXWD 144(DX), Y4 + VPADDD Y4, Y1, Y1 + VPMOVZXWD 160(DX), Y4 + VPADDD Y4, Y2, Y2 + VPMOVZXWD 176(DX), Y4 + VPADDD Y4, Y3, Y3 + VPMOVZXWD 192(DX), Y4 + VPADDD Y4, Y0, Y0 + VPMOVZXWD 208(DX), Y4 + VPADDD Y4, Y1, Y1 + VPMOVZXWD 224(DX), Y4 + VPADDD Y4, Y2, Y2 + VPMOVZXWD 240(DX), Y4 + VPADDD Y4, Y3, Y3 + ADDQ $0x00000100, DX + DECQ CX + JNZ bigLoop + CMPQ BX, $0x10 + JLT doneSmallLoop + + // now read a single 16 byte unit of data at a time +smallLoop: + VPMOVZXWD (DX), Y4 + VPADDD Y4, Y0, Y0 + ADDQ $0x10, DX + SUBQ $0x10, BX + CMPQ BX, $0x10 + JGE smallLoop + +doneSmallLoop: + CMPQ BX, $0x00 + JE doneSIMD + + // There are between 1 and 15 bytes remaining. Perform an overlapped read. + LEAQ xmmLoadMasks<>+0(SB), CX + VMOVDQU -16(DX)(BX*1), X4 + VPAND -16(CX)(BX*8), X4, X4 + VPMOVZXWD X4, Y4 + VPADDD Y4, Y0, Y0 + +doneSIMD: + // Multi-chain loop is done, combine the accumulators + VPADDD Y1, Y0, Y0 + VPADDD Y2, Y0, Y0 + VPADDD Y3, Y0, Y0 + + // extract the YMM into a pair of XMM and sum them + VEXTRACTI128 $0x01, Y0, X1 + VPADDD X0, X1, X0 + + // extract the XMM into GP64 + VPEXTRQ $0x00, X0, CX + VPEXTRQ $0x01, X0, DX + + // no more AVX code, clear upper registers to avoid SSE slowdowns + VZEROUPPER + ADDQ CX, AX + ADCQ DX, AX + +foldAndReturn: + // add CF and fold + RORXQ $0x20, AX, CX + ADCL CX, AX + RORXL $0x10, AX, CX + ADCW CX, AX + ADCW $0x00, AX + XCHGB AH, AL + MOVW AX, ret+32(FP) + RET + +// func checksumSSE2(b []byte, initial uint16) uint16 +// Requires: SSE2 +TEXT ·checksumSSE2(SB), NOSPLIT|NOFRAME, $0-34 + MOVWQZX initial+24(FP), AX + XCHGB AH, AL + MOVQ b_base+0(FP), DX + MOVQ b_len+8(FP), BX + + // handle odd length buffers; they are difficult to handle in general + TESTQ $0x00000001, BX + JZ lengthIsEven + MOVBQZX -1(DX)(BX*1), CX + DECQ BX + ADDQ CX, AX + +lengthIsEven: + // handle tiny buffers (<=31 bytes) specially + CMPQ BX, $0x1f + JGT bufferIsNotTiny + XORQ CX, CX + XORQ SI, SI + XORQ DI, DI + + // shift twice to start because length is guaranteed to be even + // n = n >> 2; CF = originalN & 2 + SHRQ $0x02, BX + JNC handleTiny4 + + // tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:] + MOVWQZX (DX), CX + ADDQ $0x02, DX + +handleTiny4: + // n = n >> 1; CF = originalN & 4 + SHRQ $0x01, BX + JNC handleTiny8 + + // tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:] + MOVLQZX (DX), SI + ADDQ $0x04, DX + +handleTiny8: + // n = n >> 1; CF = originalN & 8 + SHRQ $0x01, BX + JNC handleTiny16 + + // tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:] + MOVQ (DX), DI + ADDQ $0x08, DX + +handleTiny16: + // n = n >> 1; CF = originalN & 16 + // n == 0 now, otherwise we would have branched after comparing with tinyBufferSize + SHRQ $0x01, BX + JNC handleTinyFinish + ADDQ (DX), AX + ADCQ 8(DX), AX + +handleTinyFinish: + // CF should be included from the previous add, so we use ADCQ. + // If we arrived via the JNC above, then CF=0 due to the branch condition, + // so ADCQ will still produce the correct result. + ADCQ CX, AX + ADCQ SI, AX + ADCQ DI, AX + JMP foldAndReturn + +bufferIsNotTiny: + // skip all SIMD for small buffers + CMPQ BX, $0x00000100 + JGE startSIMD + + // Accumulate carries in this register. It is never expected to overflow. + XORQ SI, SI + + // We will perform an overlapped read for buffers with length not a multiple of 8. + // Overlapped in this context means some memory will be read twice, but a shift will + // eliminate the duplicated data. This extra read is performed at the end of the buffer to + // preserve any alignment that may exist for the start of the buffer. + MOVQ BX, CX + SHRQ $0x03, BX + ANDQ $0x07, CX + JZ handleRemaining8 + LEAQ (DX)(BX*8), DI + MOVQ -8(DI)(CX*1), DI + + // Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8) + SHLQ $0x03, CX + NEGQ CX + ADDQ $0x40, CX + SHRQ CL, DI + ADDQ DI, AX + ADCQ $0x00, SI + +handleRemaining8: + SHRQ $0x01, BX + JNC handleRemaining16 + ADDQ (DX), AX + ADCQ $0x00, SI + ADDQ $0x08, DX + +handleRemaining16: + SHRQ $0x01, BX + JNC handleRemaining32 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ $0x00, SI + ADDQ $0x10, DX + +handleRemaining32: + SHRQ $0x01, BX + JNC handleRemaining64 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ $0x00, SI + ADDQ $0x20, DX + +handleRemaining64: + SHRQ $0x01, BX + JNC handleRemaining128 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ $0x00, SI + ADDQ $0x40, DX + +handleRemaining128: + SHRQ $0x01, BX + JNC handleRemainingComplete + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ 64(DX), AX + ADCQ 72(DX), AX + ADCQ 80(DX), AX + ADCQ 88(DX), AX + ADCQ 96(DX), AX + ADCQ 104(DX), AX + ADCQ 112(DX), AX + ADCQ 120(DX), AX + ADCQ $0x00, SI + ADDQ $0x80, DX + +handleRemainingComplete: + ADDQ SI, AX + JMP foldAndReturn + +startSIMD: + PXOR X0, X0 + PXOR X1, X1 + PXOR X2, X2 + PXOR X3, X3 + PXOR X4, X4 + MOVQ BX, CX + + // Update number of bytes remaining after the loop completes + ANDQ $0xff, BX + + // Number of 256 byte iterations + SHRQ $0x08, CX + JZ smallLoop + +bigLoop: + MOVOU (DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X2 + MOVOU 16(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X1 + PADDD X6, X3 + MOVOU 32(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X2 + PADDD X6, X0 + MOVOU 48(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X3 + PADDD X6, X1 + MOVOU 64(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X2 + MOVOU 80(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X1 + PADDD X6, X3 + MOVOU 96(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X2 + PADDD X6, X0 + MOVOU 112(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X3 + PADDD X6, X1 + MOVOU 128(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X2 + MOVOU 144(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X1 + PADDD X6, X3 + MOVOU 160(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X2 + PADDD X6, X0 + MOVOU 176(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X3 + PADDD X6, X1 + MOVOU 192(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X2 + MOVOU 208(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X1 + PADDD X6, X3 + MOVOU 224(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X2 + PADDD X6, X0 + MOVOU 240(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X3 + PADDD X6, X1 + ADDQ $0x00000100, DX + DECQ CX + JNZ bigLoop + CMPQ BX, $0x10 + JLT doneSmallLoop + + // now read a single 16 byte unit of data at a time +smallLoop: + MOVOU (DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X1 + ADDQ $0x10, DX + SUBQ $0x10, BX + CMPQ BX, $0x10 + JGE smallLoop + +doneSmallLoop: + CMPQ BX, $0x00 + JE doneSIMD + + // There are between 1 and 15 bytes remaining. Perform an overlapped read. + LEAQ xmmLoadMasks<>+0(SB), CX + MOVOU -16(DX)(BX*1), X5 + PAND -16(CX)(BX*8), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X1 + +doneSIMD: + // Multi-chain loop is done, combine the accumulators + PADDD X1, X0 + PADDD X2, X0 + PADDD X3, X0 + + // extract the XMM into GP64 + MOVQ X0, CX + PSRLDQ $0x08, X0 + MOVQ X0, DX + ADDQ CX, AX + ADCQ DX, AX + +foldAndReturn: + // add CF and fold + MOVL AX, CX + ADCQ $0x00, CX + SHRQ $0x20, AX + ADDQ CX, AX + MOVWQZX AX, CX + SHRQ $0x10, AX + ADDQ CX, AX + MOVW AX, CX + SHRQ $0x10, AX + ADDW CX, AX + ADCW $0x00, AX + XCHGB AH, AL + MOVW AX, ret+32(FP) + RET + +// func checksumAMD64(b []byte, initial uint16) uint16 +TEXT ·checksumAMD64(SB), NOSPLIT|NOFRAME, $0-34 + MOVWQZX initial+24(FP), AX + XCHGB AH, AL + MOVQ b_base+0(FP), DX + MOVQ b_len+8(FP), BX + + // handle odd length buffers; they are difficult to handle in general + TESTQ $0x00000001, BX + JZ lengthIsEven + MOVBQZX -1(DX)(BX*1), CX + DECQ BX + ADDQ CX, AX + +lengthIsEven: + // handle tiny buffers (<=31 bytes) specially + CMPQ BX, $0x1f + JGT bufferIsNotTiny + XORQ CX, CX + XORQ SI, SI + XORQ DI, DI + + // shift twice to start because length is guaranteed to be even + // n = n >> 2; CF = originalN & 2 + SHRQ $0x02, BX + JNC handleTiny4 + + // tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:] + MOVWQZX (DX), CX + ADDQ $0x02, DX + +handleTiny4: + // n = n >> 1; CF = originalN & 4 + SHRQ $0x01, BX + JNC handleTiny8 + + // tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:] + MOVLQZX (DX), SI + ADDQ $0x04, DX + +handleTiny8: + // n = n >> 1; CF = originalN & 8 + SHRQ $0x01, BX + JNC handleTiny16 + + // tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:] + MOVQ (DX), DI + ADDQ $0x08, DX + +handleTiny16: + // n = n >> 1; CF = originalN & 16 + // n == 0 now, otherwise we would have branched after comparing with tinyBufferSize + SHRQ $0x01, BX + JNC handleTinyFinish + ADDQ (DX), AX + ADCQ 8(DX), AX + +handleTinyFinish: + // CF should be included from the previous add, so we use ADCQ. + // If we arrived via the JNC above, then CF=0 due to the branch condition, + // so ADCQ will still produce the correct result. + ADCQ CX, AX + ADCQ SI, AX + ADCQ DI, AX + JMP foldAndReturn + +bufferIsNotTiny: + // Number of 256 byte iterations into loop counter + MOVQ BX, CX + + // Update number of bytes remaining after the loop completes + ANDQ $0xff, BX + SHRQ $0x08, CX + JZ startCleanup + CLC + XORQ SI, SI + XORQ DI, DI + XORQ R8, R8 + XORQ R9, R9 + XORQ R10, R10 + XORQ R11, R11 + XORQ R12, R12 + +bigLoop: + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ $0x00, SI + ADDQ 32(DX), DI + ADCQ 40(DX), DI + ADCQ 48(DX), DI + ADCQ 56(DX), DI + ADCQ $0x00, R8 + ADDQ 64(DX), R9 + ADCQ 72(DX), R9 + ADCQ 80(DX), R9 + ADCQ 88(DX), R9 + ADCQ $0x00, R10 + ADDQ 96(DX), R11 + ADCQ 104(DX), R11 + ADCQ 112(DX), R11 + ADCQ 120(DX), R11 + ADCQ $0x00, R12 + ADDQ 128(DX), AX + ADCQ 136(DX), AX + ADCQ 144(DX), AX + ADCQ 152(DX), AX + ADCQ $0x00, SI + ADDQ 160(DX), DI + ADCQ 168(DX), DI + ADCQ 176(DX), DI + ADCQ 184(DX), DI + ADCQ $0x00, R8 + ADDQ 192(DX), R9 + ADCQ 200(DX), R9 + ADCQ 208(DX), R9 + ADCQ 216(DX), R9 + ADCQ $0x00, R10 + ADDQ 224(DX), R11 + ADCQ 232(DX), R11 + ADCQ 240(DX), R11 + ADCQ 248(DX), R11 + ADCQ $0x00, R12 + ADDQ $0x00000100, DX + SUBQ $0x01, CX + JNZ bigLoop + ADDQ SI, AX + ADCQ DI, AX + ADCQ R8, AX + ADCQ R9, AX + ADCQ R10, AX + ADCQ R11, AX + ADCQ R12, AX + + // accumulate CF (twice, in case the first time overflows) + ADCQ $0x00, AX + ADCQ $0x00, AX + +startCleanup: + // Accumulate carries in this register. It is never expected to overflow. + XORQ SI, SI + + // We will perform an overlapped read for buffers with length not a multiple of 8. + // Overlapped in this context means some memory will be read twice, but a shift will + // eliminate the duplicated data. This extra read is performed at the end of the buffer to + // preserve any alignment that may exist for the start of the buffer. + MOVQ BX, CX + SHRQ $0x03, BX + ANDQ $0x07, CX + JZ handleRemaining8 + LEAQ (DX)(BX*8), DI + MOVQ -8(DI)(CX*1), DI + + // Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8) + SHLQ $0x03, CX + NEGQ CX + ADDQ $0x40, CX + SHRQ CL, DI + ADDQ DI, AX + ADCQ $0x00, SI + +handleRemaining8: + SHRQ $0x01, BX + JNC handleRemaining16 + ADDQ (DX), AX + ADCQ $0x00, SI + ADDQ $0x08, DX + +handleRemaining16: + SHRQ $0x01, BX + JNC handleRemaining32 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ $0x00, SI + ADDQ $0x10, DX + +handleRemaining32: + SHRQ $0x01, BX + JNC handleRemaining64 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ $0x00, SI + ADDQ $0x20, DX + +handleRemaining64: + SHRQ $0x01, BX + JNC handleRemaining128 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ $0x00, SI + ADDQ $0x40, DX + +handleRemaining128: + SHRQ $0x01, BX + JNC handleRemainingComplete + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ 64(DX), AX + ADCQ 72(DX), AX + ADCQ 80(DX), AX + ADCQ 88(DX), AX + ADCQ 96(DX), AX + ADCQ 104(DX), AX + ADCQ 112(DX), AX + ADCQ 120(DX), AX + ADCQ $0x00, SI + ADDQ $0x80, DX + +handleRemainingComplete: + ADDQ SI, AX + +foldAndReturn: + // add CF and fold + MOVL AX, CX + ADCQ $0x00, CX + SHRQ $0x20, AX + ADDQ CX, AX + MOVWQZX AX, CX + SHRQ $0x10, AX + ADDQ CX, AX + MOVW AX, CX + SHRQ $0x10, AX + ADDW CX, AX + ADCW $0x00, AX + XCHGB AH, AL + MOVW AX, ret+32(FP) + RET diff --git a/tun/checksum_generic.go b/tun/checksum_generic.go new file mode 100644 index 000000000..d0bfb697c --- /dev/null +++ b/tun/checksum_generic.go @@ -0,0 +1,15 @@ +// This file contains IP checksum algorithms that are not specific to any +// architecture and don't use hardware acceleration. + +//go:build !amd64 + +package tun + +import "strconv" + +func checksum(data []byte, initial uint16) uint16 { + if strconv.IntSize < 64 { + return checksumGeneric32(data, initial) + } + return checksumGeneric64(data, initial) +} diff --git a/tun/checksum_generic_test.go b/tun/checksum_generic_test.go index a0c945740..401a7bb88 100644 --- a/tun/checksum_generic_test.go +++ b/tun/checksum_generic_test.go @@ -1,9 +1,26 @@ +//go:build !amd64 + package tun var archChecksumFuncs = []archChecksumDetails{ { - name: "generic", + name: "generic32", + available: true, + f: checksumGeneric32, + }, + { + name: "generic32Alternate", + available: true, + f: checksumGeneric32Alternate, + }, + { + name: "generic64", + available: true, + f: checksumGeneric64, + }, + { + name: "generic64Alternate", available: true, - f: checksum, + f: checksumGeneric64Alternate, }, } diff --git a/tun/checksum_test.go b/tun/checksum_test.go index b3a358378..c40efc996 100644 --- a/tun/checksum_test.go +++ b/tun/checksum_test.go @@ -18,7 +18,7 @@ import ( type archChecksumDetails struct { name string available bool - f func([]byte, uint64) uint16 + f func([]byte, uint16) uint16 } func deterministicRandomBytes(seed int64, length int) []byte { @@ -402,7 +402,7 @@ func TestChecksum(t *testing.T) { if !fd.available { t.Skip("can not run on this system") } - if got := fd.f(tt.data, uint64(tt.initial)); got != tt.want { + if got := fd.f(tt.data, tt.initial); got != tt.want { t.Errorf("%s checksum = %04x, want %04x", fd.name, got, tt.want) } }) @@ -444,20 +444,28 @@ func TestPseudoHeaderChecksumNoFold(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotNoFold := pseudoHeaderChecksumNoFold(tt.protocol, tt.srcAddr, tt.dstAddr, tt.totalLen) - got := checksum([]byte{}, gotNoFold) - if got != tt.want { - t.Errorf("pseudoHeaderChecksumNoFold() = %x, folds to %04x, want %04x", gotNoFold, got, tt.want) - } - - got = header.PseudoHeaderChecksum( - tcpip.TransportProtocolNumber(tt.protocol), - tcpip.AddrFromSlice(tt.srcAddr), - tcpip.AddrFromSlice(tt.dstAddr), - tt.totalLen) - if got != tt.want { - t.Errorf("header.PseudoHeaderChecksum() = %04x, want %04x", got, tt.want) - } + t.Run("pseudoHeaderChecksum32", func(t *testing.T) { + got := pseudoHeaderChecksum32(tt.protocol, tt.srcAddr, tt.dstAddr, tt.totalLen) + if got != tt.want { + t.Errorf("got %04x, want %04x", got, tt.want) + } + }) + t.Run("pseudoHeaderChecksum64", func(t *testing.T) { + got := pseudoHeaderChecksum64(tt.protocol, tt.srcAddr, tt.dstAddr, tt.totalLen) + if got != tt.want { + t.Errorf("got %04x, want %04x", got, tt.want) + } + }) + t.Run("reference", func(t *testing.T) { + got := header.PseudoHeaderChecksum( + tcpip.TransportProtocolNumber(tt.protocol), + tcpip.AddrFromSlice(tt.srcAddr), + tcpip.AddrFromSlice(tt.dstAddr), + tt.totalLen) + if got != tt.want { + t.Errorf("got %04x, want %04x", got, tt.want) + } + }) }) } } @@ -482,7 +490,7 @@ func FuzzChecksum(f *testing.F) { if !fd.available { t.Skip("can not run on this system") } - if got := fd.f(data, uint64(initial)); got != want { + if got := fd.f(data, initial); got != want { t.Errorf("%s checksum = %04x, want %04x", fd.name, got, want) } }) @@ -491,7 +499,6 @@ func FuzzChecksum(f *testing.F) { } var result uint16 -var result64 uint64 func BenchmarkChecksum(b *testing.B) { offsets := []int{ // offsets from page alignment @@ -590,9 +597,16 @@ func BenchmarkPseudoHeaderChecksum(b *testing.B) { } for _, tt := range tests { b.Run(tt.name, func(b *testing.B) { - for i := 0; i < b.N; i++ { - result64 += pseudoHeaderChecksumNoFold(tt.protocol, tt.srcAddr, tt.dstAddr, tt.totalLen) - } + b.Run("pseudoHeaderChecksum32", func(b *testing.B) { + for i := 0; i < b.N; i++ { + result += pseudoHeaderChecksum32(tt.protocol, tt.srcAddr, tt.dstAddr, tt.totalLen) + } + }) + b.Run("pseudoHeaderChecksum64", func(b *testing.B) { + for i := 0; i < b.N; i++ { + result += pseudoHeaderChecksum64(tt.protocol, tt.srcAddr, tt.dstAddr, tt.totalLen) + } + }) }) } } diff --git a/tun/generate_amd64.go b/tun/generate_amd64.go new file mode 100644 index 000000000..543e21132 --- /dev/null +++ b/tun/generate_amd64.go @@ -0,0 +1,579 @@ +//go:build ignore + +//go:generate go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go + +package main + +import ( + "fmt" + "math" + "math/bits" + + . "github.com/mmcloughlin/avo/build" + "github.com/mmcloughlin/avo/operand" + "github.com/mmcloughlin/avo/reg" +) + +const checksumSignature = "func(b []byte, initial uint16) uint16" + +func loadParams() (accum, buf, n reg.GPVirtual) { + accum, buf, n = GP64(), GP64(), GP64() + Load(Param("initial"), accum) + XCHGB(accum.As8H(), accum.As8L()) + Load(Param("b").Base(), buf) + Load(Param("b").Len(), n) + return +} + +type simdStrategy int + +const ( + sse2 = iota + avx2 +) + +const tinyBufferSize = 31 // A buffer is tiny if it has at most 31 bytes. + +func generateSIMDChecksum(name, doc string, minSIMDSize, chains int, strategy simdStrategy) { + TEXT(name, NOSPLIT|NOFRAME, checksumSignature) + Pragma("noescape") + Doc(doc) + + accum64, buf, n := loadParams() + + handleOddLength(n, buf, accum64) + // no chance of overflow because accum64 was initialized by a uint16 and + // handleOddLength adds at most a uint8 + handleTinyBuffers(n, buf, accum64, operand.LabelRef("foldAndReturn"), operand.LabelRef("bufferIsNotTiny")) + Label("bufferIsNotTiny") + + const simdReadSize = 16 + + if minSIMDSize > tinyBufferSize { + Comment("skip all SIMD for small buffers") + if minSIMDSize <= math.MaxUint8 { + CMPQ(n, operand.U8(minSIMDSize)) + } else { + CMPQ(n, operand.U32(minSIMDSize)) + } + JGE(operand.LabelRef("startSIMD")) + + handleRemaining(n, buf, accum64, minSIMDSize-1) + JMP(operand.LabelRef("foldAndReturn")) + } + + Label("startSIMD") + + // chains is the number of accumulators to use. This improves speed via + // reduced data dependency. We combine the accumulators once when the big + // loop is complete. + simdAccumulate := make([]reg.VecVirtual, chains) + for i := range simdAccumulate { + switch strategy { + case sse2: + simdAccumulate[i] = XMM() + PXOR(simdAccumulate[i], simdAccumulate[i]) + case avx2: + simdAccumulate[i] = YMM() + VPXOR(simdAccumulate[i], simdAccumulate[i], simdAccumulate[i]) + } + } + var zero reg.VecVirtual + if strategy == sse2 { + zero = XMM() + PXOR(zero, zero) + } + + // Number of loads per big loop + const unroll = 16 + // Number of bytes + loopSize := uint64(simdReadSize * unroll) + if bits.Len64(loopSize) != bits.Len64(loopSize-1)+1 { + panic("loopSize is not a power of 2") + } + loopCount := GP64() + + MOVQ(n, loopCount) + Comment("Update number of bytes remaining after the loop completes") + ANDQ(operand.Imm(loopSize-1), n) + Comment(fmt.Sprintf("Number of %d byte iterations", loopSize)) + SHRQ(operand.Imm(uint64(bits.Len64(loopSize-1))), loopCount) + JZ(operand.LabelRef("smallLoop")) + Label("bigLoop") + for i := 0; i < unroll; i++ { + chain := i % chains + switch strategy { + case sse2: + sse2AccumulateStep(i*simdReadSize, buf, zero, simdAccumulate[chain], simdAccumulate[(chain+chains/2)%chains]) + case avx2: + avx2AccumulateStep(i*simdReadSize, buf, simdAccumulate[chain]) + } + } + ADDQ(operand.U32(loopSize), buf) + DECQ(loopCount) + JNZ(operand.LabelRef("bigLoop")) + + Label("bigCleanup") + + CMPQ(n, operand.Imm(uint64(simdReadSize))) + JLT(operand.LabelRef("doneSmallLoop")) + + Commentf("now read a single %d byte unit of data at a time", simdReadSize) + Label("smallLoop") + + switch strategy { + case sse2: + sse2AccumulateStep(0, buf, zero, simdAccumulate[0], simdAccumulate[1]) + case avx2: + avx2AccumulateStep(0, buf, simdAccumulate[0]) + } + ADDQ(operand.Imm(uint64(simdReadSize)), buf) + SUBQ(operand.Imm(uint64(simdReadSize)), n) + CMPQ(n, operand.Imm(uint64(simdReadSize))) + JGE(operand.LabelRef("smallLoop")) + + Label("doneSmallLoop") + CMPQ(n, operand.Imm(0)) + JE(operand.LabelRef("doneSIMD")) + + Commentf("There are between 1 and %d bytes remaining. Perform an overlapped read.", simdReadSize-1) + + maskDataPtr := GP64() + LEAQ(operand.NewDataAddr(operand.NewStaticSymbol("xmmLoadMasks"), 0), maskDataPtr) + dataAddr := operand.Mem{Index: n, Scale: 1, Base: buf, Disp: -simdReadSize} + // scale 8 is only correct here because n is guaranteed to be even and we + // do not generate masks for odd lengths + maskAddr := operand.Mem{Base: maskDataPtr, Index: n, Scale: 8, Disp: -16} + remainder := XMM() + + switch strategy { + case sse2: + MOVOU(dataAddr, remainder) + PAND(maskAddr, remainder) + low := XMM() + MOVOA(remainder, low) + PUNPCKHWL(zero, remainder) + PUNPCKLWL(zero, low) + PADDD(remainder, simdAccumulate[0]) + PADDD(low, simdAccumulate[1]) + case avx2: + // Note: this is very similar to the sse2 path but MOVOU has a massive + // performance hit if used here, presumably due to switching between SSE + // and AVX2 modes. + VMOVDQU(dataAddr, remainder) + VPAND(maskAddr, remainder, remainder) + + temp := YMM() + VPMOVZXWD(remainder, temp) + VPADDD(temp, simdAccumulate[0], simdAccumulate[0]) + } + + Label("doneSIMD") + + Comment("Multi-chain loop is done, combine the accumulators") + for i := range simdAccumulate { + if i == 0 { + continue + } + switch strategy { + case sse2: + PADDD(simdAccumulate[i], simdAccumulate[0]) + case avx2: + VPADDD(simdAccumulate[i], simdAccumulate[0], simdAccumulate[0]) + } + } + + if strategy == avx2 { + Comment("extract the YMM into a pair of XMM and sum them") + tmp := YMM() + VEXTRACTI128(operand.Imm(1), simdAccumulate[0], tmp.AsX()) + + xAccumulate := XMM() + VPADDD(simdAccumulate[0].AsX(), tmp.AsX(), xAccumulate) + simdAccumulate = []reg.VecVirtual{xAccumulate} + } + + Comment("extract the XMM into GP64") + low, high := GP64(), GP64() + switch strategy { + case sse2: + MOVQ(simdAccumulate[0], low) + PSRLDQ(operand.Imm(8), simdAccumulate[0]) + MOVQ(simdAccumulate[0], high) + case avx2: + VPEXTRQ(operand.Imm(0), simdAccumulate[0], low) + VPEXTRQ(operand.Imm(1), simdAccumulate[0], high) + + Comment("no more AVX code, clear upper registers to avoid SSE slowdowns") + VZEROUPPER() + } + ADDQ(low, accum64) + ADCQ(high, accum64) + Label("foldAndReturn") + foldWithCF(accum64, strategy == avx2) + XCHGB(accum64.As8H(), accum64.As8L()) + Store(accum64.As16(), ReturnIndex(0)) + RET() +} + +// handleOddLength generates instructions to incorporate the last byte into +// accum64 if the length is odd. CF may be set if accum64 overflows; be sure to +// handle that if overflow is possible. +func handleOddLength(n, buf, accum64 reg.GPVirtual) { + Comment("handle odd length buffers; they are difficult to handle in general") + TESTQ(operand.U32(1), n) + JZ(operand.LabelRef("lengthIsEven")) + + tmp := GP64() + MOVBQZX(operand.Mem{Base: buf, Index: n, Scale: 1, Disp: -1}, tmp) + DECQ(n) + ADDQ(tmp, accum64) + + Label("lengthIsEven") +} + +func sse2AccumulateStep(offset int, buf reg.GPVirtual, zero, accumulate1, accumulate2 reg.VecVirtual) { + high, low := XMM(), XMM() + MOVOU(operand.Mem{Disp: offset, Base: buf}, high) + MOVOA(high, low) + PUNPCKHWL(zero, high) + PUNPCKLWL(zero, low) + PADDD(high, accumulate1) + PADDD(low, accumulate2) +} + +func avx2AccumulateStep(offset int, buf reg.GPVirtual, accumulate reg.VecVirtual) { + tmp := YMM() + VPMOVZXWD(operand.Mem{Disp: offset, Base: buf}, tmp) + VPADDD(tmp, accumulate, accumulate) +} + +func generateAMD64Checksum(name, doc string) { + TEXT(name, NOSPLIT|NOFRAME, checksumSignature) + Pragma("noescape") + Doc(doc) + + accum64, buf, n := loadParams() + + handleOddLength(n, buf, accum64) + // no chance of overflow because accum64 was initialized by a uint16 and + // handleOddLength adds at most a uint8 + handleTinyBuffers(n, buf, accum64, operand.LabelRef("foldAndReturn"), operand.LabelRef("bufferIsNotTiny")) + Label("bufferIsNotTiny") + + const ( + // numChains is the number of accumulators and carry counters to use. + // This improves speed via reduced data dependency. We combine the + // accumulators and carry counters once when the loop is complete. + numChains = 4 + unroll = 32 // The number of 64-bit reads to perform per iteration of the loop. + loopSize = 8 * unroll // The number of bytes read per iteration of the loop. + ) + if bits.Len(loopSize) != bits.Len(loopSize-1)+1 { + panic("loopSize is not a power of 2") + } + loopCount := GP64() + + Comment(fmt.Sprintf("Number of %d byte iterations into loop counter", loopSize)) + MOVQ(n, loopCount) + Comment("Update number of bytes remaining after the loop completes") + ANDQ(operand.Imm(loopSize-1), n) + SHRQ(operand.Imm(uint64(bits.Len(loopSize-1))), loopCount) + JZ(operand.LabelRef("startCleanup")) + CLC() + + chains := make([]struct { + accum reg.GPVirtual + carries reg.GPVirtual + }, numChains) + for i := range chains { + if i == 0 { + chains[i].accum = accum64 + } else { + chains[i].accum = GP64() + XORQ(chains[i].accum, chains[i].accum) + } + chains[i].carries = GP64() + XORQ(chains[i].carries, chains[i].carries) + } + + Label("bigLoop") + + var curChain int + for i := 0; i < unroll; i++ { + // It is significantly faster to use a ADCX/ADOX pair instead of plain + // ADC, which results in two dependency chains, however those require + // ADX support, which was added after AVX2. If AVX2 is available, that's + // even better than ADCX/ADOX. + // + // However, multiple dependency chains using multiple accumulators and + // occasionally storing CF into temporary counters seems to work almost + // as well. + addr := operand.Mem{Disp: i * 8, Base: buf} + + if i%4 == 0 { + if i > 0 { + ADCQ(operand.Imm(0), chains[curChain].carries) + curChain = (curChain + 1) % len(chains) + } + ADDQ(addr, chains[curChain].accum) + } else { + ADCQ(addr, chains[curChain].accum) + } + } + ADCQ(operand.Imm(0), chains[curChain].carries) + ADDQ(operand.U32(loopSize), buf) + SUBQ(operand.Imm(1), loopCount) + JNZ(operand.LabelRef("bigLoop")) + for i := range chains { + if i == 0 { + ADDQ(chains[i].carries, accum64) + continue + } + ADCQ(chains[i].accum, accum64) + ADCQ(chains[i].carries, accum64) + } + + accumulateCF(accum64) + + Label("startCleanup") + handleRemaining(n, buf, accum64, loopSize-1) + Label("foldAndReturn") + foldWithCF(accum64, false) + + XCHGB(accum64.As8H(), accum64.As8L()) + Store(accum64.As16(), ReturnIndex(0)) + RET() +} + +// handleTinyBuffers computes checksums if the buffer length (the n parameter) +// is less than 32. After computing the checksum, a jump to returnLabel will +// be executed. Otherwise, if the buffer length is at least 32, nothing will be +// modified; a jump to continueLabel will be executed instead. +// +// When jumping to returnLabel, CF may be set and must be accommodated e.g. +// using foldWithCF or accumulateCF. +// +// Anecdotally, this appears to be faster than attempting to coordinate an +// overlapped read (which would also require special handling for buffers +// smaller than 8). +func handleTinyBuffers(n, buf, accum reg.GPVirtual, returnLabel, continueLabel operand.LabelRef) { + Comment("handle tiny buffers (<=31 bytes) specially") + CMPQ(n, operand.Imm(tinyBufferSize)) + JGT(continueLabel) + + tmp2, tmp4, tmp8 := GP64(), GP64(), GP64() + XORQ(tmp2, tmp2) + XORQ(tmp4, tmp4) + XORQ(tmp8, tmp8) + + Comment("shift twice to start because length is guaranteed to be even", + "n = n >> 2; CF = originalN & 2") + SHRQ(operand.Imm(2), n) + JNC(operand.LabelRef("handleTiny4")) + Comment("tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]") + MOVWQZX(operand.Mem{Base: buf}, tmp2) + ADDQ(operand.Imm(2), buf) + + Label("handleTiny4") + Comment("n = n >> 1; CF = originalN & 4") + SHRQ(operand.Imm(1), n) + JNC(operand.LabelRef("handleTiny8")) + Comment("tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]") + MOVLQZX(operand.Mem{Base: buf}, tmp4) + ADDQ(operand.Imm(4), buf) + + Label("handleTiny8") + Comment("n = n >> 1; CF = originalN & 8") + SHRQ(operand.Imm(1), n) + JNC(operand.LabelRef("handleTiny16")) + Comment("tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]") + MOVQ(operand.Mem{Base: buf}, tmp8) + ADDQ(operand.Imm(8), buf) + + Label("handleTiny16") + Comment("n = n >> 1; CF = originalN & 16", + "n == 0 now, otherwise we would have branched after comparing with tinyBufferSize") + SHRQ(operand.Imm(1), n) + JNC(operand.LabelRef("handleTinyFinish")) + ADDQ(operand.Mem{Base: buf}, accum) + ADCQ(operand.Mem{Base: buf, Disp: 8}, accum) + + Label("handleTinyFinish") + Comment("CF should be included from the previous add, so we use ADCQ.", + "If we arrived via the JNC above, then CF=0 due to the branch condition,", + "so ADCQ will still produce the correct result.") + ADCQ(tmp2, accum) + ADCQ(tmp4, accum) + ADCQ(tmp8, accum) + + JMP(returnLabel) +} + +// handleRemaining generates a series of conditional unrolled additions, +// starting with 8 bytes long and doubling each time until the length reaches +// max. This is the reverse order of what may be intuitive, but makes the branch +// conditions convenient to compute: perform one right shift each time and test +// against CF. +// +// When done, CF may be set and must be accommodated e.g., using foldWithCF or +// accumulateCF. +// +// If n is not a multiple of 8, an extra 64 bit read at the end of the buffer +// will be performed, overlapping with data that will be read later. The +// duplicate data will be shifted off. +// +// The original buffer length must have been at least 8 bytes long, even if +// n < 8, otherwise this will access memory before the start of the buffer, +// which may be unsafe. +func handleRemaining(n, buf, accum64 reg.GPVirtual, max int) { + Comment("Accumulate carries in this register. It is never expected to overflow.") + carries := GP64() + XORQ(carries, carries) + + Comment("We will perform an overlapped read for buffers with length not a multiple of 8.", + "Overlapped in this context means some memory will be read twice, but a shift will", + "eliminate the duplicated data. This extra read is performed at the end of the buffer to", + "preserve any alignment that may exist for the start of the buffer.") + leftover := reg.RCX + MOVQ(n, leftover) + SHRQ(operand.Imm(3), n) // n is now the number of 64 bit reads remaining + ANDQ(operand.Imm(0x7), leftover) // leftover is now the number of bytes to read from the end + JZ(operand.LabelRef("handleRemaining8")) + endBuf := GP64() + // endBuf is the position near the end of the buffer that is just past the + // last multiple of 8: (buf + len(buf)) & ^0x7 + LEAQ(operand.Mem{Base: buf, Index: n, Scale: 8}, endBuf) + + overlapRead := GP64() + // equivalent to overlapRead = binary.LittleEndian.Uint64(buf[len(buf)-8:len(buf)]) + MOVQ(operand.Mem{Base: endBuf, Index: leftover, Scale: 1, Disp: -8}, overlapRead) + + Comment("Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)") + SHLQ(operand.Imm(3), leftover) // leftover = leftover * 8 + NEGQ(leftover) // leftover = -leftover; this completes the (-leftoverBytes*8) part of the expression + ADDQ(operand.Imm(64), leftover) // now we have (64 - leftoverBytes*8) + SHRQ(reg.CL, overlapRead) // shift right by (64 - leftoverBytes*8); CL is the low 8 bits of leftover (set to RCX above) and variable shift only accepts CL + + ADDQ(overlapRead, accum64) + ADCQ(operand.Imm(0), carries) + + for curBytes := 8; curBytes <= max; curBytes *= 2 { + Label(fmt.Sprintf("handleRemaining%d", curBytes)) + SHRQ(operand.Imm(1), n) + if curBytes*2 <= max { + JNC(operand.LabelRef(fmt.Sprintf("handleRemaining%d", curBytes*2))) + } else { + JNC(operand.LabelRef("handleRemainingComplete")) + } + + numLoads := curBytes / 8 + for i := 0; i < numLoads; i++ { + addr := operand.Mem{Base: buf, Disp: i * 8} + // It is possible to add the multiple dependency chains trick here + // that generateAMD64Checksum uses but anecdotally it does not + // appear to outweigh the cost. + if i == 0 { + ADDQ(addr, accum64) + continue + } + ADCQ(addr, accum64) + } + ADCQ(operand.Imm(0), carries) + + if curBytes > math.MaxUint8 { + ADDQ(operand.U32(uint64(curBytes)), buf) + } else { + ADDQ(operand.U8(uint64(curBytes)), buf) + } + if curBytes*2 >= max { + continue + } + JMP(operand.LabelRef(fmt.Sprintf("handleRemaining%d", curBytes*2))) + } + Label("handleRemainingComplete") + ADDQ(carries, accum64) +} + +func accumulateCF(accum64 reg.GPVirtual) { + Comment("accumulate CF (twice, in case the first time overflows)") + // accum64 += CF + ADCQ(operand.Imm(0), accum64) + // accum64 += CF again if the previous add overflowed. The previous add was + // 0 or 1. If it overflowed, then accum64 == 0, so adding another 1 can + // never overflow. + ADCQ(operand.Imm(0), accum64) +} + +// foldWithCF generates instructions to fold accum (a GP64) into a 16-bit value +// according to ones-complement arithmetic. BMI2 instructions will be used if +// allowBMI2 is true (requires fewer instructions). +func foldWithCF(accum reg.GPVirtual, allowBMI2 bool) { + Comment("add CF and fold") + + // CF|accum max value starts as 0x1_ffff_ffff_ffff_ffff + + tmp := GP64() + if allowBMI2 { + // effectively, tmp = accum >> 32 (technically, this is a rotate) + RORXQ(operand.Imm(32), accum, tmp) + // accum as uint32 = uint32(accum) + uint32(tmp64) + CF; max value 0xffff_ffff + CF set + ADCL(tmp.As32(), accum.As32()) + // effectively, tmp64 as uint32 = uint32(accum) >> 16 (also a rotate) + RORXL(operand.Imm(16), accum.As32(), tmp.As32()) + // accum as uint16 = uint16(accum) + uint16(tmp) + CF; max value 0xffff + CF unset or 0xfffe + CF set + ADCW(tmp.As16(), accum.As16()) + } else { + // tmp = uint32(accum); max value 0xffff_ffff + // MOVL clears the upper 32 bits of a GP64 so this is equivalent to the + // non-existent MOVLQZX. + MOVL(accum.As32(), tmp.As32()) + // tmp += CF; max value 0x1_0000_0000, CF unset + ADCQ(operand.Imm(0), tmp) + // accum = accum >> 32; max value 0xffff_ffff + SHRQ(operand.Imm(32), accum) + // accum = accum + tmp; max value 0x1_ffff_ffff + CF unset + ADDQ(tmp, accum) + // tmp = uint16(accum); max value 0xffff + MOVWQZX(accum.As16(), tmp) + // accum = accum >> 16; max value 0x1_ffff + SHRQ(operand.Imm(16), accum) + // accum = accum + tmp; max value 0x2_fffe + CF unset + ADDQ(tmp, accum) + // tmp as uint16 = uint16(accum); max value 0xffff + MOVW(accum.As16(), tmp.As16()) + // accum = accum >> 16; max value 0x2 + SHRQ(operand.Imm(16), accum) + // accum as uint16 = uint16(accum) + uint16(tmp); max value 0xffff + CF unset or 0x2 + CF set + ADDW(tmp.As16(), accum.As16()) + } + // accum as uint16 += CF; will not overflow: either CF was 0 or accum <= 0xfffe + ADCW(operand.Imm(0), accum.As16()) +} + +func generateLoadMasks() { + var offset int + // xmmLoadMasks is a table of masks that can be used with PAND to zero all but the last N bytes in an XMM, N=2,4,6,8,10,12,14 + GLOBL("xmmLoadMasks", RODATA|NOPTR) + + for n := 2; n < 16; n += 2 { + var pattern [16]byte + for i := 0; i < len(pattern); i++ { + if i < len(pattern)-n { + pattern[i] = 0 + continue + } + pattern[i] = 0xff + } + DATA(offset, operand.String(pattern[:])) + offset += len(pattern) + } +} + +func main() { + generateLoadMasks() + generateSIMDChecksum("checksumAVX2", "checksumAVX2 computes an IP checksum using amd64 v3 instructions (AVX2, BMI2)", 256, 4, avx2) + generateSIMDChecksum("checksumSSE2", "checksumSSE2 computes an IP checksum using amd64 baseline instructions (SSE2)", 256, 4, sse2) + generateAMD64Checksum("checksumAMD64", "checksumAMD64 computes an IP checksum using amd64 baseline instructions") + Generate() +} diff --git a/tun/tcp_offload_linux.go b/tun/tcp_offload_linux.go index 67288237f..b023bbd60 100644 --- a/tun/tcp_offload_linux.go +++ b/tun/tcp_offload_linux.go @@ -260,8 +260,8 @@ func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool { addrSize = 16 } tcpTotalLen := uint16(len(pkt) - int(iphLen)) - tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen) - return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0 + tcpCSum := pseudoHeaderChecksum(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen) + return ^checksum(pkt[iphLen:], tcpCSum) == 0 } // coalesceResult represents the result of attempting to coalesce two TCP @@ -532,7 +532,7 @@ func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 srcAddrAt := offset + addrOffset srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] - psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + psum := pseudoHeaderChecksum(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) } else { hdr := virtioNetHdr{} @@ -643,8 +643,8 @@ func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffs // TCP checksum tcpHLen := int(hdr.hdrLen - hdr.csumStart) tcpLenForPseudo := uint16(tcpHLen + segmentDataLen) - tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo) - tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold) + tcpCSum := pseudoHeaderChecksum(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo) + tcpCSum = ^checksum(out[hdr.csumStart:totalLen], tcpCSum) binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum) nextSegmentDataAt += int(hdr.gsoSize) @@ -658,6 +658,6 @@ func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error { // checksum we compute. This is typically the pseudo-header checksum. initial := binary.BigEndian.Uint16(in[cSumAt:]) in[cSumAt], in[cSumAt+1] = 0, 0 - binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial))) + binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], initial)) return nil } From 6cd5922a04de9370d48cf742751d4cdcc6ce65e1 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 24 Aug 2023 14:39:23 -0700 Subject: [PATCH 10/39] all: adjust build tags for plan9 Signed-off-by: Brad Fitzpatrick --- conn/bind_std.go | 7 ++++++- conn/controlfns_unix.go | 2 +- conn/erraddrinuse.go | 14 ++++++++++++++ ipc/{uapi_wasm.go => uapi_fake.go} | 4 +++- rwcancel/rwcancel.go | 2 +- rwcancel/rwcancel_stub.go | 2 +- 6 files changed, 26 insertions(+), 5 deletions(-) create mode 100644 conn/erraddrinuse.go rename ipc/{uapi_wasm.go => uapi_fake.go} (72%) diff --git a/conn/bind_std.go b/conn/bind_std.go index cc5cf2311..428e52815 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -136,6 +136,11 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) { return conn.(*net.UDPConn), uaddr.Port, nil } +// errEADDRINUSE is syscall.EADDRINUSE, boxed into an interface once +// in erraddrinuse.go on almost all platforms. For other platforms, +// it's at least non-nil. +var errEADDRINUSE error = errors.New("") + func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { s.mu.Lock() defer s.mu.Unlock() @@ -162,7 +167,7 @@ again: // Listen on the same port as we're using for ipv4. v6conn, port, err = listenNet("udp6", port) - if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { + if uport == 0 && errors.Is(err, errEADDRINUSE) && tries < 100 { v4conn.Close() tries++ goto again diff --git a/conn/controlfns_unix.go b/conn/controlfns_unix.go index 91692c0a6..5cc4d98f9 100644 --- a/conn/controlfns_unix.go +++ b/conn/controlfns_unix.go @@ -1,4 +1,4 @@ -//go:build !windows && !linux && !wasm +//go:build !windows && !linux && !wasm && !plan9 /* SPDX-License-Identifier: MIT * diff --git a/conn/erraddrinuse.go b/conn/erraddrinuse.go new file mode 100644 index 000000000..751660edf --- /dev/null +++ b/conn/erraddrinuse.go @@ -0,0 +1,14 @@ +//go:build !plan9 + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import "syscall" + +func init() { + errEADDRINUSE = syscall.EADDRINUSE +} diff --git a/ipc/uapi_wasm.go b/ipc/uapi_fake.go similarity index 72% rename from ipc/uapi_wasm.go rename to ipc/uapi_fake.go index fa84684aa..e68863d4b 100644 --- a/ipc/uapi_wasm.go +++ b/ipc/uapi_fake.go @@ -1,3 +1,5 @@ +//go:build wasm || plan9 + /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. @@ -5,7 +7,7 @@ package ipc -// Made up sentinel error codes for {js,wasip1}/wasm. +// Made up sentinel error codes for {js,wasip1}/wasm, and plan9. const ( IpcErrorIO = 1 IpcErrorInvalid = 2 diff --git a/rwcancel/rwcancel.go b/rwcancel/rwcancel.go index e397c0e8a..dd649d49f 100644 --- a/rwcancel/rwcancel.go +++ b/rwcancel/rwcancel.go @@ -1,4 +1,4 @@ -//go:build !windows && !wasm +//go:build !windows && !wasm && !plan9 /* SPDX-License-Identifier: MIT * diff --git a/rwcancel/rwcancel_stub.go b/rwcancel/rwcancel_stub.go index 2a98b2b4a..46238014c 100644 --- a/rwcancel/rwcancel_stub.go +++ b/rwcancel/rwcancel_stub.go @@ -1,4 +1,4 @@ -//go:build windows || wasm +//go:build windows || wasm || plan9 // SPDX-License-Identifier: MIT From 202a3401e7f83a2f809fefaa8231d79e0e97edae Mon Sep 17 00:00:00 2001 From: Andrea Barisani Date: Wed, 30 Aug 2023 14:05:01 +0200 Subject: [PATCH 11/39] adjust build tags for tamago --- conn/controlfns_unix.go | 2 +- ipc/uapi_tamago.go | 17 +++++++++++++++++ rwcancel/rwcancel.go | 2 +- rwcancel/rwcancel_stub.go | 2 +- 4 files changed, 20 insertions(+), 3 deletions(-) create mode 100644 ipc/uapi_tamago.go diff --git a/conn/controlfns_unix.go b/conn/controlfns_unix.go index 5cc4d98f9..144a8808f 100644 --- a/conn/controlfns_unix.go +++ b/conn/controlfns_unix.go @@ -1,4 +1,4 @@ -//go:build !windows && !linux && !wasm && !plan9 +//go:build !windows && !linux && !wasm && !plan9 && !tamago /* SPDX-License-Identifier: MIT * diff --git a/ipc/uapi_tamago.go b/ipc/uapi_tamago.go new file mode 100644 index 000000000..85a725a76 --- /dev/null +++ b/ipc/uapi_tamago.go @@ -0,0 +1,17 @@ +//go:build tamago + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package ipc + +// Made up sentinel error codes for tamago platform. +const ( + IpcErrorIO = 1 + IpcErrorInvalid = 2 + IpcErrorPortInUse = 3 + IpcErrorUnknown = 4 + IpcErrorProtocol = 5 +) diff --git a/rwcancel/rwcancel.go b/rwcancel/rwcancel.go index dd649d49f..ceb87e4da 100644 --- a/rwcancel/rwcancel.go +++ b/rwcancel/rwcancel.go @@ -1,4 +1,4 @@ -//go:build !windows && !wasm && !plan9 +//go:build !windows && !wasm && !plan9 && !tamago /* SPDX-License-Identifier: MIT * diff --git a/rwcancel/rwcancel_stub.go b/rwcancel/rwcancel_stub.go index 46238014c..60ae9af0e 100644 --- a/rwcancel/rwcancel_stub.go +++ b/rwcancel/rwcancel_stub.go @@ -1,4 +1,4 @@ -//go:build windows || wasm || plan9 +//go:build windows || wasm || plan9 || tamago // SPDX-License-Identifier: MIT From 2f6748dc88e777ff6eed22f5ce5d7658c6bb9410 Mon Sep 17 00:00:00 2001 From: James Tucker Date: Wed, 27 Sep 2023 14:52:21 -0700 Subject: [PATCH 12/39] tun: fix crash when ForceMTU is called after close Close closes the events channel, resulting in a panic from send on closed channel. Reported-By: Brad Fitzpatrick Link: https://github.com/tailscale/tailscale/issues/9555 Signed-off-by: James Tucker --- tun/tun_windows.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 0cb4ce192..34f29805d 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -127,6 +127,9 @@ func (tun *NativeTun) MTU() (int, error) { // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. func (tun *NativeTun) ForceMTU(mtu int) { + if tun.close.Load() { + return + } update := tun.forcedMTU != mtu tun.forcedMTU = mtu if update { From 24f8d7c9e7312590e1f2d8e0ca02c30b9d88e91a Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Fri, 6 Oct 2023 17:16:49 -0700 Subject: [PATCH 13/39] tun: implement UDP GSO/GRO for Linux Signed-off-by: Jordan Whited --- ...{tcp_offload_linux.go => offload_linux.go} | 595 ++++++++++++++---- ...ad_linux_test.go => offload_linux_test.go} | 405 ++++++++++-- ...65e4830d6dc087cab24cd1e154c2e790589a309b77 | 8 - ...6784411a8ce2e8e03aa3384105e581f2c67494700d | 8 - tun/tun_linux.go | 71 ++- 5 files changed, 849 insertions(+), 238 deletions(-) rename tun/{tcp_offload_linux.go => offload_linux.go} (50%) rename tun/{tcp_offload_linux_test.go => offload_linux_test.go} (52%) delete mode 100644 tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 delete mode 100644 tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d diff --git a/tun/tcp_offload_linux.go b/tun/offload_linux.go similarity index 50% rename from tun/tcp_offload_linux.go rename to tun/offload_linux.go index b023bbd60..54461cc9b 100644 --- a/tun/tcp_offload_linux.go +++ b/tun/offload_linux.go @@ -57,22 +57,23 @@ const ( virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{})) ) -// flowKey represents the key for a flow. -type flowKey struct { +// tcpFlowKey represents the key for a TCP flow. +type tcpFlowKey struct { srcAddr, dstAddr [16]byte srcPort, dstPort uint16 rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows. + isV6 bool } -// tcpGROTable holds flow and coalescing information for the purposes of GRO. +// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO. type tcpGROTable struct { - itemsByFlow map[flowKey][]tcpGROItem + itemsByFlow map[tcpFlowKey][]tcpGROItem itemsPool [][]tcpGROItem } func newTCPGROTable() *tcpGROTable { t := &tcpGROTable{ - itemsByFlow: make(map[flowKey][]tcpGROItem, conn.IdealBatchSize), + itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, conn.IdealBatchSize), itemsPool: make([][]tcpGROItem, conn.IdealBatchSize), } for i := range t.itemsPool { @@ -81,14 +82,15 @@ func newTCPGROTable() *tcpGROTable { return t } -func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey { - key := flowKey{} - addrSize := dstAddr - srcAddr - copy(key.srcAddr[:], pkt[srcAddr:dstAddr]) - copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize]) +func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey { + key := tcpFlowKey{} + addrSize := dstAddrOffset - srcAddrOffset + copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset]) + copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize]) key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:]) key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:]) key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:]) + key.isV6 = addrSize == 16 return key } @@ -96,7 +98,7 @@ func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey { // returning the packets found for the flow, or inserting a new one if none // is found. func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) { - key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) items, ok := t.itemsByFlow[key] if ok { return items, ok @@ -108,7 +110,7 @@ func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, t // insert an item in the table for the provided packet and packet metadata. func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) { - key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) item := tcpGROItem{ key: key, bufsIndex: uint16(bufsIndex), @@ -131,7 +133,7 @@ func (t *tcpGROTable) updateAt(item tcpGROItem, i int) { items[i] = item } -func (t *tcpGROTable) deleteAt(key flowKey, i int) { +func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) { items, _ := t.itemsByFlow[key] items = append(items[:i], items[i+1:]...) t.itemsByFlow[key] = items @@ -140,7 +142,7 @@ func (t *tcpGROTable) deleteAt(key flowKey, i int) { // tcpGROItem represents bookkeeping data for a TCP packet during the lifetime // of a GRO evaluation across a vector of packets. type tcpGROItem struct { - key flowKey + key tcpFlowKey sentSeq uint32 // the sequence number bufsIndex uint16 // the index into the original bufs slice numMerged uint16 // the number of packets merged into this item @@ -164,6 +166,103 @@ func (t *tcpGROTable) reset() { } } +// udpFlowKey represents the key for a UDP flow. +type udpFlowKey struct { + srcAddr, dstAddr [16]byte + srcPort, dstPort uint16 + isV6 bool +} + +// udpGROTable holds flow and coalescing information for the purposes of UDP GRO. +type udpGROTable struct { + itemsByFlow map[udpFlowKey][]udpGROItem + itemsPool [][]udpGROItem +} + +func newUDPGROTable() *udpGROTable { + u := &udpGROTable{ + itemsByFlow: make(map[udpFlowKey][]udpGROItem, conn.IdealBatchSize), + itemsPool: make([][]udpGROItem, conn.IdealBatchSize), + } + for i := range u.itemsPool { + u.itemsPool[i] = make([]udpGROItem, 0, conn.IdealBatchSize) + } + return u +} + +func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey { + key := udpFlowKey{} + addrSize := dstAddrOffset - srcAddrOffset + copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset]) + copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize]) + key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:]) + key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:]) + key.isV6 = addrSize == 16 + return key +} + +// lookupOrInsert looks up a flow for the provided packet and metadata, +// returning the packets found for the flow, or inserting a new one if none +// is found. +func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) { + key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) + items, ok := u.itemsByFlow[key] + if ok { + return items, ok + } + // TODO: insert() performs another map lookup. This could be rearranged to avoid. + u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false) + return nil, false +} + +// insert an item in the table for the provided packet and packet metadata. +func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) { + key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) + item := udpGROItem{ + key: key, + bufsIndex: uint16(bufsIndex), + gsoSize: uint16(len(pkt[udphOffset+udphLen:])), + iphLen: uint8(udphOffset), + cSumKnownInvalid: cSumKnownInvalid, + } + items, ok := u.itemsByFlow[key] + if !ok { + items = u.newItems() + } + items = append(items, item) + u.itemsByFlow[key] = items +} + +func (u *udpGROTable) updateAt(item udpGROItem, i int) { + items, _ := u.itemsByFlow[item.key] + items[i] = item +} + +// udpGROItem represents bookkeeping data for a UDP packet during the lifetime +// of a GRO evaluation across a vector of packets. +type udpGROItem struct { + key udpFlowKey + bufsIndex uint16 // the index into the original bufs slice + numMerged uint16 // the number of packets merged into this item + gsoSize uint16 // payload size + iphLen uint8 // ip header len + cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown. +} + +func (u *udpGROTable) newItems() []udpGROItem { + var items []udpGROItem + items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1] + return items +} + +func (u *udpGROTable) reset() { + for k, items := range u.itemsByFlow { + items = items[:0] + u.itemsPool = append(u.itemsPool, items) + delete(u.itemsByFlow, k) + } +} + // canCoalesce represents the outcome of checking if two TCP packets are // candidates for coalescing. type canCoalesce int @@ -174,6 +273,61 @@ const ( coalesceAppend canCoalesce = 1 ) +// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB +// meet all requirements to be merged as part of a GRO operation, otherwise it +// returns false. +func ipHeadersCanCoalesce(pktA, pktB []byte) bool { + if len(pktA) < 9 || len(pktB) < 9 { + return false + } + if pktA[0]>>4 == 6 { + if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 { + // cannot coalesce with unequal Traffic class values + return false + } + if pktA[7] != pktB[7] { + // cannot coalesce with unequal Hop limit values + return false + } + } else { + if pktA[1] != pktB[1] { + // cannot coalesce with unequal ToS values + return false + } + if pktA[6]>>5 != pktB[6]>>5 { + // cannot coalesce with unequal DF or reserved bits. MF is checked + // further up the stack. + return false + } + if pktA[8] != pktB[8] { + // cannot coalesce with unequal TTL values + return false + } + } + return true +} + +// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet +// described by item. iphLen and gsoSize describe pkt. bufs is the vector of +// packets involved in the current GRO evaluation. bufsOffset is the offset at +// which packet data begins within bufs. +func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { + pktTarget := bufs[item.bufsIndex][bufsOffset:] + if !ipHeadersCanCoalesce(pkt, pktTarget) { + return coalesceUnavailable + } + if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 { + // A smaller than gsoSize packet has been appended previously. + // Nothing can come after a smaller packet on the end. + return coalesceUnavailable + } + if gsoSize > item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + return coalesceAppend +} + // tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet // described by item. This function makes considerations that match the kernel's // GRO self tests, which can be found in tools/testing/selftests/net/gro.c. @@ -189,29 +343,8 @@ func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet return coalesceUnavailable } } - if pkt[0]>>4 == 6 { - if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 { - // cannot coalesce with unequal Traffic class values - return coalesceUnavailable - } - if pkt[7] != pktTarget[7] { - // cannot coalesce with unequal Hop limit values - return coalesceUnavailable - } - } else { - if pkt[1] != pktTarget[1] { - // cannot coalesce with unequal ToS values - return coalesceUnavailable - } - if pkt[6]>>5 != pktTarget[6]>>5 { - // cannot coalesce with unequal DF or reserved bits. MF is checked - // further up the stack. - return coalesceUnavailable - } - if pkt[8] != pktTarget[8] { - // cannot coalesce with unequal TTL values - return coalesceUnavailable - } + if !ipHeadersCanCoalesce(pkt, pktTarget) { + return coalesceUnavailable } // seq adjacency lhsLen := item.gsoSize @@ -252,16 +385,16 @@ func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet return coalesceUnavailable } -func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool { +func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool { srcAddrAt := ipv4SrcAddrOffset addrSize := 4 if isV6 { srcAddrAt = ipv6SrcAddrOffset addrSize = 16 } - tcpTotalLen := uint16(len(pkt) - int(iphLen)) - tcpCSum := pseudoHeaderChecksum(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen) - return ^checksum(pkt[iphLen:], tcpCSum) == 0 + lenForPseudo := uint16(len(pkt) - int(iphLen)) + cSum := pseudoHeaderChecksum(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo) + return ^checksum(pkt[iphLen:], cSum) == 0 } // coalesceResult represents the result of attempting to coalesce two TCP @@ -276,8 +409,36 @@ const ( coalesceSuccess ) +// coalesceUDPPackets attempts to coalesce pkt with the packet described by +// item, and returns the outcome. +func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { + pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front + headersLen := item.iphLen + udphLen + coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) + + if cap(pktHead)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if item.numMerged == 0 { + if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) { + return coalescePktInvalidCSum + } + extendBy := len(pkt) - int(headersLen) + bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) + copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) + + item.numMerged++ + return coalesceSuccess +} + // coalesceTCPPackets attempts to coalesce pkt with the packet described by -// item, returning the outcome. This function may swap bufs elements in the +// item, and returns the outcome. This function may swap bufs elements in the // event of a prepend as item's bufs index is already being tracked for writing // to a Device. func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { @@ -297,11 +458,11 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize return coalescePSHEnding } if item.numMerged == 0 { - if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) { + if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { return coalesceItemInvalidCSum } } - if !tcpChecksumValid(pkt, item.iphLen, isV6) { + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { return coalescePktInvalidCSum } item.sentSeq = seq @@ -319,11 +480,11 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize return coalesceInsufficientCap } if item.numMerged == 0 { - if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) { + if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { return coalesceItemInvalidCSum } } - if !tcpChecksumValid(pkt, item.iphLen, isV6) { + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { return coalescePktInvalidCSum } if pshSet { @@ -354,52 +515,52 @@ const ( maxUint16 = 1<<16 - 1 ) -type tcpGROResult int +type groResult int const ( - tcpGROResultNoop tcpGROResult = iota - tcpGROResultTableInsert - tcpGROResultCoalesced + groResultNoop groResult = iota + groResultTableInsert + groResultCoalesced ) // tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with -// existing packets tracked in table. It returns a tcpGROResultNoop when no -// action was taken, tcpGROResultTableInsert when the evaluated packet was -// inserted into table, and tcpGROResultCoalesced when the evaluated packet was +// existing packets tracked in table. It returns a groResultNoop when no +// action was taken, groResultTableInsert when the evaluated packet was +// inserted into table, and groResultCoalesced when the evaluated packet was // coalesced with another packet in table. -func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) tcpGROResult { +func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult { pkt := bufs[pktI][offset:] if len(pkt) > maxUint16 { // A valid IPv4 or IPv6 packet will never exceed this. - return tcpGROResultNoop + return groResultNoop } iphLen := int((pkt[0] & 0x0F) * 4) if isV6 { iphLen = 40 ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) if ipv6HPayloadLen != len(pkt)-iphLen { - return tcpGROResultNoop + return groResultNoop } } else { totalLen := int(binary.BigEndian.Uint16(pkt[2:])) if totalLen != len(pkt) { - return tcpGROResultNoop + return groResultNoop } } if len(pkt) < iphLen { - return tcpGROResultNoop + return groResultNoop } tcphLen := int((pkt[iphLen+12] >> 4) * 4) if tcphLen < 20 || tcphLen > 60 { - return tcpGROResultNoop + return groResultNoop } if len(pkt) < iphLen+tcphLen { - return tcpGROResultNoop + return groResultNoop } if !isV6 { if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { // no GRO support for fragmented segments for now - return tcpGROResultNoop + return groResultNoop } } tcpFlags := pkt[iphLen+tcpFlagsOffset] @@ -407,14 +568,14 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) // not a candidate if any non-ACK flags (except PSH+ACK) are set if tcpFlags != tcpFlagACK { if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { - return tcpGROResultNoop + return groResultNoop } pshSet = true } gsoSize := uint16(len(pkt) - tcphLen - iphLen) // not a candidate if payload len is 0 if gsoSize < 1 { - return tcpGROResultNoop + return groResultNoop } seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) srcAddrOffset := ipv4SrcAddrOffset @@ -425,7 +586,7 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) } items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) if !existing { - return tcpGROResultNoop + return groResultTableInsert } for i := len(items) - 1; i >= 0; i-- { // In the best case of packets arriving in order iterating in reverse is @@ -443,54 +604,25 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) switch result { case coalesceSuccess: table.updateAt(item, i) - return tcpGROResultCoalesced + return groResultCoalesced case coalesceItemInvalidCSum: // delete the item with an invalid csum table.deleteAt(item.key, i) case coalescePktInvalidCSum: // no point in inserting an item that we can't coalesce - return tcpGROResultNoop + return groResultNoop default: } } } // failed to coalesce with any other packets; store the item in the flow table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) - return tcpGROResultTableInsert -} - -func isTCP4NoIPOptions(b []byte) bool { - if len(b) < 40 { - return false - } - if b[0]>>4 != 4 { - return false - } - if b[0]&0x0F != 5 { - return false - } - if b[9] != unix.IPPROTO_TCP { - return false - } - return true + return groResultTableInsert } -func isTCP6NoEH(b []byte) bool { - if len(b) < 60 { - return false - } - if b[0]>>4 != 6 { - return false - } - if b[6] != unix.IPPROTO_TCP { - return false - } - return true -} - -// applyCoalesceAccounting updates bufs to account for coalescing based on the +// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the // metadata found in table. -func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 bool) error { +func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error { for _, items := range table.itemsByFlow { for _, item := range items { if item.numMerged > 0 { @@ -505,7 +637,7 @@ func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 // Recalculate the total len (IPv4) or payload len (IPv6). // Recalculate the (IPv4) header checksum. - if isV6 { + if item.key.isV6 { hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len } else { @@ -525,7 +657,7 @@ func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 // this with computation of the tcp header and payload checksum. addrLen := 4 addrOffset := ipv4SrcAddrOffset - if isV6 { + if item.key.isV6 { addrLen = 16 addrOffset = ipv6SrcAddrOffset } @@ -546,54 +678,244 @@ func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 return nil } +// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + udphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 6, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4 + if item.key.isV6 { + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Recalculate the UDP len field value + binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:]))) + + // Calculate the pseudo header checksum and place it at the UDP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the udp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if item.key.isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := pseudoHeaderChecksum(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + +type groCandidateType uint8 + +const ( + notGROCandidate groCandidateType = iota + tcp4GROCandidate + tcp6GROCandidate + udp4GROCandidate + udp6GROCandidate +) + +func packetIsGROCandidate(b []byte) groCandidateType { + if len(b) < 28 { + return notGROCandidate + } + if b[0]>>4 == 4 { + if b[0]&0x0F != 5 { + // IPv4 packets w/IP options do not coalesce + return notGROCandidate + } + if b[9] == unix.IPPROTO_TCP && len(b) >= 40 { + return tcp4GROCandidate + } + if b[9] == unix.IPPROTO_UDP { + return udp4GROCandidate + } + } else if b[0]>>4 == 6 { + if b[6] == unix.IPPROTO_TCP && len(b) >= 60 { + return tcp6GROCandidate + } + if b[6] == unix.IPPROTO_UDP && len(b) >= 48 { + return udp6GROCandidate + } + } + return notGROCandidate +} + +const ( + udphLen = 8 +) + +// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with +// existing packets tracked in table. It returns a groResultNoop when no +// action was taken, groResultTableInsert when the evaluated packet was +// inserted into table, and groResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult { + pkt := bufs[pktI][offset:] + if len(pkt) > maxUint16 { + // A valid IPv4 or IPv6 packet will never exceed this. + return groResultNoop + } + iphLen := int((pkt[0] & 0x0F) * 4) + if isV6 { + iphLen = 40 + ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) + if ipv6HPayloadLen != len(pkt)-iphLen { + return groResultNoop + } + } else { + totalLen := int(binary.BigEndian.Uint16(pkt[2:])) + if totalLen != len(pkt) { + return groResultNoop + } + } + if len(pkt) < iphLen { + return groResultNoop + } + if len(pkt) < iphLen+udphLen { + return groResultNoop + } + if !isV6 { + if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { + // no GRO support for fragmented segments for now + return groResultNoop + } + } + gsoSize := uint16(len(pkt) - udphLen - iphLen) + // not a candidate if payload len is 0 + if gsoSize < 1 { + return groResultNoop + } + srcAddrOffset := ipv4SrcAddrOffset + addrLen := 4 + if isV6 { + srcAddrOffset = ipv6SrcAddrOffset + addrLen = 16 + } + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI) + if !existing { + return groResultTableInsert + } + // With UDP we only check the last item, otherwise we could reorder packets + // for a given flow. We must also always insert a new item, or successfully + // coalesce with an existing item, for the same reason. + item := items[len(items)-1] + can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset) + var pktCSumKnownInvalid bool + if can == coalesceAppend { + result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6) + switch result { + case coalesceSuccess: + table.updateAt(item, len(items)-1) + return groResultCoalesced + case coalesceItemInvalidCSum: + // If the existing item has an invalid csum we take no action. A new + // item will be stored after it, and the existing item will never be + // revisited as part of future coalescing candidacy checks. + case coalescePktInvalidCSum: + // We must insert a new item, but we also mark it as invalid csum + // to prevent a repeat checksum validation. + pktCSumKnownInvalid = true + default: + } + } + // failed to coalesce with any other packets; store the item in the flow + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid) + return groResultTableInsert +} + // handleGRO evaluates bufs for GRO, and writes the indices of the resulting -// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be +// packets into toWrite. toWrite, tcpTable, and udpTable should initially be // empty (but non-nil), and are passed in to save allocs as the caller may reset // and recycle them across vectors of packets. -func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error { +func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, toWrite *[]int) error { for i := range bufs { if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { return errors.New("invalid offset") } - var result tcpGROResult - switch { - case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce - result = tcpGRO(bufs, offset, i, tcp4Table, false) - case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce - result = tcpGRO(bufs, offset, i, tcp6Table, true) + var result groResult + switch packetIsGROCandidate(bufs[i][offset:]) { + case tcp4GROCandidate: + result = tcpGRO(bufs, offset, i, tcpTable, false) + case tcp6GROCandidate: + result = tcpGRO(bufs, offset, i, tcpTable, true) + case udp4GROCandidate: + result = udpGRO(bufs, offset, i, udpTable, false) + case udp6GROCandidate: + result = udpGRO(bufs, offset, i, udpTable, true) } switch result { - case tcpGROResultNoop: + case groResultNoop: hdr := virtioNetHdr{} err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) if err != nil { return err } fallthrough - case tcpGROResultTableInsert: + case groResultTableInsert: *toWrite = append(*toWrite, i) } } - err4 := applyCoalesceAccounting(bufs, offset, tcp4Table, false) - err6 := applyCoalesceAccounting(bufs, offset, tcp6Table, true) - return errors.Join(err4, err6) + errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable) + errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable) + return errors.Join(errTCP, errUDP) } -// tcpTSO splits packets from in into outBuffs, writing the size of each +// gsoSplit splits packets from in into outBuffs, writing the size of each // element into sizes. It returns the number of buffers populated, and/or an // error. -func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) { +func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int, isV6 bool) (int, error) { iphLen := int(hdr.csumStart) srcAddrOffset := ipv6SrcAddrOffset addrLen := 16 - if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 { + if !isV6 { in[10], in[11] = 0, 0 // clear ipv4 header checksum srcAddrOffset = ipv4SrcAddrOffset addrLen = 4 } - tcpCSumAt := int(hdr.csumStart + hdr.csumOffset) - in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum - firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:]) + transportCsumAt := int(hdr.csumStart + hdr.csumOffset) + in[transportCsumAt], in[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum + var firstTCPSeqNum uint32 + var protocol uint8 + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV6 { + protocol = unix.IPPROTO_TCP + firstTCPSeqNum = binary.BigEndian.Uint32(in[hdr.csumStart+4:]) + } else { + protocol = unix.IPPROTO_UDP + } nextSegmentDataAt := int(hdr.hdrLen) i := 0 for ; nextSegmentDataAt < len(in); i++ { @@ -610,7 +932,7 @@ func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffs out := outBuffs[i][outOffset:] copy(out, in[:iphLen]) - if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 { + if !isV6 { // For IPv4 we are responsible for incrementing the ID field, // updating the total len field, and recalculating the header // checksum. @@ -627,25 +949,32 @@ func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffs binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen)) } - // TCP header + // copy transport header copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen]) - tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i)) - binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq) - if nextSegmentEnd != len(in) { - // FIN and PSH should only be set on last segment - clearFlags := tcpFlagFIN | tcpFlagPSH - out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags + + if protocol == unix.IPPROTO_TCP { + // set TCP seq and adjust TCP flags + tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i)) + binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq) + if nextSegmentEnd != len(in) { + // FIN and PSH should only be set on last segment + clearFlags := tcpFlagFIN | tcpFlagPSH + out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags + } + } else { + // set UDP header len + binary.BigEndian.PutUint16(out[hdr.csumStart+4:], uint16(segmentDataLen)+(hdr.hdrLen-hdr.csumStart)) } // payload copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd]) - // TCP checksum - tcpHLen := int(hdr.hdrLen - hdr.csumStart) - tcpLenForPseudo := uint16(tcpHLen + segmentDataLen) - tcpCSum := pseudoHeaderChecksum(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo) - tcpCSum = ^checksum(out[hdr.csumStart:totalLen], tcpCSum) - binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum) + // transport checksum + transportHeaderLen := int(hdr.hdrLen - hdr.csumStart) + lenForPseudo := uint16(transportHeaderLen + segmentDataLen) + transportCSum := pseudoHeaderChecksum(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo) + transportCSum = ^checksum(out[hdr.csumStart:totalLen], transportCSum) + binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], transportCSum) nextSegmentDataAt += int(hdr.gsoSize) } diff --git a/tun/tcp_offload_linux_test.go b/tun/offload_linux_test.go similarity index 52% rename from tun/tcp_offload_linux_test.go rename to tun/offload_linux_test.go index 41fba7064..192232c32 100644 --- a/tun/tcp_offload_linux_test.go +++ b/tun/offload_linux_test.go @@ -28,6 +28,71 @@ var ( ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1") ) +func udp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv4Fields)) []byte { + totalLen := 28 + payloadLen + b := make([]byte, offset+int(totalLen), 65535) + ipv4H := header.IPv4(b[offset:]) + srcAs4 := srcIPPort.Addr().As4() + dstAs4 := dstIPPort.Addr().As4() + ipFields := &header.IPv4Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs4[:]), + DstAddr: tcpip.AddrFromSlice(dstAs4[:]), + Protocol: unix.IPPROTO_UDP, + TTL: 64, + TotalLength: uint16(totalLen), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv4H.Encode(ipFields) + udpH := header.UDP(b[offset+20:]) + udpH.Encode(&header.UDPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + Length: uint16(payloadLen + udphLen), + }) + ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udphLen+payloadLen)) + udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum)) + return b +} + +func udp6Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte { + return udp6PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil) +} + +func udp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv6Fields)) []byte { + totalLen := 48 + payloadLen + b := make([]byte, offset+int(totalLen), 65535) + ipv6H := header.IPv6(b[offset:]) + srcAs16 := srcIPPort.Addr().As16() + dstAs16 := dstIPPort.Addr().As16() + ipFields := &header.IPv6Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs16[:]), + DstAddr: tcpip.AddrFromSlice(dstAs16[:]), + TransportProtocol: unix.IPPROTO_UDP, + HopLimit: 64, + PayloadLength: uint16(payloadLen + udphLen), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv6H.Encode(ipFields) + udpH := header.UDP(b[offset+40:]) + udpH.Encode(&header.UDPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + Length: uint16(payloadLen + udphLen), + }) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udphLen+payloadLen)) + udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum)) + return b +} + +func udp4Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte { + return udp4PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil) +} + func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte { totalLen := 40 + segmentSize b := make([]byte, offset+int(totalLen), 65535) @@ -137,6 +202,34 @@ func Test_handleVirtioRead(t *testing.T) { []int{160, 160}, false, }, + { + "udp4", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4, + gsoSize: 100, + hdrLen: 28, + csumStart: 20, + csumOffset: 6, + }, + udp4Packet(ip4PortA, ip4PortB, 200), + []int{128, 128}, + false, + }, + { + "udp6", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4, + gsoSize: 100, + hdrLen: 48, + csumStart: 40, + csumOffset: 6, + }, + udp6Packet(ip6PortA, ip6PortB, 200), + []int{148, 148}, + false, + }, } for _, tt := range tests { @@ -173,6 +266,13 @@ func flipTCP4Checksum(b []byte) []byte { return b } +func flipUDP4Checksum(b []byte) []byte { + at := virtioNetHdrLen + 20 + 6 // 20 byte ipv4 header; udp csum offset is 6 + b[at] ^= 0xFF + b[at+1] ^= 0xFF + return b +} + func Fuzz_handleGRO(f *testing.F) { pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1) pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101) @@ -180,11 +280,17 @@ func Fuzz_handleGRO(f *testing.F) { pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1) pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101) pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201) - f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, offset) - f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5 []byte, offset int) { - pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5} + pkt6 := udp4Packet(ip4PortA, ip4PortB, 100) + pkt7 := udp4Packet(ip4PortA, ip4PortB, 100) + pkt8 := udp4Packet(ip4PortA, ip4PortC, 100) + pkt9 := udp6Packet(ip6PortA, ip6PortB, 100) + pkt10 := udp6Packet(ip6PortA, ip6PortB, 100) + pkt11 := udp6Packet(ip6PortA, ip6PortC, 100) + f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, offset) + f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, offset int) { + pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11} toWrite := make([]int, 0, len(pkts)) - handleGRO(pkts, offset, newTCPGROTable(), newTCPGROTable(), &toWrite) + handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), &toWrite) if len(toWrite) > len(pkts) { t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) } @@ -210,17 +316,22 @@ func Test_handleGRO(t *testing.T) { wantErr bool }{ { - "multiple flows", + "multiple protocols and flows", [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // v4 flow 2 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // v6 flow 2 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1 + tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 }, - []int{0, 2, 3, 5}, - []int{240, 140, 260, 160}, + []int{0, 1, 2, 4, 5, 7, 9}, + []int{240, 228, 128, 140, 260, 160, 248}, false, }, { @@ -245,9 +356,12 @@ func Test_handleGRO(t *testing.T) { flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 + flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4Packet(ip4PortA, ip4PortB, 100), }, - []int{0, 1}, - []int{140, 240}, + []int{0, 1, 3, 4}, + []int{140, 240, 128, 228}, false, }, { @@ -262,75 +376,99 @@ func Test_handleGRO(t *testing.T) { false, }, { - "tcp4 unequal TTL", + "unequal TTL", [][]byte{ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { fields.TTL++ }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.TTL++ + }), }, - []int{0, 1}, - []int{140, 140}, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, false, }, { - "tcp4 unequal ToS", + "unequal ToS", [][]byte{ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { fields.TOS++ }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.TOS++ + }), }, - []int{0, 1}, - []int{140, 140}, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, false, }, { - "tcp4 unequal flags more fragments set", + "unequal flags more fragments set", [][]byte{ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { fields.Flags = 1 }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.Flags = 1 + }), }, - []int{0, 1}, - []int{140, 140}, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, false, }, { - "tcp4 unequal flags DF set", + "unequal flags DF set", [][]byte{ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { fields.Flags = 2 }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.Flags = 2 + }), }, - []int{0, 1}, - []int{140, 140}, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, false, }, { - "tcp6 unequal hop limit", + "ipv6 unequal hop limit", [][]byte{ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { fields.HopLimit++ }), + udp6Packet(ip6PortA, ip6PortB, 100), + udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) { + fields.HopLimit++ + }), }, - []int{0, 1}, - []int{160, 160}, + []int{0, 1, 2, 3}, + []int{160, 160, 148, 148}, false, }, { - "tcp6 unequal traffic class", + "ipv6 unequal traffic class", [][]byte{ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { fields.TrafficClass++ }), + udp6Packet(ip6PortA, ip6PortB, 100), + udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) { + fields.TrafficClass++ + }), }, - []int{0, 1}, - []int{160, 160}, + []int{0, 1, 2, 3}, + []int{160, 160, 148, 148}, false, }, } @@ -338,7 +476,7 @@ func Test_handleGRO(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { toWrite := make([]int, 0, len(tt.pktsIn)) - err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newTCPGROTable(), &toWrite) + err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), &toWrite) if err != nil { if tt.wantErr { return @@ -360,51 +498,198 @@ func Test_handleGRO(t *testing.T) { } } -func Test_isTCP4NoIPOptions(t *testing.T) { - valid := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] - invalidLen := valid[:39] - invalidHeaderLen := make([]byte, len(valid)) - copy(invalidHeaderLen, valid) - invalidHeaderLen[0] = 0x46 - invalidProtocol := make([]byte, len(valid)) - copy(invalidProtocol, valid) - invalidProtocol[9] = unix.IPPROTO_TCP + 1 +func Test_packetIsGROCandidate(t *testing.T) { + tcp4 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] + tcp4TooShort := tcp4[:39] + ip4InvalidHeaderLen := make([]byte, len(tcp4)) + copy(ip4InvalidHeaderLen, tcp4) + ip4InvalidHeaderLen[0] = 0x46 + ip4InvalidProtocol := make([]byte, len(tcp4)) + copy(ip4InvalidProtocol, tcp4) + ip4InvalidProtocol[9] = unix.IPPROTO_GRE + + tcp6 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] + tcp6TooShort := tcp6[:59] + ip6InvalidProtocol := make([]byte, len(tcp6)) + copy(ip6InvalidProtocol, tcp6) + ip6InvalidProtocol[6] = unix.IPPROTO_GRE + + udp4 := udp4Packet(ip4PortA, ip4PortB, 100)[virtioNetHdrLen:] + udp4TooShort := udp4[:27] + + udp6 := udp6Packet(ip6PortA, ip6PortB, 100)[virtioNetHdrLen:] + udp6TooShort := udp6[:47] tests := []struct { name string b []byte - want bool + want groCandidateType }{ { - "valid", - valid, - true, + "tcp4", + tcp4, + tcp4GROCandidate, }, { - "invalid length", - invalidLen, - false, + "tcp6", + tcp6, + tcp6GROCandidate, + }, + { + "udp4", + udp4, + udp4GROCandidate, + }, + { + "udp6", + udp6, + udp6GROCandidate, + }, + { + "udp4 too short", + udp4TooShort, + notGROCandidate, + }, + { + "udp6 too short", + udp6TooShort, + notGROCandidate, + }, + { + "tcp4 too short", + tcp4TooShort, + notGROCandidate, }, { - "invalid version", + "tcp6 too short", + tcp6TooShort, + notGROCandidate, + }, + { + "invalid IP version", []byte{0x00}, - false, + notGROCandidate, }, { - "invalid header len", - invalidHeaderLen, - false, + "invalid IP header len", + ip4InvalidHeaderLen, + notGROCandidate, }, { - "invalid protocol", - invalidProtocol, - false, + "ip4 invalid protocol", + ip4InvalidProtocol, + notGROCandidate, + }, + { + "ip6 invalid protocol", + ip6InvalidProtocol, + notGROCandidate, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := packetIsGROCandidate(tt.b); got != tt.want { + t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_udpPacketsCanCoalesce(t *testing.T) { + udp4a := udp4Packet(ip4PortA, ip4PortB, 100) + udp4b := udp4Packet(ip4PortA, ip4PortB, 100) + udp4c := udp4Packet(ip4PortA, ip4PortB, 110) + + type args struct { + pkt []byte + iphLen uint8 + gsoSize uint16 + item udpGROItem + bufs [][]byte + bufsOffset int + } + tests := []struct { + name string + args args + want canCoalesce + }{ + { + "coalesceAppend equal gso", + args{ + pkt: udp4a[offset:], + iphLen: 20, + gsoSize: 100, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4a, + udp4b, + }, + bufsOffset: offset, + }, + coalesceAppend, + }, + { + "coalesceAppend smaller gso", + args{ + pkt: udp4a[offset : len(udp4a)-90], + iphLen: 20, + gsoSize: 10, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4a, + udp4b, + }, + bufsOffset: offset, + }, + coalesceAppend, + }, + { + "coalesceUnavailable smaller gso previously appended", + args{ + pkt: udp4a[offset:], + iphLen: 20, + gsoSize: 100, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4c, + udp4b, + }, + bufsOffset: offset, + }, + coalesceUnavailable, + }, + { + "coalesceUnavailable larger following smaller", + args{ + pkt: udp4c[offset:], + iphLen: 20, + gsoSize: 110, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4a, + udp4c, + }, + bufsOffset: offset, + }, + coalesceUnavailable, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := isTCP4NoIPOptions(tt.b); got != tt.want { - t.Errorf("isTCP4NoIPOptions() = %v, want %v", got, tt.want) + if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want { + t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want) } }) } diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 b/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 deleted file mode 100644 index 5461e79a0..000000000 --- a/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 +++ /dev/null @@ -1,8 +0,0 @@ -go test fuzz v1 -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -int(34) diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d b/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d deleted file mode 100644 index b441819e7..000000000 --- a/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d +++ /dev/null @@ -1,8 +0,0 @@ -go test fuzz v1 -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -int(-48) diff --git a/tun/tun_linux.go b/tun/tun_linux.go index eb5051ed8..94bffcc76 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -38,6 +38,7 @@ type NativeTun struct { statusListenersShutdown chan struct{} batchSize int vnetHdr bool + udpGSO bool closeOnce sync.Once @@ -48,9 +49,10 @@ type NativeTun struct { readOpMu sync.Mutex // readOpMu guards readBuff readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr - writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable - toWrite []int - tcp4GROTable, tcp6GROTable *tcpGROTable + writeOpMu sync.Mutex // writeOpMu guards toWrite, tcpGROTable + toWrite []int + tcpGROTable *tcpGROTable + udpGROTable *udpGROTable } func (tun *NativeTun) File() *os.File { @@ -333,8 +335,8 @@ func (tun *NativeTun) nameSlow() (string, error) { func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { tun.writeOpMu.Lock() defer func() { - tun.tcp4GROTable.reset() - tun.tcp6GROTable.reset() + tun.tcpGROTable.reset() + tun.udpGROTable.reset() tun.writeOpMu.Unlock() }() var ( @@ -343,7 +345,7 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { ) tun.toWrite = tun.toWrite[:0] if tun.vnetHdr { - err := handleGRO(bufs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite) + err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, &tun.toWrite) if err != nil { return 0, err } @@ -394,37 +396,42 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e sizes[0] = n return 1, nil } - if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType) } ipVersion := in[0] >> 4 switch ipVersion { case 4: - if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 { + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) } case 6: - if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) } default: return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) } - if len(in) <= int(hdr.csumStart+12) { - return 0, errors.New("packet is too short") - } // Don't trust hdr.hdrLen from the kernel as it can be equal to the length // of the entire first packet when the kernel is handling it as part of a - // FORWARD path. Instead, parse the TCP header length and add it onto + // FORWARD path. Instead, parse the transport header length and add it onto // csumStart, which is synonymous for IP header length. - tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) - if tcpHLen < 20 || tcpHLen > 60 { - // A TCP header must be between 20 and 60 bytes in length. - return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 { + hdr.hdrLen = hdr.csumStart + 8 + } else { + if len(in) <= int(hdr.csumStart+12) { + return 0, errors.New("packet is too short") + } + + tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) + if tcpHLen < 20 || tcpHLen > 60 { + // A TCP header must be between 20 and 60 bytes in length. + return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) + } + hdr.hdrLen = hdr.csumStart + tcpHLen } - hdr.hdrLen = hdr.csumStart + tcpHLen if len(in) < int(hdr.hdrLen) { return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen) @@ -438,7 +445,7 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) } - return tcpTSO(in, hdr, bufs, sizes, offset) + return gsoSplit(in, hdr, bufs, sizes, offset, ipVersion == 6) } func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { @@ -497,7 +504,8 @@ func (tun *NativeTun) BatchSize() int { const ( // TODO: support TSO with ECN bits - tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6 ) func (tun *NativeTun) initFromFlags(name string) error { @@ -519,12 +527,17 @@ func (tun *NativeTun) initFromFlags(name string) error { } got := ifr.Uint16() if got&unix.IFF_VNET_HDR != 0 { - err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads) + // tunTCPOffloads were added in Linux v2.6. We require their support + // if IFF_VNET_HDR is set. + err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads) if err != nil { return } tun.vnetHdr = true tun.batchSize = conn.IdealBatchSize + // tunUDPOffloads were added in Linux v6.2. We do not return an + // error if they are unsupported at runtime. + tun.udpGSO = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) == nil } else { tun.batchSize = 1 } @@ -575,8 +588,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { events: make(chan Event, 5), errors: make(chan error, 5), statusListenersShutdown: make(chan struct{}), - tcp4GROTable: newTCPGROTable(), - tcp6GROTable: newTCPGROTable(), + tcpGROTable: newTCPGROTable(), + udpGROTable: newUDPGROTable(), toWrite: make([]int, 0, conn.IdealBatchSize), } @@ -628,12 +641,12 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { } file := os.NewFile(uintptr(fd), "/dev/tun") tun := &NativeTun{ - tunFile: file, - events: make(chan Event, 5), - errors: make(chan error, 5), - tcp4GROTable: newTCPGROTable(), - tcp6GROTable: newTCPGROTable(), - toWrite: make([]int, 0, conn.IdealBatchSize), + tunFile: file, + events: make(chan Event, 5), + errors: make(chan error, 5), + tcpGROTable: newTCPGROTable(), + udpGROTable: newUDPGROTable(), + toWrite: make([]int, 0, conn.IdealBatchSize), } name, err := tun.Name() if err != nil { From db7604d1aa907a37c91164982d2a881fe6edc3c1 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Tue, 31 Oct 2023 18:08:04 -0700 Subject: [PATCH 14/39] tun: don't assume UDP GRO is supported Signed-off-by: Jordan Whited --- tun/offload_linux.go | 13 +++---- tun/offload_linux_test.go | 72 ++++++++++++++++++++++++++++++++++----- tun/tun_linux.go | 2 +- 3 files changed, 72 insertions(+), 15 deletions(-) diff --git a/tun/offload_linux.go b/tun/offload_linux.go index 54461cc9b..9a9d38e79 100644 --- a/tun/offload_linux.go +++ b/tun/offload_linux.go @@ -748,7 +748,7 @@ const ( udp6GROCandidate ) -func packetIsGROCandidate(b []byte) groCandidateType { +func packetIsGROCandidate(b []byte, canUDPGRO bool) groCandidateType { if len(b) < 28 { return notGROCandidate } @@ -760,14 +760,14 @@ func packetIsGROCandidate(b []byte) groCandidateType { if b[9] == unix.IPPROTO_TCP && len(b) >= 40 { return tcp4GROCandidate } - if b[9] == unix.IPPROTO_UDP { + if b[9] == unix.IPPROTO_UDP && canUDPGRO { return udp4GROCandidate } } else if b[0]>>4 == 6 { if b[6] == unix.IPPROTO_TCP && len(b) >= 60 { return tcp6GROCandidate } - if b[6] == unix.IPPROTO_UDP && len(b) >= 48 { + if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && canUDPGRO { return udp6GROCandidate } } @@ -860,14 +860,15 @@ func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) // handleGRO evaluates bufs for GRO, and writes the indices of the resulting // packets into toWrite. toWrite, tcpTable, and udpTable should initially be // empty (but non-nil), and are passed in to save allocs as the caller may reset -// and recycle them across vectors of packets. -func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, toWrite *[]int) error { +// and recycle them across vectors of packets. canUDPGRO indicates if UDP GRO is +// supported. +func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, canUDPGRO bool, toWrite *[]int) error { for i := range bufs { if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { return errors.New("invalid offset") } var result groResult - switch packetIsGROCandidate(bufs[i][offset:]) { + switch packetIsGROCandidate(bufs[i][offset:], canUDPGRO) { case tcp4GROCandidate: result = tcpGRO(bufs, offset, i, tcpTable, false) case tcp6GROCandidate: diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go index 192232c32..91f394108 100644 --- a/tun/offload_linux_test.go +++ b/tun/offload_linux_test.go @@ -286,11 +286,11 @@ func Fuzz_handleGRO(f *testing.F) { pkt9 := udp6Packet(ip6PortA, ip6PortB, 100) pkt10 := udp6Packet(ip6PortA, ip6PortB, 100) pkt11 := udp6Packet(ip6PortA, ip6PortC, 100) - f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, offset) - f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, offset int) { + f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, true, offset) + f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, canUDPGRO bool, offset int) { pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11} toWrite := make([]int, 0, len(pkts)) - handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), &toWrite) + handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, &toWrite) if len(toWrite) > len(pkts) { t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) } @@ -311,6 +311,7 @@ func Test_handleGRO(t *testing.T) { tests := []struct { name string pktsIn [][]byte + canUDPGRO bool wantToWrite []int wantLens []int wantErr bool @@ -330,10 +331,31 @@ func Test_handleGRO(t *testing.T) { udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 }, + true, []int{0, 1, 2, 4, 5, 7, 9}, []int{240, 228, 128, 140, 260, 160, 248}, false, }, + { + "multiple protocols and flows no UDP GRO", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1 + tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + }, + false, + []int{0, 1, 2, 4, 5, 7, 8, 9, 10}, + []int{240, 128, 128, 140, 260, 160, 128, 148, 148}, + false, + }, { "PSH interleaved", [][]byte{ @@ -346,6 +368,7 @@ func Test_handleGRO(t *testing.T) { tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1 }, + true, []int{0, 2, 4, 6}, []int{240, 240, 260, 260}, false, @@ -360,6 +383,7 @@ func Test_handleGRO(t *testing.T) { udp4Packet(ip4PortA, ip4PortB, 100), udp4Packet(ip4PortA, ip4PortB, 100), }, + true, []int{0, 1, 3, 4}, []int{140, 240, 128, 228}, false, @@ -371,6 +395,7 @@ func Test_handleGRO(t *testing.T) { tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 }, + true, []int{0}, []int{340}, false, @@ -387,6 +412,7 @@ func Test_handleGRO(t *testing.T) { fields.TTL++ }), }, + true, []int{0, 1, 2, 3}, []int{140, 140, 128, 128}, false, @@ -403,6 +429,7 @@ func Test_handleGRO(t *testing.T) { fields.TOS++ }), }, + true, []int{0, 1, 2, 3}, []int{140, 140, 128, 128}, false, @@ -419,6 +446,7 @@ func Test_handleGRO(t *testing.T) { fields.Flags = 1 }), }, + true, []int{0, 1, 2, 3}, []int{140, 140, 128, 128}, false, @@ -435,6 +463,7 @@ func Test_handleGRO(t *testing.T) { fields.Flags = 2 }), }, + true, []int{0, 1, 2, 3}, []int{140, 140, 128, 128}, false, @@ -451,6 +480,7 @@ func Test_handleGRO(t *testing.T) { fields.HopLimit++ }), }, + true, []int{0, 1, 2, 3}, []int{160, 160, 148, 148}, false, @@ -467,6 +497,7 @@ func Test_handleGRO(t *testing.T) { fields.TrafficClass++ }), }, + true, []int{0, 1, 2, 3}, []int{160, 160, 148, 148}, false, @@ -476,7 +507,7 @@ func Test_handleGRO(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { toWrite := make([]int, 0, len(tt.pktsIn)) - err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), &toWrite) + err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, &toWrite) if err != nil { if tt.wantErr { return @@ -521,74 +552,99 @@ func Test_packetIsGROCandidate(t *testing.T) { udp6TooShort := udp6[:47] tests := []struct { - name string - b []byte - want groCandidateType + name string + b []byte + canUDPGRO bool + want groCandidateType }{ { "tcp4", tcp4, + true, tcp4GROCandidate, }, { "tcp6", tcp6, + true, tcp6GROCandidate, }, { "udp4", udp4, + true, udp4GROCandidate, }, + { + "udp4 no support", + udp4, + false, + notGROCandidate, + }, { "udp6", udp6, + true, udp6GROCandidate, }, + { + "udp6 no support", + udp6, + false, + notGROCandidate, + }, { "udp4 too short", udp4TooShort, + true, notGROCandidate, }, { "udp6 too short", udp6TooShort, + true, notGROCandidate, }, { "tcp4 too short", tcp4TooShort, + true, notGROCandidate, }, { "tcp6 too short", tcp6TooShort, + true, notGROCandidate, }, { "invalid IP version", []byte{0x00}, + true, notGROCandidate, }, { "invalid IP header len", ip4InvalidHeaderLen, + true, notGROCandidate, }, { "ip4 invalid protocol", ip4InvalidProtocol, + true, notGROCandidate, }, { "ip6 invalid protocol", ip6InvalidProtocol, + true, notGROCandidate, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := packetIsGROCandidate(tt.b); got != tt.want { + if got := packetIsGROCandidate(tt.b, tt.canUDPGRO); got != tt.want { t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want) } }) diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 94bffcc76..9313ebf0b 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -345,7 +345,7 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { ) tun.toWrite = tun.toWrite[:0] if tun.vnetHdr { - err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, &tun.toWrite) + err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, &tun.toWrite) if err != nil { return 0, err } From 8cc8b8b11b1f7189f3e19616d2e233fce6ee7eda Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Mon, 20 Nov 2023 16:49:06 -0800 Subject: [PATCH 15/39] device: change Peer.endpoint locking to reduce contention Access to Peer.endpoint was previously synchronized by Peer.RWMutex. This has now moved to Peer.endpoint.Mutex. Peer.SendBuffers() is now the sole caller of Endpoint.ClearSrc(), which is signaled via a new bool, Peer.endpoint.clearSrcOnTx. Previous Callers of Endpoint.ClearSrc() now set this bool, primarily via peer.markEndpointSrcForClearing(). Peer.SetEndpointFromPacket() clears Peer.endpoint.clearSrcOnTx when an updated conn.Endpoint is stored. This maintains the same event order as before, i.e. a conn.Endpoint received after peer.endpoint.clearSrcOnTx is set, but before the next Peer.SendBuffers() call results in the latest conn.Endpoint source being used for the next packet transmission. These changes result in throughput improvements for single flow, parallel (-P n) flow, and bidirectional (--bidir) flow iperf3 TCP/UDP tests as measured on both Linux and Windows. Latency under load improves especially for high throughput Linux scenarios. These improvements are likely realized on all platforms to some degree, as the changes are not platform-specific. Co-authored-by: James Tucker Signed-off-by: Jordan Whited --- device/device.go | 12 ++------- device/mobilequirks.go | 6 ++--- device/peer.go | 50 ++++++++++++++++++++++++----------- device/sticky_linux.go | 30 ++++++++++----------- device/timers.go | 12 ++------- device/uapi.go | 60 ++++++++++++++++++++---------------------- 6 files changed, 86 insertions(+), 84 deletions(-) diff --git a/device/device.go b/device/device.go index 5c666acc9..86dff0d7e 100644 --- a/device/device.go +++ b/device/device.go @@ -461,11 +461,7 @@ func (device *Device) BindSetMark(mark uint32) error { // clear cached source addresses device.peers.RLock() for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } + peer.markEndpointSrcForClearing() } device.peers.RUnlock() @@ -515,11 +511,7 @@ func (device *Device) BindUpdate() error { // clear cached source addresses device.peers.RLock() for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } + peer.markEndpointSrcForClearing() } device.peers.RUnlock() diff --git a/device/mobilequirks.go b/device/mobilequirks.go index 4e5051d7e..0a0080efd 100644 --- a/device/mobilequirks.go +++ b/device/mobilequirks.go @@ -11,9 +11,9 @@ func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() { device.net.brokenRoaming = true device.peers.RLock() for _, peer := range device.peers.keyMap { - peer.Lock() - peer.disableRoaming = peer.endpoint != nil - peer.Unlock() + peer.endpoint.Lock() + peer.endpoint.disableRoaming = peer.endpoint.val != nil + peer.endpoint.Unlock() } device.peers.RUnlock() } diff --git a/device/peer.go b/device/peer.go index 22757d443..89b719bf3 100644 --- a/device/peer.go +++ b/device/peer.go @@ -17,17 +17,20 @@ import ( type Peer struct { isRunning atomic.Bool - sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer keypairs Keypairs handshake Handshake device *Device - endpoint conn.Endpoint stopping sync.WaitGroup // routines pending stop txBytes atomic.Uint64 // bytes send to peer (endpoint) rxBytes atomic.Uint64 // bytes received from peer lastHandshakeNano atomic.Int64 // nano seconds since epoch - disableRoaming bool + endpoint struct { + sync.Mutex + val conn.Endpoint + clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission + disableRoaming bool + } timers struct { retransmitHandshake *Timer @@ -74,8 +77,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { // create peer peer := new(Peer) - peer.Lock() - defer peer.Unlock() peer.cookieGenerator.Init(pk) peer.device = device @@ -97,7 +98,11 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { handshake.mutex.Unlock() // reset endpoint - peer.endpoint = nil + peer.endpoint.Lock() + peer.endpoint.val = nil + peer.endpoint.disableRoaming = false + peer.endpoint.clearSrcOnTx = false + peer.endpoint.Unlock() // init timers peer.timersInit() @@ -116,14 +121,19 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error { return nil } - peer.RLock() - defer peer.RUnlock() - - if peer.endpoint == nil { + peer.endpoint.Lock() + endpoint := peer.endpoint.val + if endpoint == nil { + peer.endpoint.Unlock() return errors.New("no known endpoint for peer") } + if peer.endpoint.clearSrcOnTx { + endpoint.ClearSrc() + peer.endpoint.clearSrcOnTx = false + } + peer.endpoint.Unlock() - err := peer.device.net.bind.Send(buffers, peer.endpoint) + err := peer.device.net.bind.Send(buffers, endpoint) if err == nil { var totalLen uint64 for _, b := range buffers { @@ -267,10 +277,20 @@ func (peer *Peer) Stop() { } func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { - if peer.disableRoaming { + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + if peer.endpoint.disableRoaming { + return + } + peer.endpoint.clearSrcOnTx = false + peer.endpoint.val = endpoint +} + +func (peer *Peer) markEndpointSrcForClearing() { + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + if peer.endpoint.val == nil { return } - peer.Lock() - peer.endpoint = endpoint - peer.Unlock() + peer.endpoint.clearSrcOnTx = true } diff --git a/device/sticky_linux.go b/device/sticky_linux.go index 7a519c1f8..6eeced2b0 100644 --- a/device/sticky_linux.go +++ b/device/sticky_linux.go @@ -110,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl if !ok { break } - pePtr.peer.Lock() - if &pePtr.peer.endpoint != pePtr.endpoint { - pePtr.peer.Unlock() + pePtr.peer.endpoint.Lock() + if &pePtr.peer.endpoint.val != pePtr.endpoint { + pePtr.peer.endpoint.Unlock() break } - if uint32(pePtr.peer.endpoint.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx { - pePtr.peer.Unlock() + if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx { + pePtr.peer.endpoint.Unlock() break } - pePtr.peer.endpoint.(*conn.StdNetEndpoint).ClearSrc() - pePtr.peer.Unlock() + pePtr.peer.endpoint.clearSrcOnTx = true + pePtr.peer.endpoint.Unlock() } attr = attr[attrhdr.Len:] } @@ -134,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl device.peers.RLock() i := uint32(1) for _, peer := range device.peers.keyMap { - peer.RLock() - if peer.endpoint == nil { - peer.RUnlock() + peer.endpoint.Lock() + if peer.endpoint.val == nil { + peer.endpoint.Unlock() continue } - nativeEP, _ := peer.endpoint.(*conn.StdNetEndpoint) + nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint) if nativeEP == nil { - peer.RUnlock() + peer.endpoint.Unlock() continue } if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 { - peer.RUnlock() + peer.endpoint.Unlock() break } nlmsg := struct { @@ -188,10 +188,10 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl reqPeerLock.Lock() reqPeer[i] = peerEndpointPtr{ peer: peer, - endpoint: &peer.endpoint, + endpoint: &peer.endpoint.val, } reqPeerLock.Unlock() - peer.RUnlock() + peer.endpoint.Unlock() i++ _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) if err != nil { diff --git a/device/timers.go b/device/timers.go index e28732c42..d4a4ed4e5 100644 --- a/device/timers.go +++ b/device/timers.go @@ -100,11 +100,7 @@ func expiredRetransmitHandshake(peer *Peer) { peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1) /* We clear the endpoint address src address, in case this is the cause of trouble. */ - peer.Lock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - peer.Unlock() + peer.markEndpointSrcForClearing() peer.SendHandshakeInitiation(true) } @@ -123,11 +119,7 @@ func expiredSendKeepalive(peer *Peer) { func expiredNewHandshake(peer *Peer) { peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) /* We clear the endpoint address src address, in case this is the cause of trouble. */ - peer.Lock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - peer.Unlock() + peer.markEndpointSrcForClearing() peer.SendHandshakeInitiation(false) } diff --git a/device/uapi.go b/device/uapi.go index 2a91a9361..4987cdae0 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -99,33 +99,31 @@ func (device *Device) IpcGetOperation(w io.Writer) error { for _, peer := range device.peers.keyMap { // Serialize peer state. - // Do the work in an anonymous function so that we can use defer. - func() { - peer.RLock() - defer peer.RUnlock() - - keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic)) - keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey)) - sendf("protocol_version=1") - if peer.endpoint != nil { - sendf("endpoint=%s", peer.endpoint.DstToString()) - } - - nano := peer.lastHandshakeNano.Load() - secs := nano / time.Second.Nanoseconds() - nano %= time.Second.Nanoseconds() - - sendf("last_handshake_time_sec=%d", secs) - sendf("last_handshake_time_nsec=%d", nano) - sendf("tx_bytes=%d", peer.txBytes.Load()) - sendf("rx_bytes=%d", peer.rxBytes.Load()) - sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load()) - - device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { - sendf("allowed_ip=%s", prefix.String()) - return true - }) - }() + peer.handshake.mutex.RLock() + keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic)) + keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey)) + peer.handshake.mutex.RUnlock() + sendf("protocol_version=1") + peer.endpoint.Lock() + if peer.endpoint.val != nil { + sendf("endpoint=%s", peer.endpoint.val.DstToString()) + } + peer.endpoint.Unlock() + + nano := peer.lastHandshakeNano.Load() + secs := nano / time.Second.Nanoseconds() + nano %= time.Second.Nanoseconds() + + sendf("last_handshake_time_sec=%d", secs) + sendf("last_handshake_time_nsec=%d", nano) + sendf("tx_bytes=%d", peer.txBytes.Load()) + sendf("rx_bytes=%d", peer.rxBytes.Load()) + sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load()) + + device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { + sendf("allowed_ip=%s", prefix.String()) + return true + }) } }() @@ -262,7 +260,7 @@ func (peer *ipcSetPeer) handlePostConfig() { return } if peer.created { - peer.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint != nil + peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil } if peer.device.isUp() { peer.Start() @@ -345,9 +343,9 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) } - peer.Lock() - defer peer.Unlock() - peer.endpoint = endpoint + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + peer.endpoint.val = endpoint case "persistent_keepalive_interval": device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer) From cc193a0b327276d902b3b688d3a06a857b17fcb7 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Tue, 7 Nov 2023 15:24:21 -0800 Subject: [PATCH 16/39] device: reduce redundant per-packet overhead in RX path Peer.RoutineSequentialReceiver() deals with packet vectors and does not need to perform timer and endpoint operations for every packet in a given vector. Changing these per-packet operations to per-vector improves throughput by as much as 10% in some environments. Signed-off-by: Jordan Whited --- device/receive.go | 21 +++++++++++++++------ device/send.go | 4 ++-- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/device/receive.go b/device/receive.go index da663e9ee..af2db44e8 100644 --- a/device/receive.go +++ b/device/receive.go @@ -445,7 +445,9 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { return } elemsContainer.Lock() - for _, elem := range elemsContainer.elems { + validTailPacket := -1 + dataPacketReceived := false + for i, elem := range elemsContainer.elems { if elem.packet == nil { // decryption failed continue @@ -455,21 +457,19 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { continue } - peer.SetEndpointFromPacket(elem.endpoint) + validTailPacket = i if peer.ReceivedWithKeypair(elem.keypair) { + peer.SetEndpointFromPacket(elem.endpoint) peer.timersHandshakeComplete() peer.SendStagedPackets() } - peer.keepKeyFreshReceiving() - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketReceived() peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize)) if len(elem.packet) == 0 { device.log.Verbosef("%v - Receiving keepalive packet", peer) continue } - peer.timersDataReceived() + dataPacketReceived = true switch elem.packet[0] >> 4 { case 4: @@ -512,6 +512,15 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) } + if validTailPacket >= 0 { + peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint) + peer.keepKeyFreshReceiving() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + } + if dataPacketReceived { + peer.timersDataReceived() + } if len(bufs) > 0 { _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent) if err != nil && !device.isClosed() { diff --git a/device/send.go b/device/send.go index 95a5cbe49..2701f4e92 100644 --- a/device/send.go +++ b/device/send.go @@ -436,7 +436,7 @@ func calculatePaddingSize(packetSize, mtu int) int { return paddedSize - lastUnit } -/* Encrypts the elems in the queue +/* Encrypts the elements in the queue * and marks them for sequential consumption (by releasing the mutex) * * Obs. One instance per core @@ -495,7 +495,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { return } if !peer.isRunning.Load() { - // peer has been stopped; return re-usable elemsContainer to the shared pool. + // peer has been stopped; return re-usable elems to the shared pool. // This is an optimization only. It is possible for the peer to be stopped // immediately after this check, in which case, elem will get processed. // The timers and SendBuffers code are resilient to a few stragglers. From 64040e66467d89c9312cf442866185404b7ec60c Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sat, 13 Apr 2024 10:55:05 -0700 Subject: [PATCH 17/39] ipc: build on aix Updates tailscale/tailscale#11361 --- ipc/uapi_fake.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ipc/uapi_fake.go b/ipc/uapi_fake.go index e68863d4b..3e3dcdf80 100644 --- a/ipc/uapi_fake.go +++ b/ipc/uapi_fake.go @@ -1,4 +1,4 @@ -//go:build wasm || plan9 +//go:build wasm || plan9 || aix /* SPDX-License-Identifier: MIT * From 03c5a0ccf7546055344017505460bdb9c0425b17 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Mon, 29 Apr 2024 09:15:56 -0700 Subject: [PATCH 18/39] tun: implement API for disabling UDP GRO on Linux Certain device drivers (e.g. vxlan, geneve) do not properly handle coalesced UDP packets later in the stack, resulting in packet loss. Signed-off-by: Jordan Whited --- tun/tun.go | 8 ++++++++ tun/tun_linux.go | 12 ++++++++++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/tun/tun.go b/tun/tun.go index 0ae53d073..d3c5012c4 100644 --- a/tun/tun.go +++ b/tun/tun.go @@ -51,3 +51,11 @@ type Device interface { // lifetime of a Device. BatchSize() int } + +type LinuxDevice interface { + Device + // DisableUDPGRO disables UDP GRO if it is enabled. Certain device drivers + // (e.g. vxlan, geneve) do not properly handle coalesced UDP packets later + // in the stack, resulting in packet loss. + DisableUDPGRO() +} diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 9313ebf0b..6aa03d4b5 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -49,10 +49,11 @@ type NativeTun struct { readOpMu sync.Mutex // readOpMu guards readBuff readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr - writeOpMu sync.Mutex // writeOpMu guards toWrite, tcpGROTable + writeOpMu sync.Mutex // writeOpMu guards the following fields toWrite []int tcpGROTable *tcpGROTable udpGROTable *udpGROTable + udpGRO bool } func (tun *NativeTun) File() *os.File { @@ -345,7 +346,7 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { ) tun.toWrite = tun.toWrite[:0] if tun.vnetHdr { - err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, &tun.toWrite) + err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGRO, &tun.toWrite) if err != nil { return 0, err } @@ -502,6 +503,12 @@ func (tun *NativeTun) BatchSize() int { return tun.batchSize } +func (tun *NativeTun) DisableUDPGRO() { + tun.writeOpMu.Lock() + tun.udpGRO = false + tun.writeOpMu.Unlock() +} + const ( // TODO: support TSO with ECN bits tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 @@ -538,6 +545,7 @@ func (tun *NativeTun) initFromFlags(name string) error { // tunUDPOffloads were added in Linux v6.2. We do not return an // error if they are unsupported at runtime. tun.udpGSO = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) == nil + tun.udpGRO = tun.udpGSO } else { tun.batchSize = 1 } From 1e088837d114a74a2c335968e7d40537447b9b4f Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Thu, 27 Jun 2024 08:43:41 -0700 Subject: [PATCH 19/39] device: fix WaitPool sync.Cond usage The sync.Locker used with a sync.Cond must be acquired when changing the associated condition, otherwise there is a window within sync.Cond.Wait() where a wake-up may be missed. Fixes: 4846070 ("device: use a waiting sync.Pool instead of a channel") Signed-off-by: Jordan Whited --- device/pools.go | 11 ++++++----- device/pools_test.go | 5 +++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/device/pools.go b/device/pools.go index 94f3dc7e6..55d2be7df 100644 --- a/device/pools.go +++ b/device/pools.go @@ -7,14 +7,13 @@ package device import ( "sync" - "sync/atomic" ) type WaitPool struct { pool sync.Pool cond sync.Cond lock sync.Mutex - count atomic.Uint32 + count uint32 // Get calls not yet Put back max uint32 } @@ -27,10 +26,10 @@ func NewWaitPool(max uint32, new func() any) *WaitPool { func (p *WaitPool) Get() any { if p.max != 0 { p.lock.Lock() - for p.count.Load() >= p.max { + for p.count >= p.max { p.cond.Wait() } - p.count.Add(1) + p.count++ p.lock.Unlock() } return p.pool.Get() @@ -41,7 +40,9 @@ func (p *WaitPool) Put(x any) { if p.max == 0 { return } - p.count.Add(^uint32(0)) + p.lock.Lock() + defer p.lock.Unlock() + p.count-- p.cond.Signal() } diff --git a/device/pools_test.go b/device/pools_test.go index 82d7493e1..2b16f3984 100644 --- a/device/pools_test.go +++ b/device/pools_test.go @@ -15,7 +15,6 @@ import ( ) func TestWaitPool(t *testing.T) { - t.Skip("Currently disabled") var wg sync.WaitGroup var trials atomic.Int32 startTrials := int32(100000) @@ -32,7 +31,9 @@ func TestWaitPool(t *testing.T) { wg.Add(workers) var max atomic.Uint32 updateMax := func() { - count := p.count.Load() + p.lock.Lock() + count := p.count + p.lock.Unlock() if count > p.max { t.Errorf("count (%d) > max (%d)", count, p.max) } From cfa45674af86ac5074d879257ca513438bb291eb Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Thu, 27 Jun 2024 09:06:40 -0700 Subject: [PATCH 20/39] device: fix missed return of QueueOutboundElementsContainer to its WaitPool Fixes: 3bb8fec ("conn, device, tun: implement vectorized I/O plumbing") Signed-off-by: Jordan Whited --- device/send.go | 1 + 1 file changed, 1 insertion(+) diff --git a/device/send.go b/device/send.go index 2701f4e92..8ed2e5f6c 100644 --- a/device/send.go +++ b/device/send.go @@ -506,6 +506,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } + device.PutOutboundElementsContainer(elemsContainer) continue } dataSent := false From 2f5d148bcfe13a65d0f85dca8698200543c226da Mon Sep 17 00:00:00 2001 From: James Tucker Date: Fri, 7 Jun 2024 16:57:40 -0700 Subject: [PATCH 21/39] conn,device: enable cryptorouting via PeerAwareEndpoint Introduce an optional extension point for Endpoint that enables a path for WireGuard to inform an integration about the peer public key that is associated with an Endpoint. The API is expected to return either the same or a new Endpoint in response to this function. A future version of this patch could potentially remove the returned Endpoint, but would require larger integrator changes downstream. This adds a small per-packet cost that could later be removed with a larger refactor of the wireguard-go interface and Tailscale magicsock code, as well as introducing a generic bound for Endpoint in a device & bind instance. Updates tailscale/corp#20732 --- conn/conn.go | 14 ++++++++++++++ device/noise-protocol.go | 2 +- device/peer.go | 3 +++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/conn/conn.go b/conn/conn.go index a1f57d2b1..8df5aaa66 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -84,6 +84,20 @@ type Endpoint interface { SrcIP() netip.Addr } +// PeerAwareEndpoint is an optional Endpoint specialization for +// integrations that want to know about the outcome of cryptorouting +// identification. +// +// If they receive a packet from a source they had not pre-identified, +// to learn the identification WireGuard can derive from the session +// or handshake. +// +// If GetPeerEndpoint returns nil, WireGuard will be unable to respond +// to the peer until a new endpoint is written by a later packet. +type PeerAwareEndpoint interface { + GetPeerEndpoint(peerPublicKey [32]byte) Endpoint +} + var ( ErrBindAlreadyOpen = errors.New("bind is already open") ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type") diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 9f2ba509a..2d8f98426 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -124,7 +124,7 @@ type Handshake struct { localEphemeral NoisePrivateKey // ephemeral secret key localIndex uint32 // used to clear hash-table remoteIndex uint32 // index for sending - remoteStatic NoisePublicKey // long term key + remoteStatic NoisePublicKey // long term key, never changes, can be accessed without mutex remoteEphemeral NoisePublicKey // ephemeral public key precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret lastTimestamp tai64n.Timestamp diff --git a/device/peer.go b/device/peer.go index 89b719bf3..876e5daf0 100644 --- a/device/peer.go +++ b/device/peer.go @@ -283,6 +283,9 @@ func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { return } peer.endpoint.clearSrcOnTx = false + if ep, ok := endpoint.(conn.PeerAwareEndpoint); ok { + endpoint = ep.GetPeerEndpoint(peer.handshake.remoteStatic) + } peer.endpoint.val = endpoint } From 60eeedfd624bb0ee4d173c1aaac1ed0ddedf13fe Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Tue, 16 Jul 2024 15:36:20 -0700 Subject: [PATCH 22/39] tun: export GSOSplit() for external Device implementers External implementers of tun.Device may support GSO, and may also be platform-agnostic, e.g. gVisor. Signed-off-by: Jordan Whited --- tun/offload.go | 220 +++++++++++++++++++++++++++++++++++++++++++ tun/offload_linux.go | 134 +++++--------------------- tun/offload_test.go | 95 +++++++++++++++++++ tun/tun_linux.go | 65 +++---------- 4 files changed, 353 insertions(+), 161 deletions(-) create mode 100644 tun/offload.go create mode 100644 tun/offload_test.go diff --git a/tun/offload.go b/tun/offload.go new file mode 100644 index 000000000..6627e46e4 --- /dev/null +++ b/tun/offload.go @@ -0,0 +1,220 @@ +package tun + +import ( + "encoding/binary" + "fmt" +) + +// GSOType represents the type of segmentation offload. +type GSOType int + +const ( + GSONone GSOType = iota + GSOTCPv4 + GSOTCPv6 + GSOUDPL4 +) + +func (g GSOType) String() string { + switch g { + case GSONone: + return "GSONone" + case GSOTCPv4: + return "GSOTCPv4" + case GSOTCPv6: + return "GSOTCPv6" + case GSOUDPL4: + return "GSOUDPL4" + default: + return "unknown" + } +} + +// GSOOptions is loosely modeled after struct virtio_net_hdr from the VIRTIO +// specification. It is a common representation of GSO metadata that can be +// applied to support packet GSO across tun.Device implementations. +type GSOOptions struct { + // GSOType represents the type of segmentation offload. + GSOType GSOType + // HdrLen is the sum of the layer 3 and 4 header lengths. This field may be + // zero when GSOType == GSONone. + HdrLen uint16 + // CsumStart is the head byte index of the packet data to be checksummed, + // i.e. the start of the TCP or UDP header. + CsumStart uint16 + // CsumOffset is the offset from CsumStart where the 2-byte checksum value + // should be placed. + CsumOffset uint16 + // GSOSize is the size of each segment exclusive of HdrLen. The tail segment + // may be smaller than this value. + GSOSize uint16 + // NeedsCsum may be set where GSOType == GSONone. When set, the checksum + // at CsumStart + CsumOffset must be a partial checksum, i.e. the + // pseudo-header sum. + NeedsCsum bool +} + +const ( + ipv4SrcAddrOffset = 12 + ipv6SrcAddrOffset = 8 +) + +const tcpFlagsOffset = 13 + +const ( + tcpFlagFIN uint8 = 0x01 + tcpFlagPSH uint8 = 0x08 + tcpFlagACK uint8 = 0x10 +) + +const ( + // defined here in order to avoid importation of any platform-specific pkgs + ipProtoTCP = 6 + ipProtoUDP = 17 +) + +// GSOSplit splits packets from 'in' into outBufs[][outOffset:], writing +// the size of each element into sizes. It returns the number of buffers +// populated, and/or an error. Callers may pass an 'in' slice that overlaps with +// the first element of outBuffers, i.e. &in[0] may be equal to +// &outBufs[0][outOffset]. GSONone is a valid options.GSOType regardless of the +// value of options.NeedsCsum. Length of each outBufs element must be greater +// than or equal to the length of 'in', otherwise output may be silently +// truncated. +func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outOffset int) (int, error) { + cSumAt := int(options.CsumStart) + int(options.CsumOffset) + if cSumAt+1 >= len(in) { + return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) + } + + if len(in) < int(options.HdrLen) { + return 0, fmt.Errorf("length of packet (%d) < GSO HdrLen (%d)", len(in), options.HdrLen) + } + + // Handle the conditions where we are copying a single element to outBuffs. + payloadLen := len(in) - int(options.HdrLen) + if options.GSOType == GSONone || payloadLen < int(options.GSOSize) { + if len(in) > len(outBufs[0][outOffset:]) { + return 0, fmt.Errorf("length of packet (%d) exceeds output element length (%d)", len(in), len(outBufs[0][outOffset:])) + } + if options.NeedsCsum { + // The initial value at the checksum offset should be summed with + // the checksum we compute. This is typically the pseudo-header sum. + initial := binary.BigEndian.Uint16(in[cSumAt:]) + in[cSumAt], in[cSumAt+1] = 0, 0 + binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[options.CsumStart:], initial)) + } + sizes[0] = copy(outBufs[0][outOffset:], in) + return 1, nil + } + + if options.HdrLen < options.CsumStart { + return 0, fmt.Errorf("GSO HdrLen (%d) < GSO CsumStart (%d)", options.HdrLen, options.CsumStart) + } + + ipVersion := in[0] >> 4 + switch ipVersion { + case 4: + if options.GSOType != GSOTCPv4 && options.GSOType != GSOUDPL4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %s", ipVersion, options.GSOType) + } + if len(in) < 20 { + return 0, fmt.Errorf("length of packet (%d) < minimum ipv4 header size (%d)", len(in), 20) + } + case 6: + if options.GSOType != GSOTCPv6 && options.GSOType != GSOUDPL4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %s", ipVersion, options.GSOType) + } + if len(in) < 40 { + return 0, fmt.Errorf("length of packet (%d) < minimum ipv6 header size (%d)", len(in), 40) + } + default: + return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) + } + + iphLen := int(options.CsumStart) + srcAddrOffset := ipv6SrcAddrOffset + addrLen := 16 + if ipVersion == 4 { + srcAddrOffset = ipv4SrcAddrOffset + addrLen = 4 + } + transportCsumAt := int(options.CsumStart + options.CsumOffset) + var firstTCPSeqNum uint32 + var protocol uint8 + if options.GSOType == GSOTCPv4 || options.GSOType == GSOTCPv6 { + protocol = ipProtoTCP + if len(in) < int(options.CsumStart)+20 { + return 0, fmt.Errorf("length of packet (%d) < GSO CsumStart (%d) + minimum TCP header size (%d)", + len(in), options.CsumStart, 20) + } + firstTCPSeqNum = binary.BigEndian.Uint32(in[options.CsumStart+4:]) + } else { + protocol = ipProtoUDP + } + nextSegmentDataAt := int(options.HdrLen) + i := 0 + for ; nextSegmentDataAt < len(in); i++ { + if i == len(outBufs) { + return i - 1, ErrTooManySegments + } + nextSegmentEnd := nextSegmentDataAt + int(options.GSOSize) + if nextSegmentEnd > len(in) { + nextSegmentEnd = len(in) + } + segmentDataLen := nextSegmentEnd - nextSegmentDataAt + totalLen := int(options.HdrLen) + segmentDataLen + sizes[i] = totalLen + out := outBufs[i][outOffset:] + + copy(out, in[:iphLen]) + if ipVersion == 4 { + // For IPv4 we are responsible for incrementing the ID field, + // updating the total len field, and recalculating the header + // checksum. + if i > 0 { + id := binary.BigEndian.Uint16(out[4:]) + id += uint16(i) + binary.BigEndian.PutUint16(out[4:], id) + } + out[10], out[11] = 0, 0 // clear ipv4 header checksum + binary.BigEndian.PutUint16(out[2:], uint16(totalLen)) + ipv4CSum := ^checksum(out[:iphLen], 0) + binary.BigEndian.PutUint16(out[10:], ipv4CSum) + } else { + // For IPv6 we are responsible for updating the payload length field. + binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen)) + } + + // copy transport header + copy(out[options.CsumStart:options.HdrLen], in[options.CsumStart:options.HdrLen]) + + if protocol == ipProtoTCP { + // set TCP seq and adjust TCP flags + tcpSeq := firstTCPSeqNum + uint32(options.GSOSize*uint16(i)) + binary.BigEndian.PutUint32(out[options.CsumStart+4:], tcpSeq) + if nextSegmentEnd != len(in) { + // FIN and PSH should only be set on last segment + clearFlags := tcpFlagFIN | tcpFlagPSH + out[options.CsumStart+tcpFlagsOffset] &^= clearFlags + } + } else { + // set UDP header len + binary.BigEndian.PutUint16(out[options.CsumStart+4:], uint16(segmentDataLen)+(options.HdrLen-options.CsumStart)) + } + + // payload + copy(out[options.HdrLen:], in[nextSegmentDataAt:nextSegmentEnd]) + + // transport checksum + out[transportCsumAt], out[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum + transportHeaderLen := int(options.HdrLen - options.CsumStart) + lenForPseudo := uint16(transportHeaderLen + segmentDataLen) + transportCSum := pseudoHeaderChecksum(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo) + transportCSum = ^checksum(out[options.CsumStart:totalLen], transportCSum) + binary.BigEndian.PutUint16(out[options.CsumStart+options.CsumOffset:], transportCSum) + + nextSegmentDataAt += int(options.GSOSize) + } + return i, nil +} diff --git a/tun/offload_linux.go b/tun/offload_linux.go index 9a9d38e79..3f0dc5368 100644 --- a/tun/offload_linux.go +++ b/tun/offload_linux.go @@ -9,6 +9,7 @@ import ( "bytes" "encoding/binary" "errors" + "fmt" "io" "unsafe" @@ -16,14 +17,6 @@ import ( "golang.org/x/sys/unix" ) -const tcpFlagsOffset = 13 - -const ( - tcpFlagFIN uint8 = 0x01 - tcpFlagPSH uint8 = 0x08 - tcpFlagACK uint8 = 0x10 -) - // virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The // kernel symbol is virtio_net_hdr. type virtioNetHdr struct { @@ -35,6 +28,30 @@ type virtioNetHdr struct { csumOffset uint16 } +func (v *virtioNetHdr) toGSOOptions() (GSOOptions, error) { + var gsoType GSOType + switch v.gsoType { + case unix.VIRTIO_NET_HDR_GSO_NONE: + gsoType = GSONone + case unix.VIRTIO_NET_HDR_GSO_TCPV4: + gsoType = GSOTCPv4 + case unix.VIRTIO_NET_HDR_GSO_TCPV6: + gsoType = GSOTCPv6 + case unix.VIRTIO_NET_HDR_GSO_UDP_L4: + gsoType = GSOUDPL4 + default: + return GSOOptions{}, fmt.Errorf("unsupported virtio gsoType: %d", v.gsoType) + } + return GSOOptions{ + GSOType: gsoType, + HdrLen: v.hdrLen, + CsumStart: v.csumStart, + CsumOffset: v.csumOffset, + GSOSize: v.gsoSize, + NeedsCsum: v.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0, + }, nil +} + func (v *virtioNetHdr) decode(b []byte) error { if len(b) < virtioNetHdrLen { return io.ErrShortBuffer @@ -510,9 +527,7 @@ const ( ) const ( - ipv4SrcAddrOffset = 12 - ipv6SrcAddrOffset = 8 - maxUint16 = 1<<16 - 1 + maxUint16 = 1<<16 - 1 ) type groResult int @@ -894,100 +909,3 @@ func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGR errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable) return errors.Join(errTCP, errUDP) } - -// gsoSplit splits packets from in into outBuffs, writing the size of each -// element into sizes. It returns the number of buffers populated, and/or an -// error. -func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int, isV6 bool) (int, error) { - iphLen := int(hdr.csumStart) - srcAddrOffset := ipv6SrcAddrOffset - addrLen := 16 - if !isV6 { - in[10], in[11] = 0, 0 // clear ipv4 header checksum - srcAddrOffset = ipv4SrcAddrOffset - addrLen = 4 - } - transportCsumAt := int(hdr.csumStart + hdr.csumOffset) - in[transportCsumAt], in[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum - var firstTCPSeqNum uint32 - var protocol uint8 - if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV6 { - protocol = unix.IPPROTO_TCP - firstTCPSeqNum = binary.BigEndian.Uint32(in[hdr.csumStart+4:]) - } else { - protocol = unix.IPPROTO_UDP - } - nextSegmentDataAt := int(hdr.hdrLen) - i := 0 - for ; nextSegmentDataAt < len(in); i++ { - if i == len(outBuffs) { - return i - 1, ErrTooManySegments - } - nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize) - if nextSegmentEnd > len(in) { - nextSegmentEnd = len(in) - } - segmentDataLen := nextSegmentEnd - nextSegmentDataAt - totalLen := int(hdr.hdrLen) + segmentDataLen - sizes[i] = totalLen - out := outBuffs[i][outOffset:] - - copy(out, in[:iphLen]) - if !isV6 { - // For IPv4 we are responsible for incrementing the ID field, - // updating the total len field, and recalculating the header - // checksum. - if i > 0 { - id := binary.BigEndian.Uint16(out[4:]) - id += uint16(i) - binary.BigEndian.PutUint16(out[4:], id) - } - binary.BigEndian.PutUint16(out[2:], uint16(totalLen)) - ipv4CSum := ^checksum(out[:iphLen], 0) - binary.BigEndian.PutUint16(out[10:], ipv4CSum) - } else { - // For IPv6 we are responsible for updating the payload length field. - binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen)) - } - - // copy transport header - copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen]) - - if protocol == unix.IPPROTO_TCP { - // set TCP seq and adjust TCP flags - tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i)) - binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq) - if nextSegmentEnd != len(in) { - // FIN and PSH should only be set on last segment - clearFlags := tcpFlagFIN | tcpFlagPSH - out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags - } - } else { - // set UDP header len - binary.BigEndian.PutUint16(out[hdr.csumStart+4:], uint16(segmentDataLen)+(hdr.hdrLen-hdr.csumStart)) - } - - // payload - copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd]) - - // transport checksum - transportHeaderLen := int(hdr.hdrLen - hdr.csumStart) - lenForPseudo := uint16(transportHeaderLen + segmentDataLen) - transportCSum := pseudoHeaderChecksum(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo) - transportCSum = ^checksum(out[hdr.csumStart:totalLen], transportCSum) - binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], transportCSum) - - nextSegmentDataAt += int(hdr.gsoSize) - } - return i, nil -} - -func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error { - cSumAt := cSumStart + cSumOffset - // The initial value at the checksum offset should be summed with the - // checksum we compute. This is typically the pseudo-header checksum. - initial := binary.BigEndian.Uint16(in[cSumAt:]) - in[cSumAt], in[cSumAt+1] = 0, 0 - binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], initial)) - return nil -} diff --git a/tun/offload_test.go b/tun/offload_test.go new file mode 100644 index 000000000..82a37b9cc --- /dev/null +++ b/tun/offload_test.go @@ -0,0 +1,95 @@ +package tun + +import ( + "net/netip" + "testing" + + "github.com/tailscale/wireguard-go/conn" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +func Fuzz_GSOSplit(f *testing.F) { + const segmentSize = 100 + + tcpFields := &header.TCPFields{ + SrcPort: 1, + DstPort: 1, + SeqNum: 1, + AckNum: 1, + DataOffset: 20, + Flags: header.TCPFlagAck | header.TCPFlagPsh, + WindowSize: 3000, + } + udpFields := &header.UDPFields{ + SrcPort: 1, + DstPort: 1, + Length: 8 + segmentSize, + } + + gsoTCPv4 := make([]byte, 20+20+segmentSize) + header.IPv4(gsoTCPv4).Encode(&header.IPv4Fields{ + SrcAddr: tcpip.AddrFromSlice(netip.MustParseAddr("192.0.2.1").AsSlice()), + DstAddr: tcpip.AddrFromSlice(netip.MustParseAddr("192.0.2.2").AsSlice()), + Protocol: ipProtoTCP, + TTL: 64, + TotalLength: uint16(len(gsoTCPv4)), + }) + header.TCP(gsoTCPv4[20:]).Encode(tcpFields) + + gsoUDPv4 := make([]byte, 20+8+segmentSize) + header.IPv4(gsoUDPv4).Encode(&header.IPv4Fields{ + SrcAddr: tcpip.AddrFromSlice(netip.MustParseAddr("192.0.2.1").AsSlice()), + DstAddr: tcpip.AddrFromSlice(netip.MustParseAddr("192.0.2.2").AsSlice()), + Protocol: ipProtoUDP, + TTL: 64, + TotalLength: uint16(len(gsoUDPv4)), + }) + header.UDP(gsoTCPv4[20:]).Encode(udpFields) + + gsoTCPv6 := make([]byte, 40+20+segmentSize) + header.IPv6(gsoTCPv6).Encode(&header.IPv6Fields{ + SrcAddr: tcpip.AddrFromSlice(netip.MustParseAddr("2001:db8::1").AsSlice()), + DstAddr: tcpip.AddrFromSlice(netip.MustParseAddr("2001:db8::2").AsSlice()), + TransportProtocol: ipProtoTCP, + HopLimit: 64, + PayloadLength: uint16(20 + segmentSize), + }) + header.TCP(gsoTCPv6[40:]).Encode(tcpFields) + + gsoUDPv6 := make([]byte, 40+8+segmentSize) + header.IPv6(gsoUDPv6).Encode(&header.IPv6Fields{ + SrcAddr: tcpip.AddrFromSlice(netip.MustParseAddr("2001:db8::1").AsSlice()), + DstAddr: tcpip.AddrFromSlice(netip.MustParseAddr("2001:db8::2").AsSlice()), + TransportProtocol: ipProtoUDP, + HopLimit: 64, + PayloadLength: uint16(8 + segmentSize), + }) + header.UDP(gsoUDPv6[20:]).Encode(udpFields) + + out := make([][]byte, conn.IdealBatchSize) + for i := range out { + out[i] = make([]byte, 65535) + } + sizes := make([]int, conn.IdealBatchSize) + + f.Add(gsoTCPv4, int(GSOTCPv4), uint16(40), uint16(20), uint16(16), uint16(100), false) + f.Add(gsoUDPv4, int(GSOUDPL4), uint16(28), uint16(20), uint16(6), uint16(100), false) + f.Add(gsoTCPv6, int(GSOTCPv6), uint16(60), uint16(40), uint16(16), uint16(100), false) + f.Add(gsoUDPv6, int(GSOUDPL4), uint16(48), uint16(40), uint16(6), uint16(100), false) + + f.Fuzz(func(t *testing.T, pkt []byte, gsoType int, hdrLen, csumStart, csumOffset, gsoSize uint16, needsCsum bool) { + options := GSOOptions{ + GSOType: GSOType(gsoType), + HdrLen: hdrLen, + CsumStart: csumStart, + CsumOffset: csumOffset, + GSOSize: gsoSize, + NeedsCsum: needsCsum, + } + n, _ := GSOSplit(pkt, options, out, sizes, 0) + if n > len(sizes) { + t.Errorf("n (%d) > len(sizes): %d", n, len(sizes)) + } + }) +} diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 6aa03d4b5..664eecc5e 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -380,73 +380,32 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e return 0, err } in = in[virtioNetHdrLen:] - if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE { - if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { - // This means CHECKSUM_PARTIAL in skb context. We are responsible - // for computing the checksum starting at hdr.csumStart and placing - // at hdr.csumOffset. - err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset) - if err != nil { - return 0, err - } - } - if len(in) > len(bufs[0][offset:]) { - return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:])) - } - n := copy(bufs[0][offset:], in) - sizes[0] = n - return 1, nil - } - if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { - return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType) - } - ipVersion := in[0] >> 4 - switch ipVersion { - case 4: - if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { - return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) - } - case 6: - if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { - return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) - } - default: - return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) + options, err := hdr.toGSOOptions() + if err != nil { + return 0, err } - // Don't trust hdr.hdrLen from the kernel as it can be equal to the length + // Don't trust HdrLen from the kernel as it can be equal to the length // of the entire first packet when the kernel is handling it as part of a // FORWARD path. Instead, parse the transport header length and add it onto - // csumStart, which is synonymous for IP header length. - if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 { - hdr.hdrLen = hdr.csumStart + 8 - } else { - if len(in) <= int(hdr.csumStart+12) { + // CsumStart, which is synonymous for IP header length. + if options.GSOType == GSOUDPL4 { + options.HdrLen = options.CsumStart + 8 + } else if options.GSOType != GSONone { + if len(in) <= int(options.CsumStart+12) { return 0, errors.New("packet is too short") } - tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) + tcpHLen := uint16(in[options.CsumStart+12] >> 4 * 4) if tcpHLen < 20 || tcpHLen > 60 { // A TCP header must be between 20 and 60 bytes in length. return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) } - hdr.hdrLen = hdr.csumStart + tcpHLen - } - - if len(in) < int(hdr.hdrLen) { - return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen) - } - - if hdr.hdrLen < hdr.csumStart { - return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart) - } - cSumAt := int(hdr.csumStart + hdr.csumOffset) - if cSumAt+1 >= len(in) { - return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) + options.HdrLen = options.CsumStart + tcpHLen } - return gsoSplit(in, hdr, bufs, sizes, offset, ipVersion == 6) + return GSOSplit(in, options, bufs, sizes, offset) } func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { From 6c039a188c2d15592023f851731cd03214af32ec Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Tue, 23 Jul 2024 13:34:34 -0700 Subject: [PATCH 23/39] tun: export optimized IP checksum funcs External implementers of tun.Device may support GRO, requiring checksum offload. Signed-off-by: Jordan Whited --- tun/checksum.go | 4 +++- tun/checksum_amd64.go | 11 +++++++---- tun/checksum_generic.go | 2 +- tun/offload.go | 8 ++++---- tun/offload_linux.go | 16 ++++++++-------- 5 files changed, 23 insertions(+), 18 deletions(-) diff --git a/tun/checksum.go b/tun/checksum.go index ee3f35960..6634050c3 100644 --- a/tun/checksum.go +++ b/tun/checksum.go @@ -702,7 +702,9 @@ func pseudoHeaderChecksum32(protocol uint8, srcAddr, dstAddr []byte, totalLen ui return foldedSum } -func pseudoHeaderChecksum(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 { +// PseudoHeaderChecksum computes an IP pseudo-header checksum. srcAddr and +// dstAddr must be 4 or 16 bytes in length. +func PseudoHeaderChecksum(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 { if strconv.IntSize < 64 { return pseudoHeaderChecksum32(protocol, srcAddr, dstAddr, totalLen) } diff --git a/tun/checksum_amd64.go b/tun/checksum_amd64.go index 5e87693b6..4fb684ec7 100644 --- a/tun/checksum_amd64.go +++ b/tun/checksum_amd64.go @@ -2,12 +2,15 @@ package tun import "golang.org/x/sys/cpu" -// checksum computes an IP checksum starting with the provided initial value. -// The length of data should be at least 128 bytes for best performance. Smaller -// buffers will still compute a correct result. For best performance with -// smaller buffers, use shortChecksum(). var checksum = checksumAMD64 +// Checksum computes an IP checksum starting with the provided initial value. +// The length of data should be at least 128 bytes for best performance. Smaller +// buffers will still compute a correct result. +func Checksum(data []byte, initial uint16) uint16 { + return checksum(data, initial) +} + func init() { if cpu.X86.HasAVX && cpu.X86.HasAVX2 && cpu.X86.HasBMI2 { checksum = checksumAVX2 diff --git a/tun/checksum_generic.go b/tun/checksum_generic.go index d0bfb697c..2ef201a1f 100644 --- a/tun/checksum_generic.go +++ b/tun/checksum_generic.go @@ -7,7 +7,7 @@ package tun import "strconv" -func checksum(data []byte, initial uint16) uint16 { +func Checksum(data []byte, initial uint16) uint16 { if strconv.IntSize < 64 { return checksumGeneric32(data, initial) } diff --git a/tun/offload.go b/tun/offload.go index 6627e46e4..6db437c34 100644 --- a/tun/offload.go +++ b/tun/offload.go @@ -102,7 +102,7 @@ func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outO // the checksum we compute. This is typically the pseudo-header sum. initial := binary.BigEndian.Uint16(in[cSumAt:]) in[cSumAt], in[cSumAt+1] = 0, 0 - binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[options.CsumStart:], initial)) + binary.BigEndian.PutUint16(in[cSumAt:], ^Checksum(in[options.CsumStart:], initial)) } sizes[0] = copy(outBufs[0][outOffset:], in) return 1, nil @@ -179,7 +179,7 @@ func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outO } out[10], out[11] = 0, 0 // clear ipv4 header checksum binary.BigEndian.PutUint16(out[2:], uint16(totalLen)) - ipv4CSum := ^checksum(out[:iphLen], 0) + ipv4CSum := ^Checksum(out[:iphLen], 0) binary.BigEndian.PutUint16(out[10:], ipv4CSum) } else { // For IPv6 we are responsible for updating the payload length field. @@ -210,8 +210,8 @@ func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outO out[transportCsumAt], out[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum transportHeaderLen := int(options.HdrLen - options.CsumStart) lenForPseudo := uint16(transportHeaderLen + segmentDataLen) - transportCSum := pseudoHeaderChecksum(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo) - transportCSum = ^checksum(out[options.CsumStart:totalLen], transportCSum) + transportCSum := PseudoHeaderChecksum(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo) + transportCSum = ^Checksum(out[options.CsumStart:totalLen], transportCSum) binary.BigEndian.PutUint16(out[options.CsumStart+options.CsumOffset:], transportCSum) nextSegmentDataAt += int(options.GSOSize) diff --git a/tun/offload_linux.go b/tun/offload_linux.go index 3f0dc5368..fe3440162 100644 --- a/tun/offload_linux.go +++ b/tun/offload_linux.go @@ -410,8 +410,8 @@ func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool { addrSize = 16 } lenForPseudo := uint16(len(pkt) - int(iphLen)) - cSum := pseudoHeaderChecksum(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo) - return ^checksum(pkt[iphLen:], cSum) == 0 + cSum := PseudoHeaderChecksum(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo) + return ^Checksum(pkt[iphLen:], cSum) == 0 } // coalesceResult represents the result of attempting to coalesce two TCP @@ -659,7 +659,7 @@ func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) e hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 pkt[10], pkt[11] = 0, 0 binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length - iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + iphCSum := ^Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field } err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) @@ -679,8 +679,8 @@ func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) e srcAddrAt := offset + addrOffset srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] - psum := pseudoHeaderChecksum(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) - binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) + psum := PseudoHeaderChecksum(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], Checksum([]byte{}, psum)) } else { hdr := virtioNetHdr{} err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) @@ -716,7 +716,7 @@ func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) e } else { pkt[10], pkt[11] = 0, 0 binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length - iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + iphCSum := ^Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field } err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) @@ -739,8 +739,8 @@ func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) e srcAddrAt := offset + addrOffset srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] - psum := pseudoHeaderChecksum(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) - binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) + psum := PseudoHeaderChecksum(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], Checksum([]byte{}, psum)) } else { hdr := virtioNetHdr{} err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) From 71393c576b98c1903cd8d31450b603535586f81e Mon Sep 17 00:00:00 2001 From: Adrian Dewhurst Date: Wed, 31 Jul 2024 16:08:04 -0400 Subject: [PATCH 24/39] tun: fix checksum test failures on non-4KiB page sizes When generating page-aligned random bytes, random data started at the beginning of the buffer that will be chopped off. When the page size differs, the start of the returned slice is different than expected for the expected checksums, causing the tests to fail. --- tun/checksum_test.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tun/checksum_test.go b/tun/checksum_test.go index c40efc996..f5b8f1869 100644 --- a/tun/checksum_test.go +++ b/tun/checksum_test.go @@ -21,26 +21,33 @@ type archChecksumDetails struct { f func([]byte, uint16) uint16 } -func deterministicRandomBytes(seed int64, length int) []byte { +func fillRandomBuffer(seed int64, buf []byte) { rng := rand.New(rand.NewSource(seed)) - buf := make([]byte, length) n, err := rng.Read(buf) if err != nil { panic(err) } - if n != length { + if n != len(buf) { panic("incomplete random buffer") } +} + +func deterministicRandomBytes(seed int64, length int) []byte { + buf := make([]byte, length) + fillRandomBuffer(seed, buf) return buf } func getPageAlignedRandomBytes(seed int64, length int) []byte { alignment := syscall.Getpagesize() - buf := deterministicRandomBytes(seed, length+(alignment-1)) + buf := make([]byte, length+(alignment-1)) bufPtr := uintptr(unsafe.Pointer(&buf[0])) alignedBufPtr := (bufPtr + uintptr(alignment-1)) & ^uintptr(alignment-1) alignedStart := int(alignedBufPtr - bufPtr) - return buf[alignedStart:] + + buf = buf[alignedStart : alignedStart+length] + fillRandomBuffer(seed, buf) + return buf } func TestChecksum(t *testing.T) { From 799c1978fafc07fae8a15a42c8535b70e6c69e6e Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Wed, 4 Sep 2024 12:17:44 -0700 Subject: [PATCH 25/39] tun: add method for disabling TCP GRO on Linux torvalds/linux@e269d79c7d35aa3808b1f3c1737d63dab504ddc8 broke virtio_net TCP & UDP GRO causing GRO writes to return EINVAL. The bug was then resolved later in torvalds/linux@89add40066f9ed9abe5f7f886fe5789ff7e0c50e. The offending commit was pulled into various LTS releases. Updates tailscale/tailscale#13041 Signed-off-by: Jordan Whited --- tun/offload_linux.go | 18 ++++----- tun/offload_linux_test.go | 82 ++++++++++++++++++++++----------------- tun/tun.go | 23 +++++++++-- tun/tun_linux.go | 45 ++++++++++++++++++--- 4 files changed, 114 insertions(+), 54 deletions(-) diff --git a/tun/offload_linux.go b/tun/offload_linux.go index fe3440162..fb6ac5b94 100644 --- a/tun/offload_linux.go +++ b/tun/offload_linux.go @@ -763,7 +763,7 @@ const ( udp6GROCandidate ) -func packetIsGROCandidate(b []byte, canUDPGRO bool) groCandidateType { +func packetIsGROCandidate(b []byte, gro groDisablementFlags) groCandidateType { if len(b) < 28 { return notGROCandidate } @@ -772,17 +772,17 @@ func packetIsGROCandidate(b []byte, canUDPGRO bool) groCandidateType { // IPv4 packets w/IP options do not coalesce return notGROCandidate } - if b[9] == unix.IPPROTO_TCP && len(b) >= 40 { + if b[9] == unix.IPPROTO_TCP && len(b) >= 40 && gro.canTCPGRO() { return tcp4GROCandidate } - if b[9] == unix.IPPROTO_UDP && canUDPGRO { + if b[9] == unix.IPPROTO_UDP && gro.canUDPGRO() { return udp4GROCandidate } } else if b[0]>>4 == 6 { - if b[6] == unix.IPPROTO_TCP && len(b) >= 60 { + if b[6] == unix.IPPROTO_TCP && len(b) >= 60 && gro.canTCPGRO() { return tcp6GROCandidate } - if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && canUDPGRO { + if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && gro.canUDPGRO() { return udp6GROCandidate } } @@ -875,15 +875,15 @@ func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) // handleGRO evaluates bufs for GRO, and writes the indices of the resulting // packets into toWrite. toWrite, tcpTable, and udpTable should initially be // empty (but non-nil), and are passed in to save allocs as the caller may reset -// and recycle them across vectors of packets. canUDPGRO indicates if UDP GRO is -// supported. -func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, canUDPGRO bool, toWrite *[]int) error { +// and recycle them across vectors of packets. gro indicates if TCP and UDP GRO +// are supported/enabled. +func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, gro groDisablementFlags, toWrite *[]int) error { for i := range bufs { if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { return errors.New("invalid offset") } var result groResult - switch packetIsGROCandidate(bufs[i][offset:], canUDPGRO) { + switch packetIsGROCandidate(bufs[i][offset:], gro) { case tcp4GROCandidate: result = tcpGRO(bufs, offset, i, tcpTable, false) case tcp6GROCandidate: diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go index 91f394108..407037863 100644 --- a/tun/offload_linux_test.go +++ b/tun/offload_linux_test.go @@ -286,11 +286,11 @@ func Fuzz_handleGRO(f *testing.F) { pkt9 := udp6Packet(ip6PortA, ip6PortB, 100) pkt10 := udp6Packet(ip6PortA, ip6PortB, 100) pkt11 := udp6Packet(ip6PortA, ip6PortC, 100) - f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, true, offset) - f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, canUDPGRO bool, offset int) { + f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, 0, offset) + f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, gro int, offset int) { pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11} toWrite := make([]int, 0, len(pkts)) - handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, &toWrite) + handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), groDisablementFlags(gro), &toWrite) if len(toWrite) > len(pkts) { t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) } @@ -311,7 +311,7 @@ func Test_handleGRO(t *testing.T) { tests := []struct { name string pktsIn [][]byte - canUDPGRO bool + gro groDisablementFlags wantToWrite []int wantLens []int wantErr bool @@ -331,7 +331,7 @@ func Test_handleGRO(t *testing.T) { udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 }, - true, + 0, []int{0, 1, 2, 4, 5, 7, 9}, []int{240, 228, 128, 140, 260, 160, 248}, false, @@ -351,7 +351,7 @@ func Test_handleGRO(t *testing.T) { udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 }, - false, + udpGRODisabled, []int{0, 1, 2, 4, 5, 7, 8, 9, 10}, []int{240, 128, 128, 140, 260, 160, 128, 148, 148}, false, @@ -368,7 +368,7 @@ func Test_handleGRO(t *testing.T) { tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1 tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1 }, - true, + 0, []int{0, 2, 4, 6}, []int{240, 240, 260, 260}, false, @@ -383,7 +383,7 @@ func Test_handleGRO(t *testing.T) { udp4Packet(ip4PortA, ip4PortB, 100), udp4Packet(ip4PortA, ip4PortB, 100), }, - true, + 0, []int{0, 1, 3, 4}, []int{140, 240, 128, 228}, false, @@ -395,7 +395,7 @@ func Test_handleGRO(t *testing.T) { tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100 tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 }, - true, + 0, []int{0}, []int{340}, false, @@ -412,7 +412,7 @@ func Test_handleGRO(t *testing.T) { fields.TTL++ }), }, - true, + 0, []int{0, 1, 2, 3}, []int{140, 140, 128, 128}, false, @@ -429,7 +429,7 @@ func Test_handleGRO(t *testing.T) { fields.TOS++ }), }, - true, + 0, []int{0, 1, 2, 3}, []int{140, 140, 128, 128}, false, @@ -446,7 +446,7 @@ func Test_handleGRO(t *testing.T) { fields.Flags = 1 }), }, - true, + 0, []int{0, 1, 2, 3}, []int{140, 140, 128, 128}, false, @@ -463,7 +463,7 @@ func Test_handleGRO(t *testing.T) { fields.Flags = 2 }), }, - true, + 0, []int{0, 1, 2, 3}, []int{140, 140, 128, 128}, false, @@ -480,7 +480,7 @@ func Test_handleGRO(t *testing.T) { fields.HopLimit++ }), }, - true, + 0, []int{0, 1, 2, 3}, []int{160, 160, 148, 148}, false, @@ -497,7 +497,7 @@ func Test_handleGRO(t *testing.T) { fields.TrafficClass++ }), }, - true, + 0, []int{0, 1, 2, 3}, []int{160, 160, 148, 148}, false, @@ -507,7 +507,7 @@ func Test_handleGRO(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { toWrite := make([]int, 0, len(tt.pktsIn)) - err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, &toWrite) + err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.gro, &toWrite) if err != nil { if tt.wantErr { return @@ -552,99 +552,111 @@ func Test_packetIsGROCandidate(t *testing.T) { udp6TooShort := udp6[:47] tests := []struct { - name string - b []byte - canUDPGRO bool - want groCandidateType + name string + b []byte + gro groDisablementFlags + want groCandidateType }{ { "tcp4", tcp4, - true, + 0, tcp4GROCandidate, }, + { + "tcp4 no support", + tcp4, + tcpGRODisabled, + notGROCandidate, + }, { "tcp6", tcp6, - true, + 0, tcp6GROCandidate, }, + { + "tcp6 no support", + tcp6, + tcpGRODisabled, + notGROCandidate, + }, { "udp4", udp4, - true, + 0, udp4GROCandidate, }, { "udp4 no support", udp4, - false, + udpGRODisabled, notGROCandidate, }, { "udp6", udp6, - true, + 0, udp6GROCandidate, }, { "udp6 no support", udp6, - false, + udpGRODisabled, notGROCandidate, }, { "udp4 too short", udp4TooShort, - true, + 0, notGROCandidate, }, { "udp6 too short", udp6TooShort, - true, + 0, notGROCandidate, }, { "tcp4 too short", tcp4TooShort, - true, + 0, notGROCandidate, }, { "tcp6 too short", tcp6TooShort, - true, + 0, notGROCandidate, }, { "invalid IP version", []byte{0x00}, - true, + 0, notGROCandidate, }, { "invalid IP header len", ip4InvalidHeaderLen, - true, + 0, notGROCandidate, }, { "ip4 invalid protocol", ip4InvalidProtocol, - true, + 0, notGROCandidate, }, { "ip6 invalid protocol", ip6InvalidProtocol, - true, + 0, notGROCandidate, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := packetIsGROCandidate(tt.b, tt.canUDPGRO); got != tt.want { + if got := packetIsGROCandidate(tt.b, tt.gro); got != tt.want { t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want) } }) diff --git a/tun/tun.go b/tun/tun.go index d3c5012c4..719a60631 100644 --- a/tun/tun.go +++ b/tun/tun.go @@ -52,10 +52,25 @@ type Device interface { BatchSize() int } -type LinuxDevice interface { +// GRODevice is a Device extended with methods for disabling GRO. Certain OS +// versions may have offload bugs. Where these bugs negatively impact throughput +// or break connectivity entirely we can use these methods to disable the +// related offload. +// +// Linux has the following known, GRO bugs. +// +// torvalds/linux@e269d79c7d35aa3808b1f3c1737d63dab504ddc8 broke virtio_net +// TCP & UDP GRO causing GRO writes to return EINVAL. The bug was then +// resolved later in +// torvalds/linux@89add40066f9ed9abe5f7f886fe5789ff7e0c50e. The offending +// commit was pulled into various LTS releases. +// +// UDP GRO writes end up blackholing/dropping packets destined for a +// vxlan/geneve interface on kernel versions prior to 6.8.5. +type GRODevice interface { Device - // DisableUDPGRO disables UDP GRO if it is enabled. Certain device drivers - // (e.g. vxlan, geneve) do not properly handle coalesced UDP packets later - // in the stack, resulting in packet loss. + // DisableUDPGRO disables UDP GRO if it is enabled. DisableUDPGRO() + // DisableTCPGRO disables TCP GRO if it is enabled. + DisableTCPGRO() } diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 664eecc5e..4a0338714 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -38,7 +38,6 @@ type NativeTun struct { statusListenersShutdown chan struct{} batchSize int vnetHdr bool - udpGSO bool closeOnce sync.Once @@ -53,7 +52,30 @@ type NativeTun struct { toWrite []int tcpGROTable *tcpGROTable udpGROTable *udpGROTable - udpGRO bool + gro groDisablementFlags +} + +type groDisablementFlags int + +const ( + tcpGRODisabled groDisablementFlags = 1 << iota + udpGRODisabled +) + +func (g *groDisablementFlags) disableTCPGRO() { + *g |= tcpGRODisabled +} + +func (g *groDisablementFlags) canTCPGRO() bool { + return (*g)&tcpGRODisabled == 0 +} + +func (g *groDisablementFlags) disableUDPGRO() { + *g |= udpGRODisabled +} + +func (g *groDisablementFlags) canUDPGRO() bool { + return (*g)&udpGRODisabled == 0 } func (tun *NativeTun) File() *os.File { @@ -346,7 +368,7 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { ) tun.toWrite = tun.toWrite[:0] if tun.vnetHdr { - err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGRO, &tun.toWrite) + err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.gro, &tun.toWrite) if err != nil { return 0, err } @@ -462,9 +484,19 @@ func (tun *NativeTun) BatchSize() int { return tun.batchSize } +// DisableUDPGRO disables UDP GRO if it is enabled. See the GRODevice interface +// for cases where it should be called. func (tun *NativeTun) DisableUDPGRO() { tun.writeOpMu.Lock() - tun.udpGRO = false + tun.gro.disableUDPGRO() + tun.writeOpMu.Unlock() +} + +// DisableTCPGRO disables TCP GRO if it is enabled. See the GRODevice interface +// for cases where it should be called. +func (tun *NativeTun) DisableTCPGRO() { + tun.writeOpMu.Lock() + tun.gro.disableTCPGRO() tun.writeOpMu.Unlock() } @@ -503,8 +535,9 @@ func (tun *NativeTun) initFromFlags(name string) error { tun.batchSize = conn.IdealBatchSize // tunUDPOffloads were added in Linux v6.2. We do not return an // error if they are unsupported at runtime. - tun.udpGSO = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) == nil - tun.udpGRO = tun.udpGSO + if unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) != nil { + tun.gro.disableUDPGRO() + } } else { tun.batchSize = 1 } From 4e883d38c8d363e92e02445e4c664f2c27ffdb88 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 12 Nov 2024 16:53:55 -0800 Subject: [PATCH 26/39] tun: use x/sys/unix IoctlIfreq/NewIfreq to set MTU on Linux The manual struct packing was suspect: https://github.com/tailscale/tailscale/issues/11899 And no need for doing it manually if there's API for it already. Updates tailscale/tailscale#11899 Signed-off-by: Brad Fitzpatrick --- tun/tun_linux.go | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 4a0338714..7cdbf8825 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -269,21 +269,15 @@ func (tun *NativeTun) setMTU(n int) error { defer unix.Close(fd) - // do ioctl call - var ifr [ifReqSize]byte - copy(ifr[:], name) - *(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n) - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCSIFMTU), - uintptr(unsafe.Pointer(&ifr[0])), - ) - - if errno != 0 { - return fmt.Errorf("failed to set MTU of TUN device: %w", errno) + req, err := unix.NewIfreq(name) + if err != nil { + return fmt.Errorf("unix.NewIfreq(%q): %w", name, err) + } + req.SetUint32(uint32(n)) + err = unix.IoctlIfreq(fd, unix.SIOCSIFMTU, req) + if err != nil { + return fmt.Errorf("failed to set MTU of TUN device %q: %w", name, err) } - return nil } From 0b8b35511f19726b3dc1ff25e57061c326ad0e5b Mon Sep 17 00:00:00 2001 From: Nahum Shalman Date: Mon, 2 Dec 2024 14:30:52 +0000 Subject: [PATCH 27/39] ipc: build on solaris/illumos --- ipc/uapi_fake.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ipc/uapi_fake.go b/ipc/uapi_fake.go index 3e3dcdf80..a2e0f85b3 100644 --- a/ipc/uapi_fake.go +++ b/ipc/uapi_fake.go @@ -1,4 +1,4 @@ -//go:build wasm || plan9 || aix +//go:build wasm || plan9 || aix || solaris || illumos /* SPDX-License-Identifier: MIT * From 91a0587fb251a72c28724ee111fe04cf1436ca4c Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 3 Mar 2025 16:01:00 -0800 Subject: [PATCH 28/39] tun: add plan9 support Reviewed-by: James Tucker --- tun/tun_plan9.go | 147 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 tun/tun_plan9.go diff --git a/tun/tun_plan9.go b/tun/tun_plan9.go new file mode 100644 index 000000000..7b66eadf6 --- /dev/null +++ b/tun/tun_plan9.go @@ -0,0 +1,147 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "fmt" + "io" + "os" + "strconv" + "strings" + "sync" +) + +type NativeTun struct { + name string // "/net/ipifc/2" + ctlFile *os.File + dataFile *os.File + events chan Event + errors chan error + closeOnce sync.Once +} + +func CreateTUN(_ string, mtu int) (Device, error) { + ctl, err := os.OpenFile("/net/ipifc/clone", os.O_RDWR, 0) + if err != nil { + return nil, err + } + nbuf := make([]byte, 5) + n, err := ctl.Read(nbuf) + if err != nil { + ctl.Close() + return nil, fmt.Errorf("error reading from clone file: %w", err) + } + ifn, err := strconv.Atoi(strings.TrimSpace(string(nbuf[:n]))) + if err != nil { + ctl.Close() + return nil, fmt.Errorf("error converting clone result %q to int: %w", nbuf[:n], err) + } + + if _, err := fmt.Fprintf(ctl, "bind pkt\n"); err != nil { + ctl.Close() + return nil, fmt.Errorf("error binding to pkt: %w", err) + } + if mtu > 0 { + if _, err := fmt.Fprintf(ctl, "mtu %d\n", mtu); err != nil { + ctl.Close() + return nil, fmt.Errorf("error setting MTU: %w", err) + } + } + + dataFile, err := os.OpenFile(fmt.Sprintf("/net/ipifc/%d/data", ifn), os.O_RDWR, 0) + if err != nil { + ctl.Close() + return nil, err + } + + tun := &NativeTun{ + ctlFile: ctl, + dataFile: dataFile, + name: fmt.Sprintf("/net/ipifc/%d", ifn), + events: make(chan Event, 10), + errors: make(chan error, 5), + } + tun.events <- EventUp + + return tun, nil +} + +func (tun *NativeTun) Name() (string, error) { + return tun.name, nil +} + +func (tun *NativeTun) File() *os.File { + return tun.ctlFile +} + +func (tun *NativeTun) Events() <-chan Event { + return tun.events +} + +func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { + select { + case err := <-tun.errors: + return 0, err + default: + n, err := tun.dataFile.Read(bufs[0][offset:]) + if n == 1 && bufs[0][offset] == 0 { + // EOF + err = io.EOF + n = 0 + } + sizes[0] = n + return 1, err + } +} + +func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { + for i, buf := range bufs { + if _, err := tun.dataFile.Write(buf[offset:]); err != nil { + return i, err + } + } + return len(bufs), nil +} + +func (tun *NativeTun) Close() error { + var err1, err2 error + tun.closeOnce.Do(func() { + _, err1 := fmt.Fprintf(tun.ctlFile, "unbind\n") + if err := tun.ctlFile.Close(); err != nil && err1 == nil { + err1 = err + } + err2 = tun.dataFile.Close() + }) + if err1 != nil { + return err1 + } + return err2 +} + +func (tun *NativeTun) MTU() (int, error) { + var buf [100]byte + f, err := os.Open(tun.name + "/status") + if err != nil { + return 0, err + } + defer f.Close() + n, err := f.Read(buf[:]) + _, res, ok := strings.Cut(string(buf[:n]), " maxtu ") + if ok { + if mtus, _, ok := strings.Cut(res, " "); ok { + mtu, err := strconv.Atoi(mtus) + if err != nil { + return 0, fmt.Errorf("error converting mtu %q to int: %w", mtus, err) + } + return mtu, nil + } + } + return 0, fmt.Errorf("no 'maxtu' field found in %s/status", tun.name) +} + +func (tun *NativeTun) BatchSize() int { + return 1 +} From 6413b491d4dc3c5e5b295907b0b1242c6272427e Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Wed, 28 May 2025 15:05:06 -0700 Subject: [PATCH 29/39] .github/workflows: establish basic build and test actions jobs Upstream doesn't use GitHub actions for CI as GitHub is simply a mirror. Our workflows involve GitHub, so establish some basic CI jobs. Updates tailscale/corp#28877 Signed-off-by: Jordan Whited --- .github/workflows/test.yml | 68 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 .github/workflows/test.yml diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 000000000..a37099545 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,68 @@ +name: CI + +on: + push: + branches: ["tailscale"] + pull_request: + branches: ["tailscale"] + +jobs: + build: + runs-on: ubuntu-22.04 + strategy: + matrix: + include: + - goos: linux + goarch: amd64 + - goos: linux + goarch: arm64 + - goos: linux + goarch: "386" + - goos: linux + goarch: loong64 + - goos: linux + goarch: arm + goarm: "5" + - goos: linux + goarch: arm + goarm: "7" + # macOS + - goos: darwin + goarch: amd64 + - goos: darwin + goarch: arm64 + # Windows + - goos: windows + goarch: amd64 + - goos: windows + goarch: arm64 + # BSDs + - goos: freebsd + goarch: amd64 + - goos: openbsd + goarch: amd64 + steps: + - name: checkout + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: setup go + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5.5.0 + with: + go-version-file: go.mod + - name: build + run: go build ./... + env: + GOOS: ${{ matrix.goos }} + GOARCH: ${{ matrix.goarch }} + CGO_ENABLED: "0" + + test: + runs-on: ubuntu-22.04 + steps: + - name: checkout + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: setup go + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5.5.0 + with: + go-version-file: go.mod + - name: test + run: go test -race -v ./... From 2b555120c89de2b197f2cfed2787ccda28616c8c Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Mon, 11 Dec 2023 16:35:57 +0100 Subject: [PATCH 30/39] device: do atomic 64-bit add outside of vector loop Only bother updating the rxBytes counter once we've processed a whole vector, since additions are atomic. cherry picked from commit WireGuard/wireguard-go@542e565baa776ed4c5c55b73ef9aa38d33d55197 Updates tailscale/corp#28879 Signed-off-by: Jason A. Donenfeld --- device/receive.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/device/receive.go b/device/receive.go index af2db44e8..01804b7f3 100644 --- a/device/receive.go +++ b/device/receive.go @@ -447,6 +447,7 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { elemsContainer.Lock() validTailPacket := -1 dataPacketReceived := false + rxBytesLen := uint64(0) for i, elem := range elemsContainer.elems { if elem.packet == nil { // decryption failed @@ -463,7 +464,7 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { peer.timersHandshakeComplete() peer.SendStagedPackets() } - peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize)) + rxBytesLen += uint64(len(elem.packet) + MinMessageSize) if len(elem.packet) == 0 { device.log.Verbosef("%v - Receiving keepalive packet", peer) @@ -512,6 +513,8 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) } + + peer.rxBytes.Add(rxBytesLen) if validTailPacket >= 0 { peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint) peer.keepKeyFreshReceiving() From f74ff38c79b272da9b154a0a324cd51ec6a1fee8 Mon Sep 17 00:00:00 2001 From: Martin Basovnik Date: Fri, 10 Nov 2023 11:10:12 +0100 Subject: [PATCH 31/39] device: fix possible deadlock in close method There is a possible deadlock in `device.Close()` when you try to close the device very soon after its start. The problem is that two different methods acquire the same locks in different order: 1. device.Close() - device.ipcMutex.Lock() - device.state.Lock() 2. device.changeState(deviceState) - device.state.Lock() - device.ipcMutex.Lock() Reproducer: func TestDevice_deadlock(t *testing.T) { d := randDevice(t) d.Close() } Problem: $ go clean -testcache && go test -race -timeout 3s -run TestDevice_deadlock ./device | grep -A 10 sync.runtime_SemacquireMutex sync.runtime_SemacquireMutex(0xc000117d20?, 0x94?, 0x0?) /usr/local/opt/go/libexec/src/runtime/sema.go:77 +0x25 sync.(*Mutex).lockSlow(0xc000130518) /usr/local/opt/go/libexec/src/sync/mutex.go:171 +0x213 sync.(*Mutex).Lock(0xc000130518) /usr/local/opt/go/libexec/src/sync/mutex.go:90 +0x55 golang.zx2c4.com/wireguard/device.(*Device).Close(0xc000130500) /Users/martin.basovnik/git/basovnik/wireguard-go/device/device.go:373 +0xb6 golang.zx2c4.com/wireguard/device.TestDevice_deadlock(0x0?) /Users/martin.basovnik/git/basovnik/wireguard-go/device/device_test.go:480 +0x2c testing.tRunner(0xc00014c000, 0x131d7b0) -- sync.runtime_SemacquireMutex(0xc000130564?, 0x60?, 0xc000130548?) /usr/local/opt/go/libexec/src/runtime/sema.go:77 +0x25 sync.(*Mutex).lockSlow(0xc000130750) /usr/local/opt/go/libexec/src/sync/mutex.go:171 +0x213 sync.(*Mutex).Lock(0xc000130750) /usr/local/opt/go/libexec/src/sync/mutex.go:90 +0x55 sync.(*RWMutex).Lock(0xc000130750) /usr/local/opt/go/libexec/src/sync/rwmutex.go:147 +0x45 golang.zx2c4.com/wireguard/device.(*Device).upLocked(0xc000130500) /Users/martin.basovnik/git/basovnik/wireguard-go/device/device.go:179 +0x72 golang.zx2c4.com/wireguard/device.(*Device).changeState(0xc000130500, 0x1) cherry picked from commit WireGuard/wireguard-go@12269c2761734b15625017d8565745096325392f Updates tailscale/corp#28879 Signed-off-by: Martin Basovnik Signed-off-by: Jason A. Donenfeld --- device/device.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/device/device.go b/device/device.go index 86dff0d7e..5b2348564 100644 --- a/device/device.go +++ b/device/device.go @@ -368,10 +368,10 @@ func (device *Device) RemoveAllPeers() { } func (device *Device) Close() { - device.ipcMutex.Lock() - defer device.ipcMutex.Unlock() device.state.Lock() defer device.state.Unlock() + device.ipcMutex.Lock() + defer device.ipcMutex.Unlock() if device.isClosed() { return } From 19f7e298052c939c07a5f9eb2b60fb62ff0d10c1 Mon Sep 17 00:00:00 2001 From: Alexander Yastrebov Date: Thu, 26 Dec 2024 20:36:53 +0100 Subject: [PATCH 32/39] device: reduce RoutineHandshake allocations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reduce allocations by eliminating byte reader, hand-rolled decoding and reusing message structs. Synthetic benchmark: var msgSink MessageInitiation func BenchmarkMessageInitiationUnmarshal(b *testing.B) { packet := make([]byte, MessageInitiationSize) reader := bytes.NewReader(packet) err := binary.Read(reader, binary.LittleEndian, &msgSink) if err != nil { b.Fatal(err) } b.Run("binary.Read", func(b *testing.B) { b.ReportAllocs() for range b.N { reader := bytes.NewReader(packet) _ = binary.Read(reader, binary.LittleEndian, &msgSink) } }) b.Run("unmarshal", func(b *testing.B) { b.ReportAllocs() for range b.N { _ = msgSink.unmarshal(packet) } }) } Results: │ - │ │ sec/op │ MessageInitiationUnmarshal/binary.Read-8 1.508µ ± 2% MessageInitiationUnmarshal/unmarshal-8 12.66n ± 2% │ - │ │ B/op │ MessageInitiationUnmarshal/binary.Read-8 208.0 ± 0% MessageInitiationUnmarshal/unmarshal-8 0.000 ± 0% │ - │ │ allocs/op │ MessageInitiationUnmarshal/binary.Read-8 2.000 ± 0% MessageInitiationUnmarshal/unmarshal-8 0.000 ± 0% cherry picked from commit WireGuard/wireguard-go@9e7529c3d2d0c54f4d5384c01645a9279e4740ae Updates tailscale/corp#28879 Signed-off-by: Alexander Yastrebov Signed-off-by: Jason A. Donenfeld --- device/noise-protocol.go | 48 ++++++++++++++++++++++++++++++++++++++++ device/receive.go | 10 +++------ 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 2d8f98426..2e7e9ae33 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -6,6 +6,7 @@ package device import ( + "encoding/binary" "errors" "fmt" "sync" @@ -115,6 +116,53 @@ type MessageCookieReply struct { Cookie [blake2s.Size128 + poly1305.TagSize]byte } +var errMessageTooShort = errors.New("message too short") + +func (msg *MessageInitiation) unmarshal(b []byte) error { + if len(b) < MessageInitiationSize { + return errMessageTooShort + } + + msg.Type = binary.LittleEndian.Uint32(b) + msg.Sender = binary.LittleEndian.Uint32(b[4:]) + copy(msg.Ephemeral[:], b[8:]) + copy(msg.Static[:], b[8+len(msg.Ephemeral):]) + copy(msg.Timestamp[:], b[8+len(msg.Ephemeral)+len(msg.Static):]) + copy(msg.MAC1[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):]) + copy(msg.MAC2[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp)+len(msg.MAC1):]) + + return nil +} + +func (msg *MessageResponse) unmarshal(b []byte) error { + if len(b) < MessageResponseSize { + return errMessageTooShort + } + + msg.Type = binary.LittleEndian.Uint32(b) + msg.Sender = binary.LittleEndian.Uint32(b[4:]) + msg.Receiver = binary.LittleEndian.Uint32(b[8:]) + copy(msg.Ephemeral[:], b[12:]) + copy(msg.Empty[:], b[12+len(msg.Ephemeral):]) + copy(msg.MAC1[:], b[12+len(msg.Ephemeral)+len(msg.Empty):]) + copy(msg.MAC2[:], b[12+len(msg.Ephemeral)+len(msg.Empty)+len(msg.MAC1):]) + + return nil +} + +func (msg *MessageCookieReply) unmarshal(b []byte) error { + if len(b) < MessageCookieReplySize { + return errMessageTooShort + } + + msg.Type = binary.LittleEndian.Uint32(b) + msg.Receiver = binary.LittleEndian.Uint32(b[4:]) + copy(msg.Nonce[:], b[8:]) + copy(msg.Cookie[:], b[8+len(msg.Nonce):]) + + return nil +} + type Handshake struct { state handshakeState mutex sync.RWMutex diff --git a/device/receive.go b/device/receive.go index 01804b7f3..bc37f915e 100644 --- a/device/receive.go +++ b/device/receive.go @@ -6,7 +6,6 @@ package device import ( - "bytes" "encoding/binary" "errors" "net" @@ -287,8 +286,7 @@ func (device *Device) RoutineHandshake(id int) { // unmarshal packet var reply MessageCookieReply - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &reply) + err := reply.unmarshal(elem.packet) if err != nil { device.log.Verbosef("Failed to decode cookie reply") goto skip @@ -353,8 +351,7 @@ func (device *Device) RoutineHandshake(id int) { // unmarshal var msg MessageInitiation - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) + err := msg.unmarshal(elem.packet) if err != nil { device.log.Errorf("Failed to decode initiation message") goto skip @@ -386,8 +383,7 @@ func (device *Device) RoutineHandshake(id int) { // unmarshal var msg MessageResponse - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) + err := msg.unmarshal(elem.packet) if err != nil { device.log.Errorf("Failed to decode response message") goto skip From ae0636254c5f6232d78b7d9a9908eee434639406 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Thu, 15 May 2025 16:48:14 +0200 Subject: [PATCH 33/39] device: make unmarshall length checks exact This is already enforced in receive.go, but if these unmarshallers are to have error return values anyway, make them as explicit as possible. cherry picked from commit WireGuard/wireguard-go@842888ac5c93ccc5ee6344eceaadf783fcf1e243 Updates tailscale/corp#28879 Signed-off-by: Jason A. Donenfeld --- device/noise-protocol.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 2e7e9ae33..1e99f6886 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -116,11 +116,11 @@ type MessageCookieReply struct { Cookie [blake2s.Size128 + poly1305.TagSize]byte } -var errMessageTooShort = errors.New("message too short") +var errMessageLengthMismatch = errors.New("message length mismatch") func (msg *MessageInitiation) unmarshal(b []byte) error { - if len(b) < MessageInitiationSize { - return errMessageTooShort + if len(b) != MessageInitiationSize { + return errMessageLengthMismatch } msg.Type = binary.LittleEndian.Uint32(b) @@ -135,8 +135,8 @@ func (msg *MessageInitiation) unmarshal(b []byte) error { } func (msg *MessageResponse) unmarshal(b []byte) error { - if len(b) < MessageResponseSize { - return errMessageTooShort + if len(b) != MessageResponseSize { + return errMessageLengthMismatch } msg.Type = binary.LittleEndian.Uint32(b) @@ -151,8 +151,8 @@ func (msg *MessageResponse) unmarshal(b []byte) error { } func (msg *MessageCookieReply) unmarshal(b []byte) error { - if len(b) < MessageCookieReplySize { - return errMessageTooShort + if len(b) != MessageCookieReplySize { + return errMessageLengthMismatch } msg.Type = binary.LittleEndian.Uint32(b) From 022546570f38af621ea050600e3b1eb43d430a88 Mon Sep 17 00:00:00 2001 From: Alexander Yastrebov Date: Sat, 17 May 2025 11:34:30 +0200 Subject: [PATCH 34/39] device: optimize message encoding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Optimize message encoding by eliminating binary.Write (which internally uses reflection) in favour of hand-rolled encoding. This is companion to 9e7529c3d2d0c54f4d5384c01645a9279e4740ae. Synthetic benchmark: var packetSink []byte func BenchmarkMessageInitiationMarshal(b *testing.B) { var msg MessageInitiation b.Run("binary.Write", func(b *testing.B) { b.ReportAllocs() for range b.N { var buf [MessageInitiationSize]byte writer := bytes.NewBuffer(buf[:0]) _ = binary.Write(writer, binary.LittleEndian, msg) packetSink = writer.Bytes() } }) b.Run("binary.Encode", func(b *testing.B) { b.ReportAllocs() for range b.N { packet := make([]byte, MessageInitiationSize) _, _ = binary.Encode(packet, binary.LittleEndian, msg) packetSink = packet } }) b.Run("marshal", func(b *testing.B) { b.ReportAllocs() for range b.N { packet := make([]byte, MessageInitiationSize) _ = msg.marshal(packet) packetSink = packet } }) } Results: │ - │ │ sec/op │ MessageInitiationMarshal/binary.Write-8 1.337µ ± 0% MessageInitiationMarshal/binary.Encode-8 1.242µ ± 0% MessageInitiationMarshal/marshal-8 53.05n ± 1% │ - │ │ B/op │ MessageInitiationMarshal/binary.Write-8 368.0 ± 0% MessageInitiationMarshal/binary.Encode-8 160.0 ± 0% MessageInitiationMarshal/marshal-8 160.0 ± 0% │ - │ │ allocs/op │ MessageInitiationMarshal/binary.Write-8 3.000 ± 0% MessageInitiationMarshal/binary.Encode-8 1.000 ± 0% MessageInitiationMarshal/marshal-8 1.000 ± 0% cherry picked from commit WireGuard/wireguard-go@264889f0bbdf9250bb8389a637dd5f38389bfe0b Updates tailscale/corp#28879 Signed-off-by: Alexander Yastrebov Signed-off-by: Jason A. Donenfeld --- device/noise-protocol.go | 45 ++++++++++++++++++++++++++++++++++++++++ device/send.go | 21 +++++++------------ 2 files changed, 53 insertions(+), 13 deletions(-) diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 1e99f6886..cb4dedb11 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -134,6 +134,22 @@ func (msg *MessageInitiation) unmarshal(b []byte) error { return nil } +func (msg *MessageInitiation) marshal(b []byte) error { + if len(b) != MessageInitiationSize { + return errMessageLengthMismatch + } + + binary.LittleEndian.PutUint32(b, msg.Type) + binary.LittleEndian.PutUint32(b[4:], msg.Sender) + copy(b[8:], msg.Ephemeral[:]) + copy(b[8+len(msg.Ephemeral):], msg.Static[:]) + copy(b[8+len(msg.Ephemeral)+len(msg.Static):], msg.Timestamp[:]) + copy(b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):], msg.MAC1[:]) + copy(b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp)+len(msg.MAC1):], msg.MAC2[:]) + + return nil +} + func (msg *MessageResponse) unmarshal(b []byte) error { if len(b) != MessageResponseSize { return errMessageLengthMismatch @@ -150,6 +166,22 @@ func (msg *MessageResponse) unmarshal(b []byte) error { return nil } +func (msg *MessageResponse) marshal(b []byte) error { + if len(b) != MessageResponseSize { + return errMessageLengthMismatch + } + + binary.LittleEndian.PutUint32(b, msg.Type) + binary.LittleEndian.PutUint32(b[4:], msg.Sender) + binary.LittleEndian.PutUint32(b[8:], msg.Receiver) + copy(b[12:], msg.Ephemeral[:]) + copy(b[12+len(msg.Ephemeral):], msg.Empty[:]) + copy(b[12+len(msg.Ephemeral)+len(msg.Empty):], msg.MAC1[:]) + copy(b[12+len(msg.Ephemeral)+len(msg.Empty)+len(msg.MAC1):], msg.MAC2[:]) + + return nil +} + func (msg *MessageCookieReply) unmarshal(b []byte) error { if len(b) != MessageCookieReplySize { return errMessageLengthMismatch @@ -163,6 +195,19 @@ func (msg *MessageCookieReply) unmarshal(b []byte) error { return nil } +func (msg *MessageCookieReply) marshal(b []byte) error { + if len(b) != MessageCookieReplySize { + return errMessageLengthMismatch + } + + binary.LittleEndian.PutUint32(b, msg.Type) + binary.LittleEndian.PutUint32(b[4:], msg.Receiver) + copy(b[8:], msg.Nonce[:]) + copy(b[8+len(msg.Nonce):], msg.Cookie[:]) + + return nil +} + type Handshake struct { state handshakeState mutex sync.RWMutex diff --git a/device/send.go b/device/send.go index 8ed2e5f6c..7900f577f 100644 --- a/device/send.go +++ b/device/send.go @@ -6,7 +6,6 @@ package device import ( - "bytes" "encoding/binary" "errors" "net" @@ -124,10 +123,8 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { return err } - var buf [MessageInitiationSize]byte - writer := bytes.NewBuffer(buf[:0]) - binary.Write(writer, binary.LittleEndian, msg) - packet := writer.Bytes() + packet := make([]byte, MessageInitiationSize) + _ = msg.marshal(packet) peer.cookieGenerator.AddMacs(packet) peer.timersAnyAuthenticatedPacketTraversal() @@ -155,10 +152,8 @@ func (peer *Peer) SendHandshakeResponse() error { return err } - var buf [MessageResponseSize]byte - writer := bytes.NewBuffer(buf[:0]) - binary.Write(writer, binary.LittleEndian, response) - packet := writer.Bytes() + packet := make([]byte, MessageResponseSize) + _ = response.marshal(packet) peer.cookieGenerator.AddMacs(packet) err = peer.BeginSymmetricSession() @@ -189,11 +184,11 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) return err } - var buf [MessageCookieReplySize]byte - writer := bytes.NewBuffer(buf[:0]) - binary.Write(writer, binary.LittleEndian, reply) + packet := make([]byte, MessageCookieReplySize) + _ = reply.marshal(packet) // TODO: allocation could be avoided - device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint) + device.net.bind.Send([][]byte{packet}, initiatingElem.endpoint) + return nil } From 65cd6eed7d7f688acea0b24ba64a781a9c58248e Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Fri, 30 May 2025 13:14:43 -0700 Subject: [PATCH 35/39] conn,device: provide 8 free bytes at packet head to conn.Bind.Send() This enables a conn.Bind to bring its own encapsulating transport, e.g. VXLAN/Geneve. Updates tailscale/corp#27502 Signed-off-by: Jordan Whited --- conn/bind_std.go | 9 +++++---- conn/bind_std_test.go | 2 +- conn/bind_windows.go | 3 ++- conn/bindtest/bindtest.go | 3 ++- conn/conn.go | 8 +++++--- device/constants.go | 6 +++--- device/device_test.go | 10 +++++----- device/noise-protocol.go | 15 ++++++++------- device/peer.go | 5 ++++- device/send.go | 36 +++++++++++++++++++++++------------- 10 files changed, 58 insertions(+), 39 deletions(-) diff --git a/conn/bind_std.go b/conn/bind_std.go index 428e52815..fc0563456 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -341,7 +341,7 @@ func (e ErrUDPGSODisabled) Unwrap() error { return e.RetryErr } -func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { +func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint, offset int) error { s.mu.Lock() blackhole := s.blackhole4 conn := s.ipv4 @@ -384,7 +384,7 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { ) retry: if offload { - n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize) + n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, offset, *msgs, setGSOSize) err = s.send(conn, br, (*msgs)[:n]) if err != nil && offload && errShouldDisableUDPGSO(err) { offload = false @@ -401,7 +401,7 @@ retry: } else { for i := range bufs { (*msgs)[i].Addr = ua - (*msgs)[i].Buffers[0] = bufs[i] + (*msgs)[i].Buffers[0] = bufs[i][offset:] setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint)) } err = s.send(conn, br, (*msgs)[:len(bufs)]) @@ -450,7 +450,7 @@ const ( type setGSOFunc func(control *[]byte, gsoSize uint16) -func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int { +func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, offset int, msgs []ipv6.Message, setGSO setGSOFunc) int { var ( base = -1 // index of msg we are currently coalescing into gsoSize int // segmentation size of msgs[base] @@ -462,6 +462,7 @@ func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs maxPayloadLen = maxIPv6PayloadLen } for i, buf := range bufs { + buf = buf[offset:] if i > 0 { msgLen := len(buf) baseLenBefore := len(msgs[base].Buffers[0]) diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go index 34a3c9acf..77af0d925 100644 --- a/conn/bind_std_test.go +++ b/conn/bind_std_test.go @@ -98,7 +98,7 @@ func Test_coalesceMessages(t *testing.T) { msgs[i].Buffers = make([][]byte, 1) msgs[i].OOB = make([]byte, 0, 2) } - got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize) + got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, 0, msgs, mockSetGSOSize) if got != len(tt.wantLens) { t.Fatalf("got len %d want: %d", got, len(tt.wantLens)) } diff --git a/conn/bind_windows.go b/conn/bind_windows.go index 9638b3096..737b475e1 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -486,7 +486,7 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomi return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } -func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error { +func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint, offset int) error { nend, ok := endpoint.(*WinRingEndpoint) if !ok { return ErrWrongEndpointType @@ -494,6 +494,7 @@ func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error { bind.mu.RLock() defer bind.mu.RUnlock() for _, buf := range bufs { + buf = buf[offset:] switch nend.family { case windows.AF_INET: if bind.v4.blackhole { diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go index 836d983ce..741b776c4 100644 --- a/conn/bindtest/bindtest.go +++ b/conn/bindtest/bindtest.go @@ -107,8 +107,9 @@ func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { } } -func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error { +func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint, offset int) error { for _, b := range bufs { + b = b[offset:] select { case <-c.closeSignal: return net.ErrClosed diff --git a/conn/conn.go b/conn/conn.go index 8df5aaa66..5083648ff 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -45,9 +45,11 @@ type Bind interface { // This mark is passed to the kernel as the socket option SO_MARK. SetMark(mark uint32) error - // Send writes one or more packets in bufs to address ep. The length of - // bufs must not exceed BatchSize(). - Send(bufs [][]byte, ep Endpoint) error + // Send writes one or more packets in bufs to address ep. A nonzero offset + // can be used to instruct the Bind on where packet data begins in each + // element of the bufs slice. Space preceding offset is free to use for + // additional encapsulation. The length of bufs must not exceed BatchSize(). + Send(bufs [][]byte, ep Endpoint, offset int) error // ParseEndpoint creates a new endpoint from a string. ParseEndpoint(s string) (Endpoint, error) diff --git a/device/constants.go b/device/constants.go index 59854a126..92c3bdea8 100644 --- a/device/constants.go +++ b/device/constants.go @@ -27,9 +27,9 @@ const ( ) const ( - MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive) - MaxMessageSize = MaxSegmentSize // maximum size of transport message - MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content + MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive) + MaxMessageSize = MaxSegmentSize // maximum size of transport message + MaxContentSize = MaxSegmentSize - MessageTransportSize - MessageEncapsulatingTransportSize // maximum size of transport message content ) /* Implementation constants */ diff --git a/device/device_test.go b/device/device_test.go index 4088b9fab..e44342170 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -426,11 +426,11 @@ type fakeBindSized struct { func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { return nil, 0, nil } -func (b *fakeBindSized) Close() error { return nil } -func (b *fakeBindSized) SetMark(mark uint32) error { return nil } -func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil } -func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } -func (b *fakeBindSized) BatchSize() int { return b.size } +func (b *fakeBindSized) Close() error { return nil } +func (b *fakeBindSized) SetMark(mark uint32) error { return nil } +func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint, offset int) error { return nil } +func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } +func (b *fakeBindSized) BatchSize() int { return b.size } type fakeTUNDeviceSized struct { size int diff --git a/device/noise-protocol.go b/device/noise-protocol.go index cb4dedb11..555ce915a 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -61,13 +61,14 @@ const ( ) const ( - MessageInitiationSize = 148 // size of handshake initiation message - MessageResponseSize = 92 // size of response message - MessageCookieReplySize = 64 // size of cookie reply message - MessageTransportHeaderSize = 16 // size of data preceding content in transport message - MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport - MessageKeepaliveSize = MessageTransportSize // size of keepalive - MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message + MessageInitiationSize = 148 // size of handshake initiation message + MessageResponseSize = 92 // size of response message + MessageCookieReplySize = 64 // size of cookie reply message + MessageTransportHeaderSize = 16 // size of data preceding content in transport message + MessageEncapsulatingTransportSize = 8 // size of optional, free (for use by conn.Bind.Send()) space preceding the transport header + MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport + MessageKeepaliveSize = MessageTransportSize // size of keepalive + MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message ) const ( diff --git a/device/peer.go b/device/peer.go index 876e5daf0..f79a0af29 100644 --- a/device/peer.go +++ b/device/peer.go @@ -113,6 +113,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { return peer, nil } +// SendBuffers sends buffers to peer. WireGuard packet data in each element of +// buffers must be preceded by MessageEncapsulatingTransportSize number of +// bytes. func (peer *Peer) SendBuffers(buffers [][]byte) error { peer.device.net.RLock() defer peer.device.net.RUnlock() @@ -133,7 +136,7 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error { } peer.endpoint.Unlock() - err := peer.device.net.bind.Send(buffers, endpoint) + err := peer.device.net.bind.Send(buffers, endpoint, MessageEncapsulatingTransportSize) if err == nil { var totalLen uint64 for _, b := range buffers { diff --git a/device/send.go b/device/send.go index 7900f577f..bf854b714 100644 --- a/device/send.go +++ b/device/send.go @@ -45,11 +45,15 @@ import ( */ type QueueOutboundElement struct { - buffer *[MaxMessageSize]byte // slice holding the packet data - packet []byte // slice of "buffer" (always!) - nonce uint64 // nonce for encryption - keypair *Keypair // keypair for encryption - peer *Peer // related peer + buffer *[MaxMessageSize]byte // slice holding the packet data + // packet is always a slice of "buffer". The starting offset in buffer + // is either: + // a) MessageEncapsulatingTransportSize+MessageTransportHeaderSize (plaintext) + // b) 0 (post-encryption) + packet []byte + nonce uint64 // nonce for encryption + keypair *Keypair // keypair for encryption + peer *Peer // related peer } type QueueOutboundElementsContainer struct { @@ -123,14 +127,15 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { return err } - packet := make([]byte, MessageInitiationSize) + buf := make([]byte, MessageEncapsulatingTransportSize+MessageInitiationSize) + packet := buf[MessageEncapsulatingTransportSize:] _ = msg.marshal(packet) peer.cookieGenerator.AddMacs(packet) peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffers([][]byte{packet}) + err = peer.SendBuffers([][]byte{buf}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } @@ -152,7 +157,8 @@ func (peer *Peer) SendHandshakeResponse() error { return err } - packet := make([]byte, MessageResponseSize) + buf := make([]byte, MessageEncapsulatingTransportSize+MessageResponseSize) + packet := buf[MessageEncapsulatingTransportSize:] _ = response.marshal(packet) peer.cookieGenerator.AddMacs(packet) @@ -167,7 +173,7 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketSent() // TODO: allocation could be avoided - err = peer.SendBuffers([][]byte{packet}) + err = peer.SendBuffers([][]byte{buf}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) } @@ -184,10 +190,11 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) return err } - packet := make([]byte, MessageCookieReplySize) + buf := make([]byte, MessageEncapsulatingTransportSize+MessageCookieReplySize) + packet := buf[MessageEncapsulatingTransportSize:] _ = reply.marshal(packet) // TODO: allocation could be avoided - device.net.bind.Send([][]byte{packet}, initiatingElem.endpoint) + device.net.bind.Send([][]byte{buf}, initiatingElem.endpoint, MessageEncapsulatingTransportSize) return nil } @@ -220,7 +227,7 @@ func (device *Device) RoutineReadFromTUN() { elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize) count = 0 sizes = make([]int, batchSize) - offset = MessageTransportHeaderSize + offset = MessageEncapsulatingTransportSize + MessageTransportHeaderSize ) for i := range elems { @@ -446,7 +453,7 @@ func (device *Device) RoutineEncryption(id int) { for elemsContainer := range device.queue.encryption.c { for _, elem := range elemsContainer.elems { // populate header fields - header := elem.buffer[:MessageTransportHeaderSize] + header := elem.buffer[MessageEncapsulatingTransportSize : MessageEncapsulatingTransportSize+MessageTransportHeaderSize] fieldType := header[0:4] fieldReceiver := header[4:8] @@ -469,6 +476,9 @@ func (device *Device) RoutineEncryption(id int) { elem.packet, nil, ) + + // re-slice packet to include encapsulating transport space + elem.packet = elem.buffer[:MessageEncapsulatingTransportSize+len(elem.packet)] } elemsContainer.Unlock() } From 24483d7a00033707e47a0322de6970ddc504e454 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Tue, 1 Jul 2025 14:44:37 -0700 Subject: [PATCH 36/39] conn,device: always perform PeerAwareEndpoint check It was previously suppressed if roaming was disabled for the peer. Tailscale always disables roaming as we explicitly configure conn.Endpoint's for all peers. This commit also modifies PeerAwareEndpoint usage such that wireguard-go never uses/sets it as a Peer Endpoint value. In theory we (Tailscale) always disable roaming, so we should always return early from SetEndpointFromPacket(), but this acts as an extra footgun guard and improves clarity around intended usage. Updates tailscale/corp#27502 Updates tailscale/corp#29422 Updates tailscale/corp#30042 Signed-off-by: Jordan Whited --- conn/conn.go | 12 ++++++++---- device/peer.go | 7 ++++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/conn/conn.go b/conn/conn.go index 5083648ff..2a04e6dd7 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -87,17 +87,21 @@ type Endpoint interface { } // PeerAwareEndpoint is an optional Endpoint specialization for -// integrations that want to know about the outcome of cryptorouting +// integrations that want to know about the outcome of Cryptokey Routing // identification. // // If they receive a packet from a source they had not pre-identified, // to learn the identification WireGuard can derive from the session // or handshake. // -// If GetPeerEndpoint returns nil, WireGuard will be unable to respond -// to the peer until a new endpoint is written by a later packet. +// wireguard-go never installs a [PeerAwareEndpoint] as the [Endpoint] for a +// [Peer]. type PeerAwareEndpoint interface { - GetPeerEndpoint(peerPublicKey [32]byte) Endpoint + // FromPeer is called at least once per successfully Cryptokey Routing ID'd + // [ReceiveFunc] packets batch for a given node key. wireguard-go will + // always call it for the latest/tail packet in the batch, only ever + // suppressing calls for older packets. + FromPeer(peerPublicKey [32]byte) } var ( diff --git a/device/peer.go b/device/peer.go index f79a0af29..c188c3151 100644 --- a/device/peer.go +++ b/device/peer.go @@ -282,13 +282,14 @@ func (peer *Peer) Stop() { func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { peer.endpoint.Lock() defer peer.endpoint.Unlock() + if ep, ok := endpoint.(conn.PeerAwareEndpoint); ok { + ep.FromPeer(peer.handshake.remoteStatic) + return + } if peer.endpoint.disableRoaming { return } peer.endpoint.clearSrcOnTx = false - if ep, ok := endpoint.(conn.PeerAwareEndpoint); ok { - endpoint = ep.GetPeerEndpoint(peer.handshake.remoteStatic) - } peer.endpoint.val = endpoint } From 1f398ae148a8a8183e855289b4a4492dde0490dd Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Thu, 3 Jul 2025 20:54:49 -0700 Subject: [PATCH 37/39] conn,device: implement InitiationAwareEndpoint To be implemented by [magicsock.lazyEndpoint], which is responsible for triggering JIT peer configuration. Updates tailscale/corp#20732 Updates tailscale/corp#30042 Signed-off-by: Jordan Whited --- conn/conn.go | 15 +++++++++++++++ device/noise-protocol.go | 8 +++++++- device/noise_test.go | 30 +++++++++++++++++++++++++++++- device/receive.go | 2 +- 4 files changed, 52 insertions(+), 3 deletions(-) diff --git a/conn/conn.go b/conn/conn.go index 2a04e6dd7..b949641e0 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -86,6 +86,21 @@ type Endpoint interface { SrcIP() netip.Addr } +// InitiationAwareEndpoint is an optional [Endpoint] specialization for +// integrations that want to know when a WireGuard handshake initiation +// message has been received, enabling just-in-time peer configuration before +// attempted decryption. +// +// It's most useful when used in combination with [PeerAwareEndpoint], enabling +// JIT peer configuration and post-decryption peer verification from a single +// implementer. +type InitiationAwareEndpoint interface { + // InitiationMessagePublicKey is called when a handshake initiation message + // has been received, and the sender's public key has been identified, but + // BEFORE an attempt has been made to verify it. + InitiationMessagePublicKey(peerPublicKey [32]byte) +} + // PeerAwareEndpoint is an optional Endpoint specialization for // integrations that want to know about the outcome of Cryptokey Routing // identification. diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 555ce915a..ad5838e1d 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -16,6 +16,7 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/poly1305" + "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/tai64n" ) @@ -338,7 +339,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e return &msg, nil } -func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { +func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation, endpoint conn.Endpoint) *Peer { var ( hash [blake2s.Size]byte chainKey [blake2s.Size]byte @@ -372,6 +373,11 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { // lookup peer + initEP, ok := endpoint.(conn.InitiationAwareEndpoint) + if ok { + initEP.InitiationMessagePublicKey(peerPK) + } + peer := device.LookupPeer(peerPK) if peer == nil || !peer.isRunning.Load() { return nil diff --git a/device/noise_test.go b/device/noise_test.go index 7d6af1df0..160bee588 100644 --- a/device/noise_test.go +++ b/device/noise_test.go @@ -8,6 +8,7 @@ package device import ( "bytes" "encoding/binary" + "net/netip" "testing" "github.com/tailscale/wireguard-go/conn" @@ -56,6 +57,26 @@ func assertEqual(t *testing.T, a, b []byte) { } } +type initAwareEP struct { + calledWith *[32]byte +} + +var _ conn.Endpoint = (*initAwareEP)(nil) +var _ conn.InitiationAwareEndpoint = (*initAwareEP)(nil) + +func (i *initAwareEP) ClearSrc() {} +func (i *initAwareEP) SrcToString() string { return "" } +func (i *initAwareEP) DstToString() string { return "" } +func (i *initAwareEP) DstToBytes() []byte { return nil } +func (i *initAwareEP) DstIP() netip.Addr { return netip.Addr{} } +func (i *initAwareEP) SrcIP() netip.Addr { return netip.Addr{} } + +func (i *initAwareEP) InitiationMessagePublicKey(peerPublicKey [32]byte) { + calledWith := [32]byte{} + copy(calledWith[:], peerPublicKey[:]) + i.calledWith = &calledWith +} + func TestNoiseHandshake(t *testing.T) { dev1 := randDevice(t) dev2 := randDevice(t) @@ -93,10 +114,17 @@ func TestNoiseHandshake(t *testing.T) { writer := bytes.NewBuffer(packet) err = binary.Write(writer, binary.LittleEndian, msg1) assertNil(t, err) - peer := dev2.ConsumeMessageInitiation(msg1) + initEP := &initAwareEP{} + peer := dev2.ConsumeMessageInitiation(msg1, initEP) if peer == nil { t.Fatal("handshake failed at initiation message") } + if initEP.calledWith == nil { + t.Fatal("initAwareEP never called") + } + if *initEP.calledWith != dev1.staticIdentity.publicKey { + t.Fatal("initAwareEP called with unexpected public key") + } assertEqual( t, diff --git a/device/receive.go b/device/receive.go index bc37f915e..e74de1a45 100644 --- a/device/receive.go +++ b/device/receive.go @@ -359,7 +359,7 @@ func (device *Device) RoutineHandshake(id int) { // consume initiation - peer := device.ConsumeMessageInitiation(&msg) + peer := device.ConsumeMessageInitiation(&msg, elem.endpoint) if peer == nil { device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString()) goto skip From 4064566ecaf999e6d097f16f47c16453fbc67d8e Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Thu, 10 Jul 2025 21:36:38 -0700 Subject: [PATCH 38/39] device: fix keepalive detection in TX path Updates tailscale/corp#30364 Signed-off-by: Jordan Whited --- device/send.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/device/send.go b/device/send.go index bf854b714..c8bb0792f 100644 --- a/device/send.go +++ b/device/send.go @@ -517,7 +517,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { dataSent := false elemsContainer.Lock() for _, elem := range elemsContainer.elems { - if len(elem.packet) != MessageKeepaliveSize { + if len(elem.packet[MessageEncapsulatingTransportSize:]) != MessageKeepaliveSize { dataSent = true } bufs = append(bufs, elem.packet) From 1d0488a3d7da6b6ed79202519f30e7a286e0d4e6 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Wed, 16 Jul 2025 08:43:20 -0700 Subject: [PATCH 39/39] conn,device: eval conn.PeerAwareEndpoint per-packet Peer.SetEndpointFromPacket is not called per-packet. It is guaranteed to be called at least once per packet batch. Updates tailscale/corp#30042 Updates tailscale/corp#20732 Signed-off-by: Jordan Whited --- conn/conn.go | 5 +++-- device/peer.go | 4 ---- device/receive.go | 3 +++ 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/conn/conn.go b/conn/conn.go index b949641e0..f1781614d 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -109,8 +109,9 @@ type InitiationAwareEndpoint interface { // to learn the identification WireGuard can derive from the session // or handshake. // -// wireguard-go never installs a [PeerAwareEndpoint] as the [Endpoint] for a -// [Peer]. +// A [PeerAwareEndpoint] may be installed as the [conn.Endpoint] following +// successful decryption unless endpoint roaming has been disabled for +// the peer. type PeerAwareEndpoint interface { // FromPeer is called at least once per successfully Cryptokey Routing ID'd // [ReceiveFunc] packets batch for a given node key. wireguard-go will diff --git a/device/peer.go b/device/peer.go index c188c3151..064feb22b 100644 --- a/device/peer.go +++ b/device/peer.go @@ -282,10 +282,6 @@ func (peer *Peer) Stop() { func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { peer.endpoint.Lock() defer peer.endpoint.Unlock() - if ep, ok := endpoint.(conn.PeerAwareEndpoint); ok { - ep.FromPeer(peer.handshake.remoteStatic) - return - } if peer.endpoint.disableRoaming { return } diff --git a/device/receive.go b/device/receive.go index e74de1a45..02c8f21fc 100644 --- a/device/receive.go +++ b/device/receive.go @@ -460,6 +460,9 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { peer.timersHandshakeComplete() peer.SendStagedPackets() } + if ep, ok := elem.endpoint.(conn.PeerAwareEndpoint); ok { + ep.FromPeer(peer.handshake.remoteStatic) + } rxBytesLen += uint64(len(elem.packet) + MinMessageSize) if len(elem.packet) == 0 {