From 806d1cc3e4724530866c838125c19fe2b08e4073 Mon Sep 17 00:00:00 2001 From: Parth Sarthi Date: Fri, 31 Oct 2025 15:33:49 -0700 Subject: [PATCH] nftables: Added uint8/16 and int8/16/64 BytesView converters. PiperOrigin-RevId: 826652100 --- .../socket/netlink/netfilter/protocol.go | 6 +- pkg/sentry/socket/netlink/nlmsg/message.go | 55 +++++++++++++++++++ .../linux/socket_netlink_netfilter.cc | 22 ++++---- 3 files changed, 69 insertions(+), 14 deletions(-) diff --git a/pkg/sentry/socket/netlink/netfilter/protocol.go b/pkg/sentry/socket/netlink/netfilter/protocol.go index 5deadd180d..adacd29b9e 100644 --- a/pkg/sentry/socket/netlink/netfilter/protocol.go +++ b/pkg/sentry/socket/netlink/netfilter/protocol.go @@ -442,7 +442,7 @@ func (p *Protocol) addChain(attrs map[uint16]nlmsg.BytesView, tab *nftables.Tabl return syserr.NewAnnotatedError(syserr.ErrNotSupported, fmt.Sprintf("Nftables: Chain binding attribute is not supported for chains with a hook")) } - bcInfo, err = p.chainParseHook(nil, family, nlmsg.AttrsView(hookDataBytes)) + bcInfo, err = p.chainParseHook(nil, family, nlmsg.AttrsView(hookDataBytes), attrs) if err != nil { return err } @@ -494,7 +494,7 @@ func (p *Protocol) addChain(attrs map[uint16]nlmsg.BytesView, tab *nftables.Tabl // chainParseHook parses the hook attributes and returns a complete // BaseChainInfo. -func (p *Protocol) chainParseHook(chain *nftables.Chain, family stack.AddressFamily, hdata nlmsg.AttrsView) (*nftables.BaseChainInfo, *syserr.AnnotatedError) { +func (p *Protocol) chainParseHook(chain *nftables.Chain, family stack.AddressFamily, hdata nlmsg.AttrsView, attrs map[uint16]nlmsg.BytesView) (*nftables.BaseChainInfo, *syserr.AnnotatedError) { hookAttrs, ok := nftables.NfParse(hdata) if !ok { return nil, syserr.NewAnnotatedError(syserr.ErrInvalidArgument, fmt.Sprintf("Nftables: Failed to parse hook attributes")) @@ -530,7 +530,7 @@ func (p *Protocol) chainParseHook(chain *nftables.Chain, family stack.AddressFam // All families default to filter type. hookInfo.ChainType = nftables.BaseChainTypeFilter - if chainTypeBytes, ok := hookAttrs[linux.NFTA_CHAIN_TYPE]; ok { + if chainTypeBytes, ok := attrs[linux.NFTA_CHAIN_TYPE]; ok { // TODO - b/434243967: Support base chain types other than filter. switch chainType := chainTypeBytes.String(); chainType { case "filter": diff --git a/pkg/sentry/socket/netlink/nlmsg/message.go b/pkg/sentry/socket/netlink/nlmsg/message.go index cf5d0d27af..75d3babf1d 100644 --- a/pkg/sentry/socket/netlink/nlmsg/message.go +++ b/pkg/sentry/socket/netlink/nlmsg/message.go @@ -374,6 +374,28 @@ func (v *BytesView) String() string { return string(b) } +// Uint8 converts the raw attribute value to uint8. +func (v *BytesView) Uint8() (uint8, bool) { + attr := []byte(*v) + val := primitive.Uint8(0) + if len(attr) != val.SizeBytes() { + return 0, false + } + val.UnmarshalBytes(attr) + return uint8(val), true +} + +// Uint16 converts the raw attribute value to uint16. +func (v *BytesView) Uint16() (uint16, bool) { + attr := []byte(*v) + val := primitive.Uint16(0) + if len(attr) != val.SizeBytes() { + return 0, false + } + val.UnmarshalBytes(attr) + return uint16(val), true +} + // Uint32 converts the raw attribute value to uint32. func (v *BytesView) Uint32() (uint32, bool) { attr := []byte(*v) @@ -396,6 +418,28 @@ func (v *BytesView) Uint64() (uint64, bool) { return uint64(val), true } +// Int8 converts the raw attribute value to int8. +func (v *BytesView) Int8() (int8, bool) { + attr := []byte(*v) + val := primitive.Int8(0) + if len(attr) != val.SizeBytes() { + return 0, false + } + val.UnmarshalBytes(attr) + return int8(val), true +} + +// Int16 converts the raw attribute value to int32. +func (v *BytesView) Int16() (int16, bool) { + attr := []byte(*v) + val := primitive.Int16(0) + if len(attr) != val.SizeBytes() { + return 0, false + } + val.UnmarshalBytes(attr) + return int16(val), true +} + // Int32 converts the raw attribute value to int32. func (v *BytesView) Int32() (int32, bool) { attr := []byte(*v) @@ -407,6 +451,17 @@ func (v *BytesView) Int32() (int32, bool) { return int32(val), true } +// Int64 converts the raw attribute value to int32. +func (v *BytesView) Int64() (int64, bool) { + attr := []byte(*v) + val := primitive.Int64(0) + if len(attr) != val.SizeBytes() { + return 0, false + } + val.UnmarshalBytes(attr) + return int64(val), true +} + // NetToHostU16 converts a uint16 in network byte order to // host byte order value. func NetToHostU16(v uint16) uint16 { diff --git a/test/syscalls/linux/socket_netlink_netfilter.cc b/test/syscalls/linux/socket_netlink_netfilter.cc index 9998001e16..f981f4e218 100644 --- a/test/syscalls/linux/socket_netlink_netfilter.cc +++ b/test/syscalls/linux/socket_netlink_netfilter.cc @@ -932,7 +932,6 @@ TEST(NetlinkNetfilterTest, ErrNewBaseChainWithInvalidPolicy) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_table_request_buffer = @@ -955,6 +954,7 @@ TEST(NetlinkNetfilterTest, ErrNewBaseChainWithInvalidPolicy) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .Build()) .SeqEnd(kSeq + 5) @@ -1122,7 +1122,6 @@ TEST(NetlinkNetfilterTest, ErrNewBaseChainWithInvalidChainType) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_request_buffer = @@ -1139,6 +1138,7 @@ TEST(NetlinkNetfilterTest, ErrNewBaseChainWithInvalidChainType) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .Build()) .SeqEnd(kSeq + 3) @@ -1167,7 +1167,6 @@ TEST(NetlinkNetfilterTest, ErrNewNATBaseChainWithInvalidPriority) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_request_buffer = @@ -1184,6 +1183,7 @@ TEST(NetlinkNetfilterTest, ErrNewNATBaseChainWithInvalidPriority) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .Build()) .SeqEnd(kSeq + 3) @@ -1212,7 +1212,6 @@ TEST(NetlinkNetfilterTest, ErrUnsupportedNewNetDevBaseChain) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_request_buffer = @@ -1229,6 +1228,7 @@ TEST(NetlinkNetfilterTest, ErrUnsupportedNewNetDevBaseChain) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .Build()) .SeqEnd(kSeq + 3) @@ -1257,7 +1257,6 @@ TEST(NetlinkNetfilterTest, ErrUnsupportedNewInetBaseChainAtIngress) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_request_buffer = @@ -1274,6 +1273,7 @@ TEST(NetlinkNetfilterTest, ErrUnsupportedNewInetBaseChainAtIngress) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .Build()) .SeqEnd(kSeq + 3) @@ -1302,7 +1302,6 @@ TEST(NetlinkNetfilterTest, ErrUnsupportedNewBaseChainWithChainCounters) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_request_buffer = @@ -1319,6 +1318,7 @@ TEST(NetlinkNetfilterTest, ErrUnsupportedNewBaseChainWithChainCounters) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .RawAttr(NFTA_CHAIN_COUNTERS, nullptr, 0) .Build()) @@ -1540,7 +1540,6 @@ TEST(NetlinkNetfilterTest, AddBaseChainWithDropPolicy) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_request_buffer = @@ -1557,6 +1556,7 @@ TEST(NetlinkNetfilterTest, AddBaseChainWithDropPolicy) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .Build()) .SeqEnd(kSeq + 3) @@ -1787,7 +1787,6 @@ TEST(NetlinkNetfilterTest, GetBaseChain) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_request_buffer = @@ -1804,6 +1803,7 @@ TEST(NetlinkNetfilterTest, GetBaseChain) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .RawAttr(NFTA_CHAIN_USERDATA, test_user_data, expected_udata_size) @@ -1857,7 +1857,6 @@ TEST(NetlinkNetfilterTest, ErrDeleteChainWithNoTableNameSpecified) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_request_buffer = @@ -1874,6 +1873,7 @@ TEST(NetlinkNetfilterTest, ErrDeleteChainWithNoTableNameSpecified) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .Build()) .SeqEnd(kSeq + 3) @@ -1984,7 +1984,6 @@ TEST(NetlinkNetfilterTest, DeleteBaseChain) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_request_buffer = @@ -2001,6 +2000,7 @@ TEST(NetlinkNetfilterTest, DeleteBaseChain) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .Build()) .SeqEnd(kSeq + 3) @@ -2041,7 +2041,6 @@ TEST(NetlinkNetfilterTest, DeleteBaseChainByHandle) { NlNestedAttr() .U32Attr(NFTA_HOOK_HOOKNUM, test_hook_num) .U32Attr(NFTA_HOOK_PRIORITY, test_hook_priority) - .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .Build(); std::vector add_request_buffer = @@ -2058,6 +2057,7 @@ TEST(NetlinkNetfilterTest, DeleteBaseChainByHandle) { .U32Attr(NFTA_CHAIN_POLICY, test_policy) .RawAttr(NFTA_CHAIN_HOOK, nested_hook_data.data(), nested_hook_data.size()) + .StrAttr(NFTA_CHAIN_TYPE, test_chain_type_name) .U32Attr(NFTA_CHAIN_FLAGS, test_chain_flags) .Build()) .SeqEnd(kSeq + 3)