Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions client/internal/dns/mgmt/bypass_resolver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package mgmt

import (
"context"
"fmt"
"net"
"net/netip"

nbnet "github.com/netbirdio/netbird/client/net"
)

// NewBypassResolver builds a *net.Resolver that sends queries directly to
// the supplied nameservers through a socket that bypasses the NetBird
// overlay interface. This lets the mgmt cache refresh control-plane
// FQDNs (api/signal/relay/stun/turn) even when an exit-node default
// route is installed on the overlay before its peer is live.
//
// Returns nil if nameservers is empty. The caller must not pass
// loopback/overlay IPs (e.g. 127.0.0.1, the overlay listener address);
// those would defeat the purpose of bypassing.
func NewBypassResolver(nameservers []netip.Addr) *net.Resolver {
if len(nameservers) == 0 {
return nil
}

servers := make([]string, 0, len(nameservers))
for _, ns := range nameservers {
if !ns.IsValid() || ns.IsLoopback() || ns.IsUnspecified() {
continue
}
servers = append(servers, netip.AddrPortFrom(ns, 53).String())
}
if len(servers) == 0 {
return nil
}

return &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, _ string) (net.Conn, error) {
nbDialer := nbnet.NewDialer()
var lastErr error
for _, ns := range servers {
conn, err := nbDialer.DialContext(ctx, network, ns)
if err == nil {
return conn, nil
}
lastErr = err
}
if lastErr == nil {
return nil, fmt.Errorf("no bypass nameservers configured")
}
return nil, fmt.Errorf("dial bypass nameservers: %w", lastErr)
},
}
}
207 changes: 196 additions & 11 deletions client/internal/dns/mgmt/mgmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/url"
"os"
"slices"
"strconv"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -71,6 +72,14 @@ type Resolver struct {
refreshing map[dns.Question]*atomic.Bool

cacheTTL time.Duration

// bypassResolver, when non-nil, is used by osLookup instead of
// net.DefaultResolver. It is constructed by the caller to dial the
// original (pre-NetBird) system nameservers through a socket that
// bypasses the overlay interface (control-plane fwmark / bound iface),
// so that when an exit-node default route is installed before a peer
// is handshaked the refresh does not fail with ENOKEY.
bypassResolver *net.Resolver
}

// NewResolver creates a new management domains cache resolver.
Expand Down Expand Up @@ -98,8 +107,28 @@ func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
m.mutex.Unlock()
}

// SetBypassResolver installs a resolver that osLookup uses instead of
// net.DefaultResolver. It is intended to dial the original (pre-NetBird)
// system nameservers through a socket that does not follow the overlay
// default route, so that a refresh initiated while an exit node is active
// but its WireGuard peer is not yet installed cannot deadlock on ENOKEY.
// Passing nil restores use of net.DefaultResolver.
func (m *Resolver) SetBypassResolver(r *net.Resolver) {
m.mutex.Lock()
m.bypassResolver = r
m.mutex.Unlock()
}

// ServeDNS serves cached A/AAAA records. Stale entries are returned
// immediately and refreshed asynchronously (stale-while-revalidate).
//
// If the query name is not in the cache but falls under a pool-root
// domain (a domain the mgmt advertised in ServerDomains.Relay, whose
// instance subdomains like streamline-de-fra1-0.relay.netbird.io are
// part of the relay pool), resolve it on demand through the bypass
// resolver and cache the result. This is what lets the daemon reach
// a foreign relay FQDN after an exit-node default route has been
// installed on the overlay before its peer is live.
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
m.continueToNext(w, r)
Expand All @@ -126,6 +155,10 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
m.mutex.RUnlock()

if !found {
if m.isUnderPoolRoot(question.Name) {
m.resolveOnDemand(w, r, question)
return
}
m.continueToNext(w, r)
return
}
Expand Down Expand Up @@ -155,12 +188,117 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
}

// MatchSubdomains returns false since this resolver only handles exact domain matches
// for NetBird infrastructure domains (signal, relay, flow, etc.), not their subdomains.
// MatchSubdomains returns false by default: the bare resolver is registered
// against exact domains. Pool-root domains (currently Relay entries from
// ServerDomains) are registered through a subdomain-matching wrapper at
// the call site instead, so instance subdomains hit this handler and get
// the on-demand resolve path in ServeDNS.
func (m *Resolver) MatchSubdomains() bool {
return false
}

// isUnderPoolRoot reports whether fqdn is an instance subdomain under any
// pool-root domain advertised by the mgmt (currently ServerDomains.Relay),
// e.g. "streamline-de-fra1-0.relay.netbird.io." is under "relay.netbird.io".
// The pool-root itself is not considered a subdomain (it matches the exact
// cache entry populated by AddDomain instead).
//
// Canonicalization mirrors server.toZone — lowercase, strip trailing dot,
// and strip a leading "*." wildcard (via canonicalizePoolDomain) — so the
// membership check is consistent with the handler-chain registration that
// runs the same set through toZone. toZone itself lives in the parent dns
// package and cannot be imported from here without a cycle.
func (m *Resolver) isUnderPoolRoot(fqdn string) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
if m.serverDomains == nil {
return false
}
fqdn = canonicalizePoolDomain(fqdn)
if fqdn == "" {
return false
}
for _, root := range m.serverDomains.Relay {
r := canonicalizePoolDomain(root.PunycodeString())
if r == "" || fqdn == r {
continue
}
if strings.HasSuffix(fqdn, "."+r) {
return true
}
}
return false
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// canonicalizePoolDomain normalizes a domain for pool-root membership
// comparison: lowercase, trailing dot stripped, leading "*." wildcard
// stripped. Matches the transformation server.toZone applies on the
// handler-registration side (modulo trailing-dot orientation, which is
// self-consistent within this file).
func canonicalizePoolDomain(s string) string {
s = strings.ToLower(strings.TrimSuffix(s, "."))
s = strings.TrimPrefix(s, "*.")
return s
}

