diff --git a/client/internal/dns/mgmt/bypass_resolver.go b/client/internal/dns/mgmt/bypass_resolver.go new file mode 100644 index 00000000000..5a4c4442ce8 --- /dev/null +++ b/client/internal/dns/mgmt/bypass_resolver.go @@ -0,0 +1,55 @@ +package mgmt + +import ( + "context" + "fmt" + "net" + "net/netip" + + nbnet "github.com/netbirdio/netbird/client/net" +) + +// NewBypassResolver builds a *net.Resolver that sends queries directly to +// the supplied nameservers through a socket that bypasses the NetBird +// overlay interface. This lets the mgmt cache refresh control-plane +// FQDNs (api/signal/relay/stun/turn) even when an exit-node default +// route is installed on the overlay before its peer is live. +// +// Returns nil if nameservers is empty. The caller must not pass +// loopback/overlay IPs (e.g. 127.0.0.1, the overlay listener address); +// those would defeat the purpose of bypassing. +func NewBypassResolver(nameservers []netip.Addr) *net.Resolver { + if len(nameservers) == 0 { + return nil + } + + servers := make([]string, 0, len(nameservers)) + for _, ns := range nameservers { + if !ns.IsValid() || ns.IsLoopback() || ns.IsUnspecified() { + continue + } + servers = append(servers, netip.AddrPortFrom(ns, 53).String()) + } + if len(servers) == 0 { + return nil + } + + return &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, _ string) (net.Conn, error) { + nbDialer := nbnet.NewDialer() + var lastErr error + for _, ns := range servers { + conn, err := nbDialer.DialContext(ctx, network, ns) + if err == nil { + return conn, nil + } + lastErr = err + } + if lastErr == nil { + return nil, fmt.Errorf("no bypass nameservers configured") + } + return nil, fmt.Errorf("dial bypass nameservers: %w", lastErr) + }, + } +} diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index 988e427fb59..ee40720c1b5 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -8,6 +8,7 @@ import ( "net/url" "os" "slices" + "strconv" "strings" "sync" "sync/atomic" @@ -71,6 +72,14 @@ type Resolver struct { refreshing map[dns.Question]*atomic.Bool cacheTTL time.Duration + + // bypassResolver, when non-nil, is used by osLookup instead of + // net.DefaultResolver. It is constructed by the caller to dial the + // original (pre-NetBird) system nameservers through a socket that + // bypasses the overlay interface (control-plane fwmark / bound iface), + // so that when an exit-node default route is installed before a peer + // is handshaked the refresh does not fail with ENOKEY. + bypassResolver *net.Resolver } // NewResolver creates a new management domains cache resolver. @@ -98,8 +107,28 @@ func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) { m.mutex.Unlock() } +// SetBypassResolver installs a resolver that osLookup uses instead of +// net.DefaultResolver. It is intended to dial the original (pre-NetBird) +// system nameservers through a socket that does not follow the overlay +// default route, so that a refresh initiated while an exit node is active +// but its WireGuard peer is not yet installed cannot deadlock on ENOKEY. +// Passing nil restores use of net.DefaultResolver. +func (m *Resolver) SetBypassResolver(r *net.Resolver) { + m.mutex.Lock() + m.bypassResolver = r + m.mutex.Unlock() +} + // ServeDNS serves cached A/AAAA records. Stale entries are returned // immediately and refreshed asynchronously (stale-while-revalidate). +// +// If the query name is not in the cache but falls under a pool-root +// domain (a domain the mgmt advertised in ServerDomains.Relay, whose +// instance subdomains like streamline-de-fra1-0.relay.netbird.io are +// part of the relay pool), resolve it on demand through the bypass +// resolver and cache the result. This is what lets the daemon reach +// a foreign relay FQDN after an exit-node default route has been +// installed on the overlay before its peer is live. func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 0 { m.continueToNext(w, r) @@ -126,6 +155,10 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { m.mutex.RUnlock() if !found { + if m.isUnderPoolRoot(question.Name) { + m.resolveOnDemand(w, r, question) + return + } m.continueToNext(w, r) return } @@ -155,12 +188,117 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } } -// MatchSubdomains returns false since this resolver only handles exact domain matches -// for NetBird infrastructure domains (signal, relay, flow, etc.), not their subdomains. +// MatchSubdomains returns false by default: the bare resolver is registered +// against exact domains. Pool-root domains (currently Relay entries from +// ServerDomains) are registered through a subdomain-matching wrapper at +// the call site instead, so instance subdomains hit this handler and get +// the on-demand resolve path in ServeDNS. func (m *Resolver) MatchSubdomains() bool { return false } +// isUnderPoolRoot reports whether fqdn is an instance subdomain under any +// pool-root domain advertised by the mgmt (currently ServerDomains.Relay), +// e.g. "streamline-de-fra1-0.relay.netbird.io." is under "relay.netbird.io". +// The pool-root itself is not considered a subdomain (it matches the exact +// cache entry populated by AddDomain instead). +// +// Canonicalization mirrors server.toZone — lowercase, strip trailing dot, +// and strip a leading "*." wildcard (via canonicalizePoolDomain) — so the +// membership check is consistent with the handler-chain registration that +// runs the same set through toZone. toZone itself lives in the parent dns +// package and cannot be imported from here without a cycle. +func (m *Resolver) isUnderPoolRoot(fqdn string) bool { + m.mutex.RLock() + defer m.mutex.RUnlock() + if m.serverDomains == nil { + return false + } + fqdn = canonicalizePoolDomain(fqdn) + if fqdn == "" { + return false + } + for _, root := range m.serverDomains.Relay { + r := canonicalizePoolDomain(root.PunycodeString()) + if r == "" || fqdn == r { + continue + } + if strings.HasSuffix(fqdn, "."+r) { + return true + } + } + return false +} + +// canonicalizePoolDomain normalizes a domain for pool-root membership +// comparison: lowercase, trailing dot stripped, leading "*." wildcard +// stripped. Matches the transformation server.toZone applies on the +// handler-registration side (modulo trailing-dot orientation, which is +// self-consistent within this file). +func canonicalizePoolDomain(s string) string { + s = strings.ToLower(strings.TrimSuffix(s, ".")) + s = strings.TrimPrefix(s, "*.") + return s +} + +// resolveOnDemand resolves an uncached pool-root subdomain (e.g. a relay +// instance FQDN) through the bypass resolver path, caches the result, and +// writes it back to w. Falls through to the next handler on error so the +// normal chain can still attempt the resolve. +func (m *Resolver) resolveOnDemand(w dns.ResponseWriter, r *dns.Msg, question dns.Question) { + d, err := domain.FromString(strings.TrimSuffix(question.Name, ".")) + if err != nil { + log.Debugf("on-demand resolve: parse domain %q: %v", question.Name, err) + m.continueToNext(w, r) + return + } + + // Collapse concurrent on-demand lookups for the same (name, qtype) into + // a single upstream query via singleflight. A burst of parallel queries + // for a freshly-learned pool-root subdomain (e.g. multiple peer workers + // dialing the same foreign relay, or A + AAAA racing each other) would + // otherwise each hit the bypass resolver independently. The prefix + // namespaces this key off scheduleRefresh's keyspace so the two paths + // can coexist without collisions. + key := "ondemand:" + question.Name + ":" + strconv.Itoa(int(question.Qtype)) + result, err, _ := m.refreshGroup.Do(key, func() (any, error) { + ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) + defer cancel() + return m.lookupRecords(ctx, d, question) + }) + if err != nil { + log.Debugf("on-demand resolve %s type=%s: %v", + d.SafeString(), dns.TypeToString[question.Qtype], err) + m.continueToNext(w, r) + return + } + records, _ := result.([]dns.RR) + if len(records) == 0 { + m.continueToNext(w, r) + return + } + + now := time.Now() + m.mutex.Lock() + if _, exists := m.records[question]; !exists { + m.records[question] = &cachedRecord{records: records, cachedAt: now} + } + m.mutex.Unlock() + + resp := &dns.Msg{} + resp.SetReply(r) + resp.Authoritative = false + resp.RecursionAvailable = true + resp.Answer = cloneRecordsWithTTL(records, uint32(m.cacheTTL.Seconds())) + + log.Debugf("on-demand resolved %d records for domain=%s", len(resp.Answer), question.Name) + + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write on-demand response: %v", err) + } +} + + // continueToNext signals the handler chain to continue to the next handler. func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) { resp := &dns.Msg{} @@ -315,14 +453,29 @@ func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedReco return c.consecFailures } -// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let -// callers tell records, NODATA (nil err, no records), and failure apart. +// lookupBoth resolves A and AAAA via bypass resolver, chain, or OS. +// Per-family errors let callers tell records, NODATA (nil err, no records), +// and failure apart. +// +// Preference order: +// 1. bypassResolver (direct, overlay-bypassing dial to original system +// nameservers; immune to the exit-node ENOKEY race). +// 2. chain (handler chain; used when NetBird is the system resolver and +// no bypass resolver is installed). +// 3. net.DefaultResolver via osLookup (legacy fallback). func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) { m.mutex.RLock() chain := m.chain maxPriority := m.chainMaxPriority + bypass := m.bypassResolver m.mutex.RUnlock() + if bypass != nil { + aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA) + aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA) + return + } + if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) { aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA) aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA) @@ -337,15 +490,22 @@ func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName stri return } -// lookupRecords resolves a single record type via chain or OS. The OS branch -// arms the loop detector for the duration of its call so that ServeDNS can -// spot the OS resolver routing the recursive query back to us. +// lookupRecords resolves a single record type. See lookupBoth for the +// preference order. The OS branch arms the loop detector for the duration +// of its call so that ServeDNS can spot the OS resolver routing the +// recursive query back to us; the bypass branch skips the loop detector +// because its dial does not enter the system resolver. func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) { m.mutex.RLock() chain := m.chain maxPriority := m.chainMaxPriority + bypass := m.bypassResolver m.mutex.RUnlock() + if bypass != nil { + return m.osLookup(ctx, d, q.Name, q.Qtype) + } + if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) { return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype) } @@ -394,9 +554,9 @@ func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxP return filtered, nil } -// osLookup resolves a single family via net.DefaultResolver using resutil, -// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA -// returns (nil, nil). +// osLookup resolves a single family via the bypass resolver (if configured) +// or net.DefaultResolver using resutil, which disambiguates NODATA from +// NXDOMAIN and Unmaps v4-mapped-v6. NODATA returns (nil, nil). func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) { network := resutil.NetworkForQtype(qtype) if network == "" { @@ -406,7 +566,14 @@ func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype]) defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype]) - result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype) + m.mutex.RLock() + resolver := m.bypassResolver + m.mutex.RUnlock() + if resolver == nil { + resolver = net.DefaultResolver + } + + result := resutil.LookupIP(ctx, resolver, network, d.PunycodeString(), qtype) if result.Rcode == dns.RcodeSuccess { return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil } @@ -467,6 +634,24 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error { return nil } +// GetPoolRootDomains returns the set of domains that should be registered +// with subdomain matching (currently the Relay entries from ServerDomains). +// Instance subdomains under these roots are resolved on demand in ServeDNS. +func (m *Resolver) GetPoolRootDomains() domain.List { + m.mutex.RLock() + defer m.mutex.RUnlock() + if m.serverDomains == nil { + return nil + } + out := make(domain.List, 0, len(m.serverDomains.Relay)) + for _, d := range m.serverDomains.Relay { + if d != "" { + out = append(out, d) + } + } + return out +} + // GetCachedDomains returns a list of all cached domains. func (m *Resolver) GetCachedDomains() domain.List { m.mutex.RLock() diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index d4f54dec581..736d0506ce2 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -31,6 +31,28 @@ import ( const envSkipDNSProbe = "NB_SKIP_DNS_PROBE" +// subdomainMatchHandler is a thin wrapper used to register a handler under +// a pool-root domain (e.g. a relay URL advertised by the mgmt) with +// subdomain matching enabled. The underlying handler's own MatchSubdomains +// is left untouched so that exact-match registrations keep their +// semantics. +type subdomainMatchHandler struct { + dns.Handler +} + +// MatchSubdomains lets the handler chain route any instance subdomain +// (e.g. streamline-de-fra1-0.relay.netbird.io) to the wrapped handler. +func (subdomainMatchHandler) MatchSubdomains() bool { return true } + +// String returns a debug-friendly name; the chain uses fmt.Stringer for +// its "registering handler X" logs. +func (h subdomainMatchHandler) String() string { + if s, ok := h.Handler.(fmt.Stringer); ok { + return s.String() + "[subdomains]" + } + return "subdomainMatchHandler" +} + // ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes type ReadyListener interface { OnReady() @@ -95,6 +117,11 @@ type DefaultServer struct { batchMode bool mgmtCacheResolver *mgmt.Resolver + // mgmtPoolRoots tracks pool-root domains currently contributed to + // extraDomains by the mgmt cache, so the next UpdateServerConfig can + // decrement the old set before incrementing the new one without + // disturbing unrelated registerHandler callers. + mgmtPoolRoots map[domain.Domain]struct{} // permanent related properties permanent bool @@ -229,6 +256,7 @@ func newDefaultServer( hostsDNSHolder: newHostsDNSHolder(), hostManager: &noopHostConfigurator{}, mgmtCacheResolver: mgmtCacheResolver, + mgmtPoolRoots: make(map[domain.Domain]struct{}), currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied } @@ -587,25 +615,92 @@ func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) erro s.mux.Lock() defer s.mux.Unlock() - if s.mgmtCacheResolver != nil { - removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains) - if err != nil { - return fmt.Errorf("update management cache resolver: %w", err) - } + if s.mgmtCacheResolver == nil { + return nil + } - if len(removedDomains) > 0 { - s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache) - } + removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains) + if err != nil { + return fmt.Errorf("update management cache resolver: %w", err) + } - newDomains := s.mgmtCacheResolver.GetCachedDomains() - if len(newDomains) > 0 { - s.registerHandler(newDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache) - } + if len(removedDomains) > 0 { + s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache) } + poolRoots := s.mgmtCacheResolver.GetPoolRootDomains() + s.registerMgmtCacheHandlers(poolRoots) + s.reconcileMgmtPoolRoots(poolRoots) + + if !s.batchMode { + s.applyHostConfig() + } return nil } +// registerMgmtCacheHandlers wires the mgmt cache resolver into the handler +// chain for the current set of cached domains. Pool-root domains (advertised +// by the mgmt as Relay URLs) go through a thin subdomain-matching wrapper so +// a query like "streamline-de-fra1-0.relay.netbird.io" routes to the mgmt +// cache resolver, which resolves it on demand through the bypass resolver +// instead of falling through to the overlay-routed upstream handler. +// +// Canonicalize with toZone on both sides of the pool-root membership check so +// the comparison is independent of each source's canonical form: +// GetPoolRootDomains returns what the extractor stored; GetCachedDomains +// strips the trailing dot from question names. +func (s *DefaultServer) registerMgmtCacheHandlers(poolRoots domain.List) { + poolRootSet := make(map[domain.Domain]struct{}, len(poolRoots)) + for _, d := range poolRoots { + poolRootSet[toZone(d)] = struct{}{} + } + + if len(poolRoots) > 0 { + s.registerHandler(poolRoots.ToPunycodeList(), subdomainMatchHandler{Handler: s.mgmtCacheResolver}, PriorityMgmtCache) + } + + var exactDomains domain.List + for _, d := range s.mgmtCacheResolver.GetCachedDomains() { + if _, isPool := poolRootSet[toZone(d)]; isPool { + continue + } + exactDomains = append(exactDomains, d) + } + if len(exactDomains) > 0 { + s.registerHandler(exactDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache) + } +} + +// reconcileMgmtPoolRoots keeps extraDomains in sync with the current mgmt +// pool-root set. These entries show up as *match* domains for the host DNS +// manager (systemd-resolved, NetworkManager, etc.) so instance subdomain +// queries like streamline-* are delegated to the wt0 link where the daemon's +// DNS listener sits. Without this, systemd-resolved answers them from the +// host's global upstream, skipping our handler chain entirely. +// +// Uses s.mgmtPoolRoots as a dedicated tracking map so increments/decrements +// here don't collide with RegisterHandler's refcounting. +func (s *DefaultServer) reconcileMgmtPoolRoots(poolRoots domain.List) { + newPoolRoots := make(map[domain.Domain]struct{}, len(poolRoots)) + for _, d := range poolRoots { + zone := toZone(d) + newPoolRoots[zone] = struct{}{} + if _, already := s.mgmtPoolRoots[zone]; !already { + s.extraDomains[zone]++ + } + } + for zone := range s.mgmtPoolRoots { + if _, keep := newPoolRoots[zone]; keep { + continue + } + s.extraDomains[zone]-- + if s.extraDomains[zone] <= 0 { + delete(s.extraDomains, zone) + } + } + s.mgmtPoolRoots = newPoolRoots +} + func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { // is the service should be Disabled, we stop the listener or fake resolver if update.ServiceEnable { @@ -759,6 +854,9 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) { originalNameservers := hostMgrWithNS.getOriginalNameservers() if len(originalNameservers) == 0 { s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback) + if s.mgmtCacheResolver != nil { + s.mgmtCacheResolver.SetBypassResolver(nil) + } return } @@ -777,6 +875,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) { } handler.routeMatch = s.routeMatch + var bypassNameservers []netip.Addr for _, ns := range originalNameservers { if ns == config.ServerIP { log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP) @@ -785,11 +884,22 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) { addrPort := netip.AddrPortFrom(ns, DefaultPort) handler.upstreamServers = append(handler.upstreamServers, addrPort) + bypassNameservers = append(bypassNameservers, ns) } handler.deactivate = func(error) { /* always active */ } handler.reactivate = func() { /* always active */ } s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback) + + // Wire a bypass resolver into the mgmt cache so its refresh path dials + // the original nameservers directly over a fwmarked socket, avoiding + // the ENOKEY deadlock that occurs when an exit-node default route is + // installed on the overlay before its peer has handshaked. Scoped to + // the mgmt cache only: ordinary user DNS still flows through the + // normal upstream path. + if s.mgmtCacheResolver != nil { + s.mgmtCacheResolver.SetBypassResolver(mgmt.NewBypassResolver(bypassNameservers)) + } } func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.CustomZone, error) {