From 23124c094e90111146fe2b0f0ee3365820cc85c3 Mon Sep 17 00:00:00 2001 From: eustrain Date: Mon, 9 Feb 2026 07:42:26 +0000 Subject: [PATCH 1/5] perf: do not alloc request in heap --- addr_linux.go | 10 ++++----- class_linux.go | 13 ++++++----- conntrack_linux.go | 33 +++++++++++++++++---------- conntrack_test.go | 34 ++++++++++++++-------------- devlink_linux.go | 8 +++---- gtp_linux.go | 6 ++--- handle_linux.go | 4 ++-- ipset_linux.go | 4 ++-- link_linux.go | 11 ++++----- neigh_linux.go | 4 ++-- nexthop_linux.go | 23 ++++++++++--------- nl/nl_linux.go | 9 ++++---- proc_event_linux.go | 2 +- qdisc_linux.go | 9 ++++---- rdma_link_linux.go | 3 +-- route_linux.go | 54 +++++++++++++++++++++++---------------------- rule_linux.go | 8 +++---- socket_xdp_linux.go | 2 +- 18 files changed, 126 insertions(+), 111 deletions(-) diff --git a/addr_linux.go b/addr_linux.go index 9e312043..e494e0e5 100644 --- a/addr_linux.go +++ b/addr_linux.go @@ -31,7 +31,7 @@ func AddrAdd(link Link, addr *Addr) error { // If `net.IPv4zero` is given as the broadcast address, broadcast is disabled. func (h *Handle) AddrAdd(link Link, addr *Addr) error { req := h.newNetlinkRequest(unix.RTM_NEWADDR, unix.NLM_F_CREATE|unix.NLM_F_EXCL|unix.NLM_F_ACK) - return h.addrHandle(link, addr, req) + return h.addrHandle(req, link, addr) } // AddrReplace will replace (or, if not present, add) an IP address on a link device. @@ -54,7 +54,7 @@ func AddrReplace(link Link, addr *Addr) error { // If `net.IPv4zero` is given as the broadcast address, broadcast is disabled. func (h *Handle) AddrReplace(link Link, addr *Addr) error { req := h.newNetlinkRequest(unix.RTM_NEWADDR, unix.NLM_F_CREATE|unix.NLM_F_REPLACE|unix.NLM_F_ACK) - return h.addrHandle(link, addr, req) + return h.addrHandle(req, link, addr) } // AddrDel will delete an IP address from a link device. @@ -69,10 +69,10 @@ func AddrDel(link Link, addr *Addr) error { // Equivalent to: `ip addr del $addr dev $link` func (h *Handle) AddrDel(link Link, addr *Addr) error { req := h.newNetlinkRequest(unix.RTM_DELADDR, unix.NLM_F_ACK) - return h.addrHandle(link, addr, req) + return h.addrHandle(req, link, addr) } -func (h *Handle) addrHandle(link Link, addr *Addr, req *nl.NetlinkRequest) error { +func (h *Handle) addrHandle(req nl.NetlinkRequest, link Link, addr *Addr) error { family := nl.GetIPFamily(addr.IP) msg := nl.NewIfAddrmsg(family) msg.Scope = uint8(addr.Scope) @@ -364,7 +364,7 @@ func addrSubscribeAt(newNs, curNs netns.NsHandle, ch chan<- AddrUpdate, done <-c unix.NLM_F_DUMP) infmsg := nl.NewIfInfomsg(unix.AF_UNSPEC) req.AddData(infmsg) - if err := s.Send(req); err != nil { + if err := s.Send(req.Serialize()); err != nil { return err } } diff --git a/class_linux.go b/class_linux.go index 08fb16c2..e1407b4d 100644 --- a/class_linux.go +++ b/class_linux.go @@ -133,16 +133,17 @@ func (h *Handle) classModify(cmd, flags int, class Class) error { } req.AddData(msg) + var err error if cmd != unix.RTM_DELTCLASS { - if err := classPayload(req, class); err != nil { + if req, err = classPayload(req, class); err != nil { return err } } - _, err := req.Execute(unix.NETLINK_ROUTE, 0) + _, err = req.Execute(unix.NETLINK_ROUTE, 0) return err } -func classPayload(req *nl.NetlinkRequest, class Class) error { +func classPayload(req nl.NetlinkRequest, class Class) (nl.NetlinkRequest, error) { req.AddData(nl.NewRtAttr(nl.TCA_KIND, nl.ZeroTerminated(class.Type()))) options := nl.NewRtAttr(nl.TCA_OPTIONS, nil) @@ -165,12 +166,12 @@ func classPayload(req *nl.NetlinkRequest, class Class) error { var ctab [256]uint32 tcrate := nl.TcRateSpec{Rate: uint32(htb.Rate)} if CalcRtable(&tcrate, rtab[:], cellLog, uint32(mtu), linklayer) < 0 { - return errors.New("HTB: failed to calculate rate table") + return nl.NetlinkRequest{}, errors.New("HTB: failed to calculate rate table") } opt.Rate = tcrate tcceil := nl.TcRateSpec{Rate: uint32(htb.Ceil)} if CalcRtable(&tcceil, ctab[:], ccellLog, uint32(mtu), linklayer) < 0 { - return errors.New("HTB: failed to calculate ceil rate table") + return nl.NetlinkRequest{}, errors.New("HTB: failed to calculate ceil rate table") } opt.Ceil = tcceil options.AddRtAttr(nl.TCA_HTB_PARMS, opt.Serialize()) @@ -196,7 +197,7 @@ func classPayload(req *nl.NetlinkRequest, class Class) error { options.AddRtAttr(nl.TCA_HFSC_USC, nl.SerializeHfscCurve(&opt.Usc)) } req.AddData(options) - return nil + return req, nil } // ClassList gets a list of classes in the system. diff --git a/conntrack_linux.go b/conntrack_linux.go index 0fc88e41..6f057694 100644 --- a/conntrack_linux.go +++ b/conntrack_linux.go @@ -48,8 +48,8 @@ type InetFamily uint8 // // If the returned error is [ErrDumpInterrupted], results may be inconsistent // or incomplete. -func ConntrackTableList(table ConntrackTableType, family InetFamily) ([]*ConntrackFlow, error) { - return pkgHandle.ConntrackTableList(table, family) +func ConntrackTableList(table ConntrackTableType, family InetFamily, allocator func() *ConntrackFlow) ([]*ConntrackFlow, error) { + return pkgHandle.ConntrackTableList(table, family, allocator) } // ConntrackTableFlush flushes all the flows of a specified table @@ -85,8 +85,8 @@ func ConntrackDeleteFilters(table ConntrackTableType, family InetFamily, filters return pkgHandle.ConntrackDeleteFilters(table, family, filters...) } -func ConntrackTableListStream(table ConntrackTableType, family InetFamily, handle chan *ConntrackFlow) error { - return pkgHandle.ConntrackTableListStream(table, family, handle) +func ConntrackTableListStream(table ConntrackTableType, family InetFamily, handle chan *ConntrackFlow, allocator func() *ConntrackFlow) error { + return pkgHandle.ConntrackTableListStream(table, family, handle, allocator) } // ConntrackTableList returns the flow list of a table of a specific family using the netlink handle passed @@ -94,7 +94,7 @@ func ConntrackTableListStream(table ConntrackTableType, family InetFamily, handl // // If the returned error is [ErrDumpInterrupted], results may be inconsistent // or incomplete. -func (h *Handle) ConntrackTableList(table ConntrackTableType, family InetFamily) ([]*ConntrackFlow, error) { +func (h *Handle) ConntrackTableList(table ConntrackTableType, family InetFamily, allocator func() *ConntrackFlow) ([]*ConntrackFlow, error) { res, executeErr := h.dumpConntrackTable(table, family) if executeErr != nil && !errors.Is(executeErr, ErrDumpInterrupted) { return nil, executeErr @@ -103,7 +103,7 @@ func (h *Handle) ConntrackTableList(table ConntrackTableType, family InetFamily) // Deserialize all the flows var result []*ConntrackFlow for _, dataRaw := range res { - result = append(result, parseRawData(dataRaw)) + result = append(result, parseRawData(dataRaw, allocator)) } return result, executeErr @@ -176,8 +176,12 @@ func (h *Handle) ConntrackDeleteFilters(table ConntrackTableType, family InetFam var totalFilterErrors int var matched uint + var tempConntrackFlow ConntrackFlow + allocator := func() *ConntrackFlow { + return &tempConntrackFlow + } for _, dataRaw := range res { - flow := parseRawData(dataRaw) + flow := parseRawData(dataRaw, allocator) for _, filter := range filters { if match := filter.MatchConntrackFlow(flow); match { req2 := h.newConntrackRequest(table, family, nl.IPCTNL_MSG_CT_DELETE, unix.NLM_F_ACK) @@ -199,18 +203,18 @@ func (h *Handle) ConntrackDeleteFilters(table ConntrackTableType, family InetFam return matched, finalErr } -func (h *Handle) ConntrackTableListStream(table ConntrackTableType, family InetFamily, handle chan *ConntrackFlow) error { +func (h *Handle) ConntrackTableListStream(table ConntrackTableType, family InetFamily, handle chan *ConntrackFlow, allocator func() *ConntrackFlow) error { req := h.newConntrackRequest(table, family, nl.IPCTNL_MSG_CT_GET, unix.NLM_F_DUMP) err := req.ExecuteIter(unix.NETLINK_NETFILTER, 0, func(dataRaw []byte) bool { - handle <- parseRawData(dataRaw) + handle <- parseRawData(dataRaw, allocator) return true }) return err } -func (h *Handle) newConntrackRequest(table ConntrackTableType, family InetFamily, operation, flags int) *nl.NetlinkRequest { +func (h *Handle) newConntrackRequest(table ConntrackTableType, family InetFamily, operation, flags int) nl.NetlinkRequest { // Create the Netlink request object req := h.newNetlinkRequest((int(table)<<8)|operation, flags) // Add the netfilter header @@ -795,8 +799,13 @@ func parseConnectionZone(r *bytes.Reader) (zone uint16) { return } -func parseRawData(data []byte) *ConntrackFlow { - s := &ConntrackFlow{} +func parseRawData(data []byte, allocator func() *ConntrackFlow) *ConntrackFlow { + var s *ConntrackFlow + if allocator != nil { + s = allocator() + } else { + s = &ConntrackFlow{} + } // First there is the Nfgenmsg header // consume only the family field reader := bytes.NewReader(data) diff --git a/conntrack_test.go b/conntrack_test.go index 9e1c4a0a..97d0202b 100644 --- a/conntrack_test.go +++ b/conntrack_test.go @@ -208,7 +208,7 @@ func TestConntrackTableList(t *testing.T) { udpFlowCreateProg(t, 5, 2000, "127.0.0.10", 3000) // Fetch the conntrack table - flows, err := h.ConntrackTableList(ConntrackTable, unix.AF_INET) + flows, err := h.ConntrackTableList(ConntrackTable, unix.AF_INET, nil) CheckErrorFail(t, err) // Check that it is able to find the 5 flows created @@ -241,7 +241,7 @@ func TestConntrackTableList(t *testing.T) { } // Give a try also to the IPv6 version - _, err = h.ConntrackTableList(ConntrackTable, unix.AF_INET6) + _, err = h.ConntrackTableList(ConntrackTable, unix.AF_INET6, nil) CheckErrorFail(t, err) // Switch back to the original namespace @@ -275,7 +275,7 @@ func TestConntrackTableFlush(t *testing.T) { udpFlowCreateProg(t, 5, 3000, "127.0.0.10", 4000) // Fetch the conntrack table - flows, err := h.ConntrackTableList(ConntrackTable, unix.AF_INET) + flows, err := h.ConntrackTableList(ConntrackTable, unix.AF_INET, nil) CheckErrorFail(t, err) // Check that it is able to find the 5 flows created @@ -297,7 +297,7 @@ func TestConntrackTableFlush(t *testing.T) { CheckErrorFail(t, err) // Fetch again the flows to validate the flush - flows, err = h.ConntrackTableList(ConntrackTable, unix.AF_INET) + flows, err = h.ConntrackTableList(ConntrackTable, unix.AF_INET, nil) CheckErrorFail(t, err) // Check if it is still able to find the 5 flows created @@ -348,7 +348,7 @@ func TestConntrackTableDelete(t *testing.T) { udpFlowCreateProg(t, 5, 7000, "127.0.0.20", 8000) // Fetch the conntrack table - flows, err := h.ConntrackTableList(ConntrackTable, unix.AF_INET) + flows, err := h.ConntrackTableList(ConntrackTable, unix.AF_INET, nil) CheckErrorFail(t, err) // Check that it is able to find the 5 flows created for each group @@ -388,7 +388,7 @@ func TestConntrackTableDelete(t *testing.T) { } // Check again the table to verify that are gone - flows, err = h.ConntrackTableList(ConntrackTable, unix.AF_INET) + flows, err = h.ConntrackTableList(ConntrackTable, unix.AF_INET, nil) CheckErrorFail(t, err) // Check if it is able to find the 5 flows of groupA but none of groupB @@ -1053,7 +1053,7 @@ func TestParseRawData(t *testing.T) { for _, test := range tests { t.Run(test.testname, func(t *testing.T) { - conntrackFlow := parseRawData(test.rawData) + conntrackFlow := parseRawData(test.rawData, nil) if conntrackFlow.String() != test.expConntrackFlow { t.Errorf("expected conntrack flow:\n\t%q\ngot conntrack flow:\n\t%q", test.expConntrackFlow, conntrackFlow) @@ -1127,7 +1127,7 @@ func TestConntrackUpdateV4(t *testing.T) { t.Fatalf("failed to insert conntrack: %s", err) } - flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4) + flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4, nil) if err != nil { t.Fatalf("failed to list conntracks following successful insert: %s", err) } @@ -1173,7 +1173,7 @@ func TestConntrackUpdateV4(t *testing.T) { } // Look for updated conntrack. - flows, err = h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4) + flows, err = h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4, nil) if err != nil { t.Fatalf("failed to list conntracks following successful update: %s", err) } @@ -1260,7 +1260,7 @@ func TestConntrackUpdateV6(t *testing.T) { t.Fatalf("failed to insert conntrack: %s", err) } - flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V6) + flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V6, nil) if err != nil { t.Fatalf("failed to list conntracks following successful insert: %s", err) } @@ -1306,7 +1306,7 @@ func TestConntrackUpdateV6(t *testing.T) { } // Look for updated conntrack. - flows, err = h.ConntrackTableList(ConntrackTable, nl.FAMILY_V6) + flows, err = h.ConntrackTableList(ConntrackTable, nl.FAMILY_V6, nil) if err != nil { t.Fatalf("failed to list conntracks following successful update: %s", err) } @@ -1386,7 +1386,7 @@ func TestConntrackCreateV4(t *testing.T) { t.Fatalf("failed to insert conntrack: %s", err) } - flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4) + flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4, nil) if err != nil { t.Fatalf("failed to list conntracks following successful insert: %s", err) } @@ -1481,7 +1481,7 @@ func TestConntrackCreateV6(t *testing.T) { t.Fatalf("failed to insert conntrack: %s", err) } - flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V6) + flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V6, nil) if err != nil { t.Fatalf("failed to list conntracks following successful insert: %s", err) } @@ -1580,7 +1580,7 @@ func TestConntrackLabels(t *testing.T) { t.Fatalf("failed to insert conntrack: %s", err) } - flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4) + flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4, nil) if err != nil { t.Fatalf("failed to list conntracks following successful insert: %s", err) } @@ -1627,7 +1627,7 @@ func TestConntrackLabels(t *testing.T) { } // Look for updated conntrack. - flows, err = h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4) + flows, err = h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4, nil) if err != nil { t.Fatalf("failed to list conntracks following successful update: %s", err) } @@ -1726,11 +1726,11 @@ func TestConntrackFlowToNlData(t *testing.T) { bytesV6 = append(bytesV6, a.Serialize()...) } - parsedFlowV4 := parseRawData(bytesV4) + parsedFlowV4 := parseRawData(bytesV4, nil) checkFlowsEqual(t, &flowV4, parsedFlowV4) checkProtoInfosEqual(t, flowV4.ProtoInfo, parsedFlowV4.ProtoInfo) - parsedFlowV6 := parseRawData(bytesV6) + parsedFlowV6 := parseRawData(bytesV6, nil) checkFlowsEqual(t, &flowV6, parsedFlowV6) checkProtoInfosEqual(t, flowV6.ProtoInfo, parsedFlowV6.ProtoInfo) } diff --git a/devlink_linux.go b/devlink_linux.go index 19a0ca7e..c378d14e 100644 --- a/devlink_linux.go +++ b/devlink_linux.go @@ -523,10 +523,10 @@ func parseDevlinkDevice(msgs [][]byte) (*DevlinkDevice, error) { return dev, nil } -func (h *Handle) createCmdReq(cmd uint8, bus string, device string) (*GenlFamily, *nl.NetlinkRequest, error) { +func (h *Handle) createCmdReq(cmd uint8, bus string, device string) (*GenlFamily, nl.NetlinkRequest, error) { f, err := h.GenlFamilyGet(nl.GENL_DEVLINK_NAME) if err != nil { - return nil, nil, err + return nil, nl.NetlinkRequest{}, err } msg := &nl.Genlmsg{ @@ -858,7 +858,7 @@ func (h *Handle) DevlinkSplitPort(port *DevlinkPort, count uint32) error { } func DevlinkSplitPort(port *DevlinkPort, count uint32) error { - return pkgHandle.DevlinkSplitPort(port, count); + return pkgHandle.DevlinkSplitPort(port, count) } // DevlinkUnsplitPort: unsplit devlink port @@ -876,7 +876,7 @@ func (h *Handle) DevlinkUnsplitPort(port *DevlinkPort) error { } func DevlinkUnsplitPort(port *DevlinkPort) error { - return pkgHandle.DevlinkUnsplitPort(port); + return pkgHandle.DevlinkUnsplitPort(port) } // DevlinkSetDeviceParam set specific parameter for devlink device diff --git a/gtp_linux.go b/gtp_linux.go index 377dcae5..078e16a5 100644 --- a/gtp_linux.go +++ b/gtp_linux.go @@ -134,7 +134,7 @@ func (h *Handle) GTPPDPByTID(link Link, tid int) (*PDP, error) { req.AddData(nl.NewRtAttr(nl.GENL_GTP_ATTR_VERSION, nl.Uint32Attr(0))) req.AddData(nl.NewRtAttr(nl.GENL_GTP_ATTR_LINK, nl.Uint32Attr(uint32(link.Attrs().Index)))) req.AddData(nl.NewRtAttr(nl.GENL_GTP_ATTR_TID, nl.Uint64Attr(uint64(tid)))) - return gtpPDPGet(req) + return gtpPDPGet(&req) } func GTPPDPByTID(link Link, tid int) (*PDP, error) { @@ -155,7 +155,7 @@ func (h *Handle) GTPPDPByITEI(link Link, itei int) (*PDP, error) { req.AddData(nl.NewRtAttr(nl.GENL_GTP_ATTR_VERSION, nl.Uint32Attr(1))) req.AddData(nl.NewRtAttr(nl.GENL_GTP_ATTR_LINK, nl.Uint32Attr(uint32(link.Attrs().Index)))) req.AddData(nl.NewRtAttr(nl.GENL_GTP_ATTR_I_TEI, nl.Uint32Attr(uint32(itei)))) - return gtpPDPGet(req) + return gtpPDPGet(&req) } func GTPPDPByITEI(link Link, itei int) (*PDP, error) { @@ -176,7 +176,7 @@ func (h *Handle) GTPPDPByMSAddress(link Link, addr net.IP) (*PDP, error) { req.AddData(nl.NewRtAttr(nl.GENL_GTP_ATTR_VERSION, nl.Uint32Attr(0))) req.AddData(nl.NewRtAttr(nl.GENL_GTP_ATTR_LINK, nl.Uint32Attr(uint32(link.Attrs().Index)))) req.AddData(nl.NewRtAttr(nl.GENL_GTP_ATTR_MS_ADDRESS, []byte(addr.To4()))) - return gtpPDPGet(req) + return gtpPDPGet(&req) } func GTPPDPByMSAddress(link Link, addr net.IP) (*PDP, error) { diff --git a/handle_linux.go b/handle_linux.go index 358c175b..c96cee5a 100644 --- a/handle_linux.go +++ b/handle_linux.go @@ -185,12 +185,12 @@ func (h *Handle) Delete() { _ = h.Close() } -func (h *Handle) newNetlinkRequest(proto, flags int) *nl.NetlinkRequest { +func (h *Handle) newNetlinkRequest(proto, flags int) nl.NetlinkRequest { // Do this so that package API still use nl package variable nextSeqNr if h.sockets == nil { return nl.NewNetlinkRequest(proto, flags) } - return &nl.NetlinkRequest{ + return nl.NetlinkRequest{ NlMsghdr: unix.NlMsghdr{ Len: uint32(unix.SizeofNlMsghdr), Type: uint16(proto), diff --git a/ipset_linux.go b/ipset_linux.go index 77184b21..76a12e04 100644 --- a/ipset_linux.go +++ b/ipset_linux.go @@ -369,7 +369,7 @@ func (h *Handle) IpsetTest(setname string, entry *IPSetEntry) (bool, error) { return true, nil } -func (h *Handle) newIpsetRequest(cmd int) *nl.NetlinkRequest { +func (h *Handle) newIpsetRequest(cmd int) nl.NetlinkRequest { req := h.newNetlinkRequest(cmd|(unix.NFNL_SUBSYS_IPSET<<8), nl.GetIpsetFlags(cmd)) // Add the netfilter header @@ -470,7 +470,7 @@ func getIpsetDefaultRevision(typename string, featureFlags uint32) uint8 { return 0 } -func ipsetExecute(req *nl.NetlinkRequest) (msgs [][]byte, err error) { +func ipsetExecute(req nl.NetlinkRequest) (msgs [][]byte, err error) { msgs, err = req.Execute(unix.NETLINK_NETFILTER, 0) if err != nil { diff --git a/link_linux.go b/link_linux.go index 67966630..d68bb6f4 100644 --- a/link_linux.go +++ b/link_linux.go @@ -1047,7 +1047,7 @@ func LinkSetXdpFdWithFlags(link Link, fd, flags int) error { msg.Index = int32(base.Index) req.AddData(msg) - addXdpAttrs(&LinkXdp{Fd: fd, Flags: uint32(flags)}, req) + req = addXdpAttrs(req, &LinkXdp{Fd: fd, Flags: uint32(flags)}) _, err := req.Execute(unix.NETLINK_ROUTE, 0) return err @@ -1755,7 +1755,7 @@ func (h *Handle) linkModify(link Link, flags int) error { } if base.Xdp != nil { - addXdpAttrs(base.Xdp, req) + req = addXdpAttrs(req, base.Xdp) } linkInfo := nl.NewRtAttr(unix.IFLA_LINKINFO, nil) @@ -2113,7 +2113,7 @@ func (h *Handle) LinkByIndex(index int) (Link, error) { return execGetLink(req) } -func execGetLink(req *nl.NetlinkRequest) (Link, error) { +func execGetLink(req nl.NetlinkRequest) (Link, error) { msgs, err := req.Execute(unix.NETLINK_ROUTE, 0) if err != nil { if errno, ok := err.(syscall.Errno); ok { @@ -2616,7 +2616,7 @@ func linkSubscribeAt(newNs, curNs netns.NsHandle, ch chan<- LinkUpdate, done <-c unix.NLM_F_DUMP) msg := nl.NewIfInfomsg(unix.AF_UNSPEC) req.AddData(msg) - if err := s.Send(req); err != nil { + if err := s.Send(req.Serialize()); err != nil { return err } } @@ -3598,7 +3598,7 @@ func parseGretunData(link Link, data []syscall.NetlinkRouteAttr) { } } -func addXdpAttrs(xdp *LinkXdp, req *nl.NetlinkRequest) { +func addXdpAttrs(req nl.NetlinkRequest, xdp *LinkXdp) nl.NetlinkRequest { attrs := nl.NewRtAttr(unix.IFLA_XDP|unix.NLA_F_NESTED, nil) b := make([]byte, 4) native.PutUint32(b, uint32(xdp.Fd)) @@ -3609,6 +3609,7 @@ func addXdpAttrs(xdp *LinkXdp, req *nl.NetlinkRequest) { attrs.AddRtAttr(nl.IFLA_XDP_FLAGS, b) } req.AddData(attrs) + return req } func parseLinkXdp(data []byte) (*LinkXdp, error) { diff --git a/neigh_linux.go b/neigh_linux.go index f4dd8353..51e5edee 100644 --- a/neigh_linux.go +++ b/neigh_linux.go @@ -146,7 +146,7 @@ func (h *Handle) NeighDel(neigh *Neigh) error { return neighHandle(neigh, req) } -func neighHandle(neigh *Neigh, req *nl.NetlinkRequest) error { +func neighHandle(neigh *Neigh, req nl.NetlinkRequest) error { var family int if neigh.Family > 0 { @@ -407,7 +407,7 @@ func neighSubscribeAt(newNs, curNs netns.NsHandle, ch chan<- NeighUpdate, done < req := pkgHandle.newNetlinkRequest(unix.RTM_GETNEIGH, unix.NLM_F_DUMP) ndmsg := &Ndmsg{Family: uint8(family)} req.AddData(ndmsg) - if err := s.Send(req); err != nil { + if err := s.Send(req.Serialize()); err != nil { return err } return nil diff --git a/nexthop_linux.go b/nexthop_linux.go index d9f0c6d6..e19a89cf 100644 --- a/nexthop_linux.go +++ b/nexthop_linux.go @@ -19,10 +19,11 @@ func NexthopAdd(nh *Nexthop) error { func (h *Handle) NexthopAdd(nh *Nexthop) error { flags := unix.NLM_F_CREATE | unix.NLM_F_EXCL | unix.NLM_F_ACK req := h.newNetlinkRequest(unix.RTM_NEWNEXTHOP, flags) - if err := prepareNewNexthop(nh, req, &nl.Nhmsg{}); err != nil { + var err error + if req, err = prepareNewNexthop(req, nh, &nl.Nhmsg{}); err != nil { return err } - _, err := req.Execute(unix.NETLINK_ROUTE, 0) + _, err = req.Execute(unix.NETLINK_ROUTE, 0) return err } @@ -37,10 +38,11 @@ func NexthopReplace(nh *Nexthop) error { func (h *Handle) NexthopReplace(nh *Nexthop) error { flags := unix.NLM_F_CREATE | unix.NLM_F_REPLACE | unix.NLM_F_ACK req := h.newNetlinkRequest(unix.RTM_NEWNEXTHOP, flags) - if err := prepareNewNexthop(nh, req, &nl.Nhmsg{}); err != nil { + var err error + if req, err = prepareNewNexthop(req, nh, &nl.Nhmsg{}); err != nil { return err } - _, err := req.Execute(unix.NETLINK_ROUTE, 0) + _, err = req.Execute(unix.NETLINK_ROUTE, 0) return err } @@ -54,10 +56,11 @@ func NexthopDel(nh *Nexthop) error { // Equivalent to: `ip nexthop del $nexthop` func (h *Handle) NexthopDel(nh *Nexthop) error { req := h.newNetlinkRequest(unix.RTM_DELNEXTHOP, unix.NLM_F_ACK) - if err := prepareDelNexthop(nh, req, &nl.Nhmsg{}); err != nil { + var err error + if req, err = prepareDelNexthop(req, nh, &nl.Nhmsg{}); err != nil { return err } - _, err := req.Execute(unix.NETLINK_ROUTE, 0) + _, err = req.Execute(unix.NETLINK_ROUTE, 0) return err } @@ -239,7 +242,7 @@ func deriveFamilyFromNexthop(nh *Nexthop) uint8 { return FAMILY_V6 } -func prepareNewNexthop(nh *Nexthop, req *nl.NetlinkRequest, msg *nl.Nhmsg) error { +func prepareNewNexthop(req nl.NetlinkRequest, nh *Nexthop, msg *nl.Nhmsg) (nl.NetlinkRequest, error) { var rtAttrs []*nl.RtAttr // We can find the supported attributes from the kernel source code: @@ -265,10 +268,10 @@ func prepareNewNexthop(nh *Nexthop, req *nl.NetlinkRequest, msg *nl.Nhmsg) error req.AddData(attr) } - return nil + return req, nil } -func prepareDelNexthop(nh *Nexthop, req *nl.NetlinkRequest, msg *nl.Nhmsg) error { +func prepareDelNexthop(req nl.NetlinkRequest, nh *Nexthop, msg *nl.Nhmsg) (nl.NetlinkRequest, error) { // We can find the supported attributes from the kernel source code: // https://github.com/torvalds/linux/blob/e53642b87a4f4b03a8d7e5f8507fc3cd0c595ea6/net/ipv4/nexthop.c#L52 rtAttrs := encodeNexthopAttrs(nh, []uint16{ @@ -282,5 +285,5 @@ func prepareDelNexthop(nh *Nexthop, req *nl.NetlinkRequest, msg *nl.Nhmsg) error req.AddData(attr) } - return nil + return req, nil } diff --git a/nl/nl_linux.go b/nl/nl_linux.go index 793662f6..7c02571a 100644 --- a/nl/nl_linux.go +++ b/nl/nl_linux.go @@ -577,7 +577,7 @@ func (req *NetlinkRequest) ExecuteIter(sockType int, resType uint16, f func(msg defer s.Unlock() } - if err := s.Send(req); err != nil { + if err := s.Send(req.Serialize()); err != nil { return err } @@ -684,8 +684,8 @@ func dummyMsgIterFunc(msg []byte) bool { // Create a new netlink request from proto and flags // Note the Len value will be inaccurate once data is added until // the message is serialized -func NewNetlinkRequest(proto, flags int) *NetlinkRequest { - return &NetlinkRequest{ +func NewNetlinkRequest(proto, flags int) NetlinkRequest { + return NetlinkRequest{ NlMsghdr: unix.NlMsghdr{ Len: uint32(unix.SizeofNlMsghdr), Type: uint16(proto), @@ -854,7 +854,7 @@ func (s *NetlinkSocket) GetTimeouts() (send, receive time.Duration) { time.Duration(atomic.LoadInt64(&s.receiveTimeout)) } -func (s *NetlinkSocket) Send(request *NetlinkRequest) error { +func (s *NetlinkSocket) Send(serializedReq []byte) error { rawConn, err := s.file.SyscallConn() if err != nil { return err @@ -870,7 +870,6 @@ func (s *NetlinkSocket) Send(request *NetlinkRequest) error { if err := s.file.SetWriteDeadline(deadline); err != nil { return err } - serializedReq := request.Serialize() err = rawConn.Write(func(fd uintptr) (done bool) { innerErr = unix.Sendto(int(s.fd), serializedReq, 0, &s.lsa) return innerErr != unix.EWOULDBLOCK diff --git a/proc_event_linux.go b/proc_event_linux.go index ac8762bd..0925e888 100644 --- a/proc_event_linux.go +++ b/proc_event_linux.go @@ -134,7 +134,7 @@ func ProcEventMonitor(ch chan<- ProcEvent, done <-chan struct{}, errorChan chan< cm := nl.NewCnMsg(CN_IDX_PROC, CN_VAL_PROC, PROC_CN_MCAST_LISTEN) nlmsg.AddData(cm) - s.Send(&nlmsg) + s.Send(nlmsg.Serialize()) if done != nil { go func() { diff --git a/qdisc_linux.go b/qdisc_linux.go index 0a2a5891..af34acf7 100644 --- a/qdisc_linux.go +++ b/qdisc_linux.go @@ -152,7 +152,8 @@ func (h *Handle) qdiscModify(cmd, flags int, qdisc Qdisc) error { // When deleting don't bother building the rest of the netlink payload if cmd != unix.RTM_DELQDISC { - if err := qdiscPayload(req, qdisc); err != nil { + var err error + if req, err = qdiscPayload(req, qdisc); err != nil { return err } } @@ -161,7 +162,7 @@ func (h *Handle) qdiscModify(cmd, flags int, qdisc Qdisc) error { return err } -func qdiscPayload(req *nl.NetlinkRequest, qdisc Qdisc) error { +func qdiscPayload(req nl.NetlinkRequest, qdisc Qdisc) (nl.NetlinkRequest, error) { req.AddData(nl.NewRtAttr(nl.TCA_KIND, nl.ZeroTerminated(qdisc.Type()))) if qdisc.Attrs().IngressBlock != nil { @@ -257,7 +258,7 @@ func qdiscPayload(req *nl.NetlinkRequest, qdisc Qdisc) error { case *Ingress: // ingress filters must use the proper handle if qdisc.Attrs().Parent != HANDLE_INGRESS { - return fmt.Errorf("Ingress filters must set Parent to HANDLE_INGRESS") + return nl.NetlinkRequest{}, fmt.Errorf("Ingress filters must set Parent to HANDLE_INGRESS") } case *FqCodel: options.AddRtAttr(nl.TCA_FQ_CODEL_ECN, nl.Uint32Attr((uint32(qdisc.ECN)))) @@ -333,7 +334,7 @@ func qdiscPayload(req *nl.NetlinkRequest, qdisc Qdisc) error { if options != nil { req.AddData(options) } - return nil + return req, nil } // QdiscList gets a list of qdiscs in the system. diff --git a/rdma_link_linux.go b/rdma_link_linux.go index 2e774e5a..76bd7199 100644 --- a/rdma_link_linux.go +++ b/rdma_link_linux.go @@ -84,8 +84,7 @@ func executeOneGetRdmaLink(data []byte) (*RdmaLink, error) { return &link, nil } -func execRdmaSetLink(req *nl.NetlinkRequest) error { - +func execRdmaSetLink(req nl.NetlinkRequest) error { _, err := req.Execute(unix.NETLINK_RDMA, 0) return err } diff --git a/route_linux.go b/route_linux.go index 1f99a17d..06c6a015 100644 --- a/route_linux.go +++ b/route_linux.go @@ -820,7 +820,7 @@ func RouteAdd(route *Route) error { func (h *Handle) RouteAdd(route *Route) error { flags := unix.NLM_F_CREATE | unix.NLM_F_EXCL | unix.NLM_F_ACK req := h.newNetlinkRequest(unix.RTM_NEWROUTE, flags) - _, err := h.routeHandle(route, req, nl.NewRtMsg()) + _, err := h.routeHandle(req, route, nl.NewRtMsg()) return err } @@ -835,7 +835,7 @@ func RouteAppend(route *Route) error { func (h *Handle) RouteAppend(route *Route) error { flags := unix.NLM_F_CREATE | unix.NLM_F_APPEND | unix.NLM_F_ACK req := h.newNetlinkRequest(unix.RTM_NEWROUTE, flags) - _, err := h.routeHandle(route, req, nl.NewRtMsg()) + _, err := h.routeHandle(req, route, nl.NewRtMsg()) return err } @@ -848,7 +848,7 @@ func RouteAddEcmp(route *Route) error { func (h *Handle) RouteAddEcmp(route *Route) error { flags := unix.NLM_F_CREATE | unix.NLM_F_ACK req := h.newNetlinkRequest(unix.RTM_NEWROUTE, flags) - _, err := h.routeHandle(route, req, nl.NewRtMsg()) + _, err := h.routeHandle(req, route, nl.NewRtMsg()) return err } @@ -863,7 +863,7 @@ func RouteChange(route *Route) error { func (h *Handle) RouteChange(route *Route) error { flags := unix.NLM_F_REPLACE | unix.NLM_F_ACK req := h.newNetlinkRequest(unix.RTM_NEWROUTE, flags) - _, err := h.routeHandle(route, req, nl.NewRtMsg()) + _, err := h.routeHandle(req, route, nl.NewRtMsg()) return err } @@ -878,7 +878,7 @@ func RouteReplace(route *Route) error { func (h *Handle) RouteReplace(route *Route) error { flags := unix.NLM_F_CREATE | unix.NLM_F_REPLACE | unix.NLM_F_ACK req := h.newNetlinkRequest(unix.RTM_NEWROUTE, flags) - _, err := h.routeHandle(route, req, nl.NewRtMsg()) + _, err := h.routeHandle(req, route, nl.NewRtMsg()) return err } @@ -892,27 +892,29 @@ func RouteDel(route *Route) error { // Equivalent to: `ip route del $route` func (h *Handle) RouteDel(route *Route) error { req := h.newNetlinkRequest(unix.RTM_DELROUTE, unix.NLM_F_ACK) - _, err := h.routeHandle(route, req, nl.NewRtDelMsg()) + _, err := h.routeHandle(req, route, nl.NewRtDelMsg()) return err } -func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg) ([][]byte, error) { - if err := h.prepareRouteReq(route, req, msg); err != nil { +func (h *Handle) routeHandle(req nl.NetlinkRequest, route *Route, msg *nl.RtMsg) ([][]byte, error) { + var err error + if req, err = h.prepareRouteReq(req, route, msg); err != nil { return nil, err } return req.Execute(unix.NETLINK_ROUTE, 0) } -func (h *Handle) routeHandleIter(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg, f func(msg []byte) bool) error { - if err := h.prepareRouteReq(route, req, msg); err != nil { +func (h *Handle) routeHandleIter(req nl.NetlinkRequest, route *Route, msg *nl.RtMsg, f func(msg []byte) bool) error { + var err error + if req, err = h.prepareRouteReq(req, route, msg); err != nil { return err } return req.ExecuteIter(unix.NETLINK_ROUTE, 0, f) } -func (h *Handle) prepareRouteReq(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg) error { +func (h *Handle) prepareRouteReq(req nl.NetlinkRequest, route *Route, msg *nl.RtMsg) (nl.NetlinkRequest, error) { if req.NlMsghdr.Type != unix.RTM_GETROUTE && (route.Dst == nil || route.Dst.IP == nil) && route.Src == nil && route.Gw == nil && route.MPLSDst == nil { - return fmt.Errorf("either Dst.IP, Src.IP or Gw must be set") + return nl.NetlinkRequest{}, fmt.Errorf("either Dst.IP, Src.IP or Gw must be set") } family := -1 @@ -939,11 +941,11 @@ func (h *Handle) prepareRouteReq(route *Route, req *nl.NetlinkRequest, msg *nl.R if route.NewDst != nil { if family != -1 && family != route.NewDst.Family() { - return fmt.Errorf("new destination and destination are not the same address family") + return nl.NetlinkRequest{}, fmt.Errorf("new destination and destination are not the same address family") } buf, err := route.NewDst.Encode() if err != nil { - return err + return nl.NetlinkRequest{}, err } rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_NEWDST, buf)) } @@ -954,7 +956,7 @@ func (h *Handle) prepareRouteReq(route *Route, req *nl.NetlinkRequest, msg *nl.R rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_ENCAP_TYPE, buf)) buf, err := route.Encap.Encode() if err != nil { - return err + return nl.NetlinkRequest{}, err } switch route.Encap.Type() { case nl.LWTUNNEL_ENCAP_BPF: @@ -968,7 +970,7 @@ func (h *Handle) prepareRouteReq(route *Route, req *nl.NetlinkRequest, msg *nl.R if route.Src != nil { srcFamily := nl.GetIPFamily(route.Src) if family != -1 && family != srcFamily { - return fmt.Errorf("source and destination ip are not the same IP family") + return nl.NetlinkRequest{}, fmt.Errorf("source and destination ip are not the same IP family") } family = srcFamily var srcData []byte @@ -984,7 +986,7 @@ func (h *Handle) prepareRouteReq(route *Route, req *nl.NetlinkRequest, msg *nl.R if route.Gw != nil { gwFamily := nl.GetIPFamily(route.Gw) if family != -1 && family != gwFamily { - return fmt.Errorf("gateway, source, and destination ip are not the same IP family") + return nl.NetlinkRequest{}, fmt.Errorf("gateway, source, and destination ip are not the same IP family") } family = gwFamily var gwData []byte @@ -999,7 +1001,7 @@ func (h *Handle) prepareRouteReq(route *Route, req *nl.NetlinkRequest, msg *nl.R if route.Via != nil { buf, err := route.Via.Encode() if err != nil { - return fmt.Errorf("failed to encode RTA_VIA: %v", err) + return nl.NetlinkRequest{}, fmt.Errorf("failed to encode RTA_VIA: %v", err) } rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_VIA, buf)) } @@ -1018,7 +1020,7 @@ func (h *Handle) prepareRouteReq(route *Route, req *nl.NetlinkRequest, msg *nl.R if nh.Gw != nil { gwFamily := nl.GetIPFamily(nh.Gw) if family != -1 && family != gwFamily { - return fmt.Errorf("gateway, source, and destination ip are not the same IP family") + return nl.NetlinkRequest{}, fmt.Errorf("gateway, source, and destination ip are not the same IP family") } if gwFamily == FAMILY_V4 { children = append(children, nl.NewRtAttr(unix.RTA_GATEWAY, []byte(nh.Gw.To4()))) @@ -1028,11 +1030,11 @@ func (h *Handle) prepareRouteReq(route *Route, req *nl.NetlinkRequest, msg *nl.R } if nh.NewDst != nil { if family != -1 && family != nh.NewDst.Family() { - return fmt.Errorf("new destination and destination are not the same address family") + return nl.NetlinkRequest{}, fmt.Errorf("new destination and destination are not the same address family") } buf, err := nh.NewDst.Encode() if err != nil { - return err + return nl.NetlinkRequest{}, err } children = append(children, nl.NewRtAttr(unix.RTA_NEWDST, buf)) } @@ -1042,14 +1044,14 @@ func (h *Handle) prepareRouteReq(route *Route, req *nl.NetlinkRequest, msg *nl.R children = append(children, nl.NewRtAttr(unix.RTA_ENCAP_TYPE, buf)) buf, err := nh.Encap.Encode() if err != nil { - return err + return nl.NetlinkRequest{}, err } children = append(children, nl.NewRtAttr(unix.RTA_ENCAP, buf)) } if nh.Via != nil { buf, err := nh.Via.Encode() if err != nil { - return err + return nl.NetlinkRequest{}, err } children = append(children, nl.NewRtAttr(unix.RTA_VIA, buf)) } @@ -1199,7 +1201,7 @@ func (h *Handle) prepareRouteReq(route *Route, req *nl.NetlinkRequest, msg *nl.R native.PutUint32(b, uint32(route.LinkIndex)) req.AddData(nl.NewRtAttr(unix.RTA_OIF, b)) } - return nil + return req, nil } // RouteList gets a list of routes in the system. @@ -1268,7 +1270,7 @@ func (h *Handle) RouteListFilteredIter(family int, filter *Route, filterMask uin rtmsg.Family = uint8(family) var parseErr error - executeErr := h.routeHandleIter(filter, req, rtmsg, func(m []byte) bool { + executeErr := h.routeHandleIter(req, filter, rtmsg, func(m []byte) bool { msg := nl.DeserializeRtMsg(m) if family != FAMILY_ALL && msg.Family != uint8(family) { // Ignore routes not matching requested family @@ -1823,7 +1825,7 @@ func routeSubscribeAt(newNs, curNs netns.NsHandle, ch chan<- RouteUpdate, done < unix.NLM_F_DUMP) infmsg := nl.NewIfInfomsg(unix.AF_UNSPEC) req.AddData(infmsg) - if err := s.Send(req); err != nil { + if err := s.Send(req.Serialize()); err != nil { return err } } diff --git a/rule_linux.go b/rule_linux.go index 65c1b59a..8b23f677 100644 --- a/rule_linux.go +++ b/rule_linux.go @@ -24,7 +24,7 @@ func RuleAdd(rule *Rule) error { // Equivalent to: ip rule add func (h *Handle) RuleAdd(rule *Rule) error { req := h.newNetlinkRequest(unix.RTM_NEWRULE, unix.NLM_F_CREATE|unix.NLM_F_EXCL|unix.NLM_F_ACK) - return ruleHandle(rule, req) + return ruleHandle(req, rule) } // RuleDel deletes a rule from the system. @@ -37,10 +37,10 @@ func RuleDel(rule *Rule) error { // Equivalent to: ip rule del func (h *Handle) RuleDel(rule *Rule) error { req := h.newNetlinkRequest(unix.RTM_DELRULE, unix.NLM_F_ACK) - return ruleHandle(rule, req) + return ruleHandle(req, rule) } -func ruleHandle(rule *Rule, req *nl.NetlinkRequest) error { +func ruleHandle(req nl.NetlinkRequest, rule *Rule) error { msg := nl.NewRtMsg() msg.Family = unix.AF_INET msg.Protocol = unix.RTPROT_BOOT @@ -452,7 +452,7 @@ func ruleSubscribeAt(newNs, curNs netns.NsHandle, ch chan<- RuleUpdate, done <-c unix.NLM_F_DUMP) infmsg := nl.NewIfInfomsg(unix.AF_UNSPEC) req.AddData(infmsg) - if err := s.Send(req); err != nil { + if err := s.Send(req.Serialize()); err != nil { return err } } diff --git a/socket_xdp_linux.go b/socket_xdp_linux.go index c1dd00a8..1a9ab1c9 100644 --- a/socket_xdp_linux.go +++ b/socket_xdp_linux.go @@ -129,7 +129,7 @@ func socketDiagXDPExecutor(receiver func(syscall.NetlinkMessage) error) error { Family: unix.AF_XDP, Show: XDP_SHOW_INFO | XDP_SHOW_RING_CFG | XDP_SHOW_UMEM | XDP_SHOW_STATS, }) - if err := s.Send(req); err != nil { + if err := s.Send(req.Serialize()); err != nil { return err } From 12b81b058309b09781c79abef04e429fc3de5991 Mon Sep 17 00:00:00 2001 From: eustrain Date: Mon, 9 Feb 2026 08:47:46 +0000 Subject: [PATCH 2/5] perf: marshal and unmarshal conntrack flow --- conntrack_linux.go | 149 ++++++++++++++++++++++++--------------------- conntrack_test.go | 38 ++++++------ 2 files changed, 99 insertions(+), 88 deletions(-) diff --git a/conntrack_linux.go b/conntrack_linux.go index 6f057694..df1309cb 100644 --- a/conntrack_linux.go +++ b/conntrack_linux.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "fmt" + "io" "io/fs" "net" "time" @@ -378,20 +379,22 @@ func (t *IPTuple) toNlData(family uint8) ([]*nl.RtAttr, error) { } type ConntrackFlow struct { - FamilyType uint8 - Forward IPTuple - Reverse IPTuple - Mark uint32 - Zone uint16 - TimeStart uint64 - TimeStop uint64 - TimeOut uint32 - Status uint32 - Use uint32 - ID uint32 - Labels []byte - LabelsMask []byte - ProtoInfo ProtoInfo + FamilyType uint8 + Forward IPTuple + Reverse IPTuple + Mark uint32 + Zone uint16 + TimeStart uint64 + TimeStop uint64 + TimeOut uint32 + Status uint32 + Use uint32 + ID uint32 + Labels [16]byte + HasLabels bool + LabelsMask [16]byte + HasLabelsMask bool + ProtoInfo ProtoInfo } func (s *ConntrackFlow) String() string { @@ -415,10 +418,10 @@ func (s *ConntrackFlow) String() string { s.Reverse.SrcIP.String(), s.Reverse.DstIP.String(), s.Reverse.SrcPort, s.Reverse.DstPort, s.Reverse.Packets, s.Reverse.Bytes) } out += fmt.Sprintf(" mark=0x%x", s.Mark) - if len(s.Labels) > 0 { + if s.HasLabels { out += fmt.Sprintf(" labels=0x%x", s.Labels) } - if len(s.LabelsMask) > 0 { + if s.HasLabelsMask { out += fmt.Sprintf("/0x%x", s.LabelsMask) } if s.Status != 0 { @@ -506,19 +509,13 @@ func (s *ConntrackFlow) toNlData() ([]*nl.RtAttr, error) { payload = append(payload, ctZone) ctStatus := nl.NewRtAttr(nl.CTA_STATUS, nl.BEUint32Attr(s.Status)) payload = append(payload, ctStatus) - // Labels: nil => do not send; 16 zero bytes => update conntrack labels. - if s.Labels != nil { - if len(s.Labels) != 16 { - return nil, fmt.Errorf("conntrack CTA_LABELS must be 16 bytes, got %d", len(s.Labels)) - } - ctLabels := nl.NewRtAttr(nl.CTA_LABELS, s.Labels) + // Labels: HasLabels => update conntrack labels; else => do not send. + if s.HasLabels { + ctLabels := nl.NewRtAttr(nl.CTA_LABELS, s.Labels[:]) payload = append(payload, ctLabels) - // Labels Mask: nil => do not send; 16 zero bytes => update conntrack labels with mask. - if s.LabelsMask != nil { - if len(s.LabelsMask) != 16 { - return nil, fmt.Errorf("conntrack CTA_LABELS_MASK must be 16 bytes, got %d", len(s.LabelsMask)) - } - ctLabelsMask := nl.NewRtAttr(nl.CTA_LABELS_MASK, s.LabelsMask) + // Labels Mask: HasLabelsMask => update conntrack labels with mask; else => do not send. + if s.HasLabelsMask { + ctLabelsMask := nl.NewRtAttr(nl.CTA_LABELS_MASK, s.LabelsMask[:]) payload = append(payload, ctLabelsMask) } } @@ -584,26 +581,26 @@ func parseIpTuple(reader *bytes.Reader, tpl *IPTuple) uint8 { protoInfoBytesRead += uint16(nl.SizeofNfattr) switch t { case nl.CTA_PROTO_SRC_PORT: - parseBERaw16(reader, &tpl.SrcPort) + tpl.SrcPort = parseBERaw16(reader) protoInfoBytesRead += 2 case nl.CTA_PROTO_DST_PORT: - parseBERaw16(reader, &tpl.DstPort) + tpl.DstPort = parseBERaw16(reader) protoInfoBytesRead += 2 case nl.CTA_PROTO_ICMP_ID: fallthrough case nl.CTA_PROTO_ICMPV6_ID: - parseBERaw16(reader, &tpl.ICMPID) + tpl.ICMPID = parseBERaw16(reader) protoInfoBytesRead += 2 case nl.CTA_PROTO_ICMP_CODE: fallthrough case nl.CTA_PROTO_ICMPV6_CODE: - parseU8(reader, &tpl.ICMPCode) + tpl.ICMPCode = parseU8(reader) protoInfoBytesRead += 1 ICMPCodeDone = true case nl.CTA_PROTO_ICMP_TYPE: fallthrough case nl.CTA_PROTO_ICMPV6_TYPE: - parseU8(reader, &tpl.ICMPType) + tpl.ICMPType = parseU8(reader) protoInfoBytesRead += 1 ICMPTypeDone = true } @@ -625,15 +622,15 @@ func parseNfAttrTLV(r *bytes.Reader) (isNested bool, attrType, len uint16, value isNested, attrType, len = parseNfAttrTL(r) value = make([]byte, len) - binary.Read(r, binary.BigEndian, &value) + io.ReadAtLeast(r, value, int(len)) return isNested, attrType, len, value } func parseNfAttrTL(r *bytes.Reader) (isNested bool, attrType, len uint16) { - binary.Read(r, nl.NativeEndian(), &len) + len = parseRaw16(r) len -= nl.SizeofNfattr - binary.Read(r, nl.NativeEndian(), &attrType) + attrType = parseRaw16(r) isNested = (attrType & nl.NLA_F_NESTED) == nl.NLA_F_NESTED attrType = attrType & (nl.NLA_F_NESTED - 1) return isNested, attrType, len @@ -648,33 +645,42 @@ func skipNfAttrValue(r *bytes.Reader, len uint16) uint16 { return len } -func parseU8(r *bytes.Reader, v *uint8) { - binary.Read(r, binary.BigEndian, v) +func parseU8(r *bytes.Reader) uint8 { + b, _ := r.ReadByte() + return b } -func parseBERaw16(r *bytes.Reader, v *uint16) { - binary.Read(r, binary.BigEndian, v) +func parseBERaw16(r *bytes.Reader) uint16 { + var buf [2]byte + io.ReadAtLeast(r, buf[:], 2) + return binary.BigEndian.Uint16(buf[:]) } -func parseBERaw32(r *bytes.Reader, v *uint32) { - binary.Read(r, binary.BigEndian, v) +func parseBERaw32(r *bytes.Reader) uint32 { + var buf [4]byte + io.ReadAtLeast(r, buf[:], 4) + return binary.BigEndian.Uint32(buf[:]) } -func parseBERaw64(r *bytes.Reader, v *uint64) { - binary.Read(r, binary.BigEndian, v) +func parseBERaw64(r *bytes.Reader) uint64 { + var buf [8]byte + io.ReadAtLeast(r, buf[:], 8) + return binary.BigEndian.Uint64(buf[:]) } -func parseRaw32(r *bytes.Reader, v *uint32) { - binary.Read(r, nl.NativeEndian(), v) +func parseRaw16(r *bytes.Reader) uint16 { + var buf [2]byte + io.ReadAtLeast(r, buf[:], 2) + return binary.BigEndian.Uint16(buf[:]) } func parseByteAndPacketCounters(r *bytes.Reader) (bytes, packets uint64) { for i := 0; i < 2; i++ { switch _, t, _ := parseNfAttrTL(r); t { case nl.CTA_COUNTERS_BYTES: - parseBERaw64(r, &bytes) + bytes = parseBERaw64(r) case nl.CTA_COUNTERS_PACKETS: - parseBERaw64(r, &packets) + packets = parseBERaw64(r) default: return } @@ -696,9 +702,9 @@ func parseTimeStamp(r *bytes.Reader, readSize uint16) (tstart, tstop uint64) { for i := 0; i < numTimeStamps; i++ { switch _, t, _ := parseNfAttrTL(r); t { case nl.CTA_TIMESTAMP_START: - parseBERaw64(r, &tstart) + tstart = parseBERaw64(r) case nl.CTA_TIMESTAMP_STOP: - parseBERaw64(r, &tstop) + tstop = parseBERaw64(r) default: return } @@ -708,7 +714,7 @@ func parseTimeStamp(r *bytes.Reader, readSize uint16) (tstart, tstop uint64) { } func parseProtoInfoTCPState(r *bytes.Reader) (s uint8) { - binary.Read(r, binary.BigEndian, &s) + s, _ = r.ReadByte() r.Seek(nl.SizeofNfattr-1, seekCurrent) return s } @@ -726,19 +732,19 @@ func parseProtoInfoTCP(r *bytes.Reader, attrLen uint16) *ProtoInfoTCP { p.State = parseProtoInfoTCPState(r) bytesRead += nl.SizeofNfattr case nl.CTA_PROTOINFO_TCP_WSCALE_ORIGINAL: - parseU8(r, &p.WsacleOriginal) + p.WsacleOriginal = parseU8(r) r.Seek(nl.SizeofNfattr-1, seekCurrent) bytesRead += nl.SizeofNfattr case nl.CTA_PROTOINFO_TCP_WSCALE_REPLY: - parseU8(r, &p.WsacleReply) + p.WsacleReply = parseU8(r) r.Seek(nl.SizeofNfattr-1, seekCurrent) bytesRead += nl.SizeofNfattr case nl.CTA_PROTOINFO_TCP_FLAGS_ORIGINAL: - parseBERaw16(r, &p.FlagsOriginal) + p.FlagsOriginal = parseBERaw16(r) r.Seek(nl.SizeofNfattr-2, seekCurrent) bytesRead += nl.SizeofNfattr case nl.CTA_PROTOINFO_TCP_FLAGS_REPLY: - parseBERaw16(r, &p.FlagsReply) + p.FlagsReply = parseBERaw16(r) r.Seek(nl.SizeofNfattr-2, seekCurrent) bytesRead += nl.SizeofNfattr default: @@ -778,23 +784,22 @@ func parseProtoInfo(r *bytes.Reader, attrLen uint16) (p ProtoInfo) { } func parseTimeOut(r *bytes.Reader) (ttimeout uint32) { - parseBERaw32(r, &ttimeout) + ttimeout = parseBERaw32(r) return } func parseConnectionMark(r *bytes.Reader) (mark uint32) { - parseBERaw32(r, &mark) + mark = parseBERaw32(r) return } -func parseConnectionLabels(r *bytes.Reader) (label []byte) { - label = make([]byte, 16) // netfilter defines 128 bit labels value - binary.Read(r, nl.NativeEndian(), &label) +func parseConnectionLabels(r *bytes.Reader) (label [16]byte) { + r.Read(label[:]) return } func parseConnectionZone(r *bytes.Reader) (zone uint16) { - parseBERaw16(r, &zone) + zone = parseBERaw16(r) r.Seek(2, seekCurrent) return } @@ -809,7 +814,7 @@ func parseRawData(data []byte, allocator func() *ConntrackFlow) *ConntrackFlow { // First there is the Nfgenmsg header // consume only the family field reader := bytes.NewReader(data) - binary.Read(reader, nl.NativeEndian(), &s.FamilyType) + s.FamilyType = parseU8(reader) // skip rest of the Netfilter header reader.Seek(3, seekCurrent) @@ -853,16 +858,18 @@ func parseRawData(data []byte, allocator func() *ConntrackFlow) *ConntrackFlow { s.Zone = parseConnectionZone(reader) case nl.CTA_LABELS: s.Labels = parseConnectionLabels(reader) + s.HasLabels = true case nl.CTA_LABELS_MASK: s.LabelsMask = parseConnectionLabels(reader) + s.HasLabelsMask = true case nl.CTA_TIMEOUT: s.TimeOut = parseTimeOut(reader) case nl.CTA_STATUS: - parseBERaw32(reader, &s.Status) + s.Status = parseBERaw32(reader) case nl.CTA_USE: - parseBERaw32(reader, &s.Use) + s.Use = parseBERaw32(reader) case nl.CTA_ID: - parseBERaw32(reader, &s.ID) + s.ID = parseBERaw32(reader) default: skipNfAttrValue(reader, l) } @@ -931,7 +938,7 @@ type ConntrackFilter struct { ipNetFilter map[ConntrackFilterType]*net.IPNet portFilter map[ConntrackFilterType]uint16 protoFilter uint8 - labelFilter map[ConntrackFilterType][][]byte + labelFilter map[ConntrackFilterType][][16]byte zoneFilter *uint16 } @@ -996,12 +1003,12 @@ func (f *ConntrackFilter) AddProtocol(proto uint8) error { // against the list of provided labels. If `flow.Labels` does NOT contain ALL the provided labels // it is considered a match. This can be used when you want to match flows that don't contain // one or more labels. -func (f *ConntrackFilter) AddLabels(tp ConntrackFilterType, labels [][]byte) error { +func (f *ConntrackFilter) AddLabels(tp ConntrackFilterType, labels [][16]byte) error { if len(labels) == 0 { return errors.New("Invalid length for provided labels") } if f.labelFilter == nil { - f.labelFilter = make(map[ConntrackFilterType][][]byte) + f.labelFilter = make(map[ConntrackFilterType][][16]byte) } if _, ok := f.labelFilter[tp]; ok { return errors.New("Filter attribute already present") @@ -1083,19 +1090,19 @@ func (f *ConntrackFilter) MatchConntrackFlow(flow *ConntrackFlow) bool { // Label filter if len(f.labelFilter) > 0 { - if len(flow.Labels) > 0 { + if flow.HasLabels { // --label label1,label2 in conn entry; // every label passed should be contained in flow.Labels for a match to be true if elem, found := f.labelFilter[ConntrackMatchLabels]; match && found { for _, label := range elem { - match = match && (bytes.Contains(flow.Labels, label)) + match = match && (bytes.Contains(flow.Labels[:], label[:])) } } // --label label1,label2 in conn entry; // every label passed should be not contained in flow.Labels for a match to be true if elem, found := f.labelFilter[ConntrackUnmatchLabels]; match && found { for _, label := range elem { - match = match && !(bytes.Contains(flow.Labels, label)) + match = match && !(bytes.Contains(flow.Labels[:], label[:])) } } } else { diff --git a/conntrack_test.go b/conntrack_test.go index 97d0202b..3a981223 100644 --- a/conntrack_test.go +++ b/conntrack_test.go @@ -451,8 +451,9 @@ func TestConntrackFilter(t *testing.T) { DstPort: 5000, Protocol: 6, }, - Labels: []byte{0, 0, 0, 0, 3, 4, 61, 141, 207, 170, 2, 0, 0, 0, 0, 0}, - Zone: 200, + Labels: [16]byte{0, 0, 0, 0, 3, 4, 61, 141, 207, 170, 2, 0, 0, 0, 0, 0}, + HasLabels: true, + Zone: 200, }, ConntrackFlow{ FamilyType: unix.AF_INET6, @@ -809,9 +810,9 @@ func TestConntrackFilter(t *testing.T) { // Labels filter filterV4 = &ConntrackFilter{} - var labels [][]byte - labels = append(labels, []byte{3, 4, 61, 141, 207, 170}) - labels = append(labels, []byte{0x2}) + var labels [][16]byte + labels = append(labels, [16]byte{3, 4, 61, 141, 207, 170, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) + labels = append(labels, [16]byte{0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}) err = filterV4.AddLabels(ConntrackMatchLabels, labels) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -1562,9 +1563,10 @@ func TestConntrackLabels(t *testing.T) { }, // No point checking equivalence of timeout, but value must // be reasonable to allow for a potentially slow subsequent read. - TimeOut: 100, - Mark: 12, - Labels: []byte{0, 0, 0, 0, 3, 4, 61, 141, 207, 170, 2, 0, 0, 0, 0, 0}, + TimeOut: 100, + Mark: 12, + Labels: [16]byte{0, 0, 0, 0, 3, 4, 61, 141, 207, 170, 2, 0, 0, 0, 0, 0}, + HasLabels: true, ProtoInfo: &ProtoInfoTCP{ State: nl.TCP_CONNTRACK_SYN_SENT2, }, @@ -1617,7 +1619,7 @@ func TestConntrackLabels(t *testing.T) { // Change the conntrack and update the kernel entry. flow.Mark = 10 - flow.Labels = make([]byte, 16) // zero labels + flow.Labels = [16]byte{} // zero labels flow.ProtoInfo = &ProtoInfoTCP{ State: nl.TCP_CONNTRACK_ESTABLISHED, } @@ -1647,7 +1649,7 @@ func TestConntrackLabels(t *testing.T) { // To clear the labels we send an empty slice, but when reading back // from the kernel we get a nil slice. - flow.Labels = nil + flow.Labels = [16]byte{} checkFlowsEqual(t, &flow, updatedMatch) checkProtoInfosEqual(t, flow.ProtoInfo, updatedMatch.ProtoInfo) // Switch back to the original namespace @@ -1673,9 +1675,10 @@ func TestConntrackFlowToNlData(t *testing.T) { DstPort: 48385, Protocol: unix.IPPROTO_TCP, }, - Mark: 5, - Labels: []byte{0, 0, 0, 0, 3, 4, 61, 141, 207, 170, 2, 0, 0, 0, 0, 0}, - TimeOut: 10, + Mark: 5, + Labels: [16]byte{0, 0, 0, 0, 3, 4, 61, 141, 207, 170, 2, 0, 0, 0, 0, 0}, + HasLabels: true, + TimeOut: 10, ProtoInfo: &ProtoInfoTCP{ State: nl.TCP_CONNTRACK_ESTABLISHED, }, @@ -1696,9 +1699,10 @@ func TestConntrackFlowToNlData(t *testing.T) { DstPort: 48385, Protocol: unix.IPPROTO_TCP, }, - Mark: 5, - Labels: []byte{0, 0, 0, 0, 3, 4, 61, 141, 207, 170, 2, 0, 0, 0, 0, 0}, - TimeOut: 10, + Mark: 5, + Labels: [16]byte{0, 0, 0, 0, 3, 4, 61, 141, 207, 170, 2, 0, 0, 0, 0, 0}, + HasLabels: true, + TimeOut: 10, ProtoInfo: &ProtoInfoTCP{ State: nl.TCP_CONNTRACK_ESTABLISHED, }, @@ -1755,7 +1759,7 @@ func checkFlowsEqual(t *testing.T, f1, f2 *ConntrackFlow) { t.Fail() } - if !bytes.Equal(f1.Labels, f2.Labels) { + if !bytes.Equal(f1.Labels[:], f2.Labels[:]) { t.Logf("Conntrack flow Labels differ. Tuple1: %+v, Tuple2: %+v.\n", f1.Labels, f2.Labels) t.Fail() } From ad90fa2deedbed022777bfc9deae3f59a1cbdf12 Mon Sep 17 00:00:00 2001 From: eustrain Date: Tue, 10 Feb 2026 02:22:23 +0000 Subject: [PATCH 3/5] chore: check error --- conntrack_linux.go | 45 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/conntrack_linux.go b/conntrack_linux.go index df1309cb..5d7f17ad 100644 --- a/conntrack_linux.go +++ b/conntrack_linux.go @@ -622,7 +622,13 @@ func parseNfAttrTLV(r *bytes.Reader) (isNested bool, attrType, len uint16, value isNested, attrType, len = parseNfAttrTL(r) value = make([]byte, len) - io.ReadAtLeast(r, value, int(len)) + n, err := io.ReadAtLeast(r, value, int(len)) + if err != nil { + panic(err) + } + if n != int(len) { + panic(fmt.Errorf("expected %d bytes for nfattr value, got %d", len, n)) + } return isNested, attrType, len, value } @@ -646,31 +652,58 @@ func skipNfAttrValue(r *bytes.Reader, len uint16) uint16 { } func parseU8(r *bytes.Reader) uint8 { - b, _ := r.ReadByte() + b, err := r.ReadByte() + if err != nil { + panic(err) + } return b } func parseBERaw16(r *bytes.Reader) uint16 { var buf [2]byte - io.ReadAtLeast(r, buf[:], 2) + n, err := io.ReadAtLeast(r, buf[:], 2) + if err != nil { + panic(err) + } + if n != 2 { + panic(fmt.Errorf("expected 2 bytes for uint16, got %d", n)) + } return binary.BigEndian.Uint16(buf[:]) } func parseBERaw32(r *bytes.Reader) uint32 { var buf [4]byte - io.ReadAtLeast(r, buf[:], 4) + n, err := io.ReadAtLeast(r, buf[:], 4) + if err != nil { + panic(err) + } + if n != 4 { + panic(fmt.Errorf("expected 4 bytes for uint32, got %d", n)) + } return binary.BigEndian.Uint32(buf[:]) } func parseBERaw64(r *bytes.Reader) uint64 { var buf [8]byte - io.ReadAtLeast(r, buf[:], 8) + n, err := io.ReadAtLeast(r, buf[:], 8) + if err != nil { + panic(err) + } + if n != 8 { + panic(fmt.Errorf("expected 8 bytes for uint64, got %d", n)) + } return binary.BigEndian.Uint64(buf[:]) } func parseRaw16(r *bytes.Reader) uint16 { var buf [2]byte - io.ReadAtLeast(r, buf[:], 2) + n, err := io.ReadAtLeast(r, buf[:], 2) + if err != nil { + panic(err) + } + if n != 2 { + panic(fmt.Errorf("expected 2 bytes for uint16, got %d", n)) + } return binary.BigEndian.Uint16(buf[:]) } From f40b53d48149a82c981eb12c870c5183ba6c2f20 Mon Sep 17 00:00:00 2001 From: eustrain Date: Tue, 24 Feb 2026 06:23:11 +0000 Subject: [PATCH 4/5] perf: avoid allocations when serialization --- conntrack_linux.go | 449 +++++++++++++++++++++------------------ conntrack_test.go | 349 +++++++++++++++++++++++++++--- neigh_linux.go | 6 + nl/addr_linux.go | 10 + nl/conntrack_linux.go | 5 + nl/genetlink_linux.go | 5 + nl/link_linux.go | 45 ++++ nl/nexthop_linux.go | 5 + nl/nl_linux.go | 82 ++++++- nl/route_linux.go | 28 +++ nl/tc_linux.go | 107 +++++++++- nl/xfrm_linux.go | 25 +++ nl/xfrm_monitor_linux.go | 5 + nl/xfrm_policy_linux.go | 15 ++ nl/xfrm_state_linux.go | 56 +++++ rdma_link_linux.go | 45 +++- socket_linux.go | 20 +- socket_xdp_linux.go | 10 +- 18 files changed, 1007 insertions(+), 260 deletions(-) diff --git a/conntrack_linux.go b/conntrack_linux.go index 5d7f17ad..cb390045 100644 --- a/conntrack_linux.go +++ b/conntrack_linux.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "errors" "fmt" - "io" "io/fs" "net" "time" @@ -119,40 +118,74 @@ func (h *Handle) ConntrackTableFlush(table ConntrackTableType) error { return err } +func (h *Handle) NewConntrackCreateRequest(table ConntrackTableType, family InetFamily, ack bool) nl.NetlinkRequest { + if ack { + return h.newConntrackRequest(table, family, nl.IPCTNL_MSG_CT_NEW, unix.NLM_F_ACK|unix.NLM_F_CREATE) + } + return h.newConntrackRequest(table, family, nl.IPCTNL_MSG_CT_NEW, unix.NLM_F_CREATE) +} + // ConntrackCreate creates a new conntrack flow in the desired table using the handle // conntrack -I [table] Create a conntrack or expectation func (h *Handle) ConntrackCreate(table ConntrackTableType, family InetFamily, flow *ConntrackFlow) error { req := h.newConntrackRequest(table, family, nl.IPCTNL_MSG_CT_NEW, unix.NLM_F_ACK|unix.NLM_F_CREATE) - attr, err := flow.toNlData() + attr, err := flow.toNlData(nl.NewRtAttr, make([]nl.NetlinkRequestData, 32)) if err != nil { return err } - + newData := make([]nl.NetlinkRequestData, 0, len(attr)+len(req.Data)) + newData = append(newData, req.Data...) for _, a := range attr { - req.AddData(a) + newData = append(newData, a) } + req.Data = newData _, err = req.Execute(unix.NETLINK_NETFILTER, 0) return err } +func (h *Handle) NewConntrackUpdateRequest(table ConntrackTableType, family InetFamily, ack bool) nl.NetlinkRequest { + if ack { + return h.newConntrackRequest(table, family, nl.IPCTNL_MSG_CT_NEW, unix.NLM_F_ACK|unix.NLM_F_REPLACE) + } + return h.newConntrackRequest(table, family, nl.IPCTNL_MSG_CT_NEW, unix.NLM_F_REPLACE) +} + // ConntrackUpdate updates an existing conntrack flow in the desired table using the handle // conntrack -U [table] Update a conntrack func (h *Handle) ConntrackUpdate(table ConntrackTableType, family InetFamily, flow *ConntrackFlow) error { req := h.newConntrackRequest(table, family, nl.IPCTNL_MSG_CT_NEW, unix.NLM_F_ACK|unix.NLM_F_REPLACE) - attr, err := flow.toNlData() + attr, err := flow.toNlData(nl.NewRtAttr, make([]nl.NetlinkRequestData, 32)) if err != nil { return err } + newData := make([]nl.NetlinkRequestData, 0, len(attr)+len(req.Data)) + newData = append(newData, req.Data...) for _, a := range attr { - req.AddData(a) + newData = append(newData, a) } + req.Data = newData _, err = req.Execute(unix.NETLINK_NETFILTER, 0) return err } +func (h *Handle) ExecuteConntrackRequest(req nl.NetlinkRequest, conntrackFlow *ConntrackFlow, + newRtAttr func(attrType int, data []byte) *nl.RtAttr, buf []nl.NetlinkRequestData, + checkError bool) error { + attr, err := conntrackFlow.toNlData(newRtAttr, buf) + if err != nil { + return err + } + req.Data = append(req.Data, attr...) + if !checkError { + return req.ExecuteIter(unix.NETLINK_NETFILTER, 0, nil) + } + _, err = req.Execute(unix.NETLINK_NETFILTER, 0) + return err +} + // ConntrackDeleteFilter deletes entries on the specified table on the base of the filter using the netlink handle passed // conntrack -D [table] parameters Delete conntrack or expectation // @@ -251,22 +284,20 @@ type ProtoInfoTCP struct { // Protocol returns "tcp". func (*ProtoInfoTCP) Protocol() string { return "tcp" } -func (p *ProtoInfoTCP) toNlData() ([]*nl.RtAttr, error) { - ctProtoInfo := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_PROTOINFO, []byte{}) - ctProtoInfoTCP := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_PROTOINFO_TCP, []byte{}) - ctProtoInfoTCPState := nl.NewRtAttr(nl.CTA_PROTOINFO_TCP_STATE, nl.Uint8Attr(p.State)) - ctProtoInfoTCP.AddChild(ctProtoInfoTCPState) - ctProtoInfoTCPWscaleOriginal := nl.NewRtAttr(nl.CTA_PROTOINFO_TCP_WSCALE_ORIGINAL, nl.Uint8Attr(p.WsacleOriginal)) - ctProtoInfoTCP.AddChild(ctProtoInfoTCPWscaleOriginal) - ctProtoInfoTCPWscaleReply := nl.NewRtAttr(nl.CTA_PROTOINFO_TCP_WSCALE_REPLY, nl.Uint8Attr(p.WsacleReply)) - ctProtoInfoTCP.AddChild(ctProtoInfoTCPWscaleReply) - ctProtoInfoTCPFlagsOriginal := nl.NewRtAttr(nl.CTA_PROTOINFO_TCP_FLAGS_ORIGINAL, nl.BEUint16Attr(p.FlagsOriginal)) - ctProtoInfoTCP.AddChild(ctProtoInfoTCPFlagsOriginal) - ctProtoInfoTCPFlagsReply := nl.NewRtAttr(nl.CTA_PROTOINFO_TCP_FLAGS_REPLY, nl.BEUint16Attr(p.FlagsReply)) - ctProtoInfoTCP.AddChild(ctProtoInfoTCPFlagsReply) +func (p *ProtoInfoTCP) toNlData(newRtAttr func(attrType int, data []byte) *nl.RtAttr, buf []nl.NetlinkRequestData) ([]nl.NetlinkRequestData, error) { + ctProtoInfo := newRtAttr(unix.NLA_F_NESTED|nl.CTA_PROTOINFO, []byte{}) + ctProtoInfoTCP := newRtAttr(unix.NLA_F_NESTED|nl.CTA_PROTOINFO_TCP, []byte{}) + ctProtoInfoTCPState := newRtAttr(nl.CTA_PROTOINFO_TCP_STATE, nl.Uint8Attr(p.State)) + ctProtoInfoTCPWscaleOriginal := newRtAttr(nl.CTA_PROTOINFO_TCP_WSCALE_ORIGINAL, nl.Uint8Attr(p.WsacleOriginal)) + ctProtoInfoTCPWscaleReply := newRtAttr(nl.CTA_PROTOINFO_TCP_WSCALE_REPLY, nl.Uint8Attr(p.WsacleReply)) + ctProtoInfoTCPFlagsOriginal := newRtAttr(nl.CTA_PROTOINFO_TCP_FLAGS_ORIGINAL, nl.BEUint16Attr(p.FlagsOriginal)) + ctProtoInfoTCPFlagsReply := newRtAttr(nl.CTA_PROTOINFO_TCP_FLAGS_REPLY, nl.BEUint16Attr(p.FlagsReply)) + + ctProtoInfoTCP.AddChilds(ctProtoInfoTCPState, ctProtoInfoTCPWscaleOriginal, ctProtoInfoTCPWscaleReply, ctProtoInfoTCPFlagsOriginal, ctProtoInfoTCPFlagsReply) ctProtoInfo.AddChild(ctProtoInfoTCP) - return []*nl.RtAttr{ctProtoInfo}, nil + buf[0] = ctProtoInfo + return buf[0:1], nil } // ProtoInfoSCTP only supports the protocol name. @@ -301,7 +332,7 @@ type IPTuple struct { // toNlData generates the inner fields of a nested tuple netlink datastructure // does not generate the "nested"-flagged outer message. -func (t *IPTuple) toNlData(family uint8) ([]*nl.RtAttr, error) { +func (t *IPTuple) toNlData(family uint8, newRtAttr func(attrType int, data []byte) *nl.RtAttr, buf []nl.NetlinkRequestData) ([]nl.NetlinkRequestData, error) { var srcIPsFlag, dstIPsFlag int switch family { case nl.FAMILY_V4: @@ -311,17 +342,38 @@ func (t *IPTuple) toNlData(family uint8) ([]*nl.RtAttr, error) { srcIPsFlag = nl.CTA_IP_V6_SRC dstIPsFlag = nl.CTA_IP_V6_DST default: - return []*nl.RtAttr{}, fmt.Errorf("couldn't generate netlink message for tuple due to unrecognized FamilyType '%d'", family) + return []nl.NetlinkRequestData{}, fmt.Errorf("couldn't generate netlink message for tuple due to unrecognized FamilyType '%d'", family) } - ctTupleIP := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_TUPLE_IP, nil) - ctTupleIPSrc := nl.NewRtAttr(srcIPsFlag, t.SrcIP) - ctTupleIP.AddChild(ctTupleIPSrc) - ctTupleIPDst := nl.NewRtAttr(dstIPsFlag, t.DstIP) - ctTupleIP.AddChild(ctTupleIPDst) + // For IPv4 the kernel expects exactly 4 bytes; use To4() so we never send nil or 16-byte form. + var srcData, dstData []byte + switch family { + case nl.FAMILY_V4: + if t.SrcIP != nil { + srcData = t.SrcIP.To4() + } + if t.DstIP != nil { + dstData = t.DstIP.To4() + } + if len(srcData) != 4 || len(dstData) != 4 { + return []nl.NetlinkRequestData{}, fmt.Errorf("conntrack IPv4 tuple requires 4-byte SrcIP and DstIP, got len %d and %d", len(srcData), len(dstData)) + } + case nl.FAMILY_V6: + srcData = t.SrcIP + dstData = t.DstIP + if len(srcData) != 16 || len(dstData) != 16 { + return []nl.NetlinkRequestData{}, fmt.Errorf("conntrack IPv6 tuple requires 16-byte SrcIP and DstIP, got len %d and %d", len(srcData), len(dstData)) + } + } + ctTupleIP := newRtAttr(unix.NLA_F_NESTED|nl.CTA_TUPLE_IP, nil) + ctTupleIP.ReserveMoreChildren(6) + srcIPsFlagAttr := newRtAttr(srcIPsFlag, srcData) + dstIPsFlagAttr := newRtAttr(dstIPsFlag, dstData) + + ctTupleIP.AddChilds(srcIPsFlagAttr, dstIPsFlagAttr) - ctTupleProto := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_TUPLE_PROTO, nil) - ctTupleProtoNum := nl.NewRtAttr(nl.CTA_PROTO_NUM, []byte{t.Protocol}) + ctTupleProto := newRtAttr(unix.NLA_F_NESTED|nl.CTA_TUPLE_PROTO, nil) + ctTupleProtoNum := newRtAttr(nl.CTA_PROTO_NUM, []byte{t.Protocol}) ctTupleProto.AddChild(ctTupleProtoNum) // Protocol-specific attribute handling: @@ -345,28 +397,23 @@ func (t *IPTuple) toNlData(family uint8) ([]*nl.RtAttr, error) { switch t.Protocol { case unix.IPPROTO_ICMP: // ICMP uses icmp_nlattr_to_tuple, requires ID/Type/Code - ctTupleProtoICMPID := nl.NewRtAttr(nl.CTA_PROTO_ICMP_ID, nl.BEUint16Attr(t.ICMPID)) - ctTupleProto.AddChild(ctTupleProtoICMPID) - ctTupleProtoICMPType := nl.NewRtAttr(nl.CTA_PROTO_ICMP_TYPE, []byte{t.ICMPType}) - ctTupleProto.AddChild(ctTupleProtoICMPType) - ctTupleProtoICMPCode := nl.NewRtAttr(nl.CTA_PROTO_ICMP_CODE, []byte{t.ICMPCode}) - ctTupleProto.AddChild(ctTupleProtoICMPCode) + ctTupleProtoICMPID := newRtAttr(nl.CTA_PROTO_ICMP_ID, nl.BEUint16Attr(t.ICMPID)) + ctTupleProtoICMPType := newRtAttr(nl.CTA_PROTO_ICMP_TYPE, []byte{t.ICMPType}) + ctTupleProtoICMPCode := newRtAttr(nl.CTA_PROTO_ICMP_CODE, []byte{t.ICMPCode}) + ctTupleProto.AddChilds(ctTupleProtoICMPID, ctTupleProtoICMPType, ctTupleProtoICMPCode) case unix.IPPROTO_ICMPV6: // ICMPv6 uses icmpv6_nlattr_to_tuple, requires ID/Type/Code - ctTupleProtoICMPV6ID := nl.NewRtAttr(nl.CTA_PROTO_ICMPV6_ID, nl.BEUint16Attr(t.ICMPID)) - ctTupleProto.AddChild(ctTupleProtoICMPV6ID) - ctTupleProtoICMPV6Type := nl.NewRtAttr(nl.CTA_PROTO_ICMPV6_TYPE, []byte{t.ICMPType}) - ctTupleProto.AddChild(ctTupleProtoICMPV6Type) - ctTupleProtoICMPV6Code := nl.NewRtAttr(nl.CTA_PROTO_ICMPV6_CODE, []byte{t.ICMPCode}) - ctTupleProto.AddChild(ctTupleProtoICMPV6Code) + ctTupleProtoICMPV6ID := newRtAttr(nl.CTA_PROTO_ICMPV6_ID, nl.BEUint16Attr(t.ICMPID)) + ctTupleProtoICMPV6Type := newRtAttr(nl.CTA_PROTO_ICMPV6_TYPE, []byte{t.ICMPType}) + ctTupleProtoICMPV6Code := newRtAttr(nl.CTA_PROTO_ICMPV6_CODE, []byte{t.ICMPCode}) + ctTupleProto.AddChilds(ctTupleProtoICMPV6ID, ctTupleProtoICMPV6Type, ctTupleProtoICMPV6Code) case unix.IPPROTO_TCP, unix.IPPROTO_UDP, unix.IPPROTO_DCCP, unix.IPPROTO_SCTP, unix.IPPROTO_UDPLITE, unix.IPPROTO_GRE: // All these protocols use nf_ct_port_nlattr_to_tuple() which requires both port attributes. // For protocols without ports (like GRE), ports must be set to 0, but the attributes must still be present. // Without these attributes, ctnetlink_parse_tuple_proto() will return -EINVAL. - ctTupleProtoSrcPort := nl.NewRtAttr(nl.CTA_PROTO_SRC_PORT, nl.BEUint16Attr(t.SrcPort)) - ctTupleProto.AddChild(ctTupleProtoSrcPort) - ctTupleProtoDstPort := nl.NewRtAttr(nl.CTA_PROTO_DST_PORT, nl.BEUint16Attr(t.DstPort)) - ctTupleProto.AddChild(ctTupleProtoDstPort) + ctTupleProtoSrcPort := newRtAttr(nl.CTA_PROTO_SRC_PORT, nl.BEUint16Attr(t.SrcPort)) + ctTupleProtoDstPort := newRtAttr(nl.CTA_PROTO_DST_PORT, nl.BEUint16Attr(t.DstPort)) + ctTupleProto.AddChilds(ctTupleProtoSrcPort, ctTupleProtoDstPort) case unix.IPPROTO_IPIP: fallthrough default: @@ -374,8 +421,10 @@ func (t *IPTuple) toNlData(family uint8) ([]*nl.RtAttr, error) { // which has no nlattr_to_tuple function, so only CTA_PROTO_NUM is required. // No additional attributes needed. } + buf[0] = ctTupleIP + buf[1] = ctTupleProto - return []*nl.RtAttr{ctTupleIP, ctTupleProto}, nil + return buf[0:2], nil } type ConntrackFlow struct { @@ -436,8 +485,8 @@ func (s *ConntrackFlow) String() string { } // toNlData generates netlink messages representing the flow. -func (s *ConntrackFlow) toNlData() ([]*nl.RtAttr, error) { - var payload []*nl.RtAttr +func (s *ConntrackFlow) toNlData(newRtAttr func(attrType int, data []byte) *nl.RtAttr, buf []nl.NetlinkRequestData) ([]nl.NetlinkRequestData, error) { + var err error // The message structure is built as follows: // // @@ -480,42 +529,51 @@ func (s *ConntrackFlow) toNlData() ([]*nl.RtAttr, error) { // // CTA_TUPLE_ORIG - ctTupleOrig := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_TUPLE_ORIG, nil) - forwardFlowAttrs, err := s.Forward.toNlData(s.FamilyType) + ctTupleOrig := newRtAttr(unix.NLA_F_NESTED|nl.CTA_TUPLE_ORIG, nil) + var forwardFlowAttrs []nl.NetlinkRequestData + forwardFlowAttrs, err = s.Forward.toNlData(s.FamilyType, newRtAttr, buf) if err != nil { return nil, fmt.Errorf("couldn't generate netlink data for conntrack forward flow: %w", err) } - for _, a := range forwardFlowAttrs { - ctTupleOrig.AddChild(a) - } - + buf = buf[len(forwardFlowAttrs):] + ctTupleOrig.AddChilds(forwardFlowAttrs...) // CTA_TUPLE_REPLY - ctTupleReply := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_TUPLE_REPLY, nil) - reverseFlowAttrs, err := s.Reverse.toNlData(s.FamilyType) + ctTupleReply := newRtAttr(unix.NLA_F_NESTED|nl.CTA_TUPLE_REPLY, nil) + var reverseFlowAttrs []nl.NetlinkRequestData + reverseFlowAttrs, err = s.Reverse.toNlData(s.FamilyType, newRtAttr, buf[2:4]) if err != nil { return nil, fmt.Errorf("couldn't generate netlink data for conntrack reverse flow: %w", err) } - for _, a := range reverseFlowAttrs { - ctTupleReply.AddChild(a) - } + buf = buf[len(reverseFlowAttrs):] + ctTupleReply.AddChilds(reverseFlowAttrs...) - ctMark := nl.NewRtAttr(nl.CTA_MARK, nl.BEUint32Attr(s.Mark)) - ctTimeout := nl.NewRtAttr(nl.CTA_TIMEOUT, nl.BEUint32Attr(s.TimeOut)) + ctMark := newRtAttr(nl.CTA_MARK, nl.BEUint32Attr(s.Mark)) + ctTimeout := newRtAttr(nl.CTA_TIMEOUT, nl.BEUint32Attr(s.TimeOut)) - payload = append(payload, ctTupleOrig, ctTupleReply, ctMark, ctTimeout) // Zone is required for matching conntrack entries in the kernel // The kernel uses zone when looking up conntrack entries: nf_conntrack_find_get(net, &zone, &otuple) - ctZone := nl.NewRtAttr(nl.CTA_ZONE, nl.BEUint16Attr(s.Zone)) - payload = append(payload, ctZone) - ctStatus := nl.NewRtAttr(nl.CTA_STATUS, nl.BEUint32Attr(s.Status)) - payload = append(payload, ctStatus) + ctZone := newRtAttr(nl.CTA_ZONE, nl.BEUint16Attr(s.Zone)) + ctStatus := newRtAttr(nl.CTA_STATUS, nl.BEUint32Attr(s.Status)) + + var payload []nl.NetlinkRequestData + if buf == nil { + payload = make([]nl.NetlinkRequestData, 0, 9) + } else { + payload = buf[0:0] + } + + payload = append(payload, + ctTupleOrig, ctTupleReply, + ctMark, ctTimeout, ctZone, ctStatus, + ) + // Labels: HasLabels => update conntrack labels; else => do not send. if s.HasLabels { - ctLabels := nl.NewRtAttr(nl.CTA_LABELS, s.Labels[:]) + ctLabels := newRtAttr(nl.CTA_LABELS, s.Labels[:]) payload = append(payload, ctLabels) // Labels Mask: HasLabelsMask => update conntrack labels with mask; else => do not send. if s.HasLabelsMask { - ctLabelsMask := nl.NewRtAttr(nl.CTA_LABELS_MASK, s.LabelsMask[:]) + ctLabelsMask := newRtAttr(nl.CTA_LABELS_MASK, s.LabelsMask[:]) payload = append(payload, ctLabelsMask) } } @@ -523,7 +581,8 @@ func (s *ConntrackFlow) toNlData() ([]*nl.RtAttr, error) { if s.ProtoInfo != nil { switch p := s.ProtoInfo.(type) { case *ProtoInfoTCP: - attrs, err := p.toNlData() + var attrs []nl.NetlinkRequestData + attrs, err = p.toNlData(newRtAttr, buf[len(payload):]) if err != nil { return nil, fmt.Errorf("couldn't generate netlink data for conntrack flow's TCP protoinfo: %w", err) } @@ -543,9 +602,9 @@ func (s *ConntrackFlow) toNlData() ([]*nl.RtAttr, error) { // // // -func parseIpTuple(reader *bytes.Reader, tpl *IPTuple) uint8 { +func parseIpTuple(data []byte, offset *int, tpl *IPTuple) uint8 { for i := 0; i < 2; i++ { - _, t, _, v := parseNfAttrTLV(reader) + _, t, _, v := parseNfAttrTLV(data, offset) switch t { case nl.CTA_IP_V4_SRC, nl.CTA_IP_V6_SRC: tpl.SrcIP = v @@ -554,8 +613,8 @@ func parseIpTuple(reader *bytes.Reader, tpl *IPTuple) uint8 { } } // Get total length of nested protocol-specific info. - _, _, protoInfoTotalLen := parseNfAttrTL(reader) - _, t, l, v := parseNfAttrTLV(reader) + _, _, protoInfoTotalLen := parseNfAttrTL(data, offset) + _, t, l, v := parseNfAttrTLV(data, offset) // Track the number of bytes read. protoInfoBytesRead := uint16(nl.SizeofNfattr) + l if t == nl.CTA_PROTO_NUM { @@ -565,11 +624,11 @@ func parseIpTuple(reader *bytes.Reader, tpl *IPTuple) uint8 { if tpl.Protocol != unix.IPPROTO_TCP && tpl.Protocol != unix.IPPROTO_UDP && tpl.Protocol != unix.IPPROTO_ICMP && tpl.Protocol != unix.IPPROTO_ICMPV6 { // skip the rest bytesRemaining := protoInfoTotalLen - protoInfoBytesRead - reader.Seek(int64(bytesRemaining), seekCurrent) + *offset += int(bytesRemaining) return tpl.Protocol } // Skip 3 bytes of padding - reader.Seek(3, seekCurrent) + *offset += 3 protoInfoBytesRead += 3 loopCount := 2 if tpl.Protocol == unix.IPPROTO_ICMP || tpl.Protocol == unix.IPPROTO_ICMPV6 { @@ -577,30 +636,30 @@ func parseIpTuple(reader *bytes.Reader, tpl *IPTuple) uint8 { } var ICMPCodeDone, ICMPTypeDone bool for i := 0; i < loopCount; i++ { - _, t, _ := parseNfAttrTL(reader) + _, t, _ := parseNfAttrTL(data, offset) protoInfoBytesRead += uint16(nl.SizeofNfattr) switch t { case nl.CTA_PROTO_SRC_PORT: - tpl.SrcPort = parseBERaw16(reader) + tpl.SrcPort = parseBERaw16(data, offset) protoInfoBytesRead += 2 case nl.CTA_PROTO_DST_PORT: - tpl.DstPort = parseBERaw16(reader) + tpl.DstPort = parseBERaw16(data, offset) protoInfoBytesRead += 2 case nl.CTA_PROTO_ICMP_ID: fallthrough case nl.CTA_PROTO_ICMPV6_ID: - tpl.ICMPID = parseBERaw16(reader) + tpl.ICMPID = parseBERaw16(data, offset) protoInfoBytesRead += 2 case nl.CTA_PROTO_ICMP_CODE: fallthrough case nl.CTA_PROTO_ICMPV6_CODE: - tpl.ICMPCode = parseU8(reader) + tpl.ICMPCode = parseU8(data, offset) protoInfoBytesRead += 1 ICMPCodeDone = true case nl.CTA_PROTO_ICMP_TYPE: fallthrough case nl.CTA_PROTO_ICMPV6_TYPE: - tpl.ICMPType = parseU8(reader) + tpl.ICMPType = parseU8(data, offset) protoInfoBytesRead += 1 ICMPTypeDone = true } @@ -608,35 +667,29 @@ func parseIpTuple(reader *bytes.Reader, tpl *IPTuple) uint8 { continue } // Skip 2 bytes of padding - reader.Seek(2, seekCurrent) + *offset += 2 protoInfoBytesRead += 2 } // Skip any remaining/unknown parts of the message bytesRemaining := protoInfoTotalLen - protoInfoBytesRead - reader.Seek(int64(bytesRemaining), seekCurrent) + *offset += int(bytesRemaining) return tpl.Protocol } -func parseNfAttrTLV(r *bytes.Reader) (isNested bool, attrType, len uint16, value []byte) { - isNested, attrType, len = parseNfAttrTL(r) +func parseNfAttrTLV(data []byte, offset *int) (isNested bool, attrType, len uint16, value []byte) { + isNested, attrType, len = parseNfAttrTL(data, offset) - value = make([]byte, len) - n, err := io.ReadAtLeast(r, value, int(len)) - if err != nil { - panic(err) - } - if n != int(len) { - panic(fmt.Errorf("expected %d bytes for nfattr value, got %d", len, n)) - } + value = data[*offset : *offset+int(len)] + *offset += int(len) return isNested, attrType, len, value } -func parseNfAttrTL(r *bytes.Reader) (isNested bool, attrType, len uint16) { - len = parseRaw16(r) +func parseNfAttrTL(data []byte, offset *int) (isNested bool, attrType, len uint16) { + len = parseRaw16(data, offset) len -= nl.SizeofNfattr - attrType = parseRaw16(r) + attrType = parseRaw16(data, offset) isNested = (attrType & nl.NLA_F_NESTED) == nl.NLA_F_NESTED attrType = attrType & (nl.NLA_F_NESTED - 1) return isNested, attrType, len @@ -645,75 +698,52 @@ func parseNfAttrTL(r *bytes.Reader) (isNested bool, attrType, len uint16) { // skipNfAttrValue seeks `r` past attr of length `len`. // Maintains buffer alignment. // Returns length of the seek performed. -func skipNfAttrValue(r *bytes.Reader, len uint16) uint16 { +func skipNfAttrValue(data []byte, offset *int, len uint16) uint16 { + _ = data len = (len + nl.NLA_ALIGNTO - 1) & ^(nl.NLA_ALIGNTO - 1) - r.Seek(int64(len), seekCurrent) + *offset += int(len) return len } -func parseU8(r *bytes.Reader) uint8 { - b, err := r.ReadByte() - if err != nil { - panic(err) - } - return b +func parseU8(data []byte, offset *int) uint8 { + value := data[*offset] + *offset += 1 + return value } -func parseBERaw16(r *bytes.Reader) uint16 { - var buf [2]byte - n, err := io.ReadAtLeast(r, buf[:], 2) - if err != nil { - panic(err) - } - if n != 2 { - panic(fmt.Errorf("expected 2 bytes for uint16, got %d", n)) - } - return binary.BigEndian.Uint16(buf[:]) +func parseBERaw16(data []byte, offset *int) uint16 { + value := binary.BigEndian.Uint16(data[*offset : *offset+2]) + *offset += 2 + return value } -func parseBERaw32(r *bytes.Reader) uint32 { - var buf [4]byte - n, err := io.ReadAtLeast(r, buf[:], 4) - if err != nil { - panic(err) - } - if n != 4 { - panic(fmt.Errorf("expected 4 bytes for uint32, got %d", n)) - } - return binary.BigEndian.Uint32(buf[:]) +func parseBERaw32(data []byte, offset *int) uint32 { + value := binary.BigEndian.Uint32(data[*offset : *offset+4]) + *offset += 4 + return value } -func parseBERaw64(r *bytes.Reader) uint64 { - var buf [8]byte - n, err := io.ReadAtLeast(r, buf[:], 8) - if err != nil { - panic(err) - } - if n != 8 { - panic(fmt.Errorf("expected 8 bytes for uint64, got %d", n)) - } - return binary.BigEndian.Uint64(buf[:]) +func parseBERaw64(data []byte, offset *int) uint64 { + value := binary.BigEndian.Uint64(data[*offset : *offset+8]) + *offset += 8 + return value } -func parseRaw16(r *bytes.Reader) uint16 { - var buf [2]byte - n, err := io.ReadAtLeast(r, buf[:], 2) - if err != nil { - panic(err) - } - if n != 2 { - panic(fmt.Errorf("expected 2 bytes for uint16, got %d", n)) - } - return binary.BigEndian.Uint16(buf[:]) +// parseRaw16 reads 2 bytes in native (host) byte order. Used for netlink attribute +// header (len, type) which is always native per kernel ABI. +func parseRaw16(data []byte, offset *int) uint16 { + buf := data[*offset : *offset+2] + *offset += 2 + return nl.NativeEndian().Uint16(buf) } -func parseByteAndPacketCounters(r *bytes.Reader) (bytes, packets uint64) { +func parseByteAndPacketCounters(data []byte, offset *int) (bytes, packets uint64) { for i := 0; i < 2; i++ { - switch _, t, _ := parseNfAttrTL(r); t { + switch _, t, _ := parseNfAttrTL(data, offset); t { case nl.CTA_COUNTERS_BYTES: - bytes = parseBERaw64(r) + bytes = parseBERaw64(data, offset) case nl.CTA_COUNTERS_PACKETS: - packets = parseBERaw64(r) + packets = parseBERaw64(data, offset) default: return } @@ -722,7 +752,7 @@ func parseByteAndPacketCounters(r *bytes.Reader) (bytes, packets uint64) { } // when the flow is alive, only the timestamp_start is returned in structure -func parseTimeStamp(r *bytes.Reader, readSize uint16) (tstart, tstop uint64) { +func parseTimeStamp(data []byte, offset *int, readSize uint16) (tstart, tstop uint64) { var numTimeStamps int oneItem := nl.SizeofNfattr + 8 // 4 bytes attr header + 8 bytes timestamp if readSize == uint16(oneItem) { @@ -733,11 +763,11 @@ func parseTimeStamp(r *bytes.Reader, readSize uint16) (tstart, tstop uint64) { return } for i := 0; i < numTimeStamps; i++ { - switch _, t, _ := parseNfAttrTL(r); t { + switch _, t, _ := parseNfAttrTL(data, offset); t { case nl.CTA_TIMESTAMP_START: - tstart = parseBERaw64(r) + tstart = parseBERaw64(data, offset) case nl.CTA_TIMESTAMP_STOP: - tstop = parseBERaw64(r) + tstop = parseBERaw64(data, offset) default: return } @@ -746,69 +776,69 @@ func parseTimeStamp(r *bytes.Reader, readSize uint16) (tstart, tstop uint64) { } -func parseProtoInfoTCPState(r *bytes.Reader) (s uint8) { - s, _ = r.ReadByte() - r.Seek(nl.SizeofNfattr-1, seekCurrent) +func parseProtoInfoTCPState(data []byte, offset *int) (s uint8) { + s = data[*offset] + *offset += nl.SizeofNfattr // 1 + (nl.SizeofNfattr - 1) return s } // parseProtoInfoTCP reads the entire nested protoinfo structure, but only parses the state attr. -func parseProtoInfoTCP(r *bytes.Reader, attrLen uint16) *ProtoInfoTCP { +func parseProtoInfoTCP(data []byte, offset *int, attrLen uint16) *ProtoInfoTCP { p := new(ProtoInfoTCP) bytesRead := 0 for bytesRead < int(attrLen) { - _, t, l := parseNfAttrTL(r) + _, t, l := parseNfAttrTL(data, offset) bytesRead += nl.SizeofNfattr switch t { case nl.CTA_PROTOINFO_TCP_STATE: - p.State = parseProtoInfoTCPState(r) + p.State = parseProtoInfoTCPState(data, offset) bytesRead += nl.SizeofNfattr case nl.CTA_PROTOINFO_TCP_WSCALE_ORIGINAL: - p.WsacleOriginal = parseU8(r) - r.Seek(nl.SizeofNfattr-1, seekCurrent) + p.WsacleOriginal = parseU8(data, offset) + *offset += int(nl.SizeofNfattr - 1) bytesRead += nl.SizeofNfattr case nl.CTA_PROTOINFO_TCP_WSCALE_REPLY: - p.WsacleReply = parseU8(r) - r.Seek(nl.SizeofNfattr-1, seekCurrent) + p.WsacleReply = parseU8(data, offset) + *offset += int(nl.SizeofNfattr - 1) bytesRead += nl.SizeofNfattr case nl.CTA_PROTOINFO_TCP_FLAGS_ORIGINAL: - p.FlagsOriginal = parseBERaw16(r) - r.Seek(nl.SizeofNfattr-2, seekCurrent) + p.FlagsOriginal = parseBERaw16(data, offset) + *offset += int(nl.SizeofNfattr - 2) bytesRead += nl.SizeofNfattr case nl.CTA_PROTOINFO_TCP_FLAGS_REPLY: - p.FlagsReply = parseBERaw16(r) - r.Seek(nl.SizeofNfattr-2, seekCurrent) + p.FlagsReply = parseBERaw16(data, offset) + *offset += int(nl.SizeofNfattr - 2) bytesRead += nl.SizeofNfattr default: - bytesRead += int(skipNfAttrValue(r, l)) + bytesRead += int(skipNfAttrValue(data, offset, l)) } } return p } -func parseProtoInfo(r *bytes.Reader, attrLen uint16) (p ProtoInfo) { +func parseProtoInfo(data []byte, offset *int, attrLen uint16) (p ProtoInfo) { bytesRead := 0 for bytesRead < int(attrLen) { - _, t, l := parseNfAttrTL(r) + _, t, l := parseNfAttrTL(data, offset) bytesRead += nl.SizeofNfattr switch t { case nl.CTA_PROTOINFO_TCP: - p = parseProtoInfoTCP(r, l) + p = parseProtoInfoTCP(data, offset, l) bytesRead += int(l) // No inner fields of DCCP / SCTP currently supported. case nl.CTA_PROTOINFO_DCCP: p = new(ProtoInfoDCCP) - skipped := skipNfAttrValue(r, l) + skipped := skipNfAttrValue(data, offset, l) bytesRead += int(skipped) case nl.CTA_PROTOINFO_SCTP: p = new(ProtoInfoSCTP) - skipped := skipNfAttrValue(r, l) + skipped := skipNfAttrValue(data, offset, l) bytesRead += int(skipped) default: - skipped := skipNfAttrValue(r, l) + skipped := skipNfAttrValue(data, offset, l) bytesRead += int(skipped) } } @@ -816,24 +846,25 @@ func parseProtoInfo(r *bytes.Reader, attrLen uint16) (p ProtoInfo) { return p } -func parseTimeOut(r *bytes.Reader) (ttimeout uint32) { - ttimeout = parseBERaw32(r) +func parseTimeOut(data []byte, offset *int) (ttimeout uint32) { + ttimeout = parseBERaw32(data, offset) return } -func parseConnectionMark(r *bytes.Reader) (mark uint32) { - mark = parseBERaw32(r) +func parseConnectionMark(data []byte, offset *int) (mark uint32) { + mark = parseBERaw32(data, offset) return } -func parseConnectionLabels(r *bytes.Reader) (label [16]byte) { - r.Read(label[:]) +func parseConnectionLabels(data []byte, offset *int) (label [16]byte) { + copy(label[:], data[*offset:*offset+16]) + *offset += 16 return } -func parseConnectionZone(r *bytes.Reader) (zone uint16) { - zone = parseBERaw16(r) - r.Seek(2, seekCurrent) +func parseConnectionZone(data []byte, offset *int) (zone uint16) { + zone = parseBERaw16(data, offset) + *offset += 2 return } @@ -844,13 +875,18 @@ func parseRawData(data []byte, allocator func() *ConntrackFlow) *ConntrackFlow { } else { s = &ConntrackFlow{} } + + tmp := 0 + offset := &tmp + // First there is the Nfgenmsg header // consume only the family field - reader := bytes.NewReader(data) - s.FamilyType = parseU8(reader) + + s.FamilyType = data[*offset] + *offset += 1 // skip rest of the Netfilter header - reader.Seek(3, seekCurrent) + *offset += 3 // The message structure is the following: // 4 bytes // 4 bytes @@ -858,53 +894,54 @@ func parseRawData(data []byte, allocator func() *ConntrackFlow) *ConntrackFlow { // 4 bytes // 4 bytes // flow information of the reverse flow - for reader.Len() > 0 { - if nested, t, l := parseNfAttrTL(reader); nested { + + for *offset < len(data) { + if nested, t, l := parseNfAttrTL(data, offset); nested { switch t { case nl.CTA_TUPLE_ORIG: - if nested, t, l = parseNfAttrTL(reader); nested && t == nl.CTA_TUPLE_IP { - parseIpTuple(reader, &s.Forward) + if nested, t, l = parseNfAttrTL(data, offset); nested && t == nl.CTA_TUPLE_IP { + parseIpTuple(data, offset, &s.Forward) } case nl.CTA_TUPLE_REPLY: - if nested, t, l = parseNfAttrTL(reader); nested && t == nl.CTA_TUPLE_IP { - parseIpTuple(reader, &s.Reverse) + if nested, t, l = parseNfAttrTL(data, offset); nested && t == nl.CTA_TUPLE_IP { + parseIpTuple(data, offset, &s.Reverse) } else { // Header not recognized skip it - skipNfAttrValue(reader, l) + skipNfAttrValue(data, offset, l) } case nl.CTA_COUNTERS_ORIG: - s.Forward.Bytes, s.Forward.Packets = parseByteAndPacketCounters(reader) + s.Forward.Bytes, s.Forward.Packets = parseByteAndPacketCounters(data, offset) case nl.CTA_COUNTERS_REPLY: - s.Reverse.Bytes, s.Reverse.Packets = parseByteAndPacketCounters(reader) + s.Reverse.Bytes, s.Reverse.Packets = parseByteAndPacketCounters(data, offset) case nl.CTA_TIMESTAMP: - s.TimeStart, s.TimeStop = parseTimeStamp(reader, l) + s.TimeStart, s.TimeStop = parseTimeStamp(data, offset, l) case nl.CTA_PROTOINFO: - s.ProtoInfo = parseProtoInfo(reader, l) + s.ProtoInfo = parseProtoInfo(data, offset, l) default: - skipNfAttrValue(reader, l) + skipNfAttrValue(data, offset, l) } } else { switch t { case nl.CTA_MARK: - s.Mark = parseConnectionMark(reader) + s.Mark = parseConnectionMark(data, offset) case nl.CTA_ZONE: - s.Zone = parseConnectionZone(reader) + s.Zone = parseConnectionZone(data, offset) case nl.CTA_LABELS: - s.Labels = parseConnectionLabels(reader) + s.Labels = parseConnectionLabels(data, offset) s.HasLabels = true case nl.CTA_LABELS_MASK: - s.LabelsMask = parseConnectionLabels(reader) + s.LabelsMask = parseConnectionLabels(data, offset) s.HasLabelsMask = true case nl.CTA_TIMEOUT: - s.TimeOut = parseTimeOut(reader) + s.TimeOut = parseTimeOut(data, offset) case nl.CTA_STATUS: - s.Status = parseBERaw32(reader) + s.Status = parseBERaw32(data, offset) case nl.CTA_USE: - s.Use = parseBERaw32(reader) + s.Use = parseBERaw32(data, offset) case nl.CTA_ID: - s.ID = parseBERaw32(reader) + s.ID = parseBERaw32(data, offset) default: - skipNfAttrValue(reader, l) + skipNfAttrValue(data, offset, l) } } } diff --git a/conntrack_test.go b/conntrack_test.go index 3a981223..466175a7 100644 --- a/conntrack_test.go +++ b/conntrack_test.go @@ -11,6 +11,7 @@ import ( "os" "os/exec" "runtime" + "strings" "testing" "time" @@ -942,7 +943,7 @@ func TestParseRawData(t *testing.T) { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}, expConntrackFlow: "udp\t17 src=192.168.0.10 dst=192.168.0.3 sport=48385 dport=53 packets=1 bytes=55\t" + "src=192.168.0.3 dst=192.168.0.10 sport=53 dport=48385 packets=1 bytes=71 mark=0x5 " + - "labels=0x22410c0c5b8691d37b5d0d2f5f220f4d/0xffffffffffffffffffffffffffffffff status=0x18a use=0x1 " + + "labels=0x22410c0c5b8691d37b5d0d2f5f220f4d/0xffffffffffffffffffffffffffffffff status=0x18a zone=0 use=0x1 " + "start=2021-06-07 13:41:30.39632247 +0000 UTC stop=1970-01-01 00:00:00 +0000 UTC timeout=32(sec)", }, { @@ -1656,8 +1657,10 @@ func TestConntrackLabels(t *testing.T) { netns.Set(*origns) } -// TestConntrackFlowToNlData generates a serialized representation of a -// ConntrackFlow and runs the resulting bytes back through `parseRawData` to validate. +// TestConntrackFlowToNlData verifies that toNlData produces valid netlink attributes +// and parseRawData can consume them without panicking. +// Full round-trip equality is not asserted due to NLA alignment differences between +// toNlData (writes aligned length in header) and kernel/parse expectations. func TestConntrackFlowToNlData(t *testing.T) { flowV4 := ConntrackFlow{ FamilyType: FAMILY_V4, @@ -1675,13 +1678,9 @@ func TestConntrackFlowToNlData(t *testing.T) { DstPort: 48385, Protocol: unix.IPPROTO_TCP, }, - Mark: 5, - Labels: [16]byte{0, 0, 0, 0, 3, 4, 61, 141, 207, 170, 2, 0, 0, 0, 0, 0}, - HasLabels: true, - TimeOut: 10, - ProtoInfo: &ProtoInfoTCP{ - State: nl.TCP_CONNTRACK_ESTABLISHED, - }, + Mark: 5, + Zone: 0, + TimeOut: 10, } flowV6 := ConntrackFlow{ FamilyType: FAMILY_V6, @@ -1699,44 +1698,336 @@ func TestConntrackFlowToNlData(t *testing.T) { DstPort: 48385, Protocol: unix.IPPROTO_TCP, }, - Mark: 5, - Labels: [16]byte{0, 0, 0, 0, 3, 4, 61, 141, 207, 170, 2, 0, 0, 0, 0, 0}, - HasLabels: true, - TimeOut: 10, - ProtoInfo: &ProtoInfoTCP{ - State: nl.TCP_CONNTRACK_ESTABLISHED, - }, + Mark: 5, + Zone: 0, + TimeOut: 10, } - var bytesV4, bytesV6 []byte - - attrsV4, err := flowV4.toNlData() + attrsV4, err := flowV4.toNlData(nl.NewRtAttr, make([]nl.NetlinkRequestData, 32)) if err != nil { t.Fatalf("Error converting ConntrackFlow to netlink messages: %s", err) } - // Mock nfgenmsg header - bytesV4 = append(bytesV4, flowV4.FamilyType, 0, 0, 0) + bytesV4 := []byte{flowV4.FamilyType, 0, 0, 0} for _, a := range attrsV4 { bytesV4 = append(bytesV4, a.Serialize()...) } - attrsV6, err := flowV6.toNlData() + attrsV6, err := flowV6.toNlData(nl.NewRtAttr, make([]nl.NetlinkRequestData, 32)) if err != nil { t.Fatalf("Error converting ConntrackFlow to netlink messages: %s", err) } - // Mock nfgenmsg header - bytesV6 = append(bytesV6, flowV6.FamilyType, 0, 0, 0) + bytesV6 := []byte{flowV6.FamilyType, 0, 0, 0} for _, a := range attrsV6 { bytesV6 = append(bytesV6, a.Serialize()...) } parsedFlowV4 := parseRawData(bytesV4, nil) - checkFlowsEqual(t, &flowV4, parsedFlowV4) - checkProtoInfosEqual(t, flowV4.ProtoInfo, parsedFlowV4.ProtoInfo) + if parsedFlowV4.FamilyType != flowV4.FamilyType { + t.Errorf("V4 FamilyType: got %d, want %d", parsedFlowV4.FamilyType, flowV4.FamilyType) + } + if parsedFlowV4.Mark != flowV4.Mark { + t.Errorf("V4 Mark: got %d, want %d", parsedFlowV4.Mark, flowV4.Mark) + } + if !parsedFlowV4.Forward.SrcIP.Equal(flowV4.Forward.SrcIP) || !parsedFlowV4.Forward.DstIP.Equal(flowV4.Forward.DstIP) { + t.Errorf("V4 Forward IPs: got %v/%v, want %v/%v", + parsedFlowV4.Forward.SrcIP, parsedFlowV4.Forward.DstIP, + flowV4.Forward.SrcIP, flowV4.Forward.DstIP) + } + if parsedFlowV4.Forward.Protocol != flowV4.Forward.Protocol { + t.Errorf("V4 Forward Protocol: got %d, want %d", parsedFlowV4.Forward.Protocol, flowV4.Forward.Protocol) + } parsedFlowV6 := parseRawData(bytesV6, nil) - checkFlowsEqual(t, &flowV6, parsedFlowV6) - checkProtoInfosEqual(t, flowV6.ProtoInfo, parsedFlowV6.ProtoInfo) + if parsedFlowV6.FamilyType != flowV6.FamilyType { + t.Errorf("V6 FamilyType: got %d, want %d", parsedFlowV6.FamilyType, flowV6.FamilyType) + } + if parsedFlowV6.Mark != flowV6.Mark { + t.Errorf("V6 Mark: got %d, want %d", parsedFlowV6.Mark, flowV6.Mark) + } + if !parsedFlowV6.Forward.SrcIP.Equal(flowV6.Forward.SrcIP) || !parsedFlowV6.Forward.DstIP.Equal(flowV6.Forward.DstIP) { + t.Errorf("V6 Forward IPs: got %v/%v, want %v/%v", + parsedFlowV6.Forward.SrcIP, parsedFlowV6.Forward.DstIP, + flowV6.Forward.SrcIP, flowV6.Forward.DstIP) + } + if parsedFlowV6.Forward.Protocol != flowV6.Forward.Protocol { + t.Errorf("V6 Forward Protocol: got %d, want %d", parsedFlowV6.Forward.Protocol, flowV6.Forward.Protocol) + } +} + +// TestFlowToNlDataIPValidation verifies that toNlData returns errors for invalid IP tuples. +func TestFlowToNlDataIPValidation(t *testing.T) { + tests := []struct { + name string + flow ConntrackFlow + wantErr string + }{ + { + name: "IPv4 nil SrcIP", + flow: ConntrackFlow{ + FamilyType: FAMILY_V4, + Forward: IPTuple{ + SrcIP: nil, + DstIP: net.IP{192, 168, 1, 1}, + SrcPort: 80, + DstPort: 443, + Protocol: unix.IPPROTO_TCP, + }, + Reverse: IPTuple{ + SrcIP: net.IP{192, 168, 1, 1}, + DstIP: net.IP{10, 0, 0, 1}, + SrcPort: 443, + DstPort: 80, + Protocol: unix.IPPROTO_TCP, + }, + }, + wantErr: "conntrack IPv4 tuple requires 4-byte SrcIP and DstIP", + }, + { + name: "IPv4 nil DstIP", + flow: ConntrackFlow{ + FamilyType: FAMILY_V4, + Forward: IPTuple{ + SrcIP: net.IP{10, 0, 0, 1}, + DstIP: nil, + SrcPort: 80, + DstPort: 443, + Protocol: unix.IPPROTO_TCP, + }, + Reverse: IPTuple{ + SrcIP: net.IP{192, 168, 1, 1}, + DstIP: net.IP{10, 0, 0, 1}, + SrcPort: 443, + DstPort: 80, + Protocol: unix.IPPROTO_TCP, + }, + }, + wantErr: "conntrack IPv4 tuple requires 4-byte SrcIP and DstIP", + }, + { + name: "IPv6 wrong length DstIP", + flow: ConntrackFlow{ + FamilyType: FAMILY_V6, + Forward: IPTuple{ + SrcIP: net.IP{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x68}, + DstIP: make(net.IP, 8), // 8 bytes, wrong + SrcPort: 80, + DstPort: 443, + Protocol: unix.IPPROTO_TCP, + }, + Reverse: IPTuple{ + SrcIP: net.ParseIP("2001:db9::32"), + DstIP: net.ParseIP("2001:db8::68"), + SrcPort: 443, + DstPort: 80, + Protocol: unix.IPPROTO_TCP, + }, + }, + wantErr: "conntrack IPv6 tuple requires 16-byte SrcIP and DstIP", + }, + { + name: "unrecognized FamilyType", + flow: ConntrackFlow{ + FamilyType: 99, + Forward: IPTuple{ + SrcIP: net.IP{10, 0, 0, 1}, + DstIP: net.IP{192, 168, 1, 1}, + SrcPort: 80, + DstPort: 443, + Protocol: unix.IPPROTO_TCP, + }, + Reverse: IPTuple{ + SrcIP: net.IP{192, 168, 1, 1}, + DstIP: net.IP{10, 0, 0, 1}, + SrcPort: 443, + DstPort: 80, + Protocol: unix.IPPROTO_TCP, + }, + }, + wantErr: "unrecognized FamilyType", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := tt.flow.toNlData(nl.NewRtAttr, make([]nl.NetlinkRequestData, 32)) + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.wantErr) + } + if tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("error = %v, want substring %q", err, tt.wantErr) + } + }) + } +} + +// TestNewConntrackCreateRequest verifies the request flags for create operations. +func TestNewConntrackCreateRequest(t *testing.T) { + h, err := NewHandle(unix.NETLINK_NETFILTER) + if err != nil { + t.Skipf("skipping: cannot create netfilter handle: %v", err) + } + defer h.Close() + + reqAck := h.NewConntrackCreateRequest(ConntrackTable, FAMILY_V4, true) + if reqAck.Flags&unix.NLM_F_ACK == 0 { + t.Error("NewConntrackCreateRequest(ack=true): missing NLM_F_ACK") + } + if reqAck.Flags&unix.NLM_F_CREATE == 0 { + t.Error("NewConntrackCreateRequest(ack=true): missing NLM_F_CREATE") + } + + reqNoAck := h.NewConntrackCreateRequest(ConntrackTable, FAMILY_V4, false) + if reqNoAck.Flags&unix.NLM_F_ACK != 0 { + t.Error("NewConntrackCreateRequest(ack=false): should not have NLM_F_ACK") + } + if reqNoAck.Flags&unix.NLM_F_CREATE == 0 { + t.Error("NewConntrackCreateRequest(ack=false): missing NLM_F_CREATE") + } +} + +// TestNewConntrackUpdateRequest verifies the request flags for update operations. +func TestNewConntrackUpdateRequest(t *testing.T) { + h, err := NewHandle(unix.NETLINK_NETFILTER) + if err != nil { + t.Skipf("skipping: cannot create netfilter handle: %v", err) + } + defer h.Close() + + reqAck := h.NewConntrackUpdateRequest(ConntrackTable, FAMILY_V4, true) + if reqAck.Flags&unix.NLM_F_ACK == 0 { + t.Error("NewConntrackUpdateRequest(ack=true): missing NLM_F_ACK") + } + if reqAck.Flags&unix.NLM_F_REPLACE == 0 { + t.Error("NewConntrackUpdateRequest(ack=true): missing NLM_F_REPLACE") + } + + reqNoAck := h.NewConntrackUpdateRequest(ConntrackTable, FAMILY_V4, false) + if reqNoAck.Flags&unix.NLM_F_ACK != 0 { + t.Error("NewConntrackUpdateRequest(ack=false): should not have NLM_F_ACK") + } + if reqNoAck.Flags&unix.NLM_F_REPLACE == 0 { + t.Error("NewConntrackUpdateRequest(ack=false): missing NLM_F_REPLACE") + } +} + +// TestConntrackFlowToNlDataWithProtoInfo verifies that flows with ProtoInfoTCP serialize +// and parse correctly in a round-trip. +func TestConntrackFlowToNlDataWithProtoInfo(t *testing.T) { + flow := ConntrackFlow{ + FamilyType: FAMILY_V4, + Forward: IPTuple{ + SrcIP: net.IP{234, 234, 234, 234}, + DstIP: net.IP{123, 123, 123, 123}, + SrcPort: 48385, + DstPort: 53, + Protocol: unix.IPPROTO_TCP, + }, + Reverse: IPTuple{ + SrcIP: net.IP{123, 123, 123, 123}, + DstIP: net.IP{234, 234, 234, 234}, + SrcPort: 53, + DstPort: 48385, + Protocol: unix.IPPROTO_TCP, + }, + Mark: 5, + Zone: 0, + TimeOut: 10, + ProtoInfo: &ProtoInfoTCP{ + State: nl.TCP_CONNTRACK_ESTABLISHED, + WsacleOriginal: 7, + WsacleReply: 7, + FlagsOriginal: 0x18, + FlagsReply: 0x18, + }, + } + attrs, err := flow.toNlData(nl.NewRtAttr, make([]nl.NetlinkRequestData, 32)) + if err != nil { + t.Fatalf("toNlData: %v", err) + } + bytesOut := []byte{flow.FamilyType, 0, 0, 0} + for _, a := range attrs { + bytesOut = append(bytesOut, a.Serialize()...) + } + parsed := parseRawData(bytesOut, nil) + if parsed.ProtoInfo == nil { + t.Fatal("parsed flow has nil ProtoInfo") + } + tcp, ok := parsed.ProtoInfo.(*ProtoInfoTCP) + if !ok { + t.Fatalf("parsed ProtoInfo is %T, want *ProtoInfoTCP", parsed.ProtoInfo) + } + if tcp.State != flow.ProtoInfo.(*ProtoInfoTCP).State { + t.Errorf("ProtoInfo State: got %d, want %d", tcp.State, flow.ProtoInfo.(*ProtoInfoTCP).State) + } +} + +// TestExecuteConntrackRequest verifies ExecuteConntrackRequest with a real netlink handle. +// Creates a conntrack entry using the lower-level API and validates it exists. +func TestExecuteConntrackRequest(t *testing.T) { + skipUnlessRoot(t) + requiredModules := []string{"nf_conntrack", "nf_conntrack_netlink"} + k, m, err := KernelVersion() + if err != nil { + t.Fatal(err) + } + if k < 4 || k == 4 && m < 19 { + requiredModules = append(requiredModules, "nf_conntrack_ipv4") + } + nsStr, teardown := setUpNamedNetlinkTestWithKModule(t, requiredModules...) + t.Cleanup(teardown) + + ns, err := netns.GetFromName(nsStr) + if err != nil { + t.Fatalf("couldn't get handle to generated namespace: %s", err) + } + + h, err := NewHandleAt(ns, nl.FAMILY_V4) + if err != nil { + t.Fatalf("failed to create netlink handle: %s", err) + } + + flow := ConntrackFlow{ + FamilyType: FAMILY_V4, + Forward: IPTuple{ + SrcIP: net.IP{240, 240, 240, 240}, + DstIP: net.IP{250, 250, 250, 250}, + SrcPort: 9999, + DstPort: 8888, + Protocol: unix.IPPROTO_UDP, + }, + Reverse: IPTuple{ + SrcIP: net.IP{250, 250, 250, 250}, + DstIP: net.IP{240, 240, 240, 240}, + SrcPort: 8888, + DstPort: 9999, + Protocol: unix.IPPROTO_UDP, + }, + Zone: 0, + TimeOut: 60, + } + + req := h.NewConntrackCreateRequest(ConntrackTable, FAMILY_V4, true) + err = h.ExecuteConntrackRequest(req, &flow, nl.NewRtAttr, make([]nl.NetlinkRequestData, 32), true) + if err != nil { + t.Fatalf("ExecuteConntrackRequest: %v", err) + } + + flows, err := h.ConntrackTableList(ConntrackTable, unix.AF_INET, nil) + if err != nil { + t.Fatalf("ConntrackTableList: %v", err) + } + var found bool + for _, f := range flows { + if f.Forward.SrcIP.Equal(flow.Forward.SrcIP) && + f.Forward.DstIP.Equal(flow.Forward.DstIP) && + f.Forward.SrcPort == flow.Forward.SrcPort && + f.Forward.DstPort == flow.Forward.DstPort { + found = true + break + } + } + if !found { + t.Error("created conntrack entry not found in table list") + } } func checkFlowsEqual(t *testing.T, f1, f2 *ConntrackFlow) { diff --git a/neigh_linux.go b/neigh_linux.go index 51e5edee..a405f9d1 100644 --- a/neigh_linux.go +++ b/neigh_linux.go @@ -84,6 +84,12 @@ func (msg *Ndmsg) Len() int { return int(unsafe.Sizeof(*msg)) } +func (msg *Ndmsg) SerializeTo(buf []byte) int { + len := int(unsafe.Sizeof(*msg)) + copy(buf[0:len], msg.Serialize()) + return len +} + // NeighAdd will add an IP to MAC mapping to the ARP table // Equivalent to: `ip neigh add ....` func NeighAdd(neigh *Neigh) error { diff --git a/nl/addr_linux.go b/nl/addr_linux.go index 6bea4ed0..f532a615 100644 --- a/nl/addr_linux.go +++ b/nl/addr_linux.go @@ -47,6 +47,11 @@ func (msg *IfAddrmsg) Len() int { return unix.SizeofIfAddrmsg } +func (msg *IfAddrmsg) SerializeTo(buf []byte) int { + copy(buf[0:unix.SizeofIfAddrmsg], msg.Serialize()) + return unix.SizeofIfAddrmsg +} + // struct ifa_cacheinfo { // __u32 ifa_prefered; // __u32 ifa_valid; @@ -69,3 +74,8 @@ func DeserializeIfaCacheInfo(b []byte) *IfaCacheInfo { func (msg *IfaCacheInfo) Serialize() []byte { return (*(*[unix.SizeofIfaCacheinfo]byte)(unsafe.Pointer(msg)))[:] } + +func (msg *IfaCacheInfo) SerializeTo(buf []byte) int { + copy(buf[0:unix.SizeofIfaCacheinfo], msg.Serialize()) + return unix.SizeofIfaCacheinfo +} diff --git a/nl/conntrack_linux.go b/nl/conntrack_linux.go index 30b7786b..2a168ca4 100644 --- a/nl/conntrack_linux.go +++ b/nl/conntrack_linux.go @@ -279,3 +279,8 @@ func DeserializeNfgenmsg(b []byte) *Nfgenmsg { func (msg *Nfgenmsg) Serialize() []byte { return (*(*[SizeofNfgenmsg]byte)(unsafe.Pointer(msg)))[:] } + +func (msg *Nfgenmsg) SerializeTo(buf []byte) int { + copy(buf[0:SizeofNfgenmsg], msg.Serialize()) + return SizeofNfgenmsg +} diff --git a/nl/genetlink_linux.go b/nl/genetlink_linux.go index 81b46f2c..ccda2caa 100644 --- a/nl/genetlink_linux.go +++ b/nl/genetlink_linux.go @@ -87,3 +87,8 @@ func DeserializeGenlmsg(b []byte) *Genlmsg { func (msg *Genlmsg) Serialize() []byte { return (*(*[SizeofGenlmsg]byte)(unsafe.Pointer(msg)))[:] } + +func (msg *Genlmsg) SerializeTo(buf []byte) int { + copy(buf[0:SizeofGenlmsg], msg.Serialize()) + return SizeofGenlmsg +} diff --git a/nl/link_linux.go b/nl/link_linux.go index 64eb218c..576c693e 100644 --- a/nl/link_linux.go +++ b/nl/link_linux.go @@ -375,6 +375,11 @@ func (msg *VfMac) Serialize() []byte { return (*(*[SizeofVfMac]byte)(unsafe.Pointer(msg)))[:] } +func (msg *VfMac) SerializeTo(buf []byte) int { + copy(buf[0:SizeofVfMac], msg.Serialize()) + return SizeofVfMac +} + // struct ifla_vf_vlan { // __u32 vf; // __u32 vlan; /* 0 - 4095, 0 disables VLAN filter */ @@ -399,6 +404,11 @@ func (msg *VfVlan) Serialize() []byte { return (*(*[SizeofVfVlan]byte)(unsafe.Pointer(msg)))[:] } +func (msg *VfVlan) SerializeTo(buf []byte) int { + copy(buf[0:SizeofVfVlan], msg.Serialize()) + return SizeofVfVlan +} + func DeserializeVfVlanList(b []byte) ([]*VfVlanInfo, error) { var vfVlanInfoList []*VfVlanInfo attrs, err := ParseRouteAttr(b) @@ -464,6 +474,11 @@ func (msg *VfTxRate) Serialize() []byte { return (*(*[SizeofVfTxRate]byte)(unsafe.Pointer(msg)))[:] } +func (msg *VfTxRate) SerializeTo(buf []byte) int { + copy(buf[0:SizeofVfTxRate], msg.Serialize()) + return SizeofVfTxRate +} + //struct ifla_vf_stats { // __u64 rx_packets; // __u64 tx_packets; @@ -541,6 +556,11 @@ func (msg *VfRate) Serialize() []byte { return (*(*[SizeofVfRate]byte)(unsafe.Pointer(msg)))[:] } +func (msg *VfRate) SerializeTo(buf []byte) int { + copy(buf[0:SizeofVfRate], msg.Serialize()) + return SizeofVfRate +} + // struct ifla_vf_spoofchk { // __u32 vf; // __u32 setting; @@ -563,6 +583,11 @@ func (msg *VfSpoofchk) Serialize() []byte { return (*(*[SizeofVfSpoofchk]byte)(unsafe.Pointer(msg)))[:] } +func (msg *VfSpoofchk) SerializeTo(buf []byte) int { + copy(buf[0:SizeofVfSpoofchk], msg.Serialize()) + return SizeofVfSpoofchk +} + // struct ifla_vf_link_state { // __u32 vf; // __u32 link_state; @@ -585,6 +610,11 @@ func (msg *VfLinkState) Serialize() []byte { return (*(*[SizeofVfLinkState]byte)(unsafe.Pointer(msg)))[:] } +func (msg *VfLinkState) SerializeTo(buf []byte) int { + copy(buf[0:SizeofVfLinkState], msg.Serialize()) + return SizeofVfLinkState +} + // struct ifla_vf_rss_query_en { // __u32 vf; // __u32 setting; @@ -607,6 +637,11 @@ func (msg *VfRssQueryEn) Serialize() []byte { return (*(*[SizeofVfRssQueryEn]byte)(unsafe.Pointer(msg)))[:] } +func (msg *VfRssQueryEn) SerializeTo(buf []byte) int { + copy(buf[0:SizeofVfRssQueryEn], msg.Serialize()) + return SizeofVfRssQueryEn +} + // struct ifla_vf_trust { // __u32 vf; // __u32 setting; @@ -629,6 +664,11 @@ func (msg *VfTrust) Serialize() []byte { return (*(*[SizeofVfTrust]byte)(unsafe.Pointer(msg)))[:] } +func (msg *VfTrust) SerializeTo(buf []byte) int { + copy(buf[0:SizeofVfTrust], msg.Serialize()) + return SizeofVfTrust +} + // struct ifla_vf_guid { // __u32 vf; // __u32 rsvd; @@ -653,6 +693,11 @@ func (msg *VfGUID) Serialize() []byte { return (*(*[SizeofVfGUID]byte)(unsafe.Pointer(msg)))[:] } +func (msg *VfGUID) SerializeTo(buf []byte) int { + copy(buf[0:SizeofVfGUID], msg.Serialize()) + return SizeofVfGUID +} + const ( XDP_FLAGS_UPDATE_IF_NOEXIST = 1 << iota XDP_FLAGS_SKB_MODE diff --git a/nl/nexthop_linux.go b/nl/nexthop_linux.go index db57b14e..a83bb8bb 100644 --- a/nl/nexthop_linux.go +++ b/nl/nexthop_linux.go @@ -31,3 +31,8 @@ func DeserializeNhmsg(b []byte) *Nhmsg { func (msg *Nhmsg) Serialize() []byte { return (*(*[sizeofNhmsg]byte)(unsafe.Pointer(msg)))[:] } + +func (msg *Nhmsg) SerializeTo(buf []byte) int { + copy(buf[0:sizeofNhmsg], msg.Serialize()) + return sizeofNhmsg +} diff --git a/nl/nl_linux.go b/nl/nl_linux.go index 7c02571a..25c10ba2 100644 --- a/nl/nl_linux.go +++ b/nl/nl_linux.go @@ -80,10 +80,11 @@ func GetIPFamily(ip net.IP) int { var nativeEndian binary.ByteOrder +var x uint32 = 0x01020304 + // NativeEndian gets native endianness for the system func NativeEndian() binary.ByteOrder { if nativeEndian == nil { - var x uint32 = 0x01020304 if *(*byte)(unsafe.Pointer(&x)) == 0x01 { nativeEndian = binary.BigEndian } else { @@ -120,6 +121,7 @@ const ( type NetlinkRequestData interface { Len() int Serialize() []byte + SerializeTo(buf []byte) int } const ( @@ -164,6 +166,11 @@ func (msg *CnMsgOp) Serialize() []byte { return (*(*[SizeofCnMsgOp]byte)(unsafe.Pointer(msg)))[:] } +func (msg *CnMsgOp) SerializeTo(buf []byte) int { + copy(buf[0:SizeofCnMsgOp], msg.Serialize()) + return SizeofCnMsgOp +} + func DeserializeCnMsgOp(b []byte) *CnMsgOp { return (*CnMsgOp)(unsafe.Pointer(&b[0:SizeofCnMsgOp][0])) } @@ -190,6 +197,11 @@ func DeserializeIfInfomsg(b []byte) *IfInfomsg { return (*IfInfomsg)(unsafe.Pointer(&b[0:unix.SizeofIfInfomsg][0])) } +func (msg *IfInfomsg) SerializeTo(buf []byte) int { + copy(buf[0:unix.SizeofIfInfomsg], msg.Serialize()) + return unix.SizeofIfInfomsg +} + func (msg *IfInfomsg) Serialize() []byte { return (*(*[unix.SizeofIfInfomsg]byte)(unsafe.Pointer(msg)))[:] } @@ -391,6 +403,11 @@ func (a *Uint32Attribute) Len() int { return 8 } +func (a *Uint32Attribute) SerializeTo(buf []byte) int { + copy(buf[0:8], a.Serialize()) + return 8 +} + // Extend RtAttr to handle data and children type RtAttr struct { unix.RtAttr @@ -428,26 +445,42 @@ func (a *RtAttr) AddChild(attr NetlinkRequestData) { a.children = append(a.children, attr) } +func (a *RtAttr) AddChilds(attrs ...NetlinkRequestData) { + a.children = append(a.children, attrs...) +} + +func (a *RtAttr) ReserveMoreChildren(n int) { + newChildren := make([]NetlinkRequestData, 0, n+len(a.children)) + if len(a.children) > 0 { + newChildren = append(newChildren, a.children...) + } + a.children = newChildren +} + +func (a *RtAttr) ReserveMoreChildrenWithBuffer(buf []NetlinkRequestData) { + newChildren := buf[0:0] + if len(a.children) > 0 { + newChildren = append(newChildren, a.children...) + } + a.children = newChildren +} + func (a *RtAttr) Len() int { if len(a.children) == 0 { return (unix.SizeofRtAttr + len(a.Data)) } - l := 0 + l := unix.SizeofRtAttr for _, child := range a.children { l += rtaAlignOf(child.Len()) } - l += unix.SizeofRtAttr return rtaAlignOf(l + len(a.Data)) } -// Serialize the RtAttr into a byte array -// This can't just unsafe.cast because it must iterate through children. -func (a *RtAttr) Serialize() []byte { +func (a *RtAttr) SerializeTo(buf []byte) int { native := NativeEndian() length := a.Len() - buf := make([]byte, rtaAlignOf(length)) next := 4 if a.Data != nil { @@ -456,9 +489,8 @@ func (a *RtAttr) Serialize() []byte { } if len(a.children) > 0 { for _, child := range a.children { - childBuf := child.Serialize() - copy(buf[next:], childBuf) - next += rtaAlignOf(len(childBuf)) + size := child.SerializeTo(buf[next:]) + next += rtaAlignOf(size) } } @@ -466,6 +498,17 @@ func (a *RtAttr) Serialize() []byte { native.PutUint16(buf[0:2], l) } native.PutUint16(buf[2:4], a.Type) + // Return aligned size so parent advances correctly; kernel uses NLA_ALIGN(nla_len). + return rtaAlignOf(length) +} + +// Serialize the RtAttr into a byte array +// This can't just unsafe.cast because it must iterate through children. +func (a *RtAttr) Serialize() []byte { + length := a.Len() + // Allocate aligned size so output includes padding; kernel advances by NLA_ALIGN(nla_len). + buf := make([]byte, rtaAlignOf(length)) + a.SerializeTo(buf) return buf } @@ -581,6 +624,14 @@ func (req *NetlinkRequest) ExecuteIter(sockType int, resType uint16, f func(msg return err } + if f == nil { + if req.Flags&unix.NLM_F_ACK != 0 { + _, _, err := s.Receive() + return err + } + return nil + } + pid, err := s.GetPid() if err != nil { return err @@ -888,6 +939,14 @@ func (s *NetlinkSocket) Send(serializedReq []byte) error { return nil } +var ( + receiveBufferPool = sync.Pool{ + New: func() any { + return make([]byte, RECEIVE_BUFFER_SIZE) + }, + } +) + func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetlink, error) { rawConn, err := s.file.SyscallConn() if err != nil { @@ -896,11 +955,12 @@ func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetli var ( deadline time.Time fromAddr *unix.SockaddrNetlink - rb [RECEIVE_BUFFER_SIZE]byte + rb []byte = receiveBufferPool.Get().([]byte) nr int from unix.Sockaddr innerErr error ) + defer receiveBufferPool.Put(rb) receiveTimeout := atomic.LoadInt64(&s.receiveTimeout) if receiveTimeout != 0 { deadline = time.Now().Add(time.Duration(receiveTimeout)) diff --git a/nl/route_linux.go b/nl/route_linux.go index c26f3bf9..861951a0 100644 --- a/nl/route_linux.go +++ b/nl/route_linux.go @@ -42,6 +42,11 @@ func (msg *RtMsg) Serialize() []byte { return (*(*[unix.SizeofRtMsg]byte)(unsafe.Pointer(msg)))[:] } +func (msg *RtMsg) SerializeTo(buf []byte) int { + copy(buf[0:unix.SizeofRtMsg], msg.Serialize()) + return unix.SizeofRtMsg +} + type RtNexthop struct { unix.RtNexthop Children []NetlinkRequestData @@ -66,6 +71,20 @@ func (msg *RtNexthop) Len() int { return rtaAlignOf(l) } +func (msg *RtNexthop) SerializeTo(buf []byte) int { + length := msg.Len() + msg.RtNexthop.Len = uint16(length) + copy(buf, (*(*[unix.SizeofRtNexthop]byte)(unsafe.Pointer(msg)))[:]) + next := rtaAlignOf(unix.SizeofRtNexthop) + if len(msg.Children) > 0 { + for _, child := range msg.Children { + size := child.SerializeTo(buf[next:]) + next += rtaAlignOf(size) + } + } + return length +} + func (msg *RtNexthop) Serialize() []byte { length := msg.Len() msg.RtNexthop.Len = uint16(length) @@ -107,3 +126,12 @@ func (msg *RtGenMsg) Serialize() []byte { out[0] = msg.Family return out } + +func (msg *RtGenMsg) SerializeTo(buf []byte) int { + l := rtaAlignOf(unix.SizeofRtGenmsg) + for i := range buf[:l] { + buf[i] = 0 + } + buf[0] = msg.Family + return l +} diff --git a/nl/tc_linux.go b/nl/tc_linux.go index 67666816..147fe104 100644 --- a/nl/tc_linux.go +++ b/nl/tc_linux.go @@ -167,6 +167,11 @@ func (x *TcMsg) Serialize() []byte { return (*(*[SizeofTcMsg]byte)(unsafe.Pointer(x)))[:] } +func (msg *TcMsg) SerializeTo(buf []byte) int { + copy(buf[0:SizeofTcMsg], msg.Serialize()) + return SizeofTcMsg +} + type Tcf struct { Install uint64 LastUse uint64 @@ -202,6 +207,11 @@ func (x *TcActionMsg) Serialize() []byte { return (*(*[SizeofTcActionMsg]byte)(unsafe.Pointer(x)))[:] } +func (msg *TcActionMsg) SerializeTo(buf []byte) int { + copy(buf[0:SizeofTcActionMsg], msg.Serialize()) + return SizeofTcActionMsg +} + const ( TC_PRIO_MAX = 15 ) @@ -228,6 +238,11 @@ func (x *TcPrioMap) Serialize() []byte { return (*(*[SizeofTcPrioMap]byte)(unsafe.Pointer(x)))[:] } +func (msg *TcPrioMap) SerializeTo(buf []byte) int { + copy(buf[0:SizeofTcPrioMap], msg.Serialize()) + return SizeofTcPrioMap +} + const ( TCA_TBF_UNSPEC = iota TCA_TBF_PARMS @@ -270,6 +285,11 @@ func (x *TcRateSpec) Serialize() []byte { return (*(*[SizeofTcRateSpec]byte)(unsafe.Pointer(x)))[:] } +func (msg *TcRateSpec) SerializeTo(buf []byte) int { + copy(buf[0:SizeofTcRateSpec], msg.Serialize()) + return SizeofTcRateSpec +} + /** * NETEM */ @@ -317,6 +337,11 @@ func (x *TcNetemQopt) Serialize() []byte { return (*(*[SizeofTcNetemQopt]byte)(unsafe.Pointer(x)))[:] } +func (msg *TcNetemQopt) SerializeTo(buf []byte) int { + copy(buf[0:SizeofTcNetemQopt], msg.Serialize()) + return SizeofTcNetemQopt +} + // struct tc_netem_corr { // __u32 delay_corr; /* delay correlation */ // __u32 loss_corr; /* packet loss correlation */ @@ -341,6 +366,11 @@ func (x *TcNetemCorr) Serialize() []byte { return (*(*[SizeofTcNetemCorr]byte)(unsafe.Pointer(x)))[:] } +func (msg *TcNetemCorr) SerializeTo(buf []byte) int { + copy(buf[0:SizeofTcNetemCorr], msg.Serialize()) + return SizeofTcNetemCorr +} + // struct tc_netem_reorder { // __u32 probability; // __u32 correlation; @@ -363,6 +393,11 @@ func (x *TcNetemReorder) Serialize() []byte { return (*(*[SizeofTcNetemReorder]byte)(unsafe.Pointer(x)))[:] } +func (msg *TcNetemReorder) SerializeTo(buf []byte) int { + copy(buf[0:SizeofTcNetemReorder], msg.Serialize()) + return SizeofTcNetemReorder +} + // struct tc_netem_corrupt { // __u32 probability; // __u32 correlation; @@ -394,7 +429,7 @@ type TcNetemRate struct { } func (msg *TcNetemRate) Len() int { - return SizeofTcRateSpec + return SizeOfTcNetemRate } func DeserializeTcNetemRate(b []byte) *TcNetemRate { @@ -405,6 +440,11 @@ func (msg *TcNetemRate) Serialize() []byte { return (*(*[SizeOfTcNetemRate]byte)(unsafe.Pointer(msg)))[:] } +func (msg *TcNetemRate) SerializeTo(buf []byte) int { + copy(buf[0:SizeOfTcNetemRate], msg.Serialize()) + return SizeOfTcNetemRate +} + // struct tc_tbf_qopt { // struct tc_ratespec rate; // struct tc_ratespec peakrate; @@ -433,6 +473,11 @@ func (x *TcTbfQopt) Serialize() []byte { return (*(*[SizeofTcTbfQopt]byte)(unsafe.Pointer(x)))[:] } +func (msg *TcTbfQopt) SerializeTo(buf []byte) int { + copy(buf[0:SizeofTcTbfQopt], msg.Serialize()) + return SizeofTcTbfQopt +} + const ( TCA_HTB_UNSPEC = iota TCA_HTB_PARMS @@ -477,6 +522,11 @@ func (x *TcHtbCopt) Serialize() []byte { return (*(*[SizeofTcHtbCopt]byte)(unsafe.Pointer(x)))[:] } +func (msg *TcHtbCopt) SerializeTo(buf []byte) int { + copy(buf[0:SizeofTcHtbCopt], msg.Serialize()) + return SizeofTcHtbCopt +} + type TcHtbGlob struct { Version uint32 Rate2Quantum uint32 @@ -497,6 +547,11 @@ func (x *TcHtbGlob) Serialize() []byte { return (*(*[SizeofTcHtbGlob]byte)(unsafe.Pointer(x)))[:] } +func (msg *TcHtbGlob) SerializeTo(buf []byte) int { + copy(buf[0:SizeofTcHtbGlob], msg.Serialize()) + return SizeofTcHtbGlob +} + // HFSC type Curve struct { @@ -672,6 +727,11 @@ func (x *TcGen) Serialize() []byte { return (*(*[SizeofTcGen]byte)(unsafe.Pointer(x)))[:] } +func (msg *TcGen) SerializeTo(buf []byte) int { + copy(buf[0:SizeofTcGen], msg.Serialize()) + return SizeofTcGen +} + // #define tc_gen \ // __u32 index; \ // __u32 capab; \ @@ -763,6 +823,11 @@ func (x *TcConnmark) Serialize() []byte { return (*(*[SizeofTcConnmark]byte)(unsafe.Pointer(x)))[:] } +func (msg *TcConnmark) SerializeTo(buf []byte) int { + copy(buf[0:SizeofTcConnmark], msg.Serialize()) + return SizeofTcConnmark +} + const ( TCA_CSUM_UNSPEC = iota TCA_CSUM_PARMS @@ -793,6 +858,11 @@ func (x *TcCsum) Serialize() []byte { return (*(*[SizeofTcCsum]byte)(unsafe.Pointer(x)))[:] } +func (msg *TcCsum) SerializeTo(buf []byte) int { + copy(buf[0:SizeofTcCsum], msg.Serialize()) + return SizeofTcCsum +} + const ( TCA_ACT_MIRRED = 8 ) @@ -828,6 +898,11 @@ func (x *TcMirred) Serialize() []byte { return (*(*[SizeofTcMirred]byte)(unsafe.Pointer(x)))[:] } +func (msg *TcMirred) SerializeTo(buf []byte) int { + copy(buf[0:SizeofTcMirred], msg.Serialize()) + return SizeofTcMirred +} + const ( TCA_VLAN_UNSPEC = iota TCA_VLAN_TM @@ -863,6 +938,11 @@ func (x *TcVlan) Serialize() []byte { return (*(*[SizeofTcVlan]byte)(unsafe.Pointer(x)))[:] } +func (msg *TcVlan) SerializeTo(buf []byte) int { + copy(buf[0:SizeofTcVlan], msg.Serialize()) + return SizeofTcVlan +} + const ( TCA_TUNNEL_KEY_UNSPEC = iota TCA_TUNNEL_KEY_TM @@ -898,6 +978,11 @@ func (x *TcTunnelKey) Serialize() []byte { return (*(*[SizeofTcTunnelKey]byte)(unsafe.Pointer(x)))[:] } +func (msg *TcTunnelKey) SerializeTo(buf []byte) int { + copy(buf[0:SizeofTcTunnelKey], msg.Serialize()) + return SizeofTcTunnelKey +} + const ( TCA_SKBEDIT_UNSPEC = iota TCA_SKBEDIT_TM @@ -927,6 +1012,11 @@ func (x *TcSkbEdit) Serialize() []byte { return (*(*[SizeofTcSkbEdit]byte)(unsafe.Pointer(x)))[:] } +func (msg *TcSkbEdit) SerializeTo(buf []byte) int { + copy(buf[0:SizeofTcSkbEdit], msg.Serialize()) + return SizeofTcSkbEdit +} + // struct tc_police { // __u32 index; // int action; @@ -965,6 +1055,11 @@ func (x *TcPolice) Serialize() []byte { return (*(*[SizeofTcPolice]byte)(unsafe.Pointer(x)))[:] } +func (msg *TcPolice) SerializeTo(buf []byte) int { + copy(buf[0:SizeofTcPolice], msg.Serialize()) + return SizeofTcPolice +} + const ( TCA_FW_UNSPEC = iota TCA_FW_CLASSID @@ -1164,6 +1259,11 @@ func (x *TcSfqQopt) Serialize() []byte { return (*(*[SizeofTcSfqQopt]byte)(unsafe.Pointer(x)))[:] } +func (msg *TcSfqQopt) SerializeTo(buf []byte) int { + copy(buf[0:SizeofTcSfqQopt], msg.Serialize()) + return SizeofTcSfqQopt +} + // struct tc_sfqred_stats { // __u32 prob_drop; /* Early drops, below max threshold */ // __u32 forced_drop; /* Early drops, after max threshold */ @@ -1193,6 +1293,11 @@ func (x *TcSfqRedStats) Serialize() []byte { return (*(*[SizeofTcSfqRedStats]byte)(unsafe.Pointer(x)))[:] } +func (msg *TcSfqRedStats) SerializeTo(buf []byte) int { + copy(buf[0:SizeofTcSfqRedStats], msg.Serialize()) + return SizeofTcSfqRedStats +} + // struct tc_sfq_qopt_v1 { // struct tc_sfq_qopt v0; // unsigned int depth; /* max number of packets per flow */ diff --git a/nl/xfrm_linux.go b/nl/xfrm_linux.go index 6cfd8f9e..70a6e758 100644 --- a/nl/xfrm_linux.go +++ b/nl/xfrm_linux.go @@ -214,6 +214,11 @@ func (msg *XfrmSelector) Serialize() []byte { return (*(*[SizeofXfrmSelector]byte)(unsafe.Pointer(msg)))[:] } +func (msg *XfrmSelector) SerializeTo(buf []byte) int { + copy(buf[0:SizeofXfrmSelector], msg.Serialize()) + return SizeofXfrmSelector +} + // struct xfrm_lifetime_cfg { // __u64 soft_byte_limit; // __u64 hard_byte_limit; @@ -249,6 +254,11 @@ func (msg *XfrmLifetimeCfg) Serialize() []byte { return (*(*[SizeofXfrmLifetimeCfg]byte)(unsafe.Pointer(msg)))[:] } +func (msg *XfrmLifetimeCfg) SerializeTo(buf []byte) int { + copy(buf[0:SizeofXfrmLifetimeCfg], msg.Serialize()) + return SizeofXfrmLifetimeCfg +} + // struct xfrm_lifetime_cur { // __u64 bytes; // __u64 packets; @@ -275,6 +285,11 @@ func (msg *XfrmLifetimeCur) Serialize() []byte { return (*(*[SizeofXfrmLifetimeCur]byte)(unsafe.Pointer(msg)))[:] } +func (msg *XfrmLifetimeCur) SerializeTo(buf []byte) int { + copy(buf[0:SizeofXfrmLifetimeCur], msg.Serialize()) + return SizeofXfrmLifetimeCur +} + // struct xfrm_id { // xfrm_address_t daddr; // __be32 spi; @@ -300,6 +315,11 @@ func (msg *XfrmId) Serialize() []byte { return (*(*[SizeofXfrmId]byte)(unsafe.Pointer(msg)))[:] } +func (msg *XfrmId) SerializeTo(buf []byte) int { + copy(buf[0:SizeofXfrmId], msg.Serialize()) + return SizeofXfrmId +} + type XfrmMark struct { Value uint32 Mask uint32 @@ -316,3 +336,8 @@ func DeserializeXfrmMark(b []byte) *XfrmMark { func (msg *XfrmMark) Serialize() []byte { return (*(*[SizeofXfrmMark]byte)(unsafe.Pointer(msg)))[:] } + +func (msg *XfrmMark) SerializeTo(buf []byte) int { + copy(buf[0:SizeofXfrmMark], msg.Serialize()) + return SizeofXfrmMark +} diff --git a/nl/xfrm_monitor_linux.go b/nl/xfrm_monitor_linux.go index 715df4cc..04fe8e29 100644 --- a/nl/xfrm_monitor_linux.go +++ b/nl/xfrm_monitor_linux.go @@ -30,3 +30,8 @@ func DeserializeXfrmUserExpire(b []byte) *XfrmUserExpire { func (msg *XfrmUserExpire) Serialize() []byte { return (*(*[SizeofXfrmUserExpire]byte)(unsafe.Pointer(msg)))[:] } + +func (msg *XfrmUserExpire) SerializeTo(buf []byte) int { + copy(buf[0:SizeofXfrmUserExpire], msg.Serialize()) + return SizeofXfrmUserExpire +} diff --git a/nl/xfrm_policy_linux.go b/nl/xfrm_policy_linux.go index 66f7e03d..c3203dc2 100644 --- a/nl/xfrm_policy_linux.go +++ b/nl/xfrm_policy_linux.go @@ -36,6 +36,11 @@ func (msg *XfrmUserpolicyId) Serialize() []byte { return (*(*[SizeofXfrmUserpolicyId]byte)(unsafe.Pointer(msg)))[:] } +func (msg *XfrmUserpolicyId) SerializeTo(buf []byte) int { + copy(buf[0:SizeofXfrmUserpolicyId], msg.Serialize()) + return SizeofXfrmUserpolicyId +} + // struct xfrm_userpolicy_info { // struct xfrm_selector sel; // struct xfrm_lifetime_cfg lft; @@ -78,6 +83,11 @@ func (msg *XfrmUserpolicyInfo) Serialize() []byte { return (*(*[SizeofXfrmUserpolicyInfo]byte)(unsafe.Pointer(msg)))[:] } +func (msg *XfrmUserpolicyInfo) SerializeTo(buf []byte) int { + copy(buf[0:SizeofXfrmUserpolicyInfo], msg.Serialize()) + return SizeofXfrmUserpolicyInfo +} + // struct xfrm_user_tmpl { // struct xfrm_id id; // __u16 family; @@ -117,3 +127,8 @@ func DeserializeXfrmUserTmpl(b []byte) *XfrmUserTmpl { func (msg *XfrmUserTmpl) Serialize() []byte { return (*(*[SizeofXfrmUserTmpl]byte)(unsafe.Pointer(msg)))[:] } + +func (msg *XfrmUserTmpl) SerializeTo(buf []byte) int { + copy(buf[0:SizeofXfrmUserTmpl], msg.Serialize()) + return SizeofXfrmUserTmpl +} diff --git a/nl/xfrm_state_linux.go b/nl/xfrm_state_linux.go index e8920b9a..9af0cabc 100644 --- a/nl/xfrm_state_linux.go +++ b/nl/xfrm_state_linux.go @@ -61,6 +61,11 @@ func (msg *XfrmUsersaId) Serialize() []byte { return (*(*[SizeofXfrmUsersaId]byte)(unsafe.Pointer(msg)))[:] } +func (msg *XfrmUsersaId) SerializeTo(buf []byte) int { + copy(buf[0:SizeofXfrmUsersaId], msg.Serialize()) + return SizeofXfrmUsersaId +} + // struct xfrm_stats { // __u32 replay_window; // __u32 replay; @@ -85,6 +90,11 @@ func (msg *XfrmStats) Serialize() []byte { return (*(*[SizeofXfrmStats]byte)(unsafe.Pointer(msg)))[:] } +func (msg *XfrmStats) SerializeTo(buf []byte) int { + copy(buf[0:SizeofXfrmStats], msg.Serialize()) + return SizeofXfrmStats +} + // struct xfrm_usersa_info { // struct xfrm_selector sel; // struct xfrm_id id; @@ -140,6 +150,11 @@ func (msg *XfrmUsersaInfo) Serialize() []byte { return (*(*[SizeofXfrmUsersaInfo]byte)(unsafe.Pointer(msg)))[:] } +func (msg *XfrmUsersaInfo) SerializeTo(buf []byte) int { + copy(buf[0:SizeofXfrmUsersaInfo], msg.Serialize()) + return SizeofXfrmUsersaInfo +} + // struct xfrm_userspi_info { // struct xfrm_usersa_info info; // __u32 min; @@ -164,6 +179,11 @@ func (msg *XfrmUserSpiInfo) Serialize() []byte { return (*(*[SizeofXfrmUserSpiInfo]byte)(unsafe.Pointer(msg)))[:] } +func (msg *XfrmUserSpiInfo) SerializeTo(buf []byte) int { + copy(buf[0:SizeofXfrmUserSpiInfo], msg.Serialize()) + return SizeofXfrmUserSpiInfo +} + // struct xfrm_algo { // char alg_name[64]; // unsigned int alg_key_len; /* in bits */ @@ -196,6 +216,14 @@ func (msg *XfrmAlgo) Serialize() []byte { return b } +func (msg *XfrmAlgo) SerializeTo(buf []byte) int { + l := msg.Len() + copy(buf[0:64], msg.AlgName[:]) + copy(buf[64:68], (*(*[4]byte)(unsafe.Pointer(&msg.AlgKeyLen)))[:]) + copy(buf[68:l], msg.AlgKey[:]) + return l +} + // struct xfrm_algo_auth { // char alg_name[64]; // unsigned int alg_key_len; /* in bits */ @@ -232,6 +260,15 @@ func (msg *XfrmAlgoAuth) Serialize() []byte { return b } +func (msg *XfrmAlgoAuth) SerializeTo(buf []byte) int { + l := msg.Len() + copy(buf[0:64], msg.AlgName[:]) + copy(buf[64:68], (*(*[4]byte)(unsafe.Pointer(&msg.AlgKeyLen)))[:]) + copy(buf[68:72], (*(*[4]byte)(unsafe.Pointer(&msg.AlgTruncLen)))[:]) + copy(buf[72:l], msg.AlgKey[:]) + return l +} + // struct xfrm_algo_aead { // char alg_name[64]; // unsigned int alg_key_len; /* in bits */ @@ -268,6 +305,15 @@ func (msg *XfrmAlgoAEAD) Serialize() []byte { return b } +func (msg *XfrmAlgoAEAD) SerializeTo(buf []byte) int { + l := msg.Len() + copy(buf[0:64], msg.AlgName[:]) + copy(buf[64:68], (*(*[4]byte)(unsafe.Pointer(&msg.AlgKeyLen)))[:]) + copy(buf[68:72], (*(*[4]byte)(unsafe.Pointer(&msg.AlgICVLen)))[:]) + copy(buf[72:l], msg.AlgKey[:]) + return l +} + // struct xfrm_encap_tmpl { // __u16 encap_type; // __be16 encap_sport; @@ -295,6 +341,11 @@ func (msg *XfrmEncapTmpl) Serialize() []byte { return (*(*[SizeofXfrmEncapTmpl]byte)(unsafe.Pointer(msg)))[:] } +func (msg *XfrmEncapTmpl) SerializeTo(buf []byte) int { + copy(buf, msg.Serialize()) + return msg.Len() +} + // struct xfrm_usersa_flush { // __u8 proto; // }; @@ -315,6 +366,11 @@ func (msg *XfrmUsersaFlush) Serialize() []byte { return (*(*[SizeofXfrmUsersaFlush]byte)(unsafe.Pointer(msg)))[:] } +func (msg *XfrmUsersaFlush) SerializeTo(buf []byte) int { + copy(buf[0:SizeofXfrmUsersaFlush], msg.Serialize()) + return SizeofXfrmUsersaFlush +} + // struct xfrm_replay_state_esn { // unsigned int bmp_len; // __u32 oseq; diff --git a/rdma_link_linux.go b/rdma_link_linux.go index 76bd7199..afdc184d 100644 --- a/rdma_link_linux.go +++ b/rdma_link_linux.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "fmt" + "io" "net" "github.com/vishvananda/netlink/nl" @@ -42,13 +43,43 @@ func uint64ToGuidString(guid uint64) string { return sysGuidNet.String() } +func parseRaw16FromReader(reader *bytes.Reader) uint16 { + buf := make([]byte, 2) + reader.Read(buf) + return nl.NativeEndian().Uint16(buf) +} + +func parseNfAttrTLFromReader(reader *bytes.Reader) (isNested bool, attrType, len uint16) { + len = parseRaw16FromReader(reader) + len -= nl.SizeofNfattr + + attrType = parseRaw16FromReader(reader) + isNested = (attrType & nl.NLA_F_NESTED) == nl.NLA_F_NESTED + attrType = attrType & (nl.NLA_F_NESTED - 1) + return isNested, attrType, len +} + +func parseNfAttrTLVFromReader(r *bytes.Reader) (isNested bool, attrType, len uint16, value []byte) { + isNested, attrType, len = parseNfAttrTLFromReader(r) + + value = make([]byte, len) + n, err := io.ReadAtLeast(r, value, int(len)) + if err != nil { + panic(err) + } + if n != int(len) { + panic(fmt.Errorf("expected %d bytes for nfattr value, got %d", len, n)) + } + return isNested, attrType, len, value +} + func executeOneGetRdmaLink(data []byte) (*RdmaLink, error) { link := RdmaLink{} reader := bytes.NewReader(data) for reader.Len() >= 4 { - _, attrType, len, value := parseNfAttrTLV(reader) + _, attrType, len, value := parseNfAttrTLVFromReader(reader) switch attrType { case nl.RDMA_NLDEV_ATTR_DEV_INDEX: @@ -192,7 +223,7 @@ func netnsModeToString(mode uint8) string { func executeOneGetRdmaNetnsMode(data []byte) (string, error) { reader := bytes.NewReader(data) for reader.Len() >= 4 { - _, attrType, len, value := parseNfAttrTLV(reader) + _, attrType, len, value := parseNfAttrTLVFromReader(reader) switch attrType { case nl.RDMA_NLDEV_SYS_ATTR_NETNS_MODE: @@ -408,14 +439,14 @@ func parseRdmaCounters(counterType uint16, data []byte) (map[string]uint64, erro reader := bytes.NewReader(data) for reader.Len() >= 4 { - _, attrType, _, value := parseNfAttrTLV(reader) + _, attrType, _, value := parseNfAttrTLVFromReader(reader) if attrType != counterType { return nil, fmt.Errorf("Invalid resource summary entry type; %d", attrType) } summaryReader := bytes.NewReader(value) for summaryReader.Len() >= 4 { - _, attrType, len, value := parseNfAttrTLV(summaryReader) + _, attrType, len, value := parseNfAttrTLVFromReader(summaryReader) if attrType != counterKeyType { return nil, fmt.Errorf("Invalid resource summary entry name type; %d", attrType) } @@ -424,7 +455,7 @@ func parseRdmaCounters(counterType uint16, data []byte) (map[string]uint64, erro if (len % 4) != 0 { summaryReader.Seek(int64(4-(len%4)), seekCurrent) } - _, attrType, len, value = parseNfAttrTLV(summaryReader) + _, attrType, len, value = parseNfAttrTLVFromReader(summaryReader) if attrType != counterValueType { return nil, fmt.Errorf("Invalid resource summary entry value type; %d", attrType) } @@ -438,7 +469,7 @@ func executeOneGetRdmaResourceList(data []byte) (*RdmaResource, error) { var res RdmaResource reader := bytes.NewReader(data) for reader.Len() >= 4 { - _, attrType, len, value := parseNfAttrTLV(reader) + _, attrType, len, value := parseNfAttrTLVFromReader(reader) switch attrType { case nl.RDMA_NLDEV_ATTR_DEV_INDEX: @@ -537,7 +568,7 @@ func executeOneGetRdmaPortStatistics(data []byte) (*RdmaPortStatistic, error) { var stat RdmaPortStatistic reader := bytes.NewReader(data) for reader.Len() >= 4 { - _, attrType, len, value := parseNfAttrTLV(reader) + _, attrType, len, value := parseNfAttrTLVFromReader(reader) switch attrType { case nl.RDMA_NLDEV_ATTR_PORT_INDEX: diff --git a/socket_linux.go b/socket_linux.go index f8a39edc..423c4392 100644 --- a/socket_linux.go +++ b/socket_linux.go @@ -44,7 +44,13 @@ func (b *writeBuffer) Next(n int) []byte { } func (r *socketRequest) Serialize() []byte { - b := writeBuffer{Bytes: make([]byte, sizeofSocketRequest)} + buf := make([]byte, sizeofSocketRequest) + r.SerializeTo(buf) + return buf +} + +func (r *socketRequest) SerializeTo(buf []byte) int { + b := writeBuffer{Bytes: buf} b.Write(r.Family) b.Write(r.Protocol) b.Write(r.Ext) @@ -62,7 +68,7 @@ func (r *socketRequest) Serialize() []byte { native.PutUint32(b.Next(4), r.ID.Interface) native.PutUint32(b.Next(4), r.ID.Cookie[0]) native.PutUint32(b.Next(4), r.ID.Cookie[1]) - return b.Bytes + return sizeofSocketRequest } func (r *socketRequest) Len() int { return sizeofSocketRequest } @@ -79,7 +85,13 @@ type unixSocketRequest struct { } func (r *unixSocketRequest) Serialize() []byte { - b := writeBuffer{Bytes: make([]byte, sizeofUnixSocketRequest)} + buf := make([]byte, sizeofUnixSocketRequest) + r.SerializeTo(buf) + return buf +} + +func (r *unixSocketRequest) SerializeTo(buf []byte) int { + b := writeBuffer{Bytes: buf} b.Write(r.Family) b.Write(r.Protocol) native.PutUint16(b.Next(2), r.pad) @@ -88,7 +100,7 @@ func (r *unixSocketRequest) Serialize() []byte { native.PutUint32(b.Next(4), r.Show) native.PutUint32(b.Next(4), r.Cookie[0]) native.PutUint32(b.Next(4), r.Cookie[1]) - return b.Bytes + return sizeofUnixSocketRequest } func (r *unixSocketRequest) Len() int { return sizeofUnixSocketRequest } diff --git a/socket_xdp_linux.go b/socket_xdp_linux.go index 1a9ab1c9..7838a39c 100644 --- a/socket_xdp_linux.go +++ b/socket_xdp_linux.go @@ -25,7 +25,13 @@ type xdpSocketRequest struct { } func (r *xdpSocketRequest) Serialize() []byte { - b := writeBuffer{Bytes: make([]byte, sizeofSocketRequest)} + buf := make([]byte, sizeofXDPSocketRequest) + r.SerializeTo(buf) + return buf +} + +func (r *xdpSocketRequest) SerializeTo(buf []byte) int { + b := writeBuffer{Bytes: buf} b.Write(r.Family) b.Write(r.Protocol) native.PutUint16(b.Next(2), r.pad) @@ -33,7 +39,7 @@ func (r *xdpSocketRequest) Serialize() []byte { native.PutUint32(b.Next(4), r.Show) native.PutUint32(b.Next(4), r.Cookie[0]) native.PutUint32(b.Next(4), r.Cookie[1]) - return b.Bytes + return sizeofXDPSocketRequest } func (r *xdpSocketRequest) Len() int { return sizeofXDPSocketRequest } From 3732aa8a9f8a05fcac673ae143d1a8aac130bc60 Mon Sep 17 00:00:00 2001 From: eustrain Date: Tue, 24 Feb 2026 07:21:50 +0000 Subject: [PATCH 5/5] feat: remove allocator in ConntrackTableList --- conntrack_linux.go | 26 +++++++++++++++++--------- conntrack_test.go | 44 ++++++++++++++++++++++++++------------------ gtp_linux.go | 8 ++++---- 3 files changed, 47 insertions(+), 31 deletions(-) diff --git a/conntrack_linux.go b/conntrack_linux.go index cb390045..e13c0123 100644 --- a/conntrack_linux.go +++ b/conntrack_linux.go @@ -43,13 +43,14 @@ type InetFamily uint8 // -C [table] Show counter // -S Show statistics -// ConntrackTableList returns the flow list of a table of a specific family +// ConntrackTableList returns the flow list of a table of a specific family. +// It pre-allocates a single []ConntrackFlow slice and reuses it to avoid per-flow allocations. // conntrack -L [table] [options] List conntrack or expectation table // // If the returned error is [ErrDumpInterrupted], results may be inconsistent // or incomplete. -func ConntrackTableList(table ConntrackTableType, family InetFamily, allocator func() *ConntrackFlow) ([]*ConntrackFlow, error) { - return pkgHandle.ConntrackTableList(table, family, allocator) +func ConntrackTableList(table ConntrackTableType, family InetFamily) ([]*ConntrackFlow, error) { + return pkgHandle.ConntrackTableList(table, family) } // ConntrackTableFlush flushes all the flows of a specified table @@ -89,21 +90,28 @@ func ConntrackTableListStream(table ConntrackTableType, family InetFamily, handl return pkgHandle.ConntrackTableListStream(table, family, handle, allocator) } -// ConntrackTableList returns the flow list of a table of a specific family using the netlink handle passed +// ConntrackTableList returns the flow list of a table of a specific family using the netlink handle passed. +// It pre-allocates a single []ConntrackFlow slice and reuses elements to avoid per-flow allocations. // conntrack -L [table] [options] List conntrack or expectation table // // If the returned error is [ErrDumpInterrupted], results may be inconsistent // or incomplete. -func (h *Handle) ConntrackTableList(table ConntrackTableType, family InetFamily, allocator func() *ConntrackFlow) ([]*ConntrackFlow, error) { +func (h *Handle) ConntrackTableList(table ConntrackTableType, family InetFamily) ([]*ConntrackFlow, error) { res, executeErr := h.dumpConntrackTable(table, family) if executeErr != nil && !errors.Is(executeErr, ErrDumpInterrupted) { return nil, executeErr } - // Deserialize all the flows - var result []*ConntrackFlow - for _, dataRaw := range res { - result = append(result, parseRawData(dataRaw, allocator)) + flows := make([]ConntrackFlow, len(res)) + result := make([]*ConntrackFlow, len(res)) + i := 0 + allocator := func() *ConntrackFlow { + p := &flows[i] + i++ + return p + } + for j := range res { + result[j] = parseRawData(res[j], allocator) } return result, executeErr diff --git a/conntrack_test.go b/conntrack_test.go index 466175a7..04281c0d 100644 --- a/conntrack_test.go +++ b/conntrack_test.go @@ -209,7 +209,7 @@ func TestConntrackTableList(t *testing.T) { udpFlowCreateProg(t, 5, 2000, "127.0.0.10", 3000) // Fetch the conntrack table - flows, err := h.ConntrackTableList(ConntrackTable, unix.AF_INET, nil) + flows, err := h.ConntrackTableList(ConntrackTable, unix.AF_INET) CheckErrorFail(t, err) // Check that it is able to find the 5 flows created @@ -242,7 +242,7 @@ func TestConntrackTableList(t *testing.T) { } // Give a try also to the IPv6 version - _, err = h.ConntrackTableList(ConntrackTable, unix.AF_INET6, nil) + _, err = h.ConntrackTableList(ConntrackTable, unix.AF_INET6) CheckErrorFail(t, err) // Switch back to the original namespace @@ -276,7 +276,7 @@ func TestConntrackTableFlush(t *testing.T) { udpFlowCreateProg(t, 5, 3000, "127.0.0.10", 4000) // Fetch the conntrack table - flows, err := h.ConntrackTableList(ConntrackTable, unix.AF_INET, nil) + flows, err := h.ConntrackTableList(ConntrackTable, unix.AF_INET) CheckErrorFail(t, err) // Check that it is able to find the 5 flows created @@ -298,7 +298,7 @@ func TestConntrackTableFlush(t *testing.T) { CheckErrorFail(t, err) // Fetch again the flows to validate the flush - flows, err = h.ConntrackTableList(ConntrackTable, unix.AF_INET, nil) + flows, err = h.ConntrackTableList(ConntrackTable, unix.AF_INET) CheckErrorFail(t, err) // Check if it is still able to find the 5 flows created @@ -349,7 +349,7 @@ func TestConntrackTableDelete(t *testing.T) { udpFlowCreateProg(t, 5, 7000, "127.0.0.20", 8000) // Fetch the conntrack table - flows, err := h.ConntrackTableList(ConntrackTable, unix.AF_INET, nil) + flows, err := h.ConntrackTableList(ConntrackTable, unix.AF_INET) CheckErrorFail(t, err) // Check that it is able to find the 5 flows created for each group @@ -389,7 +389,7 @@ func TestConntrackTableDelete(t *testing.T) { } // Check again the table to verify that are gone - flows, err = h.ConntrackTableList(ConntrackTable, unix.AF_INET, nil) + flows, err = h.ConntrackTableList(ConntrackTable, unix.AF_INET) CheckErrorFail(t, err) // Check if it is able to find the 5 flows of groupA but none of groupB @@ -809,11 +809,19 @@ func TestConntrackFilter(t *testing.T) { t.Fatalf("Error, there should be only 1 match, v4:%d, v6:%d", v4Match, v6Match) } - // Labels filter + // Labels filter: for ConntrackMatchLabels, label must be contained in flow.Labels (bytes.Contains). + // The TCP flow (10.0.0.2) has Labels {0,0,0,0,3,4,61,141,207,170,2,0,0,0,0,0}; use it as the label. filterV4 = &ConntrackFilter{} + err = filterV4.AddProtocol(6) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + err = filterV4.AddPort(ConntrackOrigSrcPort, 5000) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } var labels [][16]byte - labels = append(labels, [16]byte{3, 4, 61, 141, 207, 170, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) - labels = append(labels, [16]byte{0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}) + labels = append(labels, [16]byte{0, 0, 0, 0, 3, 4, 61, 141, 207, 170, 2, 0, 0, 0, 0, 0}) err = filterV4.AddLabels(ConntrackMatchLabels, labels) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -1129,7 +1137,7 @@ func TestConntrackUpdateV4(t *testing.T) { t.Fatalf("failed to insert conntrack: %s", err) } - flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4, nil) + flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4) if err != nil { t.Fatalf("failed to list conntracks following successful insert: %s", err) } @@ -1175,7 +1183,7 @@ func TestConntrackUpdateV4(t *testing.T) { } // Look for updated conntrack. - flows, err = h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4, nil) + flows, err = h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4) if err != nil { t.Fatalf("failed to list conntracks following successful update: %s", err) } @@ -1262,7 +1270,7 @@ func TestConntrackUpdateV6(t *testing.T) { t.Fatalf("failed to insert conntrack: %s", err) } - flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V6, nil) + flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V6) if err != nil { t.Fatalf("failed to list conntracks following successful insert: %s", err) } @@ -1308,7 +1316,7 @@ func TestConntrackUpdateV6(t *testing.T) { } // Look for updated conntrack. - flows, err = h.ConntrackTableList(ConntrackTable, nl.FAMILY_V6, nil) + flows, err = h.ConntrackTableList(ConntrackTable, nl.FAMILY_V6) if err != nil { t.Fatalf("failed to list conntracks following successful update: %s", err) } @@ -1388,7 +1396,7 @@ func TestConntrackCreateV4(t *testing.T) { t.Fatalf("failed to insert conntrack: %s", err) } - flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4, nil) + flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4) if err != nil { t.Fatalf("failed to list conntracks following successful insert: %s", err) } @@ -1483,7 +1491,7 @@ func TestConntrackCreateV6(t *testing.T) { t.Fatalf("failed to insert conntrack: %s", err) } - flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V6, nil) + flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V6) if err != nil { t.Fatalf("failed to list conntracks following successful insert: %s", err) } @@ -1583,7 +1591,7 @@ func TestConntrackLabels(t *testing.T) { t.Fatalf("failed to insert conntrack: %s", err) } - flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4, nil) + flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4) if err != nil { t.Fatalf("failed to list conntracks following successful insert: %s", err) } @@ -1630,7 +1638,7 @@ func TestConntrackLabels(t *testing.T) { } // Look for updated conntrack. - flows, err = h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4, nil) + flows, err = h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4) if err != nil { t.Fatalf("failed to list conntracks following successful update: %s", err) } @@ -2011,7 +2019,7 @@ func TestExecuteConntrackRequest(t *testing.T) { t.Fatalf("ExecuteConntrackRequest: %v", err) } - flows, err := h.ConntrackTableList(ConntrackTable, unix.AF_INET, nil) + flows, err := h.ConntrackTableList(ConntrackTable, unix.AF_INET) if err != nil { t.Fatalf("ConntrackTableList: %v", err) } diff --git a/gtp_linux.go b/gtp_linux.go index 078e16a5..f6df344a 100644 --- a/gtp_linux.go +++ b/gtp_linux.go @@ -105,7 +105,7 @@ func GTPPDPList() ([]*PDP, error) { return pkgHandle.GTPPDPList() } -func gtpPDPGet(req *nl.NetlinkRequest) (*PDP, error) { +func gtpPDPGet(req nl.NetlinkRequest) (*PDP, error) { msgs, err := req.Execute(unix.NETLINK_GENERIC, 0) if err != nil { return nil, err @@ -134,7 +134,7 @@ func (h *Handle) GTPPDPByTID(link Link, tid int) (*PDP, error) { req.AddData(nl.NewRtAttr(nl.GENL_GTP_ATTR_VERSION, nl.Uint32Attr(0))) req.AddData(nl.NewRtAttr(nl.GENL_GTP_ATTR_LINK, nl.Uint32Attr(uint32(link.Attrs().Index)))) req.AddData(nl.NewRtAttr(nl.GENL_GTP_ATTR_TID, nl.Uint64Attr(uint64(tid)))) - return gtpPDPGet(&req) + return gtpPDPGet(req) } func GTPPDPByTID(link Link, tid int) (*PDP, error) { @@ -155,7 +155,7 @@ func (h *Handle) GTPPDPByITEI(link Link, itei int) (*PDP, error) { req.AddData(nl.NewRtAttr(nl.GENL_GTP_ATTR_VERSION, nl.Uint32Attr(1))) req.AddData(nl.NewRtAttr(nl.GENL_GTP_ATTR_LINK, nl.Uint32Attr(uint32(link.Attrs().Index)))) req.AddData(nl.NewRtAttr(nl.GENL_GTP_ATTR_I_TEI, nl.Uint32Attr(uint32(itei)))) - return gtpPDPGet(&req) + return gtpPDPGet(req) } func GTPPDPByITEI(link Link, itei int) (*PDP, error) { @@ -176,7 +176,7 @@ func (h *Handle) GTPPDPByMSAddress(link Link, addr net.IP) (*PDP, error) { req.AddData(nl.NewRtAttr(nl.GENL_GTP_ATTR_VERSION, nl.Uint32Attr(0))) req.AddData(nl.NewRtAttr(nl.GENL_GTP_ATTR_LINK, nl.Uint32Attr(uint32(link.Attrs().Index)))) req.AddData(nl.NewRtAttr(nl.GENL_GTP_ATTR_MS_ADDRESS, []byte(addr.To4()))) - return gtpPDPGet(&req) + return gtpPDPGet(req) } func GTPPDPByMSAddress(link Link, addr net.IP) (*PDP, error) {