Skip to content

Commit 95d776b

Browse files
committed
wgengine/magicsock: only cache N most recent endpoints per-Addr
If a node is flapping or otherwise generating lots of STUN endpoints, we can end up caching a ton of useless values and sending them to peers. Instead, let's apply a fixed per-Addr limit of endpoints that we cache, so that we're only sending peers up to the N most recent. Updates tailscale/corp#13890 Signed-off-by: Andrew Dunham <andrew@du.nham.ca> Change-Id: I8079a05b44220c46da55016c0e5fc96dd2135ef8
1 parent 9c4364e commit 95d776b

File tree

5 files changed

+438
-191
lines changed

5 files changed

+438
-191
lines changed

cmd/tailscaled/depaware.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
292292
tailscale.com/tailcfg from tailscale.com/client/tailscale/apitype+
293293
💣 tailscale.com/tempfork/device from tailscale.com/net/tstun/table
294294
LD tailscale.com/tempfork/gliderlabs/ssh from tailscale.com/ssh/tailssh
295+
tailscale.com/tempfork/heap from tailscale.com/wgengine/magicsock
295296
tailscale.com/tka from tailscale.com/ipn/ipnlocal+
296297
W tailscale.com/tsconst from tailscale.com/net/interfaces
297298
tailscale.com/tsd from tailscale.com/cmd/tailscaled+
@@ -411,6 +412,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
411412
golang.org/x/time/rate from gvisor.dev/gvisor/pkg/tcpip/stack+
412413
bufio from compress/flate+
413414
bytes from bufio+
415+
cmp from slices
414416
compress/flate from compress/gzip+
415417
compress/gzip from golang.org/x/net/http2+
416418
W compress/zlib from debug/pe
@@ -495,6 +497,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
495497
runtime/debug from github.com/klauspost/compress/zstd+
496498
runtime/pprof from tailscale.com/log/logheap+
497499
runtime/trace from net/http/pprof
500+
slices from tailscale.com/wgengine/magicsock
498501
sort from compress/flate+
499502
strconv from compress/flate+
500503
strings from bufio+
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
// Copyright (c) Tailscale Inc & AUTHORS
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
package magicsock
5+
6+
import (
7+
"net/netip"
8+
"slices"
9+
"sync"
10+
"time"
11+
12+
"tailscale.com/tailcfg"
13+
"tailscale.com/tempfork/heap"
14+
"tailscale.com/util/mak"
15+
"tailscale.com/util/set"
16+
)
17+
18+
const (
19+
// endpointTrackerLifetime is how long we continue advertising an
20+
// endpoint after we last see it. This is intentionally chosen to be
21+
// slightly longer than a full netcheck period.
22+
endpointTrackerLifetime = 5*time.Minute + 10*time.Second
23+
24+
// endpointTrackerMaxPerAddr is how many cached addresses we track for
25+
// a given netip.Addr. This allows e.g. restricting the number of STUN
26+
// endpoints we cache (which usually have the same netip.Addr but
27+
// different ports).
28+
//
29+
// The value of 6 is chosen because we can advertise up to 3 endpoints
30+
// based on the STUN IP:
31+
// 1. The STUN endpoint itself (EndpointSTUN)
32+
// 2. The STUN IP with the local Tailscale port (EndpointSTUN4LocalPort)
33+
// 3. The STUN IP with a portmapped port (EndpointPortmapped)
34+
//
35+
// Storing 6 endpoints in the cache means we can store up to 2 previous
36+
// sets of endpoints.
37+
endpointTrackerMaxPerAddr = 6
38+
)
39+
40+
// endpointTrackerEntry is an entry in an endpointHeap that stores the state of
41+
// a given cached endpoint.
42+
type endpointTrackerEntry struct {
43+
// endpoint is the cached endpoint.
44+
endpoint tailcfg.Endpoint
45+
// until is the time until which this endpoint is being cached.
46+
until time.Time
47+
// index is the index within the containing endpointHeap.
48+
index int
49+
}
50+
51+
// endpointHeap is an ordered heap of endpointTrackerEntry structs, ordered in
52+
// ascending order by the 'until' expiry time (i.e. oldest first).
53+
type endpointHeap []*endpointTrackerEntry
54+
55+
var _ heap.Interface[*endpointTrackerEntry] = (*endpointHeap)(nil)
56+
57+
// Len implements heap.Interface.
58+
func (eh endpointHeap) Len() int { return len(eh) }
59+
60+
// Less implements heap.Interface.
61+
func (eh endpointHeap) Less(i, j int) bool {
62+
// We want to store items so that the lowest item in the heap is the
63+
// oldest, so that heap.Pop()-ing from the endpointHeap will remove the
64+
// oldest entry.
65+
return eh[i].until.Before(eh[j].until)
66+
}
67+
68+
// Swap implements heap.Interface.
69+
func (eh endpointHeap) Swap(i, j int) {
70+
eh[i], eh[j] = eh[j], eh[i]
71+
eh[i].index = i
72+
eh[j].index = j
73+
}
74+
75+
// Push implements heap.Interface.
76+
func (eh *endpointHeap) Push(item *endpointTrackerEntry) {
77+
n := len(*eh)
78+
item.index = n
79+
*eh = append(*eh, item)
80+
}
81+
82+
// Pop implements heap.Interface.
83+
func (eh *endpointHeap) Pop() *endpointTrackerEntry {
84+
old := *eh
85+
n := len(old)
86+
item := old[n-1]
87+
old[n-1] = nil // avoid memory leak
88+
item.index = -1 // for safety
89+
*eh = old[0 : n-1]
90+
return item
91+
}
92+
93+
// Min returns a pointer to the minimum element in the heap, without removing
94+
// it. Since this is a min-heap ordered by the 'until' field, this returns the
95+
// chronologically "earliest" element in the heap.
96+
//
97+
// Len() must be non-zero.
98+
func (eh endpointHeap) Min() *endpointTrackerEntry {
99+
return eh[0]
100+
}
101+
102+
// endpointTracker caches endpoints that are advertised to peers. This allows
103+
// peers to still reach this node if there's a temporary endpoint flap; rather
104+
// than withdrawing an endpoint and then re-advertising it the next time we run
105+
// a netcheck, we keep advertising the endpoint until it's not present for a
106+
// defined timeout.
107+
//
108+
// See tailscale/tailscale#7877 for more information.
109+
type endpointTracker struct {
110+
mu sync.Mutex
111+
endpoints map[netip.Addr]*endpointHeap
112+
}
113+
114+
// update takes as input the current sent of discovered endpoints and the
115+
// current time, and returns the set of endpoints plus any previous-cached and
116+
// non-expired endpoints that should be advertised to peers.
117+
func (et *endpointTracker) update(now time.Time, eps []tailcfg.Endpoint) (epsPlusCached []tailcfg.Endpoint) {
118+
var inputEps set.Slice[netip.AddrPort]
119+
for _, ep := range eps {
120+
inputEps.Add(ep.Addr)
121+
}
122+
123+
et.mu.Lock()
124+
defer et.mu.Unlock()
125+
126+
// Extend endpoints that already exist in the cache. We do this before
127+
// we remove expired endpoints, below, so we don't remove something
128+
// that would otherwise have survived by extending.
129+
until := now.Add(endpointTrackerLifetime)
130+
for _, ep := range eps {
131+
et.extendLocked(ep, until)
132+
}
133+
134+
// Now that we've extended existing endpoints, remove everything that
135+
// has expired.
136+
et.removeExpiredLocked(now)
137+
138+
// Add entries from the input set of endpoints into the cache; we do
139+
// this after removing expired ones so that we can store as many as
140+
// possible, with space freed by the entries removed after expiry.
141+
for _, ep := range eps {
142+
et.addLocked(now, ep, until)
143+
}
144+
145+
// Finally, add entries to the return array that aren't already there.
146+
epsPlusCached = eps
147+
for _, heap := range et.endpoints {
148+
for _, ep := range *heap {
149+
// If the endpoint was in the input list, or has expired, skip it.
150+
if inputEps.Contains(ep.endpoint.Addr) {
151+
continue
152+
} else if now.After(ep.until) {
153+
// Defense-in-depth; should never happen since
154+
// we removed expired entries above, but ignore
155+
// it anyway.
156+
continue
157+
}
158+
159+
// We haven't seen this endpoint; add to the return array
160+
epsPlusCached = append(epsPlusCached, ep.endpoint)
161+
}
162+
}
163+
164+
return epsPlusCached
165+
}
166+
167+
// extendLocked will update the expiry time of the provided endpoint in the
168+
// cache, if it is present. If it is not present, nothing will be done.
169+
//
170+
// et.mu must be held.
171+
func (et *endpointTracker) extendLocked(ep tailcfg.Endpoint, until time.Time) {
172+
key := ep.Addr.Addr()
173+
epHeap, found := et.endpoints[key]
174+
if !found {
175+
return
176+
}
177+
178+
// Find the entry for this exact address; this loop is quick since we
179+
// bound the number of items in the heap.
180+
//
181+
// TODO(andrew): this means we iterate over the entire heap once per
182+
// endpoint; even if the heap is small, if we have a lot of input
183+
// endpoints this can be expensive?
184+
for i, entry := range *epHeap {
185+
if entry.endpoint == ep {
186+
entry.until = until
187+
heap.Fix(epHeap, i)
188+
return
189+
}
190+
}
191+
}
192+
193+
// addLocked will store the provided endpoint(s) in the cache for a fixed
194+
// period of time, ensuring that the size of the endpoint cache remains below
195+
// the maximum.
196+
//
197+
// et.mu must be held.
198+
func (et *endpointTracker) addLocked(now time.Time, ep tailcfg.Endpoint, until time.Time) {
199+
key := ep.Addr.Addr()
200+
201+
// Create or get the heap for this endpoint's addr
202+
epHeap := et.endpoints[key]
203+
if epHeap == nil {
204+
epHeap = new(endpointHeap)
205+
mak.Set(&et.endpoints, key, epHeap)
206+
}
207+
208+
// Find the entry for this exact address; this loop is quick
209+
// since we bound the number of items in the heap.
210+
found := slices.ContainsFunc(*epHeap, func(v *endpointTrackerEntry) bool {
211+
return v.endpoint == ep
212+
})
213+
if !found {
214+
// Add address to heap; either the endpoint is new, or the heap
215+
// was newly-created and thus empty.
216+
heap.Push(epHeap, &endpointTrackerEntry{endpoint: ep, until: until})
217+
}
218+
219+
// Now that we've added everything, pop from our heap until we're below
220+
// the limit. This is a min-heap, so popping removes the lowest (and
221+
// thus oldest) endpoint.
222+
for epHeap.Len() > endpointTrackerMaxPerAddr {
223+
heap.Pop(epHeap)
224+
}
225+
}
226+
227+
// removeExpired will remove all expired entries from the cache.
228+
//
229+
// et.mu must be held.
230+
func (et *endpointTracker) removeExpiredLocked(now time.Time) {
231+
for k, epHeap := range et.endpoints {
232+
// The minimum element is oldest/earliest endpoint; repeatedly
233+
// pop from the heap while it's in the past.
234+
for epHeap.Len() > 0 {
235+
minElem := epHeap.Min()
236+
if now.After(minElem.until) {
237+
heap.Pop(epHeap)
238+
} else {
239+
break
240+
}
241+
}
242+
243+
if epHeap.Len() == 0 {
244+
// Free up space in the map by removing the empty heap.
245+
delete(et.endpoints, k)
246+
}
247+
}
248+
}

0 commit comments

Comments
 (0)