// resolveOnDemand resolves an uncached pool-root subdomain (e.g. a relay
// instance FQDN) through the bypass resolver path, caches the result, and
// writes it back to w. Falls through to the next handler on error so the
// normal chain can still attempt the resolve.
func (m *Resolver) resolveOnDemand(w dns.ResponseWriter, r *dns.Msg, question dns.Question) {
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
if err != nil {
log.Debugf("on-demand resolve: parse domain %q: %v", question.Name, err)
m.continueToNext(w, r)
return
}

// Collapse concurrent on-demand lookups for the same (name, qtype) into
// a single upstream query via singleflight. A burst of parallel queries
// for a freshly-learned pool-root subdomain (e.g. multiple peer workers
// dialing the same foreign relay, or A + AAAA racing each other) would
// otherwise each hit the bypass resolver independently. The prefix
// namespaces this key off scheduleRefresh's keyspace so the two paths
// can coexist without collisions.
key := "ondemand:" + question.Name + ":" + strconv.Itoa(int(question.Qtype))
result, err, _ := m.refreshGroup.Do(key, func() (any, error) {
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
return m.lookupRecords(ctx, d, question)
})
if err != nil {
log.Debugf("on-demand resolve %s type=%s: %v",
d.SafeString(), dns.TypeToString[question.Qtype], err)
m.continueToNext(w, r)
return
}
records, _ := result.([]dns.RR)
if len(records) == 0 {
m.continueToNext(w, r)
return
}

now := time.Now()
m.mutex.Lock()
if _, exists := m.records[question]; !exists {
m.records[question] = &cachedRecord{records: records, cachedAt: now}
}
m.mutex.Unlock()

resp := &dns.Msg{}
resp.SetReply(r)
resp.Authoritative = false
resp.RecursionAvailable = true
resp.Answer = cloneRecordsWithTTL(records, uint32(m.cacheTTL.Seconds()))

log.Debugf("on-demand resolved %d records for domain=%s", len(resp.Answer), question.Name)

if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write on-demand response: %v", err)
}
}
Comment thread
pappz marked this conversation as resolved.


// continueToNext signals the handler chain to continue to the next handler.
func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
Expand Down Expand Up @@ -315,14 +453,29 @@ func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedReco
return c.consecFailures
}

// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
// callers tell records, NODATA (nil err, no records), and failure apart.
// lookupBoth resolves A and AAAA via bypass resolver, chain, or OS.
// Per-family errors let callers tell records, NODATA (nil err, no records),
// and failure apart.
//
// Preference order:
// 1. bypassResolver (direct, overlay-bypassing dial to original system
// nameservers; immune to the exit-node ENOKEY race).
// 2. chain (handler chain; used when NetBird is the system resolver and
// no bypass resolver is installed).
// 3. net.DefaultResolver via osLookup (legacy fallback).
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
bypass := m.bypassResolver
m.mutex.RUnlock()

if bypass != nil {
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
return
}

if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
Expand All @@ -337,15 +490,22 @@ func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName stri
return
}

// lookupRecords resolves a single record type via chain or OS. The OS branch
// arms the loop detector for the duration of its call so that ServeDNS can
// spot the OS resolver routing the recursive query back to us.
// lookupRecords resolves a single record type. See lookupBoth for the
// preference order. The OS branch arms the loop detector for the duration
// of its call so that ServeDNS can spot the OS resolver routing the
// recursive query back to us; the bypass branch skips the loop detector
// because its dial does not enter the system resolver.
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
bypass := m.bypassResolver
m.mutex.RUnlock()

if bypass != nil {
return m.osLookup(ctx, d, q.Name, q.Qtype)
}

if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
}
Expand Down Expand Up @@ -394,9 +554,9 @@ func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxP
return filtered, nil
}

// osLookup resolves a single family via net.DefaultResolver using resutil,
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
// returns (nil, nil).
// osLookup resolves a single family via the bypass resolver (if configured)
// or net.DefaultResolver using resutil, which disambiguates NODATA from
// NXDOMAIN and Unmaps v4-mapped-v6. NODATA returns (nil, nil).
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
network := resutil.NetworkForQtype(qtype)
if network == "" {
Expand All @@ -406,7 +566,14 @@ func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])

result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
m.mutex.RLock()
resolver := m.bypassResolver
m.mutex.RUnlock()
if resolver == nil {
resolver = net.DefaultResolver
}

result := resutil.LookupIP(ctx, resolver, network, d.PunycodeString(), qtype)
if result.Rcode == dns.RcodeSuccess {
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
}
Expand Down Expand Up @@ -467,6 +634,24 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
return nil
}

// GetPoolRootDomains returns the set of domains that should be registered
// with subdomain matching (currently the Relay entries from ServerDomains).
// Instance subdomains under these roots are resolved on demand in ServeDNS.
func (m *Resolver) GetPoolRootDomains() domain.List {
m.mutex.RLock()
defer m.mutex.RUnlock()
if m.serverDomains == nil {
return nil
}
out := make(domain.List, 0, len(m.serverDomains.Relay))
for _, d := range m.serverDomains.Relay {
if d != "" {
out = append(out, d)
}
}
return out
}

// GetCachedDomains returns a list of all cached domains.
func (m *Resolver) GetCachedDomains() domain.List {
m.mutex.RLock()
Expand Down
Loading
Loading