diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 000000000..a37099545 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,68 @@ +name: CI + +on: + push: + branches: ["tailscale"] + pull_request: + branches: ["tailscale"] + +jobs: + build: + runs-on: ubuntu-22.04 + strategy: + matrix: + include: + - goos: linux + goarch: amd64 + - goos: linux + goarch: arm64 + - goos: linux + goarch: "386" + - goos: linux + goarch: loong64 + - goos: linux + goarch: arm + goarm: "5" + - goos: linux + goarch: arm + goarm: "7" + # macOS + - goos: darwin + goarch: amd64 + - goos: darwin + goarch: arm64 + # Windows + - goos: windows + goarch: amd64 + - goos: windows + goarch: arm64 + # BSDs + - goos: freebsd + goarch: amd64 + - goos: openbsd + goarch: amd64 + steps: + - name: checkout + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: setup go + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5.5.0 + with: + go-version-file: go.mod + - name: build + run: go build ./... + env: + GOOS: ${{ matrix.goos }} + GOARCH: ${{ matrix.goarch }} + CGO_ENABLED: "0" + + test: + runs-on: ubuntu-22.04 + steps: + - name: checkout + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: setup go + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5.5.0 + with: + go-version-file: go.mod + - name: test + run: go test -race -v ./... diff --git a/conn/bind_std.go b/conn/bind_std.go index c701ef872..fc0563456 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -8,6 +8,7 @@ package conn import ( "context" "errors" + "fmt" "net" "net/netip" "runtime" @@ -20,7 +21,8 @@ import ( ) var ( - _ Bind = (*StdNetBind)(nil) + _ Bind = (*StdNetBind)(nil) + _ Endpoint = (*StdNetEndpoint)(nil) ) // StdNetBind implements Bind for all platforms. While Windows has its own Bind @@ -29,16 +31,19 @@ var ( // methods for sending and receiving multiple datagrams per-syscall. See the // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. type StdNetBind struct { - mu sync.Mutex // protects all fields except as specified - ipv4 *net.UDPConn - ipv6 *net.UDPConn - ipv4PC *ipv4.PacketConn // will be nil on non-Linux - ipv6PC *ipv6.PacketConn // will be nil on non-Linux - - // these three fields are not guarded by mu - udpAddrPool sync.Pool - ipv4MsgsPool sync.Pool - ipv6MsgsPool sync.Pool + mu sync.Mutex // protects all fields except as specified + ipv4 *net.UDPConn + ipv6 *net.UDPConn + ipv4PC *ipv4.PacketConn // will be nil on non-Linux + ipv6PC *ipv6.PacketConn // will be nil on non-Linux + ipv4TxOffload bool + ipv4RxOffload bool + ipv6TxOffload bool + ipv6RxOffload bool + + // these two fields are not guarded by mu + udpAddrPool sync.Pool + msgsPool sync.Pool blackhole4 bool blackhole6 bool @@ -54,23 +59,12 @@ func NewStdNetBind() Bind { }, }, - ipv4MsgsPool: sync.Pool{ - New: func() any { - msgs := make([]ipv4.Message, IdealBatchSize) - for i := range msgs { - msgs[i].Buffers = make(net.Buffers, 1) - msgs[i].OOB = make([]byte, srcControlSize) - } - return &msgs - }, - }, - - ipv6MsgsPool: sync.Pool{ + msgsPool: sync.Pool{ New: func() any { msgs := make([]ipv6.Message, IdealBatchSize) for i := range msgs { msgs[i].Buffers = make(net.Buffers, 1) - msgs[i].OOB = make([]byte, srcControlSize) + msgs[i].OOB = make([]byte, controlSize) } return &msgs }, @@ -142,6 +136,11 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) { return conn.(*net.UDPConn), uaddr.Port, nil } +// errEADDRINUSE is syscall.EADDRINUSE, boxed into an interface once +// in erraddrinuse.go on almost all platforms. For other platforms, +// it's at least non-nil. +var errEADDRINUSE error = errors.New("") + func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { s.mu.Lock() defer s.mu.Unlock() @@ -168,7 +167,7 @@ again: // Listen on the same port as we're using for ipv4. v6conn, port, err = listenNet("udp6", port) - if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { + if uport == 0 && errors.Is(err, errEADDRINUSE) && tries < 100 { v4conn.Close() tries++ goto again @@ -179,19 +178,21 @@ again: } var fns []ReceiveFunc if v4conn != nil { + s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn) if runtime.GOOS == "linux" { v4pc = ipv4.NewPacketConn(v4conn) s.ipv4PC = v4pc } - fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn)) + fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload)) s.ipv4 = v4conn } if v6conn != nil { + s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn) if runtime.GOOS == "linux" { v6pc = ipv6.NewPacketConn(v6conn) s.ipv6PC = v6pc } - fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn)) + fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload)) s.ipv6 = v6conn } if len(fns) == 0 { @@ -201,69 +202,93 @@ again: return fns, uint16(port), nil } -func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc { - return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) - defer s.ipv4MsgsPool.Put(msgs) - for i := range bufs { - (*msgs)[i].Buffers[0] = bufs[i] - } - var numMsgs int - if runtime.GOOS == "linux" { - numMsgs, err = pc.ReadBatch(*msgs, 0) +func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) { + for i := range *msgs { + (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB} + } + s.msgsPool.Put(msgs) +} + +func (s *StdNetBind) getMessages() *[]ipv6.Message { + return s.msgsPool.Get().(*[]ipv6.Message) +} + +var ( + // If compilation fails here these are no longer the same underlying type. + _ ipv6.Message = ipv4.Message{} +) + +type batchReader interface { + ReadBatch([]ipv6.Message, int) (int, error) +} + +type batchWriter interface { + WriteBatch([]ipv6.Message, int) (int, error) +} + +func (s *StdNetBind) receiveIP( + br batchReader, + conn *net.UDPConn, + rxOffload bool, + bufs [][]byte, + sizes []int, + eps []Endpoint, +) (n int, err error) { + msgs := s.getMessages() + for i := range bufs { + (*msgs)[i].Buffers[0] = bufs[i] + (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] + } + defer s.putMessages(msgs) + var numMsgs int + if runtime.GOOS == "linux" { + if rxOffload { + readAt := len(*msgs) - 2 + numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) + if err != nil { + return 0, err + } + numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize) if err != nil { return 0, err } } else { - msg := &(*msgs)[0] - msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + numMsgs, err = br.ReadBatch(*msgs, 0) if err != nil { return 0, err } - numMsgs = 1 } - for i := 0; i < numMsgs; i++ { - msg := &(*msgs)[i] - sizes[i] = msg.N - addrPort := msg.Addr.(*net.UDPAddr).AddrPort() - ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation - getSrcFromControl(msg.OOB[:msg.NN], ep) - eps[i] = ep + } else { + msg := &(*msgs)[0] + msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + if err != nil { + return 0, err + } + numMsgs = 1 + } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + if sizes[i] == 0 { + continue } - return numMsgs, nil + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation + getSrcFromControl(msg.OOB[:msg.NN], ep) + eps[i] = ep } + return numMsgs, nil } -func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc { +func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) - defer s.ipv6MsgsPool.Put(msgs) - for i := range bufs { - (*msgs)[i].Buffers[0] = bufs[i] - } - var numMsgs int - if runtime.GOOS == "linux" { - numMsgs, err = pc.ReadBatch(*msgs, 0) - if err != nil { - return 0, err - } - } else { - msg := &(*msgs)[0] - msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) - if err != nil { - return 0, err - } - numMsgs = 1 - } - for i := 0; i < numMsgs; i++ { - msg := &(*msgs)[i] - sizes[i] = msg.N - addrPort := msg.Addr.(*net.UDPAddr).AddrPort() - ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation - getSrcFromControl(msg.OOB[:msg.NN], ep) - eps[i] = ep - } - return numMsgs, nil + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + } +} + +func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) } } @@ -293,28 +318,42 @@ func (s *StdNetBind) Close() error { } s.blackhole4 = false s.blackhole6 = false + s.ipv4TxOffload = false + s.ipv4RxOffload = false + s.ipv6TxOffload = false + s.ipv6RxOffload = false if err1 != nil { return err1 } return err2 } -func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { +type ErrUDPGSODisabled struct { + onLaddr string + RetryErr error +} + +func (e ErrUDPGSODisabled) Error() string { + return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr) +} + +func (e ErrUDPGSODisabled) Unwrap() error { + return e.RetryErr +} + +func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint, offset int) error { s.mu.Lock() blackhole := s.blackhole4 conn := s.ipv4 - var ( - pc4 *ipv4.PacketConn - pc6 *ipv6.PacketConn - ) + offload := s.ipv4TxOffload + br := batchWriter(s.ipv4PC) is6 := false if endpoint.DstIP().Is6() { blackhole = s.blackhole6 conn = s.ipv6 - pc6 = s.ipv6PC + br = s.ipv6PC is6 = true - } else { - pc4 = s.ipv4PC + offload = s.ipv6TxOffload } s.mu.Unlock() @@ -324,25 +363,56 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { if conn == nil { return syscall.EAFNOSUPPORT } + + msgs := s.getMessages() + defer s.putMessages(msgs) + ua := s.udpAddrPool.Get().(*net.UDPAddr) + defer s.udpAddrPool.Put(ua) if is6 { - return s.send6(conn, pc6, endpoint, bufs) + as16 := endpoint.DstIP().As16() + copy(ua.IP, as16[:]) + ua.IP = ua.IP[:16] } else { - return s.send4(conn, pc4, endpoint, bufs) + as4 := endpoint.DstIP().As4() + copy(ua.IP, as4[:]) + ua.IP = ua.IP[:4] } + ua.Port = int(endpoint.(*StdNetEndpoint).Port()) + var ( + retried bool + err error + ) +retry: + if offload { + n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, offset, *msgs, setGSOSize) + err = s.send(conn, br, (*msgs)[:n]) + if err != nil && offload && errShouldDisableUDPGSO(err) { + offload = false + s.mu.Lock() + if is6 { + s.ipv6TxOffload = false + } else { + s.ipv4TxOffload = false + } + s.mu.Unlock() + retried = true + goto retry + } + } else { + for i := range bufs { + (*msgs)[i].Addr = ua + (*msgs)[i].Buffers[0] = bufs[i][offset:] + setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint)) + } + err = s.send(conn, br, (*msgs)[:len(bufs)]) + } + if retried { + return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err} + } + return err } -func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error { - ua := s.udpAddrPool.Get().(*net.UDPAddr) - as4 := ep.DstIP().As4() - copy(ua.IP, as4[:]) - ua.IP = ua.IP[:4] - ua.Port = int(ep.(*StdNetEndpoint).Port()) - msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) - for i, buf := range bufs { - (*msgs)[i].Buffers[0] = buf - (*msgs)[i].Addr = ua - setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) - } +func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error { var ( n int err error @@ -350,59 +420,129 @@ func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, ) if runtime.GOOS == "linux" { for { - n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) - if err != nil || n == len((*msgs)[start:len(bufs)]) { + n, err = pc.WriteBatch(msgs[start:], 0) + if err != nil || n == len(msgs[start:]) { break } start += n } } else { - for i, buf := range bufs { - _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) + for _, msg := range msgs { + _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr)) if err != nil { break } } } - s.udpAddrPool.Put(ua) - s.ipv4MsgsPool.Put(msgs) return err } -func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error { - ua := s.udpAddrPool.Get().(*net.UDPAddr) - as16 := ep.DstIP().As16() - copy(ua.IP, as16[:]) - ua.IP = ua.IP[:16] - ua.Port = int(ep.(*StdNetEndpoint).Port()) - msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) - for i, buf := range bufs { - (*msgs)[i].Buffers[0] = buf - (*msgs)[i].Addr = ua - setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) - } +const ( + // Exceeding these values results in EMSGSIZE. They account for layer3 and + // layer4 headers. IPv6 does not need to account for itself as the payload + // length field is self excluding. + maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 + maxIPv6PayloadLen = 1<<16 - 1 - 8 + + // This is a hard limit imposed by the kernel. + udpSegmentMaxDatagrams = 64 +) + +type setGSOFunc func(control *[]byte, gsoSize uint16) + +func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, offset int, msgs []ipv6.Message, setGSO setGSOFunc) int { var ( - n int - err error - start int + base = -1 // index of msg we are currently coalescing into + gsoSize int // segmentation size of msgs[base] + dgramCnt int // number of dgrams coalesced into msgs[base] + endBatch bool // tracking flag to start a new batch on next iteration of bufs ) - if runtime.GOOS == "linux" { - for { - n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) - if err != nil || n == len((*msgs)[start:len(bufs)]) { - break + maxPayloadLen := maxIPv4PayloadLen + if ep.DstIP().Is6() { + maxPayloadLen = maxIPv6PayloadLen + } + for i, buf := range bufs { + buf = buf[offset:] + if i > 0 { + msgLen := len(buf) + baseLenBefore := len(msgs[base].Buffers[0]) + freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore + if msgLen+baseLenBefore <= maxPayloadLen && + msgLen <= gsoSize && + msgLen <= freeBaseCap && + dgramCnt < udpSegmentMaxDatagrams && + !endBatch { + msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...) + if i == len(bufs)-1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) + } + dgramCnt++ + if msgLen < gsoSize { + // A smaller than gsoSize packet on the tail is legal, but + // it must end the batch. + endBatch = true + } + continue } - start += n } - } else { - for i, buf := range bufs { - _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) - if err != nil { - break + if dgramCnt > 1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) + } + // Reset prior to incrementing base since we are preparing to start a + // new potential batch. + endBatch = false + base++ + gsoSize = len(buf) + setSrcControl(&msgs[base].OOB, ep) + msgs[base].Buffers[0] = buf + msgs[base].Addr = addr + dgramCnt = 1 + } + return base + 1 +} + +type getGSOFunc func(control []byte) (int, error) + +func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) { + for i := firstMsgAt; i < len(msgs); i++ { + msg := &msgs[i] + if msg.N == 0 { + return n, err + } + var ( + gsoSize int + start int + end = msg.N + numToSplit = 1 + ) + gsoSize, err = getGSO(msg.OOB[:msg.NN]) + if err != nil { + return n, err + } + if gsoSize > 0 { + numToSplit = (msg.N + gsoSize - 1) / gsoSize + end = gsoSize + } + for j := 0; j < numToSplit; j++ { + if n > i { + return n, errors.New("splitting coalesced packet resulted in overflow") } + copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) + msgs[n].N = copied + msgs[n].Addr = msg.Addr + start = end + end += gsoSize + if end > msg.N { + end = msg.N + } + n++ + } + if i != n-1 { + // It is legal for bytes to move within msg.Buffers[0] as a result + // of splitting, so we only zero the source msg len when it is not + // the destination of the last split operation above. + msg.N = 0 } } - s.udpAddrPool.Put(ua) - s.ipv6MsgsPool.Put(msgs) - return err + return n, nil } diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go index 1e4677654..77af0d925 100644 --- a/conn/bind_std_test.go +++ b/conn/bind_std_test.go @@ -1,6 +1,12 @@ package conn -import "testing" +import ( + "encoding/binary" + "net" + "testing" + + "golang.org/x/net/ipv6" +) func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { bind := NewStdNetBind().(*StdNetBind) @@ -20,3 +26,225 @@ func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { fn(bufs, sizes, eps) } } + +func mockSetGSOSize(control *[]byte, gsoSize uint16) { + *control = (*control)[:cap(*control)] + binary.LittleEndian.PutUint16(*control, gsoSize) +} + +func Test_coalesceMessages(t *testing.T) { + cases := []struct { + name string + buffs [][]byte + wantLens []int + wantGSO []int + }{ + { + name: "one message no coalesce", + buffs: [][]byte{ + make([]byte, 1, 1), + }, + wantLens: []int{1}, + wantGSO: []int{0}, + }, + { + name: "two messages equal len coalesce", + buffs: [][]byte{ + make([]byte, 1, 2), + make([]byte, 1, 1), + }, + wantLens: []int{2}, + wantGSO: []int{1}, + }, + { + name: "two messages unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + }, + wantLens: []int{3}, + wantGSO: []int{2}, + }, + { + name: "three messages second unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + make([]byte, 2, 2), + }, + wantLens: []int{3, 2}, + wantGSO: []int{2, 0}, + }, + { + name: "three messages limited cap coalesce", + buffs: [][]byte{ + make([]byte, 2, 4), + make([]byte, 2, 2), + make([]byte, 2, 2), + }, + wantLens: []int{4, 2}, + wantGSO: []int{2, 0}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + addr := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1").To4(), + Port: 1, + } + msgs := make([]ipv6.Message, len(tt.buffs)) + for i := range msgs { + msgs[i].Buffers = make([][]byte, 1) + msgs[i].OOB = make([]byte, 0, 2) + } + got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, 0, msgs, mockSetGSOSize) + if got != len(tt.wantLens) { + t.Fatalf("got len %d want: %d", got, len(tt.wantLens)) + } + for i := 0; i < got; i++ { + if msgs[i].Addr != addr { + t.Errorf("msgs[%d].Addr != passed addr", i) + } + gotLen := len(msgs[i].Buffers[0]) + if gotLen != tt.wantLens[i] { + t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i]) + } + gotGSO, err := mockGetGSOSize(msgs[i].OOB) + if err != nil { + t.Fatalf("msgs[%d] getGSOSize err: %v", i, err) + } + if gotGSO != tt.wantGSO[i] { + t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i]) + } + } + }) + } +} + +func mockGetGSOSize(control []byte) (int, error) { + if len(control) < 2 { + return 0, nil + } + return int(binary.LittleEndian.Uint16(control)), nil +} + +func Test_splitCoalescedMessages(t *testing.T) { + newMsg := func(n, gso int) ipv6.Message { + msg := ipv6.Message{ + Buffers: [][]byte{make([]byte, 1<<16-1)}, + N: n, + OOB: make([]byte, 2), + } + binary.LittleEndian.PutUint16(msg.OOB, uint16(gso)) + if gso > 0 { + msg.NN = 2 + } + return msg + } + + cases := []struct { + name string + msgs []ipv6.Message + firstMsgAt int + wantNumEval int + wantMsgLens []int + wantErr bool + }{ + { + name: "second last split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(3, 1), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 3, + wantMsgLens: []int{1, 1, 1, 0}, + wantErr: false, + }, + { + name: "second last no split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 1, + wantMsgLens: []int{1, 0, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last no split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(1, 0), + }, + firstMsgAt: 2, + wantNumEval: 2, + wantMsgLens: []int{1, 1, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(3, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(2, 1), + newMsg(2, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last no split last split overflow", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(4, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: true, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize) + if err != nil && !tt.wantErr { + t.Fatalf("err: %v", err) + } + if got != tt.wantNumEval { + t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval) + } + for i, msg := range tt.msgs { + if msg.N != tt.wantMsgLens[i] { + t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i]) + } + } + }) + } +} diff --git a/conn/bind_windows.go b/conn/bind_windows.go index d5095e004..737b475e1 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -17,7 +17,7 @@ import ( "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/conn/winrio" + "github.com/tailscale/wireguard-go/conn/winrio" ) const ( @@ -486,7 +486,7 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomi return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } -func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error { +func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint, offset int) error { nend, ok := endpoint.(*WinRingEndpoint) if !ok { return ErrWrongEndpointType @@ -494,6 +494,7 @@ func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error { bind.mu.RLock() defer bind.mu.RUnlock() for _, buf := range bufs { + buf = buf[offset:] switch nend.family { case windows.AF_INET: if bind.v4.blackhole { diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go index 74e7addd2..741b776c4 100644 --- a/conn/bindtest/bindtest.go +++ b/conn/bindtest/bindtest.go @@ -12,7 +12,7 @@ import ( "net/netip" "os" - "golang.zx2c4.com/wireguard/conn" + "github.com/tailscale/wireguard-go/conn" ) type ChannelBind struct { @@ -107,8 +107,9 @@ func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { } } -func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error { +func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint, offset int) error { for _, b := range bufs { + b = b[offset:] select { case <-c.closeSignal: return net.ErrClosed diff --git a/conn/conn.go b/conn/conn.go index a1f57d2b1..f1781614d 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -45,9 +45,11 @@ type Bind interface { // This mark is passed to the kernel as the socket option SO_MARK. SetMark(mark uint32) error - // Send writes one or more packets in bufs to address ep. The length of - // bufs must not exceed BatchSize(). - Send(bufs [][]byte, ep Endpoint) error + // Send writes one or more packets in bufs to address ep. A nonzero offset + // can be used to instruct the Bind on where packet data begins in each + // element of the bufs slice. Space preceding offset is free to use for + // additional encapsulation. The length of bufs must not exceed BatchSize(). + Send(bufs [][]byte, ep Endpoint, offset int) error // ParseEndpoint creates a new endpoint from a string. ParseEndpoint(s string) (Endpoint, error) @@ -84,6 +86,40 @@ type Endpoint interface { SrcIP() netip.Addr } +// InitiationAwareEndpoint is an optional [Endpoint] specialization for +// integrations that want to know when a WireGuard handshake initiation +// message has been received, enabling just-in-time peer configuration before +// attempted decryption. +// +// It's most useful when used in combination with [PeerAwareEndpoint], enabling +// JIT peer configuration and post-decryption peer verification from a single +// implementer. +type InitiationAwareEndpoint interface { + // InitiationMessagePublicKey is called when a handshake initiation message + // has been received, and the sender's public key has been identified, but + // BEFORE an attempt has been made to verify it. + InitiationMessagePublicKey(peerPublicKey [32]byte) +} + +// PeerAwareEndpoint is an optional Endpoint specialization for +// integrations that want to know about the outcome of Cryptokey Routing +// identification. +// +// If they receive a packet from a source they had not pre-identified, +// to learn the identification WireGuard can derive from the session +// or handshake. +// +// A [PeerAwareEndpoint] may be installed as the [conn.Endpoint] following +// successful decryption unless endpoint roaming has been disabled for +// the peer. +type PeerAwareEndpoint interface { + // FromPeer is called at least once per successfully Cryptokey Routing ID'd + // [ReceiveFunc] packets batch for a given node key. wireguard-go will + // always call it for the latest/tail packet in the batch, only ever + // suppressing calls for older packets. + FromPeer(peerPublicKey [32]byte) +} + var ( ErrBindAlreadyOpen = errors.New("bind is already open") ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type") diff --git a/conn/sticky_default.go b/conn/control_default.go similarity index 54% rename from conn/sticky_default.go rename to conn/control_default.go index 1fa8a0c4b..a8bc06a95 100644 --- a/conn/sticky_default.go +++ b/conn/control_default.go @@ -1,4 +1,4 @@ -//go:build !linux || android +//go:build !(linux && !android) /* SPDX-License-Identifier: MIT * @@ -21,8 +21,9 @@ func (e *StdNetEndpoint) SrcToString() string { return "" } -// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but -// use alternatively named flags and need ports and require testing. +// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets +// ({get,set}srcControl feature set, but use alternatively named flags and need +// ports and require testing. // getSrcFromControl parses the control for PKTINFO and if found updates ep with // the source information found. @@ -34,8 +35,17 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) { func setSrcControl(control *[]byte, ep *StdNetEndpoint) { } -// srcControlSize returns the recommended buffer size for pooling sticky control -// data. -const srcControlSize = 0 +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. +func setGSOSize(control *[]byte, gsoSize uint16) { +} + +// controlSize returns the recommended buffer size for pooling sticky and UDP +// offloading control data. +const controlSize = 0 const StdNetSupportsStickySockets = false diff --git a/conn/sticky_linux.go b/conn/control_linux.go similarity index 65% rename from conn/sticky_linux.go rename to conn/control_linux.go index a30ccc715..f32f26a69 100644 --- a/conn/sticky_linux.go +++ b/conn/control_linux.go @@ -8,6 +8,7 @@ package conn import ( + "fmt" "net/netip" "unsafe" @@ -105,6 +106,54 @@ func setSrcControl(control *[]byte, ep *StdNetEndpoint) { *control = append(*control, ep.src...) } -var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) +const ( + sizeOfGSOData = 2 +) + +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + var ( + hdr unix.Cmsghdr + data []byte + rem = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) + if err != nil { + return 0, fmt.Errorf("error parsing socket control message: %w", err) + } + if hdr.Level == socketOptionLevelUDP && hdr.Type == socketOptionUDPGRO && len(data) >= sizeOfGSOData { + var gso uint16 + copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData]) + return int(gso), nil + } + } + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing +// data in control untouched. +func setGSOSize(control *[]byte, gsoSize uint16) { + existingLen := len(*control) + avail := cap(*control) - existingLen + space := unix.CmsgSpace(sizeOfGSOData) + if avail < space { + return + } + *control = (*control)[:cap(*control)] + gsoControl := (*control)[existingLen:] + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0])) + hdr.Level = socketOptionLevelUDP + hdr.Type = socketOptionUDPSegment + hdr.SetLen(unix.CmsgLen(sizeOfGSOData)) + copy((gsoControl)[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData)) + *control = (*control)[:existingLen+space] +} + +// controlSize returns the recommended buffer size for pooling sticky and UDP +// offloading control data. +var controlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) + unix.CmsgSpace(sizeOfGSOData) const StdNetSupportsStickySockets = true diff --git a/conn/sticky_linux_test.go b/conn/control_linux_test.go similarity index 98% rename from conn/sticky_linux_test.go rename to conn/control_linux_test.go index 679213a82..3ca7d3723 100644 --- a/conn/sticky_linux_test.go +++ b/conn/control_linux_test.go @@ -60,7 +60,7 @@ func Test_setSrcControl(t *testing.T) { } setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5) - control := make([]byte, srcControlSize) + control := make([]byte, controlSize) setSrcControl(&control, ep) @@ -89,7 +89,7 @@ func Test_setSrcControl(t *testing.T) { } setSrc(ep, netip.MustParseAddr("::1"), 5) - control := make([]byte, srcControlSize) + control := make([]byte, controlSize) setSrcControl(&control, ep) diff --git a/conn/controlfns_linux.go b/conn/controlfns_linux.go index a2396fe89..752fbca4c 100644 --- a/conn/controlfns_linux.go +++ b/conn/controlfns_linux.go @@ -57,5 +57,13 @@ func init() { } return err }, + + // Attempt to enable UDP_GRO + func(network, address string, c syscall.RawConn) error { + c.Control(func(fd uintptr) { + _ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, socketOptionUDPGRO, 1) + }) + return nil + }, ) } diff --git a/conn/controlfns_unix.go b/conn/controlfns_unix.go index 91692c0a6..144a8808f 100644 --- a/conn/controlfns_unix.go +++ b/conn/controlfns_unix.go @@ -1,4 +1,4 @@ -//go:build !windows && !linux && !wasm +//go:build !windows && !linux && !wasm && !plan9 && !tamago /* SPDX-License-Identifier: MIT * diff --git a/conn/erraddrinuse.go b/conn/erraddrinuse.go new file mode 100644 index 000000000..751660edf --- /dev/null +++ b/conn/erraddrinuse.go @@ -0,0 +1,14 @@ +//go:build !plan9 + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import "syscall" + +func init() { + errEADDRINUSE = syscall.EADDRINUSE +} diff --git a/conn/errors_default.go b/conn/errors_default.go new file mode 100644 index 000000000..f1e5b90e5 --- /dev/null +++ b/conn/errors_default.go @@ -0,0 +1,12 @@ +//go:build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +func errShouldDisableUDPGSO(err error) bool { + return false +} diff --git a/conn/errors_linux.go b/conn/errors_linux.go new file mode 100644 index 000000000..8e61000f8 --- /dev/null +++ b/conn/errors_linux.go @@ -0,0 +1,26 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "errors" + "os" + + "golang.org/x/sys/unix" +) + +func errShouldDisableUDPGSO(err error) bool { + var serr *os.SyscallError + if errors.As(err, &serr) { + // EIO is returned by udp_send_skb() if the device driver does not have + // tx checksumming enabled, which is a hard requirement of UDP_SEGMENT. + // See: + // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 + // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 + return serr.Err == unix.EIO + } + return false +} diff --git a/conn/features_default.go b/conn/features_default.go new file mode 100644 index 000000000..d53ff5f7b --- /dev/null +++ b/conn/features_default.go @@ -0,0 +1,15 @@ +//go:build !linux +// +build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import "net" + +func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { + return +} diff --git a/conn/features_linux.go b/conn/features_linux.go new file mode 100644 index 000000000..513202e53 --- /dev/null +++ b/conn/features_linux.go @@ -0,0 +1,42 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "net" + + "golang.org/x/sys/unix" +) + +const ( + // TODO: upstream to x/sys/unix + socketOptionLevelUDP = 17 + socketOptionUDPSegment = 103 + socketOptionUDPGRO = 104 +) + +func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { + rc, err := conn.SyscallConn() + if err != nil { + return + } + err = rc.Control(func(fd uintptr) { + _, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, socketOptionUDPSegment) + if errSyscall != nil { + return + } + txOffload = true + opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, socketOptionUDPGRO) + if errSyscall != nil { + return + } + rxOffload = opt == 1 + }) + if err != nil { + return false, false + } + return txOffload, rxOffload +} diff --git a/device/bind_test.go b/device/bind_test.go index 302a52180..d64ca09b4 100644 --- a/device/bind_test.go +++ b/device/bind_test.go @@ -8,7 +8,7 @@ package device import ( "errors" - "golang.zx2c4.com/wireguard/conn" + "github.com/tailscale/wireguard-go/conn" ) type DummyDatagram struct { diff --git a/device/channels.go b/device/channels.go index 039d8dfd0..e526f6bb1 100644 --- a/device/channels.go +++ b/device/channels.go @@ -19,13 +19,13 @@ import ( // call wg.Done to remove the initial reference. // When the refcount hits 0, the queue's channel is closed. type outboundQueue struct { - c chan *QueueOutboundElement + c chan *QueueOutboundElementsContainer wg sync.WaitGroup } func newOutboundQueue() *outboundQueue { q := &outboundQueue{ - c: make(chan *QueueOutboundElement, QueueOutboundSize), + c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), } q.wg.Add(1) go func() { @@ -37,13 +37,13 @@ func newOutboundQueue() *outboundQueue { // A inboundQueue is similar to an outboundQueue; see those docs. type inboundQueue struct { - c chan *QueueInboundElement + c chan *QueueInboundElementsContainer wg sync.WaitGroup } func newInboundQueue() *inboundQueue { q := &inboundQueue{ - c: make(chan *QueueInboundElement, QueueInboundSize), + c: make(chan *QueueInboundElementsContainer, QueueInboundSize), } q.wg.Add(1) go func() { @@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue { } type autodrainingInboundQueue struct { - c chan *[]*QueueInboundElement + c chan *QueueInboundElementsContainer } // newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd. @@ -81,7 +81,7 @@ type autodrainingInboundQueue struct { // some other means, such as sending a sentinel nil values. func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { q := &autodrainingInboundQueue{ - c: make(chan *[]*QueueInboundElement, QueueInboundSize), + c: make(chan *QueueInboundElementsContainer, QueueInboundSize), } runtime.SetFinalizer(q, device.flushInboundQueue) return q @@ -90,13 +90,13 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { for { select { - case elems := <-q.c: - for _, elem := range *elems { - elem.Lock() + case elemsContainer := <-q.c: + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) } - device.PutInboundElementsSlice(elems) + device.PutInboundElementsContainer(elemsContainer) default: return } @@ -104,7 +104,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { } type autodrainingOutboundQueue struct { - c chan *[]*QueueOutboundElement + c chan *QueueOutboundElementsContainer } // newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd. @@ -114,7 +114,7 @@ type autodrainingOutboundQueue struct { // All sends to the channel must be best-effort, because there may be no receivers. func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { q := &autodrainingOutboundQueue{ - c: make(chan *[]*QueueOutboundElement, QueueOutboundSize), + c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), } runtime.SetFinalizer(q, device.flushOutboundQueue) return q @@ -123,13 +123,13 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { for { select { - case elems := <-q.c: - for _, elem := range *elems { - elem.Lock() + case elemsContainer := <-q.c: + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } - device.PutOutboundElementsSlice(elems) + device.PutOutboundElementsContainer(elemsContainer) default: return } diff --git a/device/constants.go b/device/constants.go index 59854a126..92c3bdea8 100644 --- a/device/constants.go +++ b/device/constants.go @@ -27,9 +27,9 @@ const ( ) const ( - MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive) - MaxMessageSize = MaxSegmentSize // maximum size of transport message - MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content + MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive) + MaxMessageSize = MaxSegmentSize // maximum size of transport message + MaxContentSize = MaxSegmentSize - MessageTransportSize - MessageEncapsulatingTransportSize // maximum size of transport message content ) /* Implementation constants */ diff --git a/device/device.go b/device/device.go index 1af9fe017..5b2348564 100644 --- a/device/device.go +++ b/device/device.go @@ -11,10 +11,10 @@ import ( "sync/atomic" "time" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/ratelimiter" - "golang.zx2c4.com/wireguard/rwcancel" - "golang.zx2c4.com/wireguard/tun" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/ratelimiter" + "github.com/tailscale/wireguard-go/rwcancel" + "github.com/tailscale/wireguard-go/tun" ) type Device struct { @@ -68,11 +68,11 @@ type Device struct { cookieChecker CookieChecker pool struct { - outboundElementsSlice *WaitPool - inboundElementsSlice *WaitPool - messageBuffers *WaitPool - inboundElements *WaitPool - outboundElements *WaitPool + inboundElementsContainer *WaitPool + outboundElementsContainer *WaitPool + messageBuffers *WaitPool + inboundElements *WaitPool + outboundElements *WaitPool } queue struct { @@ -368,10 +368,10 @@ func (device *Device) RemoveAllPeers() { } func (device *Device) Close() { - device.ipcMutex.Lock() - defer device.ipcMutex.Unlock() device.state.Lock() defer device.state.Unlock() + device.ipcMutex.Lock() + defer device.ipcMutex.Unlock() if device.isClosed() { return } @@ -461,11 +461,7 @@ func (device *Device) BindSetMark(mark uint32) error { // clear cached source addresses device.peers.RLock() for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } + peer.markEndpointSrcForClearing() } device.peers.RUnlock() @@ -515,11 +511,7 @@ func (device *Device) BindUpdate() error { // clear cached source addresses device.peers.RLock() for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } + peer.markEndpointSrcForClearing() } device.peers.RUnlock() diff --git a/device/device_test.go b/device/device_test.go index fff172bb8..e44342170 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -20,10 +20,10 @@ import ( "testing" "time" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/conn/bindtest" - "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/tun/tuntest" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/conn/bindtest" + "github.com/tailscale/wireguard-go/tun" + "github.com/tailscale/wireguard-go/tun/tuntest" ) // uapiCfg returns a string that contains cfg formatted use with IpcSet. @@ -426,11 +426,11 @@ type fakeBindSized struct { func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { return nil, 0, nil } -func (b *fakeBindSized) Close() error { return nil } -func (b *fakeBindSized) SetMark(mark uint32) error { return nil } -func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil } -func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } -func (b *fakeBindSized) BatchSize() int { return b.size } +func (b *fakeBindSized) Close() error { return nil } +func (b *fakeBindSized) SetMark(mark uint32) error { return nil } +func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint, offset int) error { return nil } +func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } +func (b *fakeBindSized) BatchSize() int { return b.size } type fakeTUNDeviceSized struct { size int diff --git a/device/keypair.go b/device/keypair.go index e3540d7d4..2689ee2a5 100644 --- a/device/keypair.go +++ b/device/keypair.go @@ -11,7 +11,7 @@ import ( "sync/atomic" "time" - "golang.zx2c4.com/wireguard/replay" + "github.com/tailscale/wireguard-go/replay" ) /* Due to limitations in Go and /x/crypto there is currently diff --git a/device/mobilequirks.go b/device/mobilequirks.go index 4e5051d7e..0a0080efd 100644 --- a/device/mobilequirks.go +++ b/device/mobilequirks.go @@ -11,9 +11,9 @@ func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() { device.net.brokenRoaming = true device.peers.RLock() for _, peer := range device.peers.keyMap { - peer.Lock() - peer.disableRoaming = peer.endpoint != nil - peer.Unlock() + peer.endpoint.Lock() + peer.endpoint.disableRoaming = peer.endpoint.val != nil + peer.endpoint.Unlock() } device.peers.RUnlock() } diff --git a/device/noise-protocol.go b/device/noise-protocol.go index e8f6145e5..ad5838e1d 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -6,6 +6,7 @@ package device import ( + "encoding/binary" "errors" "fmt" "sync" @@ -15,7 +16,8 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/poly1305" - "golang.zx2c4.com/wireguard/tai64n" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/tai64n" ) type handshakeState int @@ -60,13 +62,14 @@ const ( ) const ( - MessageInitiationSize = 148 // size of handshake initiation message - MessageResponseSize = 92 // size of response message - MessageCookieReplySize = 64 // size of cookie reply message - MessageTransportHeaderSize = 16 // size of data preceding content in transport message - MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport - MessageKeepaliveSize = MessageTransportSize // size of keepalive - MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message + MessageInitiationSize = 148 // size of handshake initiation message + MessageResponseSize = 92 // size of response message + MessageCookieReplySize = 64 // size of cookie reply message + MessageTransportHeaderSize = 16 // size of data preceding content in transport message + MessageEncapsulatingTransportSize = 8 // size of optional, free (for use by conn.Bind.Send()) space preceding the transport header + MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport + MessageKeepaliveSize = MessageTransportSize // size of keepalive + MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message ) const ( @@ -115,6 +118,98 @@ type MessageCookieReply struct { Cookie [blake2s.Size128 + poly1305.TagSize]byte } +var errMessageLengthMismatch = errors.New("message length mismatch") + +func (msg *MessageInitiation) unmarshal(b []byte) error { + if len(b) != MessageInitiationSize { + return errMessageLengthMismatch + } + + msg.Type = binary.LittleEndian.Uint32(b) + msg.Sender = binary.LittleEndian.Uint32(b[4:]) + copy(msg.Ephemeral[:], b[8:]) + copy(msg.Static[:], b[8+len(msg.Ephemeral):]) + copy(msg.Timestamp[:], b[8+len(msg.Ephemeral)+len(msg.Static):]) + copy(msg.MAC1[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):]) + copy(msg.MAC2[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp)+len(msg.MAC1):]) + + return nil +} + +func (msg *MessageInitiation) marshal(b []byte) error { + if len(b) != MessageInitiationSize { + return errMessageLengthMismatch + } + + binary.LittleEndian.PutUint32(b, msg.Type) + binary.LittleEndian.PutUint32(b[4:], msg.Sender) + copy(b[8:], msg.Ephemeral[:]) + copy(b[8+len(msg.Ephemeral):], msg.Static[:]) + copy(b[8+len(msg.Ephemeral)+len(msg.Static):], msg.Timestamp[:]) + copy(b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):], msg.MAC1[:]) + copy(b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp)+len(msg.MAC1):], msg.MAC2[:]) + + return nil +} + +func (msg *MessageResponse) unmarshal(b []byte) error { + if len(b) != MessageResponseSize { + return errMessageLengthMismatch + } + + msg.Type = binary.LittleEndian.Uint32(b) + msg.Sender = binary.LittleEndian.Uint32(b[4:]) + msg.Receiver = binary.LittleEndian.Uint32(b[8:]) + copy(msg.Ephemeral[:], b[12:]) + copy(msg.Empty[:], b[12+len(msg.Ephemeral):]) + copy(msg.MAC1[:], b[12+len(msg.Ephemeral)+len(msg.Empty):]) + copy(msg.MAC2[:], b[12+len(msg.Ephemeral)+len(msg.Empty)+len(msg.MAC1):]) + + return nil +} + +func (msg *MessageResponse) marshal(b []byte) error { + if len(b) != MessageResponseSize { + return errMessageLengthMismatch + } + + binary.LittleEndian.PutUint32(b, msg.Type) + binary.LittleEndian.PutUint32(b[4:], msg.Sender) + binary.LittleEndian.PutUint32(b[8:], msg.Receiver) + copy(b[12:], msg.Ephemeral[:]) + copy(b[12+len(msg.Ephemeral):], msg.Empty[:]) + copy(b[12+len(msg.Ephemeral)+len(msg.Empty):], msg.MAC1[:]) + copy(b[12+len(msg.Ephemeral)+len(msg.Empty)+len(msg.MAC1):], msg.MAC2[:]) + + return nil +} + +func (msg *MessageCookieReply) unmarshal(b []byte) error { + if len(b) != MessageCookieReplySize { + return errMessageLengthMismatch + } + + msg.Type = binary.LittleEndian.Uint32(b) + msg.Receiver = binary.LittleEndian.Uint32(b[4:]) + copy(msg.Nonce[:], b[8:]) + copy(msg.Cookie[:], b[8+len(msg.Nonce):]) + + return nil +} + +func (msg *MessageCookieReply) marshal(b []byte) error { + if len(b) != MessageCookieReplySize { + return errMessageLengthMismatch + } + + binary.LittleEndian.PutUint32(b, msg.Type) + binary.LittleEndian.PutUint32(b[4:], msg.Receiver) + copy(b[8:], msg.Nonce[:]) + copy(b[8+len(msg.Nonce):], msg.Cookie[:]) + + return nil +} + type Handshake struct { state handshakeState mutex sync.RWMutex @@ -124,7 +219,7 @@ type Handshake struct { localEphemeral NoisePrivateKey // ephemeral secret key localIndex uint32 // used to clear hash-table remoteIndex uint32 // index for sending - remoteStatic NoisePublicKey // long term key + remoteStatic NoisePublicKey // long term key, never changes, can be accessed without mutex remoteEphemeral NoisePublicKey // ephemeral public key precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret lastTimestamp tai64n.Timestamp @@ -244,7 +339,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e return &msg, nil } -func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { +func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation, endpoint conn.Endpoint) *Peer { var ( hash [blake2s.Size]byte chainKey [blake2s.Size]byte @@ -278,6 +373,11 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { // lookup peer + initEP, ok := endpoint.(conn.InitiationAwareEndpoint) + if ok { + initEP.InitiationMessagePublicKey(peerPK) + } + peer := device.LookupPeer(peerPK) if peer == nil || !peer.isRunning.Load() { return nil diff --git a/device/noise_test.go b/device/noise_test.go index 2dd53241d..160bee588 100644 --- a/device/noise_test.go +++ b/device/noise_test.go @@ -8,10 +8,11 @@ package device import ( "bytes" "encoding/binary" + "net/netip" "testing" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/tun/tuntest" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/tun/tuntest" ) func TestCurveWrappers(t *testing.T) { @@ -56,6 +57,26 @@ func assertEqual(t *testing.T, a, b []byte) { } } +type initAwareEP struct { + calledWith *[32]byte +} + +var _ conn.Endpoint = (*initAwareEP)(nil) +var _ conn.InitiationAwareEndpoint = (*initAwareEP)(nil) + +func (i *initAwareEP) ClearSrc() {} +func (i *initAwareEP) SrcToString() string { return "" } +func (i *initAwareEP) DstToString() string { return "" } +func (i *initAwareEP) DstToBytes() []byte { return nil } +func (i *initAwareEP) DstIP() netip.Addr { return netip.Addr{} } +func (i *initAwareEP) SrcIP() netip.Addr { return netip.Addr{} } + +func (i *initAwareEP) InitiationMessagePublicKey(peerPublicKey [32]byte) { + calledWith := [32]byte{} + copy(calledWith[:], peerPublicKey[:]) + i.calledWith = &calledWith +} + func TestNoiseHandshake(t *testing.T) { dev1 := randDevice(t) dev2 := randDevice(t) @@ -93,10 +114,17 @@ func TestNoiseHandshake(t *testing.T) { writer := bytes.NewBuffer(packet) err = binary.Write(writer, binary.LittleEndian, msg1) assertNil(t, err) - peer := dev2.ConsumeMessageInitiation(msg1) + initEP := &initAwareEP{} + peer := dev2.ConsumeMessageInitiation(msg1, initEP) if peer == nil { t.Fatal("handshake failed at initiation message") } + if initEP.calledWith == nil { + t.Fatal("initAwareEP never called") + } + if *initEP.calledWith != dev1.staticIdentity.publicKey { + t.Fatal("initAwareEP called with unexpected public key") + } assertEqual( t, diff --git a/device/peer.go b/device/peer.go index 0ac48962c..064feb22b 100644 --- a/device/peer.go +++ b/device/peer.go @@ -12,22 +12,25 @@ import ( "sync/atomic" "time" - "golang.zx2c4.com/wireguard/conn" + "github.com/tailscale/wireguard-go/conn" ) type Peer struct { isRunning atomic.Bool - sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer keypairs Keypairs handshake Handshake device *Device - endpoint conn.Endpoint stopping sync.WaitGroup // routines pending stop txBytes atomic.Uint64 // bytes send to peer (endpoint) rxBytes atomic.Uint64 // bytes received from peer lastHandshakeNano atomic.Int64 // nano seconds since epoch - disableRoaming bool + endpoint struct { + sync.Mutex + val conn.Endpoint + clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission + disableRoaming bool + } timers struct { retransmitHandshake *Timer @@ -45,9 +48,9 @@ type Peer struct { } queue struct { - staged chan *[]*QueueOutboundElement // staged packets before a handshake is available - outbound *autodrainingOutboundQueue // sequential ordering of udp transmission - inbound *autodrainingInboundQueue // sequential ordering of tun writing + staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available + outbound *autodrainingOutboundQueue // sequential ordering of udp transmission + inbound *autodrainingInboundQueue // sequential ordering of tun writing } cookieGenerator CookieGenerator @@ -74,14 +77,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { // create peer peer := new(Peer) - peer.Lock() - defer peer.Unlock() peer.cookieGenerator.Init(pk) peer.device = device peer.queue.outbound = newAutodrainingOutboundQueue(device) peer.queue.inbound = newAutodrainingInboundQueue(device) - peer.queue.staged = make(chan *[]*QueueOutboundElement, QueueStagedSize) + peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize) // map public key _, ok := device.peers.keyMap[pk] @@ -97,7 +98,11 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { handshake.mutex.Unlock() // reset endpoint - peer.endpoint = nil + peer.endpoint.Lock() + peer.endpoint.val = nil + peer.endpoint.disableRoaming = false + peer.endpoint.clearSrcOnTx = false + peer.endpoint.Unlock() // init timers peer.timersInit() @@ -108,6 +113,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { return peer, nil } +// SendBuffers sends buffers to peer. WireGuard packet data in each element of +// buffers must be preceded by MessageEncapsulatingTransportSize number of +// bytes. func (peer *Peer) SendBuffers(buffers [][]byte) error { peer.device.net.RLock() defer peer.device.net.RUnlock() @@ -116,14 +124,19 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error { return nil } - peer.RLock() - defer peer.RUnlock() - - if peer.endpoint == nil { + peer.endpoint.Lock() + endpoint := peer.endpoint.val + if endpoint == nil { + peer.endpoint.Unlock() return errors.New("no known endpoint for peer") } + if peer.endpoint.clearSrcOnTx { + endpoint.ClearSrc() + peer.endpoint.clearSrcOnTx = false + } + peer.endpoint.Unlock() - err := peer.device.net.bind.Send(buffers, peer.endpoint) + err := peer.device.net.bind.Send(buffers, endpoint, MessageEncapsulatingTransportSize) if err == nil { var totalLen uint64 for _, b := range buffers { @@ -267,10 +280,20 @@ func (peer *Peer) Stop() { } func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { - if peer.disableRoaming { + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + if peer.endpoint.disableRoaming { + return + } + peer.endpoint.clearSrcOnTx = false + peer.endpoint.val = endpoint +} + +func (peer *Peer) markEndpointSrcForClearing() { + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + if peer.endpoint.val == nil { return } - peer.Lock() - peer.endpoint = endpoint - peer.Unlock() + peer.endpoint.clearSrcOnTx = true } diff --git a/device/pools.go b/device/pools.go index 02a5d6acb..55d2be7df 100644 --- a/device/pools.go +++ b/device/pools.go @@ -7,14 +7,13 @@ package device import ( "sync" - "sync/atomic" ) type WaitPool struct { pool sync.Pool cond sync.Cond lock sync.Mutex - count atomic.Uint32 + count uint32 // Get calls not yet Put back max uint32 } @@ -27,10 +26,10 @@ func NewWaitPool(max uint32, new func() any) *WaitPool { func (p *WaitPool) Get() any { if p.max != 0 { p.lock.Lock() - for p.count.Load() >= p.max { + for p.count >= p.max { p.cond.Wait() } - p.count.Add(1) + p.count++ p.lock.Unlock() } return p.pool.Get() @@ -41,18 +40,20 @@ func (p *WaitPool) Put(x any) { if p.max == 0 { return } - p.count.Add(^uint32(0)) + p.lock.Lock() + defer p.lock.Unlock() + p.count-- p.cond.Signal() } func (device *Device) PopulatePools() { - device.pool.outboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any { - s := make([]*QueueOutboundElement, 0, device.BatchSize()) - return &s - }) - device.pool.inboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any { + device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { s := make([]*QueueInboundElement, 0, device.BatchSize()) - return &s + return &QueueInboundElementsContainer{elems: s} + }) + device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { + s := make([]*QueueOutboundElement, 0, device.BatchSize()) + return &QueueOutboundElementsContainer{elems: s} }) device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { return new([MaxMessageSize]byte) @@ -65,28 +66,32 @@ func (device *Device) PopulatePools() { }) } -func (device *Device) GetOutboundElementsSlice() *[]*QueueOutboundElement { - return device.pool.outboundElementsSlice.Get().(*[]*QueueOutboundElement) +func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer { + c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer) + c.Mutex = sync.Mutex{} + return c } -func (device *Device) PutOutboundElementsSlice(s *[]*QueueOutboundElement) { - for i := range *s { - (*s)[i] = nil +func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) { + for i := range c.elems { + c.elems[i] = nil } - *s = (*s)[:0] - device.pool.outboundElementsSlice.Put(s) + c.elems = c.elems[:0] + device.pool.inboundElementsContainer.Put(c) } -func (device *Device) GetInboundElementsSlice() *[]*QueueInboundElement { - return device.pool.inboundElementsSlice.Get().(*[]*QueueInboundElement) +func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer { + c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer) + c.Mutex = sync.Mutex{} + return c } -func (device *Device) PutInboundElementsSlice(s *[]*QueueInboundElement) { - for i := range *s { - (*s)[i] = nil +func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) { + for i := range c.elems { + c.elems[i] = nil } - *s = (*s)[:0] - device.pool.inboundElementsSlice.Put(s) + c.elems = c.elems[:0] + device.pool.outboundElementsContainer.Put(c) } func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { diff --git a/device/pools_test.go b/device/pools_test.go index 82d7493e1..2b16f3984 100644 --- a/device/pools_test.go +++ b/device/pools_test.go @@ -15,7 +15,6 @@ import ( ) func TestWaitPool(t *testing.T) { - t.Skip("Currently disabled") var wg sync.WaitGroup var trials atomic.Int32 startTrials := int32(100000) @@ -32,7 +31,9 @@ func TestWaitPool(t *testing.T) { wg.Add(workers) var max atomic.Uint32 updateMax := func() { - count := p.count.Load() + p.lock.Lock() + count := p.count + p.lock.Unlock() if count > p.max { t.Errorf("count (%d) > max (%d)", count, p.max) } diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go index 3d80eadb0..bab9625c4 100644 --- a/device/queueconstants_android.go +++ b/device/queueconstants_android.go @@ -5,7 +5,7 @@ package device -import "golang.zx2c4.com/wireguard/conn" +import "github.com/tailscale/wireguard-go/conn" /* Reduce memory consumption for Android */ diff --git a/device/queueconstants_default.go b/device/queueconstants_default.go index ea763d01c..9749cb789 100644 --- a/device/queueconstants_default.go +++ b/device/queueconstants_default.go @@ -7,7 +7,7 @@ package device -import "golang.zx2c4.com/wireguard/conn" +import "github.com/tailscale/wireguard-go/conn" const ( QueueStagedSize = conn.IdealBatchSize diff --git a/device/receive.go b/device/receive.go index e24d29f5b..02c8f21fc 100644 --- a/device/receive.go +++ b/device/receive.go @@ -6,17 +6,16 @@ package device import ( - "bytes" "encoding/binary" "errors" "net" "sync" "time" + "github.com/tailscale/wireguard-go/conn" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" - "golang.zx2c4.com/wireguard/conn" ) type QueueHandshakeElement struct { @@ -27,7 +26,6 @@ type QueueHandshakeElement struct { } type QueueInboundElement struct { - sync.Mutex buffer *[MaxMessageSize]byte packet []byte counter uint64 @@ -35,6 +33,11 @@ type QueueInboundElement struct { endpoint conn.Endpoint } +type QueueInboundElementsContainer struct { + sync.Mutex + elems []*QueueInboundElement +} + // clearPointers clears elem fields that contain pointers. // This makes the garbage collector's life easier and // avoids accidentally keeping other objects around unnecessarily. @@ -87,7 +90,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive count int endpoints = make([]conn.Endpoint, maxBatchSize) deathSpiral int - elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize) + elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize) ) for i := range bufsArrs { @@ -170,15 +173,14 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive elem.keypair = keypair elem.endpoint = endpoints[i] elem.counter = 0 - elem.Mutex = sync.Mutex{} - elem.Lock() elemsForPeer, ok := elemsByPeer[peer] if !ok { - elemsForPeer = device.GetInboundElementsSlice() + elemsForPeer = device.GetInboundElementsContainer() + elemsForPeer.Lock() elemsByPeer[peer] = elemsForPeer } - *elemsForPeer = append(*elemsForPeer, elem) + elemsForPeer.elems = append(elemsForPeer.elems, elem) bufsArrs[i] = device.GetMessageBuffer() bufs[i] = bufsArrs[i][:] continue @@ -217,18 +219,16 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive default: } } - for peer, elems := range elemsByPeer { + for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { - peer.queue.inbound.c <- elems - for _, elem := range *elems { - device.queue.decryption.c <- elem - } + peer.queue.inbound.c <- elemsContainer + device.queue.decryption.c <- elemsContainer } else { - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) } - device.PutInboundElementsSlice(elems) + device.PutInboundElementsContainer(elemsContainer) } delete(elemsByPeer, peer) } @@ -241,26 +241,28 @@ func (device *Device) RoutineDecryption(id int) { defer device.log.Verbosef("Routine: decryption worker %d - stopped", id) device.log.Verbosef("Routine: decryption worker %d - started", id) - for elem := range device.queue.decryption.c { - // split message into fields - counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] - content := elem.packet[MessageTransportOffsetContent:] - - // decrypt and release to consumer - var err error - elem.counter = binary.LittleEndian.Uint64(counter) - // copy counter to nonce - binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) - elem.packet, err = elem.keypair.receive.Open( - content[:0], - nonce[:], - content, - nil, - ) - if err != nil { - elem.packet = nil + for elemsContainer := range device.queue.decryption.c { + for _, elem := range elemsContainer.elems { + // split message into fields + counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] + content := elem.packet[MessageTransportOffsetContent:] + + // decrypt and release to consumer + var err error + elem.counter = binary.LittleEndian.Uint64(counter) + // copy counter to nonce + binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) + elem.packet, err = elem.keypair.receive.Open( + content[:0], + nonce[:], + content, + nil, + ) + if err != nil { + elem.packet = nil + } } - elem.Unlock() + elemsContainer.Unlock() } } @@ -284,8 +286,7 @@ func (device *Device) RoutineHandshake(id int) { // unmarshal packet var reply MessageCookieReply - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &reply) + err := reply.unmarshal(elem.packet) if err != nil { device.log.Verbosef("Failed to decode cookie reply") goto skip @@ -350,8 +351,7 @@ func (device *Device) RoutineHandshake(id int) { // unmarshal var msg MessageInitiation - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) + err := msg.unmarshal(elem.packet) if err != nil { device.log.Errorf("Failed to decode initiation message") goto skip @@ -359,7 +359,7 @@ func (device *Device) RoutineHandshake(id int) { // consume initiation - peer := device.ConsumeMessageInitiation(&msg) + peer := device.ConsumeMessageInitiation(&msg, elem.endpoint) if peer == nil { device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString()) goto skip @@ -383,8 +383,7 @@ func (device *Device) RoutineHandshake(id int) { // unmarshal var msg MessageResponse - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) + err := msg.unmarshal(elem.packet) if err != nil { device.log.Errorf("Failed to decode response message") goto skip @@ -437,12 +436,15 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { bufs := make([][]byte, 0, maxBatchSize) - for elems := range peer.queue.inbound.c { - if elems == nil { + for elemsContainer := range peer.queue.inbound.c { + if elemsContainer == nil { return } - for _, elem := range *elems { - elem.Lock() + elemsContainer.Lock() + validTailPacket := -1 + dataPacketReceived := false + rxBytesLen := uint64(0) + for i, elem := range elemsContainer.elems { if elem.packet == nil { // decryption failed continue @@ -452,21 +454,22 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { continue } - peer.SetEndpointFromPacket(elem.endpoint) + validTailPacket = i if peer.ReceivedWithKeypair(elem.keypair) { + peer.SetEndpointFromPacket(elem.endpoint) peer.timersHandshakeComplete() peer.SendStagedPackets() } - peer.keepKeyFreshReceiving() - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketReceived() - peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize)) + if ep, ok := elem.endpoint.(conn.PeerAwareEndpoint); ok { + ep.FromPeer(peer.handshake.remoteStatic) + } + rxBytesLen += uint64(len(elem.packet) + MinMessageSize) if len(elem.packet) == 0 { device.log.Verbosef("%v - Receiving keepalive packet", peer) continue } - peer.timersDataReceived() + dataPacketReceived = true switch elem.packet[0] >> 4 { case 4: @@ -509,17 +512,28 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) } + + peer.rxBytes.Add(rxBytesLen) + if validTailPacket >= 0 { + peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint) + peer.keepKeyFreshReceiving() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + } + if dataPacketReceived { + peer.timersDataReceived() + } if len(bufs) > 0 { _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent) if err != nil && !device.isClosed() { device.log.Errorf("Failed to write packets to TUN device: %v", err) } } - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) } bufs = bufs[:0] - device.PutInboundElementsSlice(elems) + device.PutInboundElementsContainer(elemsContainer) } } diff --git a/device/send.go b/device/send.go index d22bf264e..c8bb0792f 100644 --- a/device/send.go +++ b/device/send.go @@ -6,7 +6,6 @@ package device import ( - "bytes" "encoding/binary" "errors" "net" @@ -14,10 +13,11 @@ import ( "sync" "time" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/tun" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" - "golang.zx2c4.com/wireguard/tun" ) /* Outbound flow @@ -45,18 +45,25 @@ import ( */ type QueueOutboundElement struct { + buffer *[MaxMessageSize]byte // slice holding the packet data + // packet is always a slice of "buffer". The starting offset in buffer + // is either: + // a) MessageEncapsulatingTransportSize+MessageTransportHeaderSize (plaintext) + // b) 0 (post-encryption) + packet []byte + nonce uint64 // nonce for encryption + keypair *Keypair // keypair for encryption + peer *Peer // related peer +} + +type QueueOutboundElementsContainer struct { sync.Mutex - buffer *[MaxMessageSize]byte // slice holding the packet data - packet []byte // slice of "buffer" (always!) - nonce uint64 // nonce for encryption - keypair *Keypair // keypair for encryption - peer *Peer // related peer + elems []*QueueOutboundElement } func (device *Device) NewOutboundElement() *QueueOutboundElement { elem := device.GetOutboundElement() elem.buffer = device.GetMessageBuffer() - elem.Mutex = sync.Mutex{} elem.nonce = 0 // keypair and peer were cleared (if necessary) by clearPointers. return elem @@ -78,15 +85,15 @@ func (elem *QueueOutboundElement) clearPointers() { func (peer *Peer) SendKeepalive() { if len(peer.queue.staged) == 0 && peer.isRunning.Load() { elem := peer.device.NewOutboundElement() - elems := peer.device.GetOutboundElementsSlice() - *elems = append(*elems, elem) + elemsContainer := peer.device.GetOutboundElementsContainer() + elemsContainer.elems = append(elemsContainer.elems, elem) select { - case peer.queue.staged <- elems: + case peer.queue.staged <- elemsContainer: peer.device.log.Verbosef("%v - Sending keepalive packet", peer) default: peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) - peer.device.PutOutboundElementsSlice(elems) + peer.device.PutOutboundElementsContainer(elemsContainer) } } peer.SendStagedPackets() @@ -120,16 +127,15 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { return err } - var buf [MessageInitiationSize]byte - writer := bytes.NewBuffer(buf[:0]) - binary.Write(writer, binary.LittleEndian, msg) - packet := writer.Bytes() + buf := make([]byte, MessageEncapsulatingTransportSize+MessageInitiationSize) + packet := buf[MessageEncapsulatingTransportSize:] + _ = msg.marshal(packet) peer.cookieGenerator.AddMacs(packet) peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffers([][]byte{packet}) + err = peer.SendBuffers([][]byte{buf}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } @@ -151,10 +157,9 @@ func (peer *Peer) SendHandshakeResponse() error { return err } - var buf [MessageResponseSize]byte - writer := bytes.NewBuffer(buf[:0]) - binary.Write(writer, binary.LittleEndian, response) - packet := writer.Bytes() + buf := make([]byte, MessageEncapsulatingTransportSize+MessageResponseSize) + packet := buf[MessageEncapsulatingTransportSize:] + _ = response.marshal(packet) peer.cookieGenerator.AddMacs(packet) err = peer.BeginSymmetricSession() @@ -168,7 +173,7 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketSent() // TODO: allocation could be avoided - err = peer.SendBuffers([][]byte{packet}) + err = peer.SendBuffers([][]byte{buf}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) } @@ -185,11 +190,12 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) return err } - var buf [MessageCookieReplySize]byte - writer := bytes.NewBuffer(buf[:0]) - binary.Write(writer, binary.LittleEndian, reply) + buf := make([]byte, MessageEncapsulatingTransportSize+MessageCookieReplySize) + packet := buf[MessageEncapsulatingTransportSize:] + _ = reply.marshal(packet) // TODO: allocation could be avoided - device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint) + device.net.bind.Send([][]byte{buf}, initiatingElem.endpoint, MessageEncapsulatingTransportSize) + return nil } @@ -218,10 +224,10 @@ func (device *Device) RoutineReadFromTUN() { readErr error elems = make([]*QueueOutboundElement, batchSize) bufs = make([][]byte, batchSize) - elemsByPeer = make(map[*Peer]*[]*QueueOutboundElement, batchSize) + elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize) count = 0 sizes = make([]int, batchSize) - offset = MessageTransportHeaderSize + offset = MessageEncapsulatingTransportSize + MessageTransportHeaderSize ) for i := range elems { @@ -275,10 +281,10 @@ func (device *Device) RoutineReadFromTUN() { } elemsForPeer, ok := elemsByPeer[peer] if !ok { - elemsForPeer = device.GetOutboundElementsSlice() + elemsForPeer = device.GetOutboundElementsContainer() elemsByPeer[peer] = elemsForPeer } - *elemsForPeer = append(*elemsForPeer, elem) + elemsForPeer.elems = append(elemsForPeer.elems, elem) elems[i] = device.NewOutboundElement() bufs[i] = elems[i].buffer[:] } @@ -288,11 +294,11 @@ func (device *Device) RoutineReadFromTUN() { peer.StagePackets(elemsForPeer) peer.SendStagedPackets() } else { - for _, elem := range *elemsForPeer { + for _, elem := range elemsForPeer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } - device.PutOutboundElementsSlice(elemsForPeer) + device.PutOutboundElementsContainer(elemsForPeer) } delete(elemsByPeer, peer) } @@ -316,7 +322,7 @@ func (device *Device) RoutineReadFromTUN() { } } -func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) { +func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) { for { select { case peer.queue.staged <- elems: @@ -325,11 +331,11 @@ func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) { } select { case tooOld := <-peer.queue.staged: - for _, elem := range *tooOld { + for _, elem := range tooOld.elems { peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } - peer.device.PutOutboundElementsSlice(tooOld) + peer.device.PutOutboundElementsContainer(tooOld) default: } } @@ -348,54 +354,52 @@ top: } for { - var elemsOOO *[]*QueueOutboundElement + var elemsContainerOOO *QueueOutboundElementsContainer select { - case elems := <-peer.queue.staged: + case elemsContainer := <-peer.queue.staged: i := 0 - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { elem.peer = peer elem.nonce = keypair.sendNonce.Add(1) - 1 if elem.nonce >= RejectAfterMessages { keypair.sendNonce.Store(RejectAfterMessages) - if elemsOOO == nil { - elemsOOO = peer.device.GetOutboundElementsSlice() + if elemsContainerOOO == nil { + elemsContainerOOO = peer.device.GetOutboundElementsContainer() } - *elemsOOO = append(*elemsOOO, elem) + elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem) continue } else { - (*elems)[i] = elem + elemsContainer.elems[i] = elem i++ } elem.keypair = keypair - elem.Lock() } - *elems = (*elems)[:i] + elemsContainer.Lock() + elemsContainer.elems = elemsContainer.elems[:i] - if elemsOOO != nil { - peer.StagePackets(elemsOOO) // XXX: Out of order, but we can't front-load go chans + if elemsContainerOOO != nil { + peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans } - if len(*elems) == 0 { - peer.device.PutOutboundElementsSlice(elems) + if len(elemsContainer.elems) == 0 { + peer.device.PutOutboundElementsContainer(elemsContainer) goto top } // add to parallel and sequential queue if peer.isRunning.Load() { - peer.queue.outbound.c <- elems - for _, elem := range *elems { - peer.device.queue.encryption.c <- elem - } + peer.queue.outbound.c <- elemsContainer + peer.device.queue.encryption.c <- elemsContainer } else { - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } - peer.device.PutOutboundElementsSlice(elems) + peer.device.PutOutboundElementsContainer(elemsContainer) } - if elemsOOO != nil { + if elemsContainerOOO != nil { goto top } default: @@ -407,12 +411,12 @@ top: func (peer *Peer) FlushStagedPackets() { for { select { - case elems := <-peer.queue.staged: - for _, elem := range *elems { + case elemsContainer := <-peer.queue.staged: + for _, elem := range elemsContainer.elems { peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } - peer.device.PutOutboundElementsSlice(elems) + peer.device.PutOutboundElementsContainer(elemsContainer) default: return } @@ -446,32 +450,37 @@ func (device *Device) RoutineEncryption(id int) { defer device.log.Verbosef("Routine: encryption worker %d - stopped", id) device.log.Verbosef("Routine: encryption worker %d - started", id) - for elem := range device.queue.encryption.c { - // populate header fields - header := elem.buffer[:MessageTransportHeaderSize] + for elemsContainer := range device.queue.encryption.c { + for _, elem := range elemsContainer.elems { + // populate header fields + header := elem.buffer[MessageEncapsulatingTransportSize : MessageEncapsulatingTransportSize+MessageTransportHeaderSize] - fieldType := header[0:4] - fieldReceiver := header[4:8] - fieldNonce := header[8:16] + fieldType := header[0:4] + fieldReceiver := header[4:8] + fieldNonce := header[8:16] - binary.LittleEndian.PutUint32(fieldType, MessageTransportType) - binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) - binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) + binary.LittleEndian.PutUint32(fieldType, MessageTransportType) + binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) + binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) - // pad content to multiple of 16 - paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) - elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) + // pad content to multiple of 16 + paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) + elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) - // encrypt content and release to consumer + // encrypt content and release to consumer - binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) - elem.packet = elem.keypair.send.Seal( - header, - nonce[:], - elem.packet, - nil, - ) - elem.Unlock() + binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) + elem.packet = elem.keypair.send.Seal( + header, + nonce[:], + elem.packet, + nil, + ) + + // re-slice packet to include encapsulating transport space + elem.packet = elem.buffer[:MessageEncapsulatingTransportSize+len(elem.packet)] + } + elemsContainer.Unlock() } } @@ -485,9 +494,9 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { bufs := make([][]byte, 0, maxBatchSize) - for elems := range peer.queue.outbound.c { + for elemsContainer := range peer.queue.outbound.c { bufs = bufs[:0] - if elems == nil { + if elemsContainer == nil { return } if !peer.isRunning.Load() { @@ -497,17 +506,18 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { // The timers and SendBuffers code are resilient to a few stragglers. // TODO: rework peer shutdown order to ensure // that we never accidentally keep timers alive longer than necessary. - for _, elem := range *elems { - elem.Lock() + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } + device.PutOutboundElementsContainer(elemsContainer) continue } dataSent := false - for _, elem := range *elems { - elem.Lock() - if len(elem.packet) != MessageKeepaliveSize { + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { + if len(elem.packet[MessageEncapsulatingTransportSize:]) != MessageKeepaliveSize { dataSent = true } bufs = append(bufs, elem.packet) @@ -520,11 +530,18 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { if dataSent { peer.timersDataSent() } - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } - device.PutOutboundElementsSlice(elems) + device.PutOutboundElementsContainer(elemsContainer) + if err != nil { + var errGSO conn.ErrUDPGSODisabled + if errors.As(err, &errGSO) { + device.log.Verbosef(err.Error()) + err = errGSO.RetryErr + } + } if err != nil { device.log.Errorf("%v - Failed to send data packets: %v", peer, err) continue diff --git a/device/sticky_default.go b/device/sticky_default.go index 10382565e..732f84ce7 100644 --- a/device/sticky_default.go +++ b/device/sticky_default.go @@ -3,8 +3,8 @@ package device import ( - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/rwcancel" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/rwcancel" ) func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { diff --git a/device/sticky_linux.go b/device/sticky_linux.go index f9230f8c3..6eeced2b0 100644 --- a/device/sticky_linux.go +++ b/device/sticky_linux.go @@ -20,8 +20,8 @@ import ( "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/rwcancel" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/rwcancel" ) func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { @@ -110,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl if !ok { break } - pePtr.peer.Lock() - if &pePtr.peer.endpoint != pePtr.endpoint { - pePtr.peer.Unlock() + pePtr.peer.endpoint.Lock() + if &pePtr.peer.endpoint.val != pePtr.endpoint { + pePtr.peer.endpoint.Unlock() break } - if uint32(pePtr.peer.endpoint.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx { - pePtr.peer.Unlock() + if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx { + pePtr.peer.endpoint.Unlock() break } - pePtr.peer.endpoint.(*conn.StdNetEndpoint).ClearSrc() - pePtr.peer.Unlock() + pePtr.peer.endpoint.clearSrcOnTx = true + pePtr.peer.endpoint.Unlock() } attr = attr[attrhdr.Len:] } @@ -134,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl device.peers.RLock() i := uint32(1) for _, peer := range device.peers.keyMap { - peer.RLock() - if peer.endpoint == nil { - peer.RUnlock() + peer.endpoint.Lock() + if peer.endpoint.val == nil { + peer.endpoint.Unlock() continue } - nativeEP, _ := peer.endpoint.(*conn.StdNetEndpoint) + nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint) if nativeEP == nil { - peer.RUnlock() + peer.endpoint.Unlock() continue } if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 { - peer.RUnlock() + peer.endpoint.Unlock() break } nlmsg := struct { @@ -188,10 +188,10 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl reqPeerLock.Lock() reqPeer[i] = peerEndpointPtr{ peer: peer, - endpoint: &peer.endpoint, + endpoint: &peer.endpoint.val, } reqPeerLock.Unlock() - peer.RUnlock() + peer.endpoint.Unlock() i++ _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) if err != nil { diff --git a/device/timers.go b/device/timers.go index e28732c42..d4a4ed4e5 100644 --- a/device/timers.go +++ b/device/timers.go @@ -100,11 +100,7 @@ func expiredRetransmitHandshake(peer *Peer) { peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1) /* We clear the endpoint address src address, in case this is the cause of trouble. */ - peer.Lock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - peer.Unlock() + peer.markEndpointSrcForClearing() peer.SendHandshakeInitiation(true) } @@ -123,11 +119,7 @@ func expiredSendKeepalive(peer *Peer) { func expiredNewHandshake(peer *Peer) { peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) /* We clear the endpoint address src address, in case this is the cause of trouble. */ - peer.Lock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - peer.Unlock() + peer.markEndpointSrcForClearing() peer.SendHandshakeInitiation(false) } diff --git a/device/tun.go b/device/tun.go index 2a2ace920..960ecca40 100644 --- a/device/tun.go +++ b/device/tun.go @@ -8,7 +8,7 @@ package device import ( "fmt" - "golang.zx2c4.com/wireguard/tun" + "github.com/tailscale/wireguard-go/tun" ) const DefaultMTU = 1420 diff --git a/device/uapi.go b/device/uapi.go index 617dcd333..4987cdae0 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -18,7 +18,7 @@ import ( "sync" "time" - "golang.zx2c4.com/wireguard/ipc" + "github.com/tailscale/wireguard-go/ipc" ) type IPCError struct { @@ -99,33 +99,31 @@ func (device *Device) IpcGetOperation(w io.Writer) error { for _, peer := range device.peers.keyMap { // Serialize peer state. - // Do the work in an anonymous function so that we can use defer. - func() { - peer.RLock() - defer peer.RUnlock() - - keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic)) - keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey)) - sendf("protocol_version=1") - if peer.endpoint != nil { - sendf("endpoint=%s", peer.endpoint.DstToString()) - } - - nano := peer.lastHandshakeNano.Load() - secs := nano / time.Second.Nanoseconds() - nano %= time.Second.Nanoseconds() - - sendf("last_handshake_time_sec=%d", secs) - sendf("last_handshake_time_nsec=%d", nano) - sendf("tx_bytes=%d", peer.txBytes.Load()) - sendf("rx_bytes=%d", peer.rxBytes.Load()) - sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load()) - - device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { - sendf("allowed_ip=%s", prefix.String()) - return true - }) - }() + peer.handshake.mutex.RLock() + keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic)) + keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey)) + peer.handshake.mutex.RUnlock() + sendf("protocol_version=1") + peer.endpoint.Lock() + if peer.endpoint.val != nil { + sendf("endpoint=%s", peer.endpoint.val.DstToString()) + } + peer.endpoint.Unlock() + + nano := peer.lastHandshakeNano.Load() + secs := nano / time.Second.Nanoseconds() + nano %= time.Second.Nanoseconds() + + sendf("last_handshake_time_sec=%d", secs) + sendf("last_handshake_time_nsec=%d", nano) + sendf("tx_bytes=%d", peer.txBytes.Load()) + sendf("rx_bytes=%d", peer.rxBytes.Load()) + sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load()) + + device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { + sendf("allowed_ip=%s", prefix.String()) + return true + }) } }() @@ -262,7 +260,7 @@ func (peer *ipcSetPeer) handlePostConfig() { return } if peer.created { - peer.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint != nil + peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil } if peer.device.isUp() { peer.Start() @@ -345,9 +343,9 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) } - peer.Lock() - defer peer.Unlock() - peer.endpoint = endpoint + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + peer.endpoint.val = endpoint case "persistent_keepalive_interval": device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer) diff --git a/go.mod b/go.mod index c04e1bb61..9c9b02a66 100644 --- a/go.mod +++ b/go.mod @@ -1,16 +1,16 @@ -module golang.zx2c4.com/wireguard +module github.com/tailscale/wireguard-go go 1.20 require ( - golang.org/x/crypto v0.6.0 - golang.org/x/net v0.7.0 - golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 + golang.org/x/crypto v0.13.0 + golang.org/x/net v0.15.0 + golang.org/x/sys v0.12.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 - gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 + gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 ) require ( github.com/google/btree v1.0.1 // indirect - golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect + golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect ) diff --git a/go.sum b/go.sum index cfeaee623..6bcecea3f 100644 --- a/go.sum +++ b/go.sum @@ -1,14 +1,14 @@ github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= -golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= -golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= -golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= -golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 h1:260HNjMTPDya+jq5AM1zZLgG9pv9GASPAGiEEJUbRg4= -golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44= +golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= -gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 h1:Wobr37noukisGxpKo5jAsLREcpj61RxrWYzD8uwveOY= -gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0/go.mod h1:Dn5idtptoW1dIos9U6A2rpebLs/MtTwFacjKb8jLdQA= +gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= +gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= diff --git a/ipc/namedpipe/namedpipe_test.go b/ipc/namedpipe/namedpipe_test.go index 998453be0..de7d0f63c 100644 --- a/ipc/namedpipe/namedpipe_test.go +++ b/ipc/namedpipe/namedpipe_test.go @@ -20,8 +20,8 @@ import ( "testing" "time" + "github.com/tailscale/wireguard-go/ipc/namedpipe" "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/ipc/namedpipe" ) func randomPipePath() string { diff --git a/ipc/uapi_fake.go b/ipc/uapi_fake.go new file mode 100644 index 000000000..a2e0f85b3 --- /dev/null +++ b/ipc/uapi_fake.go @@ -0,0 +1,17 @@ +//go:build wasm || plan9 || aix || solaris || illumos + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package ipc + +// Made up sentinel error codes for {js,wasip1}/wasm, and plan9. +const ( + IpcErrorIO = 1 + IpcErrorInvalid = 2 + IpcErrorPortInUse = 3 + IpcErrorUnknown = 4 + IpcErrorProtocol = 5 +) diff --git a/ipc/uapi_linux.go b/ipc/uapi_linux.go index 1562a1834..bfdf1bffd 100644 --- a/ipc/uapi_linux.go +++ b/ipc/uapi_linux.go @@ -9,8 +9,8 @@ import ( "net" "os" + "github.com/tailscale/wireguard-go/rwcancel" "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/rwcancel" ) type UAPIListener struct { diff --git a/ipc/uapi_wasm.go b/ipc/uapi_tamago.go similarity index 76% rename from ipc/uapi_wasm.go rename to ipc/uapi_tamago.go index fa84684aa..85a725a76 100644 --- a/ipc/uapi_wasm.go +++ b/ipc/uapi_tamago.go @@ -1,3 +1,5 @@ +//go:build tamago + /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. @@ -5,7 +7,7 @@ package ipc -// Made up sentinel error codes for {js,wasip1}/wasm. +// Made up sentinel error codes for tamago platform. const ( IpcErrorIO = 1 IpcErrorInvalid = 2 diff --git a/ipc/uapi_windows.go b/ipc/uapi_windows.go index aa023c92a..bc30ae0dd 100644 --- a/ipc/uapi_windows.go +++ b/ipc/uapi_windows.go @@ -8,8 +8,8 @@ package ipc import ( "net" + "github.com/tailscale/wireguard-go/ipc/namedpipe" "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/ipc/namedpipe" ) // TODO: replace these with actual standard windows error numbers from the win package diff --git a/main.go b/main.go index e01611694..55000e9f9 100644 --- a/main.go +++ b/main.go @@ -14,11 +14,11 @@ import ( "runtime" "strconv" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/ipc" + "github.com/tailscale/wireguard-go/tun" "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/ipc" - "golang.zx2c4.com/wireguard/tun" ) const ( diff --git a/main_windows.go b/main_windows.go index a4dc46f2c..689a9a79a 100644 --- a/main_windows.go +++ b/main_windows.go @@ -12,11 +12,11 @@ import ( "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/ipc" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/ipc" - "golang.zx2c4.com/wireguard/tun" + "github.com/tailscale/wireguard-go/tun" ) const ( diff --git a/rwcancel/rwcancel.go b/rwcancel/rwcancel.go index e397c0e8a..ceb87e4da 100644 --- a/rwcancel/rwcancel.go +++ b/rwcancel/rwcancel.go @@ -1,4 +1,4 @@ -//go:build !windows && !wasm +//go:build !windows && !wasm && !plan9 && !tamago /* SPDX-License-Identifier: MIT * diff --git a/rwcancel/rwcancel_stub.go b/rwcancel/rwcancel_stub.go index 2a98b2b4a..60ae9af0e 100644 --- a/rwcancel/rwcancel_stub.go +++ b/rwcancel/rwcancel_stub.go @@ -1,4 +1,4 @@ -//go:build windows || wasm +//go:build windows || wasm || plan9 || tamago // SPDX-License-Identifier: MIT diff --git a/tun/checksum.go b/tun/checksum.go index f4f847164..6634050c3 100644 --- a/tun/checksum.go +++ b/tun/checksum.go @@ -1,42 +1,712 @@ package tun -import "encoding/binary" - -// TODO: Explore SIMD and/or other assembly optimizations. -func checksumNoFold(b []byte, initial uint64) uint64 { - ac := initial - i := 0 - n := len(b) - for n >= 4 { - ac += uint64(binary.BigEndian.Uint32(b[i : i+4])) - n -= 4 - i += 4 - } - for n >= 2 { - ac += uint64(binary.BigEndian.Uint16(b[i : i+2])) - n -= 2 - i += 2 - } - if n == 1 { - ac += uint64(b[i]) << 8 - } - return ac -} - -func checksum(b []byte, initial uint64) uint16 { - ac := checksumNoFold(b, initial) - ac = (ac >> 16) + (ac & 0xffff) - ac = (ac >> 16) + (ac & 0xffff) - ac = (ac >> 16) + (ac & 0xffff) - ac = (ac >> 16) + (ac & 0xffff) - return uint16(ac) -} - -func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 { - sum := checksumNoFold(srcAddr, 0) - sum = checksumNoFold(dstAddr, sum) - sum = checksumNoFold([]byte{0, protocol}, sum) - tmp := make([]byte, 2) - binary.BigEndian.PutUint16(tmp, totalLen) - return checksumNoFold(tmp, sum) +import ( + "encoding/binary" + "math/bits" + "strconv" + + "golang.org/x/sys/cpu" +) + +// checksumGeneric64 is a reference implementation of checksum using 64 bit +// arithmetic for use in testing or when an architecture-specific implementation +// is not available. +func checksumGeneric64(b []byte, initial uint16) uint16 { + var ac uint64 + var carry uint64 + + if cpu.IsBigEndian { + ac = uint64(initial) + } else { + ac = uint64(bits.ReverseBytes16(initial)) + } + + for len(b) >= 128 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[64:72]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[72:80]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[80:88]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[88:96]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[96:104]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[104:112]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[112:120]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[120:128]), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[64:72]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[72:80]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[80:88]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[88:96]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[96:104]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[104:112]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[112:120]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[120:128]), carry) + } + b = b[128:] + } + if len(b) >= 64 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry) + } + b = b[64:] + } + if len(b) >= 32 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry) + } + b = b[32:] + } + if len(b) >= 16 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry) + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry) + } + b = b[16:] + } + if len(b) >= 8 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b), carry) + } else { + ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b), carry) + } + b = b[8:] + } + if len(b) >= 4 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint32(b)), carry) + } else { + ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint32(b)), carry) + } + b = b[4:] + } + if len(b) >= 2 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint16(b)), carry) + } else { + ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint16(b)), carry) + } + b = b[2:] + } + if len(b) >= 1 { + if cpu.IsBigEndian { + ac, carry = bits.Add64(ac, uint64(b[0])<<8, carry) + } else { + ac, carry = bits.Add64(ac, uint64(b[0]), carry) + } + } + + folded := ipChecksumFold64(ac, carry) + if !cpu.IsBigEndian { + folded = bits.ReverseBytes16(folded) + } + return folded +} + +// checksumGeneric32 is a reference implementation of checksum using 32 bit +// arithmetic for use in testing or when an architecture-specific implementation +// is not available. +func checksumGeneric32(b []byte, initial uint16) uint16 { + var ac uint32 + var carry uint32 + + if cpu.IsBigEndian { + ac = uint32(initial) + } else { + ac = uint32(bits.ReverseBytes16(initial)) + } + + for len(b) >= 64 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:8]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[32:36]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[36:40]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[40:44]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[44:48]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[48:52]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[52:56]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[56:60]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[60:64]), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:8]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[32:36]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[36:40]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[40:44]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[44:48]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[48:52]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[52:56]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[56:60]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[60:64]), carry) + } + b = b[64:] + } + if len(b) >= 32 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry) + } + b = b[32:] + } + if len(b) >= 16 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry) + } + b = b[16:] + } + if len(b) >= 8 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry) + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry) + } + b = b[8:] + } + if len(b) >= 4 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b), carry) + } else { + ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b), carry) + } + b = b[4:] + } + if len(b) >= 2 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, uint32(binary.BigEndian.Uint16(b)), carry) + } else { + ac, carry = bits.Add32(ac, uint32(binary.LittleEndian.Uint16(b)), carry) + } + b = b[2:] + } + if len(b) >= 1 { + if cpu.IsBigEndian { + ac, carry = bits.Add32(ac, uint32(b[0])<<8, carry) + } else { + ac, carry = bits.Add32(ac, uint32(b[0]), carry) + } + } + + folded := ipChecksumFold32(ac, carry) + if !cpu.IsBigEndian { + folded = bits.ReverseBytes16(folded) + } + return folded +} + +// checksumGeneric32Alternate is an alternate reference implementation of +// checksum using 32 bit arithmetic for use in testing or when an +// architecture-specific implementation is not available. +func checksumGeneric32Alternate(b []byte, initial uint16) uint16 { + var ac uint32 + + if cpu.IsBigEndian { + ac = uint32(initial) + } else { + ac = uint32(bits.ReverseBytes16(initial)) + } + + for len(b) >= 64 { + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + ac += uint32(binary.BigEndian.Uint16(b[4:6])) + ac += uint32(binary.BigEndian.Uint16(b[6:8])) + ac += uint32(binary.BigEndian.Uint16(b[8:10])) + ac += uint32(binary.BigEndian.Uint16(b[10:12])) + ac += uint32(binary.BigEndian.Uint16(b[12:14])) + ac += uint32(binary.BigEndian.Uint16(b[14:16])) + ac += uint32(binary.BigEndian.Uint16(b[16:18])) + ac += uint32(binary.BigEndian.Uint16(b[18:20])) + ac += uint32(binary.BigEndian.Uint16(b[20:22])) + ac += uint32(binary.BigEndian.Uint16(b[22:24])) + ac += uint32(binary.BigEndian.Uint16(b[24:26])) + ac += uint32(binary.BigEndian.Uint16(b[26:28])) + ac += uint32(binary.BigEndian.Uint16(b[28:30])) + ac += uint32(binary.BigEndian.Uint16(b[30:32])) + ac += uint32(binary.BigEndian.Uint16(b[32:34])) + ac += uint32(binary.BigEndian.Uint16(b[34:36])) + ac += uint32(binary.BigEndian.Uint16(b[36:38])) + ac += uint32(binary.BigEndian.Uint16(b[38:40])) + ac += uint32(binary.BigEndian.Uint16(b[40:42])) + ac += uint32(binary.BigEndian.Uint16(b[42:44])) + ac += uint32(binary.BigEndian.Uint16(b[44:46])) + ac += uint32(binary.BigEndian.Uint16(b[46:48])) + ac += uint32(binary.BigEndian.Uint16(b[48:50])) + ac += uint32(binary.BigEndian.Uint16(b[50:52])) + ac += uint32(binary.BigEndian.Uint16(b[52:54])) + ac += uint32(binary.BigEndian.Uint16(b[54:56])) + ac += uint32(binary.BigEndian.Uint16(b[56:58])) + ac += uint32(binary.BigEndian.Uint16(b[58:60])) + ac += uint32(binary.BigEndian.Uint16(b[60:62])) + ac += uint32(binary.BigEndian.Uint16(b[62:64])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + ac += uint32(binary.LittleEndian.Uint16(b[4:6])) + ac += uint32(binary.LittleEndian.Uint16(b[6:8])) + ac += uint32(binary.LittleEndian.Uint16(b[8:10])) + ac += uint32(binary.LittleEndian.Uint16(b[10:12])) + ac += uint32(binary.LittleEndian.Uint16(b[12:14])) + ac += uint32(binary.LittleEndian.Uint16(b[14:16])) + ac += uint32(binary.LittleEndian.Uint16(b[16:18])) + ac += uint32(binary.LittleEndian.Uint16(b[18:20])) + ac += uint32(binary.LittleEndian.Uint16(b[20:22])) + ac += uint32(binary.LittleEndian.Uint16(b[22:24])) + ac += uint32(binary.LittleEndian.Uint16(b[24:26])) + ac += uint32(binary.LittleEndian.Uint16(b[26:28])) + ac += uint32(binary.LittleEndian.Uint16(b[28:30])) + ac += uint32(binary.LittleEndian.Uint16(b[30:32])) + ac += uint32(binary.LittleEndian.Uint16(b[32:34])) + ac += uint32(binary.LittleEndian.Uint16(b[34:36])) + ac += uint32(binary.LittleEndian.Uint16(b[36:38])) + ac += uint32(binary.LittleEndian.Uint16(b[38:40])) + ac += uint32(binary.LittleEndian.Uint16(b[40:42])) + ac += uint32(binary.LittleEndian.Uint16(b[42:44])) + ac += uint32(binary.LittleEndian.Uint16(b[44:46])) + ac += uint32(binary.LittleEndian.Uint16(b[46:48])) + ac += uint32(binary.LittleEndian.Uint16(b[48:50])) + ac += uint32(binary.LittleEndian.Uint16(b[50:52])) + ac += uint32(binary.LittleEndian.Uint16(b[52:54])) + ac += uint32(binary.LittleEndian.Uint16(b[54:56])) + ac += uint32(binary.LittleEndian.Uint16(b[56:58])) + ac += uint32(binary.LittleEndian.Uint16(b[58:60])) + ac += uint32(binary.LittleEndian.Uint16(b[60:62])) + ac += uint32(binary.LittleEndian.Uint16(b[62:64])) + } + b = b[64:] + } + if len(b) >= 32 { + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + ac += uint32(binary.BigEndian.Uint16(b[4:6])) + ac += uint32(binary.BigEndian.Uint16(b[6:8])) + ac += uint32(binary.BigEndian.Uint16(b[8:10])) + ac += uint32(binary.BigEndian.Uint16(b[10:12])) + ac += uint32(binary.BigEndian.Uint16(b[12:14])) + ac += uint32(binary.BigEndian.Uint16(b[14:16])) + ac += uint32(binary.BigEndian.Uint16(b[16:18])) + ac += uint32(binary.BigEndian.Uint16(b[18:20])) + ac += uint32(binary.BigEndian.Uint16(b[20:22])) + ac += uint32(binary.BigEndian.Uint16(b[22:24])) + ac += uint32(binary.BigEndian.Uint16(b[24:26])) + ac += uint32(binary.BigEndian.Uint16(b[26:28])) + ac += uint32(binary.BigEndian.Uint16(b[28:30])) + ac += uint32(binary.BigEndian.Uint16(b[30:32])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + ac += uint32(binary.LittleEndian.Uint16(b[4:6])) + ac += uint32(binary.LittleEndian.Uint16(b[6:8])) + ac += uint32(binary.LittleEndian.Uint16(b[8:10])) + ac += uint32(binary.LittleEndian.Uint16(b[10:12])) + ac += uint32(binary.LittleEndian.Uint16(b[12:14])) + ac += uint32(binary.LittleEndian.Uint16(b[14:16])) + ac += uint32(binary.LittleEndian.Uint16(b[16:18])) + ac += uint32(binary.LittleEndian.Uint16(b[18:20])) + ac += uint32(binary.LittleEndian.Uint16(b[20:22])) + ac += uint32(binary.LittleEndian.Uint16(b[22:24])) + ac += uint32(binary.LittleEndian.Uint16(b[24:26])) + ac += uint32(binary.LittleEndian.Uint16(b[26:28])) + ac += uint32(binary.LittleEndian.Uint16(b[28:30])) + ac += uint32(binary.LittleEndian.Uint16(b[30:32])) + } + b = b[32:] + } + if len(b) >= 16 { + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + ac += uint32(binary.BigEndian.Uint16(b[4:6])) + ac += uint32(binary.BigEndian.Uint16(b[6:8])) + ac += uint32(binary.BigEndian.Uint16(b[8:10])) + ac += uint32(binary.BigEndian.Uint16(b[10:12])) + ac += uint32(binary.BigEndian.Uint16(b[12:14])) + ac += uint32(binary.BigEndian.Uint16(b[14:16])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + ac += uint32(binary.LittleEndian.Uint16(b[4:6])) + ac += uint32(binary.LittleEndian.Uint16(b[6:8])) + ac += uint32(binary.LittleEndian.Uint16(b[8:10])) + ac += uint32(binary.LittleEndian.Uint16(b[10:12])) + ac += uint32(binary.LittleEndian.Uint16(b[12:14])) + ac += uint32(binary.LittleEndian.Uint16(b[14:16])) + } + b = b[16:] + } + if len(b) >= 8 { + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + ac += uint32(binary.BigEndian.Uint16(b[4:6])) + ac += uint32(binary.BigEndian.Uint16(b[6:8])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + ac += uint32(binary.LittleEndian.Uint16(b[4:6])) + ac += uint32(binary.LittleEndian.Uint16(b[6:8])) + } + b = b[8:] + } + if len(b) >= 4 { + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b[:2])) + ac += uint32(binary.BigEndian.Uint16(b[2:4])) + } else { + ac += uint32(binary.LittleEndian.Uint16(b[:2])) + ac += uint32(binary.LittleEndian.Uint16(b[2:4])) + } + b = b[4:] + } + if len(b) >= 2 { + if cpu.IsBigEndian { + ac += uint32(binary.BigEndian.Uint16(b)) + } else { + ac += uint32(binary.LittleEndian.Uint16(b)) + } + b = b[2:] + } + if len(b) >= 1 { + if cpu.IsBigEndian { + ac += uint32(b[0]) << 8 + } else { + ac += uint32(b[0]) + } + } + + folded := ipChecksumFold32(ac, 0) + if !cpu.IsBigEndian { + folded = bits.ReverseBytes16(folded) + } + return folded +} + +// checksumGeneric64Alternate is an alternate reference implementation of +// checksum using 64 bit arithmetic for use in testing or when an +// architecture-specific implementation is not available. +func checksumGeneric64Alternate(b []byte, initial uint16) uint16 { + var ac uint64 + + if cpu.IsBigEndian { + ac = uint64(initial) + } else { + ac = uint64(bits.ReverseBytes16(initial)) + } + + for len(b) >= 64 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + } else { + ac += uint64(binary.LittleEndian.Uint32(b[:4])) + ac += uint64(binary.LittleEndian.Uint32(b[4:8])) + ac += uint64(binary.LittleEndian.Uint32(b[8:12])) + ac += uint64(binary.LittleEndian.Uint32(b[12:16])) + ac += uint64(binary.LittleEndian.Uint32(b[16:20])) + ac += uint64(binary.LittleEndian.Uint32(b[20:24])) + ac += uint64(binary.LittleEndian.Uint32(b[24:28])) + ac += uint64(binary.LittleEndian.Uint32(b[28:32])) + ac += uint64(binary.LittleEndian.Uint32(b[32:36])) + ac += uint64(binary.LittleEndian.Uint32(b[36:40])) + ac += uint64(binary.LittleEndian.Uint32(b[40:44])) + ac += uint64(binary.LittleEndian.Uint32(b[44:48])) + ac += uint64(binary.LittleEndian.Uint32(b[48:52])) + ac += uint64(binary.LittleEndian.Uint32(b[52:56])) + ac += uint64(binary.LittleEndian.Uint32(b[56:60])) + ac += uint64(binary.LittleEndian.Uint32(b[60:64])) + } + b = b[64:] + } + if len(b) >= 32 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + } else { + ac += uint64(binary.LittleEndian.Uint32(b[:4])) + ac += uint64(binary.LittleEndian.Uint32(b[4:8])) + ac += uint64(binary.LittleEndian.Uint32(b[8:12])) + ac += uint64(binary.LittleEndian.Uint32(b[12:16])) + ac += uint64(binary.LittleEndian.Uint32(b[16:20])) + ac += uint64(binary.LittleEndian.Uint32(b[20:24])) + ac += uint64(binary.LittleEndian.Uint32(b[24:28])) + ac += uint64(binary.LittleEndian.Uint32(b[28:32])) + } + b = b[32:] + } + if len(b) >= 16 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + } else { + ac += uint64(binary.LittleEndian.Uint32(b[:4])) + ac += uint64(binary.LittleEndian.Uint32(b[4:8])) + ac += uint64(binary.LittleEndian.Uint32(b[8:12])) + ac += uint64(binary.LittleEndian.Uint32(b[12:16])) + } + b = b[16:] + } + if len(b) >= 8 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + } else { + ac += uint64(binary.LittleEndian.Uint32(b[:4])) + ac += uint64(binary.LittleEndian.Uint32(b[4:8])) + } + b = b[8:] + } + if len(b) >= 4 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint32(b)) + } else { + ac += uint64(binary.LittleEndian.Uint32(b)) + } + b = b[4:] + } + if len(b) >= 2 { + if cpu.IsBigEndian { + ac += uint64(binary.BigEndian.Uint16(b)) + } else { + ac += uint64(binary.LittleEndian.Uint16(b)) + } + b = b[2:] + } + if len(b) >= 1 { + if cpu.IsBigEndian { + ac += uint64(b[0]) << 8 + } else { + ac += uint64(b[0]) + } + } + + folded := ipChecksumFold64(ac, 0) + if !cpu.IsBigEndian { + folded = bits.ReverseBytes16(folded) + } + return folded +} + +func ipChecksumFold64(unfolded uint64, initialCarry uint64) uint16 { + sum, carry := bits.Add32(uint32(unfolded>>32), uint32(unfolded&0xffff_ffff), uint32(initialCarry)) + // if carry != 0, sum <= 0xffff_fffe, otherwise sum <= 0xffff_ffff + // therefore (sum >> 16) + (sum & 0xffff) + carry <= 0x1_fffe; so there is + // no need to save the carry flag + sum = (sum >> 16) + (sum & 0xffff) + carry + // sum <= 0x1_fffe therefore this is the last fold needed: + // if (sum >> 16) > 0 then + // (sum >> 16) == 1 && (sum & 0xffff) <= 0xfffe and therefore + // the addition will not overflow + // otherwise (sum >> 16) == 0 and sum will be unchanged + sum = (sum >> 16) + (sum & 0xffff) + return uint16(sum) +} + +func ipChecksumFold32(unfolded uint32, initialCarry uint32) uint16 { + sum := (unfolded >> 16) + (unfolded & 0xffff) + initialCarry + // sum <= 0x1_ffff: + // 0xffff + 0xffff = 0x1_fffe + // initialCarry is 0 or 1, for a combined maximum of 0x1_ffff + sum = (sum >> 16) + (sum & 0xffff) + // sum <= 0x1_0000 therefore this is the last fold needed: + // if (sum >> 16) > 0 then + // (sum >> 16) == 1 && (sum & 0xffff) == 0 and therefore + // the addition will not overflow + // otherwise (sum >> 16) == 0 and sum will be unchanged + sum = (sum >> 16) + (sum & 0xffff) + return uint16(sum) +} + +func addrPartialChecksum64(addr []byte, initial, carryIn uint64) (sum, carry uint64) { + sum, carry = initial, carryIn + switch len(addr) { + case 4: // IPv4 + if cpu.IsBigEndian { + sum, carry = bits.Add64(sum, uint64(binary.BigEndian.Uint32(addr)), carry) + } else { + sum, carry = bits.Add64(sum, uint64(binary.LittleEndian.Uint32(addr)), carry) + } + case 16: // IPv6 + if cpu.IsBigEndian { + sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr), carry) + sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr[8:]), carry) + } else { + sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr), carry) + sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr[8:]), carry) + } + default: + panic("bad addr length") + } + return sum, carry +} + +func addrPartialChecksum32(addr []byte, initial, carryIn uint32) (sum, carry uint32) { + sum, carry = initial, carryIn + switch len(addr) { + case 4: // IPv4 + if cpu.IsBigEndian { + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry) + } else { + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry) + } + case 16: // IPv6 + if cpu.IsBigEndian { + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry) + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[4:8]), carry) + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[8:12]), carry) + sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[12:16]), carry) + } else { + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry) + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[4:8]), carry) + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[8:12]), carry) + sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[12:16]), carry) + } + default: + panic("bad addr length") + } + return sum, carry +} + +func pseudoHeaderChecksum64(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 { + var sum uint64 + if cpu.IsBigEndian { + sum = uint64(totalLen) + uint64(protocol) + } else { + sum = uint64(bits.ReverseBytes16(totalLen)) + uint64(protocol)<<8 + } + sum, carry := addrPartialChecksum64(srcAddr, sum, 0) + sum, carry = addrPartialChecksum64(dstAddr, sum, carry) + + foldedSum := ipChecksumFold64(sum, carry) + if !cpu.IsBigEndian { + foldedSum = bits.ReverseBytes16(foldedSum) + } + return foldedSum +} + +func pseudoHeaderChecksum32(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 { + var sum uint32 + if cpu.IsBigEndian { + sum = uint32(totalLen) + uint32(protocol) + } else { + sum = uint32(bits.ReverseBytes16(totalLen)) + uint32(protocol)<<8 + } + sum, carry := addrPartialChecksum32(srcAddr, sum, 0) + sum, carry = addrPartialChecksum32(dstAddr, sum, carry) + + foldedSum := ipChecksumFold32(sum, carry) + if !cpu.IsBigEndian { + foldedSum = bits.ReverseBytes16(foldedSum) + } + return foldedSum +} + +// PseudoHeaderChecksum computes an IP pseudo-header checksum. srcAddr and +// dstAddr must be 4 or 16 bytes in length. +func PseudoHeaderChecksum(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 { + if strconv.IntSize < 64 { + return pseudoHeaderChecksum32(protocol, srcAddr, dstAddr, totalLen) + } + return pseudoHeaderChecksum64(protocol, srcAddr, dstAddr, totalLen) } diff --git a/tun/checksum_amd64.go b/tun/checksum_amd64.go new file mode 100644 index 000000000..4fb684ec7 --- /dev/null +++ b/tun/checksum_amd64.go @@ -0,0 +1,23 @@ +package tun + +import "golang.org/x/sys/cpu" + +var checksum = checksumAMD64 + +// Checksum computes an IP checksum starting with the provided initial value. +// The length of data should be at least 128 bytes for best performance. Smaller +// buffers will still compute a correct result. +func Checksum(data []byte, initial uint16) uint16 { + return checksum(data, initial) +} + +func init() { + if cpu.X86.HasAVX && cpu.X86.HasAVX2 && cpu.X86.HasBMI2 { + checksum = checksumAVX2 + return + } + if cpu.X86.HasSSE2 { + checksum = checksumSSE2 + return + } +} diff --git a/tun/checksum_amd64_test.go b/tun/checksum_amd64_test.go new file mode 100644 index 000000000..7a0b68140 --- /dev/null +++ b/tun/checksum_amd64_test.go @@ -0,0 +1,45 @@ +//go:build amd64 + +package tun + +import ( + "golang.org/x/sys/cpu" +) + +var archChecksumFuncs = []archChecksumDetails{ + { + name: "generic32", + available: true, + f: checksumGeneric32, + }, + { + name: "generic64", + available: true, + f: checksumGeneric64, + }, + { + name: "generic32Alternate", + available: true, + f: checksumGeneric32Alternate, + }, + { + name: "generic64Alternate", + available: true, + f: checksumGeneric64Alternate, + }, + { + name: "AMD64", + available: true, + f: checksumAMD64, + }, + { + name: "SSE2", + available: cpu.X86.HasSSE2, + f: checksumSSE2, + }, + { + name: "AVX2", + available: cpu.X86.HasAVX && cpu.X86.HasAVX2 && cpu.X86.HasBMI2, + f: checksumAVX2, + }, +} diff --git a/tun/checksum_generated_amd64.go b/tun/checksum_generated_amd64.go new file mode 100644 index 000000000..b4a29419b --- /dev/null +++ b/tun/checksum_generated_amd64.go @@ -0,0 +1,18 @@ +// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT. + +package tun + +// checksumAVX2 computes an IP checksum using amd64 v3 instructions (AVX2, BMI2) +// +//go:noescape +func checksumAVX2(b []byte, initial uint16) uint16 + +// checksumSSE2 computes an IP checksum using amd64 baseline instructions (SSE2) +// +//go:noescape +func checksumSSE2(b []byte, initial uint16) uint16 + +// checksumAMD64 computes an IP checksum using amd64 baseline instructions +// +//go:noescape +func checksumAMD64(b []byte, initial uint16) uint16 diff --git a/tun/checksum_generated_amd64.s b/tun/checksum_generated_amd64.s new file mode 100644 index 000000000..5f2e4c525 --- /dev/null +++ b/tun/checksum_generated_amd64.s @@ -0,0 +1,851 @@ +// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT. + +#include "textflag.h" + +DATA xmmLoadMasks<>+0(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" +DATA xmmLoadMasks<>+16(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff" +DATA xmmLoadMasks<>+32(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff" +DATA xmmLoadMasks<>+48(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff" +DATA xmmLoadMasks<>+64(SB)/16, $"\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +DATA xmmLoadMasks<>+80(SB)/16, $"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +DATA xmmLoadMasks<>+96(SB)/16, $"\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +GLOBL xmmLoadMasks<>(SB), RODATA|NOPTR, $112 + +// func checksumAVX2(b []byte, initial uint16) uint16 +// Requires: AVX, AVX2, BMI2 +TEXT ·checksumAVX2(SB), NOSPLIT|NOFRAME, $0-34 + MOVWQZX initial+24(FP), AX + XCHGB AH, AL + MOVQ b_base+0(FP), DX + MOVQ b_len+8(FP), BX + + // handle odd length buffers; they are difficult to handle in general + TESTQ $0x00000001, BX + JZ lengthIsEven + MOVBQZX -1(DX)(BX*1), CX + DECQ BX + ADDQ CX, AX + +lengthIsEven: + // handle tiny buffers (<=31 bytes) specially + CMPQ BX, $0x1f + JGT bufferIsNotTiny + XORQ CX, CX + XORQ SI, SI + XORQ DI, DI + + // shift twice to start because length is guaranteed to be even + // n = n >> 2; CF = originalN & 2 + SHRQ $0x02, BX + JNC handleTiny4 + + // tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:] + MOVWQZX (DX), CX + ADDQ $0x02, DX + +handleTiny4: + // n = n >> 1; CF = originalN & 4 + SHRQ $0x01, BX + JNC handleTiny8 + + // tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:] + MOVLQZX (DX), SI + ADDQ $0x04, DX + +handleTiny8: + // n = n >> 1; CF = originalN & 8 + SHRQ $0x01, BX + JNC handleTiny16 + + // tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:] + MOVQ (DX), DI + ADDQ $0x08, DX + +handleTiny16: + // n = n >> 1; CF = originalN & 16 + // n == 0 now, otherwise we would have branched after comparing with tinyBufferSize + SHRQ $0x01, BX + JNC handleTinyFinish + ADDQ (DX), AX + ADCQ 8(DX), AX + +handleTinyFinish: + // CF should be included from the previous add, so we use ADCQ. + // If we arrived via the JNC above, then CF=0 due to the branch condition, + // so ADCQ will still produce the correct result. + ADCQ CX, AX + ADCQ SI, AX + ADCQ DI, AX + JMP foldAndReturn + +bufferIsNotTiny: + // skip all SIMD for small buffers + CMPQ BX, $0x00000100 + JGE startSIMD + + // Accumulate carries in this register. It is never expected to overflow. + XORQ SI, SI + + // We will perform an overlapped read for buffers with length not a multiple of 8. + // Overlapped in this context means some memory will be read twice, but a shift will + // eliminate the duplicated data. This extra read is performed at the end of the buffer to + // preserve any alignment that may exist for the start of the buffer. + MOVQ BX, CX + SHRQ $0x03, BX + ANDQ $0x07, CX + JZ handleRemaining8 + LEAQ (DX)(BX*8), DI + MOVQ -8(DI)(CX*1), DI + + // Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8) + SHLQ $0x03, CX + NEGQ CX + ADDQ $0x40, CX + SHRQ CL, DI + ADDQ DI, AX + ADCQ $0x00, SI + +handleRemaining8: + SHRQ $0x01, BX + JNC handleRemaining16 + ADDQ (DX), AX + ADCQ $0x00, SI + ADDQ $0x08, DX + +handleRemaining16: + SHRQ $0x01, BX + JNC handleRemaining32 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ $0x00, SI + ADDQ $0x10, DX + +handleRemaining32: + SHRQ $0x01, BX + JNC handleRemaining64 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ $0x00, SI + ADDQ $0x20, DX + +handleRemaining64: + SHRQ $0x01, BX + JNC handleRemaining128 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ $0x00, SI + ADDQ $0x40, DX + +handleRemaining128: + SHRQ $0x01, BX + JNC handleRemainingComplete + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ 64(DX), AX + ADCQ 72(DX), AX + ADCQ 80(DX), AX + ADCQ 88(DX), AX + ADCQ 96(DX), AX + ADCQ 104(DX), AX + ADCQ 112(DX), AX + ADCQ 120(DX), AX + ADCQ $0x00, SI + ADDQ $0x80, DX + +handleRemainingComplete: + ADDQ SI, AX + JMP foldAndReturn + +startSIMD: + VPXOR Y0, Y0, Y0 + VPXOR Y1, Y1, Y1 + VPXOR Y2, Y2, Y2 + VPXOR Y3, Y3, Y3 + MOVQ BX, CX + + // Update number of bytes remaining after the loop completes + ANDQ $0xff, BX + + // Number of 256 byte iterations + SHRQ $0x08, CX + JZ smallLoop + +bigLoop: + VPMOVZXWD (DX), Y4 + VPADDD Y4, Y0, Y0 + VPMOVZXWD 16(DX), Y4 + VPADDD Y4, Y1, Y1 + VPMOVZXWD 32(DX), Y4 + VPADDD Y4, Y2, Y2 + VPMOVZXWD 48(DX), Y4 + VPADDD Y4, Y3, Y3 + VPMOVZXWD 64(DX), Y4 + VPADDD Y4, Y0, Y0 + VPMOVZXWD 80(DX), Y4 + VPADDD Y4, Y1, Y1 + VPMOVZXWD 96(DX), Y4 + VPADDD Y4, Y2, Y2 + VPMOVZXWD 112(DX), Y4 + VPADDD Y4, Y3, Y3 + VPMOVZXWD 128(DX), Y4 + VPADDD Y4, Y0, Y0 + VPMOVZXWD 144(DX), Y4 + VPADDD Y4, Y1, Y1 + VPMOVZXWD 160(DX), Y4 + VPADDD Y4, Y2, Y2 + VPMOVZXWD 176(DX), Y4 + VPADDD Y4, Y3, Y3 + VPMOVZXWD 192(DX), Y4 + VPADDD Y4, Y0, Y0 + VPMOVZXWD 208(DX), Y4 + VPADDD Y4, Y1, Y1 + VPMOVZXWD 224(DX), Y4 + VPADDD Y4, Y2, Y2 + VPMOVZXWD 240(DX), Y4 + VPADDD Y4, Y3, Y3 + ADDQ $0x00000100, DX + DECQ CX + JNZ bigLoop + CMPQ BX, $0x10 + JLT doneSmallLoop + + // now read a single 16 byte unit of data at a time +smallLoop: + VPMOVZXWD (DX), Y4 + VPADDD Y4, Y0, Y0 + ADDQ $0x10, DX + SUBQ $0x10, BX + CMPQ BX, $0x10 + JGE smallLoop + +doneSmallLoop: + CMPQ BX, $0x00 + JE doneSIMD + + // There are between 1 and 15 bytes remaining. Perform an overlapped read. + LEAQ xmmLoadMasks<>+0(SB), CX + VMOVDQU -16(DX)(BX*1), X4 + VPAND -16(CX)(BX*8), X4, X4 + VPMOVZXWD X4, Y4 + VPADDD Y4, Y0, Y0 + +doneSIMD: + // Multi-chain loop is done, combine the accumulators + VPADDD Y1, Y0, Y0 + VPADDD Y2, Y0, Y0 + VPADDD Y3, Y0, Y0 + + // extract the YMM into a pair of XMM and sum them + VEXTRACTI128 $0x01, Y0, X1 + VPADDD X0, X1, X0 + + // extract the XMM into GP64 + VPEXTRQ $0x00, X0, CX + VPEXTRQ $0x01, X0, DX + + // no more AVX code, clear upper registers to avoid SSE slowdowns + VZEROUPPER + ADDQ CX, AX + ADCQ DX, AX + +foldAndReturn: + // add CF and fold + RORXQ $0x20, AX, CX + ADCL CX, AX + RORXL $0x10, AX, CX + ADCW CX, AX + ADCW $0x00, AX + XCHGB AH, AL + MOVW AX, ret+32(FP) + RET + +// func checksumSSE2(b []byte, initial uint16) uint16 +// Requires: SSE2 +TEXT ·checksumSSE2(SB), NOSPLIT|NOFRAME, $0-34 + MOVWQZX initial+24(FP), AX + XCHGB AH, AL + MOVQ b_base+0(FP), DX + MOVQ b_len+8(FP), BX + + // handle odd length buffers; they are difficult to handle in general + TESTQ $0x00000001, BX + JZ lengthIsEven + MOVBQZX -1(DX)(BX*1), CX + DECQ BX + ADDQ CX, AX + +lengthIsEven: + // handle tiny buffers (<=31 bytes) specially + CMPQ BX, $0x1f + JGT bufferIsNotTiny + XORQ CX, CX + XORQ SI, SI + XORQ DI, DI + + // shift twice to start because length is guaranteed to be even + // n = n >> 2; CF = originalN & 2 + SHRQ $0x02, BX + JNC handleTiny4 + + // tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:] + MOVWQZX (DX), CX + ADDQ $0x02, DX + +handleTiny4: + // n = n >> 1; CF = originalN & 4 + SHRQ $0x01, BX + JNC handleTiny8 + + // tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:] + MOVLQZX (DX), SI + ADDQ $0x04, DX + +handleTiny8: + // n = n >> 1; CF = originalN & 8 + SHRQ $0x01, BX + JNC handleTiny16 + + // tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:] + MOVQ (DX), DI + ADDQ $0x08, DX + +handleTiny16: + // n = n >> 1; CF = originalN & 16 + // n == 0 now, otherwise we would have branched after comparing with tinyBufferSize + SHRQ $0x01, BX + JNC handleTinyFinish + ADDQ (DX), AX + ADCQ 8(DX), AX + +handleTinyFinish: + // CF should be included from the previous add, so we use ADCQ. + // If we arrived via the JNC above, then CF=0 due to the branch condition, + // so ADCQ will still produce the correct result. + ADCQ CX, AX + ADCQ SI, AX + ADCQ DI, AX + JMP foldAndReturn + +bufferIsNotTiny: + // skip all SIMD for small buffers + CMPQ BX, $0x00000100 + JGE startSIMD + + // Accumulate carries in this register. It is never expected to overflow. + XORQ SI, SI + + // We will perform an overlapped read for buffers with length not a multiple of 8. + // Overlapped in this context means some memory will be read twice, but a shift will + // eliminate the duplicated data. This extra read is performed at the end of the buffer to + // preserve any alignment that may exist for the start of the buffer. + MOVQ BX, CX + SHRQ $0x03, BX + ANDQ $0x07, CX + JZ handleRemaining8 + LEAQ (DX)(BX*8), DI + MOVQ -8(DI)(CX*1), DI + + // Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8) + SHLQ $0x03, CX + NEGQ CX + ADDQ $0x40, CX + SHRQ CL, DI + ADDQ DI, AX + ADCQ $0x00, SI + +handleRemaining8: + SHRQ $0x01, BX + JNC handleRemaining16 + ADDQ (DX), AX + ADCQ $0x00, SI + ADDQ $0x08, DX + +handleRemaining16: + SHRQ $0x01, BX + JNC handleRemaining32 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ $0x00, SI + ADDQ $0x10, DX + +handleRemaining32: + SHRQ $0x01, BX + JNC handleRemaining64 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ $0x00, SI + ADDQ $0x20, DX + +handleRemaining64: + SHRQ $0x01, BX + JNC handleRemaining128 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ $0x00, SI + ADDQ $0x40, DX + +handleRemaining128: + SHRQ $0x01, BX + JNC handleRemainingComplete + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ 64(DX), AX + ADCQ 72(DX), AX + ADCQ 80(DX), AX + ADCQ 88(DX), AX + ADCQ 96(DX), AX + ADCQ 104(DX), AX + ADCQ 112(DX), AX + ADCQ 120(DX), AX + ADCQ $0x00, SI + ADDQ $0x80, DX + +handleRemainingComplete: + ADDQ SI, AX + JMP foldAndReturn + +startSIMD: + PXOR X0, X0 + PXOR X1, X1 + PXOR X2, X2 + PXOR X3, X3 + PXOR X4, X4 + MOVQ BX, CX + + // Update number of bytes remaining after the loop completes + ANDQ $0xff, BX + + // Number of 256 byte iterations + SHRQ $0x08, CX + JZ smallLoop + +bigLoop: + MOVOU (DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X2 + MOVOU 16(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X1 + PADDD X6, X3 + MOVOU 32(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X2 + PADDD X6, X0 + MOVOU 48(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X3 + PADDD X6, X1 + MOVOU 64(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X2 + MOVOU 80(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X1 + PADDD X6, X3 + MOVOU 96(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X2 + PADDD X6, X0 + MOVOU 112(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X3 + PADDD X6, X1 + MOVOU 128(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X2 + MOVOU 144(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X1 + PADDD X6, X3 + MOVOU 160(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X2 + PADDD X6, X0 + MOVOU 176(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X3 + PADDD X6, X1 + MOVOU 192(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X2 + MOVOU 208(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X1 + PADDD X6, X3 + MOVOU 224(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X2 + PADDD X6, X0 + MOVOU 240(DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X3 + PADDD X6, X1 + ADDQ $0x00000100, DX + DECQ CX + JNZ bigLoop + CMPQ BX, $0x10 + JLT doneSmallLoop + + // now read a single 16 byte unit of data at a time +smallLoop: + MOVOU (DX), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X1 + ADDQ $0x10, DX + SUBQ $0x10, BX + CMPQ BX, $0x10 + JGE smallLoop + +doneSmallLoop: + CMPQ BX, $0x00 + JE doneSIMD + + // There are between 1 and 15 bytes remaining. Perform an overlapped read. + LEAQ xmmLoadMasks<>+0(SB), CX + MOVOU -16(DX)(BX*1), X5 + PAND -16(CX)(BX*8), X5 + MOVOA X5, X6 + PUNPCKHWL X4, X5 + PUNPCKLWL X4, X6 + PADDD X5, X0 + PADDD X6, X1 + +doneSIMD: + // Multi-chain loop is done, combine the accumulators + PADDD X1, X0 + PADDD X2, X0 + PADDD X3, X0 + + // extract the XMM into GP64 + MOVQ X0, CX + PSRLDQ $0x08, X0 + MOVQ X0, DX + ADDQ CX, AX + ADCQ DX, AX + +foldAndReturn: + // add CF and fold + MOVL AX, CX + ADCQ $0x00, CX + SHRQ $0x20, AX + ADDQ CX, AX + MOVWQZX AX, CX + SHRQ $0x10, AX + ADDQ CX, AX + MOVW AX, CX + SHRQ $0x10, AX + ADDW CX, AX + ADCW $0x00, AX + XCHGB AH, AL + MOVW AX, ret+32(FP) + RET + +// func checksumAMD64(b []byte, initial uint16) uint16 +TEXT ·checksumAMD64(SB), NOSPLIT|NOFRAME, $0-34 + MOVWQZX initial+24(FP), AX + XCHGB AH, AL + MOVQ b_base+0(FP), DX + MOVQ b_len+8(FP), BX + + // handle odd length buffers; they are difficult to handle in general + TESTQ $0x00000001, BX + JZ lengthIsEven + MOVBQZX -1(DX)(BX*1), CX + DECQ BX + ADDQ CX, AX + +lengthIsEven: + // handle tiny buffers (<=31 bytes) specially + CMPQ BX, $0x1f + JGT bufferIsNotTiny + XORQ CX, CX + XORQ SI, SI + XORQ DI, DI + + // shift twice to start because length is guaranteed to be even + // n = n >> 2; CF = originalN & 2 + SHRQ $0x02, BX + JNC handleTiny4 + + // tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:] + MOVWQZX (DX), CX + ADDQ $0x02, DX + +handleTiny4: + // n = n >> 1; CF = originalN & 4 + SHRQ $0x01, BX + JNC handleTiny8 + + // tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:] + MOVLQZX (DX), SI + ADDQ $0x04, DX + +handleTiny8: + // n = n >> 1; CF = originalN & 8 + SHRQ $0x01, BX + JNC handleTiny16 + + // tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:] + MOVQ (DX), DI + ADDQ $0x08, DX + +handleTiny16: + // n = n >> 1; CF = originalN & 16 + // n == 0 now, otherwise we would have branched after comparing with tinyBufferSize + SHRQ $0x01, BX + JNC handleTinyFinish + ADDQ (DX), AX + ADCQ 8(DX), AX + +handleTinyFinish: + // CF should be included from the previous add, so we use ADCQ. + // If we arrived via the JNC above, then CF=0 due to the branch condition, + // so ADCQ will still produce the correct result. + ADCQ CX, AX + ADCQ SI, AX + ADCQ DI, AX + JMP foldAndReturn + +bufferIsNotTiny: + // Number of 256 byte iterations into loop counter + MOVQ BX, CX + + // Update number of bytes remaining after the loop completes + ANDQ $0xff, BX + SHRQ $0x08, CX + JZ startCleanup + CLC + XORQ SI, SI + XORQ DI, DI + XORQ R8, R8 + XORQ R9, R9 + XORQ R10, R10 + XORQ R11, R11 + XORQ R12, R12 + +bigLoop: + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ $0x00, SI + ADDQ 32(DX), DI + ADCQ 40(DX), DI + ADCQ 48(DX), DI + ADCQ 56(DX), DI + ADCQ $0x00, R8 + ADDQ 64(DX), R9 + ADCQ 72(DX), R9 + ADCQ 80(DX), R9 + ADCQ 88(DX), R9 + ADCQ $0x00, R10 + ADDQ 96(DX), R11 + ADCQ 104(DX), R11 + ADCQ 112(DX), R11 + ADCQ 120(DX), R11 + ADCQ $0x00, R12 + ADDQ 128(DX), AX + ADCQ 136(DX), AX + ADCQ 144(DX), AX + ADCQ 152(DX), AX + ADCQ $0x00, SI + ADDQ 160(DX), DI + ADCQ 168(DX), DI + ADCQ 176(DX), DI + ADCQ 184(DX), DI + ADCQ $0x00, R8 + ADDQ 192(DX), R9 + ADCQ 200(DX), R9 + ADCQ 208(DX), R9 + ADCQ 216(DX), R9 + ADCQ $0x00, R10 + ADDQ 224(DX), R11 + ADCQ 232(DX), R11 + ADCQ 240(DX), R11 + ADCQ 248(DX), R11 + ADCQ $0x00, R12 + ADDQ $0x00000100, DX + SUBQ $0x01, CX + JNZ bigLoop + ADDQ SI, AX + ADCQ DI, AX + ADCQ R8, AX + ADCQ R9, AX + ADCQ R10, AX + ADCQ R11, AX + ADCQ R12, AX + + // accumulate CF (twice, in case the first time overflows) + ADCQ $0x00, AX + ADCQ $0x00, AX + +startCleanup: + // Accumulate carries in this register. It is never expected to overflow. + XORQ SI, SI + + // We will perform an overlapped read for buffers with length not a multiple of 8. + // Overlapped in this context means some memory will be read twice, but a shift will + // eliminate the duplicated data. This extra read is performed at the end of the buffer to + // preserve any alignment that may exist for the start of the buffer. + MOVQ BX, CX + SHRQ $0x03, BX + ANDQ $0x07, CX + JZ handleRemaining8 + LEAQ (DX)(BX*8), DI + MOVQ -8(DI)(CX*1), DI + + // Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8) + SHLQ $0x03, CX + NEGQ CX + ADDQ $0x40, CX + SHRQ CL, DI + ADDQ DI, AX + ADCQ $0x00, SI + +handleRemaining8: + SHRQ $0x01, BX + JNC handleRemaining16 + ADDQ (DX), AX + ADCQ $0x00, SI + ADDQ $0x08, DX + +handleRemaining16: + SHRQ $0x01, BX + JNC handleRemaining32 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ $0x00, SI + ADDQ $0x10, DX + +handleRemaining32: + SHRQ $0x01, BX + JNC handleRemaining64 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ $0x00, SI + ADDQ $0x20, DX + +handleRemaining64: + SHRQ $0x01, BX + JNC handleRemaining128 + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ $0x00, SI + ADDQ $0x40, DX + +handleRemaining128: + SHRQ $0x01, BX + JNC handleRemainingComplete + ADDQ (DX), AX + ADCQ 8(DX), AX + ADCQ 16(DX), AX + ADCQ 24(DX), AX + ADCQ 32(DX), AX + ADCQ 40(DX), AX + ADCQ 48(DX), AX + ADCQ 56(DX), AX + ADCQ 64(DX), AX + ADCQ 72(DX), AX + ADCQ 80(DX), AX + ADCQ 88(DX), AX + ADCQ 96(DX), AX + ADCQ 104(DX), AX + ADCQ 112(DX), AX + ADCQ 120(DX), AX + ADCQ $0x00, SI + ADDQ $0x80, DX + +handleRemainingComplete: + ADDQ SI, AX + +foldAndReturn: + // add CF and fold + MOVL AX, CX + ADCQ $0x00, CX + SHRQ $0x20, AX + ADDQ CX, AX + MOVWQZX AX, CX + SHRQ $0x10, AX + ADDQ CX, AX + MOVW AX, CX + SHRQ $0x10, AX + ADDW CX, AX + ADCW $0x00, AX + XCHGB AH, AL + MOVW AX, ret+32(FP) + RET diff --git a/tun/checksum_generic.go b/tun/checksum_generic.go new file mode 100644 index 000000000..2ef201a1f --- /dev/null +++ b/tun/checksum_generic.go @@ -0,0 +1,15 @@ +// This file contains IP checksum algorithms that are not specific to any +// architecture and don't use hardware acceleration. + +//go:build !amd64 + +package tun + +import "strconv" + +func Checksum(data []byte, initial uint16) uint16 { + if strconv.IntSize < 64 { + return checksumGeneric32(data, initial) + } + return checksumGeneric64(data, initial) +} diff --git a/tun/checksum_generic_test.go b/tun/checksum_generic_test.go new file mode 100644 index 000000000..401a7bb88 --- /dev/null +++ b/tun/checksum_generic_test.go @@ -0,0 +1,26 @@ +//go:build !amd64 + +package tun + +var archChecksumFuncs = []archChecksumDetails{ + { + name: "generic32", + available: true, + f: checksumGeneric32, + }, + { + name: "generic32Alternate", + available: true, + f: checksumGeneric32Alternate, + }, + { + name: "generic64", + available: true, + f: checksumGeneric64, + }, + { + name: "generic64Alternate", + available: true, + f: checksumGeneric64Alternate, + }, +} diff --git a/tun/checksum_test.go b/tun/checksum_test.go new file mode 100644 index 000000000..f5b8f1869 --- /dev/null +++ b/tun/checksum_test.go @@ -0,0 +1,619 @@ +package tun + +import ( + "fmt" + "math" + "math/rand" + "net/netip" + "sort" + "syscall" + "testing" + "unsafe" + + "gvisor.dev/gvisor/pkg/tcpip" + gvisorChecksum "gvisor.dev/gvisor/pkg/tcpip/checksum" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +type archChecksumDetails struct { + name string + available bool + f func([]byte, uint16) uint16 +} + +func fillRandomBuffer(seed int64, buf []byte) { + rng := rand.New(rand.NewSource(seed)) + n, err := rng.Read(buf) + if err != nil { + panic(err) + } + if n != len(buf) { + panic("incomplete random buffer") + } +} + +func deterministicRandomBytes(seed int64, length int) []byte { + buf := make([]byte, length) + fillRandomBuffer(seed, buf) + return buf +} + +func getPageAlignedRandomBytes(seed int64, length int) []byte { + alignment := syscall.Getpagesize() + buf := make([]byte, length+(alignment-1)) + bufPtr := uintptr(unsafe.Pointer(&buf[0])) + alignedBufPtr := (bufPtr + uintptr(alignment-1)) & ^uintptr(alignment-1) + alignedStart := int(alignedBufPtr - bufPtr) + + buf = buf[alignedStart : alignedStart+length] + fillRandomBuffer(seed, buf) + return buf +} + +func TestChecksum(t *testing.T) { + alignedBuf := getPageAlignedRandomBytes(10, 8192) + allOnes := make([]byte, 65535) + for i := range allOnes { + allOnes[i] = 0xff + } + allFE := make([]byte, 65535) + for i := range allFE { + allFE[i] = 0xfe + } + + tests := []struct { + name string + data []byte + initial uint16 + want uint16 + }{ + { + name: "empty", + data: []byte{}, + initial: 0, + want: 0, + }, + { + name: "max initial", + data: []byte{}, + initial: math.MaxUint16, + want: 0xffff, + }, + { + name: "odd length", + data: []byte{0x01, 0x02, 0x01}, + initial: 0, + want: 0x0202, + }, + { + name: "tiny", + data: []byte{0x01, 0x02, 0x01, 0x02, 0x01, 0x02}, + initial: 0, + want: 0x0306, + }, + { + name: "initial", + data: []byte{0x01, 0x02, 0x01, 0x02, 0x01, 0x02}, + initial: 0x1000, + want: 0x1306, + }, + // cleanup0 through cleanup15 is 1024 (handled by large SIMD loops) + + // 32 (handled by small SIMD loops) + n, where n ranges from 0 to 15 + // to cover all of the leftover byte sizes that are possible after small + // SIMD loops that handle 16 bytes. + { + name: "cleanup0", + data: deterministicRandomBytes(1, 1056), + initial: 0, + want: 0x11ec, + }, + { + name: "cleanup1", + data: deterministicRandomBytes(1, 1057), + initial: 0, + want: 0xc5ec, + }, + { + name: "cleanup2", + data: deterministicRandomBytes(1, 1058), + initial: 0, + want: 0xc6ad, + }, + { + name: "cleanup3", + data: deterministicRandomBytes(1, 1059), + initial: 0, + want: 0x86ae, + }, + { + name: "cleanup4", + data: deterministicRandomBytes(1, 1060), + initial: 0, + want: 0x878e, + }, + { + name: "cleanup5", + data: deterministicRandomBytes(1, 1061), + initial: 0, + want: 0xdb8e, + }, + { + name: "cleanup6", + data: deterministicRandomBytes(1, 1062), + initial: 0, + want: 0xdbd5, + }, + { + name: "cleanup7", + data: deterministicRandomBytes(1, 1063), + initial: 0, + want: 0xcfd6, + }, + { + name: "cleanup8", + data: deterministicRandomBytes(1, 1064), + initial: 0, + want: 0xd090, + }, + { + name: "cleanup9", + data: deterministicRandomBytes(1, 1065), + initial: 0, + want: 0x0791, + }, + { + name: "cleanup10", + data: deterministicRandomBytes(1, 1066), + initial: 0, + want: 0x079f, + }, + { + name: "cleanup11", + data: deterministicRandomBytes(1, 1067), + initial: 0, + want: 0xba9f, + }, + { + name: "cleanup12", + data: deterministicRandomBytes(1, 1068), + initial: 0, + want: 0xbb0c, + }, + { + name: "cleanup13", + data: deterministicRandomBytes(1, 1069), + initial: 0, + want: 0x770d, + }, + { + name: "cleanup14", + data: deterministicRandomBytes(1, 1070), + initial: 0, + want: 0x780a, + }, + { + name: "cleanup15", + data: deterministicRandomBytes(1, 1071), + initial: 0, + want: 0x640b, + }, + // small1 through small15 covers small sizes that are not large enough + // to do overlapped reads. + { + name: "small1", + data: deterministicRandomBytes(2, 1), + initial: 0x1122, + want: 0x4022, + }, + { + name: "small2", + data: deterministicRandomBytes(2, 2), + initial: 0x1122, + want: 0x40a4, + }, + { + name: "small3", + data: deterministicRandomBytes(2, 3), + initial: 0x1122, + want: 0xc2a4, + }, + { + name: "small4", + data: deterministicRandomBytes(2, 4), + initial: 0x1122, + want: 0xc36f, + }, + { + name: "small5", + data: deterministicRandomBytes(2, 5), + initial: 0x1122, + want: 0xa570, + }, + { + name: "small6", + data: deterministicRandomBytes(2, 6), + initial: 0x1122, + want: 0xa669, + }, + { + name: "small7", + data: deterministicRandomBytes(2, 7), + initial: 0x1122, + want: 0x0f6a, + }, + { + name: "small8", + data: deterministicRandomBytes(2, 8), + initial: 0x1122, + want: 0x0fd9, + }, + { + name: "small9", + data: deterministicRandomBytes(2, 9), + initial: 0x1122, + want: 0x40d9, + }, + { + name: "small10", + data: deterministicRandomBytes(2, 10), + initial: 0x1122, + want: 0x411d, + }, + { + name: "small11", + data: deterministicRandomBytes(2, 11), + initial: 0x1122, + want: 0x011e, + }, + { + name: "small12", + data: deterministicRandomBytes(2, 12), + initial: 0x1122, + want: 0x01c8, + }, + { + name: "small13", + data: deterministicRandomBytes(2, 13), + initial: 0x1122, + want: 0x4dc8, + }, + { + name: "small14", + data: deterministicRandomBytes(2, 14), + initial: 0x1122, + want: 0x4eb5, + }, + { + name: "small15", + data: deterministicRandomBytes(2, 15), + initial: 0x1122, + want: 0xa4b5, + }, + // other small-ish sizes + { + name: "small16", + data: deterministicRandomBytes(1, 16), + initial: 0, + want: 0x02fa, + }, + { + name: "small32", + data: deterministicRandomBytes(1, 32), + initial: 0, + want: 0x03ee, + }, + { + name: "small64", + data: deterministicRandomBytes(1, 64), + initial: 0, + want: 0x3f85, + }, + { + name: "medium", + data: deterministicRandomBytes(1, 1400), + initial: 0, + want: 0xbea5, + }, + { + name: "big", + data: deterministicRandomBytes(2, 65000), + initial: 0, + want: 0x3ba7, + }, + { + name: "big-initial", + data: deterministicRandomBytes(2, 65000), + initial: 0x1234, + want: 0x4ddb, + }, + { + // big-small-loop is intended to exercise a few iterations of a big + // initial loop of 128 bytes or larger + a smaller loop of 16 bytes + // + some leftover + name: "big-small-loop", + data: deterministicRandomBytes(3, 1094), + initial: 0x9999, + want: 0xe65b, + }, + { + name: "page-aligned", + data: alignedBuf[:4096], + initial: 0, + want: 0x963b, + }, + { + name: "32-aligned", + data: alignedBuf[32:4128], + initial: 0, + want: 0x30c4, + }, + { + name: "16-aligned", + data: alignedBuf[16:4112], + initial: 0, + want: 0xaeff, + }, + { + name: "8-aligned", + data: alignedBuf[8:4104], + initial: 0, + want: 0x6c3b, + }, + { + name: "4-aligned", + data: alignedBuf[4:4100], + initial: 0, + want: 0x2e4a, + }, + { + name: "2-aligned", + data: alignedBuf[2:4098], + initial: 0, + want: 0xc702, + }, + { + name: "unaligned", + data: alignedBuf[1:4097], + initial: 0, + want: 0x3bc7, + }, + { + name: "unalignedAndOdd", + data: alignedBuf[1:4096], + initial: 0, + want: 0x3b13, + }, + { + name: "fe1282", + data: allFE[:1282], + initial: 0, + want: 0x7c7c, + }, + { + name: "fe", + data: allFE, + initial: 0, + want: 0x7e81, + }, + { + name: "maximum", + data: allOnes, + initial: 0, + want: 0xff00, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for _, fd := range archChecksumFuncs { + t.Run(fd.name, func(t *testing.T) { + if !fd.available { + t.Skip("can not run on this system") + } + if got := fd.f(tt.data, tt.initial); got != tt.want { + t.Errorf("%s checksum = %04x, want %04x", fd.name, got, tt.want) + } + }) + } + t.Run("reference", func(t *testing.T) { + if got := gvisorChecksum.Checksum(tt.data, tt.initial); got != tt.want { + t.Errorf("reference checksum = %04x, want %04x", got, tt.want) + } + }) + }) + } +} + +func TestPseudoHeaderChecksumNoFold(t *testing.T) { + tests := []struct { + name string + protocol uint8 + srcAddr []byte + dstAddr []byte + totalLen uint16 + want uint16 + }{ + { + name: "ipv4", + protocol: syscall.IPPROTO_TCP, + srcAddr: netip.MustParseAddr("192.168.1.1").AsSlice(), + dstAddr: netip.MustParseAddr("192.168.1.2").AsSlice(), + totalLen: 1492, + want: 0x892e, + }, + { + name: "ipv6", + protocol: syscall.IPPROTO_TCP, + srcAddr: netip.MustParseAddr("2001:db8:3333:4444:5555:6666:7777:8888").AsSlice(), + dstAddr: netip.MustParseAddr("2001:db8:aaaa:bbbb:cccc:dddd:eeee:ffff").AsSlice(), + totalLen: 1492, + want: 0x947f, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Run("pseudoHeaderChecksum32", func(t *testing.T) { + got := pseudoHeaderChecksum32(tt.protocol, tt.srcAddr, tt.dstAddr, tt.totalLen) + if got != tt.want { + t.Errorf("got %04x, want %04x", got, tt.want) + } + }) + t.Run("pseudoHeaderChecksum64", func(t *testing.T) { + got := pseudoHeaderChecksum64(tt.protocol, tt.srcAddr, tt.dstAddr, tt.totalLen) + if got != tt.want { + t.Errorf("got %04x, want %04x", got, tt.want) + } + }) + t.Run("reference", func(t *testing.T) { + got := header.PseudoHeaderChecksum( + tcpip.TransportProtocolNumber(tt.protocol), + tcpip.AddrFromSlice(tt.srcAddr), + tcpip.AddrFromSlice(tt.dstAddr), + tt.totalLen) + if got != tt.want { + t.Errorf("got %04x, want %04x", got, tt.want) + } + }) + }) + } +} + +func FuzzChecksum(f *testing.F) { + buf := getPageAlignedRandomBytes(1234, 65536) + + f.Add([]byte{}, uint16(0)) + f.Add([]byte{}, uint16(0x1234)) + f.Add([]byte{}, uint16(0)) + f.Add(buf[:15], uint16(0x1234)) + f.Add(buf[:256], uint16(0x1234)) + f.Add(buf[:1280], uint16(0x1234)) + f.Add(buf[:1288], uint16(0x1234)) + f.Add(buf[1:1050], uint16(0x1234)) + + f.Fuzz(func(t *testing.T, data []byte, initial uint16) { + want := gvisorChecksum.Checksum(data, initial) + + for _, fd := range archChecksumFuncs { + t.Run(fd.name, func(t *testing.T) { + if !fd.available { + t.Skip("can not run on this system") + } + if got := fd.f(data, initial); got != want { + t.Errorf("%s checksum = %04x, want %04x", fd.name, got, want) + } + }) + } + }) +} + +var result uint16 + +func BenchmarkChecksum(b *testing.B) { + offsets := []int{ // offsets from page alignment + 0, + 1, + 2, + 4, + 8, + 16, + } + lengths := []int{ + 0, + 7, + 15, + 16, + 31, + 64, + 90, + 95, + 128, + 256, + 512, + 1024, + 1240, + 1500, + 2048, + 4096, + 8192, + 9000, + 9001, + 16384, + 65536, + } + if !sort.IntsAreSorted(offsets) { + b.Fatal("offsets are not sorted") + } + largestLength := lengths[len(lengths)-1] + if !sort.IntsAreSorted(lengths) { + b.Fatal("lengths are not sorted") + } + largestOffset := lengths[len(offsets)-1] + alignedBuf := getPageAlignedRandomBytes(1, largestOffset+largestLength) + var r uint16 + for _, offset := range offsets { + name := fmt.Sprintf("%vAligned", offset) + if offset == 0 { + name = "pageAligned" + } + offsetBuf := alignedBuf[offset:] + b.Run(name, func(b *testing.B) { + for _, length := range lengths { + b.Run(fmt.Sprintf("%d", length), func(b *testing.B) { + for _, fd := range archChecksumFuncs { + b.Run(fd.name, func(b *testing.B) { + if !fd.available { + b.Skip("can not run on this system") + } + b.SetBytes(int64(length)) + for i := 0; i < b.N; i++ { + r += fd.f(offsetBuf[:length], 0) + } + }) + } + }) + } + }) + } + result = r +} + +func BenchmarkPseudoHeaderChecksum(b *testing.B) { + tests := []struct { + name string + protocol uint8 + srcAddr []byte + dstAddr []byte + totalLen uint16 + want uint16 + }{ + { + name: "ipv4", + protocol: syscall.IPPROTO_TCP, + srcAddr: []byte{192, 168, 1, 1}, + dstAddr: []byte{192, 168, 1, 2}, + totalLen: 1492, + want: 0x892e, + }, + { + name: "ipv6", + protocol: syscall.IPPROTO_TCP, + srcAddr: netip.MustParseAddr("2001:db8:3333:4444:5555:6666:7777:8888").AsSlice(), + dstAddr: netip.MustParseAddr("2001:db8:aaaa:bbbb:cccc:dddd:eeee:ffff").AsSlice(), + totalLen: 1492, + want: 0x892e, + }, + } + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + b.Run("pseudoHeaderChecksum32", func(b *testing.B) { + for i := 0; i < b.N; i++ { + result += pseudoHeaderChecksum32(tt.protocol, tt.srcAddr, tt.dstAddr, tt.totalLen) + } + }) + b.Run("pseudoHeaderChecksum64", func(b *testing.B) { + for i := 0; i < b.N; i++ { + result += pseudoHeaderChecksum64(tt.protocol, tt.srcAddr, tt.dstAddr, tt.totalLen) + } + }) + }) + } +} diff --git a/tun/generate_amd64.go b/tun/generate_amd64.go new file mode 100644 index 000000000..543e21132 --- /dev/null +++ b/tun/generate_amd64.go @@ -0,0 +1,579 @@ +//go:build ignore + +//go:generate go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go + +package main + +import ( + "fmt" + "math" + "math/bits" + + . "github.com/mmcloughlin/avo/build" + "github.com/mmcloughlin/avo/operand" + "github.com/mmcloughlin/avo/reg" +) + +const checksumSignature = "func(b []byte, initial uint16) uint16" + +func loadParams() (accum, buf, n reg.GPVirtual) { + accum, buf, n = GP64(), GP64(), GP64() + Load(Param("initial"), accum) + XCHGB(accum.As8H(), accum.As8L()) + Load(Param("b").Base(), buf) + Load(Param("b").Len(), n) + return +} + +type simdStrategy int + +const ( + sse2 = iota + avx2 +) + +const tinyBufferSize = 31 // A buffer is tiny if it has at most 31 bytes. + +func generateSIMDChecksum(name, doc string, minSIMDSize, chains int, strategy simdStrategy) { + TEXT(name, NOSPLIT|NOFRAME, checksumSignature) + Pragma("noescape") + Doc(doc) + + accum64, buf, n := loadParams() + + handleOddLength(n, buf, accum64) + // no chance of overflow because accum64 was initialized by a uint16 and + // handleOddLength adds at most a uint8 + handleTinyBuffers(n, buf, accum64, operand.LabelRef("foldAndReturn"), operand.LabelRef("bufferIsNotTiny")) + Label("bufferIsNotTiny") + + const simdReadSize = 16 + + if minSIMDSize > tinyBufferSize { + Comment("skip all SIMD for small buffers") + if minSIMDSize <= math.MaxUint8 { + CMPQ(n, operand.U8(minSIMDSize)) + } else { + CMPQ(n, operand.U32(minSIMDSize)) + } + JGE(operand.LabelRef("startSIMD")) + + handleRemaining(n, buf, accum64, minSIMDSize-1) + JMP(operand.LabelRef("foldAndReturn")) + } + + Label("startSIMD") + + // chains is the number of accumulators to use. This improves speed via + // reduced data dependency. We combine the accumulators once when the big + // loop is complete. + simdAccumulate := make([]reg.VecVirtual, chains) + for i := range simdAccumulate { + switch strategy { + case sse2: + simdAccumulate[i] = XMM() + PXOR(simdAccumulate[i], simdAccumulate[i]) + case avx2: + simdAccumulate[i] = YMM() + VPXOR(simdAccumulate[i], simdAccumulate[i], simdAccumulate[i]) + } + } + var zero reg.VecVirtual + if strategy == sse2 { + zero = XMM() + PXOR(zero, zero) + } + + // Number of loads per big loop + const unroll = 16 + // Number of bytes + loopSize := uint64(simdReadSize * unroll) + if bits.Len64(loopSize) != bits.Len64(loopSize-1)+1 { + panic("loopSize is not a power of 2") + } + loopCount := GP64() + + MOVQ(n, loopCount) + Comment("Update number of bytes remaining after the loop completes") + ANDQ(operand.Imm(loopSize-1), n) + Comment(fmt.Sprintf("Number of %d byte iterations", loopSize)) + SHRQ(operand.Imm(uint64(bits.Len64(loopSize-1))), loopCount) + JZ(operand.LabelRef("smallLoop")) + Label("bigLoop") + for i := 0; i < unroll; i++ { + chain := i % chains + switch strategy { + case sse2: + sse2AccumulateStep(i*simdReadSize, buf, zero, simdAccumulate[chain], simdAccumulate[(chain+chains/2)%chains]) + case avx2: + avx2AccumulateStep(i*simdReadSize, buf, simdAccumulate[chain]) + } + } + ADDQ(operand.U32(loopSize), buf) + DECQ(loopCount) + JNZ(operand.LabelRef("bigLoop")) + + Label("bigCleanup") + + CMPQ(n, operand.Imm(uint64(simdReadSize))) + JLT(operand.LabelRef("doneSmallLoop")) + + Commentf("now read a single %d byte unit of data at a time", simdReadSize) + Label("smallLoop") + + switch strategy { + case sse2: + sse2AccumulateStep(0, buf, zero, simdAccumulate[0], simdAccumulate[1]) + case avx2: + avx2AccumulateStep(0, buf, simdAccumulate[0]) + } + ADDQ(operand.Imm(uint64(simdReadSize)), buf) + SUBQ(operand.Imm(uint64(simdReadSize)), n) + CMPQ(n, operand.Imm(uint64(simdReadSize))) + JGE(operand.LabelRef("smallLoop")) + + Label("doneSmallLoop") + CMPQ(n, operand.Imm(0)) + JE(operand.LabelRef("doneSIMD")) + + Commentf("There are between 1 and %d bytes remaining. Perform an overlapped read.", simdReadSize-1) + + maskDataPtr := GP64() + LEAQ(operand.NewDataAddr(operand.NewStaticSymbol("xmmLoadMasks"), 0), maskDataPtr) + dataAddr := operand.Mem{Index: n, Scale: 1, Base: buf, Disp: -simdReadSize} + // scale 8 is only correct here because n is guaranteed to be even and we + // do not generate masks for odd lengths + maskAddr := operand.Mem{Base: maskDataPtr, Index: n, Scale: 8, Disp: -16} + remainder := XMM() + + switch strategy { + case sse2: + MOVOU(dataAddr, remainder) + PAND(maskAddr, remainder) + low := XMM() + MOVOA(remainder, low) + PUNPCKHWL(zero, remainder) + PUNPCKLWL(zero, low) + PADDD(remainder, simdAccumulate[0]) + PADDD(low, simdAccumulate[1]) + case avx2: + // Note: this is very similar to the sse2 path but MOVOU has a massive + // performance hit if used here, presumably due to switching between SSE + // and AVX2 modes. + VMOVDQU(dataAddr, remainder) + VPAND(maskAddr, remainder, remainder) + + temp := YMM() + VPMOVZXWD(remainder, temp) + VPADDD(temp, simdAccumulate[0], simdAccumulate[0]) + } + + Label("doneSIMD") + + Comment("Multi-chain loop is done, combine the accumulators") + for i := range simdAccumulate { + if i == 0 { + continue + } + switch strategy { + case sse2: + PADDD(simdAccumulate[i], simdAccumulate[0]) + case avx2: + VPADDD(simdAccumulate[i], simdAccumulate[0], simdAccumulate[0]) + } + } + + if strategy == avx2 { + Comment("extract the YMM into a pair of XMM and sum them") + tmp := YMM() + VEXTRACTI128(operand.Imm(1), simdAccumulate[0], tmp.AsX()) + + xAccumulate := XMM() + VPADDD(simdAccumulate[0].AsX(), tmp.AsX(), xAccumulate) + simdAccumulate = []reg.VecVirtual{xAccumulate} + } + + Comment("extract the XMM into GP64") + low, high := GP64(), GP64() + switch strategy { + case sse2: + MOVQ(simdAccumulate[0], low) + PSRLDQ(operand.Imm(8), simdAccumulate[0]) + MOVQ(simdAccumulate[0], high) + case avx2: + VPEXTRQ(operand.Imm(0), simdAccumulate[0], low) + VPEXTRQ(operand.Imm(1), simdAccumulate[0], high) + + Comment("no more AVX code, clear upper registers to avoid SSE slowdowns") + VZEROUPPER() + } + ADDQ(low, accum64) + ADCQ(high, accum64) + Label("foldAndReturn") + foldWithCF(accum64, strategy == avx2) + XCHGB(accum64.As8H(), accum64.As8L()) + Store(accum64.As16(), ReturnIndex(0)) + RET() +} + +// handleOddLength generates instructions to incorporate the last byte into +// accum64 if the length is odd. CF may be set if accum64 overflows; be sure to +// handle that if overflow is possible. +func handleOddLength(n, buf, accum64 reg.GPVirtual) { + Comment("handle odd length buffers; they are difficult to handle in general") + TESTQ(operand.U32(1), n) + JZ(operand.LabelRef("lengthIsEven")) + + tmp := GP64() + MOVBQZX(operand.Mem{Base: buf, Index: n, Scale: 1, Disp: -1}, tmp) + DECQ(n) + ADDQ(tmp, accum64) + + Label("lengthIsEven") +} + +func sse2AccumulateStep(offset int, buf reg.GPVirtual, zero, accumulate1, accumulate2 reg.VecVirtual) { + high, low := XMM(), XMM() + MOVOU(operand.Mem{Disp: offset, Base: buf}, high) + MOVOA(high, low) + PUNPCKHWL(zero, high) + PUNPCKLWL(zero, low) + PADDD(high, accumulate1) + PADDD(low, accumulate2) +} + +func avx2AccumulateStep(offset int, buf reg.GPVirtual, accumulate reg.VecVirtual) { + tmp := YMM() + VPMOVZXWD(operand.Mem{Disp: offset, Base: buf}, tmp) + VPADDD(tmp, accumulate, accumulate) +} + +func generateAMD64Checksum(name, doc string) { + TEXT(name, NOSPLIT|NOFRAME, checksumSignature) + Pragma("noescape") + Doc(doc) + + accum64, buf, n := loadParams() + + handleOddLength(n, buf, accum64) + // no chance of overflow because accum64 was initialized by a uint16 and + // handleOddLength adds at most a uint8 + handleTinyBuffers(n, buf, accum64, operand.LabelRef("foldAndReturn"), operand.LabelRef("bufferIsNotTiny")) + Label("bufferIsNotTiny") + + const ( + // numChains is the number of accumulators and carry counters to use. + // This improves speed via reduced data dependency. We combine the + // accumulators and carry counters once when the loop is complete. + numChains = 4 + unroll = 32 // The number of 64-bit reads to perform per iteration of the loop. + loopSize = 8 * unroll // The number of bytes read per iteration of the loop. + ) + if bits.Len(loopSize) != bits.Len(loopSize-1)+1 { + panic("loopSize is not a power of 2") + } + loopCount := GP64() + + Comment(fmt.Sprintf("Number of %d byte iterations into loop counter", loopSize)) + MOVQ(n, loopCount) + Comment("Update number of bytes remaining after the loop completes") + ANDQ(operand.Imm(loopSize-1), n) + SHRQ(operand.Imm(uint64(bits.Len(loopSize-1))), loopCount) + JZ(operand.LabelRef("startCleanup")) + CLC() + + chains := make([]struct { + accum reg.GPVirtual + carries reg.GPVirtual + }, numChains) + for i := range chains { + if i == 0 { + chains[i].accum = accum64 + } else { + chains[i].accum = GP64() + XORQ(chains[i].accum, chains[i].accum) + } + chains[i].carries = GP64() + XORQ(chains[i].carries, chains[i].carries) + } + + Label("bigLoop") + + var curChain int + for i := 0; i < unroll; i++ { + // It is significantly faster to use a ADCX/ADOX pair instead of plain + // ADC, which results in two dependency chains, however those require + // ADX support, which was added after AVX2. If AVX2 is available, that's + // even better than ADCX/ADOX. + // + // However, multiple dependency chains using multiple accumulators and + // occasionally storing CF into temporary counters seems to work almost + // as well. + addr := operand.Mem{Disp: i * 8, Base: buf} + + if i%4 == 0 { + if i > 0 { + ADCQ(operand.Imm(0), chains[curChain].carries) + curChain = (curChain + 1) % len(chains) + } + ADDQ(addr, chains[curChain].accum) + } else { + ADCQ(addr, chains[curChain].accum) + } + } + ADCQ(operand.Imm(0), chains[curChain].carries) + ADDQ(operand.U32(loopSize), buf) + SUBQ(operand.Imm(1), loopCount) + JNZ(operand.LabelRef("bigLoop")) + for i := range chains { + if i == 0 { + ADDQ(chains[i].carries, accum64) + continue + } + ADCQ(chains[i].accum, accum64) + ADCQ(chains[i].carries, accum64) + } + + accumulateCF(accum64) + + Label("startCleanup") + handleRemaining(n, buf, accum64, loopSize-1) + Label("foldAndReturn") + foldWithCF(accum64, false) + + XCHGB(accum64.As8H(), accum64.As8L()) + Store(accum64.As16(), ReturnIndex(0)) + RET() +} + +// handleTinyBuffers computes checksums if the buffer length (the n parameter) +// is less than 32. After computing the checksum, a jump to returnLabel will +// be executed. Otherwise, if the buffer length is at least 32, nothing will be +// modified; a jump to continueLabel will be executed instead. +// +// When jumping to returnLabel, CF may be set and must be accommodated e.g. +// using foldWithCF or accumulateCF. +// +// Anecdotally, this appears to be faster than attempting to coordinate an +// overlapped read (which would also require special handling for buffers +// smaller than 8). +func handleTinyBuffers(n, buf, accum reg.GPVirtual, returnLabel, continueLabel operand.LabelRef) { + Comment("handle tiny buffers (<=31 bytes) specially") + CMPQ(n, operand.Imm(tinyBufferSize)) + JGT(continueLabel) + + tmp2, tmp4, tmp8 := GP64(), GP64(), GP64() + XORQ(tmp2, tmp2) + XORQ(tmp4, tmp4) + XORQ(tmp8, tmp8) + + Comment("shift twice to start because length is guaranteed to be even", + "n = n >> 2; CF = originalN & 2") + SHRQ(operand.Imm(2), n) + JNC(operand.LabelRef("handleTiny4")) + Comment("tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]") + MOVWQZX(operand.Mem{Base: buf}, tmp2) + ADDQ(operand.Imm(2), buf) + + Label("handleTiny4") + Comment("n = n >> 1; CF = originalN & 4") + SHRQ(operand.Imm(1), n) + JNC(operand.LabelRef("handleTiny8")) + Comment("tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]") + MOVLQZX(operand.Mem{Base: buf}, tmp4) + ADDQ(operand.Imm(4), buf) + + Label("handleTiny8") + Comment("n = n >> 1; CF = originalN & 8") + SHRQ(operand.Imm(1), n) + JNC(operand.LabelRef("handleTiny16")) + Comment("tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]") + MOVQ(operand.Mem{Base: buf}, tmp8) + ADDQ(operand.Imm(8), buf) + + Label("handleTiny16") + Comment("n = n >> 1; CF = originalN & 16", + "n == 0 now, otherwise we would have branched after comparing with tinyBufferSize") + SHRQ(operand.Imm(1), n) + JNC(operand.LabelRef("handleTinyFinish")) + ADDQ(operand.Mem{Base: buf}, accum) + ADCQ(operand.Mem{Base: buf, Disp: 8}, accum) + + Label("handleTinyFinish") + Comment("CF should be included from the previous add, so we use ADCQ.", + "If we arrived via the JNC above, then CF=0 due to the branch condition,", + "so ADCQ will still produce the correct result.") + ADCQ(tmp2, accum) + ADCQ(tmp4, accum) + ADCQ(tmp8, accum) + + JMP(returnLabel) +} + +// handleRemaining generates a series of conditional unrolled additions, +// starting with 8 bytes long and doubling each time until the length reaches +// max. This is the reverse order of what may be intuitive, but makes the branch +// conditions convenient to compute: perform one right shift each time and test +// against CF. +// +// When done, CF may be set and must be accommodated e.g., using foldWithCF or +// accumulateCF. +// +// If n is not a multiple of 8, an extra 64 bit read at the end of the buffer +// will be performed, overlapping with data that will be read later. The +// duplicate data will be shifted off. +// +// The original buffer length must have been at least 8 bytes long, even if +// n < 8, otherwise this will access memory before the start of the buffer, +// which may be unsafe. +func handleRemaining(n, buf, accum64 reg.GPVirtual, max int) { + Comment("Accumulate carries in this register. It is never expected to overflow.") + carries := GP64() + XORQ(carries, carries) + + Comment("We will perform an overlapped read for buffers with length not a multiple of 8.", + "Overlapped in this context means some memory will be read twice, but a shift will", + "eliminate the duplicated data. This extra read is performed at the end of the buffer to", + "preserve any alignment that may exist for the start of the buffer.") + leftover := reg.RCX + MOVQ(n, leftover) + SHRQ(operand.Imm(3), n) // n is now the number of 64 bit reads remaining + ANDQ(operand.Imm(0x7), leftover) // leftover is now the number of bytes to read from the end + JZ(operand.LabelRef("handleRemaining8")) + endBuf := GP64() + // endBuf is the position near the end of the buffer that is just past the + // last multiple of 8: (buf + len(buf)) & ^0x7 + LEAQ(operand.Mem{Base: buf, Index: n, Scale: 8}, endBuf) + + overlapRead := GP64() + // equivalent to overlapRead = binary.LittleEndian.Uint64(buf[len(buf)-8:len(buf)]) + MOVQ(operand.Mem{Base: endBuf, Index: leftover, Scale: 1, Disp: -8}, overlapRead) + + Comment("Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)") + SHLQ(operand.Imm(3), leftover) // leftover = leftover * 8 + NEGQ(leftover) // leftover = -leftover; this completes the (-leftoverBytes*8) part of the expression + ADDQ(operand.Imm(64), leftover) // now we have (64 - leftoverBytes*8) + SHRQ(reg.CL, overlapRead) // shift right by (64 - leftoverBytes*8); CL is the low 8 bits of leftover (set to RCX above) and variable shift only accepts CL + + ADDQ(overlapRead, accum64) + ADCQ(operand.Imm(0), carries) + + for curBytes := 8; curBytes <= max; curBytes *= 2 { + Label(fmt.Sprintf("handleRemaining%d", curBytes)) + SHRQ(operand.Imm(1), n) + if curBytes*2 <= max { + JNC(operand.LabelRef(fmt.Sprintf("handleRemaining%d", curBytes*2))) + } else { + JNC(operand.LabelRef("handleRemainingComplete")) + } + + numLoads := curBytes / 8 + for i := 0; i < numLoads; i++ { + addr := operand.Mem{Base: buf, Disp: i * 8} + // It is possible to add the multiple dependency chains trick here + // that generateAMD64Checksum uses but anecdotally it does not + // appear to outweigh the cost. + if i == 0 { + ADDQ(addr, accum64) + continue + } + ADCQ(addr, accum64) + } + ADCQ(operand.Imm(0), carries) + + if curBytes > math.MaxUint8 { + ADDQ(operand.U32(uint64(curBytes)), buf) + } else { + ADDQ(operand.U8(uint64(curBytes)), buf) + } + if curBytes*2 >= max { + continue + } + JMP(operand.LabelRef(fmt.Sprintf("handleRemaining%d", curBytes*2))) + } + Label("handleRemainingComplete") + ADDQ(carries, accum64) +} + +func accumulateCF(accum64 reg.GPVirtual) { + Comment("accumulate CF (twice, in case the first time overflows)") + // accum64 += CF + ADCQ(operand.Imm(0), accum64) + // accum64 += CF again if the previous add overflowed. The previous add was + // 0 or 1. If it overflowed, then accum64 == 0, so adding another 1 can + // never overflow. + ADCQ(operand.Imm(0), accum64) +} + +// foldWithCF generates instructions to fold accum (a GP64) into a 16-bit value +// according to ones-complement arithmetic. BMI2 instructions will be used if +// allowBMI2 is true (requires fewer instructions). +func foldWithCF(accum reg.GPVirtual, allowBMI2 bool) { + Comment("add CF and fold") + + // CF|accum max value starts as 0x1_ffff_ffff_ffff_ffff + + tmp := GP64() + if allowBMI2 { + // effectively, tmp = accum >> 32 (technically, this is a rotate) + RORXQ(operand.Imm(32), accum, tmp) + // accum as uint32 = uint32(accum) + uint32(tmp64) + CF; max value 0xffff_ffff + CF set + ADCL(tmp.As32(), accum.As32()) + // effectively, tmp64 as uint32 = uint32(accum) >> 16 (also a rotate) + RORXL(operand.Imm(16), accum.As32(), tmp.As32()) + // accum as uint16 = uint16(accum) + uint16(tmp) + CF; max value 0xffff + CF unset or 0xfffe + CF set + ADCW(tmp.As16(), accum.As16()) + } else { + // tmp = uint32(accum); max value 0xffff_ffff + // MOVL clears the upper 32 bits of a GP64 so this is equivalent to the + // non-existent MOVLQZX. + MOVL(accum.As32(), tmp.As32()) + // tmp += CF; max value 0x1_0000_0000, CF unset + ADCQ(operand.Imm(0), tmp) + // accum = accum >> 32; max value 0xffff_ffff + SHRQ(operand.Imm(32), accum) + // accum = accum + tmp; max value 0x1_ffff_ffff + CF unset + ADDQ(tmp, accum) + // tmp = uint16(accum); max value 0xffff + MOVWQZX(accum.As16(), tmp) + // accum = accum >> 16; max value 0x1_ffff + SHRQ(operand.Imm(16), accum) + // accum = accum + tmp; max value 0x2_fffe + CF unset + ADDQ(tmp, accum) + // tmp as uint16 = uint16(accum); max value 0xffff + MOVW(accum.As16(), tmp.As16()) + // accum = accum >> 16; max value 0x2 + SHRQ(operand.Imm(16), accum) + // accum as uint16 = uint16(accum) + uint16(tmp); max value 0xffff + CF unset or 0x2 + CF set + ADDW(tmp.As16(), accum.As16()) + } + // accum as uint16 += CF; will not overflow: either CF was 0 or accum <= 0xfffe + ADCW(operand.Imm(0), accum.As16()) +} + +func generateLoadMasks() { + var offset int + // xmmLoadMasks is a table of masks that can be used with PAND to zero all but the last N bytes in an XMM, N=2,4,6,8,10,12,14 + GLOBL("xmmLoadMasks", RODATA|NOPTR) + + for n := 2; n < 16; n += 2 { + var pattern [16]byte + for i := 0; i < len(pattern); i++ { + if i < len(pattern)-n { + pattern[i] = 0 + continue + } + pattern[i] = 0xff + } + DATA(offset, operand.String(pattern[:])) + offset += len(pattern) + } +} + +func main() { + generateLoadMasks() + generateSIMDChecksum("checksumAVX2", "checksumAVX2 computes an IP checksum using amd64 v3 instructions (AVX2, BMI2)", 256, 4, avx2) + generateSIMDChecksum("checksumSSE2", "checksumSSE2 computes an IP checksum using amd64 baseline instructions (SSE2)", 256, 4, sse2) + generateAMD64Checksum("checksumAMD64", "checksumAMD64 computes an IP checksum using amd64 baseline instructions") + Generate() +} diff --git a/tun/netstack/examples/http_client.go b/tun/netstack/examples/http_client.go index ccd32ede3..81f4d3180 100644 --- a/tun/netstack/examples/http_client.go +++ b/tun/netstack/examples/http_client.go @@ -13,9 +13,9 @@ import ( "net/http" "net/netip" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun/netstack" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun/netstack" ) func main() { diff --git a/tun/netstack/examples/http_server.go b/tun/netstack/examples/http_server.go index f5b7a8ff8..30f454454 100644 --- a/tun/netstack/examples/http_server.go +++ b/tun/netstack/examples/http_server.go @@ -14,9 +14,9 @@ import ( "net/http" "net/netip" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun/netstack" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun/netstack" ) func main() { diff --git a/tun/netstack/examples/ping_client.go b/tun/netstack/examples/ping_client.go index 2eef0fbc2..fc991b156 100644 --- a/tun/netstack/examples/ping_client.go +++ b/tun/netstack/examples/ping_client.go @@ -17,9 +17,9 @@ import ( "golang.org/x/net/icmp" "golang.org/x/net/ipv4" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun/netstack" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun/netstack" ) func main() { diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index 596cfcd8a..d8e70bb03 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -22,10 +22,10 @@ import ( "syscall" "time" - "golang.zx2c4.com/wireguard/tun" + "github.com/tailscale/wireguard-go/tun" "golang.org/x/net/dns/dnsmessage" - "gvisor.dev/gvisor/pkg/bufferv2" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -43,7 +43,7 @@ type netTun struct { ep *channel.Endpoint stack *stack.Stack events chan tun.Event - incomingPacket chan *bufferv2.View + incomingPacket chan *buffer.View mtu int dnsServers []netip.Addr hasV4, hasV6 bool @@ -61,7 +61,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, ep: channel.New(1024, uint32(mtu), ""), stack: stack.New(opts), events: make(chan tun.Event, 10), - incomingPacket: make(chan *bufferv2.View), + incomingPacket: make(chan *buffer.View), dnsServers: dnsServers, mtu: mtu, } @@ -84,7 +84,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, } protoAddr := tcpip.ProtocolAddress{ Protocol: protoNumber, - AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(), + AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), } tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) if tcpipErr != nil { @@ -140,7 +140,7 @@ func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { continue } - pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)}) + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) switch packet[0] >> 4 { case 4: tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) @@ -198,7 +198,7 @@ func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.Networ } return tcpip.FullAddress{ NIC: 1, - Addr: tcpip.Address(endpoint.Addr().AsSlice()), + Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), Port: endpoint.Port(), }, protoNumber } @@ -453,7 +453,7 @@ func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { return 0, nil, fmt.Errorf("ping read: %s", tcpipErr) } - remoteAddr, _ := netip.AddrFromSlice([]byte(res.RemoteAddr.Addr)) + remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice()) return res.Count, &PingAddr{remoteAddr}, nil } diff --git a/tun/offload.go b/tun/offload.go new file mode 100644 index 000000000..6db437c34 --- /dev/null +++ b/tun/offload.go @@ -0,0 +1,220 @@ +package tun + +import ( + "encoding/binary" + "fmt" +) + +// GSOType represents the type of segmentation offload. +type GSOType int + +const ( + GSONone GSOType = iota + GSOTCPv4 + GSOTCPv6 + GSOUDPL4 +) + +func (g GSOType) String() string { + switch g { + case GSONone: + return "GSONone" + case GSOTCPv4: + return "GSOTCPv4" + case GSOTCPv6: + return "GSOTCPv6" + case GSOUDPL4: + return "GSOUDPL4" + default: + return "unknown" + } +} + +// GSOOptions is loosely modeled after struct virtio_net_hdr from the VIRTIO +// specification. It is a common representation of GSO metadata that can be +// applied to support packet GSO across tun.Device implementations. +type GSOOptions struct { + // GSOType represents the type of segmentation offload. + GSOType GSOType + // HdrLen is the sum of the layer 3 and 4 header lengths. This field may be + // zero when GSOType == GSONone. + HdrLen uint16 + // CsumStart is the head byte index of the packet data to be checksummed, + // i.e. the start of the TCP or UDP header. + CsumStart uint16 + // CsumOffset is the offset from CsumStart where the 2-byte checksum value + // should be placed. + CsumOffset uint16 + // GSOSize is the size of each segment exclusive of HdrLen. The tail segment + // may be smaller than this value. + GSOSize uint16 + // NeedsCsum may be set where GSOType == GSONone. When set, the checksum + // at CsumStart + CsumOffset must be a partial checksum, i.e. the + // pseudo-header sum. + NeedsCsum bool +} + +const ( + ipv4SrcAddrOffset = 12 + ipv6SrcAddrOffset = 8 +) + +const tcpFlagsOffset = 13 + +const ( + tcpFlagFIN uint8 = 0x01 + tcpFlagPSH uint8 = 0x08 + tcpFlagACK uint8 = 0x10 +) + +const ( + // defined here in order to avoid importation of any platform-specific pkgs + ipProtoTCP = 6 + ipProtoUDP = 17 +) + +// GSOSplit splits packets from 'in' into outBufs[][outOffset:], writing +// the size of each element into sizes. It returns the number of buffers +// populated, and/or an error. Callers may pass an 'in' slice that overlaps with +// the first element of outBuffers, i.e. &in[0] may be equal to +// &outBufs[0][outOffset]. GSONone is a valid options.GSOType regardless of the +// value of options.NeedsCsum. Length of each outBufs element must be greater +// than or equal to the length of 'in', otherwise output may be silently +// truncated. +func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outOffset int) (int, error) { + cSumAt := int(options.CsumStart) + int(options.CsumOffset) + if cSumAt+1 >= len(in) { + return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) + } + + if len(in) < int(options.HdrLen) { + return 0, fmt.Errorf("length of packet (%d) < GSO HdrLen (%d)", len(in), options.HdrLen) + } + + // Handle the conditions where we are copying a single element to outBuffs. + payloadLen := len(in) - int(options.HdrLen) + if options.GSOType == GSONone || payloadLen < int(options.GSOSize) { + if len(in) > len(outBufs[0][outOffset:]) { + return 0, fmt.Errorf("length of packet (%d) exceeds output element length (%d)", len(in), len(outBufs[0][outOffset:])) + } + if options.NeedsCsum { + // The initial value at the checksum offset should be summed with + // the checksum we compute. This is typically the pseudo-header sum. + initial := binary.BigEndian.Uint16(in[cSumAt:]) + in[cSumAt], in[cSumAt+1] = 0, 0 + binary.BigEndian.PutUint16(in[cSumAt:], ^Checksum(in[options.CsumStart:], initial)) + } + sizes[0] = copy(outBufs[0][outOffset:], in) + return 1, nil + } + + if options.HdrLen < options.CsumStart { + return 0, fmt.Errorf("GSO HdrLen (%d) < GSO CsumStart (%d)", options.HdrLen, options.CsumStart) + } + + ipVersion := in[0] >> 4 + switch ipVersion { + case 4: + if options.GSOType != GSOTCPv4 && options.GSOType != GSOUDPL4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %s", ipVersion, options.GSOType) + } + if len(in) < 20 { + return 0, fmt.Errorf("length of packet (%d) < minimum ipv4 header size (%d)", len(in), 20) + } + case 6: + if options.GSOType != GSOTCPv6 && options.GSOType != GSOUDPL4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %s", ipVersion, options.GSOType) + } + if len(in) < 40 { + return 0, fmt.Errorf("length of packet (%d) < minimum ipv6 header size (%d)", len(in), 40) + } + default: + return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) + } + + iphLen := int(options.CsumStart) + srcAddrOffset := ipv6SrcAddrOffset + addrLen := 16 + if ipVersion == 4 { + srcAddrOffset = ipv4SrcAddrOffset + addrLen = 4 + } + transportCsumAt := int(options.CsumStart + options.CsumOffset) + var firstTCPSeqNum uint32 + var protocol uint8 + if options.GSOType == GSOTCPv4 || options.GSOType == GSOTCPv6 { + protocol = ipProtoTCP + if len(in) < int(options.CsumStart)+20 { + return 0, fmt.Errorf("length of packet (%d) < GSO CsumStart (%d) + minimum TCP header size (%d)", + len(in), options.CsumStart, 20) + } + firstTCPSeqNum = binary.BigEndian.Uint32(in[options.CsumStart+4:]) + } else { + protocol = ipProtoUDP + } + nextSegmentDataAt := int(options.HdrLen) + i := 0 + for ; nextSegmentDataAt < len(in); i++ { + if i == len(outBufs) { + return i - 1, ErrTooManySegments + } + nextSegmentEnd := nextSegmentDataAt + int(options.GSOSize) + if nextSegmentEnd > len(in) { + nextSegmentEnd = len(in) + } + segmentDataLen := nextSegmentEnd - nextSegmentDataAt + totalLen := int(options.HdrLen) + segmentDataLen + sizes[i] = totalLen + out := outBufs[i][outOffset:] + + copy(out, in[:iphLen]) + if ipVersion == 4 { + // For IPv4 we are responsible for incrementing the ID field, + // updating the total len field, and recalculating the header + // checksum. + if i > 0 { + id := binary.BigEndian.Uint16(out[4:]) + id += uint16(i) + binary.BigEndian.PutUint16(out[4:], id) + } + out[10], out[11] = 0, 0 // clear ipv4 header checksum + binary.BigEndian.PutUint16(out[2:], uint16(totalLen)) + ipv4CSum := ^Checksum(out[:iphLen], 0) + binary.BigEndian.PutUint16(out[10:], ipv4CSum) + } else { + // For IPv6 we are responsible for updating the payload length field. + binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen)) + } + + // copy transport header + copy(out[options.CsumStart:options.HdrLen], in[options.CsumStart:options.HdrLen]) + + if protocol == ipProtoTCP { + // set TCP seq and adjust TCP flags + tcpSeq := firstTCPSeqNum + uint32(options.GSOSize*uint16(i)) + binary.BigEndian.PutUint32(out[options.CsumStart+4:], tcpSeq) + if nextSegmentEnd != len(in) { + // FIN and PSH should only be set on last segment + clearFlags := tcpFlagFIN | tcpFlagPSH + out[options.CsumStart+tcpFlagsOffset] &^= clearFlags + } + } else { + // set UDP header len + binary.BigEndian.PutUint16(out[options.CsumStart+4:], uint16(segmentDataLen)+(options.HdrLen-options.CsumStart)) + } + + // payload + copy(out[options.HdrLen:], in[nextSegmentDataAt:nextSegmentEnd]) + + // transport checksum + out[transportCsumAt], out[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum + transportHeaderLen := int(options.HdrLen - options.CsumStart) + lenForPseudo := uint16(transportHeaderLen + segmentDataLen) + transportCSum := PseudoHeaderChecksum(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo) + transportCSum = ^Checksum(out[options.CsumStart:totalLen], transportCSum) + binary.BigEndian.PutUint16(out[options.CsumStart+options.CsumOffset:], transportCSum) + + nextSegmentDataAt += int(options.GSOSize) + } + return i, nil +} diff --git a/tun/offload_linux.go b/tun/offload_linux.go new file mode 100644 index 000000000..fb6ac5b94 --- /dev/null +++ b/tun/offload_linux.go @@ -0,0 +1,911 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "unsafe" + + "github.com/tailscale/wireguard-go/conn" + "golang.org/x/sys/unix" +) + +// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The +// kernel symbol is virtio_net_hdr. +type virtioNetHdr struct { + flags uint8 + gsoType uint8 + hdrLen uint16 + gsoSize uint16 + csumStart uint16 + csumOffset uint16 +} + +func (v *virtioNetHdr) toGSOOptions() (GSOOptions, error) { + var gsoType GSOType + switch v.gsoType { + case unix.VIRTIO_NET_HDR_GSO_NONE: + gsoType = GSONone + case unix.VIRTIO_NET_HDR_GSO_TCPV4: + gsoType = GSOTCPv4 + case unix.VIRTIO_NET_HDR_GSO_TCPV6: + gsoType = GSOTCPv6 + case unix.VIRTIO_NET_HDR_GSO_UDP_L4: + gsoType = GSOUDPL4 + default: + return GSOOptions{}, fmt.Errorf("unsupported virtio gsoType: %d", v.gsoType) + } + return GSOOptions{ + GSOType: gsoType, + HdrLen: v.hdrLen, + CsumStart: v.csumStart, + CsumOffset: v.csumOffset, + GSOSize: v.gsoSize, + NeedsCsum: v.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0, + }, nil +} + +func (v *virtioNetHdr) decode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen]) + return nil +} + +func (v *virtioNetHdr) encode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen)) + return nil +} + +const ( + // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the + // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr). + virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{})) +) + +// tcpFlowKey represents the key for a TCP flow. +type tcpFlowKey struct { + srcAddr, dstAddr [16]byte + srcPort, dstPort uint16 + rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows. + isV6 bool +} + +// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO. +type tcpGROTable struct { + itemsByFlow map[tcpFlowKey][]tcpGROItem + itemsPool [][]tcpGROItem +} + +func newTCPGROTable() *tcpGROTable { + t := &tcpGROTable{ + itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, conn.IdealBatchSize), + itemsPool: make([][]tcpGROItem, conn.IdealBatchSize), + } + for i := range t.itemsPool { + t.itemsPool[i] = make([]tcpGROItem, 0, conn.IdealBatchSize) + } + return t +} + +func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey { + key := tcpFlowKey{} + addrSize := dstAddrOffset - srcAddrOffset + copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset]) + copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize]) + key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:]) + key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:]) + key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:]) + key.isV6 = addrSize == 16 + return key +} + +// lookupOrInsert looks up a flow for the provided packet and metadata, +// returning the packets found for the flow, or inserting a new one if none +// is found. +func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) { + key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + items, ok := t.itemsByFlow[key] + if ok { + return items, ok + } + // TODO: insert() performs another map lookup. This could be rearranged to avoid. + t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex) + return nil, false +} + +// insert an item in the table for the provided packet and packet metadata. +func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) { + key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + item := tcpGROItem{ + key: key, + bufsIndex: uint16(bufsIndex), + gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])), + iphLen: uint8(tcphOffset), + tcphLen: uint8(tcphLen), + sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]), + pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0, + } + items, ok := t.itemsByFlow[key] + if !ok { + items = t.newItems() + } + items = append(items, item) + t.itemsByFlow[key] = items +} + +func (t *tcpGROTable) updateAt(item tcpGROItem, i int) { + items, _ := t.itemsByFlow[item.key] + items[i] = item +} + +func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) { + items, _ := t.itemsByFlow[key] + items = append(items[:i], items[i+1:]...) + t.itemsByFlow[key] = items +} + +// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime +// of a GRO evaluation across a vector of packets. +type tcpGROItem struct { + key tcpFlowKey + sentSeq uint32 // the sequence number + bufsIndex uint16 // the index into the original bufs slice + numMerged uint16 // the number of packets merged into this item + gsoSize uint16 // payload size + iphLen uint8 // ip header len + tcphLen uint8 // tcp header len + pshSet bool // psh flag is set +} + +func (t *tcpGROTable) newItems() []tcpGROItem { + var items []tcpGROItem + items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1] + return items +} + +func (t *tcpGROTable) reset() { + for k, items := range t.itemsByFlow { + items = items[:0] + t.itemsPool = append(t.itemsPool, items) + delete(t.itemsByFlow, k) + } +} + +// udpFlowKey represents the key for a UDP flow. +type udpFlowKey struct { + srcAddr, dstAddr [16]byte + srcPort, dstPort uint16 + isV6 bool +} + +// udpGROTable holds flow and coalescing information for the purposes of UDP GRO. +type udpGROTable struct { + itemsByFlow map[udpFlowKey][]udpGROItem + itemsPool [][]udpGROItem +} + +func newUDPGROTable() *udpGROTable { + u := &udpGROTable{ + itemsByFlow: make(map[udpFlowKey][]udpGROItem, conn.IdealBatchSize), + itemsPool: make([][]udpGROItem, conn.IdealBatchSize), + } + for i := range u.itemsPool { + u.itemsPool[i] = make([]udpGROItem, 0, conn.IdealBatchSize) + } + return u +} + +func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey { + key := udpFlowKey{} + addrSize := dstAddrOffset - srcAddrOffset + copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset]) + copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize]) + key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:]) + key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:]) + key.isV6 = addrSize == 16 + return key +} + +// lookupOrInsert looks up a flow for the provided packet and metadata, +// returning the packets found for the flow, or inserting a new one if none +// is found. +func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) { + key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) + items, ok := u.itemsByFlow[key] + if ok { + return items, ok + } + // TODO: insert() performs another map lookup. This could be rearranged to avoid. + u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false) + return nil, false +} + +// insert an item in the table for the provided packet and packet metadata. +func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) { + key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) + item := udpGROItem{ + key: key, + bufsIndex: uint16(bufsIndex), + gsoSize: uint16(len(pkt[udphOffset+udphLen:])), + iphLen: uint8(udphOffset), + cSumKnownInvalid: cSumKnownInvalid, + } + items, ok := u.itemsByFlow[key] + if !ok { + items = u.newItems() + } + items = append(items, item) + u.itemsByFlow[key] = items +} + +func (u *udpGROTable) updateAt(item udpGROItem, i int) { + items, _ := u.itemsByFlow[item.key] + items[i] = item +} + +// udpGROItem represents bookkeeping data for a UDP packet during the lifetime +// of a GRO evaluation across a vector of packets. +type udpGROItem struct { + key udpFlowKey + bufsIndex uint16 // the index into the original bufs slice + numMerged uint16 // the number of packets merged into this item + gsoSize uint16 // payload size + iphLen uint8 // ip header len + cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown. +} + +func (u *udpGROTable) newItems() []udpGROItem { + var items []udpGROItem + items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1] + return items +} + +func (u *udpGROTable) reset() { + for k, items := range u.itemsByFlow { + items = items[:0] + u.itemsPool = append(u.itemsPool, items) + delete(u.itemsByFlow, k) + } +} + +// canCoalesce represents the outcome of checking if two TCP packets are +// candidates for coalescing. +type canCoalesce int + +const ( + coalescePrepend canCoalesce = -1 + coalesceUnavailable canCoalesce = 0 + coalesceAppend canCoalesce = 1 +) + +// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB +// meet all requirements to be merged as part of a GRO operation, otherwise it +// returns false. +func ipHeadersCanCoalesce(pktA, pktB []byte) bool { + if len(pktA) < 9 || len(pktB) < 9 { + return false + } + if pktA[0]>>4 == 6 { + if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 { + // cannot coalesce with unequal Traffic class values + return false + } + if pktA[7] != pktB[7] { + // cannot coalesce with unequal Hop limit values + return false + } + } else { + if pktA[1] != pktB[1] { + // cannot coalesce with unequal ToS values + return false + } + if pktA[6]>>5 != pktB[6]>>5 { + // cannot coalesce with unequal DF or reserved bits. MF is checked + // further up the stack. + return false + } + if pktA[8] != pktB[8] { + // cannot coalesce with unequal TTL values + return false + } + } + return true +} + +// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet +// described by item. iphLen and gsoSize describe pkt. bufs is the vector of +// packets involved in the current GRO evaluation. bufsOffset is the offset at +// which packet data begins within bufs. +func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { + pktTarget := bufs[item.bufsIndex][bufsOffset:] + if !ipHeadersCanCoalesce(pkt, pktTarget) { + return coalesceUnavailable + } + if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 { + // A smaller than gsoSize packet has been appended previously. + // Nothing can come after a smaller packet on the end. + return coalesceUnavailable + } + if gsoSize > item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + return coalesceAppend +} + +// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet +// described by item. This function makes considerations that match the kernel's +// GRO self tests, which can be found in tools/testing/selftests/net/gro.c. +func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { + pktTarget := bufs[item.bufsIndex][bufsOffset:] + if tcphLen != item.tcphLen { + // cannot coalesce with unequal tcp options len + return coalesceUnavailable + } + if tcphLen > 20 { + if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) { + // cannot coalesce with unequal tcp options + return coalesceUnavailable + } + } + if !ipHeadersCanCoalesce(pkt, pktTarget) { + return coalesceUnavailable + } + // seq adjacency + lhsLen := item.gsoSize + lhsLen += item.numMerged * item.gsoSize + if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective + if item.pshSet { + // We cannot append to a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 { + // A smaller than gsoSize packet has been appended previously. + // Nothing can come after a smaller packet on the end. + return coalesceUnavailable + } + if gsoSize > item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + return coalesceAppend + } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective + if pshSet { + // We cannot prepend with a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if gsoSize < item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + if gsoSize > item.gsoSize && item.numMerged > 0 { + // There's at least one previous merge, and we're larger than all + // previous. This would put multiple smaller packets on the end. + return coalesceUnavailable + } + return coalescePrepend + } + return coalesceUnavailable +} + +func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool { + srcAddrAt := ipv4SrcAddrOffset + addrSize := 4 + if isV6 { + srcAddrAt = ipv6SrcAddrOffset + addrSize = 16 + } + lenForPseudo := uint16(len(pkt) - int(iphLen)) + cSum := PseudoHeaderChecksum(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo) + return ^Checksum(pkt[iphLen:], cSum) == 0 +} + +// coalesceResult represents the result of attempting to coalesce two TCP +// packets. +type coalesceResult int + +const ( + coalesceInsufficientCap coalesceResult = iota + coalescePSHEnding + coalesceItemInvalidCSum + coalescePktInvalidCSum + coalesceSuccess +) + +// coalesceUDPPackets attempts to coalesce pkt with the packet described by +// item, and returns the outcome. +func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { + pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front + headersLen := item.iphLen + udphLen + coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) + + if cap(pktHead)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if item.numMerged == 0 { + if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) { + return coalescePktInvalidCSum + } + extendBy := len(pkt) - int(headersLen) + bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) + copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) + + item.numMerged++ + return coalesceSuccess +} + +// coalesceTCPPackets attempts to coalesce pkt with the packet described by +// item, and returns the outcome. This function may swap bufs elements in the +// event of a prepend as item's bufs index is already being tracked for writing +// to a Device. +func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { + var pktHead []byte // the packet that will end up at the front + headersLen := item.iphLen + item.tcphLen + coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) + + // Copy data + if mode == coalescePrepend { + pktHead = pkt + if cap(pkt)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if pshSet { + return coalescePSHEnding + } + if item.numMerged == 0 { + if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalescePktInvalidCSum + } + item.sentSeq = seq + extendBy := coalescedLen - len(pktHead) + bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...) + copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):]) + // Flip the slice headers in bufs as part of prepend. The index of item + // is already being tracked for writing. + bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex] + } else { + pktHead = bufs[item.bufsIndex][bufsOffset:] + if cap(pktHead)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if item.numMerged == 0 { + if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalescePktInvalidCSum + } + if pshSet { + // We are appending a segment with PSH set. + item.pshSet = pshSet + pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH + } + extendBy := len(pkt) - int(headersLen) + bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) + copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) + } + + if gsoSize > item.gsoSize { + item.gsoSize = gsoSize + } + + item.numMerged++ + return coalesceSuccess +} + +const ( + ipv4FlagMoreFragments uint8 = 0x20 +) + +const ( + maxUint16 = 1<<16 - 1 +) + +type groResult int + +const ( + groResultNoop groResult = iota + groResultTableInsert + groResultCoalesced +) + +// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with +// existing packets tracked in table. It returns a groResultNoop when no +// action was taken, groResultTableInsert when the evaluated packet was +// inserted into table, and groResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult { + pkt := bufs[pktI][offset:] + if len(pkt) > maxUint16 { + // A valid IPv4 or IPv6 packet will never exceed this. + return groResultNoop + } + iphLen := int((pkt[0] & 0x0F) * 4) + if isV6 { + iphLen = 40 + ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) + if ipv6HPayloadLen != len(pkt)-iphLen { + return groResultNoop + } + } else { + totalLen := int(binary.BigEndian.Uint16(pkt[2:])) + if totalLen != len(pkt) { + return groResultNoop + } + } + if len(pkt) < iphLen { + return groResultNoop + } + tcphLen := int((pkt[iphLen+12] >> 4) * 4) + if tcphLen < 20 || tcphLen > 60 { + return groResultNoop + } + if len(pkt) < iphLen+tcphLen { + return groResultNoop + } + if !isV6 { + if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { + // no GRO support for fragmented segments for now + return groResultNoop + } + } + tcpFlags := pkt[iphLen+tcpFlagsOffset] + var pshSet bool + // not a candidate if any non-ACK flags (except PSH+ACK) are set + if tcpFlags != tcpFlagACK { + if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { + return groResultNoop + } + pshSet = true + } + gsoSize := uint16(len(pkt) - tcphLen - iphLen) + // not a candidate if payload len is 0 + if gsoSize < 1 { + return groResultNoop + } + seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) + srcAddrOffset := ipv4SrcAddrOffset + addrLen := 4 + if isV6 { + srcAddrOffset = ipv6SrcAddrOffset + addrLen = 16 + } + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + if !existing { + return groResultTableInsert + } + for i := len(items) - 1; i >= 0; i-- { + // In the best case of packets arriving in order iterating in reverse is + // more efficient if there are multiple items for a given flow. This + // also enables a natural table.deleteAt() in the + // coalesceItemInvalidCSum case without the need for index tracking. + // This algorithm makes a best effort to coalesce in the event of + // unordered packets, where pkt may land anywhere in items from a + // sequence number perspective, however once an item is inserted into + // the table it is never compared across other items later. + item := items[i] + can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset) + if can != coalesceUnavailable { + result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6) + switch result { + case coalesceSuccess: + table.updateAt(item, i) + return groResultCoalesced + case coalesceItemInvalidCSum: + // delete the item with an invalid csum + table.deleteAt(item.key, i) + case coalescePktInvalidCSum: + // no point in inserting an item that we can't coalesce + return groResultNoop + default: + } + } + } + // failed to coalesce with any other packets; store the item in the flow + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + return groResultTableInsert +} + +// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + item.tcphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 16, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + if item.key.isV6 { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Calculate the pseudo header checksum and place it at the TCP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the tcp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if item.key.isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := PseudoHeaderChecksum(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], Checksum([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + +// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + udphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 6, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4 + if item.key.isV6 { + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Recalculate the UDP len field value + binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:]))) + + // Calculate the pseudo header checksum and place it at the UDP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the udp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if item.key.isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := PseudoHeaderChecksum(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], Checksum([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + +type groCandidateType uint8 + +const ( + notGROCandidate groCandidateType = iota + tcp4GROCandidate + tcp6GROCandidate + udp4GROCandidate + udp6GROCandidate +) + +func packetIsGROCandidate(b []byte, gro groDisablementFlags) groCandidateType { + if len(b) < 28 { + return notGROCandidate + } + if b[0]>>4 == 4 { + if b[0]&0x0F != 5 { + // IPv4 packets w/IP options do not coalesce + return notGROCandidate + } + if b[9] == unix.IPPROTO_TCP && len(b) >= 40 && gro.canTCPGRO() { + return tcp4GROCandidate + } + if b[9] == unix.IPPROTO_UDP && gro.canUDPGRO() { + return udp4GROCandidate + } + } else if b[0]>>4 == 6 { + if b[6] == unix.IPPROTO_TCP && len(b) >= 60 && gro.canTCPGRO() { + return tcp6GROCandidate + } + if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && gro.canUDPGRO() { + return udp6GROCandidate + } + } + return notGROCandidate +} + +const ( + udphLen = 8 +) + +// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with +// existing packets tracked in table. It returns a groResultNoop when no +// action was taken, groResultTableInsert when the evaluated packet was +// inserted into table, and groResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult { + pkt := bufs[pktI][offset:] + if len(pkt) > maxUint16 { + // A valid IPv4 or IPv6 packet will never exceed this. + return groResultNoop + } + iphLen := int((pkt[0] & 0x0F) * 4) + if isV6 { + iphLen = 40 + ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) + if ipv6HPayloadLen != len(pkt)-iphLen { + return groResultNoop + } + } else { + totalLen := int(binary.BigEndian.Uint16(pkt[2:])) + if totalLen != len(pkt) { + return groResultNoop + } + } + if len(pkt) < iphLen { + return groResultNoop + } + if len(pkt) < iphLen+udphLen { + return groResultNoop + } + if !isV6 { + if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { + // no GRO support for fragmented segments for now + return groResultNoop + } + } + gsoSize := uint16(len(pkt) - udphLen - iphLen) + // not a candidate if payload len is 0 + if gsoSize < 1 { + return groResultNoop + } + srcAddrOffset := ipv4SrcAddrOffset + addrLen := 4 + if isV6 { + srcAddrOffset = ipv6SrcAddrOffset + addrLen = 16 + } + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI) + if !existing { + return groResultTableInsert + } + // With UDP we only check the last item, otherwise we could reorder packets + // for a given flow. We must also always insert a new item, or successfully + // coalesce with an existing item, for the same reason. + item := items[len(items)-1] + can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset) + var pktCSumKnownInvalid bool + if can == coalesceAppend { + result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6) + switch result { + case coalesceSuccess: + table.updateAt(item, len(items)-1) + return groResultCoalesced + case coalesceItemInvalidCSum: + // If the existing item has an invalid csum we take no action. A new + // item will be stored after it, and the existing item will never be + // revisited as part of future coalescing candidacy checks. + case coalescePktInvalidCSum: + // We must insert a new item, but we also mark it as invalid csum + // to prevent a repeat checksum validation. + pktCSumKnownInvalid = true + default: + } + } + // failed to coalesce with any other packets; store the item in the flow + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid) + return groResultTableInsert +} + +// handleGRO evaluates bufs for GRO, and writes the indices of the resulting +// packets into toWrite. toWrite, tcpTable, and udpTable should initially be +// empty (but non-nil), and are passed in to save allocs as the caller may reset +// and recycle them across vectors of packets. gro indicates if TCP and UDP GRO +// are supported/enabled. +func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, gro groDisablementFlags, toWrite *[]int) error { + for i := range bufs { + if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { + return errors.New("invalid offset") + } + var result groResult + switch packetIsGROCandidate(bufs[i][offset:], gro) { + case tcp4GROCandidate: + result = tcpGRO(bufs, offset, i, tcpTable, false) + case tcp6GROCandidate: + result = tcpGRO(bufs, offset, i, tcpTable, true) + case udp4GROCandidate: + result = udpGRO(bufs, offset, i, udpTable, false) + case udp6GROCandidate: + result = udpGRO(bufs, offset, i, udpTable, true) + } + switch result { + case groResultNoop: + hdr := virtioNetHdr{} + err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + fallthrough + case groResultTableInsert: + *toWrite = append(*toWrite, i) + } + } + errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable) + errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable) + return errors.Join(errTCP, errUDP) +} diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go new file mode 100644 index 000000000..407037863 --- /dev/null +++ b/tun/offload_linux_test.go @@ -0,0 +1,764 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "net/netip" + "testing" + + "github.com/tailscale/wireguard-go/conn" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +const ( + offset = virtioNetHdrLen +) + +var ( + ip4PortA = netip.MustParseAddrPort("192.0.2.1:1") + ip4PortB = netip.MustParseAddrPort("192.0.2.2:1") + ip4PortC = netip.MustParseAddrPort("192.0.2.3:1") + ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1") + ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1") + ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1") +) + +func udp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv4Fields)) []byte { + totalLen := 28 + payloadLen + b := make([]byte, offset+int(totalLen), 65535) + ipv4H := header.IPv4(b[offset:]) + srcAs4 := srcIPPort.Addr().As4() + dstAs4 := dstIPPort.Addr().As4() + ipFields := &header.IPv4Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs4[:]), + DstAddr: tcpip.AddrFromSlice(dstAs4[:]), + Protocol: unix.IPPROTO_UDP, + TTL: 64, + TotalLength: uint16(totalLen), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv4H.Encode(ipFields) + udpH := header.UDP(b[offset+20:]) + udpH.Encode(&header.UDPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + Length: uint16(payloadLen + udphLen), + }) + ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udphLen+payloadLen)) + udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum)) + return b +} + +func udp6Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte { + return udp6PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil) +} + +func udp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv6Fields)) []byte { + totalLen := 48 + payloadLen + b := make([]byte, offset+int(totalLen), 65535) + ipv6H := header.IPv6(b[offset:]) + srcAs16 := srcIPPort.Addr().As16() + dstAs16 := dstIPPort.Addr().As16() + ipFields := &header.IPv6Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs16[:]), + DstAddr: tcpip.AddrFromSlice(dstAs16[:]), + TransportProtocol: unix.IPPROTO_UDP, + HopLimit: 64, + PayloadLength: uint16(payloadLen + udphLen), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv6H.Encode(ipFields) + udpH := header.UDP(b[offset+40:]) + udpH.Encode(&header.UDPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + Length: uint16(payloadLen + udphLen), + }) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udphLen+payloadLen)) + udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum)) + return b +} + +func udp4Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte { + return udp4PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil) +} + +func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte { + totalLen := 40 + segmentSize + b := make([]byte, offset+int(totalLen), 65535) + ipv4H := header.IPv4(b[offset:]) + srcAs4 := srcIPPort.Addr().As4() + dstAs4 := dstIPPort.Addr().As4() + ipFields := &header.IPv4Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs4[:]), + DstAddr: tcpip.AddrFromSlice(dstAs4[:]), + Protocol: unix.IPPROTO_TCP, + TTL: 64, + TotalLength: uint16(totalLen), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv4H.Encode(ipFields) + tcpH := header.TCP(b[offset+20:]) + tcpH.Encode(&header.TCPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + SeqNum: seq, + AckNum: 1, + DataOffset: 20, + Flags: flags, + WindowSize: 3000, + }) + ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize)) + tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) + return b +} + +func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { + return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) +} + +func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte { + totalLen := 60 + segmentSize + b := make([]byte, offset+int(totalLen), 65535) + ipv6H := header.IPv6(b[offset:]) + srcAs16 := srcIPPort.Addr().As16() + dstAs16 := dstIPPort.Addr().As16() + ipFields := &header.IPv6Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs16[:]), + DstAddr: tcpip.AddrFromSlice(dstAs16[:]), + TransportProtocol: unix.IPPROTO_TCP, + HopLimit: 64, + PayloadLength: uint16(segmentSize + 20), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv6H.Encode(ipFields) + tcpH := header.TCP(b[offset+40:]) + tcpH.Encode(&header.TCPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + SeqNum: seq, + AckNum: 1, + DataOffset: 20, + Flags: flags, + WindowSize: 3000, + }) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize)) + tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) + return b +} + +func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { + return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) +} + +func Test_handleVirtioRead(t *testing.T) { + tests := []struct { + name string + hdr virtioNetHdr + pktIn []byte + wantLens []int + wantErr bool + }{ + { + "tcp4", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4, + gsoSize: 100, + hdrLen: 40, + csumStart: 20, + csumOffset: 16, + }, + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), + []int{140, 140}, + false, + }, + { + "tcp6", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6, + gsoSize: 100, + hdrLen: 60, + csumStart: 40, + csumOffset: 16, + }, + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), + []int{160, 160}, + false, + }, + { + "udp4", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4, + gsoSize: 100, + hdrLen: 28, + csumStart: 20, + csumOffset: 6, + }, + udp4Packet(ip4PortA, ip4PortB, 200), + []int{128, 128}, + false, + }, + { + "udp6", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4, + gsoSize: 100, + hdrLen: 48, + csumStart: 40, + csumOffset: 6, + }, + udp6Packet(ip6PortA, ip6PortB, 200), + []int{148, 148}, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out := make([][]byte, conn.IdealBatchSize) + sizes := make([]int, conn.IdealBatchSize) + for i := range out { + out[i] = make([]byte, 65535) + } + tt.hdr.encode(tt.pktIn) + n, err := handleVirtioRead(tt.pktIn, out, sizes, offset) + if err != nil { + if tt.wantErr { + return + } + t.Fatalf("got err: %v", err) + } + if n != len(tt.wantLens) { + t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens)) + } + for i := range tt.wantLens { + if tt.wantLens[i] != sizes[i] { + t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i]) + } + } + }) + } +} + +func flipTCP4Checksum(b []byte) []byte { + at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16 + b[at] ^= 0xFF + b[at+1] ^= 0xFF + return b +} + +func flipUDP4Checksum(b []byte) []byte { + at := virtioNetHdrLen + 20 + 6 // 20 byte ipv4 header; udp csum offset is 6 + b[at] ^= 0xFF + b[at+1] ^= 0xFF + return b +} + +func Fuzz_handleGRO(f *testing.F) { + pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1) + pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101) + pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201) + pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1) + pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101) + pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201) + pkt6 := udp4Packet(ip4PortA, ip4PortB, 100) + pkt7 := udp4Packet(ip4PortA, ip4PortB, 100) + pkt8 := udp4Packet(ip4PortA, ip4PortC, 100) + pkt9 := udp6Packet(ip6PortA, ip6PortB, 100) + pkt10 := udp6Packet(ip6PortA, ip6PortB, 100) + pkt11 := udp6Packet(ip6PortA, ip6PortC, 100) + f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, 0, offset) + f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, gro int, offset int) { + pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11} + toWrite := make([]int, 0, len(pkts)) + handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), groDisablementFlags(gro), &toWrite) + if len(toWrite) > len(pkts) { + t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) + } + seenWriteI := make(map[int]bool) + for _, writeI := range toWrite { + if writeI < 0 || writeI > len(pkts)-1 { + t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts)) + } + if seenWriteI[writeI] { + t.Errorf("duplicate toWrite value: %d", writeI) + } + seenWriteI[writeI] = true + } + }) +} + +func Test_handleGRO(t *testing.T) { + tests := []struct { + name string + pktsIn [][]byte + gro groDisablementFlags + wantToWrite []int + wantLens []int + wantErr bool + }{ + { + "multiple protocols and flows", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1 + tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + }, + 0, + []int{0, 1, 2, 4, 5, 7, 9}, + []int{240, 228, 128, 140, 260, 160, 248}, + false, + }, + { + "multiple protocols and flows no UDP GRO", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1 + tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + }, + udpGRODisabled, + []int{0, 1, 2, 4, 5, 7, 8, 9, 10}, + []int{240, 128, 128, 140, 260, 160, 128, 148, 148}, + false, + }, + { + "PSH interleaved", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1 + }, + 0, + []int{0, 2, 4, 6}, + []int{240, 240, 260, 260}, + false, + }, + { + "coalesceItemInvalidCSum", + [][]byte{ + flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 + flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4Packet(ip4PortA, ip4PortB, 100), + }, + 0, + []int{0, 1, 3, 4}, + []int{140, 240, 128, 228}, + false, + }, + { + "out of order", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 + }, + 0, + []int{0}, + []int{340}, + false, + }, + { + "unequal TTL", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.TTL++ + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.TTL++ + }), + }, + 0, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, + }, + { + "unequal ToS", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.TOS++ + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.TOS++ + }), + }, + 0, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, + }, + { + "unequal flags more fragments set", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.Flags = 1 + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.Flags = 1 + }), + }, + 0, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, + }, + { + "unequal flags DF set", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.Flags = 2 + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.Flags = 2 + }), + }, + 0, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, + }, + { + "ipv6 unequal hop limit", + [][]byte{ + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), + tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { + fields.HopLimit++ + }), + udp6Packet(ip6PortA, ip6PortB, 100), + udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) { + fields.HopLimit++ + }), + }, + 0, + []int{0, 1, 2, 3}, + []int{160, 160, 148, 148}, + false, + }, + { + "ipv6 unequal traffic class", + [][]byte{ + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), + tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { + fields.TrafficClass++ + }), + udp6Packet(ip6PortA, ip6PortB, 100), + udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) { + fields.TrafficClass++ + }), + }, + 0, + []int{0, 1, 2, 3}, + []int{160, 160, 148, 148}, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + toWrite := make([]int, 0, len(tt.pktsIn)) + err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.gro, &toWrite) + if err != nil { + if tt.wantErr { + return + } + t.Fatalf("got err: %v", err) + } + if len(toWrite) != len(tt.wantToWrite) { + t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite)) + } + for i, pktI := range tt.wantToWrite { + if tt.wantToWrite[i] != toWrite[i] { + t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i]) + } + if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) { + t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:])) + } + } + }) + } +} + +func Test_packetIsGROCandidate(t *testing.T) { + tcp4 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] + tcp4TooShort := tcp4[:39] + ip4InvalidHeaderLen := make([]byte, len(tcp4)) + copy(ip4InvalidHeaderLen, tcp4) + ip4InvalidHeaderLen[0] = 0x46 + ip4InvalidProtocol := make([]byte, len(tcp4)) + copy(ip4InvalidProtocol, tcp4) + ip4InvalidProtocol[9] = unix.IPPROTO_GRE + + tcp6 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] + tcp6TooShort := tcp6[:59] + ip6InvalidProtocol := make([]byte, len(tcp6)) + copy(ip6InvalidProtocol, tcp6) + ip6InvalidProtocol[6] = unix.IPPROTO_GRE + + udp4 := udp4Packet(ip4PortA, ip4PortB, 100)[virtioNetHdrLen:] + udp4TooShort := udp4[:27] + + udp6 := udp6Packet(ip6PortA, ip6PortB, 100)[virtioNetHdrLen:] + udp6TooShort := udp6[:47] + + tests := []struct { + name string + b []byte + gro groDisablementFlags + want groCandidateType + }{ + { + "tcp4", + tcp4, + 0, + tcp4GROCandidate, + }, + { + "tcp4 no support", + tcp4, + tcpGRODisabled, + notGROCandidate, + }, + { + "tcp6", + tcp6, + 0, + tcp6GROCandidate, + }, + { + "tcp6 no support", + tcp6, + tcpGRODisabled, + notGROCandidate, + }, + { + "udp4", + udp4, + 0, + udp4GROCandidate, + }, + { + "udp4 no support", + udp4, + udpGRODisabled, + notGROCandidate, + }, + { + "udp6", + udp6, + 0, + udp6GROCandidate, + }, + { + "udp6 no support", + udp6, + udpGRODisabled, + notGROCandidate, + }, + { + "udp4 too short", + udp4TooShort, + 0, + notGROCandidate, + }, + { + "udp6 too short", + udp6TooShort, + 0, + notGROCandidate, + }, + { + "tcp4 too short", + tcp4TooShort, + 0, + notGROCandidate, + }, + { + "tcp6 too short", + tcp6TooShort, + 0, + notGROCandidate, + }, + { + "invalid IP version", + []byte{0x00}, + 0, + notGROCandidate, + }, + { + "invalid IP header len", + ip4InvalidHeaderLen, + 0, + notGROCandidate, + }, + { + "ip4 invalid protocol", + ip4InvalidProtocol, + 0, + notGROCandidate, + }, + { + "ip6 invalid protocol", + ip6InvalidProtocol, + 0, + notGROCandidate, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := packetIsGROCandidate(tt.b, tt.gro); got != tt.want { + t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_udpPacketsCanCoalesce(t *testing.T) { + udp4a := udp4Packet(ip4PortA, ip4PortB, 100) + udp4b := udp4Packet(ip4PortA, ip4PortB, 100) + udp4c := udp4Packet(ip4PortA, ip4PortB, 110) + + type args struct { + pkt []byte + iphLen uint8 + gsoSize uint16 + item udpGROItem + bufs [][]byte + bufsOffset int + } + tests := []struct { + name string + args args + want canCoalesce + }{ + { + "coalesceAppend equal gso", + args{ + pkt: udp4a[offset:], + iphLen: 20, + gsoSize: 100, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4a, + udp4b, + }, + bufsOffset: offset, + }, + coalesceAppend, + }, + { + "coalesceAppend smaller gso", + args{ + pkt: udp4a[offset : len(udp4a)-90], + iphLen: 20, + gsoSize: 10, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4a, + udp4b, + }, + bufsOffset: offset, + }, + coalesceAppend, + }, + { + "coalesceUnavailable smaller gso previously appended", + args{ + pkt: udp4a[offset:], + iphLen: 20, + gsoSize: 100, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4c, + udp4b, + }, + bufsOffset: offset, + }, + coalesceUnavailable, + }, + { + "coalesceUnavailable larger following smaller", + args{ + pkt: udp4c[offset:], + iphLen: 20, + gsoSize: 110, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4a, + udp4c, + }, + bufsOffset: offset, + }, + coalesceUnavailable, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want { + t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/tun/offload_test.go b/tun/offload_test.go new file mode 100644 index 000000000..82a37b9cc --- /dev/null +++ b/tun/offload_test.go @@ -0,0 +1,95 @@ +package tun + +import ( + "net/netip" + "testing" + + "github.com/tailscale/wireguard-go/conn" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +func Fuzz_GSOSplit(f *testing.F) { + const segmentSize = 100 + + tcpFields := &header.TCPFields{ + SrcPort: 1, + DstPort: 1, + SeqNum: 1, + AckNum: 1, + DataOffset: 20, + Flags: header.TCPFlagAck | header.TCPFlagPsh, + WindowSize: 3000, + } + udpFields := &header.UDPFields{ + SrcPort: 1, + DstPort: 1, + Length: 8 + segmentSize, + } + + gsoTCPv4 := make([]byte, 20+20+segmentSize) + header.IPv4(gsoTCPv4).Encode(&header.IPv4Fields{ + SrcAddr: tcpip.AddrFromSlice(netip.MustParseAddr("192.0.2.1").AsSlice()), + DstAddr: tcpip.AddrFromSlice(netip.MustParseAddr("192.0.2.2").AsSlice()), + Protocol: ipProtoTCP, + TTL: 64, + TotalLength: uint16(len(gsoTCPv4)), + }) + header.TCP(gsoTCPv4[20:]).Encode(tcpFields) + + gsoUDPv4 := make([]byte, 20+8+segmentSize) + header.IPv4(gsoUDPv4).Encode(&header.IPv4Fields{ + SrcAddr: tcpip.AddrFromSlice(netip.MustParseAddr("192.0.2.1").AsSlice()), + DstAddr: tcpip.AddrFromSlice(netip.MustParseAddr("192.0.2.2").AsSlice()), + Protocol: ipProtoUDP, + TTL: 64, + TotalLength: uint16(len(gsoUDPv4)), + }) + header.UDP(gsoTCPv4[20:]).Encode(udpFields) + + gsoTCPv6 := make([]byte, 40+20+segmentSize) + header.IPv6(gsoTCPv6).Encode(&header.IPv6Fields{ + SrcAddr: tcpip.AddrFromSlice(netip.MustParseAddr("2001:db8::1").AsSlice()), + DstAddr: tcpip.AddrFromSlice(netip.MustParseAddr("2001:db8::2").AsSlice()), + TransportProtocol: ipProtoTCP, + HopLimit: 64, + PayloadLength: uint16(20 + segmentSize), + }) + header.TCP(gsoTCPv6[40:]).Encode(tcpFields) + + gsoUDPv6 := make([]byte, 40+8+segmentSize) + header.IPv6(gsoUDPv6).Encode(&header.IPv6Fields{ + SrcAddr: tcpip.AddrFromSlice(netip.MustParseAddr("2001:db8::1").AsSlice()), + DstAddr: tcpip.AddrFromSlice(netip.MustParseAddr("2001:db8::2").AsSlice()), + TransportProtocol: ipProtoUDP, + HopLimit: 64, + PayloadLength: uint16(8 + segmentSize), + }) + header.UDP(gsoUDPv6[20:]).Encode(udpFields) + + out := make([][]byte, conn.IdealBatchSize) + for i := range out { + out[i] = make([]byte, 65535) + } + sizes := make([]int, conn.IdealBatchSize) + + f.Add(gsoTCPv4, int(GSOTCPv4), uint16(40), uint16(20), uint16(16), uint16(100), false) + f.Add(gsoUDPv4, int(GSOUDPL4), uint16(28), uint16(20), uint16(6), uint16(100), false) + f.Add(gsoTCPv6, int(GSOTCPv6), uint16(60), uint16(40), uint16(16), uint16(100), false) + f.Add(gsoUDPv6, int(GSOUDPL4), uint16(48), uint16(40), uint16(6), uint16(100), false) + + f.Fuzz(func(t *testing.T, pkt []byte, gsoType int, hdrLen, csumStart, csumOffset, gsoSize uint16, needsCsum bool) { + options := GSOOptions{ + GSOType: GSOType(gsoType), + HdrLen: hdrLen, + CsumStart: csumStart, + CsumOffset: csumOffset, + GSOSize: gsoSize, + NeedsCsum: needsCsum, + } + n, _ := GSOSplit(pkt, options, out, sizes, 0) + if n > len(sizes) { + t.Errorf("n (%d) > len(sizes): %d", n, len(sizes)) + } + }) +} diff --git a/tun/tcp_offload_linux.go b/tun/tcp_offload_linux.go deleted file mode 100644 index 39a7180c5..000000000 --- a/tun/tcp_offload_linux.go +++ /dev/null @@ -1,627 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - */ - -package tun - -import ( - "bytes" - "encoding/binary" - "errors" - "io" - "unsafe" - - "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/conn" -) - -const tcpFlagsOffset = 13 - -const ( - tcpFlagFIN uint8 = 0x01 - tcpFlagPSH uint8 = 0x08 - tcpFlagACK uint8 = 0x10 -) - -// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The -// kernel symbol is virtio_net_hdr. -type virtioNetHdr struct { - flags uint8 - gsoType uint8 - hdrLen uint16 - gsoSize uint16 - csumStart uint16 - csumOffset uint16 -} - -func (v *virtioNetHdr) decode(b []byte) error { - if len(b) < virtioNetHdrLen { - return io.ErrShortBuffer - } - copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen]) - return nil -} - -func (v *virtioNetHdr) encode(b []byte) error { - if len(b) < virtioNetHdrLen { - return io.ErrShortBuffer - } - copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen)) - return nil -} - -const ( - // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the - // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr). - virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{})) -) - -// flowKey represents the key for a flow. -type flowKey struct { - srcAddr, dstAddr [16]byte - srcPort, dstPort uint16 - rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows. -} - -// tcpGROTable holds flow and coalescing information for the purposes of GRO. -type tcpGROTable struct { - itemsByFlow map[flowKey][]tcpGROItem - itemsPool [][]tcpGROItem -} - -func newTCPGROTable() *tcpGROTable { - t := &tcpGROTable{ - itemsByFlow: make(map[flowKey][]tcpGROItem, conn.IdealBatchSize), - itemsPool: make([][]tcpGROItem, conn.IdealBatchSize), - } - for i := range t.itemsPool { - t.itemsPool[i] = make([]tcpGROItem, 0, conn.IdealBatchSize) - } - return t -} - -func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey { - key := flowKey{} - addrSize := dstAddr - srcAddr - copy(key.srcAddr[:], pkt[srcAddr:dstAddr]) - copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize]) - key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:]) - key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:]) - key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:]) - return key -} - -// lookupOrInsert looks up a flow for the provided packet and metadata, -// returning the packets found for the flow, or inserting a new one if none -// is found. -func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) { - key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) - items, ok := t.itemsByFlow[key] - if ok { - return items, ok - } - // TODO: insert() performs another map lookup. This could be rearranged to avoid. - t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex) - return nil, false -} - -// insert an item in the table for the provided packet and packet metadata. -func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) { - key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) - item := tcpGROItem{ - key: key, - bufsIndex: uint16(bufsIndex), - gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])), - iphLen: uint8(tcphOffset), - tcphLen: uint8(tcphLen), - sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]), - pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0, - } - items, ok := t.itemsByFlow[key] - if !ok { - items = t.newItems() - } - items = append(items, item) - t.itemsByFlow[key] = items -} - -func (t *tcpGROTable) updateAt(item tcpGROItem, i int) { - items, _ := t.itemsByFlow[item.key] - items[i] = item -} - -func (t *tcpGROTable) deleteAt(key flowKey, i int) { - items, _ := t.itemsByFlow[key] - items = append(items[:i], items[i+1:]...) - t.itemsByFlow[key] = items -} - -// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime -// of a GRO evaluation across a vector of packets. -type tcpGROItem struct { - key flowKey - sentSeq uint32 // the sequence number - bufsIndex uint16 // the index into the original bufs slice - numMerged uint16 // the number of packets merged into this item - gsoSize uint16 // payload size - iphLen uint8 // ip header len - tcphLen uint8 // tcp header len - pshSet bool // psh flag is set -} - -func (t *tcpGROTable) newItems() []tcpGROItem { - var items []tcpGROItem - items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1] - return items -} - -func (t *tcpGROTable) reset() { - for k, items := range t.itemsByFlow { - items = items[:0] - t.itemsPool = append(t.itemsPool, items) - delete(t.itemsByFlow, k) - } -} - -// canCoalesce represents the outcome of checking if two TCP packets are -// candidates for coalescing. -type canCoalesce int - -const ( - coalescePrepend canCoalesce = -1 - coalesceUnavailable canCoalesce = 0 - coalesceAppend canCoalesce = 1 -) - -// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet -// described by item. This function makes considerations that match the kernel's -// GRO self tests, which can be found in tools/testing/selftests/net/gro.c. -func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { - pktTarget := bufs[item.bufsIndex][bufsOffset:] - if tcphLen != item.tcphLen { - // cannot coalesce with unequal tcp options len - return coalesceUnavailable - } - if tcphLen > 20 { - if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) { - // cannot coalesce with unequal tcp options - return coalesceUnavailable - } - } - if pkt[0]>>4 == 6 { - if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 { - // cannot coalesce with unequal Traffic class values - return coalesceUnavailable - } - if pkt[7] != pktTarget[7] { - // cannot coalesce with unequal Hop limit values - return coalesceUnavailable - } - } else { - if pkt[1] != pktTarget[1] { - // cannot coalesce with unequal ToS values - return coalesceUnavailable - } - if pkt[6]>>5 != pktTarget[6]>>5 { - // cannot coalesce with unequal DF or reserved bits. MF is checked - // further up the stack. - return coalesceUnavailable - } - if pkt[8] != pktTarget[8] { - // cannot coalesce with unequal TTL values - return coalesceUnavailable - } - } - // seq adjacency - lhsLen := item.gsoSize - lhsLen += item.numMerged * item.gsoSize - if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective - if item.pshSet { - // We cannot append to a segment that has the PSH flag set, PSH - // can only be set on the final segment in a reassembled group. - return coalesceUnavailable - } - if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 { - // A smaller than gsoSize packet has been appended previously. - // Nothing can come after a smaller packet on the end. - return coalesceUnavailable - } - if gsoSize > item.gsoSize { - // We cannot have a larger packet following a smaller one. - return coalesceUnavailable - } - return coalesceAppend - } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective - if pshSet { - // We cannot prepend with a segment that has the PSH flag set, PSH - // can only be set on the final segment in a reassembled group. - return coalesceUnavailable - } - if gsoSize < item.gsoSize { - // We cannot have a larger packet following a smaller one. - return coalesceUnavailable - } - if gsoSize > item.gsoSize && item.numMerged > 0 { - // There's at least one previous merge, and we're larger than all - // previous. This would put multiple smaller packets on the end. - return coalesceUnavailable - } - return coalescePrepend - } - return coalesceUnavailable -} - -func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool { - srcAddrAt := ipv4SrcAddrOffset - addrSize := 4 - if isV6 { - srcAddrAt = ipv6SrcAddrOffset - addrSize = 16 - } - tcpTotalLen := uint16(len(pkt) - int(iphLen)) - tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen) - return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0 -} - -// coalesceResult represents the result of attempting to coalesce two TCP -// packets. -type coalesceResult int - -const ( - coalesceInsufficientCap coalesceResult = 0 - coalescePSHEnding coalesceResult = 1 - coalesceItemInvalidCSum coalesceResult = 2 - coalescePktInvalidCSum coalesceResult = 3 - coalesceSuccess coalesceResult = 4 -) - -// coalesceTCPPackets attempts to coalesce pkt with the packet described by -// item, returning the outcome. This function may swap bufs elements in the -// event of a prepend as item's bufs index is already being tracked for writing -// to a Device. -func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { - var pktHead []byte // the packet that will end up at the front - headersLen := item.iphLen + item.tcphLen - coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) - - // Copy data - if mode == coalescePrepend { - pktHead = pkt - if cap(pkt)-bufsOffset < coalescedLen { - // We don't want to allocate a new underlying array if capacity is - // too small. - return coalesceInsufficientCap - } - if pshSet { - return coalescePSHEnding - } - if item.numMerged == 0 { - if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) { - return coalesceItemInvalidCSum - } - } - if !tcpChecksumValid(pkt, item.iphLen, isV6) { - return coalescePktInvalidCSum - } - item.sentSeq = seq - extendBy := coalescedLen - len(pktHead) - bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...) - copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):]) - // Flip the slice headers in bufs as part of prepend. The index of item - // is already being tracked for writing. - bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex] - } else { - pktHead = bufs[item.bufsIndex][bufsOffset:] - if cap(pktHead)-bufsOffset < coalescedLen { - // We don't want to allocate a new underlying array if capacity is - // too small. - return coalesceInsufficientCap - } - if item.numMerged == 0 { - if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) { - return coalesceItemInvalidCSum - } - } - if !tcpChecksumValid(pkt, item.iphLen, isV6) { - return coalescePktInvalidCSum - } - if pshSet { - // We are appending a segment with PSH set. - item.pshSet = pshSet - pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH - } - extendBy := len(pkt) - int(headersLen) - bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) - copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) - } - - if gsoSize > item.gsoSize { - item.gsoSize = gsoSize - } - hdr := virtioNetHdr{ - flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb - hdrLen: uint16(headersLen), - gsoSize: uint16(item.gsoSize), - csumStart: uint16(item.iphLen), - csumOffset: 16, - } - - // Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the - // (IPv4) header checksum. - if isV6 { - hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 - binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len - } else { - hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 - pktHead[10], pktHead[11] = 0, 0 // clear checksum field - binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length - iphCSum := ^checksum(pktHead[:item.iphLen], 0) // compute checksum - binary.BigEndian.PutUint16(pktHead[10:], iphCSum) // set checksum field - } - hdr.encode(bufs[item.bufsIndex][bufsOffset-virtioNetHdrLen:]) - - // Calculate the pseudo header checksum and place it at the TCP checksum - // offset. Downstream checksum offloading will combine this with computation - // of the tcp header and payload checksum. - addrLen := 4 - addrOffset := ipv4SrcAddrOffset - if isV6 { - addrLen = 16 - addrOffset = ipv6SrcAddrOffset - } - srcAddrAt := bufsOffset + addrOffset - srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] - dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] - psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen))) - binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) - - item.numMerged++ - return coalesceSuccess -} - -const ( - ipv4FlagMoreFragments uint8 = 0x20 -) - -const ( - ipv4SrcAddrOffset = 12 - ipv6SrcAddrOffset = 8 - maxUint16 = 1<<16 - 1 -) - -// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with -// existing packets tracked in table. It will return false when pktI is not -// coalesced, otherwise true. This indicates to the caller if bufs[pktI] -// should be written to the Device. -func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) { - pkt := bufs[pktI][offset:] - if len(pkt) > maxUint16 { - // A valid IPv4 or IPv6 packet will never exceed this. - return false - } - iphLen := int((pkt[0] & 0x0F) * 4) - if isV6 { - iphLen = 40 - ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) - if ipv6HPayloadLen != len(pkt)-iphLen { - return false - } - } else { - totalLen := int(binary.BigEndian.Uint16(pkt[2:])) - if totalLen != len(pkt) { - return false - } - } - if len(pkt) < iphLen { - return false - } - tcphLen := int((pkt[iphLen+12] >> 4) * 4) - if tcphLen < 20 || tcphLen > 60 { - return false - } - if len(pkt) < iphLen+tcphLen { - return false - } - if !isV6 { - if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { - // no GRO support for fragmented segments for now - return false - } - } - tcpFlags := pkt[iphLen+tcpFlagsOffset] - var pshSet bool - // not a candidate if any non-ACK flags (except PSH+ACK) are set - if tcpFlags != tcpFlagACK { - if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { - return false - } - pshSet = true - } - gsoSize := uint16(len(pkt) - tcphLen - iphLen) - // not a candidate if payload len is 0 - if gsoSize < 1 { - return false - } - seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) - srcAddrOffset := ipv4SrcAddrOffset - addrLen := 4 - if isV6 { - srcAddrOffset = ipv6SrcAddrOffset - addrLen = 16 - } - items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) - if !existing { - return false - } - for i := len(items) - 1; i >= 0; i-- { - // In the best case of packets arriving in order iterating in reverse is - // more efficient if there are multiple items for a given flow. This - // also enables a natural table.deleteAt() in the - // coalesceItemInvalidCSum case without the need for index tracking. - // This algorithm makes a best effort to coalesce in the event of - // unordered packets, where pkt may land anywhere in items from a - // sequence number perspective, however once an item is inserted into - // the table it is never compared across other items later. - item := items[i] - can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset) - if can != coalesceUnavailable { - result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6) - switch result { - case coalesceSuccess: - table.updateAt(item, i) - return true - case coalesceItemInvalidCSum: - // delete the item with an invalid csum - table.deleteAt(item.key, i) - case coalescePktInvalidCSum: - // no point in inserting an item that we can't coalesce - return false - default: - } - } - } - // failed to coalesce with any other packets; store the item in the flow - table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) - return false -} - -func isTCP4NoIPOptions(b []byte) bool { - if len(b) < 40 { - return false - } - if b[0]>>4 != 4 { - return false - } - if b[0]&0x0F != 5 { - return false - } - if b[9] != unix.IPPROTO_TCP { - return false - } - return true -} - -func isTCP6NoEH(b []byte) bool { - if len(b) < 60 { - return false - } - if b[0]>>4 != 6 { - return false - } - if b[6] != unix.IPPROTO_TCP { - return false - } - return true -} - -// handleGRO evaluates bufs for GRO, and writes the indices of the resulting -// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be -// empty (but non-nil), and are passed in to save allocs as the caller may reset -// and recycle them across vectors of packets. -func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error { - for i := range bufs { - if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { - return errors.New("invalid offset") - } - var coalesced bool - switch { - case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce - coalesced = tcpGRO(bufs, offset, i, tcp4Table, false) - case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce - coalesced = tcpGRO(bufs, offset, i, tcp6Table, true) - } - if !coalesced { - hdr := virtioNetHdr{} - err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) - if err != nil { - return err - } - *toWrite = append(*toWrite, i) - } - } - return nil -} - -// tcpTSO splits packets from in into outBuffs, writing the size of each -// element into sizes. It returns the number of buffers populated, and/or an -// error. -func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) { - iphLen := int(hdr.csumStart) - srcAddrOffset := ipv6SrcAddrOffset - addrLen := 16 - if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 { - in[10], in[11] = 0, 0 // clear ipv4 header checksum - srcAddrOffset = ipv4SrcAddrOffset - addrLen = 4 - } - tcpCSumAt := int(hdr.csumStart + hdr.csumOffset) - in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum - firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:]) - nextSegmentDataAt := int(hdr.hdrLen) - i := 0 - for ; nextSegmentDataAt < len(in); i++ { - if i == len(outBuffs) { - return i - 1, ErrTooManySegments - } - nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize) - if nextSegmentEnd > len(in) { - nextSegmentEnd = len(in) - } - segmentDataLen := nextSegmentEnd - nextSegmentDataAt - totalLen := int(hdr.hdrLen) + segmentDataLen - sizes[i] = totalLen - out := outBuffs[i][outOffset:] - - copy(out, in[:iphLen]) - if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 { - // For IPv4 we are responsible for incrementing the ID field, - // updating the total len field, and recalculating the header - // checksum. - if i > 0 { - id := binary.BigEndian.Uint16(out[4:]) - id += uint16(i) - binary.BigEndian.PutUint16(out[4:], id) - } - binary.BigEndian.PutUint16(out[2:], uint16(totalLen)) - ipv4CSum := ^checksum(out[:iphLen], 0) - binary.BigEndian.PutUint16(out[10:], ipv4CSum) - } else { - // For IPv6 we are responsible for updating the payload length field. - binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen)) - } - - // TCP header - copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen]) - tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i)) - binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq) - if nextSegmentEnd != len(in) { - // FIN and PSH should only be set on last segment - clearFlags := tcpFlagFIN | tcpFlagPSH - out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags - } - - // payload - copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd]) - - // TCP checksum - tcpHLen := int(hdr.hdrLen - hdr.csumStart) - tcpLenForPseudo := uint16(tcpHLen + segmentDataLen) - tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo) - tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold) - binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum) - - nextSegmentDataAt += int(hdr.gsoSize) - } - return i, nil -} - -func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error { - cSumAt := cSumStart + cSumOffset - // The initial value at the checksum offset should be summed with the - // checksum we compute. This is typically the pseudo-header checksum. - initial := binary.BigEndian.Uint16(in[cSumAt:]) - in[cSumAt], in[cSumAt+1] = 0, 0 - binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial))) - return nil -} diff --git a/tun/tcp_offload_linux_test.go b/tun/tcp_offload_linux_test.go deleted file mode 100644 index 9160e18cd..000000000 --- a/tun/tcp_offload_linux_test.go +++ /dev/null @@ -1,411 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - */ - -package tun - -import ( - "net/netip" - "testing" - - "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/conn" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -const ( - offset = virtioNetHdrLen -) - -var ( - ip4PortA = netip.MustParseAddrPort("192.0.2.1:1") - ip4PortB = netip.MustParseAddrPort("192.0.2.2:1") - ip4PortC = netip.MustParseAddrPort("192.0.2.3:1") - ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1") - ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1") - ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1") -) - -func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte { - totalLen := 40 + segmentSize - b := make([]byte, offset+int(totalLen), 65535) - ipv4H := header.IPv4(b[offset:]) - srcAs4 := srcIPPort.Addr().As4() - dstAs4 := dstIPPort.Addr().As4() - ipFields := &header.IPv4Fields{ - SrcAddr: tcpip.Address(srcAs4[:]), - DstAddr: tcpip.Address(dstAs4[:]), - Protocol: unix.IPPROTO_TCP, - TTL: 64, - TotalLength: uint16(totalLen), - } - if ipFn != nil { - ipFn(ipFields) - } - ipv4H.Encode(ipFields) - tcpH := header.TCP(b[offset+20:]) - tcpH.Encode(&header.TCPFields{ - SrcPort: srcIPPort.Port(), - DstPort: dstIPPort.Port(), - SeqNum: seq, - AckNum: 1, - DataOffset: 20, - Flags: flags, - WindowSize: 3000, - }) - ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) - pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize)) - tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) - return b -} - -func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { - return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) -} - -func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte { - totalLen := 60 + segmentSize - b := make([]byte, offset+int(totalLen), 65535) - ipv6H := header.IPv6(b[offset:]) - srcAs16 := srcIPPort.Addr().As16() - dstAs16 := dstIPPort.Addr().As16() - ipFields := &header.IPv6Fields{ - SrcAddr: tcpip.Address(srcAs16[:]), - DstAddr: tcpip.Address(dstAs16[:]), - TransportProtocol: unix.IPPROTO_TCP, - HopLimit: 64, - PayloadLength: uint16(segmentSize + 20), - } - if ipFn != nil { - ipFn(ipFields) - } - ipv6H.Encode(ipFields) - tcpH := header.TCP(b[offset+40:]) - tcpH.Encode(&header.TCPFields{ - SrcPort: srcIPPort.Port(), - DstPort: dstIPPort.Port(), - SeqNum: seq, - AckNum: 1, - DataOffset: 20, - Flags: flags, - WindowSize: 3000, - }) - pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize)) - tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) - return b -} - -func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { - return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) -} - -func Test_handleVirtioRead(t *testing.T) { - tests := []struct { - name string - hdr virtioNetHdr - pktIn []byte - wantLens []int - wantErr bool - }{ - { - "tcp4", - virtioNetHdr{ - flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, - gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4, - gsoSize: 100, - hdrLen: 40, - csumStart: 20, - csumOffset: 16, - }, - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), - []int{140, 140}, - false, - }, - { - "tcp6", - virtioNetHdr{ - flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, - gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6, - gsoSize: 100, - hdrLen: 60, - csumStart: 40, - csumOffset: 16, - }, - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), - []int{160, 160}, - false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - out := make([][]byte, conn.IdealBatchSize) - sizes := make([]int, conn.IdealBatchSize) - for i := range out { - out[i] = make([]byte, 65535) - } - tt.hdr.encode(tt.pktIn) - n, err := handleVirtioRead(tt.pktIn, out, sizes, offset) - if err != nil { - if tt.wantErr { - return - } - t.Fatalf("got err: %v", err) - } - if n != len(tt.wantLens) { - t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens)) - } - for i := range tt.wantLens { - if tt.wantLens[i] != sizes[i] { - t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i]) - } - } - }) - } -} - -func flipTCP4Checksum(b []byte) []byte { - at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16 - b[at] ^= 0xFF - b[at+1] ^= 0xFF - return b -} - -func Fuzz_handleGRO(f *testing.F) { - pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1) - pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101) - pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201) - pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1) - pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101) - pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201) - f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, offset) - f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5 []byte, offset int) { - pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5} - toWrite := make([]int, 0, len(pkts)) - handleGRO(pkts, offset, newTCPGROTable(), newTCPGROTable(), &toWrite) - if len(toWrite) > len(pkts) { - t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) - } - seenWriteI := make(map[int]bool) - for _, writeI := range toWrite { - if writeI < 0 || writeI > len(pkts)-1 { - t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts)) - } - if seenWriteI[writeI] { - t.Errorf("duplicate toWrite value: %d", writeI) - } - seenWriteI[writeI] = true - } - }) -} - -func Test_handleGRO(t *testing.T) { - tests := []struct { - name string - pktsIn [][]byte - wantToWrite []int - wantLens []int - wantErr bool - }{ - { - "multiple flows", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // v4 flow 2 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // v6 flow 2 - }, - []int{0, 2, 3, 5}, - []int{240, 140, 260, 160}, - false, - }, - { - "PSH interleaved", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1 - }, - []int{0, 2, 4, 6}, - []int{240, 240, 260, 260}, - false, - }, - { - "coalesceItemInvalidCSum", - [][]byte{ - flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 - }, - []int{0, 1}, - []int{140, 240}, - false, - }, - { - "out of order", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 - }, - []int{0}, - []int{340}, - false, - }, - { - "tcp4 unequal TTL", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), - tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { - fields.TTL++ - }), - }, - []int{0, 1}, - []int{140, 140}, - false, - }, - { - "tcp4 unequal ToS", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), - tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { - fields.TOS++ - }), - }, - []int{0, 1}, - []int{140, 140}, - false, - }, - { - "tcp4 unequal flags more fragments set", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), - tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { - fields.Flags = 1 - }), - }, - []int{0, 1}, - []int{140, 140}, - false, - }, - { - "tcp4 unequal flags DF set", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), - tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { - fields.Flags = 2 - }), - }, - []int{0, 1}, - []int{140, 140}, - false, - }, - { - "tcp6 unequal hop limit", - [][]byte{ - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), - tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { - fields.HopLimit++ - }), - }, - []int{0, 1}, - []int{160, 160}, - false, - }, - { - "tcp6 unequal traffic class", - [][]byte{ - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), - tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { - fields.TrafficClass++ - }), - }, - []int{0, 1}, - []int{160, 160}, - false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - toWrite := make([]int, 0, len(tt.pktsIn)) - err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newTCPGROTable(), &toWrite) - if err != nil { - if tt.wantErr { - return - } - t.Fatalf("got err: %v", err) - } - if len(toWrite) != len(tt.wantToWrite) { - t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite)) - } - for i, pktI := range tt.wantToWrite { - if tt.wantToWrite[i] != toWrite[i] { - t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i]) - } - if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) { - t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:])) - } - } - }) - } -} - -func Test_isTCP4NoIPOptions(t *testing.T) { - valid := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] - invalidLen := valid[:39] - invalidHeaderLen := make([]byte, len(valid)) - copy(invalidHeaderLen, valid) - invalidHeaderLen[0] = 0x46 - invalidProtocol := make([]byte, len(valid)) - copy(invalidProtocol, valid) - invalidProtocol[9] = unix.IPPROTO_TCP + 1 - - tests := []struct { - name string - b []byte - want bool - }{ - { - "valid", - valid, - true, - }, - { - "invalid length", - invalidLen, - false, - }, - { - "invalid version", - []byte{0x00}, - false, - }, - { - "invalid header len", - invalidHeaderLen, - false, - }, - { - "invalid protocol", - invalidProtocol, - false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := isTCP4NoIPOptions(tt.b); got != tt.want { - t.Errorf("isTCP4NoIPOptions() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 b/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 deleted file mode 100644 index 5461e79a0..000000000 --- a/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 +++ /dev/null @@ -1,8 +0,0 @@ -go test fuzz v1 -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -int(34) diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d b/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d deleted file mode 100644 index b441819e7..000000000 --- a/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d +++ /dev/null @@ -1,8 +0,0 @@ -go test fuzz v1 -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -int(-48) diff --git a/tun/tun.go b/tun/tun.go index 0ae53d073..719a60631 100644 --- a/tun/tun.go +++ b/tun/tun.go @@ -51,3 +51,26 @@ type Device interface { // lifetime of a Device. BatchSize() int } + +// GRODevice is a Device extended with methods for disabling GRO. Certain OS +// versions may have offload bugs. Where these bugs negatively impact throughput +// or break connectivity entirely we can use these methods to disable the +// related offload. +// +// Linux has the following known, GRO bugs. +// +// torvalds/linux@e269d79c7d35aa3808b1f3c1737d63dab504ddc8 broke virtio_net +// TCP & UDP GRO causing GRO writes to return EINVAL. The bug was then +// resolved later in +// torvalds/linux@89add40066f9ed9abe5f7f886fe5789ff7e0c50e. The offending +// commit was pulled into various LTS releases. +// +// UDP GRO writes end up blackholing/dropping packets destined for a +// vxlan/geneve interface on kernel versions prior to 6.8.5. +type GRODevice interface { + Device + // DisableUDPGRO disables UDP GRO if it is enabled. + DisableUDPGRO() + // DisableTCPGRO disables TCP GRO if it is enabled. + DisableTCPGRO() +} diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 12cd49f74..7cdbf8825 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -17,9 +17,9 @@ import ( "time" "unsafe" + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/rwcancel" "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/rwcancel" ) const ( @@ -48,9 +48,34 @@ type NativeTun struct { readOpMu sync.Mutex // readOpMu guards readBuff readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr - writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable - toWrite []int - tcp4GROTable, tcp6GROTable *tcpGROTable + writeOpMu sync.Mutex // writeOpMu guards the following fields + toWrite []int + tcpGROTable *tcpGROTable + udpGROTable *udpGROTable + gro groDisablementFlags +} + +type groDisablementFlags int + +const ( + tcpGRODisabled groDisablementFlags = 1 << iota + udpGRODisabled +) + +func (g *groDisablementFlags) disableTCPGRO() { + *g |= tcpGRODisabled +} + +func (g *groDisablementFlags) canTCPGRO() bool { + return (*g)&tcpGRODisabled == 0 +} + +func (g *groDisablementFlags) disableUDPGRO() { + *g |= udpGRODisabled +} + +func (g *groDisablementFlags) canUDPGRO() bool { + return (*g)&udpGRODisabled == 0 } func (tun *NativeTun) File() *os.File { @@ -244,21 +269,15 @@ func (tun *NativeTun) setMTU(n int) error { defer unix.Close(fd) - // do ioctl call - var ifr [ifReqSize]byte - copy(ifr[:], name) - *(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n) - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCSIFMTU), - uintptr(unsafe.Pointer(&ifr[0])), - ) - - if errno != 0 { - return fmt.Errorf("failed to set MTU of TUN device: %w", errno) + req, err := unix.NewIfreq(name) + if err != nil { + return fmt.Errorf("unix.NewIfreq(%q): %w", name, err) + } + req.SetUint32(uint32(n)) + err = unix.IoctlIfreq(fd, unix.SIOCSIFMTU, req) + if err != nil { + return fmt.Errorf("failed to set MTU of TUN device %q: %w", name, err) } - return nil } @@ -333,8 +352,8 @@ func (tun *NativeTun) nameSlow() (string, error) { func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { tun.writeOpMu.Lock() defer func() { - tun.tcp4GROTable.reset() - tun.tcp6GROTable.reset() + tun.tcpGROTable.reset() + tun.udpGROTable.reset() tun.writeOpMu.Unlock() }() var ( @@ -343,7 +362,7 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { ) tun.toWrite = tun.toWrite[:0] if tun.vnetHdr { - err := handleGRO(bufs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite) + err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.gro, &tun.toWrite) if err != nil { return 0, err } @@ -377,68 +396,32 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e return 0, err } in = in[virtioNetHdrLen:] - if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE { - if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { - // This means CHECKSUM_PARTIAL in skb context. We are responsible - // for computing the checksum starting at hdr.csumStart and placing - // at hdr.csumOffset. - err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset) - if err != nil { - return 0, err - } - } - if len(in) > len(bufs[0][offset:]) { - return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:])) - } - n := copy(bufs[0][offset:], in) - sizes[0] = n - return 1, nil - } - if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { - return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType) - } - ipVersion := in[0] >> 4 - switch ipVersion { - case 4: - if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 { - return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) - } - case 6: - if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { - return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) - } - default: - return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) + options, err := hdr.toGSOOptions() + if err != nil { + return 0, err } - if len(in) <= int(hdr.csumStart+12) { - return 0, errors.New("packet is too short") - } - // Don't trust hdr.hdrLen from the kernel as it can be equal to the length + // Don't trust HdrLen from the kernel as it can be equal to the length // of the entire first packet when the kernel is handling it as part of a - // FORWARD path. Instead, parse the TCP header length and add it onto - // csumStart, which is synonymous for IP header length. - tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) - if tcpHLen < 20 || tcpHLen > 60 { - // A TCP header must be between 20 and 60 bytes in length. - return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) - } - hdr.hdrLen = hdr.csumStart + tcpHLen - - if len(in) < int(hdr.hdrLen) { - return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen) - } + // FORWARD path. Instead, parse the transport header length and add it onto + // CsumStart, which is synonymous for IP header length. + if options.GSOType == GSOUDPL4 { + options.HdrLen = options.CsumStart + 8 + } else if options.GSOType != GSONone { + if len(in) <= int(options.CsumStart+12) { + return 0, errors.New("packet is too short") + } - if hdr.hdrLen < hdr.csumStart { - return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart) - } - cSumAt := int(hdr.csumStart + hdr.csumOffset) - if cSumAt+1 >= len(in) { - return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) + tcpHLen := uint16(in[options.CsumStart+12] >> 4 * 4) + if tcpHLen < 20 || tcpHLen > 60 { + // A TCP header must be between 20 and 60 bytes in length. + return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) + } + options.HdrLen = options.CsumStart + tcpHLen } - return tcpTSO(in, hdr, bufs, sizes, offset) + return GSOSplit(in, options, bufs, sizes, offset) } func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { @@ -495,9 +478,26 @@ func (tun *NativeTun) BatchSize() int { return tun.batchSize } +// DisableUDPGRO disables UDP GRO if it is enabled. See the GRODevice interface +// for cases where it should be called. +func (tun *NativeTun) DisableUDPGRO() { + tun.writeOpMu.Lock() + tun.gro.disableUDPGRO() + tun.writeOpMu.Unlock() +} + +// DisableTCPGRO disables TCP GRO if it is enabled. See the GRODevice interface +// for cases where it should be called. +func (tun *NativeTun) DisableTCPGRO() { + tun.writeOpMu.Lock() + tun.gro.disableTCPGRO() + tun.writeOpMu.Unlock() +} + const ( // TODO: support TSO with ECN bits - tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6 ) func (tun *NativeTun) initFromFlags(name string) error { @@ -519,12 +519,19 @@ func (tun *NativeTun) initFromFlags(name string) error { } got := ifr.Uint16() if got&unix.IFF_VNET_HDR != 0 { - err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads) + // tunTCPOffloads were added in Linux v2.6. We require their support + // if IFF_VNET_HDR is set. + err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads) if err != nil { return } tun.vnetHdr = true tun.batchSize = conn.IdealBatchSize + // tunUDPOffloads were added in Linux v6.2. We do not return an + // error if they are unsupported at runtime. + if unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) != nil { + tun.gro.disableUDPGRO() + } } else { tun.batchSize = 1 } @@ -575,8 +582,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { events: make(chan Event, 5), errors: make(chan error, 5), statusListenersShutdown: make(chan struct{}), - tcp4GROTable: newTCPGROTable(), - tcp6GROTable: newTCPGROTable(), + tcpGROTable: newTCPGROTable(), + udpGROTable: newUDPGROTable(), toWrite: make([]int, 0, conn.IdealBatchSize), } @@ -628,12 +635,12 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { } file := os.NewFile(uintptr(fd), "/dev/tun") tun := &NativeTun{ - tunFile: file, - events: make(chan Event, 5), - errors: make(chan error, 5), - tcp4GROTable: newTCPGROTable(), - tcp6GROTable: newTCPGROTable(), - toWrite: make([]int, 0, conn.IdealBatchSize), + tunFile: file, + events: make(chan Event, 5), + errors: make(chan error, 5), + tcpGROTable: newTCPGROTable(), + udpGROTable: newUDPGROTable(), + toWrite: make([]int, 0, conn.IdealBatchSize), } name, err := tun.Name() if err != nil { diff --git a/tun/tun_plan9.go b/tun/tun_plan9.go new file mode 100644 index 000000000..7b66eadf6 --- /dev/null +++ b/tun/tun_plan9.go @@ -0,0 +1,147 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "fmt" + "io" + "os" + "strconv" + "strings" + "sync" +) + +type NativeTun struct { + name string // "/net/ipifc/2" + ctlFile *os.File + dataFile *os.File + events chan Event + errors chan error + closeOnce sync.Once +} + +func CreateTUN(_ string, mtu int) (Device, error) { + ctl, err := os.OpenFile("/net/ipifc/clone", os.O_RDWR, 0) + if err != nil { + return nil, err + } + nbuf := make([]byte, 5) + n, err := ctl.Read(nbuf) + if err != nil { + ctl.Close() + return nil, fmt.Errorf("error reading from clone file: %w", err) + } + ifn, err := strconv.Atoi(strings.TrimSpace(string(nbuf[:n]))) + if err != nil { + ctl.Close() + return nil, fmt.Errorf("error converting clone result %q to int: %w", nbuf[:n], err) + } + + if _, err := fmt.Fprintf(ctl, "bind pkt\n"); err != nil { + ctl.Close() + return nil, fmt.Errorf("error binding to pkt: %w", err) + } + if mtu > 0 { + if _, err := fmt.Fprintf(ctl, "mtu %d\n", mtu); err != nil { + ctl.Close() + return nil, fmt.Errorf("error setting MTU: %w", err) + } + } + + dataFile, err := os.OpenFile(fmt.Sprintf("/net/ipifc/%d/data", ifn), os.O_RDWR, 0) + if err != nil { + ctl.Close() + return nil, err + } + + tun := &NativeTun{ + ctlFile: ctl, + dataFile: dataFile, + name: fmt.Sprintf("/net/ipifc/%d", ifn), + events: make(chan Event, 10), + errors: make(chan error, 5), + } + tun.events <- EventUp + + return tun, nil +} + +func (tun *NativeTun) Name() (string, error) { + return tun.name, nil +} + +func (tun *NativeTun) File() *os.File { + return tun.ctlFile +} + +func (tun *NativeTun) Events() <-chan Event { + return tun.events +} + +func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { + select { + case err := <-tun.errors: + return 0, err + default: + n, err := tun.dataFile.Read(bufs[0][offset:]) + if n == 1 && bufs[0][offset] == 0 { + // EOF + err = io.EOF + n = 0 + } + sizes[0] = n + return 1, err + } +} + +func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { + for i, buf := range bufs { + if _, err := tun.dataFile.Write(buf[offset:]); err != nil { + return i, err + } + } + return len(bufs), nil +} + +func (tun *NativeTun) Close() error { + var err1, err2 error + tun.closeOnce.Do(func() { + _, err1 := fmt.Fprintf(tun.ctlFile, "unbind\n") + if err := tun.ctlFile.Close(); err != nil && err1 == nil { + err1 = err + } + err2 = tun.dataFile.Close() + }) + if err1 != nil { + return err1 + } + return err2 +} + +func (tun *NativeTun) MTU() (int, error) { + var buf [100]byte + f, err := os.Open(tun.name + "/status") + if err != nil { + return 0, err + } + defer f.Close() + n, err := f.Read(buf[:]) + _, res, ok := strings.Cut(string(buf[:n]), " maxtu ") + if ok { + if mtus, _, ok := strings.Cut(res, " "); ok { + mtu, err := strconv.Atoi(mtus) + if err != nil { + return 0, fmt.Errorf("error converting mtu %q to int: %w", mtus, err) + } + return mtu, nil + } + } + return 0, fmt.Errorf("no 'maxtu' field found in %s/status", tun.name) +} + +func (tun *NativeTun) BatchSize() int { + return 1 +} diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 0cb4ce192..34f29805d 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -127,6 +127,9 @@ func (tun *NativeTun) MTU() (int, error) { // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. func (tun *NativeTun) ForceMTU(mtu int) { + if tun.close.Load() { + return + } update := tun.forcedMTU != mtu tun.forcedMTU = mtu if update { diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go index d07e8601a..e7507c26c 100644 --- a/tun/tuntest/tuntest.go +++ b/tun/tuntest/tuntest.go @@ -11,7 +11,7 @@ import ( "net/netip" "os" - "golang.zx2c4.com/wireguard/tun" + "github.com/tailscale/wireguard-go/tun" ) func Ping(dst, src netip.Addr) []byte {