From e4857b4d9d7fc5f39cea2df15b90dca2f96ac4cb Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 24 Mar 2026 12:06:58 +0100 Subject: [PATCH 1/9] Add dual-stack nftables manager with IPv6 table support --- client/firewall/nftables/acl_linux.go | 44 ++-- client/firewall/nftables/addr_family_linux.go | 81 +++++++ client/firewall/nftables/manager_linux.go | 217 ++++++++++++++++-- client/firewall/nftables/router_linux.go | 123 ++++++---- client/firewall/nftables/router_linux_test.go | 189 ++++++++++++++- 5 files changed, 559 insertions(+), 95 deletions(-) create mode 100644 client/firewall/nftables/addr_family_linux.go diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index a9d066e2fcc..9d2ea726496 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -33,15 +33,12 @@ const ( const flushError = "flush: %w" -var ( - anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} -) - type AclManager struct { rConn *nftables.Conn sConn *nftables.Conn wgIface iFaceMapper routingFwChainName string + af addrFamily workTable *nftables.Table chainInputRules *nftables.Chain @@ -67,6 +64,7 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam wgIface: wgIface, workTable: table, routingFwChainName: routingFwChainName, + af: familyForAddr(table.Family == nftables.TableFamilyIPv4), ipsetStore: newIpsetStore(), rules: make(map[string]*Rule), @@ -145,7 +143,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { } if _, ok := ips[r.ip.String()]; ok { - err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}}) + err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}}) if err != nil { log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err) } @@ -254,11 +252,11 @@ func (m *AclManager) addIOFiltering( expressions = append(expressions, &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, - Offset: uint32(9), + Offset: m.af.protoOffset, Len: uint32(1), }) - protoData, err := protoToInt(proto) + protoData, err := m.af.protoNum(proto) if err != nil { return nil, fmt.Errorf("convert protocol to number: %v", err) } @@ -270,19 +268,16 @@ func (m *AclManager) addIOFiltering( }) } - rawIP := ip.To4() + rawIP := ipToBytes(ip, m.af) // check if rawIP contains zeroed IPv4 0.0.0.0 value // in that case not add IP match expression into the rule definition - if !bytes.HasPrefix(anyIP, rawIP) { - // source address position - addrOffset := uint32(12) - + if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) { expressions = append(expressions, &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, - Offset: addrOffset, - Len: 4, + Offset: m.af.srcAddrOffset, + Len: m.af.addrLen, }, ) // add individual IP for match if no ipset defined @@ -587,7 +582,7 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) { ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName) - rawIP := ip.To4() + rawIP := ipToBytes(ip, m.af) if err != nil { if ipset, err = m.createSet(m.workTable, ipsetName); err != nil { return nil, fmt.Errorf("get set name: %v", err) @@ -619,7 +614,7 @@ func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Se Name: name, Table: table, Dynamic: true, - KeyType: nftables.TypeIPAddr, + KeyType: m.af.setKeyType, } if err := m.rConn.AddSet(ipset, nil); err != nil { @@ -707,15 +702,12 @@ func ifname(n string) []byte { return b } -func protoToInt(protocol firewall.Protocol) (uint8, error) { - switch protocol { - case firewall.ProtocolTCP: - return unix.IPPROTO_TCP, nil - case firewall.ProtocolUDP: - return unix.IPPROTO_UDP, nil - case firewall.ProtocolICMP: - return unix.IPPROTO_ICMP, nil - } - return 0, fmt.Errorf("unsupported protocol: %s", protocol) +// ipToBytes converts net.IP to the correct byte length for the address family. +func ipToBytes(ip net.IP, af addrFamily) []byte { + if af.addrLen == 4 { + return ip.To4() + } + return ip.To16() } + diff --git a/client/firewall/nftables/addr_family_linux.go b/client/firewall/nftables/addr_family_linux.go new file mode 100644 index 00000000000..8b6986d0546 --- /dev/null +++ b/client/firewall/nftables/addr_family_linux.go @@ -0,0 +1,81 @@ +package nftables + +import ( + "fmt" + "net" + + "github.com/google/nftables" + "golang.org/x/sys/unix" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" +) + +// addrFamily holds protocol-specific constants for nftables expression building. +type addrFamily struct { + // protoOffset is the IP header offset for the protocol/next-header field (9 for v4, 6 for v6) + protoOffset uint32 + // srcAddrOffset is the IP header offset for the source address (12 for v4, 8 for v6) + srcAddrOffset uint32 + // dstAddrOffset is the IP header offset for the destination address (16 for v4, 24 for v6) + dstAddrOffset uint32 + // addrLen is the byte length of addresses (4 for v4, 16 for v6) + addrLen uint32 + // totalBits is the address size in bits (32 for v4, 128 for v6) + totalBits int + // setKeyType is the nftables set data type for addresses + setKeyType nftables.SetDatatype + // tableFamily is the nftables table family + tableFamily nftables.TableFamily + // icmpProto is the ICMP protocol number for this family (1 for v4, 58 for v6) + icmpProto uint8 +} + +var ( + // afIPv4 defines IPv4 header layout and nftables types. + afIPv4 = addrFamily{ + protoOffset: 9, + srcAddrOffset: 12, + dstAddrOffset: 16, + addrLen: net.IPv4len, + totalBits: 8 * net.IPv4len, + setKeyType: nftables.TypeIPAddr, + tableFamily: nftables.TableFamilyIPv4, + icmpProto: unix.IPPROTO_ICMP, + } + // afIPv6 defines IPv6 header layout and nftables types. + afIPv6 = addrFamily{ + protoOffset: 6, + srcAddrOffset: 8, + dstAddrOffset: 24, + addrLen: net.IPv6len, + totalBits: 8 * net.IPv6len, + setKeyType: nftables.TypeIP6Addr, + tableFamily: nftables.TableFamilyIPv6, + icmpProto: unix.IPPROTO_ICMPV6, + } +) + +// familyForAddr returns the address family for the given IP. +func familyForAddr(is4 bool) addrFamily { + if is4 { + return afIPv4 + } + return afIPv6 +} + +// protoNum converts a firewall protocol to the IP protocol number, +// using the correct ICMP variant for the address family. +func (af addrFamily) protoNum(protocol firewall.Protocol) (uint8, error) { + switch protocol { + case firewall.ProtocolTCP: + return unix.IPPROTO_TCP, nil + case firewall.ProtocolUDP: + return unix.IPPROTO_UDP, nil + case firewall.ProtocolICMP: + return af.icmpProto, nil + case firewall.ProtocolALL: + return 0, nil + default: + return 0, fmt.Errorf("unsupported protocol: %s", protocol) + } +} diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index f57b28abc19..1ebd13f4f39 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -49,8 +49,13 @@ type Manager struct { rConn *nftables.Conn wgIface iFaceMapper - router *router - aclManager *AclManager + router *router + aclManager *AclManager + + // IPv6 counterparts, nil when no v6 overlay + router6 *router + aclManager6 *AclManager + notrackOutputChain *nftables.Chain notrackPreroutingChain *nftables.Chain } @@ -62,7 +67,8 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { wgIface: wgIface, } - workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4} + tableName := getTableName() + workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4} var err error m.router, err = newRouter(workTable, wgIface, mtu) @@ -75,9 +81,54 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { return nil, fmt.Errorf("create acl manager: %w", err) } + if wgIface.Address().HasIPv6() { + if err := m.createIPv6Components(tableName, wgIface, mtu); err != nil { + return nil, fmt.Errorf("create IPv6 firewall: %w", err) + } + } + return m, nil } +func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mtu uint16) error { + workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6} + + var err error + m.router6, err = newRouter(workTable6, wgIface, mtu) + if err != nil { + return fmt.Errorf("create v6 router: %w", err) + } + + m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw) + if err != nil { + return fmt.Errorf("create v6 acl manager: %w", err) + } + + return nil +} + +// hasIPv6 reports whether the manager has IPv6 components initialized. +func (m *Manager) hasIPv6() bool { + return m.router6 != nil +} + +func (m *Manager) initIPv6() error { + workTable6, err := m.createWorkTableFamily(nftables.TableFamilyIPv6) + if err != nil { + return fmt.Errorf("create v6 work table: %w", err) + } + + if err := m.router6.init(workTable6); err != nil { + return fmt.Errorf("v6 router init: %w", err) + } + + if err := m.aclManager6.init(workTable6); err != nil { + return fmt.Errorf("v6 acl manager init: %w", err) + } + + return nil +} + // Init nftables firewall manager func (m *Manager) Init(stateManager *statemanager.Manager) error { workTable, err := m.createWorkTable() @@ -94,6 +145,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { return fmt.Errorf("acl manager init: %w", err) } + if m.hasIPv6() { + if err := m.initIPv6(); err != nil { + // Peer has a v6 address: v6 firewall MUST work or we risk fail-open. + return fmt.Errorf("init IPv6 firewall (required because peer has IPv6 address): %w", err) + } + } + if err := m.initNoTrackChains(workTable); err != nil { log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err) } @@ -141,12 +199,14 @@ func (m *Manager) AddPeerFiltering( m.mutex.Lock() defer m.mutex.Unlock() - rawIP := ip.To4() - if rawIP == nil { - return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) + if ip.To4() != nil { + return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) } - return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) + if !m.hasIPv6() { + return nil, fmt.Errorf("IPv6 not initialized, cannot add rule for %s", ip) + } + return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) } func (m *Manager) AddRouteFiltering( @@ -160,8 +220,11 @@ func (m *Manager) AddRouteFiltering( m.mutex.Lock() defer m.mutex.Unlock() - if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { - return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) + if isIPv6RouteRule(sources, destination) { + if !m.hasIPv6() { + return nil, fmt.Errorf("IPv6 not initialized, cannot add route rule") + } + return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) } return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) @@ -172,14 +235,38 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && isIPv6Rule(rule) { + return m.aclManager6.DeletePeerRule(rule) + } return m.aclManager.DeletePeerRule(rule) } -// DeleteRouteRule deletes a routing rule +func isIPv6Rule(rule firewall.Rule) bool { + r, ok := rule.(*Rule) + return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6 +} + +// isIPv6RouteRule determines whether a route rule belongs to the v6 table. +// For static routes, the destination prefix determines the family. For dynamic +// routes (DomainSet), the sources determine the family since management +// duplicates dynamic rules per family. +func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool { + if destination.IsPrefix() { + return destination.Prefix.Addr().Is6() + } + return len(sources) > 0 && sources[0].Addr().Is6() +} + +// DeleteRouteRule deletes a routing rule. +// Route rules are keyed by content hash, so the rule exists in exactly one +// router. We check v4 first; if the key isn't there, try v6. func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && !m.router.hasRule(rule.ID()) { + return m.router6.DeleteRouteRule(rule) + } return m.router.DeleteRouteRule(rule) } @@ -195,14 +282,54 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddNatRule(pair) + if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { + if !m.hasIPv6() { + return fmt.Errorf("IPv6 not initialized, cannot add NAT rule") + } + return m.router6.AddNatRule(pair) + } + + if err := m.router.AddNatRule(pair); err != nil { + return err + } + + // Dynamic routes (DomainSet) need NAT in both tables since resolved IPs + // can be either v4 or v6. + if m.hasIPv6() && pair.Destination.IsSet() { + v6Pair := pair + v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)} + if err := m.router6.AddNatRule(v6Pair); err != nil { + return fmt.Errorf("add v6 NAT rule: %w", err) + } + } + + return nil } func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveNatRule(pair) + if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { + if !m.hasIPv6() { + return nil + } + return m.router6.RemoveNatRule(pair) + } + + if err := m.router.RemoveNatRule(pair); err != nil { + return err + } + + if m.hasIPv6() && pair.Destination.IsSet() { + v6Pair := pair + v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)} + if err := m.router6.RemoveNatRule(v6Pair); err != nil { + return fmt.Errorf("remove v6 NAT rule: %w", err) + } + } + + return nil } // AllowNetbird allows netbird interface traffic @@ -217,6 +344,11 @@ func (m *Manager) AllowNetbird() error { if err := m.aclManager.createDefaultAllowRules(); err != nil { return fmt.Errorf("create default allow rules: %w", err) } + if m.hasIPv6() { + if err := m.aclManager6.createDefaultAllowRules(); err != nil { + return fmt.Errorf("create v6 default allow rules: %w", err) + } + } if err := m.rConn.Flush(); err != nil { return fmt.Errorf("flush allow input netbird rules: %w", err) } @@ -226,7 +358,13 @@ func (m *Manager) AllowNetbird() error { // SetLegacyManagement sets the route manager to use legacy management func (m *Manager) SetLegacyManagement(isLegacy bool) error { - return firewall.SetLegacyManagement(m.router, isLegacy) + if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil { + return err + } + if m.hasIPv6() { + return firewall.SetLegacyManagement(m.router6, isLegacy) + } + return nil } // Close closes the firewall manager @@ -238,6 +376,12 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { return fmt.Errorf("reset router: %v", err) } + if m.hasIPv6() { + if err := m.router6.Reset(); err != nil { + return fmt.Errorf("reset v6 router: %v", err) + } + } + if err := m.cleanupNetbirdTables(); err != nil { return fmt.Errorf("cleanup netbird tables: %v", err) } @@ -299,6 +443,12 @@ func (m *Manager) Flush() error { return err } + if m.hasIPv6() { + if err := m.aclManager6.Flush(); err != nil { + return fmt.Errorf("flush v6 acl: %w", err) + } + } + if err := m.refreshNoTrackChains(); err != nil { log.Errorf("failed to refresh notrack chains: %v", err) } @@ -311,6 +461,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && rule.TranslatedAddress.Is6() { + return m.router6.AddDNATRule(rule) + } return m.router.AddDNATRule(rule) } @@ -319,6 +472,9 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && !m.router.hasDNATRule(rule.ID()) { + return m.router6.DeleteDNATRule(rule) + } return m.router.DeleteDNATRule(rule) } @@ -327,7 +483,26 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.UpdateSet(set, prefixes) + var v4Prefixes, v6Prefixes []netip.Prefix + for _, p := range prefixes { + if p.Addr().Is6() { + v6Prefixes = append(v6Prefixes, p) + } else { + v4Prefixes = append(v4Prefixes, p) + } + } + + if err := m.router.UpdateSet(set, v4Prefixes); err != nil { + return err + } + + if m.hasIPv6() && len(v6Prefixes) > 0 { + if err := m.router6.UpdateSet(set, v6Prefixes); err != nil { + return fmt.Errorf("update v6 set: %w", err) + } + } + + return nil } // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. @@ -335,6 +510,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && localAddr.Is6() { + return m.router6.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) + } return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) } @@ -343,6 +521,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && localAddr.Is6() { + return m.router6.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) + } return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) } @@ -517,7 +698,11 @@ func (m *Manager) refreshNoTrackChains() error { } func (m *Manager) createWorkTable() (*nftables.Table, error) { - tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) + return m.createWorkTableFamily(nftables.TableFamilyIPv4) +} + +func (m *Manager) createWorkTableFamily(family nftables.TableFamily) (*nftables.Table, error) { + tables, err := m.rConn.ListTablesOfFamily(family) if err != nil { return nil, fmt.Errorf("list of tables: %w", err) } @@ -529,7 +714,7 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) { } } - table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}) + table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: family}) err = m.rConn.Flush() return table, err } diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index fde654c20ca..4627c9acd7e 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -46,8 +46,10 @@ const ( dnatSuffix = "_dnat" snatSuffix = "_snat" - // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation - ipTCPHeaderMinSize = 40 + // ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation. + ipv4TCPHeaderSize = 40 + // ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation. + ipv6TCPHeaderSize = 60 // maxPrefixesSet 1638 prefixes start to fail, taking some margin maxPrefixesSet = 1500 @@ -72,6 +74,7 @@ type router struct { rules map[string]*nftables.Rule ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set] + af addrFamily wgIface iFaceMapper ipFwdState *ipfwdstate.IPForwardingState legacyManagement bool @@ -84,6 +87,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou workTable: workTable, chains: make(map[string]*nftables.Chain), rules: make(map[string]*nftables.Rule), + af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4), wgIface: wgIface, ipFwdState: ipfwdstate.NewIPForwardingState(), mtu: mtu, @@ -142,7 +146,7 @@ func (r *router) Reset() error { func (r *router) removeNatPreroutingRules() error { table := &nftables.Table{ Name: tableNat, - Family: nftables.TableFamilyIPv4, + Family: r.af.tableFamily, } chain := &nftables.Chain{ Name: chainNameNatPrerouting, @@ -175,7 +179,7 @@ func (r *router) removeNatPreroutingRules() error { } func (r *router) loadFilterTable() (*nftables.Table, error) { - tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) + tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily) if err != nil { return nil, fmt.Errorf("list tables: %w", err) } @@ -407,7 +411,7 @@ func (r *router) AddRouteFiltering( // Handle protocol if proto != firewall.ProtocolALL { - protoNum, err := protoToInt(proto) + protoNum, err := r.af.protoNum(proto) if err != nil { return nil, fmt.Errorf("convert protocol to number: %w", err) } @@ -467,7 +471,17 @@ func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bo return nil, fmt.Errorf("create or get ipset: %w", err) } - return getIpSetExprs(ref, isSource) + return r.getIpSetExprs(ref, isSource) +} + +func (r *router) hasRule(id string) bool { + _, ok := r.rules[id] + return ok +} + +func (r *router) hasDNATRule(id string) bool { + _, ok := r.rules[id+dnatSuffix] + return ok } func (r *router) DeleteRouteRule(rule firewall.Rule) error { @@ -516,10 +530,10 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err Table: r.workTable, // required for prefixes Interval: true, - KeyType: nftables.TypeIPAddr, + KeyType: r.af.setKeyType, } - elements := convertPrefixesToSet(prefixes) + elements := r.convertPrefixesToSet(prefixes) nElements := len(elements) maxElements := maxPrefixesSet * 2 @@ -552,23 +566,17 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err return nfset, nil } -func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement { +func (r *router) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement { var elements []nftables.SetElement for _, prefix := range prefixes { - // TODO: Implement IPv6 support - if prefix.Addr().Is6() { - log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) - continue - } - // nftables needs half-open intervals [firstIP, lastIP) for prefixes // e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc firstIP := prefix.Addr() lastIP := calculateLastIP(prefix).Next() elements = append(elements, - // the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247 - // nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true}, + // the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247 + // nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true}, nftables.SetElement{Key: firstIP.AsSlice()}, nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true}, ) @@ -578,10 +586,20 @@ func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement { // calculateLastIP determines the last IP in a given prefix. func calculateLastIP(prefix netip.Prefix) netip.Addr { - hostMask := ^uint32(0) >> prefix.Masked().Bits() - lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask + masked := prefix.Masked() + if masked.Addr().Is4() { + hostMask := ^uint32(0) >> masked.Bits() + lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask + return netip.AddrFrom4(uint32ToBytes(lastIP)) + } - return netip.AddrFrom4(uint32ToBytes(lastIP)) + // IPv6: set host bits to all 1s + b := masked.Addr().As16() + bits := masked.Bits() + for i := bits; i < 128; i++ { + b[i/8] |= 1 << (7 - i%8) + } + return netip.AddrFrom16(b) } // Utility function to convert netip.Addr to uint32. @@ -833,9 +851,12 @@ func (r *router) addPostroutingRules() { } // addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. -// TODO: Add IPv6 support func (r *router) addMSSClampingRules() error { - mss := r.mtu - ipTCPHeaderMinSize + overhead := uint16(ipv4TCPHeaderSize) + if r.af.tableFamily == nftables.TableFamilyIPv6 { + overhead = ipv6TCPHeaderSize + } + mss := r.mtu - overhead exprsOut := []expr.Any{ &expr.Meta{ @@ -1294,7 +1315,7 @@ func (r *router) removeExternalChainsRules() error { func (r *router) findExternalChains() []*nftables.Chain { var chains []*nftables.Chain - families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet} + families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet} for _, family := range families { allChains, err := r.conn.ListChainsOfTableFamily(family) @@ -1460,7 +1481,7 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { return rule, nil } - protoNum, err := protoToInt(rule.Protocol) + protoNum, err := r.af.protoNum(rule.Protocol) if err != nil { return nil, fmt.Errorf("convert protocol to number: %w", err) } @@ -1523,7 +1544,7 @@ func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, rule dnatExprs = append(dnatExprs, &expr.NAT{ Type: expr.NATTypeDestNAT, - Family: uint32(nftables.TableFamilyIPv4), + Family: uint32(r.af.tableFamily), RegAddrMin: 1, RegProtoMin: regProtoMin, RegProtoMax: regProtoMax, @@ -1619,7 +1640,7 @@ func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule f dnatRule := &nftables.Rule{ Table: &nftables.Table{ Name: tableNat, - Family: nftables.TableFamilyIPv4, + Family: r.af.tableFamily, }, Chain: &nftables.Chain{ Name: chainNameNatPrerouting, @@ -1654,8 +1675,8 @@ func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, + Offset: r.af.dstAddrOffset, + Len: r.af.addrLen, }, &expr.Cmp{ Op: expr.CmpOpEq, @@ -1733,7 +1754,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return fmt.Errorf("get set %s: %w", set.HashedName(), err) } - elements := convertPrefixesToSet(prefixes) + elements := r.convertPrefixesToSet(prefixes) if err := r.conn.SetAddElements(nfset, elements); err != nil { return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err) } @@ -1755,7 +1776,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol return nil } - protoNum, err := protoToInt(protocol) + protoNum, err := r.af.protoNum(protocol) if err != nil { return fmt.Errorf("convert protocol to number: %w", err) } @@ -1786,7 +1807,11 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol }, } - exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) + bits := 32 + if localAddr.Is6() { + bits = 128 + } + exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...) exprs = append(exprs, &expr.Immediate{ @@ -1799,7 +1824,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol }, &expr.NAT{ Type: expr.NATTypeDestNAT, - Family: uint32(nftables.TableFamilyIPv4), + Family: uint32(r.af.tableFamily), RegAddrMin: 1, RegProtoMin: 2, RegProtoMax: 0, @@ -1868,45 +1893,44 @@ func (r *router) applyNetwork( } if network.IsPrefix() { - return applyPrefix(network.Prefix, isSource), nil + return r.applyPrefix(network.Prefix, isSource), nil } return nil, nil } // applyPrefix generates nftables expressions for a CIDR prefix -func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any { - // dst offset - offset := uint32(16) +func (r *router) applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any { + // dst offset by default + offset := r.af.dstAddrOffset if isSource { // src offset - offset = 12 + offset = r.af.srcAddrOffset } ones := prefix.Bits() - // 0.0.0.0/0 doesn't need extra expressions + // unspecified address (/0) doesn't need extra expressions if ones == 0 { return nil } - mask := net.CIDRMask(ones, 32) + mask := net.CIDRMask(ones, r.af.totalBits) + xor := make([]byte, r.af.addrLen) return []expr.Any{ &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: offset, - Len: 4, + Len: r.af.addrLen, }, - // netmask &expr.Bitwise{ DestRegister: 1, SourceRegister: 1, - Len: 4, + Len: r.af.addrLen, Mask: mask, - Xor: []byte{0, 0, 0, 0}, + Xor: xor, }, - // net address &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, @@ -1989,13 +2013,12 @@ func getCtNewExprs() []expr.Any { } } -func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) { - - // dst offset - offset := uint32(16) +func (r *router) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) { + // dst offset by default + offset := r.af.dstAddrOffset if isSource { // src offset - offset = 12 + offset = r.af.srcAddrOffset } return []expr.Any{ @@ -2003,7 +2026,7 @@ func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: offset, - Len: 4, + Len: r.af.addrLen, }, &expr.Lookup{ SourceRegister: 1, diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index f0e34d211f5..c5d6729d984 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -90,8 +90,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) { } // Build CIDR matching expressions - sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true) - destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false) + testRouter := &router{af: afIPv4} + sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true) + destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false) // Combine all expressions in the correct order // nolint:gocritic @@ -508,6 +509,136 @@ func TestNftablesCreateIpSet(t *testing.T) { } } +func TestNftablesCreateIpSet_IPv6(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + workTable, err := createWorkTableIPv6() + require.NoError(t, err, "Failed to create v6 work table") + defer deleteWorkTableIPv6() + + r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU) + require.NoError(t, err, "Failed to create router") + require.NoError(t, r.init(workTable)) + defer func() { + require.NoError(t, r.Reset(), "Failed to reset router") + }() + + tests := []struct { + name string + sources []netip.Prefix + expected []netip.Prefix + }{ + { + name: "Single IPv6", + sources: []netip.Prefix{netip.MustParsePrefix("2001:db8::1/128")}, + }, + { + name: "Multiple IPv6 Subnets", + sources: []netip.Prefix{ + netip.MustParsePrefix("fd00::/64"), + netip.MustParsePrefix("2001:db8::/48"), + netip.MustParsePrefix("fe80::/10"), + }, + }, + { + name: "Overlapping IPv6", + sources: []netip.Prefix{ + netip.MustParsePrefix("fd00::/48"), + netip.MustParsePrefix("fd00::/64"), + netip.MustParsePrefix("fd00::1/128"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("fd00::/48"), + }, + }, + { + name: "Mixed prefix lengths", + sources: []netip.Prefix{ + netip.MustParsePrefix("2001:db8:1::/48"), + netip.MustParsePrefix("2001:db8:2::1/128"), + netip.MustParsePrefix("fd00:abcd::/32"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setName := firewall.NewPrefixSet(tt.sources).HashedName() + set, err := r.createIpSet(setName, setInput{prefixes: tt.sources}) + require.NoError(t, err, "Failed to create IPv6 set") + require.NotNil(t, set) + + assert.Equal(t, setName, set.Name) + assert.True(t, set.Interval) + assert.Equal(t, nftables.TypeIP6Addr, set.KeyType) + + fetchedSet, err := r.conn.GetSetByName(r.workTable, setName) + require.NoError(t, err, "Failed to fetch created set") + + elements, err := r.conn.GetSetElements(fetchedSet) + require.NoError(t, err, "Failed to get set elements") + + uniquePrefixes := make(map[string]bool) + for _, elem := range elements { + if !elem.IntervalEnd && len(elem.Key) == 16 { + ip := netip.AddrFrom16([16]byte(elem.Key)) + uniquePrefixes[ip.String()] = true + } + } + + expectedCount := len(tt.expected) + if expectedCount == 0 { + expectedCount = len(tt.sources) + } + assert.Equal(t, expectedCount, len(uniquePrefixes), "unique prefix count mismatch") + + r.conn.DelSet(set) + require.NoError(t, r.conn.Flush()) + }) + } +} + +func createWorkTableIPv6() (*nftables.Table, error) { + sConn, err := nftables.New(nftables.AsLasting()) + if err != nil { + return nil, err + } + + tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6) + if err != nil { + return nil, err + } + for _, t := range tables { + if t.Name == tableNameNetbird { + sConn.DelTable(t) + } + } + + table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv6}) + err = sConn.Flush() + return table, err +} + +func deleteWorkTableIPv6() { + sConn, err := nftables.New(nftables.AsLasting()) + if err != nil { + return + } + + tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6) + if err != nil { + return + } + for _, t := range tables { + if t.Name == tableNameNetbird { + sConn.DelTable(t) + _ = sConn.Flush() + } + } +} + func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) { t.Helper() @@ -627,7 +758,7 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool { func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool { var metaFound, cmpFound bool - expectedProto, _ := protoToInt(proto) + expectedProto, _ := afIPv4.protoNum(proto) for _, e := range exprs { switch ex := e.(type) { case *expr.Meta: @@ -854,3 +985,55 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) { } assert.Equal(t, 1, found, "NAT rule should exist in kernel") } + +func TestCalculateLastIP(t *testing.T) { + tests := []struct { + prefix string + want string + }{ + {"10.0.0.0/24", "10.0.0.255"}, + {"10.0.0.0/32", "10.0.0.0"}, + {"0.0.0.0/0", "255.255.255.255"}, + {"192.168.1.0/28", "192.168.1.15"}, + {"fd00::/64", "fd00::ffff:ffff:ffff:ffff"}, + {"fd00::/128", "fd00::"}, + {"2001:db8::/48", "2001:db8:0:ffff:ffff:ffff:ffff:ffff"}, + {"::/0", "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}, + } + for _, tt := range tests { + t.Run(tt.prefix, func(t *testing.T) { + prefix := netip.MustParsePrefix(tt.prefix) + got := calculateLastIP(prefix) + assert.Equal(t, tt.want, got.String()) + }) + } +} + +func TestConvertPrefixesToSet_IPv6(t *testing.T) { + r := &router{af: afIPv6} + prefixes := []netip.Prefix{ + netip.MustParsePrefix("fd00::/64"), + netip.MustParsePrefix("2001:db8::1/128"), + } + + elements := r.convertPrefixesToSet(prefixes) + + // Each prefix produces 2 elements (start + end) + require.Len(t, elements, 4) + + // fd00::/64 start + assert.Equal(t, netip.MustParseAddr("fd00::").As16(), [16]byte(elements[0].Key)) + assert.False(t, elements[0].IntervalEnd) + + // fd00::/64 end (fd00:0:0:1::, one past the last) + assert.Equal(t, netip.MustParseAddr("fd00:0:0:1::").As16(), [16]byte(elements[1].Key)) + assert.True(t, elements[1].IntervalEnd) + + // 2001:db8::1/128 start + assert.Equal(t, netip.MustParseAddr("2001:db8::1").As16(), [16]byte(elements[2].Key)) + assert.False(t, elements[2].IntervalEnd) + + // 2001:db8::1/128 end (2001:db8::2) + assert.Equal(t, netip.MustParseAddr("2001:db8::2").As16(), [16]byte(elements[3].Key)) + assert.True(t, elements[3].IntervalEnd) +} From c13f1af1968b1316158ba747a7b4369e7cdcc159 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 27 Mar 2026 06:50:05 +0100 Subject: [PATCH 2/9] Fix ip6tables-save compat: skip ip6tables-managed tables in external chain scan --- .../firewall/nftables/manager_linux_test.go | 95 +++++++++++++++++++ client/firewall/nftables/router_linux.go | 4 +- 2 files changed, 97 insertions(+), 2 deletions(-) diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 75b1e2b6cd9..93544b01b1d 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -389,6 +389,101 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { verifyIptablesOutput(t, stdout, stderr) } +func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + if _, err := exec.LookPath("ip6tables-save"); err != nil { + t.Skipf("ip6tables-save not available on this system: %v", err) + } + + // Seed ip6 tables in the nft backend. Docker may not create them. + seedIp6tables(t) + + ifaceMockV6 := &iFaceMock{ + NameFunc: func() string { return "wt-test" }, + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.96.0.1"), + Network: netip.MustParsePrefix("100.96.0.0/16"), + IPv6: netip.MustParseAddr("fd00::1"), + IPv6Net: netip.MustParsePrefix("fd00::/64"), + } + }, + } + + manager, err := Create(ifaceMockV6, iface.DefaultMTU) + require.NoError(t, err, "create manager") + require.NoError(t, manager.Init(nil)) + + t.Cleanup(func() { + require.NoError(t, manager.Close(nil), "close manager") + + stdout, stderr := runIp6tablesSave(t) + verifyIp6tablesOutput(t, stdout, stderr) + }) + + ip := netip.MustParseAddr("fd00::2") + _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "") + require.NoError(t, err, "add v6 peer filtering rule") + + _, err = manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("fd00:1::/64")}, + fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")}, + fw.ProtocolTCP, + nil, + &fw.Port{Values: []uint16{443}}, + fw.ActionAccept, + ) + require.NoError(t, err, "add v6 route filtering rule") + + err = manager.AddNatRule(fw.RouterPair{ + Source: fw.Network{Prefix: netip.MustParsePrefix("fd00::/64")}, + Destination: fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")}, + Masquerade: true, + }) + require.NoError(t, err, "add v6 NAT rule") + + stdout, stderr := runIp6tablesSave(t) + verifyIp6tablesOutput(t, stdout, stderr) +} + +func seedIp6tables(t *testing.T) { + t.Helper() + for _, tc := range []struct{ table, chain string }{ + {"filter", "FORWARD"}, + {"nat", "POSTROUTING"}, + {"mangle", "FORWARD"}, + } { + add := exec.Command("ip6tables", "-t", tc.table, "-A", tc.chain, "-j", "ACCEPT") + require.NoError(t, add.Run(), "seed ip6tables -t %s", tc.table) + del := exec.Command("ip6tables", "-t", tc.table, "-D", tc.chain, "-j", "ACCEPT") + require.NoError(t, del.Run(), "unseed ip6tables -t %s", tc.table) + } +} + +func runIp6tablesSave(t *testing.T) (string, string) { + t.Helper() + var stdout, stderr bytes.Buffer + cmd := exec.Command("ip6tables-save") + cmd.Stdout = &stdout + cmd.Stderr = &stderr + require.NoError(t, cmd.Run(), "ip6tables-save failed") + return stdout.String(), stderr.String() +} + +func verifyIp6tablesOutput(t *testing.T, stdout, stderr string) { + t.Helper() + require.NotContains(t, stdout, "Table `nat' is incompatible", + "ip6tables-save: nat table incompatible. Full output: %s", stdout) + require.NotContains(t, stdout, "Table `mangle' is incompatible", + "ip6tables-save: mangle table incompatible. Full output: %s", stdout) + require.NotContains(t, stdout, "Table `filter' is incompatible", + "ip6tables-save: filter table incompatible. Full output: %s", stdout) +} + func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) { if check() != NFTABLES { t.Skip("nftables not supported on this system") diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 4627c9acd7e..9f8d9f96612 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -1339,8 +1339,8 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool { return false } - // Skip all iptables-managed tables in the ip family - if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) { + // Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat) + if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) { return false } From 571527c2d3dc47710ea8c243f2d8d8d51db50511 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 27 Mar 2026 06:56:15 +0100 Subject: [PATCH 3/9] Add iptablesProto helper, filter table fallback, DNAT compat tests --- client/firewall/nftables/manager_linux.go | 8 ++++- .../firewall/nftables/manager_linux_test.go | 29 ++++++++++++++++++- client/firewall/nftables/router_linux.go | 28 ++++++++++++++---- 3 files changed, 57 insertions(+), 8 deletions(-) diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 1ebd13f4f39..cca661ad14e 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -332,7 +332,13 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { return nil } -// AllowNetbird allows netbird interface traffic +// AllowNetbird allows netbird interface traffic. +// TODO: In USP mode this only adds ACCEPT to the netbird table's own chains, +// which doesn't override DROP rules in external tables (e.g. firewalld). +// Should add passthrough rules to external chains (like the native mode router's +// addExternalChainsRules does) for both the netbird table family and inet tables. +// The netbird table itself is fine (routing chains already exist there), but +// non-netbird tables with INPUT/FORWARD hooks can still DROP our WG traffic. func (m *Manager) AllowNetbird() error { if !m.wgIface.IsUserspaceBind() { return nil diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 93544b01b1d..37337e0a016 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -385,6 +385,18 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { err = manager.AddNatRule(pair) require.NoError(t, err, "failed to add NAT rule") + dnatRule, err := manager.AddDNATRule(fw.ForwardRule{ + Protocol: fw.ProtocolTCP, + DestinationPort: fw.Port{Values: []uint16{8080}}, + TranslatedAddress: netip.MustParseAddr("100.96.0.2"), + TranslatedPort: fw.Port{Values: []uint16{80}}, + }) + require.NoError(t, err, "failed to add DNAT rule") + + t.Cleanup(func() { + require.NoError(t, manager.DeleteDNATRule(dnatRule), "failed to delete DNAT rule") + }) + stdout, stderr = runIptablesSave(t) verifyIptablesOutput(t, stdout, stderr) } @@ -446,7 +458,22 @@ func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) { }) require.NoError(t, err, "add v6 NAT rule") - stdout, stderr := runIp6tablesSave(t) + dnatRule, err := manager.AddDNATRule(fw.ForwardRule{ + Protocol: fw.ProtocolTCP, + DestinationPort: fw.Port{Values: []uint16{8080}}, + TranslatedAddress: netip.MustParseAddr("fd00::2"), + TranslatedPort: fw.Port{Values: []uint16{80}}, + }) + require.NoError(t, err, "add v6 DNAT rule") + + t.Cleanup(func() { + require.NoError(t, manager.DeleteDNATRule(dnatRule), "delete v6 DNAT rule") + }) + + stdout, stderr := runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) + + stdout, stderr = runIp6tablesSave(t) verifyIp6tablesOutput(t, stdout, stderr) } diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 9f8d9f96612..67ace848e73 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -474,6 +474,13 @@ func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bo return r.getIpSetExprs(ref, isSource) } +func (r *router) iptablesProto() iptables.Protocol { + if r.af.tableFamily == nftables.TableFamilyIPv6 { + return iptables.ProtocolIPv6 + } + return iptables.ProtocolIPv4 +} + func (r *router) hasRule(id string) bool { _, ok := r.rules[id] return ok @@ -1063,17 +1070,22 @@ func (r *router) acceptFilterTableRules() error { log.Debugf("Used %s to add accept forward and input rules", fw) }() - // Try iptables first and fallback to nftables if iptables is not available - ipt, err := iptables.New() + // Try iptables first and fallback to nftables if iptables is not available. + // Use the correct protocol (iptables vs ip6tables) for the address family. + ipt, err := iptables.NewWithProtocol(r.iptablesProto()) if err != nil { - // iptables is not available but the filter table exists log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) fw = "nftables" return r.acceptFilterRulesNftables(r.filterTable) } - return r.acceptFilterRulesIptables(ipt) + if err := r.acceptFilterRulesIptables(ipt); err != nil { + log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err) + fw = "nftables" + return r.acceptFilterRulesNftables(r.filterTable) + } + return nil } func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error { @@ -1242,13 +1254,17 @@ func (r *router) removeFilterTableRules() error { return nil } - ipt, err := iptables.New() + ipt, err := iptables.NewWithProtocol(r.iptablesProto()) if err != nil { log.Debugf("iptables not available, using nftables to remove filter rules: %v", err) return r.removeAcceptRulesFromTable(r.filterTable) } - return r.removeAcceptFilterRulesIptables(ipt) + if err := r.removeAcceptFilterRulesIptables(ipt); err != nil { + log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err) + return r.removeAcceptRulesFromTable(r.filterTable) + } + return nil } func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error { From 443d0720a39d0b015de79a41d65d8aaab887d8ed Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 7 Apr 2026 18:15:18 +0200 Subject: [PATCH 4/9] Fix CodeRabbit findings: rollback on init failure, accumulate Close errors, expand test guards --- client/firewall/nftables/manager_linux.go | 36 +++++++++++++++---- .../firewall/nftables/manager_linux_test.go | 6 ++-- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index cca661ad14e..de8985f79e1 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -11,9 +11,11 @@ import ( "github.com/google/nftables" "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" + nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -141,13 +143,31 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { } if err := m.aclManager.init(workTable); err != nil { - // TODO: cleanup router + if err := m.router.Reset(); err != nil { + log.Warnf("rollback router after acl init failure: %v", err) + } + if err := m.cleanupNetbirdTables(); err != nil { + log.Warnf("cleanup tables after acl init failure: %v", err) + } + if err := m.rConn.Flush(); err != nil { + log.Warnf("flush after acl init failure: %v", err) + } return fmt.Errorf("acl manager init: %w", err) } if m.hasIPv6() { if err := m.initIPv6(); err != nil { // Peer has a v6 address: v6 firewall MUST work or we risk fail-open. + // Best-effort rollback of already-initialized v4 state. + if err := m.router.Reset(); err != nil { + log.Warnf("rollback v4 router after v6 init failure: %v", err) + } + if err := m.cleanupNetbirdTables(); err != nil { + log.Warnf("cleanup tables after v6 init failure: %v", err) + } + if err := m.rConn.Flush(); err != nil { + log.Warnf("flush after v6 init failure: %v", err) + } return fmt.Errorf("init IPv6 firewall (required because peer has IPv6 address): %w", err) } } @@ -378,29 +398,31 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() + var merr *multierror.Error + if err := m.router.Reset(); err != nil { - return fmt.Errorf("reset router: %v", err) + merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err)) } if m.hasIPv6() { if err := m.router6.Reset(); err != nil { - return fmt.Errorf("reset v6 router: %v", err) + merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err)) } } if err := m.cleanupNetbirdTables(); err != nil { - return fmt.Errorf("cleanup netbird tables: %v", err) + merr = multierror.Append(merr, fmt.Errorf("cleanup netbird tables: %v", err)) } if err := m.rConn.Flush(); err != nil { - return fmt.Errorf(flushError, err) + merr = multierror.Append(merr, fmt.Errorf(flushError, err)) } if err := stateManager.DeleteState(&ShutdownState{}); err != nil { - return fmt.Errorf("delete state: %v", err) + merr = multierror.Append(merr, fmt.Errorf("delete state: %v", err)) } - return nil + return nberrors.FormatErrorOrNil(merr) } func (m *Manager) cleanupNetbirdTables() error { diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 37337e0a016..d925f3ef3ca 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -406,8 +406,10 @@ func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) { t.Skip("nftables not supported on this system") } - if _, err := exec.LookPath("ip6tables-save"); err != nil { - t.Skipf("ip6tables-save not available on this system: %v", err) + for _, bin := range []string{"ip6tables", "ip6tables-save", "iptables-save"} { + if _, err := exec.LookPath(bin); err != nil { + t.Skipf("%s not available on this system: %v", bin, err) + } } // Seed ip6 tables in the nft backend. Docker may not create them. From 546140a2c226ec2baaf8e054d47ec7d8686d8e6f Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 7 Apr 2026 18:18:15 +0200 Subject: [PATCH 5/9] Extract rollbackInit helper to deduplicate Init cleanup --- client/firewall/nftables/manager_linux.go | 34 ++++++++++------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index de8985f79e1..ececaae6bdf 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -143,31 +143,14 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { } if err := m.aclManager.init(workTable); err != nil { - if err := m.router.Reset(); err != nil { - log.Warnf("rollback router after acl init failure: %v", err) - } - if err := m.cleanupNetbirdTables(); err != nil { - log.Warnf("cleanup tables after acl init failure: %v", err) - } - if err := m.rConn.Flush(); err != nil { - log.Warnf("flush after acl init failure: %v", err) - } + m.rollbackInit() return fmt.Errorf("acl manager init: %w", err) } if m.hasIPv6() { if err := m.initIPv6(); err != nil { // Peer has a v6 address: v6 firewall MUST work or we risk fail-open. - // Best-effort rollback of already-initialized v4 state. - if err := m.router.Reset(); err != nil { - log.Warnf("rollback v4 router after v6 init failure: %v", err) - } - if err := m.cleanupNetbirdTables(); err != nil { - log.Warnf("cleanup tables after v6 init failure: %v", err) - } - if err := m.rConn.Flush(); err != nil { - log.Warnf("flush after v6 init failure: %v", err) - } + m.rollbackInit() return fmt.Errorf("init IPv6 firewall (required because peer has IPv6 address): %w", err) } } @@ -203,6 +186,19 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { return nil } +// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through. +func (m *Manager) rollbackInit() { + if err := m.router.Reset(); err != nil { + log.Warnf("rollback router: %v", err) + } + if err := m.cleanupNetbirdTables(); err != nil { + log.Warnf("cleanup tables: %v", err) + } + if err := m.rConn.Flush(); err != nil { + log.Warnf("flush: %v", err) + } +} + // AddPeerFiltering rule to the firewall // // If comment argument is empty firewall manager should set From 024cc6cba924b2cfe0fb0979d22bc4b855c10e3e Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 7 Apr 2026 18:22:14 +0200 Subject: [PATCH 6/9] Split Init to reduce cognitive complexity --- client/firewall/nftables/manager_linux.go | 25 ++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index ececaae6bdf..8688c821ab4 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -133,6 +133,16 @@ func (m *Manager) initIPv6() error { // Init nftables firewall manager func (m *Manager) Init(stateManager *statemanager.Manager) error { + if err := m.initFirewall(); err != nil { + return err + } + + m.persistState(stateManager) + + return nil +} + +func (m *Manager) initFirewall() error { workTable, err := m.createWorkTable() if err != nil { return fmt.Errorf("create work table: %w", err) @@ -159,12 +169,16 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err) } + return nil +} + +// persistState saves the current interface state for potential recreation on restart. +// Unlike iptables, which requires tracking individual rules, nftables maintains +// a known state (our netbird table plus a few static rules). This allows for easy +// cleanup using Close() without needing to store specific rules. +func (m *Manager) persistState(stateManager *statemanager.Manager) { stateManager.RegisterState(&ShutdownState{}) - // We only need to record minimal interface state for potential recreation. - // Unlike iptables, which requires tracking individual rules, nftables maintains - // a known state (our netbird table plus a few static rules). This allows for easy - // cleanup using Close() without needing to store specific rules. if err := stateManager.UpdateState(&ShutdownState{ InterfaceState: &InterfaceState{ NameStr: m.wgIface.Name(), @@ -176,14 +190,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { log.Errorf("failed to update state: %v", err) } - // persist early go func() { if err := stateManager.PersistState(context.Background()); err != nil { log.Errorf("failed to persist state: %v", err) } }() - - return nil } // rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through. From 212ab6316fbe132dc572a9fa6e0863aaee0195f7 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 8 Apr 2026 10:22:43 +0200 Subject: [PATCH 7/9] Fix AddOutputDNAT using removed standalone functions after merge from main --- client/firewall/nftables/router_linux.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 4c3e6b587c7..02f8288fec3 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -1928,7 +1928,7 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, return err } - protoNum, err := protoToInt(protocol) + protoNum, err := r.af.protoNum(protocol) if err != nil { return fmt.Errorf("convert protocol to number: %w", err) } @@ -1953,7 +1953,11 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, }, } - exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) + bits := 32 + if localAddr.Is6() { + bits = 128 + } + exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...) exprs = append(exprs, &expr.Immediate{ @@ -1966,7 +1970,7 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, }, &expr.NAT{ Type: expr.NATTypeDestNAT, - Family: uint32(nftables.TableFamilyIPv4), + Family: uint32(r.af.tableFamily), RegAddrMin: 1, RegProtoMin: 2, }, From 62fd61fa8fd438b3d41376218e1ae8691fea4d0e Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 9 Apr 2026 09:48:51 +0200 Subject: [PATCH 8/9] Address review feedback for nftables IPv6 firewall - Reorder declarations in addr_family_linux.go (vars before struct) - Add router6.Reset() to rollbackInit for partial IPv6 init failure - Share ipFwdState between router and router6 (same sysctls) - Add todo for separating v4/v6 forwarding state --- client/firewall/nftables/addr_family_linux.go | 40 +++++++++---------- client/firewall/nftables/manager_linux.go | 9 +++++ .../routemanager/ipfwdstate/ipfwdstate.go | 6 ++- 3 files changed, 34 insertions(+), 21 deletions(-) diff --git a/client/firewall/nftables/addr_family_linux.go b/client/firewall/nftables/addr_family_linux.go index 8b6986d0546..0c90d704a5e 100644 --- a/client/firewall/nftables/addr_family_linux.go +++ b/client/firewall/nftables/addr_family_linux.go @@ -10,26 +10,6 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" ) -// addrFamily holds protocol-specific constants for nftables expression building. -type addrFamily struct { - // protoOffset is the IP header offset for the protocol/next-header field (9 for v4, 6 for v6) - protoOffset uint32 - // srcAddrOffset is the IP header offset for the source address (12 for v4, 8 for v6) - srcAddrOffset uint32 - // dstAddrOffset is the IP header offset for the destination address (16 for v4, 24 for v6) - dstAddrOffset uint32 - // addrLen is the byte length of addresses (4 for v4, 16 for v6) - addrLen uint32 - // totalBits is the address size in bits (32 for v4, 128 for v6) - totalBits int - // setKeyType is the nftables set data type for addresses - setKeyType nftables.SetDatatype - // tableFamily is the nftables table family - tableFamily nftables.TableFamily - // icmpProto is the ICMP protocol number for this family (1 for v4, 58 for v6) - icmpProto uint8 -} - var ( // afIPv4 defines IPv4 header layout and nftables types. afIPv4 = addrFamily{ @@ -55,6 +35,26 @@ var ( } ) +// addrFamily holds protocol-specific constants for nftables expression building. +type addrFamily struct { + // protoOffset is the IP header offset for the protocol/next-header field (9 for v4, 6 for v6) + protoOffset uint32 + // srcAddrOffset is the IP header offset for the source address (12 for v4, 8 for v6) + srcAddrOffset uint32 + // dstAddrOffset is the IP header offset for the destination address (16 for v4, 24 for v6) + dstAddrOffset uint32 + // addrLen is the byte length of addresses (4 for v4, 16 for v6) + addrLen uint32 + // totalBits is the address size in bits (32 for v4, 128 for v6) + totalBits int + // setKeyType is the nftables set data type for addresses + setKeyType nftables.SetDatatype + // tableFamily is the nftables table family + tableFamily nftables.TableFamily + // icmpProto is the ICMP protocol number for this family (1 for v4, 58 for v6) + icmpProto uint8 +} + // familyForAddr returns the address family for the given IP. func familyForAddr(is4 bool) addrFamily { if is4 { diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index bf5bc8e3836..c3c1c1a6590 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -101,6 +101,10 @@ func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mt return fmt.Errorf("create v6 router: %w", err) } + // Share the same IP forwarding state with the v4 router, since + // EnableIPForwarding controls both v4 and v6 sysctls. + m.router6.ipFwdState = m.router.ipFwdState + m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw) if err != nil { return fmt.Errorf("create v6 acl manager: %w", err) @@ -202,6 +206,11 @@ func (m *Manager) rollbackInit() { if err := m.router.Reset(); err != nil { log.Warnf("rollback router: %v", err) } + if m.hasIPv6() { + if err := m.router6.Reset(); err != nil { + log.Warnf("rollback v6 router: %v", err) + } + } if err := m.cleanupNetbirdTables(); err != nil { log.Warnf("cleanup tables: %v", err) } diff --git a/client/internal/routemanager/ipfwdstate/ipfwdstate.go b/client/internal/routemanager/ipfwdstate/ipfwdstate.go index da81c18f9cd..2be1c2ae797 100644 --- a/client/internal/routemanager/ipfwdstate/ipfwdstate.go +++ b/client/internal/routemanager/ipfwdstate/ipfwdstate.go @@ -9,7 +9,11 @@ import ( ) // IPForwardingState is a struct that keeps track of the IP forwarding state. -// todo: read initial state of the IP forwarding from the system and reset the state based on it +// todo: read initial state of the IP forwarding from the system and reset the state based on it. +// todo: separate v4/v6 forwarding state, since the sysctls are independent +// (net.ipv4.ip_forward vs net.ipv6.conf.all.forwarding). Currently the nftables +// manager shares one instance between both routers, which works only because +// EnableIPForwarding enables both sysctls in a single call. type IPForwardingState struct { enabledCounter int } From cd9e90ac2f0a80b72b0d895476f2d76b2e35b5cd Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 9 Apr 2026 16:55:08 +0800 Subject: [PATCH 9/9] [client] Add dual-stack iptables manager with ip6tables support (#5708) --- client/android/client.go | 79 ++++-- client/android/route_command.go | 7 +- client/anonymize/anonymize.go | 6 + client/cmd/ssh.go | 4 +- client/cmd/ssh_test.go | 4 +- client/firewall/iptables/acl_linux.go | 31 ++- client/firewall/iptables/manager_linux.go | 236 ++++++++++++++++-- client/firewall/iptables/router_linux.go | 68 +++-- client/firewall/iptables/rule.go | 1 + client/firewall/iptables/state_linux.go | 30 +++ .../uspfilter/allow_netbird_windows.go | 53 ++-- client/firewall/uspfilter/conntrack/common.go | 7 +- client/firewall/uspfilter/filter.go | 3 +- client/firewall/uspfilter/filter_test.go | 38 +++ client/firewall/uspfilter/localip_test.go | 1 - client/iface/configurer/usp.go | 2 +- client/iface/device/adapter.go | 2 +- client/iface/device/device_android.go | 2 +- client/iface/wgproxy/bind/proxy.go | 22 +- client/internal/debug/debug_test.go | 49 +++- client/internal/dns/service_listener.go | 8 +- client/internal/dnsfwd/manager.go | 1 + client/internal/engine.go | 12 +- .../lazyconn/activity/listener_bind.go | 23 +- client/internal/peer/status.go | 6 +- client/internal/routemanager/server/server.go | 3 +- .../routemanager/systemops/systemops.go | 7 +- .../systemops/systemops_generic.go | 62 +++-- client/ios/NetBirdSDK/client.go | 74 ++++-- client/server/network.go | 42 +++- client/ui/network.go | 2 +- client/wasm/cmd/main.go | 115 +++++++-- client/wasm/internal/ssh/client.go | 15 +- proxy/cmd/proxy/cmd/debug.go | 21 +- proxy/internal/debug/client.go | 22 +- proxy/internal/debug/handler.go | 18 +- route/route.go | 61 +++++ route/route_test.go | 108 ++++++++ shared/relay/client/dialer/quic/quic.go | 2 +- 39 files changed, 1014 insertions(+), 233 deletions(-) create mode 100644 route/route_test.go diff --git a/client/android/client.go b/client/android/client.go index 70ebc0011d3..a8766afd243 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -238,41 +238,82 @@ func (c *Client) Networks() *NetworkArray { return nil } + routesMap := routeManager.GetClientRoutesWithNetID() + v6Merged := route.V6ExitMergeSet(routesMap) + resolvedDomains := c.recorder.GetResolvedDomainsStates() + networkArray := &NetworkArray{ items: make([]Network, 0), } - resolvedDomains := c.recorder.GetResolvedDomainsStates() - - for id, routes := range routeManager.GetClientRoutesWithNetID() { + for id, routes := range routesMap { if len(routes) == 0 { continue } + if _, skip := v6Merged[id]; skip { + continue + } + + network := c.buildNetwork(id, routes, routeSelector.IsSelected(id), resolvedDomains, v6Merged) + if network == nil { + continue + } + networkArray.Add(*network) + } + return networkArray +} - r := routes[0] - domains := c.getNetworkDomainsFromRoute(r, resolvedDomains) - netStr := r.Network.String() +func (c *Client) buildNetwork(id route.NetID, routes []*route.Route, selected bool, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo, v6Merged map[route.NetID]struct{}) *Network { + r := routes[0] + netStr := r.Network.String() + if r.IsDynamic() { + netStr = r.Domains.SafeString() + } - if r.IsDynamic() { - netStr = r.Domains.SafeString() + routePeer, err := c.findBestRoutePeer(routes) + if err != nil { + log.Errorf("could not get peer info for route %s: %v", id, err) + return nil + } + + network := &Network{ + Name: string(id), + Network: netStr, + Peer: routePeer.FQDN, + Status: routePeer.ConnStatus.String(), + IsSelected: selected, + Domains: c.getNetworkDomainsFromRoute(r, resolvedDomains), + } + + if route.IsV4DefaultRoute(r.Network) && route.HasV6ExitPair(id, v6Merged) { + network.Network = "0.0.0.0/0, ::/0" + } + + return network +} + +// findBestRoutePeer returns the peer actively routing traffic for the given +// HA route group. Falls back to the first connected peer, then the first peer. +func (c *Client) findBestRoutePeer(routes []*route.Route) (peer.State, error) { + netStr := routes[0].Network.String() + + fullStatus := c.recorder.GetFullStatus() + for _, p := range fullStatus.Peers { + if _, ok := p.GetRoutes()[netStr]; ok { + return p, nil } + } - routePeer, err := c.recorder.GetPeer(routes[0].Peer) + for _, r := range routes { + p, err := c.recorder.GetPeer(r.Peer) if err != nil { - log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err) continue } - network := Network{ - Name: string(id), - Network: netStr, - Peer: routePeer.FQDN, - Status: routePeer.ConnStatus.String(), - IsSelected: routeSelector.IsSelected(id), - Domains: domains, + if p.ConnStatus == peer.StatusConnected { + return p, nil } - networkArray.Add(network) } - return networkArray + return c.recorder.GetPeer(routes[0].Peer) } // OnUpdatedHostDNS update the DNS servers addresses for root zones diff --git a/client/android/route_command.go b/client/android/route_command.go index b47d5ca6ce0..5e735733574 100644 --- a/client/android/route_command.go +++ b/client/android/route_command.go @@ -18,9 +18,12 @@ func executeRouteToggle(id string, manager routemanager.Manager, netID := route.NetID(id) routes := []route.NetID{netID} - log.Debugf("%s with id: %s", operationName, id) + routesMap := manager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) - if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil { + log.Debugf("%s with ids: %v", operationName, routes) + + if err := routeOperation(routes, maps.Keys(routesMap)); err != nil { log.Debugf("error when %s: %s", operationName, err) return fmt.Errorf("error %s: %w", operationName, err) } diff --git a/client/anonymize/anonymize.go b/client/anonymize/anonymize.go index 52beaaa47c1..b7b6a20dd60 100644 --- a/client/anonymize/anonymize.go +++ b/client/anonymize/anonymize.go @@ -9,6 +9,7 @@ import ( "net/url" "regexp" "slices" + "strconv" "strings" ) @@ -97,6 +98,11 @@ func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool { } func (a *Anonymizer) AnonymizeIPString(ip string) string { + // Handle CIDR notation (e.g. "2001:db8::/32") + if prefix, err := netip.ParsePrefix(ip); err == nil { + return a.AnonymizeIP(prefix.Addr()).String() + "/" + strconv.Itoa(prefix.Bits()) + } + addr, err := netip.ParseAddr(ip) if err != nil { return ip diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index 0acf0b13334..de5150b1f0f 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -787,10 +787,10 @@ func isUnixSocket(path string) bool { return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./") } -// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces. +// normalizeLocalHost converts "*" to "" for binding to all interfaces (dual-stack). func normalizeLocalHost(host string) string { if host == "*" { - return "0.0.0.0" + return "" } return host } diff --git a/client/cmd/ssh_test.go b/client/cmd/ssh_test.go index 43291fa87c1..16ffadb90e5 100644 --- a/client/cmd/ssh_test.go +++ b/client/cmd/ssh_test.go @@ -527,10 +527,10 @@ func TestParsePortForward(t *testing.T) { { name: "wildcard bind all interfaces", spec: "*:8080:localhost:80", - expectedLocal: "0.0.0.0:8080", + expectedLocal: ":8080", expectedRemote: "localhost:80", expectError: false, - description: "Wildcard * should bind to all interfaces (0.0.0.0)", + description: "Wildcard * should bind to all interfaces (dual-stack)", }, { name: "wildcard for port only", diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index d83798f0981..4740c41273d 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -36,6 +36,7 @@ type aclManager struct { entries aclEntries optionalEntries map[string][]entry ipsetStore *ipsetStore + v6 bool stateManager *statemanager.Manager } @@ -47,6 +48,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*acl entries: make(map[string][][]string), optionalEntries: make(map[string][]entry), ipsetStore: newIpsetStore(), + v6: iptablesClient.Proto() == iptables.ProtocolIPv6, }, nil } @@ -81,7 +83,11 @@ func (m *aclManager) AddPeerFiltering( chain := chainNameInputRules ipsetName = transformIPsetName(ipsetName, sPort, dPort, action) - specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName) + if m.v6 && ipsetName != "" { + ipsetName += "-v6" + } + proto := protoForFamily(protocol, m.v6) + specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName) mangleSpecs := slices.Clone(specs) mangleSpecs = append(mangleSpecs, @@ -105,6 +111,7 @@ func (m *aclManager) AddPeerFiltering( ip: ip.String(), chain: chain, specs: specs, + v6: m.v6, }}, nil } @@ -157,6 +164,7 @@ func (m *aclManager) AddPeerFiltering( ipsetName: ipsetName, ip: ip.String(), chain: chain, + v6: m.v6, } m.updateState() @@ -376,8 +384,13 @@ func (m *aclManager) updateState() { currentState.Lock() defer currentState.Unlock() - currentState.ACLEntries = m.entries - currentState.ACLIPsetStore = m.ipsetStore + if m.v6 { + currentState.ACLEntries6 = m.entries + currentState.ACLIPsetStore6 = m.ipsetStore + } else { + currentState.ACLEntries = m.entries + currentState.ACLIPsetStore = m.ipsetStore + } if err := m.stateManager.UpdateState(currentState); err != nil { log.Errorf("failed to update state: %v", err) @@ -385,6 +398,15 @@ func (m *aclManager) updateState() { } // filterRuleSpecs returns the specs of a filtering rule +// protoForFamily translates ICMP to ICMPv6 for ip6tables. +// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp". +func protoForFamily(protocol firewall.Protocol, v6 bool) string { + if v6 && protocol == firewall.ProtocolICMP { + return "ipv6-icmp" + } + return string(protocol) +} + func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) { // don't use IP matching if IP is 0.0.0.0 matchByIP := !ip.IsUnspecified() @@ -437,6 +459,9 @@ func (m *aclManager) createIPSet(name string) error { opts := ipset.CreateOptions{ Replace: true, } + if m.v6 { + opts.Family = ipset.FamilyIPV6 + } if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { return fmt.Errorf("create ipset %s: %w", name, err) diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 2fc6f8ec8dc..c278924f250 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -17,6 +17,10 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) +type resetter interface { + Reset() error +} + // Manager of iptables firewall type Manager struct { mutex sync.Mutex @@ -27,6 +31,11 @@ type Manager struct { aclMgr *aclManager router *router rawSupported bool + + // IPv6 counterparts, nil when no v6 overlay + ipv6Client *iptables.IPTables + aclMgr6 *aclManager + router6 *router } // iFaceMapper defines subset methods of interface required for manager @@ -58,9 +67,43 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { return nil, fmt.Errorf("create acl manager: %w", err) } + if wgIface.Address().HasIPv6() { + if err := m.createIPv6Components(wgIface, mtu); err != nil { + return nil, fmt.Errorf("create IPv6 firewall: %w", err) + } + } + return m, nil } +func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error { + ip6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6) + if err != nil { + return fmt.Errorf("init ip6tables: %w", err) + } + m.ipv6Client = ip6Client + + m.router6, err = newRouter(ip6Client, wgIface, mtu) + if err != nil { + return fmt.Errorf("create v6 router: %w", err) + } + + // Share the same IP forwarding state with the v4 router, since + // EnableIPForwarding controls both v4 and v6 sysctls. + m.router6.ipFwdState = m.router.ipFwdState + + m.aclMgr6, err = newAclManager(ip6Client, wgIface) + if err != nil { + return fmt.Errorf("create v6 acl manager: %w", err) + } + + return nil +} + +func (m *Manager) hasIPv6() bool { + return m.ipv6Client != nil +} + func (m *Manager) Init(stateManager *statemanager.Manager) error { state := &ShutdownState{ InterfaceState: &InterfaceState{ @@ -75,13 +118,8 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { log.Errorf("failed to update state: %v", err) } - if err := m.router.init(stateManager); err != nil { - return fmt.Errorf("router init: %w", err) - } - - if err := m.aclMgr.init(stateManager); err != nil { - // TODO: cleanup router - return fmt.Errorf("acl manager init: %w", err) + if err := m.initChains(stateManager); err != nil { + return err } if err := m.initNoTrackChain(); err != nil { @@ -98,6 +136,41 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { return nil } +// initChains initializes router and ACL chains for both address families, +// rolling back on failure. +func (m *Manager) initChains(stateManager *statemanager.Manager) error { + type initStep struct { + name string + init func(*statemanager.Manager) error + mgr resetter + } + + steps := []initStep{ + {"router", m.router.init, m.router}, + {"acl manager", m.aclMgr.init, m.aclMgr}, + } + if m.hasIPv6() { + steps = append(steps, + initStep{"v6 router", m.router6.init, m.router6}, + initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6}, + ) + } + + var initialized []initStep + for _, s := range steps { + if err := s.init(stateManager); err != nil { + for i := len(initialized) - 1; i >= 0; i-- { + if rerr := initialized[i].mgr.Reset(); rerr != nil { + log.Warnf("rollback %s: %v", initialized[i].name, rerr) + } + } + return fmt.Errorf("%s init: %w", s.name, err) + } + initialized = append(initialized, s) + } + return nil +} + // AddPeerFiltering adds a rule to the firewall // // Comment will be ignored because some system this feature is not supported @@ -113,7 +186,13 @@ func (m *Manager) AddPeerFiltering( m.mutex.Lock() defer m.mutex.Unlock() - return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) + if ip.To4() != nil { + return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) + } + if !m.hasIPv6() { + return nil, fmt.Errorf("IPv6 not initialized, cannot add rule for %s", ip) + } + return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) } func (m *Manager) AddRouteFiltering( @@ -127,25 +206,48 @@ func (m *Manager) AddRouteFiltering( m.mutex.Lock() defer m.mutex.Unlock() - if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { - return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) + if isIPv6RouteRule(sources, destination) { + if !m.hasIPv6() { + return nil, fmt.Errorf("IPv6 not initialized, cannot add route rule") + } + return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) } return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) } +func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool { + if destination.IsPrefix() { + return destination.Prefix.Addr().Is6() + } + return len(sources) > 0 && sources[0].Addr().Is6() +} + // DeletePeerRule from the firewall by rule definition func (m *Manager) DeletePeerRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && isIPv6IptRule(rule) { + return m.aclMgr6.DeletePeerRule(rule) + } return m.aclMgr.DeletePeerRule(rule) } +func isIPv6IptRule(rule firewall.Rule) bool { + r, ok := rule.(*Rule) + return ok && r.v6 +} + +// DeleteRouteRule deletes a routing rule. +// Route rules are keyed by content hash. Check v4 first, try v6 if not found. func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && !m.router.hasRule(rule.ID()) { + return m.router6.DeleteRouteRule(rule) + } return m.router.DeleteRouteRule(rule) } @@ -161,18 +263,63 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddNatRule(pair) + if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { + if !m.hasIPv6() { + return fmt.Errorf("IPv6 not initialized, cannot add NAT rule") + } + return m.router6.AddNatRule(pair) + } + + if err := m.router.AddNatRule(pair); err != nil { + return err + } + + // Dynamic routes need NAT in both tables + if m.hasIPv6() && pair.Destination.IsSet() { + v6Pair := pair + v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)} + if err := m.router6.AddNatRule(v6Pair); err != nil { + return fmt.Errorf("add v6 NAT rule: %w", err) + } + } + + return nil } func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveNatRule(pair) + if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { + if !m.hasIPv6() { + return nil + } + return m.router6.RemoveNatRule(pair) + } + + if err := m.router.RemoveNatRule(pair); err != nil { + return err + } + + if m.hasIPv6() && pair.Destination.IsSet() { + v6Pair := pair + v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)} + if err := m.router6.RemoveNatRule(v6Pair); err != nil { + return fmt.Errorf("remove v6 NAT rule: %w", err) + } + } + + return nil } func (m *Manager) SetLegacyManagement(isLegacy bool) error { - return firewall.SetLegacyManagement(m.router, isLegacy) + if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil { + return err + } + if m.hasIPv6() { + return firewall.SetLegacyManagement(m.router6, isLegacy) + } + return nil } // Reset firewall to the default state @@ -186,6 +333,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err)) } + if m.hasIPv6() { + if err := m.aclMgr6.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err)) + } + if err := m.router6.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err)) + } + } + if err := m.aclMgr.Reset(); err != nil { merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err)) } @@ -209,19 +365,16 @@ func (m *Manager) AllowNetbird() error { return nil } - _, err := m.AddPeerFiltering( - nil, - net.IP{0, 0, 0, 0}, - firewall.ProtocolALL, - nil, - nil, - firewall.ActionAccept, - "", - ) - if err != nil { - return fmt.Errorf("allow netbird interface traffic: %w", err) + var merr *multierror.Error + if _, err := m.aclMgr.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil { + merr = multierror.Append(merr, fmt.Errorf("allow netbird interface traffic: %w", err)) } - return nil + if m.hasIPv6() { + if _, err := m.aclMgr6.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil { + merr = multierror.Append(merr, fmt.Errorf("allow v6 netbird interface traffic: %w", err)) + } + } + return nberrors.FormatErrorOrNil(merr) } // Flush doesn't need to be implemented for this manager @@ -251,6 +404,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && rule.TranslatedAddress.Is6() { + return m.router6.AddDNATRule(rule) + } return m.router.AddDNATRule(rule) } @@ -259,6 +415,9 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) { + return m.router6.DeleteDNATRule(rule) + } return m.router.DeleteDNATRule(rule) } @@ -267,7 +426,26 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.UpdateSet(set, prefixes) + var v4Prefixes, v6Prefixes []netip.Prefix + for _, p := range prefixes { + if p.Addr().Is6() { + v6Prefixes = append(v6Prefixes, p) + } else { + v4Prefixes = append(v4Prefixes, p) + } + } + + if err := m.router.UpdateSet(set, v4Prefixes); err != nil { + return err + } + + if m.hasIPv6() && len(v6Prefixes) > 0 { + if err := m.router6.UpdateSet(set, v6Prefixes); err != nil { + return fmt.Errorf("update v6 set: %w", err) + } + } + + return nil } // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. @@ -275,6 +453,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && localAddr.Is6() { + return m.router6.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) + } return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) } @@ -283,6 +464,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && localAddr.Is6() { + return m.router6.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) + } return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) } diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index a7c4f67dd5c..61921f7f9d5 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -54,8 +54,10 @@ const ( snatSuffix = "_snat" fwdSuffix = "_fwd" - // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation - ipTCPHeaderMinSize = 40 + // ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation. + ipv4TCPHeaderSize = 40 + // ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation. + ipv6TCPHeaderSize = 60 ) type ruleInfo struct { @@ -86,6 +88,7 @@ type router struct { wgIface iFaceMapper legacyManagement bool mtu uint16 + v6 bool stateManager *statemanager.Manager ipFwdState *ipfwdstate.IPForwardingState @@ -97,6 +100,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1 rules: make(map[string][]string), wgIface: wgIface, mtu: mtu, + v6: iptablesClient.Proto() == iptables.ProtocolIPv6, ipFwdState: ipfwdstate.NewIPForwardingState(), } @@ -186,6 +190,11 @@ func (r *router) AddRouteFiltering( return ruleKey, nil } +func (r *router) hasRule(id string) bool { + _, ok := r.rules[id] + return ok +} + func (r *router) DeleteRouteRule(rule firewall.Rule) error { ruleKey := rule.ID() @@ -434,6 +443,12 @@ func (r *router) createContainers() error { {chainRTRDR, tableNat}, {chainRTMSSCLAMP, tableMangle}, } { + // Fallback: clear chains that survived an unclean shutdown. + if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok { + if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil { + log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err) + } + } if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil { return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) } @@ -540,9 +555,12 @@ func (r *router) addPostroutingRules() error { } // addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. -// TODO: Add IPv6 support func (r *router) addMSSClampingRules() error { - mss := r.mtu - ipTCPHeaderMinSize + overhead := uint16(ipv4TCPHeaderSize) + if r.v6 { + overhead = ipv6TCPHeaderSize + } + mss := r.mtu - overhead // Add jump rule from FORWARD chain in mangle table to our custom chain jumpRule := []string{ @@ -727,8 +745,13 @@ func (r *router) updateState() { currentState.Lock() defer currentState.Unlock() - currentState.RouteRules = r.rules - currentState.RouteIPsetCounter = r.ipsetCounter + if r.v6 { + currentState.RouteRules6 = r.rules + currentState.RouteIPsetCounter6 = r.ipsetCounter + } else { + currentState.RouteRules = r.rules + currentState.RouteIPsetCounter = r.ipsetCounter + } if err := r.stateManager.UpdateState(currentState); err != nil { log.Errorf("failed to update state: %v", err) @@ -856,7 +879,7 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error { } if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists { - if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil { + if err := r.iptablesClient.Delete(tableFilter, chainRTFWDOUT, fwdRule...); err != nil { merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err)) } delete(r.rules, ruleKey+fwdSuffix) @@ -883,7 +906,7 @@ func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []net rule = append(rule, destExp...) if params.Proto != firewall.ProtocolALL { - rule = append(rule, "-p", strings.ToLower(string(params.Proto))) + rule = append(rule, "-p", strings.ToLower(protoForFamily(params.Proto, r.v6))) rule = append(rule, applyPort("--sport", params.SPort)...) rule = append(rule, applyPort("--dport", params.DPort)...) } @@ -900,11 +923,12 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes [] } if network.IsSet() { - if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil { + name := r.ipsetName(network.Set.HashedName()) + if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil { return nil, fmt.Errorf("create or get ipset: %w", err) } - return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil + return []string{"-m", "set", matchSet, name, direction}, nil } if network.IsPrefix() { return []string{flag, network.Prefix.String()}, nil @@ -915,19 +939,15 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes [] } func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + name := r.ipsetName(set.HashedName()) var merr *multierror.Error for _, prefix := range prefixes { - // TODO: Implement IPv6 support - if prefix.Addr().Is6() { - log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) - continue - } - if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil { + if err := r.addPrefixToIPSet(name, prefix); err != nil { merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err)) } } if merr == nil { - log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes) + log.Debugf("updated set %s with prefixes %v", name, prefixes) } return nberrors.FormatErrorOrNil(merr) @@ -943,7 +963,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol dnatRule := []string{ "-i", r.wgIface.Name(), - "-p", strings.ToLower(string(protocol)), + "-p", strings.ToLower(protoForFamily(protocol, r.v6)), "--dport", strconv.Itoa(int(sourcePort)), "-d", localAddr.String(), "-m", "addrtype", "--dst-type", "LOCAL", @@ -1076,10 +1096,22 @@ func applyPort(flag string, port *firewall.Port) []string { return []string{flag, strconv.Itoa(int(port.Values[0]))} } +// ipsetName returns the ipset name, suffixed with "-v6" for the v6 router +// to avoid collisions since ipsets are global in the kernel. +func (r *router) ipsetName(name string) string { + if r.v6 { + return name + "-v6" + } + return name +} + func (r *router) createIPSet(name string) error { opts := ipset.CreateOptions{ Replace: true, } + if r.v6 { + opts.Family = ipset.FamilyIPV6 + } if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { return fmt.Errorf("create ipset %s: %w", name, err) diff --git a/client/firewall/iptables/rule.go b/client/firewall/iptables/rule.go index aa4d2d07900..4f4eab167e7 100644 --- a/client/firewall/iptables/rule.go +++ b/client/firewall/iptables/rule.go @@ -9,6 +9,7 @@ type Rule struct { mangleSpecs []string ip string chain string + v6 bool } // GetRuleID returns the rule id diff --git a/client/firewall/iptables/state_linux.go b/client/firewall/iptables/state_linux.go index c88774c1f10..6b2e99e31d8 100644 --- a/client/firewall/iptables/state_linux.go +++ b/client/firewall/iptables/state_linux.go @@ -4,6 +4,8 @@ import ( "fmt" "sync" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -37,6 +39,12 @@ type ShutdownState struct { ACLEntries aclEntries `json:"acl_entries,omitempty"` ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"` + + // IPv6 counterparts + RouteRules6 routeRules `json:"route_rules_v6,omitempty"` + RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"` + ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"` + ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"` } func (s *ShutdownState) Name() string { @@ -67,6 +75,28 @@ func (s *ShutdownState) Cleanup() error { ipt.aclMgr.ipsetStore = s.ACLIPsetStore } + // Clean up v6 state even if the current run has no IPv6. + // The previous run may have left ip6tables rules behind. + if !ipt.hasIPv6() { + if err := ipt.createIPv6Components(s.InterfaceState, mtu); err != nil { + log.Warnf("failed to create v6 components for cleanup: %v", err) + } + } + if ipt.hasIPv6() { + if s.RouteRules6 != nil { + ipt.router6.rules = s.RouteRules6 + } + if s.RouteIPsetCounter6 != nil { + ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6) + } + if s.ACLEntries6 != nil { + ipt.aclMgr6.entries = s.ACLEntries6 + } + if s.ACLIPsetStore6 != nil { + ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6 + } + } + if err := ipt.Close(nil); err != nil { return fmt.Errorf("reset iptables manager: %w", err) } diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index 6aef2ecfdad..10a2b9116b9 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -5,8 +5,10 @@ import ( "os/exec" "syscall" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -29,15 +31,20 @@ func (m *Manager) Close(*statemanager.Manager) error { return nil } - if !isFirewallRuleActive(firewallRuleName) { - return nil + var merr *multierror.Error + if isFirewallRuleActive(firewallRuleName) { + if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove windows firewall rule: %w", err)) + } } - if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil { - return fmt.Errorf("couldn't remove windows firewall: %w", err) + if isFirewallRuleActive(firewallRuleName + "-v6") { + if err := manageFirewallRule(firewallRuleName+"-v6", deleteRule); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove windows v6 firewall rule: %w", err)) + } } - return nil + return nberrors.FormatErrorOrNil(merr) } // AllowNetbird allows netbird interface traffic @@ -46,17 +53,33 @@ func (m *Manager) AllowNetbird() error { return nil } - if isFirewallRuleActive(firewallRuleName) { - return nil + if !isFirewallRuleActive(firewallRuleName) { + if err := manageFirewallRule(firewallRuleName, + addRule, + "dir=in", + "enable=yes", + "action=allow", + "profile=any", + "localip="+m.wgIface.Address().IP.String(), + ); err != nil { + return err + } + } + + if v6 := m.wgIface.Address().IPv6; v6.IsValid() && !isFirewallRuleActive(firewallRuleName+"-v6") { + if err := manageFirewallRule(firewallRuleName+"-v6", + addRule, + "dir=in", + "enable=yes", + "action=allow", + "profile=any", + "localip="+v6.String(), + ); err != nil { + return err + } } - return manageFirewallRule(firewallRuleName, - addRule, - "dir=in", - "enable=yes", - "action=allow", - "profile=any", - "localip="+m.wgIface.Address().IP.String(), - ) + + return nil } func manageFirewallRule(ruleName string, action action, extraArgs ...string) error { diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index 7be0dd78f94..88e90317c8f 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -1,8 +1,9 @@ package conntrack import ( - "fmt" + "net" "net/netip" + "strconv" "sync/atomic" "time" @@ -64,5 +65,7 @@ type ConnKey struct { } func (c ConnKey) String() string { - return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort) + return net.JoinHostPort(c.SrcIP.Unmap().String(), strconv.Itoa(int(c.SrcPort))) + + " → " + + net.JoinHostPort(c.DstIP.Unmap().String(), strconv.Itoa(int(c.DstPort))) } diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 788aac7098a..75a02ac6f4f 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -1615,7 +1615,8 @@ func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool { } // traffic to our other local interfaces (not NetBird IP) - always forward - if dstIP != m.wgIface.Address().IP { + addr := m.wgIface.Address() + if dstIP != addr.IP && (!addr.IPv6.IsValid() || dstIP != addr.IPv6) { return true } diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index 5897f56cf71..01e5f97c16b 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -1336,6 +1336,44 @@ func TestShouldForward(t *testing.T) { }, } + // Add IPv6 to the interface and test dual-stack cases + wgIPv6 := netip.MustParseAddr("fd00::1") + otherIPv6 := netip.MustParseAddr("fd00::2") + ifaceMock.AddressFunc = func() wgaddr.Address { + return wgaddr.Address{ + IP: wgIP, + Network: netip.PrefixFrom(wgIP, 24), + IPv6: wgIPv6, + IPv6Net: netip.PrefixFrom(wgIPv6, 64), + } + } + + // Re-create manager to pick up the new address with IPv6 + require.NoError(t, manager.Close(nil)) + manager, err = Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) + require.NoError(t, err) + + v6Cases := []struct { + name string + dstIP netip.Addr + expected bool + description string + }{ + {"v6 traffic to other address", otherIPv6, true, "should forward v6 traffic not destined to our v6 address"}, + {"v6 traffic to our v6 IP", wgIPv6, false, "should not forward traffic destined to our v6 address"}, + {"v4 traffic to other with v6 configured", otherIP, true, "should forward v4 traffic when v6 configured"}, + {"v4 traffic to our v4 IP with v6 configured", wgIP, false, "should not forward traffic to our v4 address"}, + } + for _, tt := range v6Cases { + t.Run(tt.name, func(t *testing.T) { + manager.localForwarding = true + manager.netstack = false + decoder := createTCPDecoder(8080) + result := manager.shouldForward(decoder, tt.dstIP) + require.Equal(t, tt.expected, result, tt.description) + }) + } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Configure manager diff --git a/client/firewall/uspfilter/localip_test.go b/client/firewall/uspfilter/localip_test.go index 42573d16d29..0dc524c41bd 100644 --- a/client/firewall/uspfilter/localip_test.go +++ b/client/firewall/uspfilter/localip_test.go @@ -202,4 +202,3 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) { }) } } - diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index e3a96590cd4..9b070aab8a8 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -119,7 +119,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, if err != nil { return fmt.Errorf("failed to parse endpoint address: %w", err) } - addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port)) + addrPort := netip.AddrPortFrom(addr.Unmap(), uint16(endpoint.Port)) c.activityRecorder.UpsertAddress(peerKey, addrPort) } return nil diff --git a/client/iface/device/adapter.go b/client/iface/device/adapter.go index 6ebc0539007..e3caaf9305b 100644 --- a/client/iface/device/adapter.go +++ b/client/iface/device/adapter.go @@ -2,7 +2,7 @@ package device // TunAdapter is an interface for create tun device from external service type TunAdapter interface { - ConfigureInterface(address string, mtu int, dns string, searchDomains string, routes string) (int, error) + ConfigureInterface(address string, addressV6 string, mtu int, dns string, searchDomains string, routes string) (int, error) UpdateAddr(address string) error ProtectSocket(fd int32) bool } diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index 198343fbd3b..cbe88c10c0c 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -63,7 +63,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string searchDomainsToString = "" } - fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), int(t.mtu), dns, searchDomainsToString, routesString) + fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.address.IPv6String(), int(t.mtu), dns, searchDomainsToString, routesString) if err != nil { log.Errorf("failed to create Android interface: %s", err) return nil, err diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index 9ac3ea6df95..5bf670e076b 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -6,7 +6,7 @@ import ( "fmt" "net" "net/netip" - "strings" + "sync" log "github.com/sirupsen/logrus" @@ -196,18 +196,22 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { } } -// fakeAddress returns a fake address that is used to as an identifier for the peer. -// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address. +// fakeAddress returns a fake address that is used as an identifier for the peer. +// The fake address is in the format of 127.1.x.x where x.x is derived from the +// last two bytes of the peer address (works for both IPv4 and IPv6). func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) { - octets := strings.Split(peerAddress.IP.String(), ".") - if len(octets) != 4 { - return nil, fmt.Errorf("invalid IP format") + if peerAddress == nil { + return nil, fmt.Errorf("nil peer address") } - fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])) - if err != nil { - return nil, fmt.Errorf("parse new IP: %w", err) + addr, ok := netip.AddrFromSlice(peerAddress.IP) + if !ok { + return nil, fmt.Errorf("invalid IP format") } + addr = addr.Unmap() + + raw := addr.As16() + fakeIP := netip.AddrFrom4([4]byte{127, 1, raw[14], raw[15]}) netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port)) return &netipAddr, nil diff --git a/client/internal/debug/debug_test.go b/client/internal/debug/debug_test.go index 59837c32880..e242b8b1b51 100644 --- a/client/internal/debug/debug_test.go +++ b/client/internal/debug/debug_test.go @@ -430,8 +430,6 @@ func isInCGNATRange(ip net.IP) bool { } func TestAnonymizeFirewallRules(t *testing.T) { - // TODO: Add ipv6 - // Example iptables-save output iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024 *filter @@ -467,17 +465,31 @@ Chain FORWARD (policy ACCEPT 0 packets, 0 bytes) Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes) pkts bytes target prot opt in out source destination` - // Example nftables output + // Example ip6tables-save output + ip6tablesSave := `# Generated by ip6tables-save v1.8.7 on Thu Dec 19 10:00:00 2024 +*filter +:INPUT ACCEPT [0:0] +:FORWARD ACCEPT [0:0] +:OUTPUT ACCEPT [0:0] +-A INPUT -s fd00:1234::1/128 -j ACCEPT +-A INPUT -s 2607:f8b0:4005::1/128 -j DROP +-A FORWARD -s 2001:db8::/32 -d 2607:f8b0:4005::200e/128 -j ACCEPT +COMMIT` + + // Example nftables output with IPv6 nftablesRules := `table inet filter { chain input { type filter hook input priority filter; policy accept; ip saddr 192.168.1.1 accept ip saddr 44.192.140.1 drop + ip6 saddr 2607:f8b0:4005::1 drop + ip6 saddr fd00:1234::1 accept } chain forward { type filter hook forward priority filter; policy accept; ip saddr 10.0.0.0/8 drop ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept + ip6 saddr 2001:db8::/32 ip6 daddr 2607:f8b0:4005::200e/128 accept } }` @@ -540,4 +552,35 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes) assert.Contains(t, anonNftables, "table inet filter {") assert.Contains(t, anonNftables, "chain input {") assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;") + + // IPv6 public addresses in nftables should be anonymized + assert.NotContains(t, anonNftables, "2607:f8b0:4005::1") + assert.NotContains(t, anonNftables, "2607:f8b0:4005::200e") + assert.NotContains(t, anonNftables, "2001:db8::") + assert.Contains(t, anonNftables, "2001:db8:ffff::") // Default anonymous v6 range + + // ULA addresses in nftables should remain unchanged (private) + assert.Contains(t, anonNftables, "fd00:1234::1") + + // IPv6 nftables structure preserved + assert.Contains(t, anonNftables, "ip6 saddr") + assert.Contains(t, anonNftables, "ip6 daddr") + + // Test ip6tables-save anonymization + anonIp6tablesSave := anonymizer.AnonymizeString(ip6tablesSave) + + // ULA (private) IPv6 should remain unchanged + assert.Contains(t, anonIp6tablesSave, "fd00:1234::1/128") + + // Public IPv6 addresses should be anonymized + assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::1") + assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::200e") + assert.NotContains(t, anonIp6tablesSave, "2001:db8::") + assert.Contains(t, anonIp6tablesSave, "2001:db8:ffff::") // Default anonymous v6 range + + // Structure should be preserved + assert.Contains(t, anonIp6tablesSave, "*filter") + assert.Contains(t, anonIp6tablesSave, "COMMIT") + assert.Contains(t, anonIp6tablesSave, "-j DROP") + assert.Contains(t, anonIp6tablesSave, "-j ACCEPT") } diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index b4b9409f9e3..551555ad4c5 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -189,10 +189,10 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr { } -// evalListenAddress figure out the listen address for the DNS server -// first check the 53 port availability on WG interface or lo, if not success -// pick a random port on WG interface for eBPF, if not success -// check the 5053 port availability on WG interface or lo without eBPF usage, +// evalListenAddress figures out the listen address for the DNS server. +// IPv4-only: all peers have a v4 overlay address, and DNS config points to v4. +// First checks port 53 on WG interface or lo, then tries eBPF on a random port, +// then falls back to port 5053. func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) { if s.customAddr != nil { return s.customAddr.Addr(), s.customAddr.Port(), nil diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index 58b88d9ef40..c4c16cd3f8d 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -80,6 +80,7 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } + // IPv4-only: peers reach the forwarder via its v4 overlay address. localAddr := m.wgIface.Address().IP if localAddr.IsValid() && m.firewall != nil { diff --git a/client/internal/engine.go b/client/internal/engine.go index 1347161e9e6..16410519bed 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -630,7 +630,7 @@ func (e *Engine) initFirewall() error { rosenpassPort := e.rpManager.GetAddress().Port port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}} - // this rule is static and will be torn down on engine down by the firewall manager + // IPv4-only: rosenpass peers connect via AllowedIps[0] which is always v4. if _, err := e.firewall.AddPeerFiltering( nil, net.IP{0, 0, 0, 0}, @@ -682,10 +682,15 @@ func (e *Engine) blockLanAccess() { log.Infof("blocking route LAN access for networks: %v", toBlock) v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0) + v6 := netip.PrefixFrom(netip.IPv6Unspecified(), 0) for _, network := range toBlock { + source := v4 + if network.Addr().Is6() { + source = v6 + } if _, err := e.firewall.AddRouteFiltering( nil, - []netip.Prefix{v4}, + []netip.Prefix{source}, firewallManager.Network{Prefix: network}, firewallManager.ProtocolALL, nil, @@ -2362,8 +2367,7 @@ func getInterfacePrefixes() ([]netip.Prefix, error) { prefix := netip.PrefixFrom(addr.Unmap(), ones).Masked() ip := prefix.Addr() - // TODO: add IPv6 - if !ip.Is4() || ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + if ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { continue } diff --git a/client/internal/lazyconn/activity/listener_bind.go b/client/internal/lazyconn/activity/listener_bind.go index 792d0421550..60b8baadbb5 100644 --- a/client/internal/lazyconn/activity/listener_bind.go +++ b/client/internal/lazyconn/activity/listener_bind.go @@ -57,6 +57,7 @@ func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyc // deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP. // Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y). // It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface. +// For IPv6-only peers, the last two bytes of the v6 address are used. func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) { if len(allowedIPs) == 0 { return netip.Addr{}, fmt.Errorf("no allowed IPs for peer") @@ -64,6 +65,7 @@ func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, e ourNetwork := wgIface.Address().Network + // Try v4 first (preferred: deterministic from overlay IP) var peerIP netip.Addr for _, allowedIP := range allowedIPs { ip := allowedIP.Addr() @@ -76,13 +78,24 @@ func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, e } } - if !peerIP.IsValid() { - return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs") + if peerIP.IsValid() { + octets := peerIP.As4() + return netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}), nil } - octets := peerIP.As4() - fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}) - return fakeIP, nil + // Fallback: use last two bytes of first v6 overlay IP + addr := wgIface.Address() + if addr.IPv6Net.IsValid() { + for _, allowedIP := range allowedIPs { + ip := allowedIP.Addr() + if ip.Is6() && addr.IPv6Net.Contains(ip) { + raw := ip.As16() + return netip.AddrFrom4([4]byte{127, 2, raw[14], raw[15]}), nil + } + } + } + + return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs") } func (d *BindListener) setupLazyConn() error { diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index fbf95de21e7..f4db95c8a0f 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -1055,7 +1055,11 @@ func (d *Status) notifyPeerListChanged() { } func (d *Status) notifyAddressChanged() { - d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP) + addr := d.localPeer.IP + if d.localPeer.IPv6 != "" { + addr = addr + "\n" + d.localPeer.IPv6 + } + d.notifier.localAddressChanged(d.localPeer.FQDN, addr) } func (d *Status) numOfPeers() int { diff --git a/client/internal/routemanager/server/server.go b/client/internal/routemanager/server/server.go index e674c80cdfd..d35b44f5b6e 100644 --- a/client/internal/routemanager/server/server.go +++ b/client/internal/routemanager/server/server.go @@ -146,8 +146,7 @@ func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterP if useNewDNSRoute { destination.Set = firewall.NewDomainSet(route.Domains) } else { - // TODO: add ipv6 additionally - destination = getDefaultPrefix(destination.Prefix) + destination = getDefaultPrefix(route.Network) } } else { destination.Prefix = route.Network.Masked() diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index c0ca21d22e0..8724ed1bad0 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -107,8 +107,13 @@ func (r *SysOps) validateRoute(prefix netip.Prefix) error { addr.IsInterfaceLocalMulticast(), addr.IsMulticast(), addr.IsUnspecified() && prefix.Bits() != 0, - r.wgInterface.Address().Network.Contains(addr): + r.isOwnAddress(addr): return vars.ErrRouteNotAllowed } return nil } + +func (r *SysOps) isOwnAddress(addr netip.Addr) bool { + wgAddr := r.wgInterface.Address() + return wgAddr.Network.Contains(addr) || (wgAddr.IPv6Net.IsValid() && wgAddr.IPv6Net.Contains(addr)) +} diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 02f447a2f4b..07bd2c118db 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -228,29 +228,14 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er // as a separate route that the dedicated handler adds. // Soft-fail: v6 blackhole is best-effort, don't abort v4 routing on failure. if !r.wgInterface.Address().HasIPv6() { - if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { - log.Warnf("Failed to add v6 blackhole split 1: %s", err) - } else if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil { - log.Warnf("Failed to add v6 blackhole split 2: %s", err) - if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil { - log.Warnf("Failed to rollback v6 blackhole: %s", err2) - } + if err := r.addV6SplitDefault(nextHop); err != nil { + log.Warnf("failed to add v6 split-default blackhole: %s", err) } } return nil case vars.Defaultv6: - if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil { - if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return fmt.Errorf("add unreachable route split 2: %w", err) - } - - return nil + return r.addV6SplitDefault(nextHop) } return r.addToRouteTable(prefix, nextHop) @@ -272,30 +257,41 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) } if !r.wgInterface.Address().HasIPv6() { - if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { - result = multierror.Append(result, err) - } - if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil { - result = multierror.Append(result, err) - } + result = multierror.Append(result, r.removeV6SplitDefault(nextHop)) } return nberrors.FormatErrorOrNil(result) case vars.Defaultv6: - var result *multierror.Error - if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { - result = multierror.Append(result, err) - } - if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil { - result = multierror.Append(result, err) - } - - return nberrors.FormatErrorOrNil(result) + return nberrors.FormatErrorOrNil(r.removeV6SplitDefault(nextHop)) default: return r.removeFromRouteTable(prefix, nextHop) } } +func (r *SysOps) addV6SplitDefault(nextHop Nexthop) error { + if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { + return fmt.Errorf("add split 1: %w", err) + } + if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil { + if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil { + log.Warnf("Failed to rollback v6 split-default: %s", err2) + } + return fmt.Errorf("add split 2: %w", err) + } + return nil +} + +func (r *SysOps) removeV6SplitDefault(nextHop Nexthop) *multierror.Error { + var result *multierror.Error + if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { + result = multierror.Append(result, err) + } + if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil { + result = multierror.Append(result, err) + } + return result +} + func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error { beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error { if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil { diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 990e03034f0..c73a0dcd1a9 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -50,10 +50,11 @@ type CustomLogger interface { } type selectRoute struct { - NetID string - Network netip.Prefix - Domains domain.List - Selected bool + NetID string + Network netip.Prefix + Domains domain.List + Selected bool + extraNetworks []netip.Prefix } func init() { @@ -363,48 +364,60 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) { } routeManager := engine.GetRouteManager() - routesMap := routeManager.GetClientRoutesWithNetID() if routeManager == nil { return nil, fmt.Errorf("could not get route manager") } + routesMap := routeManager.GetClientRoutesWithNetID() routeSelector := routeManager.GetRouteSelector() if routeSelector == nil { return nil, fmt.Errorf("could not get route selector") } + v6ExitMerged := route.V6ExitMergeSet(routesMap) + routes := buildSelectRoutes(routesMap, routeSelector.IsSelected, v6ExitMerged) + resolvedDomains := c.recorder.GetResolvedDomainsStates() + + return prepareRouteSelectionDetails(routes, resolvedDomains), nil +} + +func buildSelectRoutes(routesMap map[route.NetID][]*route.Route, isSelected func(route.NetID) bool, v6Merged map[route.NetID]struct{}) []*selectRoute { var routes []*selectRoute for id, rt := range routesMap { if len(rt) == 0 { continue } - route := &selectRoute{ + if _, ok := v6Merged[id]; ok { + continue + } + + r := &selectRoute{ NetID: string(id), Network: rt[0].Network, Domains: rt[0].Domains, - Selected: routeSelector.IsSelected(id), + Selected: isSelected(id), } - routes = append(routes, route) + + v6ID := route.NetID(string(id) + route.V6ExitSuffix) + if _, ok := v6Merged[v6ID]; ok { + r.extraNetworks = []netip.Prefix{routesMap[v6ID][0].Network} + } + + routes = append(routes, r) } sort.Slice(routes, func(i, j int) bool { - iPrefix := routes[i].Network.Bits() - jPrefix := routes[j].Network.Bits() - - if iPrefix == jPrefix { - iAddr := routes[i].Network.Addr() - jAddr := routes[j].Network.Addr() - if iAddr == jAddr { - return routes[i].NetID < routes[j].NetID - } - return iAddr.String() < jAddr.String() + iBits, jBits := routes[i].Network.Bits(), routes[j].Network.Bits() + if iBits != jBits { + return iBits < jBits } - return iPrefix < jPrefix + iAddr, jAddr := routes[i].Network.Addr(), routes[j].Network.Addr() + if iAddr != jAddr { + return iAddr.Less(jAddr) + } + return routes[i].NetID < routes[j].NetID }) - resolvedDomains := c.recorder.GetResolvedDomainsStates() - - return prepareRouteSelectionDetails(routes, resolvedDomains), nil - + return routes } func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails { @@ -425,10 +438,15 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom } domainList = append(domainList, domainResp) } + rangeStr := r.Network.String() + for _, extra := range r.extraNetworks { + rangeStr += ", " + extra.String() + } + domainDetails := DomainDetails{items: domainList} routeSelection = append(routeSelection, RoutesSelectionInfo{ ID: r.NetID, - Network: r.Network.String(), + Network: rangeStr, Domains: &domainDetails, Selected: r.Selected, }) @@ -456,7 +474,9 @@ func (c *Client) SelectRoute(id string) error { } else { log.Debugf("select route with id: %s", id) routes := toNetIDs([]string{id}) - if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil { + routesMap := routeManager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) + if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routesMap)); err != nil { log.Debugf("error when selecting routes: %s", err) return fmt.Errorf("select routes: %w", err) } @@ -483,7 +503,9 @@ func (c *Client) DeselectRoute(id string) error { } else { log.Debugf("deselect route with id: %s", id) routes := toNetIDs([]string{id}) - if err := routeSelector.DeselectRoutes(routes, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil { + routesMap := routeManager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) + if err := routeSelector.DeselectRoutes(routes, maps.Keys(routesMap)); err != nil { log.Debugf("error when deselecting routes: %s", err) return fmt.Errorf("deselect routes: %w", err) } diff --git a/client/server/network.go b/client/server/network.go index bb1cce56c54..4a439d8cf39 100644 --- a/client/server/network.go +++ b/client/server/network.go @@ -16,10 +16,11 @@ import ( ) type selectRoute struct { - NetID route.NetID - Network netip.Prefix - Domains domain.List - Selected bool + NetID route.NetID + Network netip.Prefix + Domains domain.List + Selected bool + extraNetworks []netip.Prefix } // ListNetworks returns a list of all available networks. @@ -44,18 +45,33 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro routesMap := routeMgr.GetClientRoutesWithNetID() routeSelector := routeMgr.GetRouteSelector() + v6ExitMerged := route.V6ExitMergeSet(routesMap) + var routes []*selectRoute for id, rt := range routesMap { if len(rt) == 0 { continue } - route := &selectRoute{ + // Skip v6 exit nodes that are merged into their v4 counterpart. + if _, ok := v6ExitMerged[id]; ok { + continue + } + + r := &selectRoute{ NetID: id, Network: rt[0].Network, Domains: rt[0].Domains, Selected: routeSelector.IsSelected(id), } - routes = append(routes, route) + + // Merge paired v6 exit node prefix into this entry. + v6ID := route.NetID(string(id) + route.V6ExitSuffix) + if _, ok := v6ExitMerged[v6ID]; ok { + v6Prefix := routesMap[v6ID][0].Network + r.extraNetworks = []netip.Prefix{v6Prefix} + } + + routes = append(routes, r) } sort.Slice(routes, func(i, j int) bool { @@ -76,9 +92,13 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro resolvedDomains := s.statusRecorder.GetResolvedDomainsStates() var pbRoutes []*proto.Network for _, route := range routes { + rangeStr := route.Network.String() + for _, extra := range route.extraNetworks { + rangeStr += ", " + extra.String() + } pbRoute := &proto.Network{ ID: string(route.NetID), - Range: route.Network.String(), + Range: rangeStr, Domains: route.Domains.ToSafeStringList(), ResolvedIPs: map[string]*proto.IPList{}, Selected: route.Selected, @@ -137,7 +157,9 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ routeSelector.SelectAllRoutes() } else { routes := toNetIDs(req.GetNetworkIDs()) - netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) + routesMap := routeManager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) + netIdRoutes := maps.Keys(routesMap) if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil { return nil, fmt.Errorf("select routes: %w", err) } @@ -183,7 +205,9 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe routeSelector.DeselectAllRoutes() } else { routes := toNetIDs(req.GetNetworkIDs()) - netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) + routesMap := routeManager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) + netIdRoutes := maps.Keys(routesMap) if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil { return nil, fmt.Errorf("deselect routes: %w", err) } diff --git a/client/ui/network.go b/client/ui/network.go index 6ae57122ea4..4bb0b76118c 100644 --- a/client/ui/network.go +++ b/client/ui/network.go @@ -195,7 +195,7 @@ func getOverlappingNetworks(routes []*proto.Network) []*proto.Network { func getExitNodeNetworks(routes []*proto.Network) []*proto.Network { var filteredRoutes []*proto.Network for _, route := range routes { - if route.Range == "0.0.0.0/0" || route.Range == "::/0" { + if strings.Contains(route.Range, "0.0.0.0/0") || route.Range == "::/0" { filteredRoutes = append(filteredRoutes, route) } } diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go index 0c1a5dc6951..6fa0eeb2acc 100644 --- a/client/wasm/cmd/main.go +++ b/client/wasm/cmd/main.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "net" + "strconv" "syscall/js" "time" @@ -166,39 +167,58 @@ func createSSHMethod(client *netbird.Client) js.Func { }) } - var jwtToken string - if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() { - jwtToken = args[3].String() - } + jwtToken, ipVersion := parseSSHOptions(args) return createPromise(func(resolve, reject js.Value) { - sshClient := ssh.NewClient(client) - - if err := sshClient.Connect(host, port, username, jwtToken); err != nil { - reject.Invoke(err.Error()) - return - } - - if err := sshClient.StartSession(80, 24); err != nil { - if closeErr := sshClient.Close(); closeErr != nil { - log.Errorf("Error closing SSH client: %v", closeErr) - } + jsInterface, err := connectSSH(client, host, port, username, jwtToken, ipVersion) + if err != nil { reject.Invoke(err.Error()) return } - - jsInterface := ssh.CreateJSInterface(sshClient) resolve.Invoke(jsInterface) }) }) } -func performPing(client *netbird.Client, hostname string) { +func parseSSHOptions(args []js.Value) (jwtToken string, ipVersion int) { + if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() { + jwtToken = args[3].String() + } + if len(args) > 4 { + ipVersion = jsIPVersion(args[4]) + } + return +} + +func connectSSH(client *netbird.Client, host string, port int, username, jwtToken string, ipVersion int) (js.Value, error) { + sshClient := ssh.NewClient(client) + + if err := sshClient.Connect(host, port, username, jwtToken, ipVersion); err != nil { + return js.Undefined(), err + } + + if err := sshClient.StartSession(80, 24); err != nil { + if closeErr := sshClient.Close(); closeErr != nil { + log.Errorf("Error closing SSH client: %v", closeErr) + } + return js.Undefined(), err + } + + return ssh.CreateJSInterface(sshClient), nil +} + +func performPing(client *netbird.Client, hostname string, ipVersion int) { ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) defer cancel() + // Default to ping4 to avoid dual-stack ICMP endpoint issues in wireguard-go netstack. + network := "ping4" + if ipVersion == 6 { + network = "ping6" + } + start := time.Now() - conn, err := client.Dial(ctx, "ping", hostname) + conn, err := client.Dial(ctx, network, hostname) if err != nil { js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s failed: %v", hostname, err)) return @@ -225,27 +245,39 @@ func performPing(client *netbird.Client, hostname string) { } latency := time.Since(start) - js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds())) + remote := conn.RemoteAddr().String() + msg := fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds()) + if remote != hostname { + msg += fmt.Sprintf(" (via %s)", remote) + } + js.Global().Get("console").Call("log", msg) } -func performPingTCP(client *netbird.Client, hostname string, port int) { +func performPingTCP(client *netbird.Client, hostname string, port, ipVersion int) { ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) defer cancel() + network := ipVersionNetwork("tcp", ipVersion) + address := net.JoinHostPort(hostname, fmt.Sprintf("%d", port)) start := time.Now() - conn, err := client.Dial(ctx, "tcp", address) + conn, err := client.Dial(ctx, network, address) if err != nil { js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s failed: %v", address, err)) return } latency := time.Since(start) + remote := conn.RemoteAddr().String() if err := conn.Close(); err != nil { log.Debugf("failed to close TCP connection: %v", err) } - js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds())) + msg := fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds()) + if remote != address { + msg += fmt.Sprintf(" (via %s)", remote) + } + js.Global().Get("console").Call("log", msg) } // createPingMethod creates the ping method @@ -262,8 +294,12 @@ func createPingMethod(client *netbird.Client) js.Func { } hostname := args[0].String() + var ipVersion int + if len(args) > 1 { + ipVersion = jsIPVersion(args[1]) + } return createPromise(func(resolve, reject js.Value) { - performPing(client, hostname) + performPing(client, hostname, ipVersion) resolve.Invoke(js.Undefined()) }) }) @@ -290,8 +326,12 @@ func createPingTCPMethod(client *netbird.Client) js.Func { hostname := args[0].String() port := args[1].Int() + var ipVersion int + if len(args) > 2 { + ipVersion = jsIPVersion(args[2]) + } return createPromise(func(resolve, reject js.Value) { - performPingTCP(client, hostname, port) + performPingTCP(client, hostname, port, ipVersion) resolve.Invoke(js.Undefined()) }) }) @@ -464,6 +504,31 @@ func createSetLogLevelMethod(client *netbird.Client) js.Func { }) } +// ipVersionNetwork appends "4" or "6" to a base network string (e.g. "tcp" -> "tcp4"). +func ipVersionNetwork(base string, ipVersion int) string { + switch ipVersion { + case 4: + return base + "4" + case 6: + return base + "6" + default: + return base + } +} + +// jsIPVersion extracts an IP version (4 or 6) from a JS string or number. +func jsIPVersion(v js.Value) int { + switch v.Type() { + case js.TypeNumber: + return v.Int() + case js.TypeString: + n, _ := strconv.Atoi(v.String()) + return n + default: + return 0 + } +} + // createPromise is a helper to create JavaScript promises func createPromise(handler func(resolve, reject js.Value)) js.Value { return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { diff --git a/client/wasm/internal/ssh/client.go b/client/wasm/internal/ssh/client.go index 2f425c614d9..9cfe652669c 100644 --- a/client/wasm/internal/ssh/client.go +++ b/client/wasm/internal/ssh/client.go @@ -46,8 +46,9 @@ func NewClient(nbClient *netbird.Client) *Client { } } -// Connect establishes an SSH connection through NetBird network -func (c *Client) Connect(host string, port int, username, jwtToken string) error { +// Connect establishes an SSH connection through NetBird network. +// ipVersion may be 4, 6, or 0 for automatic selection. +func (c *Client) Connect(host string, port int, username, jwtToken string, ipVersion int) error { addr := net.JoinHostPort(host, fmt.Sprintf("%d", port)) logrus.Infof("SSH: Connecting to %s as %s", addr, username) @@ -63,10 +64,18 @@ func (c *Client) Connect(host string, port int, username, jwtToken string) error Timeout: sshDialTimeout, } + network := "tcp" + switch ipVersion { + case 4: + network = "tcp4" + case 6: + network = "tcp6" + } + ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout) defer cancel() - conn, err := c.nbClient.Dial(ctx, "tcp", addr) + conn, err := c.nbClient.Dial(ctx, network, addr) if err != nil { return fmt.Errorf("dial %s: %w", addr, err) } diff --git a/proxy/cmd/proxy/cmd/debug.go b/proxy/cmd/proxy/cmd/debug.go index 59f7a6b6590..81879e404b2 100644 --- a/proxy/cmd/proxy/cmd/debug.go +++ b/proxy/cmd/proxy/cmd/debug.go @@ -3,6 +3,7 @@ package cmd import ( "fmt" "strconv" + "time" "github.com/spf13/cobra" @@ -57,7 +58,11 @@ var debugSyncCmd = &cobra.Command{ SilenceUsage: true, } -var pingTimeout string +var ( + pingTimeout time.Duration + pingIPv4 bool + pingIPv6 bool +) var debugPingCmd = &cobra.Command{ Use: "ping [port]", @@ -108,7 +113,10 @@ func init() { debugStatusCmd.Flags().StringVar(&statusFilterByStatus, "filter-by-status", "", "Filter by status (idle|connecting|connected)") debugStatusCmd.Flags().StringVar(&statusFilterByConnectionType, "filter-by-connection-type", "", "Filter by connection type (P2P|Relayed)") - debugPingCmd.Flags().StringVar(&pingTimeout, "timeout", "", "Ping timeout (e.g., 10s)") + debugPingCmd.Flags().DurationVar(&pingTimeout, "timeout", 0, "Ping timeout (e.g., 10s)") + debugPingCmd.Flags().BoolVarP(&pingIPv4, "ipv4", "4", false, "Force IPv4") + debugPingCmd.Flags().BoolVarP(&pingIPv6, "ipv6", "6", false, "Force IPv6") + debugPingCmd.MarkFlagsMutuallyExclusive("ipv4", "ipv6") debugCmd.AddCommand(debugHealthCmd) debugCmd.AddCommand(debugClientsCmd) @@ -157,7 +165,14 @@ func runDebugPing(cmd *cobra.Command, args []string) error { } port = p } - return getDebugClient(cmd).PingTCP(cmd.Context(), args[0], args[1], port, pingTimeout) + var ipVersion string + switch { + case pingIPv4: + ipVersion = "4" + case pingIPv6: + ipVersion = "6" + } + return getDebugClient(cmd).PingTCP(cmd.Context(), args[0], args[1], port, pingTimeout, ipVersion) } func runDebugLogLevel(cmd *cobra.Command, args []string) error { diff --git a/proxy/internal/debug/client.go b/proxy/internal/debug/client.go index 01b0bc8e6d6..2ce721eb8d3 100644 --- a/proxy/internal/debug/client.go +++ b/proxy/internal/debug/client.go @@ -6,10 +6,12 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "net/url" "strings" "time" + ) // StatusFilters contains filter options for status queries. @@ -230,12 +232,16 @@ func (c *Client) ClientSyncResponse(ctx context.Context, accountID string) error } // PingTCP performs a TCP ping through a client. -func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout string) error { +// ipVersion may be "4", "6", or "" for automatic. +func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout time.Duration, ipVersion string) error { params := url.Values{} params.Set("host", host) params.Set("port", fmt.Sprintf("%d", port)) - if timeout != "" { - params.Set("timeout", timeout) + if timeout > 0 { + params.Set("timeout", timeout.String()) + } + if ipVersion != "" { + params.Set("ip_version", ipVersion) } path := fmt.Sprintf("/debug/clients/%s/pingtcp?%s", url.PathEscape(accountID), params.Encode()) @@ -244,11 +250,17 @@ func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, func (c *Client) printPingResult(data map[string]any) { success, _ := data["success"].(bool) + host := net.JoinHostPort(fmt.Sprint(data["host"]), fmt.Sprint(data["port"])) if success { - _, _ = fmt.Fprintf(c.out, "Success: %v:%v\n", data["host"], data["port"]) + remote, _ := data["remote"].(string) + if remote != "" && remote != host { + _, _ = fmt.Fprintf(c.out, "Success: %s (via %s)\n", host, remote) + } else { + _, _ = fmt.Fprintf(c.out, "Success: %s\n", host) + } _, _ = fmt.Fprintf(c.out, "Latency: %v\n", data["latency"]) } else { - _, _ = fmt.Fprintf(c.out, "Failed: %v:%v\n", data["host"], data["port"]) + _, _ = fmt.Fprintf(c.out, "Failed: %s\n", host) c.printError(data) } } diff --git a/proxy/internal/debug/handler.go b/proxy/internal/debug/handler.go index c507cfad9d0..c1d145204ee 100644 --- a/proxy/internal/debug/handler.go +++ b/proxy/internal/debug/handler.go @@ -9,6 +9,7 @@ import ( "fmt" "html/template" "maps" + "net" "net/http" "slices" "strconv" @@ -525,13 +526,18 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI } } + network := "tcp" + if v := r.URL.Query().Get("ip_version"); v == "4" || v == "6" { + network += v + } + ctx, cancel := context.WithTimeout(r.Context(), timeout) defer cancel() - address := fmt.Sprintf("%s:%d", host, port) + address := net.JoinHostPort(host, strconv.Itoa(port)) start := time.Now() - conn, err := client.Dial(ctx, "tcp", address) + conn, err := client.Dial(ctx, network, address) if err != nil { h.writeJSON(w, map[string]interface{}{ "success": false, @@ -541,18 +547,22 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI }) return } + + remote := conn.RemoteAddr().String() if err := conn.Close(); err != nil { h.logger.Debugf("close tcp ping connection: %v", err) } latency := time.Since(start) - h.writeJSON(w, map[string]interface{}{ + resp := map[string]interface{}{ "success": true, "host": host, "port": port, + "remote": remote, "latency_ms": latency.Milliseconds(), "latency": formatDuration(latency), - }) + } + h.writeJSON(w, resp) } func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { diff --git a/route/route.go b/route/route.go index c724e7c7d07..97b9721f619 100644 --- a/route/route.go +++ b/route/route.go @@ -20,6 +20,9 @@ const ( MaxMetric = 9999 // MaxNetIDChar Max Network Identifier MaxNetIDChar = 40 + + // V6ExitSuffix is appended to a v4 exit node NetID to form its v6 counterpart. + V6ExitSuffix = "-v6" ) const ( @@ -215,3 +218,61 @@ func ParseNetwork(networkString string) (NetworkType, netip.Prefix, error) { return IPv4Network, masked, nil } + +var ( + v4Default = netip.PrefixFrom(netip.IPv4Unspecified(), 0) + v6Default = netip.PrefixFrom(netip.IPv6Unspecified(), 0) +) + +// IsV4DefaultRoute reports whether p is the IPv4 default route (0.0.0.0/0). +func IsV4DefaultRoute(p netip.Prefix) bool { return p == v4Default } + +// IsV6DefaultRoute reports whether p is the IPv6 default route (::/0). +func IsV6DefaultRoute(p netip.Prefix) bool { return p == v6Default } + +// ExpandV6ExitPairs appends the paired "-v6" exit node NetID for any v4 exit +// node (0.0.0.0/0) in ids that has a matching v6 counterpart (::/0) in routesMap. +// It modifies and returns the input slice. +func ExpandV6ExitPairs(ids []NetID, routesMap map[NetID][]*Route) []NetID { + for _, id := range ids { + rt, ok := routesMap[id] + if !ok || len(rt) == 0 || !IsV4DefaultRoute(rt[0].Network) { + continue + } + v6ID := NetID(string(id) + V6ExitSuffix) + if v6Rt, ok := routesMap[v6ID]; ok && len(v6Rt) > 0 && IsV6DefaultRoute(v6Rt[0].Network) { + if !slices.Contains(ids, v6ID) { + ids = append(ids, v6ID) + } + } + } + return ids +} + +// V6ExitMergeSet scans routesMap and returns the set of v6 exit node NetIDs +// that should be hidden from the UI because they are paired with a v4 exit node. +// A v6 ID is paired when it has suffix "-v6", its route is ::/0, and the base +// name (without "-v6") exists with route 0.0.0.0/0. +func V6ExitMergeSet(routesMap map[NetID][]*Route) map[NetID]struct{} { + merged := make(map[NetID]struct{}) + for id, rt := range routesMap { + if len(rt) == 0 { + continue + } + name := string(id) + if !IsV6DefaultRoute(rt[0].Network) || !strings.HasSuffix(name, V6ExitSuffix) { + continue + } + baseName := NetID(strings.TrimSuffix(name, V6ExitSuffix)) + if baseRt, ok := routesMap[baseName]; ok && len(baseRt) > 0 && IsV4DefaultRoute(baseRt[0].Network) { + merged[id] = struct{}{} + } + } + return merged +} + +// HasV6ExitPair reports whether id has a paired v6 exit node in the merge set. +func HasV6ExitPair(id NetID, v6Merged map[NetID]struct{}) bool { + _, ok := v6Merged[NetID(string(id)+"-v6")] + return ok +} diff --git a/route/route_test.go b/route/route_test.go new file mode 100644 index 00000000000..dab707ed35a --- /dev/null +++ b/route/route_test.go @@ -0,0 +1,108 @@ +package route + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExpandV6ExitPairs(t *testing.T) { + v4ExitRoute := &Route{Network: netip.MustParsePrefix("0.0.0.0/0")} + v6ExitRoute := &Route{Network: netip.MustParsePrefix("::/0")} + regularRoute := &Route{Network: netip.MustParsePrefix("10.0.0.0/8")} + + tests := []struct { + name string + ids []NetID + routesMap map[NetID][]*Route + expected []NetID + }{ + { + name: "v4 exit node with matching v6 pair", + ids: []NetID{"exit-node"}, + routesMap: map[NetID][]*Route{ + "exit-node": {v4ExitRoute}, + "exit-node-v6": {v6ExitRoute}, + }, + expected: []NetID{"exit-node", "exit-node-v6"}, + }, + { + name: "v4 exit node without v6 pair", + ids: []NetID{"exit-node"}, + routesMap: map[NetID][]*Route{ + "exit-node": {v4ExitRoute}, + }, + expected: []NetID{"exit-node"}, + }, + { + name: "regular route is not expanded", + ids: []NetID{"office"}, + routesMap: map[NetID][]*Route{ + "office": {regularRoute}, + "office-v6": {v6ExitRoute}, + }, + expected: []NetID{"office"}, + }, + { + name: "v6 already included is not duplicated", + ids: []NetID{"exit-node", "exit-node-v6"}, + routesMap: map[NetID][]*Route{ + "exit-node": {v4ExitRoute}, + "exit-node-v6": {v6ExitRoute}, + }, + expected: []NetID{"exit-node", "exit-node-v6"}, + }, + { + name: "multiple exit nodes expanded independently", + ids: []NetID{"exit-a", "exit-b"}, + routesMap: map[NetID][]*Route{ + "exit-a": {v4ExitRoute}, + "exit-a-v6": {v6ExitRoute}, + "exit-b": {v4ExitRoute}, + "exit-b-v6": {v6ExitRoute}, + }, + expected: []NetID{"exit-a", "exit-b", "exit-a-v6", "exit-b-v6"}, + }, + { + name: "v6 suffix but not exit node network", + ids: []NetID{"office"}, + routesMap: map[NetID][]*Route{ + "office": {regularRoute}, + "office-v6": {regularRoute}, + }, + expected: []NetID{"office"}, + }, + { + name: "user-chosen name for exit node with v6 pair", + ids: []NetID{"my-exit"}, + routesMap: map[NetID][]*Route{ + "my-exit": {v4ExitRoute}, + "my-exit-v6": {v6ExitRoute}, + }, + expected: []NetID{"my-exit", "my-exit-v6"}, + }, + { + name: "real-world management-generated IDs", + ids: []NetID{"0.0.0.0/0"}, + routesMap: map[NetID][]*Route{ + "0.0.0.0/0": {v4ExitRoute}, + "0.0.0.0/0-v6": {v6ExitRoute}, + }, + expected: []NetID{"0.0.0.0/0", "0.0.0.0/0-v6"}, + }, + { + name: "empty input", + ids: []NetID{}, + routesMap: map[NetID][]*Route{}, + expected: []NetID{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExpandV6ExitPairs(tt.ids, tt.routesMap) + assert.ElementsMatch(t, tt.expected, result) + }) + } +} diff --git a/shared/relay/client/dialer/quic/quic.go b/shared/relay/client/dialer/quic/quic.go index 2d7b00a8095..7d413d4c11b 100644 --- a/shared/relay/client/dialer/quic/quic.go +++ b/shared/relay/client/dialer/quic/quic.go @@ -46,7 +46,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { InitialPacketSize: nbRelay.QUICInitialPacketSize, } - udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{Port: 0}) if err != nil { log.Errorf("failed to listen on UDP: %s", err) return nil, err