Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 55 additions & 24 deletions client/internal/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,52 +12,83 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
)

func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) {
ip, err := netip.ParseAddr(aRecord.RData)
func createPTRRecord(record nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) {
ip, err := netip.ParseAddr(record.RData)
if err != nil {
log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err)
log.Warnf("failed to parse IP address %s: %v", record.RData, err)
return nbdns.SimpleRecord{}, false
}

ip = ip.Unmap()
if !prefix.Contains(ip) {
return nbdns.SimpleRecord{}, false
}

ipOctets := strings.Split(ip.String(), ".")
slices.Reverse(ipOctets)
rdnsName := dns.Fqdn(strings.Join(ipOctets, ".") + ".in-addr.arpa")
var rdnsName string
if ip.Is4() {
octets := strings.Split(ip.String(), ".")
slices.Reverse(octets)
rdnsName = dns.Fqdn(strings.Join(octets, ".") + ".in-addr.arpa")
} else {
// Expand to full 32 nibbles in reverse order (LSB first) per RFC 3596.
raw := ip.As16()
nibbles := make([]string, 32)
for i := 0; i < 16; i++ {
nibbles[31-i*2] = fmt.Sprintf("%x", raw[i]>>4)
nibbles[31-i*2-1] = fmt.Sprintf("%x", raw[i]&0x0f)
}
rdnsName = dns.Fqdn(strings.Join(nibbles, ".") + ".ip6.arpa")
}

return nbdns.SimpleRecord{
Name: rdnsName,
Type: int(dns.TypePTR),
Class: aRecord.Class,
TTL: aRecord.TTL,
RData: dns.Fqdn(aRecord.Name),
Class: record.Class,
TTL: record.TTL,
RData: dns.Fqdn(record.Name),
}, true
}

// generateReverseZoneName creates the reverse DNS zone name for a given network
// generateReverseZoneName creates the reverse DNS zone name for a given network.
// For IPv4 it produces an in-addr.arpa name, for IPv6 an ip6.arpa name.
func generateReverseZoneName(network netip.Prefix) (string, error) {
networkIP := network.Masked().Addr()
networkIP := network.Masked().Addr().Unmap()
bits := network.Bits()

if networkIP.Is4() {
// Round up to nearest byte.
octetsToUse := (bits + 7) / 8

octets := strings.Split(networkIP.String(), ".")
if octetsToUse > len(octets) {
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", bits)
}

reverseOctets := make([]string, octetsToUse)
for i := 0; i < octetsToUse; i++ {
reverseOctets[octetsToUse-1-i] = octets[i]
}

if !networkIP.Is4() {
return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP)
return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil
}

// round up to nearest byte
octetsToUse := (network.Bits() + 7) / 8
// IPv6: round up to nearest nibble (4-bit boundary).
nibblesToUse := (bits + 3) / 4

octets := strings.Split(networkIP.String(), ".")
if octetsToUse > len(octets) {
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits())
raw := networkIP.As16()
allNibbles := make([]string, 32)
for i := 0; i < 16; i++ {
allNibbles[i*2] = fmt.Sprintf("%x", raw[i]>>4)
allNibbles[i*2+1] = fmt.Sprintf("%x", raw[i]&0x0f)
}

reverseOctets := make([]string, octetsToUse)
for i := 0; i < octetsToUse; i++ {
reverseOctets[octetsToUse-1-i] = octets[i]
// Take the first nibblesToUse nibbles (network portion), reverse them.
used := make([]string, nibblesToUse)
for i := 0; i < nibblesToUse; i++ {
used[nibblesToUse-1-i] = allNibbles[i]
}

return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil
return dns.Fqdn(strings.Join(used, ".") + ".ip6.arpa"), nil
}

// zoneExists checks if a zone with the given name already exists in the configuration
Expand All @@ -71,7 +102,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool {
return false
}

// collectPTRRecords gathers all PTR records for the given network from A records
// collectPTRRecords gathers all PTR records for the given network from A and AAAA records.
func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord

Expand All @@ -80,7 +111,7 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
continue
}
for _, record := range zone.Records {
if record.Type != int(dns.TypeA) {
if record.Type != int(dns.TypeA) && record.Type != int(dns.TypeAAAA) {
continue
}

Expand Down
1 change: 1 addition & 0 deletions client/internal/dns/host_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() {
ip = ip.Unmap()
serverAddresses = append(serverAddresses, ip)
// Prefer the first IPv4 server as ServerIP since our DNS listener is IPv4.
if !dnsSettings.ServerIP.IsValid() && ip.Is4() {
dnsSettings.ServerIP = ip
}
Expand Down
15 changes: 11 additions & 4 deletions client/internal/dns/network_manager_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,15 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st

connSettings.cleanDeprecatedSettings()

convDNSIP := binary.LittleEndian.Uint32(config.ServerIP.AsSlice())
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP})
ipKey := networkManagerDbusIPv4Key
if config.ServerIP.Is6() {
ipKey = networkManagerDbusIPv6Key
raw := config.ServerIP.As16()
connSettings[ipKey][networkManagerDbusDNSKey] = dbus.MakeVariant([][]byte{raw[:]})
} else {
convDNSIP := binary.LittleEndian.Uint32(config.ServerIP.AsSlice())
connSettings[ipKey][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP})
}
var (
searchDomains []string
matchDomains []string
Expand Down Expand Up @@ -146,8 +153,8 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st
n.routingAll = false
}

connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority)
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList)
connSettings[ipKey][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority)
connSettings[ipKey][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList)

state := &ShutdownState{
ManagerType: networkManager,
Expand Down
6 changes: 5 additions & 1 deletion client/internal/dns/systemd_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,12 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool {
}

func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
family := int32(unix.AF_INET)
if config.ServerIP.Is6() {
family = unix.AF_INET6
}
defaultLinkInput := systemdDbusDNSInput{
Family: unix.AF_INET,
Family: family,
Address: config.ServerIP.AsSlice(),
}
if err := s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil {
Expand Down
35 changes: 26 additions & 9 deletions client/internal/dns/upstream_ios.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ type upstreamResolverIOS struct {
*upstreamResolverBase
lIP netip.Addr
lNet netip.Prefix
lIPv6 netip.Addr
lNetV6 netip.Prefix
interfaceName string
}

Expand All @@ -37,6 +39,8 @@ func newUpstreamResolver(
upstreamResolverBase: upstreamResolverBase,
lIP: wgIface.Address().IP,
lNet: wgIface.Address().Network,
lIPv6: wgIface.Address().IPv6,
lNetV6: wgIface.Address().IPv6Net,
interfaceName: wgIface.Name(),
}
ios.upstreamClient = ios
Expand Down Expand Up @@ -66,12 +70,23 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
upstreamIP = upstreamIP.Unmap()
}
needsPrivate := u.lNet.Contains(upstreamIP) ||
u.lNetV6.Contains(upstreamIP) ||
(u.routeMatch != nil && u.routeMatch(upstreamIP))
if needsPrivate {
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
if err != nil {
return nil, 0, fmt.Errorf("create private client: %s", err)
var bindIP netip.Addr
switch {
case upstreamIP.Is6() && u.lIPv6.IsValid():
bindIP = u.lIPv6
case upstreamIP.Is4() && u.lIP.IsValid():
bindIP = u.lIP
}

if bindIP.IsValid() {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If have validation issue with lIPv6 or lIP then bindIP will be invalid.
The old code always used the private client for matching upstreams. This is a behavior change. If this i the expected then fine.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is a subtle intentional change. If neither lIP nor lIPv6 is valid for the upstream's address family, we fall through to the default (non-private) client instead of binding to an invalid address. In practice, if needsPrivate is true but the corresponding local IP is invalid, the configuration is broken and there's nothing useful to bind to.

log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
client, err = GetClientPrivate(bindIP, u.interfaceName, timeout)
if err != nil {
return nil, 0, fmt.Errorf("create private client: %s", err)
}
}
}

Expand All @@ -88,16 +103,18 @@ func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Dura
return nil, err
}

proto, opt := unix.IPPROTO_IP, unix.IP_BOUND_IF
if ip.Is6() {
proto, opt = unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF
}

dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{
IP: ip.AsSlice(),
Port: 0, // Let the OS pick a free port
},
LocalAddr: net.UDPAddrFromAddrPort(netip.AddrPortFrom(ip, 0)),
Timeout: dialTimeout,
Control: func(network, address string, c syscall.RawConn) error {
var operr error
fn := func(s uintptr) {
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, index)
operr = unix.SetsockoptInt(int(s), proto, opt, index)
}

if err := c.Control(fn); err != nil {
Expand Down
138 changes: 138 additions & 0 deletions client/internal/dns_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package internal

import (
"net/netip"
"testing"

"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

nbdns "github.com/netbirdio/netbird/dns"
)

func TestCreatePTRRecord_IPv4(t *testing.T) {
record := nbdns.SimpleRecord{
Name: "peer1.netbird.cloud.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "100.64.0.5",
}
prefix := netip.MustParsePrefix("100.64.0.0/16")

ptr, ok := createPTRRecord(record, prefix)
require.True(t, ok)
assert.Equal(t, "5.0.64.100.in-addr.arpa.", ptr.Name)
assert.Equal(t, int(dns.TypePTR), ptr.Type)
assert.Equal(t, "peer1.netbird.cloud.", ptr.RData)
}

func TestCreatePTRRecord_IPv6(t *testing.T) {
record := nbdns.SimpleRecord{
Name: "peer1.netbird.cloud.",
Type: int(dns.TypeAAAA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "fd00:1234:5678::1",
}
prefix := netip.MustParsePrefix("fd00:1234:5678::/48")

ptr, ok := createPTRRecord(record, prefix)
require.True(t, ok)
assert.Equal(t, "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.7.6.5.4.3.2.1.0.0.d.f.ip6.arpa.", ptr.Name)
assert.Equal(t, int(dns.TypePTR), ptr.Type)
assert.Equal(t, "peer1.netbird.cloud.", ptr.RData)
}

func TestCreatePTRRecord_OutOfRange(t *testing.T) {
record := nbdns.SimpleRecord{
Name: "peer1.netbird.cloud.",
Type: int(dns.TypeA),
RData: "10.0.0.1",
}
prefix := netip.MustParsePrefix("100.64.0.0/16")

_, ok := createPTRRecord(record, prefix)
assert.False(t, ok)
}

func TestGenerateReverseZoneName_IPv4(t *testing.T) {
tests := []struct {
prefix string
expected string
}{
{"100.64.0.0/16", "64.100.in-addr.arpa."},
{"10.0.0.0/8", "10.in-addr.arpa."},
{"192.168.1.0/24", "1.168.192.in-addr.arpa."},
}

for _, tt := range tests {
t.Run(tt.prefix, func(t *testing.T) {
zone, err := generateReverseZoneName(netip.MustParsePrefix(tt.prefix))
require.NoError(t, err)
assert.Equal(t, tt.expected, zone)
})
}
}

func TestGenerateReverseZoneName_IPv6(t *testing.T) {
tests := []struct {
prefix string
expected string
}{
{"fd00:1234:5678::/48", "8.7.6.5.4.3.2.1.0.0.d.f.ip6.arpa."},
{"fd00::/16", "0.0.d.f.ip6.arpa."},
{"fd12:3456:789a:bcde::/64", "e.d.c.b.a.9.8.7.6.5.4.3.2.1.d.f.ip6.arpa."},
}

for _, tt := range tests {
t.Run(tt.prefix, func(t *testing.T) {
zone, err := generateReverseZoneName(netip.MustParsePrefix(tt.prefix))
require.NoError(t, err)
assert.Equal(t, tt.expected, zone)
})
}
}

func TestCollectPTRRecords_BothFamilies(t *testing.T) {
config := &nbdns.Config{
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud.",
Records: []nbdns.SimpleRecord{
{Name: "peer1.netbird.cloud.", Type: int(dns.TypeA), RData: "100.64.0.1"},
{Name: "peer1.netbird.cloud.", Type: int(dns.TypeAAAA), RData: "fd00::1"},
{Name: "peer2.netbird.cloud.", Type: int(dns.TypeA), RData: "100.64.0.2"},
},
},
},
}

v4Records := collectPTRRecords(config, netip.MustParsePrefix("100.64.0.0/16"))
assert.Len(t, v4Records, 2, "should collect 2 A record PTRs for the v4 prefix")

v6Records := collectPTRRecords(config, netip.MustParsePrefix("fd00::/64"))
assert.Len(t, v6Records, 1, "should collect 1 AAAA record PTR for the v6 prefix")
}

func TestAddReverseZone_IPv6(t *testing.T) {
config := &nbdns.Config{
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud.",
Records: []nbdns.SimpleRecord{
{Name: "peer1.netbird.cloud.", Type: int(dns.TypeAAAA), RData: "fd00:1234:5678::1"},
},
},
},
}

addReverseZone(config, netip.MustParsePrefix("fd00:1234:5678::/48"))

require.Len(t, config.CustomZones, 2)
reverseZone := config.CustomZones[1]
assert.Equal(t, "8.7.6.5.4.3.2.1.0.0.d.f.ip6.arpa.", reverseZone.Domain)
assert.Len(t, reverseZone.Records, 1)
assert.Equal(t, int(dns.TypePTR), reverseZone.Records[0].Type)
}
Loading
Loading