From da6f61039a413b0607af1e00308a536a3940cffc Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 24 Mar 2026 07:39:51 +0100 Subject: [PATCH 1/3] Add IPv6 routing support with forwarding, fake IPs, and exit node handling --- client/anonymize/anonymize.go | 5 +- client/anonymize/anonymize_test.go | 14 +- .../firewall/uspfilter/forwarder/endpoint.go | 7 +- .../firewall/uspfilter/forwarder/forwarder.go | 18 +- client/firewall/uspfilter/forwarder/icmp.go | 18 +- client/firewall/uspfilter/forwarder/tcp.go | 4 +- client/firewall/uspfilter/forwarder/udp.go | 3 +- client/internal/dns/upstream.go | 7 + client/internal/dns/upstream_android.go | 2 +- client/internal/dns/upstream_general.go | 2 +- client/internal/dns/upstream_ios.go | 57 +++--- client/internal/routemanager/client/client.go | 5 +- .../routemanager/dnsinterceptor/handler.go | 2 +- client/internal/routemanager/dynamic/route.go | 4 +- .../routemanager/dynamic/route_ios.go | 37 ++-- client/internal/routemanager/fakeip/fakeip.go | 144 ++++++++++----- .../routemanager/fakeip/fakeip_test.go | 169 +++++++++++++----- client/internal/routemanager/manager.go | 20 ++- .../systemops/systemops_generic.go | 34 ++-- .../routemanager/systemops/systemops_linux.go | 23 ++- 20 files changed, 366 insertions(+), 209 deletions(-) diff --git a/client/anonymize/anonymize.go b/client/anonymize/anonymize.go index 89e653300ba..52beaaa47c1 100644 --- a/client/anonymize/anonymize.go +++ b/client/anonymize/anonymize.go @@ -26,8 +26,9 @@ type Anonymizer struct { } func DefaultAddresses() (netip.Addr, netip.Addr) { - // 198.51.100.0, 100:: - return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01}) + // 198.51.100.0 (RFC 5737 TEST-NET-2), 2001:db8:ffff:: (RFC 3849 documentation, last /48) + // The old start 100:: (discard, RFC 6666) is now used for fake IPs on Android. + return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.MustParseAddr("2001:db8:ffff::") } func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer { diff --git a/client/anonymize/anonymize_test.go b/client/anonymize/anonymize_test.go index ff2e4886943..45e20583467 100644 --- a/client/anonymize/anonymize_test.go +++ b/client/anonymize/anonymize_test.go @@ -13,7 +13,7 @@ import ( func TestAnonymizeIP(t *testing.T) { startIPv4 := netip.MustParseAddr("198.51.100.0") - startIPv6 := netip.MustParseAddr("100::") + startIPv6 := netip.MustParseAddr("2001:db8:ffff::") anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6) tests := []struct { @@ -26,9 +26,9 @@ func TestAnonymizeIP(t *testing.T) { {"Second Public IPv4", "4.3.2.1", "198.51.100.1"}, {"Repeated IPv4", "1.2.3.4", "198.51.100.0"}, {"Private IPv4", "192.168.1.1", "192.168.1.1"}, - {"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"}, - {"Second Public IPv6", "a::b", "100::1"}, - {"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"}, + {"First Public IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"}, + {"Second Public IPv6", "a::b", "2001:db8:ffff::1"}, + {"Repeated IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"}, {"Private IPv6", "fe80::1", "fe80::1"}, {"In Range IPv4", "198.51.100.2", "198.51.100.2"}, } @@ -274,17 +274,17 @@ func TestAnonymizeString_IPAddresses(t *testing.T) { { name: "IPv6 Address", input: "Access attempted from 2001:db8::ff00:42", - expect: "Access attempted from 100::", + expect: "Access attempted from 2001:db8:ffff::", }, { name: "IPv6 Address with Port", input: "Access attempted from [2001:db8::ff00:42]:8080", - expect: "Access attempted from [100::]:8080", + expect: "Access attempted from [2001:db8:ffff::]:8080", }, { name: "Both IPv4 and IPv6", input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43", - expect: "IPv4: 198.51.100.1 and IPv6: 100::1", + expect: "IPv4: 198.51.100.1 and IPv6: 2001:db8:ffff::1", }, } diff --git a/client/firewall/uspfilter/forwarder/endpoint.go b/client/firewall/uspfilter/forwarder/endpoint.go index 6dd2e8e8413..3f1ba6a449c 100644 --- a/client/firewall/uspfilter/forwarder/endpoint.go +++ b/client/firewall/uspfilter/forwarder/endpoint.go @@ -1,7 +1,8 @@ package forwarder import ( - "fmt" + "net" + "strconv" "sync/atomic" wgdevice "golang.zx2c4.com/wireguard/device" @@ -106,5 +107,7 @@ type epID stack.TransportEndpointID func (i epID) String() string { // src and remote is swapped - return fmt.Sprintf("%s:%d → %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort) + return net.JoinHostPort(i.RemoteAddress.String(), strconv.Itoa(int(i.RemotePort))) + + " → " + + net.JoinHostPort(i.LocalAddress.String(), strconv.Itoa(int(i.LocalPort))) } diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index 84cc7ec4447..bbacb95f639 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -158,6 +158,16 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP) + // TODO: gvisor's IPv6 network layer (ipv6/icmp.go) replies to ICMPv6 echo + // requests at the network layer before our transport handler fires. Unlike + // IPv4, it has no localAddressTemporary check or DeliverTransportPacket call + // before replying. With promiscuous mode, this means gvisor replies to ALL + // ICMPv6 echo (including routed traffic) with local latency. + // Not fixed as of gvisor 20260320. + // Fix: handle ICMPv6 echo in the USP filter before passing to the forwarder, + // similar to how v4 ICMP worked before the forwarder existed. The forwarder + // is needed for TCP (full proxy) and UDP (endpoint tracking), but ICMP can + // be handled directly since it's stateless request/reply. s.SetTransportProtocolHandler(icmp.ProtocolNumber6, f.handleICMPv6) f.checkICMPCapability() @@ -210,14 +220,14 @@ func (f *Forwarder) Stop() { f.stack.Wait() } -func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP { +func (f *Forwarder) determineDialAddr(addr tcpip.Address) netip.Addr { if f.netstack && f.ip.Equal(addr) { - return net.IPv4(127, 0, 0, 1) + return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } if f.netstack && f.ipv6.Equal(addr) { - return net.IPv6loopback + return netip.IPv6Loopback() } - return addr.AsSlice() + return addrToNetipAddr(addr) } func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) { diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index 7df2fec8219..c60c22a3856 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "net/netip" "os/exec" "runtime" "time" @@ -82,7 +83,7 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by } dstIP := f.determineDialAddr(id.LocalAddress) - dst := &net.IPAddr{IP: dstIP} + dst := &net.IPAddr{IP: dstIP.AsSlice()} if _, err = conn.WriteTo(payload, dst); err != nil { if closeErr := conn.Close(); closeErr != nil { @@ -273,14 +274,11 @@ func (f *Forwarder) synthesizeICMPv6EchoReply(id stack.TransportEndpointID, icmp replyHdr := header.ICMPv6(replyICMP) replyHdr.SetType(header.ICMPv6EchoReply) replyHdr.SetChecksum(0) - // ICMPv6 checksum requires a pseudo-header - psum := header.PseudoHeaderChecksum(header.ICMPv6ProtocolNumber, id.LocalAddress, id.RemoteAddress, uint16(len(replyICMP))) + // ICMPv6 checksum includes a pseudo-header computed internally by ICMPv6Checksum replyHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: replyHdr, - Src: id.LocalAddress, - Dst: id.RemoteAddress, - PayloadCsum: psum, - PayloadLen: len(replyICMP) - header.ICMPv6MinimumSize, + Header: replyHdr, + Src: id.LocalAddress, + Dst: id.RemoteAddress, })) return f.injectICMPv6Reply(id, replyICMP) @@ -317,13 +315,13 @@ const ( // buildPingCommand creates a platform-specific ping command. // Most platforms auto-detect IPv6 from raw addresses. macOS/iOS/OpenBSD require ping6. -func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd { +func buildPingCommand(ctx context.Context, target netip.Addr, timeout time.Duration) *exec.Cmd { timeoutSec := int(timeout.Seconds()) if timeoutSec < 1 { timeoutSec = 1 } - isV6 := target.To4() == nil + isV6 := target.Is6() timeoutStr := fmt.Sprintf("%d", timeoutSec) switch runtime.GOOS { diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go index 5969c0e1197..8844463f556 100644 --- a/client/firewall/uspfilter/forwarder/tcp.go +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -2,9 +2,9 @@ package forwarder import ( "context" - "fmt" "io" "net" + "strconv" "sync" "github.com/google/uuid" @@ -32,7 +32,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { } }() - dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) + dialAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort))) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) if err != nil { diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index d3313db9bbc..c92fa1f326b 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "strconv" "sync" "sync/atomic" "time" @@ -157,7 +158,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool { } }() - dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) + dstAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort))) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) if err != nil { f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err) diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 5b813513272..516383745f5 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -21,6 +21,7 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/peer" @@ -29,6 +30,12 @@ import ( var currentMTU uint16 = iface.DefaultMTU +// privateClientIface is the subset of the WireGuard interface needed by GetClientPrivate. +type privateClientIface interface { + Name() string + Address() wgaddr.Address +} + func SetCurrentMTU(mtu uint16) { currentMTU = mtu } diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index d7cff377bf0..e28fe79cd85 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -86,7 +86,7 @@ func (u *upstreamResolver) isLocalResolver(upstream string) bool { return false } -func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { +func GetClientPrivate(_ privateClientIface, _ netip.Addr, dialTimeout time.Duration) (*dns.Client, error) { return &dns.Client{ Timeout: dialTimeout, Net: "udp", diff --git a/client/internal/dns/upstream_general.go b/client/internal/dns/upstream_general.go index 1143b6c514f..910c3779ebf 100644 --- a/client/internal/dns/upstream_general.go +++ b/client/internal/dns/upstream_general.go @@ -52,7 +52,7 @@ func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns return ExchangeWithFallback(ctx, client, r, upstream) } -func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { +func GetClientPrivate(_ privateClientIface, _ netip.Addr, dialTimeout time.Duration) (*dns.Client, error) { return &dns.Client{ Timeout: dialTimeout, Net: "udp", diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 26b19dac3ed..0e04742a081 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -19,11 +19,7 @@ import ( type upstreamResolverIOS struct { *upstreamResolverBase - lIP netip.Addr - lNet netip.Prefix - lIPv6 netip.Addr - lNetV6 netip.Prefix - interfaceName string + wgIface WGIface } func newUpstreamResolver( @@ -37,11 +33,7 @@ func newUpstreamResolver( ios := &upstreamResolverIOS{ upstreamResolverBase: upstreamResolverBase, - lIP: wgIface.Address().IP, - lNet: wgIface.Address().Network, - lIPv6: wgIface.Address().IPv6, - lNetV6: wgIface.Address().IPv6Net, - interfaceName: wgIface.Name(), + wgIface: wgIface, } ios.upstreamClient = ios @@ -69,24 +61,15 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * } else { upstreamIP = upstreamIP.Unmap() } - needsPrivate := u.lNet.Contains(upstreamIP) || - u.lNetV6.Contains(upstreamIP) || + addr := u.wgIface.Address() + needsPrivate := addr.Network.Contains(upstreamIP) || + addr.IPv6Net.Contains(upstreamIP) || (u.routeMatch != nil && u.routeMatch(upstreamIP)) if needsPrivate { - var bindIP netip.Addr - switch { - case upstreamIP.Is6() && u.lIPv6.IsValid(): - bindIP = u.lIPv6 - case upstreamIP.Is4() && u.lIP.IsValid(): - bindIP = u.lIP - } - - if bindIP.IsValid() { - log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream) - client, err = GetClientPrivate(bindIP, u.interfaceName, timeout) - if err != nil { - return nil, 0, fmt.Errorf("create private client: %s", err) - } + log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream) + client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout) + if err != nil { + return nil, 0, fmt.Errorf("create private client: %s", err) } } @@ -94,23 +77,29 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * return ExchangeWithFallback(nil, client, r, upstream) } -// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface -// This method is needed for iOS -func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { - index, err := getInterfaceIndex(interfaceName) +// GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface. +// It selects the v6 bind address when the upstream is IPv6 and the interface has one, otherwise v4. +func GetClientPrivate(iface privateClientIface, upstreamIP netip.Addr, dialTimeout time.Duration) (*dns.Client, error) { + index, err := getInterfaceIndex(iface.Name()) if err != nil { - log.Debugf("unable to get interface index for %s: %s", interfaceName, err) + log.Debugf("unable to get interface index for %s: %s", iface.Name(), err) return nil, err } + addr := iface.Address() + bindIP := addr.IP + if upstreamIP.Is6() && addr.HasIPv6() { + bindIP = addr.IPv6 + } + proto, opt := unix.IPPROTO_IP, unix.IP_BOUND_IF - if ip.Is6() { + if bindIP.Is6() { proto, opt = unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF } dialer := &net.Dialer{ - LocalAddr: net.UDPAddrFromAddrPort(netip.AddrPortFrom(ip, 0)), - Timeout: dialTimeout, + LocalAddr: net.UDPAddrFromAddrPort(netip.AddrPortFrom(bindIP, 0)), + Timeout: dialTimeout, Control: func(network, address string, c syscall.RawConn) error { var operr error fn := func(s uintptr) { diff --git a/client/internal/routemanager/client/client.go b/client/internal/routemanager/client/client.go index e6ef8b87655..c691c54f8a7 100644 --- a/client/internal/routemanager/client/client.go +++ b/client/internal/routemanager/client/client.go @@ -3,9 +3,8 @@ package client import ( "context" "fmt" - "net" + "net/netip" "reflect" - "strconv" "time" log "github.com/sirupsen/logrus" @@ -566,7 +565,7 @@ func HandlerFromRoute(params common.HandlerParams) RouteHandler { return dnsinterceptor.New(params) case handlerTypeDynamic: dns := nbdns.NewServiceViaMemory(params.WgInterface) - dnsAddr := net.JoinHostPort(dns.RuntimeIP().String(), strconv.Itoa(dns.RuntimePort())) + dnsAddr := netip.AddrPortFrom(dns.RuntimeIP(), uint16(dns.RuntimePort())) return dynamic.NewRoute(params, dnsAddr) default: return static.NewRoute(params) diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 64f2a878982..e25cc2a5ccf 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -582,7 +582,7 @@ func (d *DnsInterceptor) queryUpstreamDNS(ctx context.Context, w dns.ResponseWri if nsNet != nil { reply, err = nbdns.ExchangeWithNetstack(ctx, nsNet, r, upstream) } else { - client, clientErr := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout) + client, clientErr := nbdns.GetClientPrivate(d.wgInterface, upstreamIP, dnsTimeout) if clientErr != nil { d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", clientErr)) return nil diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index 8d1398a7a37..f0efd7b2280 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -50,10 +50,10 @@ type Route struct { cancel context.CancelFunc statusRecorder *peer.Status wgInterface iface.WGIface - resolverAddr string + resolverAddr netip.AddrPort } -func NewRoute(params common.HandlerParams, resolverAddr string) *Route { +func NewRoute(params common.HandlerParams, resolverAddr netip.AddrPort) *Route { return &Route{ route: params.Route, routeRefCounter: params.RouteRefCounter, diff --git a/client/internal/routemanager/dynamic/route_ios.go b/client/internal/routemanager/dynamic/route_ios.go index 8fed1c8f997..6498bd77f66 100644 --- a/client/internal/routemanager/dynamic/route_ios.go +++ b/client/internal/routemanager/dynamic/route_ios.go @@ -17,33 +17,36 @@ import ( const dialTimeout = 10 * time.Second func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) { - privateClient, err := nbdns.GetClientPrivate(r.wgInterface.Address().IP, r.wgInterface.Name(), dialTimeout) + privateClient, err := nbdns.GetClientPrivate(r.wgInterface, r.resolverAddr.Addr(), dialTimeout) if err != nil { return nil, fmt.Errorf("error while creating private client: %s", err) } - msg := new(dns.Msg) - msg.SetQuestion(dns.Fqdn(domain.PunycodeString()), dns.TypeA) - + fqdn := dns.Fqdn(domain.PunycodeString()) startTime := time.Now() - response, _, err := nbdns.ExchangeWithFallback(nil, privateClient, msg, r.resolverAddr) - if err != nil { - return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err) - } + var ips []net.IP - if response.Rcode != dns.RcodeSuccess { - return nil, fmt.Errorf("dns response code: %s", dns.RcodeToString[response.Rcode]) - } + for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} { + msg := new(dns.Msg) + msg.SetQuestion(fqdn, qtype) - ips := make([]net.IP, 0) + response, _, err := nbdns.ExchangeWithFallback(nil, privateClient, msg, r.resolverAddr.String()) + if err != nil { + return nil, fmt.Errorf("DNS query for %s (type %d) after %s: %s", domain.SafeString(), qtype, time.Since(startTime), err) + } - for _, answ := range response.Answer { - if aRecord, ok := answ.(*dns.A); ok { - ips = append(ips, aRecord.A) + if response.Rcode != dns.RcodeSuccess { + continue } - if aaaaRecord, ok := answ.(*dns.AAAA); ok { - ips = append(ips, aaaaRecord.AAAA) + + for _, answ := range response.Answer { + if aRecord, ok := answ.(*dns.A); ok { + ips = append(ips, aRecord.A) + } + if aaaaRecord, ok := answ.(*dns.AAAA); ok { + ips = append(ips, aaaaRecord.AAAA) + } } } diff --git a/client/internal/routemanager/fakeip/fakeip.go b/client/internal/routemanager/fakeip/fakeip.go index 1592045d20e..5be4ca12e4c 100644 --- a/client/internal/routemanager/fakeip/fakeip.go +++ b/client/internal/routemanager/fakeip/fakeip.go @@ -1,93 +1,145 @@ package fakeip import ( + "errors" "fmt" "net/netip" "sync" ) -// Manager manages allocation of fake IPs from the 240.0.0.0/8 block -type Manager struct { - mu sync.Mutex - nextIP netip.Addr // Next IP to allocate +var ( + // 240.0.0.1 - 240.255.255.254, block 240.0.0.0/8 (reserved, RFC 1112) + v4Base = netip.AddrFrom4([4]byte{240, 0, 0, 1}) + v4Max = netip.AddrFrom4([4]byte{240, 255, 255, 254}) + v4Block = netip.PrefixFrom(netip.AddrFrom4([4]byte{240, 0, 0, 0}), 8) + + // 0100::1 - 0100::ffff:ffff:ffff:fffe, block 0100::/64 (discard, RFC 6666) + v6Base = netip.AddrFrom16([16]byte{0x01, 0x00, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}) + v6Max = netip.AddrFrom16([16]byte{0x01, 0x00, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}) + v6Block = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x01, 0x00}), 64) +) + +// fakeIPPool holds the allocation state for a single address family. +type fakeIPPool struct { + nextIP netip.Addr + baseIP netip.Addr + maxIP netip.Addr + block netip.Prefix allocated map[netip.Addr]netip.Addr // real IP -> fake IP fakeToReal map[netip.Addr]netip.Addr // fake IP -> real IP - baseIP netip.Addr // First usable IP: 240.0.0.1 - maxIP netip.Addr // Last usable IP: 240.255.255.254 } -// NewManager creates a new fake IP manager using 240.0.0.0/8 block -func NewManager() *Manager { - baseIP := netip.AddrFrom4([4]byte{240, 0, 0, 1}) - maxIP := netip.AddrFrom4([4]byte{240, 255, 255, 254}) - - return &Manager{ - nextIP: baseIP, +func newPool(base, maxAddr netip.Addr, block netip.Prefix) *fakeIPPool { + return &fakeIPPool{ + nextIP: base, + baseIP: base, + maxIP: maxAddr, + block: block, allocated: make(map[netip.Addr]netip.Addr), fakeToReal: make(map[netip.Addr]netip.Addr), - baseIP: baseIP, - maxIP: maxIP, } } -// AllocateFakeIP allocates a fake IP for the given real IP -// Returns the fake IP, or existing fake IP if already allocated -func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) { - if !realIP.Is4() { - return netip.Addr{}, fmt.Errorf("only IPv4 addresses supported") - } - - m.mu.Lock() - defer m.mu.Unlock() - - if fakeIP, exists := m.allocated[realIP]; exists { +// allocate allocates a fake IP for the given real IP. +// Returns the existing fake IP if already allocated. +func (p *fakeIPPool) allocate(realIP netip.Addr) (netip.Addr, error) { + if fakeIP, exists := p.allocated[realIP]; exists { return fakeIP, nil } - startIP := m.nextIP + startIP := p.nextIP for { - currentIP := m.nextIP + currentIP := p.nextIP // Advance to next IP, wrapping at boundary - if m.nextIP.Compare(m.maxIP) >= 0 { - m.nextIP = m.baseIP + if p.nextIP.Compare(p.maxIP) >= 0 { + p.nextIP = p.baseIP } else { - m.nextIP = m.nextIP.Next() + p.nextIP = p.nextIP.Next() } - // Check if current IP is available - if _, inUse := m.fakeToReal[currentIP]; !inUse { - m.allocated[realIP] = currentIP - m.fakeToReal[currentIP] = realIP + if _, inUse := p.fakeToReal[currentIP]; !inUse { + p.allocated[realIP] = currentIP + p.fakeToReal[currentIP] = realIP return currentIP, nil } - // Prevent infinite loop if all IPs exhausted - if m.nextIP.Compare(startIP) == 0 { - return netip.Addr{}, fmt.Errorf("no more fake IPs available in 240.0.0.0/8 block") + if p.nextIP.Compare(startIP) == 0 { + return netip.Addr{}, fmt.Errorf("no more fake IPs available in %s block", p.block) } } } -// GetFakeIP returns the fake IP for a real IP if it exists +// Manager manages allocation of fake IPs for dynamic DNS routes. +// IPv4 uses 240.0.0.0/8 (reserved), IPv6 uses 0100::/64 (discard, RFC 6666). +type Manager struct { + mu sync.Mutex + v4 *fakeIPPool + v6 *fakeIPPool +} + +// NewManager creates a new fake IP manager. +func NewManager() *Manager { + return &Manager{ + v4: newPool(v4Base, v4Max, v4Block), + v6: newPool(v6Base, v6Max, v6Block), + } +} + +func (m *Manager) pool(ip netip.Addr) *fakeIPPool { + if ip.Is6() { + return m.v6 + } + return m.v4 +} + +// AllocateFakeIP allocates a fake IP for the given real IP. +func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) { + realIP = realIP.Unmap() + if !realIP.IsValid() { + return netip.Addr{}, errors.New("invalid IP address") + } + + m.mu.Lock() + defer m.mu.Unlock() + + return m.pool(realIP).allocate(realIP) +} + +// GetFakeIP returns the fake IP for a real IP if it exists. func (m *Manager) GetFakeIP(realIP netip.Addr) (netip.Addr, bool) { + realIP = realIP.Unmap() + if !realIP.IsValid() { + return netip.Addr{}, false + } + m.mu.Lock() defer m.mu.Unlock() - fakeIP, exists := m.allocated[realIP] - return fakeIP, exists + fakeIP, ok := m.pool(realIP).allocated[realIP] + return fakeIP, ok } -// GetRealIP returns the real IP for a fake IP if it exists, otherwise false +// GetRealIP returns the real IP for a fake IP if it exists. func (m *Manager) GetRealIP(fakeIP netip.Addr) (netip.Addr, bool) { + fakeIP = fakeIP.Unmap() + if !fakeIP.IsValid() { + return netip.Addr{}, false + } + m.mu.Lock() defer m.mu.Unlock() - realIP, exists := m.fakeToReal[fakeIP] - return realIP, exists + realIP, ok := m.pool(fakeIP).fakeToReal[fakeIP] + return realIP, ok } -// GetFakeIPBlock returns the fake IP block used by this manager +// GetFakeIPBlock returns the v4 fake IP block used by this manager. func (m *Manager) GetFakeIPBlock() netip.Prefix { - return netip.MustParsePrefix("240.0.0.0/8") + return m.v4.block +} + +// GetFakeIPv6Block returns the v6 fake IP block used by this manager. +func (m *Manager) GetFakeIPv6Block() netip.Prefix { + return m.v6.block } diff --git a/client/internal/routemanager/fakeip/fakeip_test.go b/client/internal/routemanager/fakeip/fakeip_test.go index ad3e4bd4e60..f554f970d8c 100644 --- a/client/internal/routemanager/fakeip/fakeip_test.go +++ b/client/internal/routemanager/fakeip/fakeip_test.go @@ -9,16 +9,16 @@ import ( func TestNewManager(t *testing.T) { manager := NewManager() - if manager.baseIP.String() != "240.0.0.1" { - t.Errorf("Expected base IP 240.0.0.1, got %s", manager.baseIP.String()) + if manager.v4.baseIP.String() != "240.0.0.1" { + t.Errorf("Expected v4 base IP 240.0.0.1, got %s", manager.v4.baseIP.String()) } - if manager.maxIP.String() != "240.255.255.254" { - t.Errorf("Expected max IP 240.255.255.254, got %s", manager.maxIP.String()) + if manager.v4.maxIP.String() != "240.255.255.254" { + t.Errorf("Expected v4 max IP 240.255.255.254, got %s", manager.v4.maxIP.String()) } - if manager.nextIP.Compare(manager.baseIP) != 0 { - t.Errorf("Expected nextIP to start at baseIP") + if manager.v6.baseIP.String() != "100::1" { + t.Errorf("Expected v6 base IP 100::1, got %s", manager.v6.baseIP.String()) } } @@ -35,7 +35,6 @@ func TestAllocateFakeIP(t *testing.T) { t.Error("Fake IP should be IPv4") } - // Check it's in the correct range if fakeIP.As4()[0] != 240 { t.Errorf("Fake IP should be in 240.0.0.0/8 range, got %s", fakeIP.String()) } @@ -51,13 +50,31 @@ func TestAllocateFakeIP(t *testing.T) { } } -func TestAllocateFakeIPIPv6Rejection(t *testing.T) { +func TestAllocateFakeIPv6(t *testing.T) { manager := NewManager() - realIPv6 := netip.MustParseAddr("2001:db8::1") + realIP := netip.MustParseAddr("2001:db8::1") - _, err := manager.AllocateFakeIP(realIPv6) - if err == nil { - t.Error("Expected error for IPv6 address") + fakeIP, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to allocate fake IPv6: %v", err) + } + + if !fakeIP.Is6() { + t.Error("Fake IP should be IPv6") + } + + if !netip.MustParsePrefix("100::/64").Contains(fakeIP) { + t.Errorf("Fake IP should be in 100::/64 range, got %s", fakeIP.String()) + } + + // Should return same fake IP for same real IP + fakeIP2, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to get existing fake IPv6: %v", err) + } + + if fakeIP.Compare(fakeIP2) != 0 { + t.Errorf("Expected same fake IP, got %s and %s", fakeIP.String(), fakeIP2.String()) } } @@ -65,13 +82,11 @@ func TestGetFakeIP(t *testing.T) { manager := NewManager() realIP := netip.MustParseAddr("1.1.1.1") - // Should not exist initially _, exists := manager.GetFakeIP(realIP) if exists { t.Error("Fake IP should not exist before allocation") } - // Allocate and check expectedFakeIP, err := manager.AllocateFakeIP(realIP) if err != nil { t.Fatalf("Failed to allocate: %v", err) @@ -87,12 +102,30 @@ func TestGetFakeIP(t *testing.T) { } } +func TestGetRealIPv6(t *testing.T) { + manager := NewManager() + realIP := netip.MustParseAddr("2001:db8::1") + + fakeIP, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to allocate: %v", err) + } + + gotReal, exists := manager.GetRealIP(fakeIP) + if !exists { + t.Error("Real IP should exist for allocated fake IP") + } + + if gotReal.Compare(realIP) != 0 { + t.Errorf("Expected real IP %s, got %s", realIP, gotReal) + } +} + func TestMultipleAllocations(t *testing.T) { manager := NewManager() allocations := make(map[netip.Addr]netip.Addr) - // Allocate multiple IPs for i := 1; i <= 100; i++ { realIP := netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)}) fakeIP, err := manager.AllocateFakeIP(realIP) @@ -100,7 +133,6 @@ func TestMultipleAllocations(t *testing.T) { t.Fatalf("Failed to allocate fake IP for %s: %v", realIP.String(), err) } - // Check for duplicates for _, existingFake := range allocations { if fakeIP.Compare(existingFake) == 0 { t.Errorf("Duplicate fake IP allocated: %s", fakeIP.String()) @@ -110,7 +142,6 @@ func TestMultipleAllocations(t *testing.T) { allocations[realIP] = fakeIP } - // Verify all allocations can be retrieved for realIP, expectedFake := range allocations { actualFake, exists := manager.GetFakeIP(realIP) if !exists { @@ -124,11 +155,13 @@ func TestMultipleAllocations(t *testing.T) { func TestGetFakeIPBlock(t *testing.T) { manager := NewManager() - block := manager.GetFakeIPBlock() - expected := "240.0.0.0/8" - if block.String() != expected { - t.Errorf("Expected %s, got %s", expected, block.String()) + if block := manager.GetFakeIPBlock(); block.String() != "240.0.0.0/8" { + t.Errorf("Expected 240.0.0.0/8, got %s", block.String()) + } + + if block := manager.GetFakeIPv6Block(); block.String() != "100::/64" { + t.Errorf("Expected 100::/64, got %s", block.String()) } } @@ -141,7 +174,6 @@ func TestConcurrentAccess(t *testing.T) { var wg sync.WaitGroup results := make(chan netip.Addr, numGoroutines*allocationsPerGoroutine) - // Concurrent allocations for i := 0; i < numGoroutines; i++ { wg.Add(1) go func(goroutineID int) { @@ -161,7 +193,6 @@ func TestConcurrentAccess(t *testing.T) { wg.Wait() close(results) - // Check for duplicates seen := make(map[netip.Addr]bool) count := 0 for fakeIP := range results { @@ -178,47 +209,61 @@ func TestConcurrentAccess(t *testing.T) { } func TestIPExhaustion(t *testing.T) { - // Create a manager with limited range for testing manager := &Manager{ - nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}), - allocated: make(map[netip.Addr]netip.Addr), - fakeToReal: make(map[netip.Addr]netip.Addr), - baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}), - maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 3}), // Only 3 IPs available + v4: newPool( + netip.AddrFrom4([4]byte{240, 0, 0, 1}), + netip.AddrFrom4([4]byte{240, 0, 0, 3}), + netip.MustParsePrefix("240.0.0.0/8"), + ), + v6: newPool( + netip.MustParseAddr("100::1"), + netip.MustParseAddr("100::3"), + netip.MustParsePrefix("100::/64"), + ), } - // Allocate all available IPs - realIPs := []netip.Addr{ - netip.MustParseAddr("1.0.0.1"), - netip.MustParseAddr("1.0.0.2"), - netip.MustParseAddr("1.0.0.3"), - } - - for _, realIP := range realIPs { - _, err := manager.AllocateFakeIP(realIP) + for _, realIP := range []string{"1.0.0.1", "1.0.0.2", "1.0.0.3"} { + _, err := manager.AllocateFakeIP(netip.MustParseAddr(realIP)) if err != nil { t.Fatalf("Failed to allocate fake IP: %v", err) } } - // Try to allocate one more - should fail _, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.4")) if err == nil { - t.Error("Expected exhaustion error") + t.Error("Expected v4 exhaustion error") + } + + // Same for v6 + for _, realIP := range []string{"2001:db8::1", "2001:db8::2", "2001:db8::3"} { + _, err := manager.AllocateFakeIP(netip.MustParseAddr(realIP)) + if err != nil { + t.Fatalf("Failed to allocate fake IPv6: %v", err) + } + } + + _, err = manager.AllocateFakeIP(netip.MustParseAddr("2001:db8::4")) + if err == nil { + t.Error("Expected v6 exhaustion error") } } func TestWrapAround(t *testing.T) { - // Create manager starting near the end of range manager := &Manager{ - nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}), - allocated: make(map[netip.Addr]netip.Addr), - fakeToReal: make(map[netip.Addr]netip.Addr), - baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}), - maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}), + v4: newPool( + netip.AddrFrom4([4]byte{240, 0, 0, 1}), + netip.AddrFrom4([4]byte{240, 0, 0, 254}), + netip.MustParsePrefix("240.0.0.0/8"), + ), + v6: newPool( + netip.MustParseAddr("100::1"), + netip.MustParseAddr("100::ffff:ffff:ffff:fffe"), + netip.MustParsePrefix("100::/64"), + ), } + // Start near the end + manager.v4.nextIP = netip.AddrFrom4([4]byte{240, 0, 0, 254}) - // Allocate the last IP fakeIP1, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.1")) if err != nil { t.Fatalf("Failed to allocate first IP: %v", err) @@ -228,7 +273,6 @@ func TestWrapAround(t *testing.T) { t.Errorf("Expected 240.0.0.254, got %s", fakeIP1.String()) } - // Next allocation should wrap around to the beginning fakeIP2, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.2")) if err != nil { t.Fatalf("Failed to allocate second IP: %v", err) @@ -238,3 +282,32 @@ func TestWrapAround(t *testing.T) { t.Errorf("Expected 240.0.0.1 after wrap, got %s", fakeIP2.String()) } } + +func TestMixedV4V6(t *testing.T) { + manager := NewManager() + + v4Fake, err := manager.AllocateFakeIP(netip.MustParseAddr("8.8.8.8")) + if err != nil { + t.Fatalf("Failed to allocate v4: %v", err) + } + + v6Fake, err := manager.AllocateFakeIP(netip.MustParseAddr("2001:db8::1")) + if err != nil { + t.Fatalf("Failed to allocate v6: %v", err) + } + + if !v4Fake.Is4() || !v6Fake.Is6() { + t.Errorf("Wrong families: v4=%s v6=%s", v4Fake, v6Fake) + } + + // Reverse lookups should work for both + gotV4, ok := manager.GetRealIP(v4Fake) + if !ok || gotV4.String() != "8.8.8.8" { + t.Errorf("v4 reverse lookup failed: got %s, ok=%v", gotV4, ok) + } + + gotV6, ok := manager.GetRealIP(v6Fake) + if !ok || gotV6.String() != "2001:db8::1" { + t.Errorf("v6 reverse lookup failed: got %s, ok=%v", gotV6, ok) + } +} diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 9afe2049d00..390d04b5403 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -158,15 +158,23 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) { if config.DNSFeatureFlag { m.fakeIPManager = fakeip.NewManager() - id := uuid.NewString() - fakeIPRoute := &route.Route{ - ID: route.ID(id), + v4ID := uuid.NewString() + cr = append(cr, &route.Route{ + ID: route.ID(v4ID), Network: m.fakeIPManager.GetFakeIPBlock(), - NetID: route.NetID(id), + NetID: route.NetID(v4ID), Peer: m.pubKey, NetworkType: route.IPv4Network, - } - cr = append(cr, fakeIPRoute) + }) + + v6ID := uuid.NewString() + cr = append(cr, &route.Route{ + ID: route.ID(v6ID), + Network: m.fakeIPManager.GetFakeIPv6Block(), + NetID: route.NetID(v6ID), + Peer: m.pubKey, + NetworkType: route.IPv6Network, + }) } m.notifier.SetInitialClientRoutes(cr, routesForComparison) diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index ec219c7feeb..02f447a2f4b 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -222,15 +222,20 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er return err } - // TODO: remove once IPv6 is supported on the interface - if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil { - if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) + // When the interface has no v6, add v6 split-default as blackhole so + // unroutable v6 goes to WG (dropped, no AllowedIPs) instead of leaking + // to the system default route. When v6 is active, management sends ::/0 + // as a separate route that the dedicated handler adds. + // Soft-fail: v6 blackhole is best-effort, don't abort v4 routing on failure. + if !r.wgInterface.Address().HasIPv6() { + if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { + log.Warnf("Failed to add v6 blackhole split 1: %s", err) + } else if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil { + log.Warnf("Failed to add v6 blackhole split 2: %s", err) + if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil { + log.Warnf("Failed to rollback v6 blackhole: %s", err2) + } } - return fmt.Errorf("add unreachable route split 2: %w", err) } return nil @@ -266,12 +271,13 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) result = multierror.Append(result, err) } - // TODO: remove once IPv6 is supported on the interface - if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { - result = multierror.Append(result, err) - } - if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil { - result = multierror.Append(result, err) + if !r.wgInterface.Address().HasIPv6() { + if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { + result = multierror.Append(result, err) + } + if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil { + result = multierror.Append(result, err) + } } return nberrors.FormatErrorOrNil(result) diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index bd10f131f07..55e45279c0b 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -53,6 +53,8 @@ const ( // ipv4ForwardingPath is the path to the file containing the IP forwarding setting. ipv4ForwardingPath = "net.ipv4.ip_forward" + // ipv6ForwardingPath is the path to the file containing the IPv6 forwarding setting. + ipv6ForwardingPath = "net.ipv6.conf.all.forwarding" ) var ErrTableIDExists = errors.New("ID exists with different name") @@ -185,10 +187,11 @@ func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { // No need to check if routes exist as main table takes precedence over the VPN table via Rule 1 - // TODO remove this once we have ipv6 support - if prefix == vars.Defaultv4 { + // When the peer has no IPv6, blackhole v6 to prevent leaking. + // When IPv6 is enabled, management sends ::/0 as a separate route. + if prefix == vars.Defaultv4 && (r.wgInterface == nil || !r.wgInterface.Address().HasIPv6()) { if err := addUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil { - return fmt.Errorf("add blackhole: %w", err) + return fmt.Errorf("add v6 blackhole: %w", err) } } if err := addRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil { @@ -206,10 +209,9 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error return r.genericRemoveVPNRoute(prefix, intf) } - // TODO remove this once we have ipv6 support - if prefix == vars.Defaultv4 { + if prefix == vars.Defaultv4 && (r.wgInterface == nil || !r.wgInterface.Address().HasIPv6()) { if err := removeUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil { - return fmt.Errorf("remove unreachable route: %w", err) + log.Debugf("remove v6 blackhole: %v", err) } } if err := removeRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil { @@ -762,8 +764,13 @@ func flushRoutes(tableID, family int) error { } func EnableIPForwarding() error { - _, err := sysctl.Set(ipv4ForwardingPath, 1, false) - return err + if _, err := sysctl.Set(ipv4ForwardingPath, 1, false); err != nil { + return err + } + if _, err := sysctl.Set(ipv6ForwardingPath, 1, false); err != nil { + log.Warnf("failed to enable IPv6 forwarding: %v", err) + } + return nil } // entryExists checks if the specified ID or name already exists in the rt_tables file From 7dff0e2f70e40a2461286bfb22a98e3a8bfbd5e7 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 7 Apr 2026 18:40:09 +0200 Subject: [PATCH 2/3] Keep successful DNS answers when only one address family query fails --- client/internal/routemanager/dynamic/route_ios.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/client/internal/routemanager/dynamic/route_ios.go b/client/internal/routemanager/dynamic/route_ios.go index 6498bd77f66..1ae281d5688 100644 --- a/client/internal/routemanager/dynamic/route_ios.go +++ b/client/internal/routemanager/dynamic/route_ios.go @@ -26,6 +26,7 @@ func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) { startTime := time.Now() var ips []net.IP + var queryErr error for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} { msg := new(dns.Msg) @@ -33,7 +34,10 @@ func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) { response, _, err := nbdns.ExchangeWithFallback(nil, privateClient, msg, r.resolverAddr.String()) if err != nil { - return nil, fmt.Errorf("DNS query for %s (type %d) after %s: %s", domain.SafeString(), qtype, time.Since(startTime), err) + if queryErr == nil { + queryErr = fmt.Errorf("DNS query for %s (type %d) after %s: %w", domain.SafeString(), qtype, time.Since(startTime), err) + } + continue } if response.Rcode != dns.RcodeSuccess { @@ -51,6 +55,9 @@ func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) { } if len(ips) == 0 { + if queryErr != nil { + return nil, queryErr + } return nil, fmt.Errorf("no A or AAAA records found for %s", domain.SafeString()) } From 954c2c1d2eaeefde28f03b7db2503257c5a30e09 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 9 Apr 2026 16:55:22 +0800 Subject: [PATCH 3/3] [client] Add dual-stack nftables manager with IPv6 table support (#5707) --- client/android/client.go | 79 +++-- client/android/route_command.go | 7 +- client/anonymize/anonymize.go | 6 + client/cmd/ssh.go | 4 +- client/cmd/ssh_test.go | 4 +- client/firewall/iptables/acl_linux.go | 31 +- client/firewall/iptables/manager_linux.go | 236 ++++++++++++-- client/firewall/iptables/router_linux.go | 68 +++-- client/firewall/iptables/rule.go | 1 + client/firewall/iptables/state_linux.go | 30 ++ client/firewall/nftables/acl_linux.go | 44 ++- client/firewall/nftables/addr_family_linux.go | 81 +++++ client/firewall/nftables/manager_linux.go | 287 ++++++++++++++++-- .../firewall/nftables/manager_linux_test.go | 124 ++++++++ client/firewall/nftables/router_linux.go | 165 ++++++---- client/firewall/nftables/router_linux_test.go | 189 +++++++++++- .../uspfilter/allow_netbird_windows.go | 53 +++- client/firewall/uspfilter/conntrack/common.go | 7 +- client/firewall/uspfilter/filter.go | 3 +- client/firewall/uspfilter/filter_test.go | 38 +++ client/firewall/uspfilter/localip_test.go | 1 - client/iface/configurer/usp.go | 2 +- client/iface/device/adapter.go | 2 +- client/iface/device/device_android.go | 2 +- client/iface/wgproxy/bind/proxy.go | 22 +- client/internal/debug/debug_test.go | 49 ++- client/internal/dns/service_listener.go | 8 +- client/internal/dnsfwd/manager.go | 1 + client/internal/engine.go | 12 +- .../lazyconn/activity/listener_bind.go | 23 +- client/internal/peer/status.go | 6 +- .../routemanager/ipfwdstate/ipfwdstate.go | 6 +- client/internal/routemanager/server/server.go | 3 +- .../routemanager/systemops/systemops.go | 7 +- .../systemops/systemops_generic.go | 62 ++-- client/ios/NetBirdSDK/client.go | 74 +++-- client/server/network.go | 42 ++- client/ui/network.go | 2 +- client/wasm/cmd/main.go | 115 +++++-- client/wasm/internal/ssh/client.go | 15 +- proxy/cmd/proxy/cmd/debug.go | 21 +- proxy/internal/debug/client.go | 22 +- proxy/internal/debug/handler.go | 18 +- route/route.go | 61 ++++ route/route_test.go | 108 +++++++ shared/relay/client/dialer/quic/quic.go | 2 +- 46 files changed, 1790 insertions(+), 353 deletions(-) create mode 100644 client/firewall/nftables/addr_family_linux.go create mode 100644 route/route_test.go diff --git a/client/android/client.go b/client/android/client.go index 70ebc0011d3..a8766afd243 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -238,41 +238,82 @@ func (c *Client) Networks() *NetworkArray { return nil } + routesMap := routeManager.GetClientRoutesWithNetID() + v6Merged := route.V6ExitMergeSet(routesMap) + resolvedDomains := c.recorder.GetResolvedDomainsStates() + networkArray := &NetworkArray{ items: make([]Network, 0), } - resolvedDomains := c.recorder.GetResolvedDomainsStates() - - for id, routes := range routeManager.GetClientRoutesWithNetID() { + for id, routes := range routesMap { if len(routes) == 0 { continue } + if _, skip := v6Merged[id]; skip { + continue + } + + network := c.buildNetwork(id, routes, routeSelector.IsSelected(id), resolvedDomains, v6Merged) + if network == nil { + continue + } + networkArray.Add(*network) + } + return networkArray +} - r := routes[0] - domains := c.getNetworkDomainsFromRoute(r, resolvedDomains) - netStr := r.Network.String() +func (c *Client) buildNetwork(id route.NetID, routes []*route.Route, selected bool, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo, v6Merged map[route.NetID]struct{}) *Network { + r := routes[0] + netStr := r.Network.String() + if r.IsDynamic() { + netStr = r.Domains.SafeString() + } - if r.IsDynamic() { - netStr = r.Domains.SafeString() + routePeer, err := c.findBestRoutePeer(routes) + if err != nil { + log.Errorf("could not get peer info for route %s: %v", id, err) + return nil + } + + network := &Network{ + Name: string(id), + Network: netStr, + Peer: routePeer.FQDN, + Status: routePeer.ConnStatus.String(), + IsSelected: selected, + Domains: c.getNetworkDomainsFromRoute(r, resolvedDomains), + } + + if route.IsV4DefaultRoute(r.Network) && route.HasV6ExitPair(id, v6Merged) { + network.Network = "0.0.0.0/0, ::/0" + } + + return network +} + +// findBestRoutePeer returns the peer actively routing traffic for the given +// HA route group. Falls back to the first connected peer, then the first peer. +func (c *Client) findBestRoutePeer(routes []*route.Route) (peer.State, error) { + netStr := routes[0].Network.String() + + fullStatus := c.recorder.GetFullStatus() + for _, p := range fullStatus.Peers { + if _, ok := p.GetRoutes()[netStr]; ok { + return p, nil } + } - routePeer, err := c.recorder.GetPeer(routes[0].Peer) + for _, r := range routes { + p, err := c.recorder.GetPeer(r.Peer) if err != nil { - log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err) continue } - network := Network{ - Name: string(id), - Network: netStr, - Peer: routePeer.FQDN, - Status: routePeer.ConnStatus.String(), - IsSelected: routeSelector.IsSelected(id), - Domains: domains, + if p.ConnStatus == peer.StatusConnected { + return p, nil } - networkArray.Add(network) } - return networkArray + return c.recorder.GetPeer(routes[0].Peer) } // OnUpdatedHostDNS update the DNS servers addresses for root zones diff --git a/client/android/route_command.go b/client/android/route_command.go index b47d5ca6ce0..5e735733574 100644 --- a/client/android/route_command.go +++ b/client/android/route_command.go @@ -18,9 +18,12 @@ func executeRouteToggle(id string, manager routemanager.Manager, netID := route.NetID(id) routes := []route.NetID{netID} - log.Debugf("%s with id: %s", operationName, id) + routesMap := manager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) - if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil { + log.Debugf("%s with ids: %v", operationName, routes) + + if err := routeOperation(routes, maps.Keys(routesMap)); err != nil { log.Debugf("error when %s: %s", operationName, err) return fmt.Errorf("error %s: %w", operationName, err) } diff --git a/client/anonymize/anonymize.go b/client/anonymize/anonymize.go index 52beaaa47c1..b7b6a20dd60 100644 --- a/client/anonymize/anonymize.go +++ b/client/anonymize/anonymize.go @@ -9,6 +9,7 @@ import ( "net/url" "regexp" "slices" + "strconv" "strings" ) @@ -97,6 +98,11 @@ func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool { } func (a *Anonymizer) AnonymizeIPString(ip string) string { + // Handle CIDR notation (e.g. "2001:db8::/32") + if prefix, err := netip.ParsePrefix(ip); err == nil { + return a.AnonymizeIP(prefix.Addr()).String() + "/" + strconv.Itoa(prefix.Bits()) + } + addr, err := netip.ParseAddr(ip) if err != nil { return ip diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index 0acf0b13334..de5150b1f0f 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -787,10 +787,10 @@ func isUnixSocket(path string) bool { return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./") } -// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces. +// normalizeLocalHost converts "*" to "" for binding to all interfaces (dual-stack). func normalizeLocalHost(host string) string { if host == "*" { - return "0.0.0.0" + return "" } return host } diff --git a/client/cmd/ssh_test.go b/client/cmd/ssh_test.go index 43291fa87c1..16ffadb90e5 100644 --- a/client/cmd/ssh_test.go +++ b/client/cmd/ssh_test.go @@ -527,10 +527,10 @@ func TestParsePortForward(t *testing.T) { { name: "wildcard bind all interfaces", spec: "*:8080:localhost:80", - expectedLocal: "0.0.0.0:8080", + expectedLocal: ":8080", expectedRemote: "localhost:80", expectError: false, - description: "Wildcard * should bind to all interfaces (0.0.0.0)", + description: "Wildcard * should bind to all interfaces (dual-stack)", }, { name: "wildcard for port only", diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index d83798f0981..4740c41273d 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -36,6 +36,7 @@ type aclManager struct { entries aclEntries optionalEntries map[string][]entry ipsetStore *ipsetStore + v6 bool stateManager *statemanager.Manager } @@ -47,6 +48,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*acl entries: make(map[string][][]string), optionalEntries: make(map[string][]entry), ipsetStore: newIpsetStore(), + v6: iptablesClient.Proto() == iptables.ProtocolIPv6, }, nil } @@ -81,7 +83,11 @@ func (m *aclManager) AddPeerFiltering( chain := chainNameInputRules ipsetName = transformIPsetName(ipsetName, sPort, dPort, action) - specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName) + if m.v6 && ipsetName != "" { + ipsetName += "-v6" + } + proto := protoForFamily(protocol, m.v6) + specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName) mangleSpecs := slices.Clone(specs) mangleSpecs = append(mangleSpecs, @@ -105,6 +111,7 @@ func (m *aclManager) AddPeerFiltering( ip: ip.String(), chain: chain, specs: specs, + v6: m.v6, }}, nil } @@ -157,6 +164,7 @@ func (m *aclManager) AddPeerFiltering( ipsetName: ipsetName, ip: ip.String(), chain: chain, + v6: m.v6, } m.updateState() @@ -376,8 +384,13 @@ func (m *aclManager) updateState() { currentState.Lock() defer currentState.Unlock() - currentState.ACLEntries = m.entries - currentState.ACLIPsetStore = m.ipsetStore + if m.v6 { + currentState.ACLEntries6 = m.entries + currentState.ACLIPsetStore6 = m.ipsetStore + } else { + currentState.ACLEntries = m.entries + currentState.ACLIPsetStore = m.ipsetStore + } if err := m.stateManager.UpdateState(currentState); err != nil { log.Errorf("failed to update state: %v", err) @@ -385,6 +398,15 @@ func (m *aclManager) updateState() { } // filterRuleSpecs returns the specs of a filtering rule +// protoForFamily translates ICMP to ICMPv6 for ip6tables. +// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp". +func protoForFamily(protocol firewall.Protocol, v6 bool) string { + if v6 && protocol == firewall.ProtocolICMP { + return "ipv6-icmp" + } + return string(protocol) +} + func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) { // don't use IP matching if IP is 0.0.0.0 matchByIP := !ip.IsUnspecified() @@ -437,6 +459,9 @@ func (m *aclManager) createIPSet(name string) error { opts := ipset.CreateOptions{ Replace: true, } + if m.v6 { + opts.Family = ipset.FamilyIPV6 + } if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { return fmt.Errorf("create ipset %s: %w", name, err) diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 2fc6f8ec8dc..c278924f250 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -17,6 +17,10 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) +type resetter interface { + Reset() error +} + // Manager of iptables firewall type Manager struct { mutex sync.Mutex @@ -27,6 +31,11 @@ type Manager struct { aclMgr *aclManager router *router rawSupported bool + + // IPv6 counterparts, nil when no v6 overlay + ipv6Client *iptables.IPTables + aclMgr6 *aclManager + router6 *router } // iFaceMapper defines subset methods of interface required for manager @@ -58,9 +67,43 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { return nil, fmt.Errorf("create acl manager: %w", err) } + if wgIface.Address().HasIPv6() { + if err := m.createIPv6Components(wgIface, mtu); err != nil { + return nil, fmt.Errorf("create IPv6 firewall: %w", err) + } + } + return m, nil } +func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error { + ip6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6) + if err != nil { + return fmt.Errorf("init ip6tables: %w", err) + } + m.ipv6Client = ip6Client + + m.router6, err = newRouter(ip6Client, wgIface, mtu) + if err != nil { + return fmt.Errorf("create v6 router: %w", err) + } + + // Share the same IP forwarding state with the v4 router, since + // EnableIPForwarding controls both v4 and v6 sysctls. + m.router6.ipFwdState = m.router.ipFwdState + + m.aclMgr6, err = newAclManager(ip6Client, wgIface) + if err != nil { + return fmt.Errorf("create v6 acl manager: %w", err) + } + + return nil +} + +func (m *Manager) hasIPv6() bool { + return m.ipv6Client != nil +} + func (m *Manager) Init(stateManager *statemanager.Manager) error { state := &ShutdownState{ InterfaceState: &InterfaceState{ @@ -75,13 +118,8 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { log.Errorf("failed to update state: %v", err) } - if err := m.router.init(stateManager); err != nil { - return fmt.Errorf("router init: %w", err) - } - - if err := m.aclMgr.init(stateManager); err != nil { - // TODO: cleanup router - return fmt.Errorf("acl manager init: %w", err) + if err := m.initChains(stateManager); err != nil { + return err } if err := m.initNoTrackChain(); err != nil { @@ -98,6 +136,41 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { return nil } +// initChains initializes router and ACL chains for both address families, +// rolling back on failure. +func (m *Manager) initChains(stateManager *statemanager.Manager) error { + type initStep struct { + name string + init func(*statemanager.Manager) error + mgr resetter + } + + steps := []initStep{ + {"router", m.router.init, m.router}, + {"acl manager", m.aclMgr.init, m.aclMgr}, + } + if m.hasIPv6() { + steps = append(steps, + initStep{"v6 router", m.router6.init, m.router6}, + initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6}, + ) + } + + var initialized []initStep + for _, s := range steps { + if err := s.init(stateManager); err != nil { + for i := len(initialized) - 1; i >= 0; i-- { + if rerr := initialized[i].mgr.Reset(); rerr != nil { + log.Warnf("rollback %s: %v", initialized[i].name, rerr) + } + } + return fmt.Errorf("%s init: %w", s.name, err) + } + initialized = append(initialized, s) + } + return nil +} + // AddPeerFiltering adds a rule to the firewall // // Comment will be ignored because some system this feature is not supported @@ -113,7 +186,13 @@ func (m *Manager) AddPeerFiltering( m.mutex.Lock() defer m.mutex.Unlock() - return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) + if ip.To4() != nil { + return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) + } + if !m.hasIPv6() { + return nil, fmt.Errorf("IPv6 not initialized, cannot add rule for %s", ip) + } + return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) } func (m *Manager) AddRouteFiltering( @@ -127,25 +206,48 @@ func (m *Manager) AddRouteFiltering( m.mutex.Lock() defer m.mutex.Unlock() - if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { - return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) + if isIPv6RouteRule(sources, destination) { + if !m.hasIPv6() { + return nil, fmt.Errorf("IPv6 not initialized, cannot add route rule") + } + return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) } return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) } +func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool { + if destination.IsPrefix() { + return destination.Prefix.Addr().Is6() + } + return len(sources) > 0 && sources[0].Addr().Is6() +} + // DeletePeerRule from the firewall by rule definition func (m *Manager) DeletePeerRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && isIPv6IptRule(rule) { + return m.aclMgr6.DeletePeerRule(rule) + } return m.aclMgr.DeletePeerRule(rule) } +func isIPv6IptRule(rule firewall.Rule) bool { + r, ok := rule.(*Rule) + return ok && r.v6 +} + +// DeleteRouteRule deletes a routing rule. +// Route rules are keyed by content hash. Check v4 first, try v6 if not found. func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && !m.router.hasRule(rule.ID()) { + return m.router6.DeleteRouteRule(rule) + } return m.router.DeleteRouteRule(rule) } @@ -161,18 +263,63 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddNatRule(pair) + if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { + if !m.hasIPv6() { + return fmt.Errorf("IPv6 not initialized, cannot add NAT rule") + } + return m.router6.AddNatRule(pair) + } + + if err := m.router.AddNatRule(pair); err != nil { + return err + } + + // Dynamic routes need NAT in both tables + if m.hasIPv6() && pair.Destination.IsSet() { + v6Pair := pair + v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)} + if err := m.router6.AddNatRule(v6Pair); err != nil { + return fmt.Errorf("add v6 NAT rule: %w", err) + } + } + + return nil } func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveNatRule(pair) + if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { + if !m.hasIPv6() { + return nil + } + return m.router6.RemoveNatRule(pair) + } + + if err := m.router.RemoveNatRule(pair); err != nil { + return err + } + + if m.hasIPv6() && pair.Destination.IsSet() { + v6Pair := pair + v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)} + if err := m.router6.RemoveNatRule(v6Pair); err != nil { + return fmt.Errorf("remove v6 NAT rule: %w", err) + } + } + + return nil } func (m *Manager) SetLegacyManagement(isLegacy bool) error { - return firewall.SetLegacyManagement(m.router, isLegacy) + if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil { + return err + } + if m.hasIPv6() { + return firewall.SetLegacyManagement(m.router6, isLegacy) + } + return nil } // Reset firewall to the default state @@ -186,6 +333,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err)) } + if m.hasIPv6() { + if err := m.aclMgr6.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err)) + } + if err := m.router6.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err)) + } + } + if err := m.aclMgr.Reset(); err != nil { merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err)) } @@ -209,19 +365,16 @@ func (m *Manager) AllowNetbird() error { return nil } - _, err := m.AddPeerFiltering( - nil, - net.IP{0, 0, 0, 0}, - firewall.ProtocolALL, - nil, - nil, - firewall.ActionAccept, - "", - ) - if err != nil { - return fmt.Errorf("allow netbird interface traffic: %w", err) + var merr *multierror.Error + if _, err := m.aclMgr.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil { + merr = multierror.Append(merr, fmt.Errorf("allow netbird interface traffic: %w", err)) } - return nil + if m.hasIPv6() { + if _, err := m.aclMgr6.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil { + merr = multierror.Append(merr, fmt.Errorf("allow v6 netbird interface traffic: %w", err)) + } + } + return nberrors.FormatErrorOrNil(merr) } // Flush doesn't need to be implemented for this manager @@ -251,6 +404,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && rule.TranslatedAddress.Is6() { + return m.router6.AddDNATRule(rule) + } return m.router.AddDNATRule(rule) } @@ -259,6 +415,9 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) { + return m.router6.DeleteDNATRule(rule) + } return m.router.DeleteDNATRule(rule) } @@ -267,7 +426,26 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.UpdateSet(set, prefixes) + var v4Prefixes, v6Prefixes []netip.Prefix + for _, p := range prefixes { + if p.Addr().Is6() { + v6Prefixes = append(v6Prefixes, p) + } else { + v4Prefixes = append(v4Prefixes, p) + } + } + + if err := m.router.UpdateSet(set, v4Prefixes); err != nil { + return err + } + + if m.hasIPv6() && len(v6Prefixes) > 0 { + if err := m.router6.UpdateSet(set, v6Prefixes); err != nil { + return fmt.Errorf("update v6 set: %w", err) + } + } + + return nil } // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. @@ -275,6 +453,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && localAddr.Is6() { + return m.router6.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) + } return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) } @@ -283,6 +464,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && localAddr.Is6() { + return m.router6.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) + } return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) } diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index a7c4f67dd5c..61921f7f9d5 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -54,8 +54,10 @@ const ( snatSuffix = "_snat" fwdSuffix = "_fwd" - // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation - ipTCPHeaderMinSize = 40 + // ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation. + ipv4TCPHeaderSize = 40 + // ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation. + ipv6TCPHeaderSize = 60 ) type ruleInfo struct { @@ -86,6 +88,7 @@ type router struct { wgIface iFaceMapper legacyManagement bool mtu uint16 + v6 bool stateManager *statemanager.Manager ipFwdState *ipfwdstate.IPForwardingState @@ -97,6 +100,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1 rules: make(map[string][]string), wgIface: wgIface, mtu: mtu, + v6: iptablesClient.Proto() == iptables.ProtocolIPv6, ipFwdState: ipfwdstate.NewIPForwardingState(), } @@ -186,6 +190,11 @@ func (r *router) AddRouteFiltering( return ruleKey, nil } +func (r *router) hasRule(id string) bool { + _, ok := r.rules[id] + return ok +} + func (r *router) DeleteRouteRule(rule firewall.Rule) error { ruleKey := rule.ID() @@ -434,6 +443,12 @@ func (r *router) createContainers() error { {chainRTRDR, tableNat}, {chainRTMSSCLAMP, tableMangle}, } { + // Fallback: clear chains that survived an unclean shutdown. + if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok { + if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil { + log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err) + } + } if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil { return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) } @@ -540,9 +555,12 @@ func (r *router) addPostroutingRules() error { } // addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. -// TODO: Add IPv6 support func (r *router) addMSSClampingRules() error { - mss := r.mtu - ipTCPHeaderMinSize + overhead := uint16(ipv4TCPHeaderSize) + if r.v6 { + overhead = ipv6TCPHeaderSize + } + mss := r.mtu - overhead // Add jump rule from FORWARD chain in mangle table to our custom chain jumpRule := []string{ @@ -727,8 +745,13 @@ func (r *router) updateState() { currentState.Lock() defer currentState.Unlock() - currentState.RouteRules = r.rules - currentState.RouteIPsetCounter = r.ipsetCounter + if r.v6 { + currentState.RouteRules6 = r.rules + currentState.RouteIPsetCounter6 = r.ipsetCounter + } else { + currentState.RouteRules = r.rules + currentState.RouteIPsetCounter = r.ipsetCounter + } if err := r.stateManager.UpdateState(currentState); err != nil { log.Errorf("failed to update state: %v", err) @@ -856,7 +879,7 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error { } if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists { - if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil { + if err := r.iptablesClient.Delete(tableFilter, chainRTFWDOUT, fwdRule...); err != nil { merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err)) } delete(r.rules, ruleKey+fwdSuffix) @@ -883,7 +906,7 @@ func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []net rule = append(rule, destExp...) if params.Proto != firewall.ProtocolALL { - rule = append(rule, "-p", strings.ToLower(string(params.Proto))) + rule = append(rule, "-p", strings.ToLower(protoForFamily(params.Proto, r.v6))) rule = append(rule, applyPort("--sport", params.SPort)...) rule = append(rule, applyPort("--dport", params.DPort)...) } @@ -900,11 +923,12 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes [] } if network.IsSet() { - if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil { + name := r.ipsetName(network.Set.HashedName()) + if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil { return nil, fmt.Errorf("create or get ipset: %w", err) } - return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil + return []string{"-m", "set", matchSet, name, direction}, nil } if network.IsPrefix() { return []string{flag, network.Prefix.String()}, nil @@ -915,19 +939,15 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes [] } func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + name := r.ipsetName(set.HashedName()) var merr *multierror.Error for _, prefix := range prefixes { - // TODO: Implement IPv6 support - if prefix.Addr().Is6() { - log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) - continue - } - if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil { + if err := r.addPrefixToIPSet(name, prefix); err != nil { merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err)) } } if merr == nil { - log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes) + log.Debugf("updated set %s with prefixes %v", name, prefixes) } return nberrors.FormatErrorOrNil(merr) @@ -943,7 +963,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol dnatRule := []string{ "-i", r.wgIface.Name(), - "-p", strings.ToLower(string(protocol)), + "-p", strings.ToLower(protoForFamily(protocol, r.v6)), "--dport", strconv.Itoa(int(sourcePort)), "-d", localAddr.String(), "-m", "addrtype", "--dst-type", "LOCAL", @@ -1076,10 +1096,22 @@ func applyPort(flag string, port *firewall.Port) []string { return []string{flag, strconv.Itoa(int(port.Values[0]))} } +// ipsetName returns the ipset name, suffixed with "-v6" for the v6 router +// to avoid collisions since ipsets are global in the kernel. +func (r *router) ipsetName(name string) string { + if r.v6 { + return name + "-v6" + } + return name +} + func (r *router) createIPSet(name string) error { opts := ipset.CreateOptions{ Replace: true, } + if r.v6 { + opts.Family = ipset.FamilyIPV6 + } if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { return fmt.Errorf("create ipset %s: %w", name, err) diff --git a/client/firewall/iptables/rule.go b/client/firewall/iptables/rule.go index aa4d2d07900..4f4eab167e7 100644 --- a/client/firewall/iptables/rule.go +++ b/client/firewall/iptables/rule.go @@ -9,6 +9,7 @@ type Rule struct { mangleSpecs []string ip string chain string + v6 bool } // GetRuleID returns the rule id diff --git a/client/firewall/iptables/state_linux.go b/client/firewall/iptables/state_linux.go index c88774c1f10..6b2e99e31d8 100644 --- a/client/firewall/iptables/state_linux.go +++ b/client/firewall/iptables/state_linux.go @@ -4,6 +4,8 @@ import ( "fmt" "sync" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -37,6 +39,12 @@ type ShutdownState struct { ACLEntries aclEntries `json:"acl_entries,omitempty"` ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"` + + // IPv6 counterparts + RouteRules6 routeRules `json:"route_rules_v6,omitempty"` + RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"` + ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"` + ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"` } func (s *ShutdownState) Name() string { @@ -67,6 +75,28 @@ func (s *ShutdownState) Cleanup() error { ipt.aclMgr.ipsetStore = s.ACLIPsetStore } + // Clean up v6 state even if the current run has no IPv6. + // The previous run may have left ip6tables rules behind. + if !ipt.hasIPv6() { + if err := ipt.createIPv6Components(s.InterfaceState, mtu); err != nil { + log.Warnf("failed to create v6 components for cleanup: %v", err) + } + } + if ipt.hasIPv6() { + if s.RouteRules6 != nil { + ipt.router6.rules = s.RouteRules6 + } + if s.RouteIPsetCounter6 != nil { + ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6) + } + if s.ACLEntries6 != nil { + ipt.aclMgr6.entries = s.ACLEntries6 + } + if s.ACLIPsetStore6 != nil { + ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6 + } + } + if err := ipt.Close(nil); err != nil { return fmt.Errorf("reset iptables manager: %w", err) } diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index a9d066e2fcc..9d2ea726496 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -33,15 +33,12 @@ const ( const flushError = "flush: %w" -var ( - anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} -) - type AclManager struct { rConn *nftables.Conn sConn *nftables.Conn wgIface iFaceMapper routingFwChainName string + af addrFamily workTable *nftables.Table chainInputRules *nftables.Chain @@ -67,6 +64,7 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam wgIface: wgIface, workTable: table, routingFwChainName: routingFwChainName, + af: familyForAddr(table.Family == nftables.TableFamilyIPv4), ipsetStore: newIpsetStore(), rules: make(map[string]*Rule), @@ -145,7 +143,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { } if _, ok := ips[r.ip.String()]; ok { - err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}}) + err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}}) if err != nil { log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err) } @@ -254,11 +252,11 @@ func (m *AclManager) addIOFiltering( expressions = append(expressions, &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, - Offset: uint32(9), + Offset: m.af.protoOffset, Len: uint32(1), }) - protoData, err := protoToInt(proto) + protoData, err := m.af.protoNum(proto) if err != nil { return nil, fmt.Errorf("convert protocol to number: %v", err) } @@ -270,19 +268,16 @@ func (m *AclManager) addIOFiltering( }) } - rawIP := ip.To4() + rawIP := ipToBytes(ip, m.af) // check if rawIP contains zeroed IPv4 0.0.0.0 value // in that case not add IP match expression into the rule definition - if !bytes.HasPrefix(anyIP, rawIP) { - // source address position - addrOffset := uint32(12) - + if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) { expressions = append(expressions, &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, - Offset: addrOffset, - Len: 4, + Offset: m.af.srcAddrOffset, + Len: m.af.addrLen, }, ) // add individual IP for match if no ipset defined @@ -587,7 +582,7 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) { ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName) - rawIP := ip.To4() + rawIP := ipToBytes(ip, m.af) if err != nil { if ipset, err = m.createSet(m.workTable, ipsetName); err != nil { return nil, fmt.Errorf("get set name: %v", err) @@ -619,7 +614,7 @@ func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Se Name: name, Table: table, Dynamic: true, - KeyType: nftables.TypeIPAddr, + KeyType: m.af.setKeyType, } if err := m.rConn.AddSet(ipset, nil); err != nil { @@ -707,15 +702,12 @@ func ifname(n string) []byte { return b } -func protoToInt(protocol firewall.Protocol) (uint8, error) { - switch protocol { - case firewall.ProtocolTCP: - return unix.IPPROTO_TCP, nil - case firewall.ProtocolUDP: - return unix.IPPROTO_UDP, nil - case firewall.ProtocolICMP: - return unix.IPPROTO_ICMP, nil - } - return 0, fmt.Errorf("unsupported protocol: %s", protocol) +// ipToBytes converts net.IP to the correct byte length for the address family. +func ipToBytes(ip net.IP, af addrFamily) []byte { + if af.addrLen == 4 { + return ip.To4() + } + return ip.To16() } + diff --git a/client/firewall/nftables/addr_family_linux.go b/client/firewall/nftables/addr_family_linux.go new file mode 100644 index 00000000000..0c90d704a5e --- /dev/null +++ b/client/firewall/nftables/addr_family_linux.go @@ -0,0 +1,81 @@ +package nftables + +import ( + "fmt" + "net" + + "github.com/google/nftables" + "golang.org/x/sys/unix" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" +) + +var ( + // afIPv4 defines IPv4 header layout and nftables types. + afIPv4 = addrFamily{ + protoOffset: 9, + srcAddrOffset: 12, + dstAddrOffset: 16, + addrLen: net.IPv4len, + totalBits: 8 * net.IPv4len, + setKeyType: nftables.TypeIPAddr, + tableFamily: nftables.TableFamilyIPv4, + icmpProto: unix.IPPROTO_ICMP, + } + // afIPv6 defines IPv6 header layout and nftables types. + afIPv6 = addrFamily{ + protoOffset: 6, + srcAddrOffset: 8, + dstAddrOffset: 24, + addrLen: net.IPv6len, + totalBits: 8 * net.IPv6len, + setKeyType: nftables.TypeIP6Addr, + tableFamily: nftables.TableFamilyIPv6, + icmpProto: unix.IPPROTO_ICMPV6, + } +) + +// addrFamily holds protocol-specific constants for nftables expression building. +type addrFamily struct { + // protoOffset is the IP header offset for the protocol/next-header field (9 for v4, 6 for v6) + protoOffset uint32 + // srcAddrOffset is the IP header offset for the source address (12 for v4, 8 for v6) + srcAddrOffset uint32 + // dstAddrOffset is the IP header offset for the destination address (16 for v4, 24 for v6) + dstAddrOffset uint32 + // addrLen is the byte length of addresses (4 for v4, 16 for v6) + addrLen uint32 + // totalBits is the address size in bits (32 for v4, 128 for v6) + totalBits int + // setKeyType is the nftables set data type for addresses + setKeyType nftables.SetDatatype + // tableFamily is the nftables table family + tableFamily nftables.TableFamily + // icmpProto is the ICMP protocol number for this family (1 for v4, 58 for v6) + icmpProto uint8 +} + +// familyForAddr returns the address family for the given IP. +func familyForAddr(is4 bool) addrFamily { + if is4 { + return afIPv4 + } + return afIPv6 +} + +// protoNum converts a firewall protocol to the IP protocol number, +// using the correct ICMP variant for the address family. +func (af addrFamily) protoNum(protocol firewall.Protocol) (uint8, error) { + switch protocol { + case firewall.ProtocolTCP: + return unix.IPPROTO_TCP, nil + case firewall.ProtocolUDP: + return unix.IPPROTO_UDP, nil + case firewall.ProtocolICMP: + return af.icmpProto, nil + case firewall.ProtocolALL: + return 0, nil + default: + return 0, fmt.Errorf("unsupported protocol: %s", protocol) + } +} diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index beb5b70a79b..c3c1c1a6590 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -11,9 +11,11 @@ import ( "github.com/google/nftables" "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" + nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -49,8 +51,13 @@ type Manager struct { rConn *nftables.Conn wgIface iFaceMapper - router *router - aclManager *AclManager + router *router + aclManager *AclManager + + // IPv6 counterparts, nil when no v6 overlay + router6 *router + aclManager6 *AclManager + notrackOutputChain *nftables.Chain notrackPreroutingChain *nftables.Chain } @@ -62,7 +69,8 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { wgIface: wgIface, } - workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4} + tableName := getTableName() + workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4} var err error m.router, err = newRouter(workTable, wgIface, mtu) @@ -75,11 +83,70 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { return nil, fmt.Errorf("create acl manager: %w", err) } + if wgIface.Address().HasIPv6() { + if err := m.createIPv6Components(tableName, wgIface, mtu); err != nil { + return nil, fmt.Errorf("create IPv6 firewall: %w", err) + } + } + return m, nil } +func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mtu uint16) error { + workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6} + + var err error + m.router6, err = newRouter(workTable6, wgIface, mtu) + if err != nil { + return fmt.Errorf("create v6 router: %w", err) + } + + // Share the same IP forwarding state with the v4 router, since + // EnableIPForwarding controls both v4 and v6 sysctls. + m.router6.ipFwdState = m.router.ipFwdState + + m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw) + if err != nil { + return fmt.Errorf("create v6 acl manager: %w", err) + } + + return nil +} + +// hasIPv6 reports whether the manager has IPv6 components initialized. +func (m *Manager) hasIPv6() bool { + return m.router6 != nil +} + +func (m *Manager) initIPv6() error { + workTable6, err := m.createWorkTableFamily(nftables.TableFamilyIPv6) + if err != nil { + return fmt.Errorf("create v6 work table: %w", err) + } + + if err := m.router6.init(workTable6); err != nil { + return fmt.Errorf("v6 router init: %w", err) + } + + if err := m.aclManager6.init(workTable6); err != nil { + return fmt.Errorf("v6 acl manager init: %w", err) + } + + return nil +} + // Init nftables firewall manager func (m *Manager) Init(stateManager *statemanager.Manager) error { + if err := m.initFirewall(); err != nil { + return err + } + + m.persistState(stateManager) + + return nil +} + +func (m *Manager) initFirewall() error { workTable, err := m.createWorkTable() if err != nil { return fmt.Errorf("create work table: %w", err) @@ -90,20 +157,32 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { } if err := m.aclManager.init(workTable); err != nil { - // TODO: cleanup router + m.rollbackInit() return fmt.Errorf("acl manager init: %w", err) } + if m.hasIPv6() { + if err := m.initIPv6(); err != nil { + // Peer has a v6 address: v6 firewall MUST work or we risk fail-open. + m.rollbackInit() + return fmt.Errorf("init IPv6 firewall (required because peer has IPv6 address): %w", err) + } + } + if err := m.initNoTrackChains(workTable); err != nil { log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err) } + return nil +} + +// persistState saves the current interface state for potential recreation on restart. +// Unlike iptables, which requires tracking individual rules, nftables maintains +// a known state (our netbird table plus a few static rules). This allows for easy +// cleanup using Close() without needing to store specific rules. +func (m *Manager) persistState(stateManager *statemanager.Manager) { stateManager.RegisterState(&ShutdownState{}) - // We only need to record minimal interface state for potential recreation. - // Unlike iptables, which requires tracking individual rules, nftables maintains - // a known state (our netbird table plus a few static rules). This allows for easy - // cleanup using Close() without needing to store specific rules. if err := stateManager.UpdateState(&ShutdownState{ InterfaceState: &InterfaceState{ NameStr: m.wgIface.Name(), @@ -115,14 +194,29 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { log.Errorf("failed to update state: %v", err) } - // persist early go func() { if err := stateManager.PersistState(context.Background()); err != nil { log.Errorf("failed to persist state: %v", err) } }() +} - return nil +// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through. +func (m *Manager) rollbackInit() { + if err := m.router.Reset(); err != nil { + log.Warnf("rollback router: %v", err) + } + if m.hasIPv6() { + if err := m.router6.Reset(); err != nil { + log.Warnf("rollback v6 router: %v", err) + } + } + if err := m.cleanupNetbirdTables(); err != nil { + log.Warnf("cleanup tables: %v", err) + } + if err := m.rConn.Flush(); err != nil { + log.Warnf("flush: %v", err) + } } // AddPeerFiltering rule to the firewall @@ -141,12 +235,14 @@ func (m *Manager) AddPeerFiltering( m.mutex.Lock() defer m.mutex.Unlock() - rawIP := ip.To4() - if rawIP == nil { - return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) + if ip.To4() != nil { + return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) } - return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) + if !m.hasIPv6() { + return nil, fmt.Errorf("IPv6 not initialized, cannot add rule for %s", ip) + } + return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) } func (m *Manager) AddRouteFiltering( @@ -160,8 +256,11 @@ func (m *Manager) AddRouteFiltering( m.mutex.Lock() defer m.mutex.Unlock() - if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { - return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) + if isIPv6RouteRule(sources, destination) { + if !m.hasIPv6() { + return nil, fmt.Errorf("IPv6 not initialized, cannot add route rule") + } + return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) } return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) @@ -172,14 +271,38 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && isIPv6Rule(rule) { + return m.aclManager6.DeletePeerRule(rule) + } return m.aclManager.DeletePeerRule(rule) } -// DeleteRouteRule deletes a routing rule +func isIPv6Rule(rule firewall.Rule) bool { + r, ok := rule.(*Rule) + return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6 +} + +// isIPv6RouteRule determines whether a route rule belongs to the v6 table. +// For static routes, the destination prefix determines the family. For dynamic +// routes (DomainSet), the sources determine the family since management +// duplicates dynamic rules per family. +func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool { + if destination.IsPrefix() { + return destination.Prefix.Addr().Is6() + } + return len(sources) > 0 && sources[0].Addr().Is6() +} + +// DeleteRouteRule deletes a routing rule. +// Route rules are keyed by content hash, so the rule exists in exactly one +// router. We check v4 first; if the key isn't there, try v6. func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && !m.router.hasRule(rule.ID()) { + return m.router6.DeleteRouteRule(rule) + } return m.router.DeleteRouteRule(rule) } @@ -195,17 +318,63 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddNatRule(pair) + if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { + if !m.hasIPv6() { + return fmt.Errorf("IPv6 not initialized, cannot add NAT rule") + } + return m.router6.AddNatRule(pair) + } + + if err := m.router.AddNatRule(pair); err != nil { + return err + } + + // Dynamic routes (DomainSet) need NAT in both tables since resolved IPs + // can be either v4 or v6. + if m.hasIPv6() && pair.Destination.IsSet() { + v6Pair := pair + v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)} + if err := m.router6.AddNatRule(v6Pair); err != nil { + return fmt.Errorf("add v6 NAT rule: %w", err) + } + } + + return nil } func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveNatRule(pair) + if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { + if !m.hasIPv6() { + return nil + } + return m.router6.RemoveNatRule(pair) + } + + if err := m.router.RemoveNatRule(pair); err != nil { + return err + } + + if m.hasIPv6() && pair.Destination.IsSet() { + v6Pair := pair + v6Pair.Source = firewall.Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)} + if err := m.router6.RemoveNatRule(v6Pair); err != nil { + return fmt.Errorf("remove v6 NAT rule: %w", err) + } + } + + return nil } -// AllowNetbird allows netbird interface traffic +// AllowNetbird allows netbird interface traffic. +// TODO: In USP mode this only adds ACCEPT to the netbird table's own chains, +// which doesn't override DROP rules in external tables (e.g. firewalld). +// Should add passthrough rules to external chains (like the native mode router's +// addExternalChainsRules does) for both the netbird table family and inet tables. +// The netbird table itself is fine (routing chains already exist there), but +// non-netbird tables with INPUT/FORWARD hooks can still DROP our WG traffic. func (m *Manager) AllowNetbird() error { if !m.wgIface.IsUserspaceBind() { return nil @@ -217,6 +386,11 @@ func (m *Manager) AllowNetbird() error { if err := m.aclManager.createDefaultAllowRules(); err != nil { return fmt.Errorf("create default allow rules: %w", err) } + if m.hasIPv6() { + if err := m.aclManager6.createDefaultAllowRules(); err != nil { + return fmt.Errorf("create v6 default allow rules: %w", err) + } + } if err := m.rConn.Flush(); err != nil { return fmt.Errorf("flush allow input netbird rules: %w", err) } @@ -226,7 +400,13 @@ func (m *Manager) AllowNetbird() error { // SetLegacyManagement sets the route manager to use legacy management func (m *Manager) SetLegacyManagement(isLegacy bool) error { - return firewall.SetLegacyManagement(m.router, isLegacy) + if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil { + return err + } + if m.hasIPv6() { + return firewall.SetLegacyManagement(m.router6, isLegacy) + } + return nil } // Close closes the firewall manager @@ -234,23 +414,31 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() + var merr *multierror.Error + if err := m.router.Reset(); err != nil { - return fmt.Errorf("reset router: %v", err) + merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err)) + } + + if m.hasIPv6() { + if err := m.router6.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err)) + } } if err := m.cleanupNetbirdTables(); err != nil { - return fmt.Errorf("cleanup netbird tables: %v", err) + merr = multierror.Append(merr, fmt.Errorf("cleanup netbird tables: %v", err)) } if err := m.rConn.Flush(); err != nil { - return fmt.Errorf(flushError, err) + merr = multierror.Append(merr, fmt.Errorf(flushError, err)) } if err := stateManager.DeleteState(&ShutdownState{}); err != nil { - return fmt.Errorf("delete state: %v", err) + merr = multierror.Append(merr, fmt.Errorf("delete state: %v", err)) } - return nil + return nberrors.FormatErrorOrNil(merr) } func (m *Manager) cleanupNetbirdTables() error { @@ -299,6 +487,12 @@ func (m *Manager) Flush() error { return err } + if m.hasIPv6() { + if err := m.aclManager6.Flush(); err != nil { + return fmt.Errorf("flush v6 acl: %w", err) + } + } + if err := m.refreshNoTrackChains(); err != nil { log.Errorf("failed to refresh notrack chains: %v", err) } @@ -311,6 +505,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && rule.TranslatedAddress.Is6() { + return m.router6.AddDNATRule(rule) + } return m.router.AddDNATRule(rule) } @@ -319,6 +516,9 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && !m.router.hasDNATRule(rule.ID()) { + return m.router6.DeleteDNATRule(rule) + } return m.router.DeleteDNATRule(rule) } @@ -327,7 +527,26 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.UpdateSet(set, prefixes) + var v4Prefixes, v6Prefixes []netip.Prefix + for _, p := range prefixes { + if p.Addr().Is6() { + v6Prefixes = append(v6Prefixes, p) + } else { + v4Prefixes = append(v4Prefixes, p) + } + } + + if err := m.router.UpdateSet(set, v4Prefixes); err != nil { + return err + } + + if m.hasIPv6() && len(v6Prefixes) > 0 { + if err := m.router6.UpdateSet(set, v6Prefixes); err != nil { + return fmt.Errorf("update v6 set: %w", err) + } + } + + return nil } // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. @@ -335,6 +554,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && localAddr.Is6() { + return m.router6.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) + } return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) } @@ -343,6 +565,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && localAddr.Is6() { + return m.router6.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) + } return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) } @@ -533,7 +758,11 @@ func (m *Manager) refreshNoTrackChains() error { } func (m *Manager) createWorkTable() (*nftables.Table, error) { - tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) + return m.createWorkTableFamily(nftables.TableFamilyIPv4) +} + +func (m *Manager) createWorkTableFamily(family nftables.TableFamily) (*nftables.Table, error) { + tables, err := m.rConn.ListTablesOfFamily(family) if err != nil { return nil, fmt.Errorf("list of tables: %w", err) } @@ -545,7 +774,7 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) { } } - table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}) + table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: family}) err = m.rConn.Flush() return table, err } diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 75b1e2b6cd9..d925f3ef3ca 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -385,10 +385,134 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { err = manager.AddNatRule(pair) require.NoError(t, err, "failed to add NAT rule") + dnatRule, err := manager.AddDNATRule(fw.ForwardRule{ + Protocol: fw.ProtocolTCP, + DestinationPort: fw.Port{Values: []uint16{8080}}, + TranslatedAddress: netip.MustParseAddr("100.96.0.2"), + TranslatedPort: fw.Port{Values: []uint16{80}}, + }) + require.NoError(t, err, "failed to add DNAT rule") + + t.Cleanup(func() { + require.NoError(t, manager.DeleteDNATRule(dnatRule), "failed to delete DNAT rule") + }) + stdout, stderr = runIptablesSave(t) verifyIptablesOutput(t, stdout, stderr) } +func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + for _, bin := range []string{"ip6tables", "ip6tables-save", "iptables-save"} { + if _, err := exec.LookPath(bin); err != nil { + t.Skipf("%s not available on this system: %v", bin, err) + } + } + + // Seed ip6 tables in the nft backend. Docker may not create them. + seedIp6tables(t) + + ifaceMockV6 := &iFaceMock{ + NameFunc: func() string { return "wt-test" }, + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.96.0.1"), + Network: netip.MustParsePrefix("100.96.0.0/16"), + IPv6: netip.MustParseAddr("fd00::1"), + IPv6Net: netip.MustParsePrefix("fd00::/64"), + } + }, + } + + manager, err := Create(ifaceMockV6, iface.DefaultMTU) + require.NoError(t, err, "create manager") + require.NoError(t, manager.Init(nil)) + + t.Cleanup(func() { + require.NoError(t, manager.Close(nil), "close manager") + + stdout, stderr := runIp6tablesSave(t) + verifyIp6tablesOutput(t, stdout, stderr) + }) + + ip := netip.MustParseAddr("fd00::2") + _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "") + require.NoError(t, err, "add v6 peer filtering rule") + + _, err = manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("fd00:1::/64")}, + fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")}, + fw.ProtocolTCP, + nil, + &fw.Port{Values: []uint16{443}}, + fw.ActionAccept, + ) + require.NoError(t, err, "add v6 route filtering rule") + + err = manager.AddNatRule(fw.RouterPair{ + Source: fw.Network{Prefix: netip.MustParsePrefix("fd00::/64")}, + Destination: fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")}, + Masquerade: true, + }) + require.NoError(t, err, "add v6 NAT rule") + + dnatRule, err := manager.AddDNATRule(fw.ForwardRule{ + Protocol: fw.ProtocolTCP, + DestinationPort: fw.Port{Values: []uint16{8080}}, + TranslatedAddress: netip.MustParseAddr("fd00::2"), + TranslatedPort: fw.Port{Values: []uint16{80}}, + }) + require.NoError(t, err, "add v6 DNAT rule") + + t.Cleanup(func() { + require.NoError(t, manager.DeleteDNATRule(dnatRule), "delete v6 DNAT rule") + }) + + stdout, stderr := runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) + + stdout, stderr = runIp6tablesSave(t) + verifyIp6tablesOutput(t, stdout, stderr) +} + +func seedIp6tables(t *testing.T) { + t.Helper() + for _, tc := range []struct{ table, chain string }{ + {"filter", "FORWARD"}, + {"nat", "POSTROUTING"}, + {"mangle", "FORWARD"}, + } { + add := exec.Command("ip6tables", "-t", tc.table, "-A", tc.chain, "-j", "ACCEPT") + require.NoError(t, add.Run(), "seed ip6tables -t %s", tc.table) + del := exec.Command("ip6tables", "-t", tc.table, "-D", tc.chain, "-j", "ACCEPT") + require.NoError(t, del.Run(), "unseed ip6tables -t %s", tc.table) + } +} + +func runIp6tablesSave(t *testing.T) (string, string) { + t.Helper() + var stdout, stderr bytes.Buffer + cmd := exec.Command("ip6tables-save") + cmd.Stdout = &stdout + cmd.Stderr = &stderr + require.NoError(t, cmd.Run(), "ip6tables-save failed") + return stdout.String(), stderr.String() +} + +func verifyIp6tablesOutput(t *testing.T, stdout, stderr string) { + t.Helper() + require.NotContains(t, stdout, "Table `nat' is incompatible", + "ip6tables-save: nat table incompatible. Full output: %s", stdout) + require.NotContains(t, stdout, "Table `mangle' is incompatible", + "ip6tables-save: mangle table incompatible. Full output: %s", stdout) + require.NotContains(t, stdout, "Table `filter' is incompatible", + "ip6tables-save: filter table incompatible. Full output: %s", stdout) +} + func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) { if check() != NFTABLES { t.Skip("nftables not supported on this system") diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 904daf7cb68..02f8288fec3 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -47,8 +47,10 @@ const ( dnatSuffix = "_dnat" snatSuffix = "_snat" - // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation - ipTCPHeaderMinSize = 40 + // ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation. + ipv4TCPHeaderSize = 40 + // ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation. + ipv6TCPHeaderSize = 60 // maxPrefixesSet 1638 prefixes start to fail, taking some margin maxPrefixesSet = 1500 @@ -73,6 +75,7 @@ type router struct { rules map[string]*nftables.Rule ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set] + af addrFamily wgIface iFaceMapper ipFwdState *ipfwdstate.IPForwardingState legacyManagement bool @@ -85,6 +88,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou workTable: workTable, chains: make(map[string]*nftables.Chain), rules: make(map[string]*nftables.Rule), + af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4), wgIface: wgIface, ipFwdState: ipfwdstate.NewIPForwardingState(), mtu: mtu, @@ -143,7 +147,7 @@ func (r *router) Reset() error { func (r *router) removeNatPreroutingRules() error { table := &nftables.Table{ Name: tableNat, - Family: nftables.TableFamilyIPv4, + Family: r.af.tableFamily, } chain := &nftables.Chain{ Name: chainNameNatPrerouting, @@ -176,7 +180,7 @@ func (r *router) removeNatPreroutingRules() error { } func (r *router) loadFilterTable() (*nftables.Table, error) { - tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) + tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily) if err != nil { return nil, fmt.Errorf("list tables: %w", err) } @@ -408,7 +412,7 @@ func (r *router) AddRouteFiltering( // Handle protocol if proto != firewall.ProtocolALL { - protoNum, err := protoToInt(proto) + protoNum, err := r.af.protoNum(proto) if err != nil { return nil, fmt.Errorf("convert protocol to number: %w", err) } @@ -468,7 +472,24 @@ func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bo return nil, fmt.Errorf("create or get ipset: %w", err) } - return getIpSetExprs(ref, isSource) + return r.getIpSetExprs(ref, isSource) +} + +func (r *router) iptablesProto() iptables.Protocol { + if r.af.tableFamily == nftables.TableFamilyIPv6 { + return iptables.ProtocolIPv6 + } + return iptables.ProtocolIPv4 +} + +func (r *router) hasRule(id string) bool { + _, ok := r.rules[id] + return ok +} + +func (r *router) hasDNATRule(id string) bool { + _, ok := r.rules[id+dnatSuffix] + return ok } func (r *router) DeleteRouteRule(rule firewall.Rule) error { @@ -517,10 +538,10 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err Table: r.workTable, // required for prefixes Interval: true, - KeyType: nftables.TypeIPAddr, + KeyType: r.af.setKeyType, } - elements := convertPrefixesToSet(prefixes) + elements := r.convertPrefixesToSet(prefixes) nElements := len(elements) maxElements := maxPrefixesSet * 2 @@ -553,23 +574,17 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err return nfset, nil } -func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement { +func (r *router) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement { var elements []nftables.SetElement for _, prefix := range prefixes { - // TODO: Implement IPv6 support - if prefix.Addr().Is6() { - log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) - continue - } - // nftables needs half-open intervals [firstIP, lastIP) for prefixes // e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc firstIP := prefix.Addr() lastIP := calculateLastIP(prefix).Next() elements = append(elements, - // the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247 - // nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true}, + // the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247 + // nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true}, nftables.SetElement{Key: firstIP.AsSlice()}, nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true}, ) @@ -579,10 +594,20 @@ func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement { // calculateLastIP determines the last IP in a given prefix. func calculateLastIP(prefix netip.Prefix) netip.Addr { - hostMask := ^uint32(0) >> prefix.Masked().Bits() - lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask + masked := prefix.Masked() + if masked.Addr().Is4() { + hostMask := ^uint32(0) >> masked.Bits() + lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask + return netip.AddrFrom4(uint32ToBytes(lastIP)) + } - return netip.AddrFrom4(uint32ToBytes(lastIP)) + // IPv6: set host bits to all 1s + b := masked.Addr().As16() + bits := masked.Bits() + for i := bits; i < 128; i++ { + b[i/8] |= 1 << (7 - i%8) + } + return netip.AddrFrom16(b) } // Utility function to convert netip.Addr to uint32. @@ -834,9 +859,12 @@ func (r *router) addPostroutingRules() { } // addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. -// TODO: Add IPv6 support func (r *router) addMSSClampingRules() error { - mss := r.mtu - ipTCPHeaderMinSize + overhead := uint16(ipv4TCPHeaderSize) + if r.af.tableFamily == nftables.TableFamilyIPv6 { + overhead = ipv6TCPHeaderSize + } + mss := r.mtu - overhead exprsOut := []expr.Any{ &expr.Meta{ @@ -1043,17 +1071,22 @@ func (r *router) acceptFilterTableRules() error { log.Debugf("Used %s to add accept forward and input rules", fw) }() - // Try iptables first and fallback to nftables if iptables is not available - ipt, err := iptables.New() + // Try iptables first and fallback to nftables if iptables is not available. + // Use the correct protocol (iptables vs ip6tables) for the address family. + ipt, err := iptables.NewWithProtocol(r.iptablesProto()) if err != nil { - // iptables is not available but the filter table exists log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) fw = "nftables" return r.acceptFilterRulesNftables(r.filterTable) } - return r.acceptFilterRulesIptables(ipt) + if err := r.acceptFilterRulesIptables(ipt); err != nil { + log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err) + fw = "nftables" + return r.acceptFilterRulesNftables(r.filterTable) + } + return nil } func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error { @@ -1222,13 +1255,17 @@ func (r *router) removeFilterTableRules() error { return nil } - ipt, err := iptables.New() + ipt, err := iptables.NewWithProtocol(r.iptablesProto()) if err != nil { log.Debugf("iptables not available, using nftables to remove filter rules: %v", err) return r.removeAcceptRulesFromTable(r.filterTable) } - return r.removeAcceptFilterRulesIptables(ipt) + if err := r.removeAcceptFilterRulesIptables(ipt); err != nil { + log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err) + return r.removeAcceptRulesFromTable(r.filterTable) + } + return nil } func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error { @@ -1295,7 +1332,7 @@ func (r *router) removeExternalChainsRules() error { func (r *router) findExternalChains() []*nftables.Chain { var chains []*nftables.Chain - families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet} + families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet} for _, family := range families { allChains, err := r.conn.ListChainsOfTableFamily(family) @@ -1319,8 +1356,8 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool { return false } - // Skip all iptables-managed tables in the ip family - if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) { + // Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat) + if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) { return false } @@ -1461,7 +1498,7 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { return rule, nil } - protoNum, err := protoToInt(rule.Protocol) + protoNum, err := r.af.protoNum(rule.Protocol) if err != nil { return nil, fmt.Errorf("convert protocol to number: %w", err) } @@ -1524,7 +1561,7 @@ func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, rule dnatExprs = append(dnatExprs, &expr.NAT{ Type: expr.NATTypeDestNAT, - Family: uint32(nftables.TableFamilyIPv4), + Family: uint32(r.af.tableFamily), RegAddrMin: 1, RegProtoMin: regProtoMin, RegProtoMax: regProtoMax, @@ -1620,7 +1657,7 @@ func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule f dnatRule := &nftables.Rule{ Table: &nftables.Table{ Name: tableNat, - Family: nftables.TableFamilyIPv4, + Family: r.af.tableFamily, }, Chain: &nftables.Chain{ Name: chainNameNatPrerouting, @@ -1655,8 +1692,8 @@ func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, + Offset: r.af.dstAddrOffset, + Len: r.af.addrLen, }, &expr.Cmp{ Op: expr.CmpOpEq, @@ -1734,7 +1771,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return fmt.Errorf("get set %s: %w", set.HashedName(), err) } - elements := convertPrefixesToSet(prefixes) + elements := r.convertPrefixesToSet(prefixes) if err := r.conn.SetAddElements(nfset, elements); err != nil { return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err) } @@ -1756,7 +1793,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol return nil } - protoNum, err := protoToInt(protocol) + protoNum, err := r.af.protoNum(protocol) if err != nil { return fmt.Errorf("convert protocol to number: %w", err) } @@ -1787,7 +1824,11 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol }, } - exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) + bits := 32 + if localAddr.Is6() { + bits = 128 + } + exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...) exprs = append(exprs, &expr.Immediate{ @@ -1800,7 +1841,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol }, &expr.NAT{ Type: expr.NATTypeDestNAT, - Family: uint32(nftables.TableFamilyIPv4), + Family: uint32(r.af.tableFamily), RegAddrMin: 1, RegProtoMin: 2, RegProtoMax: 0, @@ -1887,7 +1928,7 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, return err } - protoNum, err := protoToInt(protocol) + protoNum, err := r.af.protoNum(protocol) if err != nil { return fmt.Errorf("convert protocol to number: %w", err) } @@ -1912,7 +1953,11 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, }, } - exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) + bits := 32 + if localAddr.Is6() { + bits = 128 + } + exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...) exprs = append(exprs, &expr.Immediate{ @@ -1925,7 +1970,7 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, }, &expr.NAT{ Type: expr.NATTypeDestNAT, - Family: uint32(nftables.TableFamilyIPv4), + Family: uint32(r.af.tableFamily), RegAddrMin: 1, RegProtoMin: 2, }, @@ -1993,45 +2038,44 @@ func (r *router) applyNetwork( } if network.IsPrefix() { - return applyPrefix(network.Prefix, isSource), nil + return r.applyPrefix(network.Prefix, isSource), nil } return nil, nil } // applyPrefix generates nftables expressions for a CIDR prefix -func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any { - // dst offset - offset := uint32(16) +func (r *router) applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any { + // dst offset by default + offset := r.af.dstAddrOffset if isSource { // src offset - offset = 12 + offset = r.af.srcAddrOffset } ones := prefix.Bits() - // 0.0.0.0/0 doesn't need extra expressions + // unspecified address (/0) doesn't need extra expressions if ones == 0 { return nil } - mask := net.CIDRMask(ones, 32) + mask := net.CIDRMask(ones, r.af.totalBits) + xor := make([]byte, r.af.addrLen) return []expr.Any{ &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: offset, - Len: 4, + Len: r.af.addrLen, }, - // netmask &expr.Bitwise{ DestRegister: 1, SourceRegister: 1, - Len: 4, + Len: r.af.addrLen, Mask: mask, - Xor: []byte{0, 0, 0, 0}, + Xor: xor, }, - // net address &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, @@ -2114,13 +2158,12 @@ func getCtNewExprs() []expr.Any { } } -func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) { - - // dst offset - offset := uint32(16) +func (r *router) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) { + // dst offset by default + offset := r.af.dstAddrOffset if isSource { // src offset - offset = 12 + offset = r.af.srcAddrOffset } return []expr.Any{ @@ -2128,7 +2171,7 @@ func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: offset, - Len: 4, + Len: r.af.addrLen, }, &expr.Lookup{ SourceRegister: 1, diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index f0e34d211f5..c5d6729d984 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -90,8 +90,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) { } // Build CIDR matching expressions - sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true) - destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false) + testRouter := &router{af: afIPv4} + sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true) + destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false) // Combine all expressions in the correct order // nolint:gocritic @@ -508,6 +509,136 @@ func TestNftablesCreateIpSet(t *testing.T) { } } +func TestNftablesCreateIpSet_IPv6(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + workTable, err := createWorkTableIPv6() + require.NoError(t, err, "Failed to create v6 work table") + defer deleteWorkTableIPv6() + + r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU) + require.NoError(t, err, "Failed to create router") + require.NoError(t, r.init(workTable)) + defer func() { + require.NoError(t, r.Reset(), "Failed to reset router") + }() + + tests := []struct { + name string + sources []netip.Prefix + expected []netip.Prefix + }{ + { + name: "Single IPv6", + sources: []netip.Prefix{netip.MustParsePrefix("2001:db8::1/128")}, + }, + { + name: "Multiple IPv6 Subnets", + sources: []netip.Prefix{ + netip.MustParsePrefix("fd00::/64"), + netip.MustParsePrefix("2001:db8::/48"), + netip.MustParsePrefix("fe80::/10"), + }, + }, + { + name: "Overlapping IPv6", + sources: []netip.Prefix{ + netip.MustParsePrefix("fd00::/48"), + netip.MustParsePrefix("fd00::/64"), + netip.MustParsePrefix("fd00::1/128"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("fd00::/48"), + }, + }, + { + name: "Mixed prefix lengths", + sources: []netip.Prefix{ + netip.MustParsePrefix("2001:db8:1::/48"), + netip.MustParsePrefix("2001:db8:2::1/128"), + netip.MustParsePrefix("fd00:abcd::/32"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setName := firewall.NewPrefixSet(tt.sources).HashedName() + set, err := r.createIpSet(setName, setInput{prefixes: tt.sources}) + require.NoError(t, err, "Failed to create IPv6 set") + require.NotNil(t, set) + + assert.Equal(t, setName, set.Name) + assert.True(t, set.Interval) + assert.Equal(t, nftables.TypeIP6Addr, set.KeyType) + + fetchedSet, err := r.conn.GetSetByName(r.workTable, setName) + require.NoError(t, err, "Failed to fetch created set") + + elements, err := r.conn.GetSetElements(fetchedSet) + require.NoError(t, err, "Failed to get set elements") + + uniquePrefixes := make(map[string]bool) + for _, elem := range elements { + if !elem.IntervalEnd && len(elem.Key) == 16 { + ip := netip.AddrFrom16([16]byte(elem.Key)) + uniquePrefixes[ip.String()] = true + } + } + + expectedCount := len(tt.expected) + if expectedCount == 0 { + expectedCount = len(tt.sources) + } + assert.Equal(t, expectedCount, len(uniquePrefixes), "unique prefix count mismatch") + + r.conn.DelSet(set) + require.NoError(t, r.conn.Flush()) + }) + } +} + +func createWorkTableIPv6() (*nftables.Table, error) { + sConn, err := nftables.New(nftables.AsLasting()) + if err != nil { + return nil, err + } + + tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6) + if err != nil { + return nil, err + } + for _, t := range tables { + if t.Name == tableNameNetbird { + sConn.DelTable(t) + } + } + + table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv6}) + err = sConn.Flush() + return table, err +} + +func deleteWorkTableIPv6() { + sConn, err := nftables.New(nftables.AsLasting()) + if err != nil { + return + } + + tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6) + if err != nil { + return + } + for _, t := range tables { + if t.Name == tableNameNetbird { + sConn.DelTable(t) + _ = sConn.Flush() + } + } +} + func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) { t.Helper() @@ -627,7 +758,7 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool { func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool { var metaFound, cmpFound bool - expectedProto, _ := protoToInt(proto) + expectedProto, _ := afIPv4.protoNum(proto) for _, e := range exprs { switch ex := e.(type) { case *expr.Meta: @@ -854,3 +985,55 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) { } assert.Equal(t, 1, found, "NAT rule should exist in kernel") } + +func TestCalculateLastIP(t *testing.T) { + tests := []struct { + prefix string + want string + }{ + {"10.0.0.0/24", "10.0.0.255"}, + {"10.0.0.0/32", "10.0.0.0"}, + {"0.0.0.0/0", "255.255.255.255"}, + {"192.168.1.0/28", "192.168.1.15"}, + {"fd00::/64", "fd00::ffff:ffff:ffff:ffff"}, + {"fd00::/128", "fd00::"}, + {"2001:db8::/48", "2001:db8:0:ffff:ffff:ffff:ffff:ffff"}, + {"::/0", "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}, + } + for _, tt := range tests { + t.Run(tt.prefix, func(t *testing.T) { + prefix := netip.MustParsePrefix(tt.prefix) + got := calculateLastIP(prefix) + assert.Equal(t, tt.want, got.String()) + }) + } +} + +func TestConvertPrefixesToSet_IPv6(t *testing.T) { + r := &router{af: afIPv6} + prefixes := []netip.Prefix{ + netip.MustParsePrefix("fd00::/64"), + netip.MustParsePrefix("2001:db8::1/128"), + } + + elements := r.convertPrefixesToSet(prefixes) + + // Each prefix produces 2 elements (start + end) + require.Len(t, elements, 4) + + // fd00::/64 start + assert.Equal(t, netip.MustParseAddr("fd00::").As16(), [16]byte(elements[0].Key)) + assert.False(t, elements[0].IntervalEnd) + + // fd00::/64 end (fd00:0:0:1::, one past the last) + assert.Equal(t, netip.MustParseAddr("fd00:0:0:1::").As16(), [16]byte(elements[1].Key)) + assert.True(t, elements[1].IntervalEnd) + + // 2001:db8::1/128 start + assert.Equal(t, netip.MustParseAddr("2001:db8::1").As16(), [16]byte(elements[2].Key)) + assert.False(t, elements[2].IntervalEnd) + + // 2001:db8::1/128 end (2001:db8::2) + assert.Equal(t, netip.MustParseAddr("2001:db8::2").As16(), [16]byte(elements[3].Key)) + assert.True(t, elements[3].IntervalEnd) +} diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index 6aef2ecfdad..10a2b9116b9 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -5,8 +5,10 @@ import ( "os/exec" "syscall" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -29,15 +31,20 @@ func (m *Manager) Close(*statemanager.Manager) error { return nil } - if !isFirewallRuleActive(firewallRuleName) { - return nil + var merr *multierror.Error + if isFirewallRuleActive(firewallRuleName) { + if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove windows firewall rule: %w", err)) + } } - if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil { - return fmt.Errorf("couldn't remove windows firewall: %w", err) + if isFirewallRuleActive(firewallRuleName + "-v6") { + if err := manageFirewallRule(firewallRuleName+"-v6", deleteRule); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove windows v6 firewall rule: %w", err)) + } } - return nil + return nberrors.FormatErrorOrNil(merr) } // AllowNetbird allows netbird interface traffic @@ -46,17 +53,33 @@ func (m *Manager) AllowNetbird() error { return nil } - if isFirewallRuleActive(firewallRuleName) { - return nil + if !isFirewallRuleActive(firewallRuleName) { + if err := manageFirewallRule(firewallRuleName, + addRule, + "dir=in", + "enable=yes", + "action=allow", + "profile=any", + "localip="+m.wgIface.Address().IP.String(), + ); err != nil { + return err + } + } + + if v6 := m.wgIface.Address().IPv6; v6.IsValid() && !isFirewallRuleActive(firewallRuleName+"-v6") { + if err := manageFirewallRule(firewallRuleName+"-v6", + addRule, + "dir=in", + "enable=yes", + "action=allow", + "profile=any", + "localip="+v6.String(), + ); err != nil { + return err + } } - return manageFirewallRule(firewallRuleName, - addRule, - "dir=in", - "enable=yes", - "action=allow", - "profile=any", - "localip="+m.wgIface.Address().IP.String(), - ) + + return nil } func manageFirewallRule(ruleName string, action action, extraArgs ...string) error { diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index 7be0dd78f94..88e90317c8f 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -1,8 +1,9 @@ package conntrack import ( - "fmt" + "net" "net/netip" + "strconv" "sync/atomic" "time" @@ -64,5 +65,7 @@ type ConnKey struct { } func (c ConnKey) String() string { - return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort) + return net.JoinHostPort(c.SrcIP.Unmap().String(), strconv.Itoa(int(c.SrcPort))) + + " → " + + net.JoinHostPort(c.DstIP.Unmap().String(), strconv.Itoa(int(c.DstPort))) } diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 788aac7098a..75a02ac6f4f 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -1615,7 +1615,8 @@ func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool { } // traffic to our other local interfaces (not NetBird IP) - always forward - if dstIP != m.wgIface.Address().IP { + addr := m.wgIface.Address() + if dstIP != addr.IP && (!addr.IPv6.IsValid() || dstIP != addr.IPv6) { return true } diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index 5897f56cf71..01e5f97c16b 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -1336,6 +1336,44 @@ func TestShouldForward(t *testing.T) { }, } + // Add IPv6 to the interface and test dual-stack cases + wgIPv6 := netip.MustParseAddr("fd00::1") + otherIPv6 := netip.MustParseAddr("fd00::2") + ifaceMock.AddressFunc = func() wgaddr.Address { + return wgaddr.Address{ + IP: wgIP, + Network: netip.PrefixFrom(wgIP, 24), + IPv6: wgIPv6, + IPv6Net: netip.PrefixFrom(wgIPv6, 64), + } + } + + // Re-create manager to pick up the new address with IPv6 + require.NoError(t, manager.Close(nil)) + manager, err = Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) + require.NoError(t, err) + + v6Cases := []struct { + name string + dstIP netip.Addr + expected bool + description string + }{ + {"v6 traffic to other address", otherIPv6, true, "should forward v6 traffic not destined to our v6 address"}, + {"v6 traffic to our v6 IP", wgIPv6, false, "should not forward traffic destined to our v6 address"}, + {"v4 traffic to other with v6 configured", otherIP, true, "should forward v4 traffic when v6 configured"}, + {"v4 traffic to our v4 IP with v6 configured", wgIP, false, "should not forward traffic to our v4 address"}, + } + for _, tt := range v6Cases { + t.Run(tt.name, func(t *testing.T) { + manager.localForwarding = true + manager.netstack = false + decoder := createTCPDecoder(8080) + result := manager.shouldForward(decoder, tt.dstIP) + require.Equal(t, tt.expected, result, tt.description) + }) + } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Configure manager diff --git a/client/firewall/uspfilter/localip_test.go b/client/firewall/uspfilter/localip_test.go index 42573d16d29..0dc524c41bd 100644 --- a/client/firewall/uspfilter/localip_test.go +++ b/client/firewall/uspfilter/localip_test.go @@ -202,4 +202,3 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) { }) } } - diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index e3a96590cd4..9b070aab8a8 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -119,7 +119,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, if err != nil { return fmt.Errorf("failed to parse endpoint address: %w", err) } - addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port)) + addrPort := netip.AddrPortFrom(addr.Unmap(), uint16(endpoint.Port)) c.activityRecorder.UpsertAddress(peerKey, addrPort) } return nil diff --git a/client/iface/device/adapter.go b/client/iface/device/adapter.go index 6ebc0539007..e3caaf9305b 100644 --- a/client/iface/device/adapter.go +++ b/client/iface/device/adapter.go @@ -2,7 +2,7 @@ package device // TunAdapter is an interface for create tun device from external service type TunAdapter interface { - ConfigureInterface(address string, mtu int, dns string, searchDomains string, routes string) (int, error) + ConfigureInterface(address string, addressV6 string, mtu int, dns string, searchDomains string, routes string) (int, error) UpdateAddr(address string) error ProtectSocket(fd int32) bool } diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index 198343fbd3b..cbe88c10c0c 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -63,7 +63,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string searchDomainsToString = "" } - fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), int(t.mtu), dns, searchDomainsToString, routesString) + fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.address.IPv6String(), int(t.mtu), dns, searchDomainsToString, routesString) if err != nil { log.Errorf("failed to create Android interface: %s", err) return nil, err diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index 9ac3ea6df95..5bf670e076b 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -6,7 +6,7 @@ import ( "fmt" "net" "net/netip" - "strings" + "sync" log "github.com/sirupsen/logrus" @@ -196,18 +196,22 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { } } -// fakeAddress returns a fake address that is used to as an identifier for the peer. -// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address. +// fakeAddress returns a fake address that is used as an identifier for the peer. +// The fake address is in the format of 127.1.x.x where x.x is derived from the +// last two bytes of the peer address (works for both IPv4 and IPv6). func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) { - octets := strings.Split(peerAddress.IP.String(), ".") - if len(octets) != 4 { - return nil, fmt.Errorf("invalid IP format") + if peerAddress == nil { + return nil, fmt.Errorf("nil peer address") } - fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])) - if err != nil { - return nil, fmt.Errorf("parse new IP: %w", err) + addr, ok := netip.AddrFromSlice(peerAddress.IP) + if !ok { + return nil, fmt.Errorf("invalid IP format") } + addr = addr.Unmap() + + raw := addr.As16() + fakeIP := netip.AddrFrom4([4]byte{127, 1, raw[14], raw[15]}) netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port)) return &netipAddr, nil diff --git a/client/internal/debug/debug_test.go b/client/internal/debug/debug_test.go index 59837c32880..e242b8b1b51 100644 --- a/client/internal/debug/debug_test.go +++ b/client/internal/debug/debug_test.go @@ -430,8 +430,6 @@ func isInCGNATRange(ip net.IP) bool { } func TestAnonymizeFirewallRules(t *testing.T) { - // TODO: Add ipv6 - // Example iptables-save output iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024 *filter @@ -467,17 +465,31 @@ Chain FORWARD (policy ACCEPT 0 packets, 0 bytes) Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes) pkts bytes target prot opt in out source destination` - // Example nftables output + // Example ip6tables-save output + ip6tablesSave := `# Generated by ip6tables-save v1.8.7 on Thu Dec 19 10:00:00 2024 +*filter +:INPUT ACCEPT [0:0] +:FORWARD ACCEPT [0:0] +:OUTPUT ACCEPT [0:0] +-A INPUT -s fd00:1234::1/128 -j ACCEPT +-A INPUT -s 2607:f8b0:4005::1/128 -j DROP +-A FORWARD -s 2001:db8::/32 -d 2607:f8b0:4005::200e/128 -j ACCEPT +COMMIT` + + // Example nftables output with IPv6 nftablesRules := `table inet filter { chain input { type filter hook input priority filter; policy accept; ip saddr 192.168.1.1 accept ip saddr 44.192.140.1 drop + ip6 saddr 2607:f8b0:4005::1 drop + ip6 saddr fd00:1234::1 accept } chain forward { type filter hook forward priority filter; policy accept; ip saddr 10.0.0.0/8 drop ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept + ip6 saddr 2001:db8::/32 ip6 daddr 2607:f8b0:4005::200e/128 accept } }` @@ -540,4 +552,35 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes) assert.Contains(t, anonNftables, "table inet filter {") assert.Contains(t, anonNftables, "chain input {") assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;") + + // IPv6 public addresses in nftables should be anonymized + assert.NotContains(t, anonNftables, "2607:f8b0:4005::1") + assert.NotContains(t, anonNftables, "2607:f8b0:4005::200e") + assert.NotContains(t, anonNftables, "2001:db8::") + assert.Contains(t, anonNftables, "2001:db8:ffff::") // Default anonymous v6 range + + // ULA addresses in nftables should remain unchanged (private) + assert.Contains(t, anonNftables, "fd00:1234::1") + + // IPv6 nftables structure preserved + assert.Contains(t, anonNftables, "ip6 saddr") + assert.Contains(t, anonNftables, "ip6 daddr") + + // Test ip6tables-save anonymization + anonIp6tablesSave := anonymizer.AnonymizeString(ip6tablesSave) + + // ULA (private) IPv6 should remain unchanged + assert.Contains(t, anonIp6tablesSave, "fd00:1234::1/128") + + // Public IPv6 addresses should be anonymized + assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::1") + assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::200e") + assert.NotContains(t, anonIp6tablesSave, "2001:db8::") + assert.Contains(t, anonIp6tablesSave, "2001:db8:ffff::") // Default anonymous v6 range + + // Structure should be preserved + assert.Contains(t, anonIp6tablesSave, "*filter") + assert.Contains(t, anonIp6tablesSave, "COMMIT") + assert.Contains(t, anonIp6tablesSave, "-j DROP") + assert.Contains(t, anonIp6tablesSave, "-j ACCEPT") } diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index b4b9409f9e3..551555ad4c5 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -189,10 +189,10 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr { } -// evalListenAddress figure out the listen address for the DNS server -// first check the 53 port availability on WG interface or lo, if not success -// pick a random port on WG interface for eBPF, if not success -// check the 5053 port availability on WG interface or lo without eBPF usage, +// evalListenAddress figures out the listen address for the DNS server. +// IPv4-only: all peers have a v4 overlay address, and DNS config points to v4. +// First checks port 53 on WG interface or lo, then tries eBPF on a random port, +// then falls back to port 5053. func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) { if s.customAddr != nil { return s.customAddr.Addr(), s.customAddr.Port(), nil diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index 58b88d9ef40..c4c16cd3f8d 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -80,6 +80,7 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } + // IPv4-only: peers reach the forwarder via its v4 overlay address. localAddr := m.wgIface.Address().IP if localAddr.IsValid() && m.firewall != nil { diff --git a/client/internal/engine.go b/client/internal/engine.go index 1347161e9e6..16410519bed 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -630,7 +630,7 @@ func (e *Engine) initFirewall() error { rosenpassPort := e.rpManager.GetAddress().Port port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}} - // this rule is static and will be torn down on engine down by the firewall manager + // IPv4-only: rosenpass peers connect via AllowedIps[0] which is always v4. if _, err := e.firewall.AddPeerFiltering( nil, net.IP{0, 0, 0, 0}, @@ -682,10 +682,15 @@ func (e *Engine) blockLanAccess() { log.Infof("blocking route LAN access for networks: %v", toBlock) v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0) + v6 := netip.PrefixFrom(netip.IPv6Unspecified(), 0) for _, network := range toBlock { + source := v4 + if network.Addr().Is6() { + source = v6 + } if _, err := e.firewall.AddRouteFiltering( nil, - []netip.Prefix{v4}, + []netip.Prefix{source}, firewallManager.Network{Prefix: network}, firewallManager.ProtocolALL, nil, @@ -2362,8 +2367,7 @@ func getInterfacePrefixes() ([]netip.Prefix, error) { prefix := netip.PrefixFrom(addr.Unmap(), ones).Masked() ip := prefix.Addr() - // TODO: add IPv6 - if !ip.Is4() || ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + if ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { continue } diff --git a/client/internal/lazyconn/activity/listener_bind.go b/client/internal/lazyconn/activity/listener_bind.go index 792d0421550..60b8baadbb5 100644 --- a/client/internal/lazyconn/activity/listener_bind.go +++ b/client/internal/lazyconn/activity/listener_bind.go @@ -57,6 +57,7 @@ func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyc // deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP. // Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y). // It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface. +// For IPv6-only peers, the last two bytes of the v6 address are used. func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) { if len(allowedIPs) == 0 { return netip.Addr{}, fmt.Errorf("no allowed IPs for peer") @@ -64,6 +65,7 @@ func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, e ourNetwork := wgIface.Address().Network + // Try v4 first (preferred: deterministic from overlay IP) var peerIP netip.Addr for _, allowedIP := range allowedIPs { ip := allowedIP.Addr() @@ -76,13 +78,24 @@ func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, e } } - if !peerIP.IsValid() { - return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs") + if peerIP.IsValid() { + octets := peerIP.As4() + return netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}), nil } - octets := peerIP.As4() - fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}) - return fakeIP, nil + // Fallback: use last two bytes of first v6 overlay IP + addr := wgIface.Address() + if addr.IPv6Net.IsValid() { + for _, allowedIP := range allowedIPs { + ip := allowedIP.Addr() + if ip.Is6() && addr.IPv6Net.Contains(ip) { + raw := ip.As16() + return netip.AddrFrom4([4]byte{127, 2, raw[14], raw[15]}), nil + } + } + } + + return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs") } func (d *BindListener) setupLazyConn() error { diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index fbf95de21e7..f4db95c8a0f 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -1055,7 +1055,11 @@ func (d *Status) notifyPeerListChanged() { } func (d *Status) notifyAddressChanged() { - d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP) + addr := d.localPeer.IP + if d.localPeer.IPv6 != "" { + addr = addr + "\n" + d.localPeer.IPv6 + } + d.notifier.localAddressChanged(d.localPeer.FQDN, addr) } func (d *Status) numOfPeers() int { diff --git a/client/internal/routemanager/ipfwdstate/ipfwdstate.go b/client/internal/routemanager/ipfwdstate/ipfwdstate.go index da81c18f9cd..2be1c2ae797 100644 --- a/client/internal/routemanager/ipfwdstate/ipfwdstate.go +++ b/client/internal/routemanager/ipfwdstate/ipfwdstate.go @@ -9,7 +9,11 @@ import ( ) // IPForwardingState is a struct that keeps track of the IP forwarding state. -// todo: read initial state of the IP forwarding from the system and reset the state based on it +// todo: read initial state of the IP forwarding from the system and reset the state based on it. +// todo: separate v4/v6 forwarding state, since the sysctls are independent +// (net.ipv4.ip_forward vs net.ipv6.conf.all.forwarding). Currently the nftables +// manager shares one instance between both routers, which works only because +// EnableIPForwarding enables both sysctls in a single call. type IPForwardingState struct { enabledCounter int } diff --git a/client/internal/routemanager/server/server.go b/client/internal/routemanager/server/server.go index e674c80cdfd..d35b44f5b6e 100644 --- a/client/internal/routemanager/server/server.go +++ b/client/internal/routemanager/server/server.go @@ -146,8 +146,7 @@ func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterP if useNewDNSRoute { destination.Set = firewall.NewDomainSet(route.Domains) } else { - // TODO: add ipv6 additionally - destination = getDefaultPrefix(destination.Prefix) + destination = getDefaultPrefix(route.Network) } } else { destination.Prefix = route.Network.Masked() diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index c0ca21d22e0..8724ed1bad0 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -107,8 +107,13 @@ func (r *SysOps) validateRoute(prefix netip.Prefix) error { addr.IsInterfaceLocalMulticast(), addr.IsMulticast(), addr.IsUnspecified() && prefix.Bits() != 0, - r.wgInterface.Address().Network.Contains(addr): + r.isOwnAddress(addr): return vars.ErrRouteNotAllowed } return nil } + +func (r *SysOps) isOwnAddress(addr netip.Addr) bool { + wgAddr := r.wgInterface.Address() + return wgAddr.Network.Contains(addr) || (wgAddr.IPv6Net.IsValid() && wgAddr.IPv6Net.Contains(addr)) +} diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 02f447a2f4b..07bd2c118db 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -228,29 +228,14 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er // as a separate route that the dedicated handler adds. // Soft-fail: v6 blackhole is best-effort, don't abort v4 routing on failure. if !r.wgInterface.Address().HasIPv6() { - if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { - log.Warnf("Failed to add v6 blackhole split 1: %s", err) - } else if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil { - log.Warnf("Failed to add v6 blackhole split 2: %s", err) - if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil { - log.Warnf("Failed to rollback v6 blackhole: %s", err2) - } + if err := r.addV6SplitDefault(nextHop); err != nil { + log.Warnf("failed to add v6 split-default blackhole: %s", err) } } return nil case vars.Defaultv6: - if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil { - if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return fmt.Errorf("add unreachable route split 2: %w", err) - } - - return nil + return r.addV6SplitDefault(nextHop) } return r.addToRouteTable(prefix, nextHop) @@ -272,30 +257,41 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) } if !r.wgInterface.Address().HasIPv6() { - if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { - result = multierror.Append(result, err) - } - if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil { - result = multierror.Append(result, err) - } + result = multierror.Append(result, r.removeV6SplitDefault(nextHop)) } return nberrors.FormatErrorOrNil(result) case vars.Defaultv6: - var result *multierror.Error - if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { - result = multierror.Append(result, err) - } - if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil { - result = multierror.Append(result, err) - } - - return nberrors.FormatErrorOrNil(result) + return nberrors.FormatErrorOrNil(r.removeV6SplitDefault(nextHop)) default: return r.removeFromRouteTable(prefix, nextHop) } } +func (r *SysOps) addV6SplitDefault(nextHop Nexthop) error { + if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { + return fmt.Errorf("add split 1: %w", err) + } + if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil { + if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil { + log.Warnf("Failed to rollback v6 split-default: %s", err2) + } + return fmt.Errorf("add split 2: %w", err) + } + return nil +} + +func (r *SysOps) removeV6SplitDefault(nextHop Nexthop) *multierror.Error { + var result *multierror.Error + if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { + result = multierror.Append(result, err) + } + if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil { + result = multierror.Append(result, err) + } + return result +} + func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error { beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error { if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil { diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 990e03034f0..c73a0dcd1a9 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -50,10 +50,11 @@ type CustomLogger interface { } type selectRoute struct { - NetID string - Network netip.Prefix - Domains domain.List - Selected bool + NetID string + Network netip.Prefix + Domains domain.List + Selected bool + extraNetworks []netip.Prefix } func init() { @@ -363,48 +364,60 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) { } routeManager := engine.GetRouteManager() - routesMap := routeManager.GetClientRoutesWithNetID() if routeManager == nil { return nil, fmt.Errorf("could not get route manager") } + routesMap := routeManager.GetClientRoutesWithNetID() routeSelector := routeManager.GetRouteSelector() if routeSelector == nil { return nil, fmt.Errorf("could not get route selector") } + v6ExitMerged := route.V6ExitMergeSet(routesMap) + routes := buildSelectRoutes(routesMap, routeSelector.IsSelected, v6ExitMerged) + resolvedDomains := c.recorder.GetResolvedDomainsStates() + + return prepareRouteSelectionDetails(routes, resolvedDomains), nil +} + +func buildSelectRoutes(routesMap map[route.NetID][]*route.Route, isSelected func(route.NetID) bool, v6Merged map[route.NetID]struct{}) []*selectRoute { var routes []*selectRoute for id, rt := range routesMap { if len(rt) == 0 { continue } - route := &selectRoute{ + if _, ok := v6Merged[id]; ok { + continue + } + + r := &selectRoute{ NetID: string(id), Network: rt[0].Network, Domains: rt[0].Domains, - Selected: routeSelector.IsSelected(id), + Selected: isSelected(id), } - routes = append(routes, route) + + v6ID := route.NetID(string(id) + route.V6ExitSuffix) + if _, ok := v6Merged[v6ID]; ok { + r.extraNetworks = []netip.Prefix{routesMap[v6ID][0].Network} + } + + routes = append(routes, r) } sort.Slice(routes, func(i, j int) bool { - iPrefix := routes[i].Network.Bits() - jPrefix := routes[j].Network.Bits() - - if iPrefix == jPrefix { - iAddr := routes[i].Network.Addr() - jAddr := routes[j].Network.Addr() - if iAddr == jAddr { - return routes[i].NetID < routes[j].NetID - } - return iAddr.String() < jAddr.String() + iBits, jBits := routes[i].Network.Bits(), routes[j].Network.Bits() + if iBits != jBits { + return iBits < jBits } - return iPrefix < jPrefix + iAddr, jAddr := routes[i].Network.Addr(), routes[j].Network.Addr() + if iAddr != jAddr { + return iAddr.Less(jAddr) + } + return routes[i].NetID < routes[j].NetID }) - resolvedDomains := c.recorder.GetResolvedDomainsStates() - - return prepareRouteSelectionDetails(routes, resolvedDomains), nil - + return routes } func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails { @@ -425,10 +438,15 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom } domainList = append(domainList, domainResp) } + rangeStr := r.Network.String() + for _, extra := range r.extraNetworks { + rangeStr += ", " + extra.String() + } + domainDetails := DomainDetails{items: domainList} routeSelection = append(routeSelection, RoutesSelectionInfo{ ID: r.NetID, - Network: r.Network.String(), + Network: rangeStr, Domains: &domainDetails, Selected: r.Selected, }) @@ -456,7 +474,9 @@ func (c *Client) SelectRoute(id string) error { } else { log.Debugf("select route with id: %s", id) routes := toNetIDs([]string{id}) - if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil { + routesMap := routeManager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) + if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routesMap)); err != nil { log.Debugf("error when selecting routes: %s", err) return fmt.Errorf("select routes: %w", err) } @@ -483,7 +503,9 @@ func (c *Client) DeselectRoute(id string) error { } else { log.Debugf("deselect route with id: %s", id) routes := toNetIDs([]string{id}) - if err := routeSelector.DeselectRoutes(routes, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil { + routesMap := routeManager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) + if err := routeSelector.DeselectRoutes(routes, maps.Keys(routesMap)); err != nil { log.Debugf("error when deselecting routes: %s", err) return fmt.Errorf("deselect routes: %w", err) } diff --git a/client/server/network.go b/client/server/network.go index bb1cce56c54..4a439d8cf39 100644 --- a/client/server/network.go +++ b/client/server/network.go @@ -16,10 +16,11 @@ import ( ) type selectRoute struct { - NetID route.NetID - Network netip.Prefix - Domains domain.List - Selected bool + NetID route.NetID + Network netip.Prefix + Domains domain.List + Selected bool + extraNetworks []netip.Prefix } // ListNetworks returns a list of all available networks. @@ -44,18 +45,33 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro routesMap := routeMgr.GetClientRoutesWithNetID() routeSelector := routeMgr.GetRouteSelector() + v6ExitMerged := route.V6ExitMergeSet(routesMap) + var routes []*selectRoute for id, rt := range routesMap { if len(rt) == 0 { continue } - route := &selectRoute{ + // Skip v6 exit nodes that are merged into their v4 counterpart. + if _, ok := v6ExitMerged[id]; ok { + continue + } + + r := &selectRoute{ NetID: id, Network: rt[0].Network, Domains: rt[0].Domains, Selected: routeSelector.IsSelected(id), } - routes = append(routes, route) + + // Merge paired v6 exit node prefix into this entry. + v6ID := route.NetID(string(id) + route.V6ExitSuffix) + if _, ok := v6ExitMerged[v6ID]; ok { + v6Prefix := routesMap[v6ID][0].Network + r.extraNetworks = []netip.Prefix{v6Prefix} + } + + routes = append(routes, r) } sort.Slice(routes, func(i, j int) bool { @@ -76,9 +92,13 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro resolvedDomains := s.statusRecorder.GetResolvedDomainsStates() var pbRoutes []*proto.Network for _, route := range routes { + rangeStr := route.Network.String() + for _, extra := range route.extraNetworks { + rangeStr += ", " + extra.String() + } pbRoute := &proto.Network{ ID: string(route.NetID), - Range: route.Network.String(), + Range: rangeStr, Domains: route.Domains.ToSafeStringList(), ResolvedIPs: map[string]*proto.IPList{}, Selected: route.Selected, @@ -137,7 +157,9 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ routeSelector.SelectAllRoutes() } else { routes := toNetIDs(req.GetNetworkIDs()) - netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) + routesMap := routeManager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) + netIdRoutes := maps.Keys(routesMap) if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil { return nil, fmt.Errorf("select routes: %w", err) } @@ -183,7 +205,9 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe routeSelector.DeselectAllRoutes() } else { routes := toNetIDs(req.GetNetworkIDs()) - netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) + routesMap := routeManager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) + netIdRoutes := maps.Keys(routesMap) if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil { return nil, fmt.Errorf("deselect routes: %w", err) } diff --git a/client/ui/network.go b/client/ui/network.go index 6ae57122ea4..4bb0b76118c 100644 --- a/client/ui/network.go +++ b/client/ui/network.go @@ -195,7 +195,7 @@ func getOverlappingNetworks(routes []*proto.Network) []*proto.Network { func getExitNodeNetworks(routes []*proto.Network) []*proto.Network { var filteredRoutes []*proto.Network for _, route := range routes { - if route.Range == "0.0.0.0/0" || route.Range == "::/0" { + if strings.Contains(route.Range, "0.0.0.0/0") || route.Range == "::/0" { filteredRoutes = append(filteredRoutes, route) } } diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go index 0c1a5dc6951..6fa0eeb2acc 100644 --- a/client/wasm/cmd/main.go +++ b/client/wasm/cmd/main.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "net" + "strconv" "syscall/js" "time" @@ -166,39 +167,58 @@ func createSSHMethod(client *netbird.Client) js.Func { }) } - var jwtToken string - if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() { - jwtToken = args[3].String() - } + jwtToken, ipVersion := parseSSHOptions(args) return createPromise(func(resolve, reject js.Value) { - sshClient := ssh.NewClient(client) - - if err := sshClient.Connect(host, port, username, jwtToken); err != nil { - reject.Invoke(err.Error()) - return - } - - if err := sshClient.StartSession(80, 24); err != nil { - if closeErr := sshClient.Close(); closeErr != nil { - log.Errorf("Error closing SSH client: %v", closeErr) - } + jsInterface, err := connectSSH(client, host, port, username, jwtToken, ipVersion) + if err != nil { reject.Invoke(err.Error()) return } - - jsInterface := ssh.CreateJSInterface(sshClient) resolve.Invoke(jsInterface) }) }) } -func performPing(client *netbird.Client, hostname string) { +func parseSSHOptions(args []js.Value) (jwtToken string, ipVersion int) { + if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() { + jwtToken = args[3].String() + } + if len(args) > 4 { + ipVersion = jsIPVersion(args[4]) + } + return +} + +func connectSSH(client *netbird.Client, host string, port int, username, jwtToken string, ipVersion int) (js.Value, error) { + sshClient := ssh.NewClient(client) + + if err := sshClient.Connect(host, port, username, jwtToken, ipVersion); err != nil { + return js.Undefined(), err + } + + if err := sshClient.StartSession(80, 24); err != nil { + if closeErr := sshClient.Close(); closeErr != nil { + log.Errorf("Error closing SSH client: %v", closeErr) + } + return js.Undefined(), err + } + + return ssh.CreateJSInterface(sshClient), nil +} + +func performPing(client *netbird.Client, hostname string, ipVersion int) { ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) defer cancel() + // Default to ping4 to avoid dual-stack ICMP endpoint issues in wireguard-go netstack. + network := "ping4" + if ipVersion == 6 { + network = "ping6" + } + start := time.Now() - conn, err := client.Dial(ctx, "ping", hostname) + conn, err := client.Dial(ctx, network, hostname) if err != nil { js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s failed: %v", hostname, err)) return @@ -225,27 +245,39 @@ func performPing(client *netbird.Client, hostname string) { } latency := time.Since(start) - js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds())) + remote := conn.RemoteAddr().String() + msg := fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds()) + if remote != hostname { + msg += fmt.Sprintf(" (via %s)", remote) + } + js.Global().Get("console").Call("log", msg) } -func performPingTCP(client *netbird.Client, hostname string, port int) { +func performPingTCP(client *netbird.Client, hostname string, port, ipVersion int) { ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) defer cancel() + network := ipVersionNetwork("tcp", ipVersion) + address := net.JoinHostPort(hostname, fmt.Sprintf("%d", port)) start := time.Now() - conn, err := client.Dial(ctx, "tcp", address) + conn, err := client.Dial(ctx, network, address) if err != nil { js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s failed: %v", address, err)) return } latency := time.Since(start) + remote := conn.RemoteAddr().String() if err := conn.Close(); err != nil { log.Debugf("failed to close TCP connection: %v", err) } - js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds())) + msg := fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds()) + if remote != address { + msg += fmt.Sprintf(" (via %s)", remote) + } + js.Global().Get("console").Call("log", msg) } // createPingMethod creates the ping method @@ -262,8 +294,12 @@ func createPingMethod(client *netbird.Client) js.Func { } hostname := args[0].String() + var ipVersion int + if len(args) > 1 { + ipVersion = jsIPVersion(args[1]) + } return createPromise(func(resolve, reject js.Value) { - performPing(client, hostname) + performPing(client, hostname, ipVersion) resolve.Invoke(js.Undefined()) }) }) @@ -290,8 +326,12 @@ func createPingTCPMethod(client *netbird.Client) js.Func { hostname := args[0].String() port := args[1].Int() + var ipVersion int + if len(args) > 2 { + ipVersion = jsIPVersion(args[2]) + } return createPromise(func(resolve, reject js.Value) { - performPingTCP(client, hostname, port) + performPingTCP(client, hostname, port, ipVersion) resolve.Invoke(js.Undefined()) }) }) @@ -464,6 +504,31 @@ func createSetLogLevelMethod(client *netbird.Client) js.Func { }) } +// ipVersionNetwork appends "4" or "6" to a base network string (e.g. "tcp" -> "tcp4"). +func ipVersionNetwork(base string, ipVersion int) string { + switch ipVersion { + case 4: + return base + "4" + case 6: + return base + "6" + default: + return base + } +} + +// jsIPVersion extracts an IP version (4 or 6) from a JS string or number. +func jsIPVersion(v js.Value) int { + switch v.Type() { + case js.TypeNumber: + return v.Int() + case js.TypeString: + n, _ := strconv.Atoi(v.String()) + return n + default: + return 0 + } +} + // createPromise is a helper to create JavaScript promises func createPromise(handler func(resolve, reject js.Value)) js.Value { return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { diff --git a/client/wasm/internal/ssh/client.go b/client/wasm/internal/ssh/client.go index 2f425c614d9..9cfe652669c 100644 --- a/client/wasm/internal/ssh/client.go +++ b/client/wasm/internal/ssh/client.go @@ -46,8 +46,9 @@ func NewClient(nbClient *netbird.Client) *Client { } } -// Connect establishes an SSH connection through NetBird network -func (c *Client) Connect(host string, port int, username, jwtToken string) error { +// Connect establishes an SSH connection through NetBird network. +// ipVersion may be 4, 6, or 0 for automatic selection. +func (c *Client) Connect(host string, port int, username, jwtToken string, ipVersion int) error { addr := net.JoinHostPort(host, fmt.Sprintf("%d", port)) logrus.Infof("SSH: Connecting to %s as %s", addr, username) @@ -63,10 +64,18 @@ func (c *Client) Connect(host string, port int, username, jwtToken string) error Timeout: sshDialTimeout, } + network := "tcp" + switch ipVersion { + case 4: + network = "tcp4" + case 6: + network = "tcp6" + } + ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout) defer cancel() - conn, err := c.nbClient.Dial(ctx, "tcp", addr) + conn, err := c.nbClient.Dial(ctx, network, addr) if err != nil { return fmt.Errorf("dial %s: %w", addr, err) } diff --git a/proxy/cmd/proxy/cmd/debug.go b/proxy/cmd/proxy/cmd/debug.go index 59f7a6b6590..81879e404b2 100644 --- a/proxy/cmd/proxy/cmd/debug.go +++ b/proxy/cmd/proxy/cmd/debug.go @@ -3,6 +3,7 @@ package cmd import ( "fmt" "strconv" + "time" "github.com/spf13/cobra" @@ -57,7 +58,11 @@ var debugSyncCmd = &cobra.Command{ SilenceUsage: true, } -var pingTimeout string +var ( + pingTimeout time.Duration + pingIPv4 bool + pingIPv6 bool +) var debugPingCmd = &cobra.Command{ Use: "ping [port]", @@ -108,7 +113,10 @@ func init() { debugStatusCmd.Flags().StringVar(&statusFilterByStatus, "filter-by-status", "", "Filter by status (idle|connecting|connected)") debugStatusCmd.Flags().StringVar(&statusFilterByConnectionType, "filter-by-connection-type", "", "Filter by connection type (P2P|Relayed)") - debugPingCmd.Flags().StringVar(&pingTimeout, "timeout", "", "Ping timeout (e.g., 10s)") + debugPingCmd.Flags().DurationVar(&pingTimeout, "timeout", 0, "Ping timeout (e.g., 10s)") + debugPingCmd.Flags().BoolVarP(&pingIPv4, "ipv4", "4", false, "Force IPv4") + debugPingCmd.Flags().BoolVarP(&pingIPv6, "ipv6", "6", false, "Force IPv6") + debugPingCmd.MarkFlagsMutuallyExclusive("ipv4", "ipv6") debugCmd.AddCommand(debugHealthCmd) debugCmd.AddCommand(debugClientsCmd) @@ -157,7 +165,14 @@ func runDebugPing(cmd *cobra.Command, args []string) error { } port = p } - return getDebugClient(cmd).PingTCP(cmd.Context(), args[0], args[1], port, pingTimeout) + var ipVersion string + switch { + case pingIPv4: + ipVersion = "4" + case pingIPv6: + ipVersion = "6" + } + return getDebugClient(cmd).PingTCP(cmd.Context(), args[0], args[1], port, pingTimeout, ipVersion) } func runDebugLogLevel(cmd *cobra.Command, args []string) error { diff --git a/proxy/internal/debug/client.go b/proxy/internal/debug/client.go index 01b0bc8e6d6..2ce721eb8d3 100644 --- a/proxy/internal/debug/client.go +++ b/proxy/internal/debug/client.go @@ -6,10 +6,12 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "net/url" "strings" "time" + ) // StatusFilters contains filter options for status queries. @@ -230,12 +232,16 @@ func (c *Client) ClientSyncResponse(ctx context.Context, accountID string) error } // PingTCP performs a TCP ping through a client. -func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout string) error { +// ipVersion may be "4", "6", or "" for automatic. +func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout time.Duration, ipVersion string) error { params := url.Values{} params.Set("host", host) params.Set("port", fmt.Sprintf("%d", port)) - if timeout != "" { - params.Set("timeout", timeout) + if timeout > 0 { + params.Set("timeout", timeout.String()) + } + if ipVersion != "" { + params.Set("ip_version", ipVersion) } path := fmt.Sprintf("/debug/clients/%s/pingtcp?%s", url.PathEscape(accountID), params.Encode()) @@ -244,11 +250,17 @@ func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, func (c *Client) printPingResult(data map[string]any) { success, _ := data["success"].(bool) + host := net.JoinHostPort(fmt.Sprint(data["host"]), fmt.Sprint(data["port"])) if success { - _, _ = fmt.Fprintf(c.out, "Success: %v:%v\n", data["host"], data["port"]) + remote, _ := data["remote"].(string) + if remote != "" && remote != host { + _, _ = fmt.Fprintf(c.out, "Success: %s (via %s)\n", host, remote) + } else { + _, _ = fmt.Fprintf(c.out, "Success: %s\n", host) + } _, _ = fmt.Fprintf(c.out, "Latency: %v\n", data["latency"]) } else { - _, _ = fmt.Fprintf(c.out, "Failed: %v:%v\n", data["host"], data["port"]) + _, _ = fmt.Fprintf(c.out, "Failed: %s\n", host) c.printError(data) } } diff --git a/proxy/internal/debug/handler.go b/proxy/internal/debug/handler.go index c507cfad9d0..c1d145204ee 100644 --- a/proxy/internal/debug/handler.go +++ b/proxy/internal/debug/handler.go @@ -9,6 +9,7 @@ import ( "fmt" "html/template" "maps" + "net" "net/http" "slices" "strconv" @@ -525,13 +526,18 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI } } + network := "tcp" + if v := r.URL.Query().Get("ip_version"); v == "4" || v == "6" { + network += v + } + ctx, cancel := context.WithTimeout(r.Context(), timeout) defer cancel() - address := fmt.Sprintf("%s:%d", host, port) + address := net.JoinHostPort(host, strconv.Itoa(port)) start := time.Now() - conn, err := client.Dial(ctx, "tcp", address) + conn, err := client.Dial(ctx, network, address) if err != nil { h.writeJSON(w, map[string]interface{}{ "success": false, @@ -541,18 +547,22 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI }) return } + + remote := conn.RemoteAddr().String() if err := conn.Close(); err != nil { h.logger.Debugf("close tcp ping connection: %v", err) } latency := time.Since(start) - h.writeJSON(w, map[string]interface{}{ + resp := map[string]interface{}{ "success": true, "host": host, "port": port, + "remote": remote, "latency_ms": latency.Milliseconds(), "latency": formatDuration(latency), - }) + } + h.writeJSON(w, resp) } func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { diff --git a/route/route.go b/route/route.go index c724e7c7d07..97b9721f619 100644 --- a/route/route.go +++ b/route/route.go @@ -20,6 +20,9 @@ const ( MaxMetric = 9999 // MaxNetIDChar Max Network Identifier MaxNetIDChar = 40 + + // V6ExitSuffix is appended to a v4 exit node NetID to form its v6 counterpart. + V6ExitSuffix = "-v6" ) const ( @@ -215,3 +218,61 @@ func ParseNetwork(networkString string) (NetworkType, netip.Prefix, error) { return IPv4Network, masked, nil } + +var ( + v4Default = netip.PrefixFrom(netip.IPv4Unspecified(), 0) + v6Default = netip.PrefixFrom(netip.IPv6Unspecified(), 0) +) + +// IsV4DefaultRoute reports whether p is the IPv4 default route (0.0.0.0/0). +func IsV4DefaultRoute(p netip.Prefix) bool { return p == v4Default } + +// IsV6DefaultRoute reports whether p is the IPv6 default route (::/0). +func IsV6DefaultRoute(p netip.Prefix) bool { return p == v6Default } + +// ExpandV6ExitPairs appends the paired "-v6" exit node NetID for any v4 exit +// node (0.0.0.0/0) in ids that has a matching v6 counterpart (::/0) in routesMap. +// It modifies and returns the input slice. +func ExpandV6ExitPairs(ids []NetID, routesMap map[NetID][]*Route) []NetID { + for _, id := range ids { + rt, ok := routesMap[id] + if !ok || len(rt) == 0 || !IsV4DefaultRoute(rt[0].Network) { + continue + } + v6ID := NetID(string(id) + V6ExitSuffix) + if v6Rt, ok := routesMap[v6ID]; ok && len(v6Rt) > 0 && IsV6DefaultRoute(v6Rt[0].Network) { + if !slices.Contains(ids, v6ID) { + ids = append(ids, v6ID) + } + } + } + return ids +} + +// V6ExitMergeSet scans routesMap and returns the set of v6 exit node NetIDs +// that should be hidden from the UI because they are paired with a v4 exit node. +// A v6 ID is paired when it has suffix "-v6", its route is ::/0, and the base +// name (without "-v6") exists with route 0.0.0.0/0. +func V6ExitMergeSet(routesMap map[NetID][]*Route) map[NetID]struct{} { + merged := make(map[NetID]struct{}) + for id, rt := range routesMap { + if len(rt) == 0 { + continue + } + name := string(id) + if !IsV6DefaultRoute(rt[0].Network) || !strings.HasSuffix(name, V6ExitSuffix) { + continue + } + baseName := NetID(strings.TrimSuffix(name, V6ExitSuffix)) + if baseRt, ok := routesMap[baseName]; ok && len(baseRt) > 0 && IsV4DefaultRoute(baseRt[0].Network) { + merged[id] = struct{}{} + } + } + return merged +} + +// HasV6ExitPair reports whether id has a paired v6 exit node in the merge set. +func HasV6ExitPair(id NetID, v6Merged map[NetID]struct{}) bool { + _, ok := v6Merged[NetID(string(id)+"-v6")] + return ok +} diff --git a/route/route_test.go b/route/route_test.go new file mode 100644 index 00000000000..dab707ed35a --- /dev/null +++ b/route/route_test.go @@ -0,0 +1,108 @@ +package route + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExpandV6ExitPairs(t *testing.T) { + v4ExitRoute := &Route{Network: netip.MustParsePrefix("0.0.0.0/0")} + v6ExitRoute := &Route{Network: netip.MustParsePrefix("::/0")} + regularRoute := &Route{Network: netip.MustParsePrefix("10.0.0.0/8")} + + tests := []struct { + name string + ids []NetID + routesMap map[NetID][]*Route + expected []NetID + }{ + { + name: "v4 exit node with matching v6 pair", + ids: []NetID{"exit-node"}, + routesMap: map[NetID][]*Route{ + "exit-node": {v4ExitRoute}, + "exit-node-v6": {v6ExitRoute}, + }, + expected: []NetID{"exit-node", "exit-node-v6"}, + }, + { + name: "v4 exit node without v6 pair", + ids: []NetID{"exit-node"}, + routesMap: map[NetID][]*Route{ + "exit-node": {v4ExitRoute}, + }, + expected: []NetID{"exit-node"}, + }, + { + name: "regular route is not expanded", + ids: []NetID{"office"}, + routesMap: map[NetID][]*Route{ + "office": {regularRoute}, + "office-v6": {v6ExitRoute}, + }, + expected: []NetID{"office"}, + }, + { + name: "v6 already included is not duplicated", + ids: []NetID{"exit-node", "exit-node-v6"}, + routesMap: map[NetID][]*Route{ + "exit-node": {v4ExitRoute}, + "exit-node-v6": {v6ExitRoute}, + }, + expected: []NetID{"exit-node", "exit-node-v6"}, + }, + { + name: "multiple exit nodes expanded independently", + ids: []NetID{"exit-a", "exit-b"}, + routesMap: map[NetID][]*Route{ + "exit-a": {v4ExitRoute}, + "exit-a-v6": {v6ExitRoute}, + "exit-b": {v4ExitRoute}, + "exit-b-v6": {v6ExitRoute}, + }, + expected: []NetID{"exit-a", "exit-b", "exit-a-v6", "exit-b-v6"}, + }, + { + name: "v6 suffix but not exit node network", + ids: []NetID{"office"}, + routesMap: map[NetID][]*Route{ + "office": {regularRoute}, + "office-v6": {regularRoute}, + }, + expected: []NetID{"office"}, + }, + { + name: "user-chosen name for exit node with v6 pair", + ids: []NetID{"my-exit"}, + routesMap: map[NetID][]*Route{ + "my-exit": {v4ExitRoute}, + "my-exit-v6": {v6ExitRoute}, + }, + expected: []NetID{"my-exit", "my-exit-v6"}, + }, + { + name: "real-world management-generated IDs", + ids: []NetID{"0.0.0.0/0"}, + routesMap: map[NetID][]*Route{ + "0.0.0.0/0": {v4ExitRoute}, + "0.0.0.0/0-v6": {v6ExitRoute}, + }, + expected: []NetID{"0.0.0.0/0", "0.0.0.0/0-v6"}, + }, + { + name: "empty input", + ids: []NetID{}, + routesMap: map[NetID][]*Route{}, + expected: []NetID{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExpandV6ExitPairs(tt.ids, tt.routesMap) + assert.ElementsMatch(t, tt.expected, result) + }) + } +} diff --git a/shared/relay/client/dialer/quic/quic.go b/shared/relay/client/dialer/quic/quic.go index 2d7b00a8095..7d413d4c11b 100644 --- a/shared/relay/client/dialer/quic/quic.go +++ b/shared/relay/client/dialer/quic/quic.go @@ -46,7 +46,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { InitialPacketSize: nbRelay.QUICInitialPacketSize, } - udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{Port: 0}) if err != nil { log.Errorf("failed to listen on UDP: %s", err) return nil, err