From 2d23faf068de2742cbe54b2192ad0f359c8dc450 Mon Sep 17 00:00:00 2001 From: Shailend Chand Date: Wed, 12 Nov 2025 22:00:10 -0800 Subject: [PATCH] Add RTNLGRP_LINK netlink multicast support for recv Only netstack supports sending link events, requests to join multicast groups for netlink sockets continue to be denied when hostinet is in use. connect() and sendmsg() to RTMGRP_LINK still fails. PiperOrigin-RevId: 831688843 --- pkg/abi/linux/netlink.go | 5 + pkg/sentry/inet/BUILD | 9 + pkg/sentry/inet/inet.go | 2 +- pkg/sentry/inet/namespace.go | 28 +- pkg/sentry/inet/nlmcast.go | 143 ++++++ pkg/sentry/inet/test_stack.go | 2 +- pkg/sentry/socket/hostinet/stack.go | 2 +- pkg/sentry/socket/netlink/BUILD | 1 + pkg/sentry/socket/netlink/provider.go | 13 + pkg/sentry/socket/netlink/route/protocol.go | 34 +- pkg/sentry/socket/netlink/socket.go | 189 +++++++- pkg/sentry/socket/netstack/BUILD | 10 + pkg/sentry/socket/netstack/stack.go | 119 ++++- pkg/tcpip/stack/stack.go | 128 +++-- test/syscalls/linux/socket_netlink_route.cc | 443 ++++++++++++++++-- .../linux/socket_netlink_route_util.cc | 1 + .../linux/socket_netlink_route_util.h | 1 + test/syscalls/linux/socket_netlink_util.cc | 13 +- test/syscalls/linux/socket_netlink_util.h | 12 + 19 files changed, 1012 insertions(+), 143 deletions(-) create mode 100644 pkg/sentry/inet/nlmcast.go diff --git a/pkg/abi/linux/netlink.go b/pkg/abi/linux/netlink.go index 2be0b7553c..91b7bf2516 100644 --- a/pkg/abi/linux/netlink.go +++ b/pkg/abi/linux/netlink.go @@ -157,3 +157,8 @@ type NetlinkErrorMessage struct { Error int32 Header NetlinkMessageHeader } + +// RTNetlink multicast groups, from uapi/linux/rtnetlink.h. +const ( + RTNLGRP_LINK = 1 +) diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index 035dc3b099..067656ac70 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -26,6 +26,13 @@ declare_mutex( prefix = "abstractSocketNamespace", ) +declare_mutex( + name = "nlmcast_table_mutex", + out = "nlmcast_table_mutex.go", + package = "inet", + prefix = "nlmcastTable", +) + go_library( name = "inet", srcs = [ @@ -35,6 +42,8 @@ go_library( "inet.go", "namespace.go", "namespace_refs.go", + "nlmcast.go", + "nlmcast_table_mutex.go", "test_stack.go", ], deps = [ diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index b5e4dadd9d..1166096329 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -32,7 +32,7 @@ type Stack interface { Interfaces() map[int32]Interface // RemoveInterface removes the specified network interface. - RemoveInterface(idx int32) error + RemoveInterface(ctx context.Context, idx int32) error // InterfaceAddrs returns all network interface addresses as a mapping from // interface indexes to a slice of associated interface address properties. diff --git a/pkg/sentry/inet/namespace.go b/pkg/sentry/inet/namespace.go index 0ffeee03d6..d2a4bf4618 100644 --- a/pkg/sentry/inet/namespace.go +++ b/pkg/sentry/inet/namespace.go @@ -45,6 +45,9 @@ type Namespace struct { // abstractSockets tracks abstract sockets that are in use. abstractSockets AbstractSocketNamespace + + // netlinkMcastTable manages multicast group membership for netlink sockets. + netlinkMcastTable *McastTable } // NewRootNamespace creates the root network namespace, with creator @@ -52,10 +55,14 @@ type Namespace struct { // networking will function if the network is namespaced. func NewRootNamespace(stack Stack, creator NetworkStackCreator, userNS *auth.UserNamespace) *Namespace { n := &Namespace{ - stack: stack, - creator: creator, - isRoot: true, - userNS: userNS, + stack: stack, + creator: creator, + isRoot: true, + userNS: userNS, + netlinkMcastTable: NewNetlinkMcastTable(), + } + if eventPublishingStack, ok := stack.(InterfaceEventPublisher); ok { + eventPublishingStack.AddInterfaceEventSubscriber(n.netlinkMcastTable) } n.abstractSockets.init() return n @@ -79,8 +86,9 @@ func (n *Namespace) GetInode() *nsfs.Inode { // NewNamespace creates a new network namespace from the root. func NewNamespace(root *Namespace, userNS *auth.UserNamespace) *Namespace { n := &Namespace{ - creator: root.creator, - userNS: userNS, + creator: root.creator, + userNS: userNS, + netlinkMcastTable: NewNetlinkMcastTable(), } n.init() return n @@ -148,6 +156,9 @@ func (n *Namespace) init() { if err != nil { panic(err) } + if eventPublishingStack, ok := n.stack.(InterfaceEventPublisher); ok { + eventPublishingStack.AddInterfaceEventSubscriber(n.netlinkMcastTable) + } } n.abstractSockets.init() } @@ -162,6 +173,11 @@ func (n *Namespace) AbstractSockets() *AbstractSocketNamespace { return &n.abstractSockets } +// NetlinkMcastTable returns the netlink multicast group table. +func (n *Namespace) NetlinkMcastTable() *McastTable { + return n.netlinkMcastTable +} + // NetworkStackCreator allows new instances of a network stack to be created. It // is used by the kernel to create new network namespaces when requested. type NetworkStackCreator interface { diff --git a/pkg/sentry/inet/nlmcast.go b/pkg/sentry/inet/nlmcast.go new file mode 100644 index 0000000000..9f3a521899 --- /dev/null +++ b/pkg/sentry/inet/nlmcast.go @@ -0,0 +1,143 @@ +// Copyright 2025 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inet + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" +) + +const ( + routeProtocol = linux.NETLINK_ROUTE + routeLinkMcastGroup = linux.RTNLGRP_LINK +) + +// InterfaceEventSubscriber allows clients to subscribe to events published by an inet.Stack. +// +// It is a rough parallel to the objects in Linux that subscribe to netdev +// events by calling register_netdevice_notifier(). +type InterfaceEventSubscriber interface { + // OnInterfaceChangeEvent is called by InterfaceEventPublishers when an interface change event takes place. + OnInterfaceChangeEvent(ctx context.Context, idx int32, i Interface) + + // OnInterfaceDeleteEvent is called by InterfaceEventPublishers when an interface delete event takes place. + OnInterfaceDeleteEvent(ctx context.Context, idx int32, i Interface) +} + +// InterfaceEventPublisher is the interface event publishing aspect of an inet.Stack. +// +// The Linux parallel is how it notifies subscribers via call_netdev_notifiers(). +type InterfaceEventPublisher interface { + AddInterfaceEventSubscriber(sub InterfaceEventSubscriber) +} + +// NetlinkSocket corresponds to a netlink socket. +type NetlinkSocket interface { + // Protocol returns the netlink protocol value. + Protocol() int + + // Groups returns the bitmap of multicast groups the socket is bound to. + Groups() uint64 + + // HandleInterfaceChangeEvent is called on NetlinkSockets that are members of the RTNLGRP_LINK + // multicast group when an interface is modified. + HandleInterfaceChangeEvent(context.Context, int32, Interface) + + // HandleInterfaceDeleteEvent is called on NetlinkSockets that are members of the RTNLGRP_LINK + // multicast group when an interface is deleted. + HandleInterfaceDeleteEvent(context.Context, int32, Interface) +} + +// McastTable holds multicast group membership information for netlink netlinkSocket. +// It corresponds roughly to Linux's struct netlink_table. +// +// +stateify savable +type McastTable struct { + mu nlmcastTableMutex `state:"nosave"` + socks map[int]map[NetlinkSocket]struct{} +} + +// WithTableLocked runs fn with the table mutex held. +func (m *McastTable) WithTableLocked(fn func()) { + m.mu.Lock() + defer m.mu.Unlock() + fn() +} + +// AddSocket adds a netlinkSocket to the multicast-group table. +// +// Preconditions: the netlink multicast table is locked. +func (m *McastTable) AddSocket(s NetlinkSocket) { + p := s.Protocol() + if _, ok := m.socks[p]; !ok { + m.socks[p] = make(map[NetlinkSocket]struct{}) + } + if _, ok := m.socks[p][s]; ok { + return + } + m.socks[p][s] = struct{}{} +} + +// RemoveSocket removes a netlinkSocket from the multicast-group table. +// +// Preconditions: the netlink multicast table is locked. +func (m *McastTable) RemoveSocket(s NetlinkSocket) { + p := s.Protocol() + if _, ok := m.socks[p]; !ok { + return + } + if _, ok := m.socks[p][s]; !ok { + return + } + delete(m.socks[p], s) +} + +func (m *McastTable) forEachMcastSock(protocol int, mcastGroup int, fn func(s NetlinkSocket)) { + m.mu.Lock() + defer m.mu.Unlock() + if _, ok := m.socks[protocol]; !ok { + return + } + for s := range m.socks[protocol] { + // If the socket is not bound to the multicast group, skip it. + if s.Groups()&(1<<(mcastGroup-1)) == 0 { + continue + } + fn(s) + } +} + +// OnInterfaceChangeEvent implements InterfaceEventSubscriber.OnInterfaceChangeEvent. +func (m *McastTable) OnInterfaceChangeEvent(ctx context.Context, idx int32, i Interface) { + // Relay the event to RTNLGRP_LINK subscribers. + m.forEachMcastSock(routeProtocol, routeLinkMcastGroup, func(s NetlinkSocket) { + s.HandleInterfaceChangeEvent(ctx, idx, i) + }) +} + +// OnInterfaceDeleteEvent implements InterfaceEventSubscriber.OnInterfaceDeleteEvent. +func (m *McastTable) OnInterfaceDeleteEvent(ctx context.Context, idx int32, i Interface) { + // Relay the event to RTNLGRP_LINK subscribers. + m.forEachMcastSock(routeProtocol, routeLinkMcastGroup, func(s NetlinkSocket) { + s.HandleInterfaceDeleteEvent(ctx, idx, i) + }) +} + +// NewNetlinkMcastTable creates a new McastTable. +func NewNetlinkMcastTable() *McastTable { + return &McastTable{ + socks: make(map[int]map[NetlinkSocket]struct{}), + } +} diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index d083aef3e0..894c3e96ab 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -61,7 +61,7 @@ func (s *TestStack) Destroy() { } // RemoveInterface implements Stack. -func (s *TestStack) RemoveInterface(idx int32) error { +func (s *TestStack) RemoveInterface(ctx context.Context, idx int32) error { delete(s.InterfacesMap, idx) return nil } diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 5419d991e5..72f23b8db0 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -152,7 +152,7 @@ func (s *Stack) Interfaces() map[int32]inet.Interface { } // RemoveInterface implements inet.Stack.RemoveInterface. -func (*Stack) RemoveInterface(idx int32) error { +func (*Stack) RemoveInterface(ctx context.Context, idx int32) error { return removeInterface(idx) } diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 997e43728e..15399c9148 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -15,6 +15,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/abi/linux/errno", + "//pkg/atomicbitops", "//pkg/context", "//pkg/errors/linuxerr", "//pkg/hostarch", diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index df302cd827..bf895439d2 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" + "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netlink/nlmsg" @@ -51,6 +52,18 @@ type Protocol interface { ProcessMessage(ctx context.Context, s *Socket, msg *nlmsg.Message, ms *nlmsg.MessageSet) *syserr.Error } +// RouteProtocol corresponds to the NETLINK_ROUTE family. +type RouteProtocol interface { + Protocol + + // AddNewLinkMessage is called when an interface is mutated or created by the stack. + // It is the rough equivalent of Linux's rtnetlink_event(). + AddNewLinkMessage(ms *nlmsg.MessageSet, idx int32, i inet.Interface) + + // AddDelLinkMessage is called when an interface is deleted by the stack. + AddDelLinkMessage(ms *nlmsg.MessageSet, idx int32, i inet.Interface) +} + // Provider is a function that creates a new Protocol for a specific netlink // protocol. // diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index 34ec373154..fa1b0f546c 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -101,7 +101,7 @@ func (p *Protocol) dumpLinks(ctx context.Context, s *netlink.Socket, msg *nlmsg. } for idx, i := range stack.Interfaces() { - addNewLinkMessage(ms, idx, i) + p.AddNewLinkMessage(ms, idx, i) } return nil @@ -158,7 +158,7 @@ func (p *Protocol) getLink(ctx context.Context, s *netlink.Socket, msg *nlmsg.Me return syserr.ErrInvalidArgument } - addNewLinkMessage(ms, idx, i) + p.AddNewLinkMessage(ms, idx, i) found = true break } @@ -232,16 +232,10 @@ func (p *Protocol) delLink(ctx context.Context, s *netlink.Socket, msg *nlmsg.Me return syserr.ErrNoDevice } } - return syserr.FromError(stack.RemoveInterface(ifinfomsg.Index)) + return syserr.FromError(stack.RemoveInterface(ctx, ifinfomsg.Index)) } -// addNewLinkMessage appends RTM_NEWLINK message for the given interface into -// the message set. -func addNewLinkMessage(ms *nlmsg.MessageSet, idx int32, i inet.Interface) { - m := ms.AddMessage(linux.NetlinkMessageHeader{ - Type: linux.RTM_NEWLINK, - }) - +func writeLinkInfo(m *nlmsg.Message, idx int32, i inet.Interface) { m.Put(&linux.InterfaceInfoMessage{ Family: linux.AF_UNSPEC, Type: i.DeviceType, @@ -264,6 +258,26 @@ func addNewLinkMessage(ms *nlmsg.MessageSet, idx int32, i inet.Interface) { // TODO(gvisor.dev/issue/578): There are many more attributes. } +// AddNewLinkMessage appends an RTM_NEWLINK message for the given interface into +// the message set. +// AddNewLinkMessage implements netlink.RouteProtocol.AddNewLinkMessage. +func (p *Protocol) AddNewLinkMessage(ms *nlmsg.MessageSet, idx int32, i inet.Interface) { + m := ms.AddMessage(linux.NetlinkMessageHeader{ + Type: linux.RTM_NEWLINK, + }) + writeLinkInfo(m, idx, i) +} + +// AddDelLinkMessage appends an RTM_DELLINK message for the given interface into +// the message set. +// AddDelLinkMessage implements netlink.RouteProtocol.AddDelLinkMessage. +func (p *Protocol) AddDelLinkMessage(ms *nlmsg.MessageSet, idx int32, i inet.Interface) { + m := ms.AddMessage(linux.NetlinkMessageHeader{ + Type: linux.RTM_DELLINK, + }) + writeLinkInfo(m, idx, i) +} + // dumpAddrs handles RTM_GETADDR dump requests. func (p *Protocol) dumpAddrs(ctx context.Context, s *netlink.Socket, msg *nlmsg.Message, ms *nlmsg.MessageSet) *syserr.Error { // RTM_GETADDR dump requests need not contain anything more than the diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 0167c24060..47f39995ef 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -16,12 +16,14 @@ package netlink import ( + "fmt" "io" "math" "time" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/abi/linux/errno" + "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" @@ -55,6 +57,9 @@ const ( // maxBufferSize is the largest size a send buffer can grow to. maxSendBufferSize = 4 << 20 // 4MB + + // supportedGroups is the set of multicast groups that are supported. + supportedGroups = 1 << (linux.RTNLGRP_LINK - 1) ) var errNoFilter = syserr.New("no filter attached", errno.ENOENT) @@ -92,6 +97,14 @@ type Socket struct { // sent to userspace. connection transport.ConnectedEndpoint + // netns is the network namespace associated with the socket. + // A netlink socket is immutably bound to a network namespace. + netns *inet.Namespace + + // groups is a bitmap of the set of multicast groups this socket is bound to. + // Writing to it requires the per-netns table lock to be held, reading it does not. + groups atomicbitops.Uint64 + // mu protects the fields below. mu sync.Mutex `state:"nosave"` @@ -110,13 +123,11 @@ type Socket struct { // TODO(gvisor.dev/issue/1119): We don't actually support filtering, // this is just bookkeeping for tracking add/remove. filter bool - - // netns is the network namespace associated with the socket. - netns *inet.Namespace } var _ socket.Socket = (*Socket)(nil) var _ transport.Credentialer = (*Socket)(nil) +var _ inet.NetlinkSocket = (*Socket)(nil) // New creates a new Socket. func New(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) { @@ -157,6 +168,12 @@ func (s *Socket) Stack() inet.Stack { // Release implements vfs.FileDescriptionImpl.Release. func (s *Socket) Release(ctx context.Context) { + if s.groups.Load() != 0 { + s.netns.NetlinkMcastTable().WithTableLocked(func() { + s.netns.NetlinkMcastTable().RemoveSocket(s) + }) + } + t := kernel.TaskFromContext(ctx) t.Kernel().DeleteSocket(&s.vfsfd) s.connection.Release(ctx) @@ -304,6 +321,125 @@ func (s *Socket) bindPort(t *kernel.Task, port int32) *syserr.Error { return nil } +func (s *Socket) checkMcastSupport(t *kernel.Task) *syserr.Error { + // Currently only ROUTE family sockets support multicast. + if s.Protocol() != linux.NETLINK_ROUTE { + return syserr.ErrNotSupported + } + // Not all inet.Stacks relay interface events, currently only netstack/tcpip does. + if _, ok := s.Stack().(inet.InterfaceEventPublisher); !ok { + return syserr.ErrNotSupported + } + // man 7 netlink: "Only processes with an effective UID of 0 or the CAP_NET_ADMIN + // capability may send or listen to a netlink multicast group." + if !t.HasCapability(linux.CAP_NET_ADMIN) { + return syserr.ErrPermissionDenied + } + return nil +} + +// preconditions: the netlink multicast table is locked. +func (s *Socket) joinGroups(t *kernel.Task, groups uint64) *syserr.Error { + if groups&supportedGroups != groups { + return syserr.ErrNotSupported + } + if err := s.checkMcastSupport(t); err != nil { + return err + } + + oldGroups := s.groups.Load() + s.groups.Store(groups) + if oldGroups == 0 && s.groups.Load() != 0 { + s.netns.NetlinkMcastTable().AddSocket(s) + } else if oldGroups != 0 && s.groups.Load() == 0 { + s.netns.NetlinkMcastTable().RemoveSocket(s) + } + return nil +} + +// preconditions: the netlink multicast table is locked. +func (s *Socket) joinGroup(t *kernel.Task, group uint32) *syserr.Error { + if group == 0 || group > 64 { + return syserr.ErrInvalidArgument + } + groups := uint64(1) << (group - 1) + if groups&supportedGroups != groups { + return syserr.ErrNotSupported + } + if err := s.checkMcastSupport(t); err != nil { + return err + } + + oldGroups := s.groups.Load() + s.groups.Store(oldGroups | groups) + if oldGroups == 0 { + s.netns.NetlinkMcastTable().AddSocket(s) + } + return nil +} + +// preconditions: the netlink multicast table is locked. +func (s *Socket) leaveGroup(t *kernel.Task, group uint32) *syserr.Error { + if group == 0 || group > 64 { + return syserr.ErrInvalidArgument + } + groups := uint64(1) << (group - 1) + if groups&supportedGroups != groups { + return syserr.ErrNotSupported + } + if err := s.checkMcastSupport(t); err != nil { + return err + } + + s.groups.Store(s.groups.Load() &^ groups) + if s.groups.Load() == 0 { + s.netns.NetlinkMcastTable().RemoveSocket(s) + } + return nil +} + +// Protocol implements inet.NetlinkSocket.Protocol. +func (s *Socket) Protocol() int { + return s.protocol.Protocol() +} + +// Groups implements inet.NetlinkSocket.Groups. +func (s *Socket) Groups() uint64 { + return s.groups.Load() +} + +// HandleInterfaceChangeEvent implements inet.NetlinkSocket.HandleInterfaceChangeEvent. +func (s *Socket) HandleInterfaceChangeEvent(ctx context.Context, idx int32, i inet.Interface) { + routeProtocol, ok := s.protocol.(RouteProtocol) + if !ok { + panic(fmt.Sprintf("Non-ROUTE netlink socket (protocol %d) cannot handle interface events", s.Protocol())) + } + + s.mu.Lock() + portID := s.portID + s.mu.Unlock() + ms := nlmsg.NewMessageSet(portID, 0) + routeProtocol.AddNewLinkMessage(ms, idx, i) + // TODO(b/456238795): Implement netlink ENOBUFS. + s.SendResponse(ctx, ms) +} + +// HandleInterfaceDeleteEvent implements inet.NetlinkSocket.HandleInterfaceDeleteEvent. +func (s *Socket) HandleInterfaceDeleteEvent(ctx context.Context, idx int32, i inet.Interface) { + routeProtocol, ok := s.protocol.(RouteProtocol) + if !ok { + panic(fmt.Sprintf("Non-ROUTE netlink socket (protocol %d) cannot handle interface events", s.Protocol())) + } + + s.mu.Lock() + portID := s.portID + s.mu.Unlock() + ms := nlmsg.NewMessageSet(portID, 0) + routeProtocol.AddDelLinkMessage(ms, idx, i) + // TODO(b/456238795): Implement netlink ENOBUFS. + s.SendResponse(ctx, ms) +} + // Bind implements socket.Socket.Bind. func (s *Socket) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { a, err := ExtractSockAddr(sockaddr) @@ -311,14 +447,18 @@ func (s *Socket) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { return err } - // No support for multicast groups yet. if a.Groups != 0 { - return syserr.ErrPermissionDenied + var err *syserr.Error + s.netns.NetlinkMcastTable().WithTableLocked(func() { + err = s.joinGroups(t, uint64(a.Groups)) + }) + if err != nil { + return err + } } s.mu.Lock() defer s.mu.Unlock() - return s.bindPort(t, int32(a.PortID)) } @@ -329,7 +469,7 @@ func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr return err } - // No support for multicast groups yet. + // No support for sending to destination multicast groups yet. if a.Groups != 0 { return syserr.ErrPermissionDenied } @@ -417,13 +557,19 @@ func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr hostarch } case linux.SOL_NETLINK: switch name { + case linux.NETLINK_LIST_MEMBERSHIPS: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + return primitive.AllocateUint64(s.groups.Load()), nil + case linux.NETLINK_BROADCAST_ERROR, linux.NETLINK_CAP_ACK, linux.NETLINK_DUMP_STRICT_CHK, linux.NETLINK_EXT_ACK, - linux.NETLINK_LIST_MEMBERSHIPS, linux.NETLINK_NO_ENOBUFS, linux.NETLINK_PKTINFO: + // Not supported. } } @@ -528,15 +674,36 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy } case linux.SOL_NETLINK: switch name { - case linux.NETLINK_ADD_MEMBERSHIP, - linux.NETLINK_BROADCAST_ERROR, + case linux.NETLINK_ADD_MEMBERSHIP: + if len(opt) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + group := hostarch.ByteOrder.Uint32(opt) + var err *syserr.Error + s.netns.NetlinkMcastTable().WithTableLocked(func() { + err = s.joinGroup(t, group) + }) + return err + + case linux.NETLINK_DROP_MEMBERSHIP: + if len(opt) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + group := hostarch.ByteOrder.Uint32(opt) + var err *syserr.Error + s.netns.NetlinkMcastTable().WithTableLocked(func() { + err = s.leaveGroup(t, group) + }) + return err + + case linux.NETLINK_BROADCAST_ERROR, linux.NETLINK_CAP_ACK, - linux.NETLINK_DROP_MEMBERSHIP, linux.NETLINK_DUMP_STRICT_CHK, linux.NETLINK_EXT_ACK, linux.NETLINK_LISTEN_ALL_NSID, linux.NETLINK_NO_ENOBUFS, linux.NETLINK_PKTINFO: + // Not supported. } } diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index cebd8db5ba..0bb273a64d 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -1,3 +1,4 @@ +load("//pkg/sync/locking:locking.bzl", "declare_mutex") load("//tools:defs.bzl", "go_library", "proto_library") package( @@ -5,10 +6,18 @@ package( licenses = ["notice"], ) +declare_mutex( + name = "netstack_link_mutex", + out = "netstack_link_mutex.go", + package = "netstack", + prefix = "netstackLink", +) + go_library( name = "netstack", srcs = [ "netstack.go", + "netstack_link_mutex.go", "netstack_state.go", "provider.go", "save_restore.go", @@ -48,6 +57,7 @@ go_library( "//pkg/sentry/socket/netstack/packetmmap", "//pkg/sentry/vfs", "//pkg/sync", + "//pkg/sync/locking", "//pkg/syserr", "//pkg/tcpip", "//pkg/tcpip/header", diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index d4934cf7bd..7e8d4cdd09 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -41,6 +41,46 @@ import ( // +stateify savable type Stack struct { Stack *stack.Stack `state:".(*stack.Stack)"` + + eventSubscriber inet.InterfaceEventSubscriber + + // linkMu serializes link creation, modification and deletion. + // It is a rough parallel to the per-netns rtnl_mutex in Linux. + linkMu netstackLinkMutex `state:"nosave"` +} + +// AddInterfaceEventSubscriber implements inet.InterfaceEventPublisher.AddInterfaceEventSubscriber. +func (s *Stack) AddInterfaceEventSubscriber(sub inet.InterfaceEventSubscriber) { + if s.eventSubscriber != nil { + panic("AddInterfaceEventSubscriber called twice: multiple subscribers yet to be supported") + } + s.eventSubscriber = sub +} + +func makeInterfaceInfo(ni *stack.NICInfo) inet.Interface { + return inet.Interface{ + Name: ni.Name, + Addr: []byte(ni.LinkAddress), + Flags: uint32(nicStateFlagsToLinux(ni.Flags)), + DeviceType: toLinuxARPHardwareType(ni.ARPHardwareType), + MTU: ni.MTU, + } +} + +func (s *Stack) sendChangeEvent(ctx context.Context, id tcpip.NICID) { + if s.eventSubscriber == nil { + return + } + if nicInfo, ok := s.Stack.SingleNICInfo(id); ok { + s.eventSubscriber.OnInterfaceChangeEvent(ctx, int32(id), makeInterfaceInfo(nicInfo)) + } +} + +func (s *Stack) sendDeleteEvent(ctx context.Context, id tcpip.NICID, nicInfo *stack.NICInfo) { + if s.eventSubscriber == nil { + return + } + s.eventSubscriber.OnInterfaceDeleteEvent(ctx, int32(id), makeInterfaceInfo(nicInfo)) } // EnableSaveRestore enables netstack s/r. @@ -90,22 +130,19 @@ func toLinuxARPHardwareType(t header.ARPHardwareType) uint16 { func (s *Stack) Interfaces() map[int32]inet.Interface { is := make(map[int32]inet.Interface) for id, ni := range s.Stack.NICInfo() { - is[int32(id)] = inet.Interface{ - Name: ni.Name, - Addr: []byte(ni.LinkAddress), - Flags: uint32(nicStateFlagsToLinux(ni.Flags)), - DeviceType: toLinuxARPHardwareType(ni.ARPHardwareType), - MTU: ni.MTU, - } + is[int32(id)] = makeInterfaceInfo(&ni) } return is } // RemoveInterface implements inet.Stack.RemoveInterface. -func (s *Stack) RemoveInterface(idx int32) error { +func (s *Stack) RemoveInterface(ctx context.Context, idx int32) error { + s.linkMu.Lock() + defer s.linkMu.Unlock() + nic := tcpip.NICID(idx) - nicInfo, ok := s.Stack.NICInfo()[nic] + nicInfo, ok := s.Stack.SingleNICInfo(nic) if !ok { return syserr.ErrUnknownNICID.ToError() } @@ -115,7 +152,12 @@ func (s *Stack) RemoveInterface(idx int32) error { return syserr.ErrNotSupported.ToError() } - return syserr.TranslateNetstackError(s.Stack.RemoveNIC(nic)).ToError() + if err := syserr.TranslateNetstackError(s.Stack.RemoveNIC(nic)); err != nil { + return err.ToError() + } + s.sendDeleteEvent(ctx, nic, nicInfo) + return nil + } // SetInterface implements inet.Stack.SetInterface. @@ -180,10 +222,18 @@ func (s *Stack) SetInterface(ctx context.Context, msg *nlmsg.Message) *syserr.Er // Netstack interfaces are always up. } - return s.setLink(ctx, tcpip.NICID(ifinfomsg.Index), attrs) + s.linkMu.Lock() + defer s.linkMu.Unlock() + return s.setLinkLocked(ctx, tcpip.NICID(ifinfomsg.Index), attrs) } -func (s *Stack) setLink(ctx context.Context, id tcpip.NICID, linkAttrs map[uint16]nlmsg.BytesView) *syserr.Error { +// precondition: s.linkLock is held. +func (s *Stack) setLinkLocked(ctx context.Context, id tcpip.NICID, linkAttrs map[uint16]nlmsg.BytesView) *syserr.Error { + oldNicInfo, ok := s.Stack.SingleNICInfo(id) + if !ok { + return syserr.ErrUnknownNICID + } + // IFLA_NET_NS_FD has to be handled first, because other parameters may be reset. if v, ok := linkAttrs[linux.IFLA_NET_NS_FD]; ok { fd, ok := v.Uint32() @@ -202,12 +252,21 @@ func (s *Stack) setLink(ctx context.Context, id tcpip.NICID, linkAttrs map[uint1 peer := ns.Stack().(*Stack) if peer.Stack != s.Stack { var err tcpip.Error + oldID := id + id, err = s.Stack.SetNICStack(id, peer.Stack) if err != nil { return syserr.TranslateNetstackError(err) } + + s.sendDeleteEvent(ctx, oldID, oldNicInfo) // inform about exit from old ns + peer.sendChangeEvent(ctx, id) // inform about entry into new ns + // TODO: Once we support IFLA_LINK_NETNSID, we need to call sendChangeEvent on + // the peer interface if this interface is part of a veth pair. } } + + changed := false for t, v := range linkAttrs { switch t { case linux.IFLA_MASTER: @@ -215,35 +274,55 @@ func (s *Stack) setLink(ctx context.Context, id tcpip.NICID, linkAttrs map[uint1 if !ok { return syserr.ErrInvalidArgument } + if mid, ok := s.Stack.GetNICCoordinatorID(id); ok && mid == tcpip.NICID(master) { + continue + } if master != 0 { if err := s.Stack.SetNICCoordinator(id, tcpip.NICID(master)); err != nil { return syserr.TranslateNetstackError(err) } + changed = true } case linux.IFLA_ADDRESS: if len(v) != tcpip.LinkAddressSize { return syserr.ErrInvalidArgument } addr := tcpip.LinkAddress(v) + if oldNicInfo.LinkAddress == addr { + continue + } if err := s.Stack.SetNICAddress(id, addr); err != nil { return syserr.TranslateNetstackError(err) } + changed = true case linux.IFLA_IFNAME: + if oldNicInfo.Name == v.String() { + continue + } if err := s.Stack.SetNICName(id, v.String()); err != nil { return syserr.TranslateNetstackError(err) } + changed = true case linux.IFLA_MTU: mtu, ok := v.Uint32() if !ok { return syserr.ErrInvalidArgument } + if oldNicInfo.MTU == mtu { + continue + } if err := s.Stack.SetNICMTU(id, mtu); err != nil { return syserr.TranslateNetstackError(err) } + changed = true case linux.IFLA_TXQLEN: // TODO(b/340388892): support IFLA_TXQLEN. } } + + if changed { + s.sendChangeEvent(ctx, id) + } return nil } @@ -298,6 +377,8 @@ func (s *Stack) newVeth(ctx context.Context, linkAttrs map[uint16]nlmsg.BytesVie } } } + + s.linkMu.Lock() ep, peerEP := veth.NewPair(defaultMTU, veth.DefaultBacklogSize) id := s.Stack.NextNICID() peerID := peerStack.Stack.NextNICID() @@ -308,16 +389,21 @@ func (s *Stack) newVeth(ctx context.Context, linkAttrs map[uint16]nlmsg.BytesVie Name: ifname, }) if err != nil { + s.linkMu.Unlock() return syserr.TranslateNetstackError(err) } - if err := s.setLink(ctx, id, linkAttrs); err != nil { + if err := s.setLinkLocked(ctx, id, linkAttrs); err != nil { + s.linkMu.Unlock() peerEP.Close() return err } + s.linkMu.Unlock() if peerName == "" { peerName = fmt.Sprintf("veth%d", peerID) } + peerStack.linkMu.Lock() + defer peerStack.linkMu.Unlock() err = peerStack.Stack.CreateNICWithOptions(peerID, packetsocket.New(ethernet.New(peerEP)), stack.NICOptions{ Name: peerName, }) @@ -326,7 +412,7 @@ func (s *Stack) newVeth(ctx context.Context, linkAttrs map[uint16]nlmsg.BytesVie return syserr.TranslateNetstackError(err) } if peerLinkAttrs != nil { - if err := peerStack.setLink(ctx, peerID, peerLinkAttrs); err != nil { + if err := peerStack.setLinkLocked(ctx, peerID, peerLinkAttrs); err != nil { peerStack.Stack.RemoveNIC(peerID) peerEP.Close() return err @@ -337,6 +423,9 @@ func (s *Stack) newVeth(ctx context.Context, linkAttrs map[uint16]nlmsg.BytesVie } func (s *Stack) newBridge(ctx context.Context, linkAttrs map[uint16]nlmsg.BytesView, linkInfoAttrs map[uint16]nlmsg.BytesView) *syserr.Error { + s.linkMu.Lock() + defer s.linkMu.Unlock() + ifname := "" if v, ok := linkAttrs[linux.IFLA_IFNAME]; ok { @@ -350,7 +439,7 @@ func (s *Stack) newBridge(ctx context.Context, linkAttrs map[uint16]nlmsg.BytesV if err != nil { return syserr.TranslateNetstackError(err) } - if err := s.setLink(ctx, id, linkAttrs); err != nil { + if err := s.setLinkLocked(ctx, id, linkAttrs); err != nil { return err } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index c486a91c81..de6ae3cb52 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1066,6 +1066,18 @@ func (s *Stack) removeNICLocked(id tcpip.NICID) (func(), tcpip.Error) { return nic.remove(true /* closeLinkEndpoint */) } +// GetNICCoordinatorID returns the ID of the coordinator device of a NIC. +func (s *Stack) GetNICCoordinatorID(id tcpip.NICID) (tcpip.NICID, bool) { + s.mu.Lock() + defer s.mu.Unlock() + if nic, ok := s.nics[id]; ok { + if nic.Primary != nil { + return nic.Primary.id, true + } + } + return 0, false +} + // SetNICCoordinator sets a coordinator device. func (s *Stack) SetNICCoordinator(id tcpip.NICID, mid tcpip.NICID) tcpip.Error { s.mu.Lock() @@ -1176,65 +1188,83 @@ func (s *Stack) HasNIC(id tcpip.NICID) bool { return ok } -// NICInfo returns a map of NICIDs to their associated information. -func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { - s.mu.RLock() - defer s.mu.RUnlock() +type forwardingFn func(tcpip.NetworkProtocolNumber) (bool, tcpip.Error) - type forwardingFn func(tcpip.NetworkProtocolNumber) (bool, tcpip.Error) - forwardingValue := func(forwardingFn forwardingFn, proto tcpip.NetworkProtocolNumber, nicID tcpip.NICID, fnName string) (forward bool, ok bool) { - switch forwarding, err := forwardingFn(proto); err.(type) { - case nil: - return forwarding, true - case *tcpip.ErrUnknownProtocol: - panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nicID)) - case *tcpip.ErrNotSupported: - // Not all network protocols support forwarding. - default: - panic(fmt.Sprintf("nic(id=%d).%s(%d): %s", nicID, fnName, proto, err)) - } - return false, false +func forwardingValue(forwardingFn forwardingFn, proto tcpip.NetworkProtocolNumber, nicID tcpip.NICID, fnName string) (forward bool, ok bool) { + switch forwarding, err := forwardingFn(proto); err.(type) { + case nil: + return forwarding, true + case *tcpip.ErrUnknownProtocol: + panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nicID)) + case *tcpip.ErrNotSupported: + // Not all network protocols support forwarding. + default: + panic(fmt.Sprintf("nic(id=%d).%s(%d): %s", nicID, fnName, proto, err)) } + return false, false +} - nics := make(map[tcpip.NICID]NICInfo) - for id, nic := range s.nics { - flags := NICStateFlags{ - Up: true, // Netstack interfaces are always up. - Running: nic.Enabled(), - Promiscuous: nic.Promiscuous(), - Loopback: nic.IsLoopback(), - } +// precondition: s.mu is held. +func (s *Stack) nicInfo(nic *nic, id tcpip.NICID) *NICInfo { + flags := NICStateFlags{ + Up: true, // Netstack interfaces are always up. + Running: nic.Enabled(), + Promiscuous: nic.Promiscuous(), + Loopback: nic.IsLoopback(), + } - netStats := make(map[tcpip.NetworkProtocolNumber]NetworkEndpointStats) - for proto, netEP := range nic.networkEndpoints { - netStats[proto] = netEP.Stats() + netStats := make(map[tcpip.NetworkProtocolNumber]NetworkEndpointStats) + for proto, netEP := range nic.networkEndpoints { + netStats[proto] = netEP.Stats() + } + + info := NICInfo{ + Name: nic.name, + LinkAddress: nic.NetworkLinkEndpoint.LinkAddress(), + ProtocolAddresses: nic.primaryAddresses(), + Flags: flags, + MTU: nic.NetworkLinkEndpoint.MTU(), + Stats: nic.stats.local, + NetworkStats: netStats, + Context: nic.context, + ARPHardwareType: nic.NetworkLinkEndpoint.ARPHardwareType(), + Forwarding: make(map[tcpip.NetworkProtocolNumber]bool), + MulticastForwarding: make(map[tcpip.NetworkProtocolNumber]bool), + } + + for proto := range s.networkProtocols { + if forwarding, ok := forwardingValue(nic.forwarding, proto, id, "forwarding"); ok { + info.Forwarding[proto] = forwarding } - info := NICInfo{ - Name: nic.name, - LinkAddress: nic.NetworkLinkEndpoint.LinkAddress(), - ProtocolAddresses: nic.primaryAddresses(), - Flags: flags, - MTU: nic.NetworkLinkEndpoint.MTU(), - Stats: nic.stats.local, - NetworkStats: netStats, - Context: nic.context, - ARPHardwareType: nic.NetworkLinkEndpoint.ARPHardwareType(), - Forwarding: make(map[tcpip.NetworkProtocolNumber]bool), - MulticastForwarding: make(map[tcpip.NetworkProtocolNumber]bool), + if multicastForwarding, ok := forwardingValue(nic.multicastForwarding, proto, id, "multicastForwarding"); ok { + info.MulticastForwarding[proto] = multicastForwarding } + } - for proto := range s.networkProtocols { - if forwarding, ok := forwardingValue(nic.forwarding, proto, id, "forwarding"); ok { - info.Forwarding[proto] = forwarding - } + return &info +} - if multicastForwarding, ok := forwardingValue(nic.multicastForwarding, proto, id, "multicastForwarding"); ok { - info.MulticastForwarding[proto] = multicastForwarding - } - } +// SingleNICInfo returns the NICInfo for the given NICID. +func (s *Stack) SingleNICInfo(id tcpip.NICID) (*NICInfo, bool) { + s.mu.RLock() + defer s.mu.RUnlock() - nics[id] = info + if nic, ok := s.nics[id]; !ok { + return nil, false + } else { + return s.nicInfo(nic, id), true + } +} + +// NICInfo returns a map of NICIDs to their associated information. +func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { + s.mu.RLock() + defer s.mu.RUnlock() + + nics := make(map[tcpip.NICID]NICInfo) + for id, nic := range s.nics { + nics[id] = *s.nicInfo(nic, id) } return nics } diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index 55f6e5afae..977baa2687 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -16,10 +16,11 @@ #include #include #include -#include #include #include #include +#include +#include #include #include #include @@ -280,6 +281,29 @@ TEST_P(NetlinkSetLinkTest, ChangeLinkName) { EXPECT_TRUE(found) << "Netlink response does not contain any links."; } +struct MtuRequest { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + struct rtattr rtattr; + uint32_t mtu; +}; + +MtuRequest GetMtuRequest(const Link& link, uint16_t nlmsg_type, uint32_t mtu) { + MtuRequest req = {}; + + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = nlmsg_type; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + req.ifm.ifi_index = link.index; + req.rtattr.rta_type = IFLA_MTU; + req.rtattr.rta_len = RTA_LENGTH(sizeof(uint32_t)); + req.mtu = mtu; + + return req; +} + TEST_P(NetlinkSetLinkTest, ChangeMTU) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); SKIP_IF(IsRunningWithHostinet()); @@ -289,23 +313,10 @@ TEST_P(NetlinkSetLinkTest, ChangeMTU) { FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - struct request { - struct nlmsghdr hdr; - struct ifinfomsg ifm; - struct rtattr rtattr; - uint32_t mtu; - } req = {}; - // Change the MTU. - req.hdr.nlmsg_len = sizeof(req); - req.hdr.nlmsg_type = GetParam(); - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; - req.hdr.nlmsg_seq = kSeq; - req.ifm.ifi_family = AF_UNSPEC; - req.ifm.ifi_index = loopback_link.index; - req.rtattr.rta_type = IFLA_MTU; - req.rtattr.rta_len = RTA_LENGTH(sizeof(uint32_t)); - req.mtu = loopback_link.mtu + 10; + uint16_t nlmsg_type = GetParam(); + MtuRequest req = + GetMtuRequest(loopback_link, nlmsg_type, loopback_link.mtu + 10); EXPECT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req))); // Update the local loopback_link's MTU to the requested value. @@ -1739,55 +1750,55 @@ void addattr(struct nlmsghdr* n, int maxlen, int type, const void* data, n->nlmsg_len = NLMSG_ALIGN(n->nlmsg_len) + RTA_ALIGN(len); } -TEST(NetlinkRouteTest, VethAdd) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); - SKIP_IF(IsRunningWithHostinet()); +struct VethRequest { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + char buf[1024]; +}; - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - - struct request { - struct nlmsghdr hdr; - struct ifinfomsg ifm; - char buf[1024]; - }; - - struct request req = {}; +struct VethRequest GetVethRequest(uint32_t seq, const char* ifname_first, + const char* ifname_second) { + struct VethRequest req = {}; req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifinfomsg)); req.hdr.nlmsg_type = RTM_NEWLINK; req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE; - req.hdr.nlmsg_seq = kSeq; + req.hdr.nlmsg_seq = seq; req.ifm.ifi_family = AF_UNSPEC; req.ifm.ifi_index = 0; - req.ifm.ifi_change = IFF_UP; - req.ifm.ifi_flags = IFF_UP; - const char veth_first[] = "veth_first"; - addattr(&req.hdr, sizeof(req), IFLA_IFNAME, veth_first, strlen(veth_first)); + addattr(&req.hdr, sizeof(req), IFLA_IFNAME, ifname_first, + strlen(ifname_first)); - struct rtattr* linkinfo; - linkinfo = NLMSG_TAIL(&req.hdr); + struct rtattr* linkinfo = NLMSG_TAIL(&req.hdr); { addattr(&req.hdr, sizeof(req), IFLA_LINKINFO, nullptr, 0); addattr(&req.hdr, sizeof(req), IFLA_INFO_KIND, "veth", 4); - - struct rtattr *veth_data, *peer_data; - veth_data = NLMSG_TAIL(&req.hdr); + struct rtattr* veth_data = NLMSG_TAIL(&req.hdr); { addattr(&req.hdr, sizeof(req), IFLA_INFO_DATA, NULL, 0); - peer_data = NLMSG_TAIL(&req.hdr); + struct rtattr* peer_data = NLMSG_TAIL(&req.hdr); { struct ifinfomsg ifm = {}; addattr(&req.hdr, sizeof(req), VETH_INFO_PEER, &ifm, sizeof(ifm)); - const char veth_second[] = "veth_second"; - addattr(&req.hdr, sizeof(req), IFLA_IFNAME, veth_second, - strlen(veth_second)); + addattr(&req.hdr, sizeof(req), IFLA_IFNAME, ifname_second, + strlen(ifname_second)); } peer_data->rta_len = (uint64_t)NLMSG_TAIL(&req.hdr) - (uint64_t)peer_data; } veth_data->rta_len = (uint64_t)NLMSG_TAIL(&req.hdr) - (uint64_t)veth_data; } linkinfo->rta_len = (uint64_t)NLMSG_TAIL(&req.hdr) - (uint64_t)linkinfo; + + return req; +} + +TEST(NetlinkRouteTest, VethAdd) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + SKIP_IF(IsRunningWithHostinet()); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + VethRequest req = GetVethRequest(kSeq, "veth1", "veth2"); EXPECT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len)); } @@ -1811,6 +1822,350 @@ TEST(NetlinkRouteTest, LookupAllAddrOrder) { freeifaddrs(if_addr_list); } } + +struct NameRequest { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + struct rtattr rtattr; + char name[IFNAMSIZ]; +}; + +NameRequest GetNameRequest(const Link& link, const char* name, uint32_t seq) { + NameRequest req = {}; + req.hdr.nlmsg_type = RTM_SETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = seq; + req.ifm.ifi_family = AF_UNSPEC; + req.ifm.ifi_index = link.index; + + const size_t payload_len = strlen(name) + 1; + req.rtattr.rta_type = IFLA_IFNAME; + req.rtattr.rta_len = RTA_LENGTH(payload_len); + memcpy(req.name, name, payload_len); + + req.hdr.nlmsg_len = + NLMSG_LENGTH(sizeof(struct ifinfomsg)) + RTA_SPACE(payload_len); + return req; +}; + +TEST(NetlinkRouteTest, LinkMulticastGroupBasic) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + SKIP_IF(IsRunningWithHostinet()); + // TODO(gvisor.dev/issue/4595): enable cooperative save tests. + const DisableSave ds; + + // nlsk_bound_group joins RTMGRP_LINK via bind(). + struct sockaddr_nl addr = {}; + addr.nl_family = AF_NETLINK; + addr.nl_groups = RTMGRP_LINK; + FileDescriptor nlsk_bound_group = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE, &addr)); + + // nlsk_sockopt_group joins RTMGRP_LINK via setsockopt(). + addr = {}; + addr.nl_family = AF_NETLINK; + FileDescriptor nlsk_sockopt_group = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE, &addr)); + unsigned int group = RTMGRP_LINK; + ASSERT_THAT(setsockopt(nlsk_sockopt_group.get(), SOL_NETLINK, + NETLINK_ADD_MEMBERSHIP, &group, sizeof(group)), + SyscallSucceeds()); + int64_t res_groups; + socklen_t res_groups_len = sizeof(res_groups); + EXPECT_THAT( + getsockopt(nlsk_sockopt_group.get(), SOL_NETLINK, + NETLINK_LIST_MEMBERSHIPS, &res_groups, &res_groups_len), + SyscallSucceeds()); + EXPECT_EQ(res_groups_len, sizeof(res_groups)); + EXPECT_EQ(res_groups, RTMGRP_LINK); + + FileDescriptor control_fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + // Change the name of the loopback interface. + const Link link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + std::string old_loopback_name = link.name; + NameRequest name_request = GetNameRequest(link, "lo_test", kSeq); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(control_fd, kSeq, &name_request, + name_request.hdr.nlmsg_len)); + auto restore_loopback_name = Cleanup([&]() { + NameRequest name_request = + GetNameRequest(link, old_loopback_name.c_str(), kSeq); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(control_fd, kSeq, &name_request, + name_request.hdr.nlmsg_len)); + }); + const Link link_newname = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + + // Change the MTU of the loopback interface. + MtuRequest mtu_request = GetMtuRequest(link, RTM_SETLINK, link.mtu + 10); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(control_fd, kSeq, &mtu_request, + sizeof(mtu_request))); + auto restore_mtu = Cleanup([&]() { + MtuRequest mtu_request = GetMtuRequest(link, RTM_SETLINK, link.mtu); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(control_fd, kSeq, &mtu_request, + sizeof(mtu_request))); + }); + const Link link_newmtu = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + + struct TestCase { + const char* name; + FileDescriptor* nlsk; + const char* event_name; + const Link& link; + }; + std::vector test_cases = { + { + .name = "bound_group", + .nlsk = &nlsk_bound_group, + .event_name = "name_change", + .link = link_newname, + }, + { + .name = "sockopt_group", + .nlsk = &nlsk_sockopt_group, + .event_name = "name_change", + .link = link_newname, + }, + { + .name = "bound_group", + .nlsk = &nlsk_bound_group, + .event_name = "mtu_change", + .link = link_newmtu, + }, + { + .name = "sockopt_group", + .nlsk = &nlsk_sockopt_group, + .event_name = "mtu_change", + .link = link_newmtu, + }, + }; + + for (const auto& tc : test_cases) { + SCOPED_TRACE(std::string(tc.name) + " " + tc.event_name); + + struct pollfd pfd = {.fd = tc.nlsk->get(), .events = POLLIN}; + constexpr int kPollTimeoutMs = 1000; + int poll_ret = RetryEINTR(poll)(&pfd, 1, kPollTimeoutMs); + ASSERT_EQ(poll_ret, 1); + + bool got_msg = false; + ASSERT_NO_ERRNO(NetlinkResponse( + *tc.nlsk, + [&](const struct nlmsghdr* hdr) { + const struct ifinfomsg* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + if (msg->ifi_index != tc.link.index) { + return; + } + CheckLinkMsg(hdr, tc.link); + got_msg = true; + }, + /*expect_nlmsgerr=*/false)); + EXPECT_TRUE(got_msg); + } +} + +struct VethRequest GetSetNetNSRequest(uint32_t seq, int if_index, int ns_fd) { + struct VethRequest req = {}; + + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifinfomsg)); + req.hdr.nlmsg_type = RTM_NEWLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = seq; + req.ifm.ifi_family = AF_UNSPEC; + req.ifm.ifi_index = if_index; + addattr(&req.hdr, sizeof(req), IFLA_NET_NS_FD, &ns_fd, sizeof(ns_fd)); + + return req; +} + +// To verify the namespaced nature of the netlink multicast groups. +TEST(NetlinkRouteTest, LinkMulticastGroupNamespaced) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + SKIP_IF(IsRunningWithHostinet()); + // TODO(gvisor.dev/issue/4595): enable cooperative save tests. + const DisableSave ds; + + FileDescriptor control_nlsk = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + VethRequest req = GetVethRequest(kSeq, "veth1", "veth2"); + EXPECT_NO_ERRNO( + NetlinkRequestAckOrError(control_nlsk, kSeq, &req, req.hdr.nlmsg_len)); + + int inner_veth_idx = if_nametoindex("veth2"); + ASSERT_NE(inner_veth_idx, 0); + + struct sockaddr_nl mcast_addr = {}; + mcast_addr.nl_family = AF_NETLINK; + mcast_addr.nl_groups = RTMGRP_LINK; + FileDescriptor root_nlsk = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE, &mcast_addr)); + + const FileDescriptor root_nsfd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/thread-self/ns/net", O_RDONLY)); + Cleanup restore_netns = Cleanup([&] { + ASSERT_THAT(setns(root_nsfd.get(), CLONE_NEWNET), + SyscallSucceedsWithValue(0)); + }); + + // Enter a new network namespace. + ASSERT_THAT(unshare(CLONE_NEWNET), SyscallSucceedsWithValue(0)); + FileDescriptor inner_nlsk = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE, &mcast_addr)); + + // And move veth2 into it. + const FileDescriptor inner_nsfd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/thread-self/ns/net", O_RDONLY)); + VethRequest set_netns_req = + GetSetNetNSRequest(kSeq, inner_veth_idx, inner_nsfd.get()); + EXPECT_NO_ERRNO(NetlinkRequestAckOrError(control_nlsk, kSeq, &set_netns_req, + set_netns_req.hdr.nlmsg_len)); + + constexpr int kPollTimeoutMs = 1000; + bool got_msg = false; + // We expect an RTM_DELINK message for veth2 in the root netns socket. + // But an RTM_NEWLINK is also expected for veth1 because its peer was moved. + // Hence the two attempts. N.B. gVisor does not send the RTM_NEWLINK because + // IFLA_LINK_NETNSID is not yet supported. + for (int i = 0; i < 2; ++i) { + struct pollfd pfd = {.fd = root_nlsk.get(), .events = POLLIN}; + ASSERT_EQ(RetryEINTR(poll)(&pfd, 1, kPollTimeoutMs), 1) + << "root_nlsk: Did not get veth2 DELLINK"; + + ASSERT_NO_ERRNO(NetlinkResponse( + root_nlsk, + [&](const struct nlmsghdr* hdr) { + const struct ifinfomsg* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + if (hdr->nlmsg_type != RTM_DELLINK) return; + if (msg->ifi_index != inner_veth_idx) return; + got_msg = true; + }, + /*expect_nlmsgerr=*/false)); + if (got_msg) break; + } + EXPECT_TRUE(got_msg) << "root_nlsk: Did not get veth2 DELLINK"; + + // We expect an RTM_NEWLINK message for veth2 in the inner netns socket. + { + struct pollfd pfd = {.fd = inner_nlsk.get(), .events = POLLIN}; + ASSERT_EQ(RetryEINTR(poll)(&pfd, 1, kPollTimeoutMs), 1) + << "inner_nlsk: Did not get veth2 NEWLINK"; + + bool got_msg = false; + ASSERT_NO_ERRNO(NetlinkResponse( + inner_nlsk, + [&](const struct nlmsghdr* hdr) { + const struct ifinfomsg* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + ASSERT_EQ(hdr->nlmsg_type, RTM_NEWLINK); + if (msg->ifi_index == 1) return; // Ignore the loopback interface. + + char ifname[IF_NAMESIZE]; + EXPECT_NE(if_indextoname(msg->ifi_index, ifname), nullptr); + EXPECT_STREQ(ifname, "veth2"); + got_msg = true; + }, + /*expect_nlmsgerr=*/false)); + EXPECT_TRUE(got_msg) << "inner_nlsk: Did not get veth2 NEWLINK"; + } +} + +// NOOP requests should not result in any netlink multicast messages. +TEST(NetlinkRouteTest, LinkMulticastGroupNoop) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + SKIP_IF(IsRunningWithHostinet()); + // TODO(gvisor.dev/issue/4595): enable cooperative save tests. + const DisableSave ds; + + struct sockaddr_nl mcast_addr = {}; + mcast_addr.nl_family = AF_NETLINK; + mcast_addr.nl_groups = RTMGRP_LINK; + FileDescriptor nlsk = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE, &mcast_addr)); + + // Issue a request to set the name of the loopback interface to the same name. + const Link link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + NameRequest name_request = GetNameRequest(link, link.name.c_str(), kSeq); + FileDescriptor control_nlsk = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(control_nlsk, kSeq, &name_request, + name_request.hdr.nlmsg_len)); + + // We expect no RTM_NEWLINK message for the loopback interface. + struct pollfd pfd = {.fd = nlsk.get(), .events = POLLIN}; + constexpr int kPollTimeoutMs = 500; + bool got_msg = false; + if (RetryEINTR(poll)(&pfd, 1, kPollTimeoutMs) >= 1) { + ASSERT_NO_ERRNO(NetlinkResponse( + nlsk, + [&](const struct nlmsghdr* hdr) { + const struct ifinfomsg* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + if (hdr->nlmsg_type != RTM_NEWLINK) return; + if (msg->ifi_index != link.index) return; + got_msg = true; + }, + /*expect_nlmsgerr=*/false)); + } + EXPECT_FALSE(got_msg) + << "Should not get a newlink event for the loopback interface."; +} + +// Userspace should know that it failed to keep up with its recvmsg()s, and the +// kernel alerts it to this by having recvmsg() return ENOBUFS. +TEST(NetlinkRouteTest, LinkMulticastGroupEnobufs) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + SKIP_IF(IsRunningWithHostinet()); + // TODO(gvisor.dev/issue/4595): enable cooperative save tests. + const DisableSave ds; + // TODO(b/456238795): enable this test once gVisor returns ENOBUFS. + if (IsRunningOnGvisor()) { + GTEST_SKIP() << "gVisor never returns ENOBUFS."; + } + + struct sockaddr_nl mcast_addr = {}; + mcast_addr.nl_family = AF_NETLINK; + mcast_addr.nl_groups = RTMGRP_LINK; + FileDescriptor nlsk = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE, &mcast_addr)); + + // N.B. gvisor ignores the SO_RCVBUF value. + constexpr int kSmallRcvBufSize = 512; + ASSERT_THAT(setsockopt(nlsk.get(), SOL_SOCKET, SO_RCVBUF, &kSmallRcvBufSize, + sizeof(kSmallRcvBufSize)), + SyscallSucceeds()); + int recv_buf_size; + socklen_t rec_buf_size_len = sizeof(recv_buf_size); + ASSERT_THAT(getsockopt(nlsk.get(), SOL_SOCKET, SO_RCVBUF, &recv_buf_size, + &rec_buf_size_len), + SyscallSucceeds()); + + // Generate enough link events to overflow poor nlsk's receive buffer. + FileDescriptor control_nlsk = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + Link link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + constexpr int kMinimumNewlinkMsgSize = 32; + const int num_msgs = recv_buf_size / kMinimumNewlinkMsgSize; + for (int i = 0; i < num_msgs || link.name != "lo"; ++i) { + std::string name = link.name == "lo" ? "lo_test" : "lo"; + NameRequest name_request = GetNameRequest(link, name.c_str(), kSeq); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(control_nlsk, kSeq, &name_request, + name_request.hdr.nlmsg_len)); + link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + } + + std::vector buf(kSmallRcvBufSize); + struct iovec iov = {}; + iov.iov_base = buf.data(); + iov.iov_len = buf.size(); + struct msghdr msg = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + EXPECT_THAT(RetryEINTR(recvmsg)(nlsk.get(), &msg, 0), + SyscallFailsWithErrno(ENOBUFS)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc index 61faeb50e2..f77ece9ed1 100644 --- a/test/syscalls/linux/socket_netlink_route_util.cc +++ b/test/syscalls/linux/socket_netlink_route_util.cc @@ -290,6 +290,7 @@ PosixErrorOr> DumpLinks() { rta_address == nullptr ? "" : std::string(reinterpret_cast(RTA_DATA(rta_address))); + links.back().flags = msg->ifi_flags; })); return links; } diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h index fdb1bb5a0b..2b37eb9e9e 100644 --- a/test/syscalls/linux/socket_netlink_route_util.h +++ b/test/syscalls/linux/socket_netlink_route_util.h @@ -31,6 +31,7 @@ struct Link { std::string name; uint32_t mtu; std::string address; + unsigned int flags; }; PosixError DumpLinks(const FileDescriptor& fd, uint32_t seq, diff --git a/test/syscalls/linux/socket_netlink_util.cc b/test/syscalls/linux/socket_netlink_util.cc index c100066df6..9b810d02c0 100644 --- a/test/syscalls/linux/socket_netlink_util.cc +++ b/test/syscalls/linux/socket_netlink_util.cc @@ -38,14 +38,17 @@ namespace gvisor { namespace testing { PosixErrorOr NetlinkBoundSocket(int protocol) { - FileDescriptor fd; - ASSIGN_OR_RETURN_ERRNO(fd, Socket(AF_NETLINK, SOCK_RAW, protocol)); - struct sockaddr_nl addr = {}; addr.nl_family = AF_NETLINK; + return NetlinkBoundSocket(protocol, &addr); +} - RETURN_ERROR_IF_SYSCALL_FAIL( - bind(fd.get(), reinterpret_cast(&addr), sizeof(addr))); +PosixErrorOr NetlinkBoundSocket( + int protocol, const struct sockaddr_nl* addr) { + FileDescriptor fd; + ASSIGN_OR_RETURN_ERRNO(fd, Socket(AF_NETLINK, SOCK_RAW, protocol)); + + RETURN_ERROR_IF_SYSCALL_FAIL(bind(fd.get(), AsSockAddr(addr), sizeof(*addr))); MaybeSave(); return std::move(fd); diff --git a/test/syscalls/linux/socket_netlink_util.h b/test/syscalls/linux/socket_netlink_util.h index b21f513f69..c6661b7e90 100644 --- a/test/syscalls/linux/socket_netlink_util.h +++ b/test/syscalls/linux/socket_netlink_util.h @@ -35,6 +35,10 @@ namespace testing { // Returns a bound netlink socket. PosixErrorOr NetlinkBoundSocket(int protocol); +// Returns a bound netlink socket. +PosixErrorOr NetlinkBoundSocket(int protocol, + const struct sockaddr_nl* addr); + // Returns the port ID of the passed socket. PosixErrorOr NetlinkPortID(int fd); @@ -86,6 +90,14 @@ void InitNetlinkAttr(struct nlattr* attr, int payload_size, uint16_t attr_type); // Helper function to find a netlink attribute in a message. const struct nfattr* FindNfAttr(const struct nlmsghdr* hdr, const struct nfgenmsg* msg, int16_t attr); + +inline sockaddr* AsSockAddr(sockaddr_nl* s) { + return reinterpret_cast(s); +} +inline const sockaddr* AsSockAddr(const sockaddr_nl* s) { + return reinterpret_cast(s); +} + } // namespace testing } // namespace gvisor