diff --git a/.github/workflows/vet.yml b/.github/workflows/vet.yml new file mode 100644 index 0000000000000..7eff6b45fd37b --- /dev/null +++ b/.github/workflows/vet.yml @@ -0,0 +1,38 @@ +name: tailscale.com/cmd/vet + +env: + HOME: ${{ github.workspace }} + # GOMODCACHE is the same definition on all OSes. Within the workspace, we use + # toplevel directories "src" (for the checked out source code), and "gomodcache" + # and other caches as siblings to follow. + GOMODCACHE: ${{ github.workspace }}/gomodcache + +on: + push: + branches: + - main + - "release-branch/*" + paths: + - "**.go" + pull_request: + paths: + - "**.go" + +jobs: + vet: + runs-on: [ self-hosted, linux ] + timeout-minutes: 5 + + steps: + - name: Check out code + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: src + + - name: Build 'go vet' tool + working-directory: src + run: ./tool/go build -o /tmp/vettool tailscale.com/cmd/vet + + - name: Run 'go vet' + working-directory: src + run: ./tool/go vet -vettool=/tmp/vettool tailscale.com/... diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index a5877cb112eff..348483df57558 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,147 +1,103 @@ -# Contributor Covenant Code of Conduct +# Tailscale Community Code of Conduct ## Our Pledge -We are committed to creating an open, welcoming, diverse, inclusive, -healthy and respectful community. +We are committed to creating an open, welcoming, diverse, inclusive, healthy and respectful community. +Unacceptable, harmful and inappropriate behavior will not be tolerated. ## Our Standards -Examples of behavior that contributes to a positive environment for our -community include: -* Demonstrating empathy and kindness toward other people. -* Being respectful of differing opinions, viewpoints, and experiences. -* Giving and gracefully accepting constructive feedback. -* Accepting responsibility and apologizing to those affected by our - mistakes, and learning from the experience. -* Focusing on what is best not just for us as individuals, but for the - overall community. +Examples of behavior that contributes to a positive environment for our community include: + +- Demonstrating empathy and kindness toward other people. +- Being respectful of differing opinions, viewpoints, and experiences. +- Giving and gracefully accepting constructive feedback. +- Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience. +- Focusing on what is best not just for us as individuals, but for the overall community. Examples of unacceptable behavior include without limitation: -* The use of sexualized language or imagery, and sexual attention or - advances of any kind. -* The use of violent, intimidating or bullying language or imagery. -* Trolling, insulting or derogatory comments, and personal or - political attacks. -* Public or private harassment. -* Publishing others' private information, such as a physical or email - address, without their explicit permission. -* Spamming community channels and members, such as sending repeat messages, - low-effort content, or automated messages. -* Phishing or any similar activity; -* Distributing or promoting malware; -* Other conduct which could reasonably be considered inappropriate in a - professional setting. - -Please also see the Tailscale Acceptable Use Policy, available at -[tailscale.com/tailscale-aup](https://tailscale.com/tailscale-aup). - -# Reporting Incidents - -Instances of abusive, harassing, or otherwise unacceptable behavior -may be reported to Tailscale directly via info@tailscale.com, or to -the community leaders or moderators via DM or similar. + +- The use of language, imagery or emojis (collectively "content") that is racist, sexist, homophobic, transphobic, or otherwise harassing or discriminatory based on any protected characteristic. +- The use of sexualized content and sexual attention or advances of any kind. +- The use of violent, intimidating or bullying content. +- Trolling, concern trolling, insulting or derogatory comments, and personal or political attacks. +- Public or private harassment. +- Publishing others' personal information, such as a photo, physical address, email address, online profile information, or other personal information, without their explicit permission or with the intent to bully or harass the other person. +- Posting deep fake or other AI generated content about or involving another person without the explicit permission. +- Spamming community channels and members, such as sending repeat messages, low-effort content, or automated messages. +- Phishing or any similar activity. +- Distributing or promoting malware. +- The use of any coded or suggestive content to hide or provoke otherwise unacceptable behavior. +- Other conduct which could reasonably be considered harmful, illegal, or inappropriate in a professional setting. + +Please also see the Tailscale Acceptable Use Policy, available at [tailscale.com/tailscale-aup](https://tailscale.com/tailscale-aup). + +## Reporting Incidents + +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to Tailscale directly via , or to the community leaders or moderators via DM or similar. All complaints will be reviewed and investigated promptly and fairly. We will respect the privacy and safety of the reporter of any issues. -Please note that this community is not moderated by staff 24/7, and we -do not have, and do not undertake, any obligation to prescreen, monitor, -edit, or remove any content or data, or to actively seek facts or -circumstances indicating illegal activity. While we strive to keep the -community safe and welcoming, moderation may not be immediate at all hours. +Please note that this community is not moderated by staff 24/7, and we do not have, and do not undertake, any obligation to prescreen, monitor, edit, or remove any content or data, or to actively seek facts or circumstances indicating illegal activity. +While we strive to keep the community safe and welcoming, moderation may not be immediate at all hours. If you encounter any issues, report them using the appropriate channels. -## Enforcement - -Community leaders and moderators are responsible for clarifying and -enforcing our standards of acceptable behavior and will take appropriate -and fair corrective action in response to any behavior that they deem -inappropriate, threatening, offensive, or harmful. +## Enforcement Guidelines -Community leaders and moderators have the right and responsibility to remove, -edit, or reject comments, commits, code, wiki edits, issues, and other -contributions that are not aligned to this Community Code of Conduct. -Tailscale retains full discretion to take action (or not) in response -to a violation of these guidelines with or without notice or liability -to you. We will interpret our policies and resolve disputes in favor of -protecting users, customers, the public, our community and our company, -as a whole. +Community leaders and moderators are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. -## Enforcement Guidelines +Community leaders and moderators have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Community Code of Conduct. +Tailscale retains full discretion to take action (or not) in response to a violation of these guidelines with or without notice or liability to you. +We will interpret our policies and resolve disputes in favor of protecting users, customers, the public, our community and our company, as a whole. -Community leaders will follow these Community Impact Guidelines in -determining the consequences for any action they deem in violation of -this Code of Conduct: +Community leaders will follow these community enforcement guidelines in determining the consequences for any action they deem in violation of this Code of Conduct, +and retain full discretion to apply the enforcement guidelines as necessary depending on the circumstances: ### 1. Correction -Community Impact: Use of inappropriate language or other behavior -deemed unprofessional or unwelcome in the community. +Community Impact: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. -Consequence: A private, written warning from community leaders, -providing clarity around the nature of the violation and an -explanation of why the behavior was inappropriate. A public apology -may be requested. +Consequence: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. +A public apology may be requested. ### 2. Warning -Community Impact: A violation through a single incident or series -of actions. +Community Impact: A violation through a single incident or series of actions. -Consequence: A warning with consequences for continued -behavior. No interaction with the people involved, including -unsolicited interaction with those enforcing this Community Code of Conduct, -for a specified period of time. This includes avoiding interactions in -community spaces as well as external channels like social -media. Violating these terms may lead to a temporary or permanent ban. +Consequence: A warning with consequences for continued behavior. +No interaction with the people involved, including unsolicited interaction with those enforcing this Community Code of Conduct, for a specified period of time. +This includes avoiding interactions in community spaces as well as external channels like social media. +Violating these terms may lead to a temporary or permanent ban. ### 3. Temporary Ban -Community Impact: A serious violation of community standards, -including sustained inappropriate behavior. +Community Impact: A serious violation of community standards, including sustained inappropriate behavior. -Consequence: A temporary ban from any sort of interaction or -public communication with the community for a specified period of -time. No public or private interaction with the people involved, -including unsolicited interaction with those enforcing the Code of Conduct, -is allowed during this period. Violating these terms may lead to a permanent ban. +Consequence: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. +No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. ### 4. Permanent Ban -Community Impact: Demonstrating a pattern of violation of community -standards, including sustained inappropriate behavior, harassment of -an individual, or aggression toward or disparagement of -classes of individuals. +Community Impact: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. -Consequence: A permanent ban from any sort of public interaction -within the community. +Consequence: A permanent ban from any sort of public interaction within the community. ## Acceptable Use Policy -Violation of this Community Code of Conduct may also violate the -Tailscale Acceptable Use Policy, which may result in suspension or -termination of your Tailscale account. For more information, please -see the Tailscale Acceptable Use Policy, available at -[tailscale.com/tailscale-aup](https://tailscale.com/tailscale-aup). +Violation of this Community Code of Conduct may also violate the Tailscale Acceptable Use Policy, which may result in suspension or termination of your Tailscale account. +For more information, please see the Tailscale Acceptable Use Policy, available at [tailscale.com/tailscale-aup](https://tailscale.com/tailscale-aup). ## Privacy -Please see the Tailscale [Privacy Policy](http://tailscale.com/privacy-policy) -for more information about how Tailscale collects, uses, discloses and protects -information. +Please see the Tailscale [Privacy Policy](https://tailscale.com/privacy-policy) for more information about how Tailscale collects, uses, discloses and protects information. ## Attribution -This Code of Conduct is adapted from the [Contributor -Covenant][homepage], version 2.0, available at -https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at . -Community Impact Guidelines were inspired by [Mozilla's code of -conduct enforcement ladder](https://github.com/mozilla/diversity). +Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). [homepage]: https://www.contributor-covenant.org -For answers to common questions about this code of conduct, see the -FAQ at https://www.contributor-covenant.org/faq. Translations are -available at https://www.contributor-covenant.org/translations. - +For answers to common questions about this code of conduct, see the FAQ at . +Translations are available at . diff --git a/VERSION.txt b/VERSION.txt index 237c0b66ad7cd..7f47174bcbd13 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -1.90.8 +1.92.5 diff --git a/appc/appconnector.go b/appc/appconnector.go index e7b5032f0edc4..d41f9e8ba6357 100644 --- a/appc/appconnector.go +++ b/appc/appconnector.go @@ -16,9 +16,9 @@ import ( "net/netip" "slices" "strings" - "sync" "time" + "tailscale.com/syncs" "tailscale.com/types/appctype" "tailscale.com/types/logger" "tailscale.com/types/views" @@ -139,7 +139,7 @@ type AppConnector struct { hasStoredRoutes bool // mu guards the fields that follow - mu sync.Mutex + mu syncs.Mutex // domains is a map of lower case domain names with no trailing dot, to an // ordered list of resolved IP addresses. @@ -203,12 +203,12 @@ func NewAppConnector(c Config) *AppConnector { ac.wildcards = c.RouteInfo.Wildcards ac.controlRoutes = c.RouteInfo.Control } - ac.writeRateMinute = newRateLogger(time.Now, time.Minute, func(c int64, s time.Time, l int64) { - ac.logf("routeInfo write rate: %d in minute starting at %v (%d routes)", c, s, l) - metricStoreRoutes(c, l) + ac.writeRateMinute = newRateLogger(time.Now, time.Minute, func(c int64, s time.Time, ln int64) { + ac.logf("routeInfo write rate: %d in minute starting at %v (%d routes)", c, s, ln) + metricStoreRoutes(c, ln) }) - ac.writeRateDay = newRateLogger(time.Now, 24*time.Hour, func(c int64, s time.Time, l int64) { - ac.logf("routeInfo write rate: %d in 24 hours starting at %v (%d routes)", c, s, l) + ac.writeRateDay = newRateLogger(time.Now, 24*time.Hour, func(c int64, s time.Time, ln int64) { + ac.logf("routeInfo write rate: %d in 24 hours starting at %v (%d routes)", c, s, ln) }) return ac } @@ -510,8 +510,8 @@ func (e *AppConnector) addDomainAddrLocked(domain string, addr netip.Addr) { slices.SortFunc(e.domains[domain], compareAddr) } -func compareAddr(l, r netip.Addr) int { - return l.Compare(r) +func compareAddr(a, b netip.Addr) int { + return a.Compare(b) } // routesWithout returns a without b where a and b diff --git a/appc/ippool.go b/appc/ippool.go new file mode 100644 index 0000000000000..a2e86a7c296a8 --- /dev/null +++ b/appc/ippool.go @@ -0,0 +1,61 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package appc + +import ( + "errors" + "net/netip" + + "go4.org/netipx" +) + +// errPoolExhausted is returned when there are no more addresses to iterate over. +var errPoolExhausted = errors.New("ip pool exhausted") + +// ippool allows for iteration over all the addresses within a netipx.IPSet. +// netipx.IPSet has a Ranges call that returns the "minimum and sorted set of IP ranges that covers [the set]". +// netipx.IPRange is "an inclusive range of IP addresses from the same address family.". So we can iterate over +// all the addresses in the set by keeping a track of the last address we returned, calling Next on the last address +// to get the new one, and if we run off the edge of the current range, starting on the next one. +type ippool struct { + // ranges defines the addresses in the pool + ranges []netipx.IPRange + // last is internal tracking of which the last address provided was. + last netip.Addr + // rangeIdx is internal tracking of which netipx.IPRange from the IPSet we are currently on. + rangeIdx int +} + +func newIPPool(ipset *netipx.IPSet) *ippool { + if ipset == nil { + return &ippool{} + } + return &ippool{ranges: ipset.Ranges()} +} + +// next returns the next address from the set, or errPoolExhausted if we have +// iterated over the whole set. +func (ipp *ippool) next() (netip.Addr, error) { + if ipp.rangeIdx >= len(ipp.ranges) { + // ipset is empty or we have iterated off the end + return netip.Addr{}, errPoolExhausted + } + if !ipp.last.IsValid() { + // not initialized yet + ipp.last = ipp.ranges[0].From() + return ipp.last, nil + } + currRange := ipp.ranges[ipp.rangeIdx] + if ipp.last == currRange.To() { + // then we need to move to the next range + ipp.rangeIdx++ + if ipp.rangeIdx >= len(ipp.ranges) { + return netip.Addr{}, errPoolExhausted + } + ipp.last = ipp.ranges[ipp.rangeIdx].From() + return ipp.last, nil + } + ipp.last = ipp.last.Next() + return ipp.last, nil +} diff --git a/appc/ippool_test.go b/appc/ippool_test.go new file mode 100644 index 0000000000000..64b76738f661e --- /dev/null +++ b/appc/ippool_test.go @@ -0,0 +1,60 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package appc + +import ( + "errors" + "net/netip" + "testing" + + "go4.org/netipx" + "tailscale.com/util/must" +) + +func TestNext(t *testing.T) { + a := ippool{} + _, err := a.next() + if !errors.Is(err, errPoolExhausted) { + t.Fatalf("expected errPoolExhausted, got %v", err) + } + + var isb netipx.IPSetBuilder + ipset := must.Get(isb.IPSet()) + b := newIPPool(ipset) + _, err = b.next() + if !errors.Is(err, errPoolExhausted) { + t.Fatalf("expected errPoolExhausted, got %v", err) + } + + isb.AddRange(netipx.IPRangeFrom(netip.MustParseAddr("192.168.0.0"), netip.MustParseAddr("192.168.0.2"))) + isb.AddRange(netipx.IPRangeFrom(netip.MustParseAddr("200.0.0.0"), netip.MustParseAddr("200.0.0.0"))) + isb.AddRange(netipx.IPRangeFrom(netip.MustParseAddr("201.0.0.0"), netip.MustParseAddr("201.0.0.1"))) + ipset = must.Get(isb.IPSet()) + c := newIPPool(ipset) + expected := []string{ + "192.168.0.0", + "192.168.0.1", + "192.168.0.2", + "200.0.0.0", + "201.0.0.0", + "201.0.0.1", + } + for i, want := range expected { + addr, err := c.next() + if err != nil { + t.Fatal(err) + } + if addr != netip.MustParseAddr(want) { + t.Fatalf("next call %d want: %s, got: %v", i, want, addr) + } + } + _, err = c.next() + if !errors.Is(err, errPoolExhausted) { + t.Fatalf("expected errPoolExhausted, got %v", err) + } + _, err = c.next() + if !errors.Is(err, errPoolExhausted) { + t.Fatalf("expected errPoolExhausted, got %v", err) + } +} diff --git a/atomicfile/atomicfile_test.go b/atomicfile/atomicfile_test.go index 78c93e664f738..a081c90409788 100644 --- a/atomicfile/atomicfile_test.go +++ b/atomicfile/atomicfile_test.go @@ -31,11 +31,11 @@ func TestDoesNotOverwriteIrregularFiles(t *testing.T) { // The least troublesome thing to make that is not a file is a unix socket. // Making a null device sadly requires root. - l, err := net.ListenUnix("unix", &net.UnixAddr{Name: path, Net: "unix"}) + ln, err := net.ListenUnix("unix", &net.UnixAddr{Name: path, Net: "unix"}) if err != nil { t.Fatal(err) } - defer l.Close() + defer ln.Close() err = WriteFile(path, []byte("hello"), 0644) if err == nil { diff --git a/atomicfile/zsyscall_windows.go b/atomicfile/zsyscall_windows.go index f2f0b6d08cbb7..bd1bf8113ca2a 100644 --- a/atomicfile/zsyscall_windows.go +++ b/atomicfile/zsyscall_windows.go @@ -44,7 +44,7 @@ var ( ) func replaceFileW(replaced *uint16, replacement *uint16, backup *uint16, flags uint32, exclude unsafe.Pointer, reserved unsafe.Pointer) (err error) { - r1, _, e1 := syscall.Syscall6(procReplaceFileW.Addr(), 6, uintptr(unsafe.Pointer(replaced)), uintptr(unsafe.Pointer(replacement)), uintptr(unsafe.Pointer(backup)), uintptr(flags), uintptr(exclude), uintptr(reserved)) + r1, _, e1 := syscall.SyscallN(procReplaceFileW.Addr(), uintptr(unsafe.Pointer(replaced)), uintptr(unsafe.Pointer(replacement)), uintptr(unsafe.Pointer(backup)), uintptr(flags), uintptr(exclude), uintptr(reserved)) if int32(r1) == 0 { err = errnoErr(e1) } diff --git a/chirp/chirp_test.go b/chirp/chirp_test.go index a57ef224b2c1b..c545c277d6e87 100644 --- a/chirp/chirp_test.go +++ b/chirp/chirp_test.go @@ -24,7 +24,7 @@ type fakeBIRD struct { func newFakeBIRD(t *testing.T, protocols ...string) *fakeBIRD { sock := filepath.Join(t.TempDir(), "sock") - l, err := net.Listen("unix", sock) + ln, err := net.Listen("unix", sock) if err != nil { t.Fatal(err) } @@ -33,7 +33,7 @@ func newFakeBIRD(t *testing.T, protocols ...string) *fakeBIRD { pe[p] = false } return &fakeBIRD{ - Listener: l, + Listener: ln, protocolsEnabled: pe, sock: sock, } @@ -123,12 +123,12 @@ type hangingListener struct { func newHangingListener(t *testing.T) *hangingListener { sock := filepath.Join(t.TempDir(), "sock") - l, err := net.Listen("unix", sock) + ln, err := net.Listen("unix", sock) if err != nil { t.Fatal(err) } return &hangingListener{ - Listener: l, + Listener: ln, t: t, done: make(chan struct{}), sock: sock, diff --git a/client/local/local.go b/client/local/local.go index 2382a12252a20..72ddbb55f773a 100644 --- a/client/local/local.go +++ b/client/local/local.go @@ -38,6 +38,7 @@ import ( "tailscale.com/net/udprelay/status" "tailscale.com/paths" "tailscale.com/safesocket" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/types/appctype" "tailscale.com/types/dnstype" @@ -1363,7 +1364,7 @@ type IPNBusWatcher struct { httpRes *http.Response dec *json.Decoder - mu sync.Mutex + mu syncs.Mutex closed bool } @@ -1400,6 +1401,23 @@ func (lc *Client) SuggestExitNode(ctx context.Context) (apitype.ExitNodeSuggesti return decodeJSON[apitype.ExitNodeSuggestionResponse](body) } +// CheckSOMarkInUse reports whether the socket mark option is in use. This will only +// be true if tailscale is running on Linux and tailscaled uses SO_MARK. +func (lc *Client) CheckSOMarkInUse(ctx context.Context) (bool, error) { + body, err := lc.get200(ctx, "/localapi/v0/check-so-mark-in-use") + if err != nil { + return false, err + } + var res struct { + UseSOMark bool `json:"useSoMark"` + } + + if err := json.Unmarshal(body, &res); err != nil { + return false, fmt.Errorf("invalid JSON from check-so-mark-in-use: %w", err) + } + return res.UseSOMark, nil +} + // ShutdownTailscaled requests a graceful shutdown of tailscaled. func (lc *Client) ShutdownTailscaled(ctx context.Context) error { _, err := lc.send(ctx, "POST", "/localapi/v0/shutdown", 200, nil) diff --git a/client/systray/systray.go b/client/systray/systray.go index 4ac08058854e4..bc099a1ec23a2 100644 --- a/client/systray/systray.go +++ b/client/systray/systray.go @@ -158,6 +158,18 @@ func init() { // onReady is called by the systray package when the menu is ready to be built. func (menu *Menu) onReady() { log.Printf("starting") + if os.Getuid() == 0 || os.Getuid() != os.Geteuid() || os.Getenv("SUDO_USER") != "" || os.Getenv("DOAS_USER") != "" { + fmt.Fprintln(os.Stderr, ` +It appears that you might be running the systray with sudo/doas. +This can lead to issues with D-Bus, and should be avoided. + +The systray application should be run with the same user as your desktop session. +This usually means that you should run the application like: + +tailscale systray + +See https://tailscale.com/kb/1597/linux-systray for more information.`) + } setAppIcon(disconnected) menu.rebuild() @@ -500,7 +512,7 @@ func (menu *Menu) watchIPNBus() { } func (menu *Menu) watchIPNBusInner() error { - watcher, err := menu.lc.WatchIPNBus(menu.bgCtx, ipn.NotifyNoPrivateKeys) + watcher, err := menu.lc.WatchIPNBus(menu.bgCtx, 0) if err != nil { return fmt.Errorf("watching ipn bus: %w", err) } diff --git a/client/web/src/hooks/exit-nodes.ts b/client/web/src/hooks/exit-nodes.ts index b3ce0a9fa12ec..5e47fbc227cd4 100644 --- a/client/web/src/hooks/exit-nodes.ts +++ b/client/web/src/hooks/exit-nodes.ts @@ -66,7 +66,7 @@ export default function useExitNodes(node: NodeData, filter?: string) { // match from a list of exit node `options` to `nodes`. const addBestMatchNode = ( options: ExitNode[], - name: (l: ExitNodeLocation) => string + name: (loc: ExitNodeLocation) => string ) => { const bestNode = highestPriorityNode(options) if (!bestNode || !bestNode.Location) { @@ -86,7 +86,7 @@ export default function useExitNodes(node: NodeData, filter?: string) { locationNodesMap.forEach( // add one node per country (countryNodes) => - addBestMatchNode(flattenMap(countryNodes), (l) => l.Country) + addBestMatchNode(flattenMap(countryNodes), (loc) => loc.Country) ) } else { // Otherwise, show the best match on a city-level, @@ -97,12 +97,12 @@ export default function useExitNodes(node: NodeData, filter?: string) { countryNodes.forEach( // add one node per city (cityNodes) => - addBestMatchNode(cityNodes, (l) => `${l.Country}: ${l.City}`) + addBestMatchNode(cityNodes, (loc) => `${loc.Country}: ${loc.City}`) ) // add the "Country: Best Match" node addBestMatchNode( flattenMap(countryNodes), - (l) => `${l.Country}: Best Match` + (loc) => `${loc.Country}: Best Match` ) }) } diff --git a/clientupdate/clientupdate.go b/clientupdate/clientupdate.go index 84b289615f911..3a0a8d03e0425 100644 --- a/clientupdate/clientupdate.go +++ b/clientupdate/clientupdate.go @@ -418,13 +418,13 @@ func parseSynoinfo(path string) (string, error) { // Extract the CPU in the middle (88f6282 in the above example). s := bufio.NewScanner(f) for s.Scan() { - l := s.Text() - if !strings.HasPrefix(l, "unique=") { + line := s.Text() + if !strings.HasPrefix(line, "unique=") { continue } - parts := strings.SplitN(l, "_", 3) + parts := strings.SplitN(line, "_", 3) if len(parts) != 3 { - return "", fmt.Errorf(`malformed %q: found %q, expected format like 'unique="synology_$cpu_$model'`, path, l) + return "", fmt.Errorf(`malformed %q: found %q, expected format like 'unique="synology_$cpu_$model'`, path, line) } return parts[1], nil } diff --git a/cmd/cigocacher/cigocacher.go b/cmd/cigocacher/cigocacher.go new file mode 100644 index 0000000000000..b38df4c2b40a5 --- /dev/null +++ b/cmd/cigocacher/cigocacher.go @@ -0,0 +1,308 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// cigocacher is an opinionated-to-Tailscale client for gocached. It connects +// at a URL like "https://ci-gocached-azure-1.corp.ts.net:31364", but that is +// stored in a GitHub actions variable so that its hostname can be updated for +// all branches at the same time in sync with the actual infrastructure. +// +// It authenticates using GitHub OIDC tokens, and all HTTP errors are ignored +// so that its failure mode is just that builds get slower and fall back to +// disk-only cache. +package main + +import ( + "bytes" + "context" + jsonv1 "encoding/json" + "errors" + "flag" + "fmt" + "io" + "log" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "sync/atomic" + "time" + + "github.com/bradfitz/go-tool-cache/cacheproc" + "github.com/bradfitz/go-tool-cache/cachers" +) + +func main() { + var ( + auth = flag.Bool("auth", false, "auth with cigocached and exit, printing the access token as output") + token = flag.String("token", "", "the cigocached access token to use, as created using --auth") + cigocachedURL = flag.String("cigocached-url", "", "optional cigocached URL (scheme, host, and port). empty means to not use one.") + verbose = flag.Bool("verbose", false, "enable verbose logging") + ) + flag.Parse() + + if *auth { + if *cigocachedURL == "" { + log.Print("--cigocached-url is empty, skipping auth") + return + } + tk, err := fetchAccessToken(httpClient(), os.Getenv("ACTIONS_ID_TOKEN_REQUEST_URL"), os.Getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN"), *cigocachedURL) + if err != nil { + log.Printf("error fetching access token, skipping auth: %v", err) + return + } + fmt.Println(tk) + return + } + + d, err := os.UserCacheDir() + if err != nil { + log.Fatal(err) + } + d = filepath.Join(d, "go-cacher") + log.Printf("Defaulting to cache dir %v ...", d) + if err := os.MkdirAll(d, 0750); err != nil { + log.Fatal(err) + } + + c := &cigocacher{ + disk: &cachers.DiskCache{Dir: d}, + verbose: *verbose, + } + if *cigocachedURL != "" { + log.Printf("Using cigocached at %s", *cigocachedURL) + c.gocached = &gocachedClient{ + baseURL: *cigocachedURL, + cl: httpClient(), + accessToken: *token, + verbose: *verbose, + } + } + var p *cacheproc.Process + p = &cacheproc.Process{ + Close: func() error { + log.Printf("gocacheprog: closing; %d gets (%d hits, %d misses, %d errors); %d puts (%d errors)", + p.Gets.Load(), p.GetHits.Load(), p.GetMisses.Load(), p.GetErrors.Load(), p.Puts.Load(), p.PutErrors.Load()) + return c.close() + }, + Get: c.get, + Put: c.put, + } + + if err := p.Run(); err != nil { + log.Fatal(err) + } +} + +func httpClient() *http.Client { + return &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err == nil { + // This does not run in a tailnet. We serve corp.ts.net + // TLS certs, and override DNS resolution to lookup the + // private IP for the VM by its hostname. + if vm, ok := strings.CutSuffix(host, ".corp.ts.net"); ok { + addr = net.JoinHostPort(vm, port) + } + } + var d net.Dialer + return d.DialContext(ctx, network, addr) + }, + }, + } +} + +type cigocacher struct { + disk *cachers.DiskCache + gocached *gocachedClient + verbose bool + + getNanos atomic.Int64 // total nanoseconds spent in gets + putNanos atomic.Int64 // total nanoseconds spent in puts + getHTTP atomic.Int64 // HTTP get requests made + getHTTPBytes atomic.Int64 // HTTP get bytes transferred + getHTTPHits atomic.Int64 // HTTP get hits + getHTTPMisses atomic.Int64 // HTTP get misses + getHTTPErrors atomic.Int64 // HTTP get errors ignored on best-effort basis + getHTTPNanos atomic.Int64 // total nanoseconds spent in HTTP gets + putHTTP atomic.Int64 // HTTP put requests made + putHTTPBytes atomic.Int64 // HTTP put bytes transferred + putHTTPErrors atomic.Int64 // HTTP put errors ignored on best-effort basis + putHTTPNanos atomic.Int64 // total nanoseconds spent in HTTP puts +} + +func (c *cigocacher) get(ctx context.Context, actionID string) (outputID, diskPath string, err error) { + t0 := time.Now() + defer func() { + c.getNanos.Add(time.Since(t0).Nanoseconds()) + }() + if c.gocached == nil { + return c.disk.Get(ctx, actionID) + } + + outputID, diskPath, err = c.disk.Get(ctx, actionID) + if err == nil && outputID != "" { + return outputID, diskPath, nil + } + + c.getHTTP.Add(1) + t0HTTP := time.Now() + defer func() { + c.getHTTPNanos.Add(time.Since(t0HTTP).Nanoseconds()) + }() + outputID, res, err := c.gocached.get(ctx, actionID) + if err != nil { + c.getHTTPErrors.Add(1) + return "", "", nil + } + if outputID == "" || res == nil { + c.getHTTPMisses.Add(1) + return "", "", nil + } + + defer res.Body.Close() + + // TODO(tomhjp): make sure we timeout if cigocached disappears, but for some + // reason, this seemed to tank network performance. + // ctx, cancel := context.WithTimeout(ctx, httpTimeout(res.ContentLength)) + // defer cancel() + diskPath, err = c.disk.Put(ctx, actionID, outputID, res.ContentLength, res.Body) + if err != nil { + return "", "", fmt.Errorf("error filling disk cache from HTTP: %w", err) + } + + c.getHTTPHits.Add(1) + c.getHTTPBytes.Add(res.ContentLength) + return outputID, diskPath, nil +} + +func (c *cigocacher) put(ctx context.Context, actionID, outputID string, size int64, r io.Reader) (diskPath string, err error) { + t0 := time.Now() + defer func() { + c.putNanos.Add(time.Since(t0).Nanoseconds()) + }() + if c.gocached == nil { + return c.disk.Put(ctx, actionID, outputID, size, r) + } + + c.putHTTP.Add(1) + var diskReader, httpReader io.Reader + tee := &bestEffortTeeReader{r: r} + if size == 0 { + // Special case the empty file so NewRequest sets "Content-Length: 0", + // as opposed to thinking we didn't set it and not being able to sniff its size + // from the type. + diskReader, httpReader = bytes.NewReader(nil), bytes.NewReader(nil) + } else { + pr, pw := io.Pipe() + defer pw.Close() + // The diskReader is in the driving seat. We will try to forward data + // to httpReader as well, but only best-effort. + diskReader = tee + tee.w = pw + httpReader = pr + } + httpErrCh := make(chan error) + go func() { + // TODO(tomhjp): make sure we timeout if cigocached disappears, but for some + // reason, this seemed to tank network performance. + // ctx, cancel := context.WithTimeout(ctx, httpTimeout(size)) + // defer cancel() + t0HTTP := time.Now() + defer func() { + c.putHTTPNanos.Add(time.Since(t0HTTP).Nanoseconds()) + }() + httpErrCh <- c.gocached.put(ctx, actionID, outputID, size, httpReader) + }() + + diskPath, err = c.disk.Put(ctx, actionID, outputID, size, diskReader) + if err != nil { + return "", fmt.Errorf("error writing to disk cache: %w", errors.Join(err, tee.err)) + } + + select { + case err := <-httpErrCh: + if err != nil { + c.putHTTPErrors.Add(1) + } else { + c.putHTTPBytes.Add(size) + } + case <-ctx.Done(): + } + + return diskPath, nil +} + +func (c *cigocacher) close() error { + log.Printf("cigocacher HTTP stats: %d gets (%.1fMiB, %.2fs, %d hits, %d misses, %d errors ignored); %d puts (%.1fMiB, %.2fs, %d errors ignored)", + c.getHTTP.Load(), float64(c.getHTTPBytes.Load())/float64(1<<20), float64(c.getHTTPNanos.Load())/float64(time.Second), c.getHTTPHits.Load(), c.getHTTPMisses.Load(), c.getHTTPErrors.Load(), + c.putHTTP.Load(), float64(c.putHTTPBytes.Load())/float64(1<<20), float64(c.putHTTPNanos.Load())/float64(time.Second), c.putHTTPErrors.Load()) + if !c.verbose || c.gocached == nil { + return nil + } + + stats, err := c.gocached.fetchStats() + if err != nil { + log.Printf("error fetching gocached stats: %v", err) + } else { + log.Printf("gocached session stats: %s", stats) + } + + return nil +} + +func fetchAccessToken(cl *http.Client, idTokenURL, idTokenRequestToken, gocachedURL string) (string, error) { + req, err := http.NewRequest("GET", idTokenURL+"&audience=gocached", nil) + if err != nil { + return "", err + } + req.Header.Set("Authorization", "Bearer "+idTokenRequestToken) + resp, err := cl.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + type idTokenResp struct { + Value string `json:"value"` + } + var idToken idTokenResp + if err := jsonv1.NewDecoder(resp.Body).Decode(&idToken); err != nil { + return "", err + } + + req, _ = http.NewRequest("POST", gocachedURL+"/auth/exchange-token", strings.NewReader(`{"jwt":"`+idToken.Value+`"}`)) + req.Header.Set("Content-Type", "application/json") + resp, err = cl.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + type accessTokenResp struct { + AccessToken string `json:"access_token"` + } + var accessToken accessTokenResp + if err := jsonv1.NewDecoder(resp.Body).Decode(&accessToken); err != nil { + return "", err + } + + return accessToken.AccessToken, nil +} + +type bestEffortTeeReader struct { + r io.Reader + w io.WriteCloser + err error +} + +func (t *bestEffortTeeReader) Read(p []byte) (int, error) { + n, err := t.r.Read(p) + if n > 0 && t.w != nil { + if _, err := t.w.Write(p[:n]); err != nil { + t.err = errors.Join(err, t.w.Close()) + t.w = nil + } + } + return n, err +} diff --git a/cmd/cigocacher/http.go b/cmd/cigocacher/http.go new file mode 100644 index 0000000000000..57d3bfb45f53e --- /dev/null +++ b/cmd/cigocacher/http.go @@ -0,0 +1,115 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "context" + "fmt" + "io" + "log" + "net/http" +) + +type gocachedClient struct { + baseURL string // base URL of the cacher server, like "http://localhost:31364". + cl *http.Client // http.Client to use. + accessToken string // Bearer token to use in the Authorization header. + verbose bool +} + +// drainAndClose reads and throws away a small bounded amount of data. This is a +// best-effort attempt to allow connection reuse; Go's HTTP/1 Transport won't +// reuse a TCP connection unless you fully consume HTTP responses. +func drainAndClose(body io.ReadCloser) { + io.CopyN(io.Discard, body, 4<<10) + body.Close() +} + +func tryReadErrorMessage(res *http.Response) []byte { + msg, _ := io.ReadAll(io.LimitReader(res.Body, 4<<10)) + return msg +} + +func (c *gocachedClient) get(ctx context.Context, actionID string) (outputID string, resp *http.Response, err error) { + // TODO(tomhjp): make sure we timeout if cigocached disappears, but for some + // reason, this seemed to tank network performance. + // // Set a generous upper limit on the time we'll wait for a response. We'll + // // shorten this deadline later once we know the content length. + // ctx, cancel := context.WithTimeout(ctx, time.Minute) + // defer cancel() + req, _ := http.NewRequestWithContext(ctx, "GET", c.baseURL+"/action/"+actionID, nil) + req.Header.Set("Want-Object", "1") // opt in to single roundtrip protocol + if c.accessToken != "" { + req.Header.Set("Authorization", "Bearer "+c.accessToken) + } + + res, err := c.cl.Do(req) + if err != nil { + return "", nil, err + } + defer func() { + if resp == nil { + drainAndClose(res.Body) + } + }() + if res.StatusCode == http.StatusNotFound { + return "", nil, nil + } + if res.StatusCode != http.StatusOK { + msg := tryReadErrorMessage(res) + if c.verbose { + log.Printf("error GET /action/%s: %v, %s", actionID, res.Status, msg) + } + return "", nil, fmt.Errorf("unexpected GET /action/%s status %v", actionID, res.Status) + } + + outputID = res.Header.Get("Go-Output-Id") + if outputID == "" { + return "", nil, fmt.Errorf("missing Go-Output-Id header in response") + } + if res.ContentLength == -1 { + return "", nil, fmt.Errorf("no Content-Length from server") + } + return outputID, res, nil +} + +func (c *gocachedClient) put(ctx context.Context, actionID, outputID string, size int64, body io.Reader) error { + req, _ := http.NewRequestWithContext(ctx, "PUT", c.baseURL+"/"+actionID+"/"+outputID, body) + req.ContentLength = size + if c.accessToken != "" { + req.Header.Set("Authorization", "Bearer "+c.accessToken) + } + res, err := c.cl.Do(req) + if err != nil { + if c.verbose { + log.Printf("error PUT /%s/%s: %v", actionID, outputID, err) + } + return err + } + defer res.Body.Close() + if res.StatusCode != http.StatusNoContent { + msg := tryReadErrorMessage(res) + if c.verbose { + log.Printf("error PUT /%s/%s: %v, %s", actionID, outputID, res.Status, msg) + } + return fmt.Errorf("unexpected PUT /%s/%s status %v", actionID, outputID, res.Status) + } + + return nil +} + +func (c *gocachedClient) fetchStats() (string, error) { + req, _ := http.NewRequest("GET", c.baseURL+"/session/stats", nil) + req.Header.Set("Authorization", "Bearer "+c.accessToken) + resp, err := c.cl.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + b, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + return string(b), nil +} diff --git a/cmd/cloner/cloner.go b/cmd/cloner/cloner.go index 544d00518e113..a81bd10bd5401 100644 --- a/cmd/cloner/cloner.go +++ b/cmd/cloner/cloner.go @@ -192,45 +192,34 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { writef("\t\tdst.%s[k] = append([]%s{}, src.%s[k]...)", fname, n, fname) writef("\t}") writef("}") - } else if codegen.ContainsPointers(elem) { + } else if codegen.IsViewType(elem) || !codegen.ContainsPointers(elem) { + // If the map values are view types (which are + // immutable and don't need cloning) or don't + // themselves contain pointers, we can just + // clone the map itself. + it.Import("", "maps") + writef("\tdst.%s = maps.Clone(src.%s)", fname, fname) + } else { + // Otherwise we need to clone each element of + // the map using our recursive helper. writef("if dst.%s != nil {", fname) writef("\tdst.%s = map[%s]%s{}", fname, it.QualifiedName(ft.Key()), it.QualifiedName(elem)) writef("\tfor k, v := range src.%s {", fname) - switch elem := elem.Underlying().(type) { - case *types.Pointer: - writef("\t\tif v == nil { dst.%s[k] = nil } else {", fname) - if base := elem.Elem().Underlying(); codegen.ContainsPointers(base) { - if _, isIface := base.(*types.Interface); isIface { - it.Import("", "tailscale.com/types/ptr") - writef("\t\t\tdst.%s[k] = ptr.To((*v).Clone())", fname) - } else { - writef("\t\t\tdst.%s[k] = v.Clone()", fname) - } - } else { - it.Import("", "tailscale.com/types/ptr") - writef("\t\t\tdst.%s[k] = ptr.To(*v)", fname) - } - writef("}") - case *types.Interface: - if cloneResultType := methodResultType(elem, "Clone"); cloneResultType != nil { - if _, isPtr := cloneResultType.(*types.Pointer); isPtr { - writef("\t\tdst.%s[k] = *(v.Clone())", fname) - } else { - writef("\t\tdst.%s[k] = v.Clone()", fname) - } - } else { - writef(`panic("%s (%v) does not have a Clone method")`, fname, elem) - } - default: - writef("\t\tdst.%s[k] = *(v.Clone())", fname) - } - + // Use a recursive helper here; this handles + // arbitrarily nested maps in addition to + // simpler types. + writeMapValueClone(mapValueCloneParams{ + Buf: buf, + It: it, + Elem: elem, + SrcExpr: "v", + DstExpr: fmt.Sprintf("dst.%s[k]", fname), + BaseIndent: "\t", + Depth: 1, + }) writef("\t}") writef("}") - } else { - it.Import("", "maps") - writef("\tdst.%s = maps.Clone(src.%s)", fname, fname) } case *types.Interface: // If ft is an interface with a "Clone() ft" method, it can be used to clone the field. @@ -271,3 +260,99 @@ func methodResultType(typ types.Type, method string) types.Type { } return sig.Results().At(0).Type() } + +type mapValueCloneParams struct { + // Buf is the buffer to write generated code to + Buf *bytes.Buffer + // It is the import tracker for managing imports. + It *codegen.ImportTracker + // Elem is the type of the map value to clone + Elem types.Type + // SrcExpr is the expression for the source value (e.g., "v", "v2", "v3") + SrcExpr string + // DstExpr is the expression for the destination (e.g., "dst.Field[k]", "dst.Field[k][k2]") + DstExpr string + // BaseIndent is the "base" indentation string for the generated code + // (i.e. 1 or more tabs). Additional indentation will be added based on + // the Depth parameter. + BaseIndent string + // Depth is the current nesting depth (1 for first level, 2 for second, etc.) + Depth int +} + +// writeMapValueClone generates code to clone a map value recursively. +// It handles arbitrary nesting of maps, pointers, and interfaces. +func writeMapValueClone(params mapValueCloneParams) { + indent := params.BaseIndent + strings.Repeat("\t", params.Depth) + writef := func(format string, args ...any) { + fmt.Fprintf(params.Buf, indent+format+"\n", args...) + } + + switch elem := params.Elem.Underlying().(type) { + case *types.Pointer: + writef("if %s == nil { %s = nil } else {", params.SrcExpr, params.DstExpr) + if base := elem.Elem().Underlying(); codegen.ContainsPointers(base) { + if _, isIface := base.(*types.Interface); isIface { + params.It.Import("", "tailscale.com/types/ptr") + writef("\t%s = ptr.To((*%s).Clone())", params.DstExpr, params.SrcExpr) + } else { + writef("\t%s = %s.Clone()", params.DstExpr, params.SrcExpr) + } + } else { + params.It.Import("", "tailscale.com/types/ptr") + writef("\t%s = ptr.To(*%s)", params.DstExpr, params.SrcExpr) + } + writef("}") + + case *types.Map: + // Recursively handle nested maps + innerElem := elem.Elem() + if codegen.IsViewType(innerElem) || !codegen.ContainsPointers(innerElem) { + // Inner map values don't need deep cloning + params.It.Import("", "maps") + writef("%s = maps.Clone(%s)", params.DstExpr, params.SrcExpr) + } else { + // Inner map values need cloning + keyType := params.It.QualifiedName(elem.Key()) + valueType := params.It.QualifiedName(innerElem) + // Generate unique variable names for nested loops based on depth + keyVar := fmt.Sprintf("k%d", params.Depth+1) + valVar := fmt.Sprintf("v%d", params.Depth+1) + + writef("if %s == nil {", params.SrcExpr) + writef("\t%s = nil", params.DstExpr) + writef("\tcontinue") + writef("}") + writef("%s = map[%s]%s{}", params.DstExpr, keyType, valueType) + writef("for %s, %s := range %s {", keyVar, valVar, params.SrcExpr) + + // Recursively generate cloning code for the nested map value + nestedDstExpr := fmt.Sprintf("%s[%s]", params.DstExpr, keyVar) + writeMapValueClone(mapValueCloneParams{ + Buf: params.Buf, + It: params.It, + Elem: innerElem, + SrcExpr: valVar, + DstExpr: nestedDstExpr, + BaseIndent: params.BaseIndent, + Depth: params.Depth + 1, + }) + + writef("}") + } + + case *types.Interface: + if cloneResultType := methodResultType(elem, "Clone"); cloneResultType != nil { + if _, isPtr := cloneResultType.(*types.Pointer); isPtr { + writef("%s = *(%s.Clone())", params.DstExpr, params.SrcExpr) + } else { + writef("%s = %s.Clone()", params.DstExpr, params.SrcExpr) + } + } else { + writef(`panic("map value (%%v) does not have a Clone method")`, elem) + } + + default: + writef("%s = *(%s.Clone())", params.DstExpr, params.SrcExpr) + } +} diff --git a/cmd/cloner/cloner_test.go b/cmd/cloner/cloner_test.go index 3556c14bc109e..754a4ac49a220 100644 --- a/cmd/cloner/cloner_test.go +++ b/cmd/cloner/cloner_test.go @@ -108,3 +108,109 @@ func TestInterfaceContainer(t *testing.T) { }) } } + +func TestMapWithPointers(t *testing.T) { + num1, num2 := 42, 100 + orig := &clonerex.MapWithPointers{ + Nested: map[string]*int{ + "foo": &num1, + "bar": &num2, + }, + WithCloneMethod: map[string]*clonerex.SliceContainer{ + "container1": {Slice: []*int{&num1, &num2}}, + "container2": {Slice: []*int{&num1}}, + }, + CloneInterface: map[string]clonerex.Cloneable{ + "impl1": &clonerex.CloneableImpl{Value: 123}, + "impl2": &clonerex.CloneableImpl{Value: 456}, + }, + } + + cloned := orig.Clone() + if !reflect.DeepEqual(orig, cloned) { + t.Errorf("Clone() = %v, want %v", cloned, orig) + } + + // Mutate cloned.Nested pointer values + *cloned.Nested["foo"] = 999 + if *orig.Nested["foo"] == 999 { + t.Errorf("Clone() aliased memory in Nested: original was modified") + } + + // Mutate cloned.WithCloneMethod slice values + *cloned.WithCloneMethod["container1"].Slice[0] = 888 + if *orig.WithCloneMethod["container1"].Slice[0] == 888 { + t.Errorf("Clone() aliased memory in WithCloneMethod: original was modified") + } + + // Mutate cloned.CloneInterface values + if impl, ok := cloned.CloneInterface["impl1"].(*clonerex.CloneableImpl); ok { + impl.Value = 777 + if origImpl, ok := orig.CloneInterface["impl1"].(*clonerex.CloneableImpl); ok { + if origImpl.Value == 777 { + t.Errorf("Clone() aliased memory in CloneInterface: original was modified") + } + } + } +} + +func TestDeeplyNestedMap(t *testing.T) { + num := 123 + orig := &clonerex.DeeplyNestedMap{ + ThreeLevels: map[string]map[string]map[string]int{ + "a": { + "b": {"c": 1, "d": 2}, + "e": {"f": 3}, + }, + "g": { + "h": {"i": 4}, + }, + }, + FourLevels: map[string]map[string]map[string]map[string]*clonerex.SliceContainer{ + "l1a": { + "l2a": { + "l3a": { + "l4a": {Slice: []*int{&num}}, + "l4b": {Slice: []*int{&num, &num}}, + }, + }, + }, + }, + } + + cloned := orig.Clone() + if !reflect.DeepEqual(orig, cloned) { + t.Errorf("Clone() = %v, want %v", cloned, orig) + } + + // Mutate the clone's ThreeLevels map + cloned.ThreeLevels["a"]["b"]["c"] = 777 + if orig.ThreeLevels["a"]["b"]["c"] == 777 { + t.Errorf("Clone() aliased memory in ThreeLevels: original was modified") + } + + // Mutate the clone's FourLevels map at the deepest pointer level + *cloned.FourLevels["l1a"]["l2a"]["l3a"]["l4a"].Slice[0] = 666 + if *orig.FourLevels["l1a"]["l2a"]["l3a"]["l4a"].Slice[0] == 666 { + t.Errorf("Clone() aliased memory in FourLevels: original was modified") + } + + // Add a new top-level key to the clone's FourLevels map + newNum := 999 + cloned.FourLevels["l1b"] = map[string]map[string]map[string]*clonerex.SliceContainer{ + "l2b": { + "l3b": { + "l4c": {Slice: []*int{&newNum}}, + }, + }, + } + if _, exists := orig.FourLevels["l1b"]; exists { + t.Errorf("Clone() aliased FourLevels map: new top-level key appeared in original") + } + + // Add a new nested key to the clone's FourLevels map + cloned.FourLevels["l1a"]["l2a"]["l3a"]["l4c"] = &clonerex.SliceContainer{Slice: []*int{&newNum}} + if _, exists := orig.FourLevels["l1a"]["l2a"]["l3a"]["l4c"]; exists { + t.Errorf("Clone() aliased FourLevels map: new nested key appeared in original") + } +} diff --git a/cmd/cloner/clonerex/clonerex.go b/cmd/cloner/clonerex/clonerex.go index 6463f91442a32..b9f6d60dedb35 100644 --- a/cmd/cloner/clonerex/clonerex.go +++ b/cmd/cloner/clonerex/clonerex.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type SliceContainer,InterfaceContainer +//go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type SliceContainer,InterfaceContainer,MapWithPointers,DeeplyNestedMap // Package clonerex is an example package for the cloner tool. package clonerex @@ -32,3 +32,15 @@ func (c *CloneableImpl) Clone() Cloneable { type InterfaceContainer struct { Interface Cloneable } + +type MapWithPointers struct { + Nested map[string]*int + WithCloneMethod map[string]*SliceContainer + CloneInterface map[string]Cloneable +} + +// DeeplyNestedMap tests arbitrary depth of map nesting (3+ levels) +type DeeplyNestedMap struct { + ThreeLevels map[string]map[string]map[string]int + FourLevels map[string]map[string]map[string]map[string]*SliceContainer +} diff --git a/cmd/cloner/clonerex/clonerex_clone.go b/cmd/cloner/clonerex/clonerex_clone.go index 533d7e723d3ea..13e1276c4e4b8 100644 --- a/cmd/cloner/clonerex/clonerex_clone.go +++ b/cmd/cloner/clonerex/clonerex_clone.go @@ -6,6 +6,8 @@ package clonerex import ( + "maps" + "tailscale.com/types/ptr" ) @@ -54,9 +56,114 @@ var _InterfaceContainerCloneNeedsRegeneration = InterfaceContainer(struct { Interface Cloneable }{}) +// Clone makes a deep copy of MapWithPointers. +// The result aliases no memory with the original. +func (src *MapWithPointers) Clone() *MapWithPointers { + if src == nil { + return nil + } + dst := new(MapWithPointers) + *dst = *src + if dst.Nested != nil { + dst.Nested = map[string]*int{} + for k, v := range src.Nested { + if v == nil { + dst.Nested[k] = nil + } else { + dst.Nested[k] = ptr.To(*v) + } + } + } + if dst.WithCloneMethod != nil { + dst.WithCloneMethod = map[string]*SliceContainer{} + for k, v := range src.WithCloneMethod { + if v == nil { + dst.WithCloneMethod[k] = nil + } else { + dst.WithCloneMethod[k] = v.Clone() + } + } + } + if dst.CloneInterface != nil { + dst.CloneInterface = map[string]Cloneable{} + for k, v := range src.CloneInterface { + dst.CloneInterface[k] = v.Clone() + } + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _MapWithPointersCloneNeedsRegeneration = MapWithPointers(struct { + Nested map[string]*int + WithCloneMethod map[string]*SliceContainer + CloneInterface map[string]Cloneable +}{}) + +// Clone makes a deep copy of DeeplyNestedMap. +// The result aliases no memory with the original. +func (src *DeeplyNestedMap) Clone() *DeeplyNestedMap { + if src == nil { + return nil + } + dst := new(DeeplyNestedMap) + *dst = *src + if dst.ThreeLevels != nil { + dst.ThreeLevels = map[string]map[string]map[string]int{} + for k, v := range src.ThreeLevels { + if v == nil { + dst.ThreeLevels[k] = nil + continue + } + dst.ThreeLevels[k] = map[string]map[string]int{} + for k2, v2 := range v { + dst.ThreeLevels[k][k2] = maps.Clone(v2) + } + } + } + if dst.FourLevels != nil { + dst.FourLevels = map[string]map[string]map[string]map[string]*SliceContainer{} + for k, v := range src.FourLevels { + if v == nil { + dst.FourLevels[k] = nil + continue + } + dst.FourLevels[k] = map[string]map[string]map[string]*SliceContainer{} + for k2, v2 := range v { + if v2 == nil { + dst.FourLevels[k][k2] = nil + continue + } + dst.FourLevels[k][k2] = map[string]map[string]*SliceContainer{} + for k3, v3 := range v2 { + if v3 == nil { + dst.FourLevels[k][k2][k3] = nil + continue + } + dst.FourLevels[k][k2][k3] = map[string]*SliceContainer{} + for k4, v4 := range v3 { + if v4 == nil { + dst.FourLevels[k][k2][k3][k4] = nil + } else { + dst.FourLevels[k][k2][k3][k4] = v4.Clone() + } + } + } + } + } + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _DeeplyNestedMapCloneNeedsRegeneration = DeeplyNestedMap(struct { + ThreeLevels map[string]map[string]map[string]int + FourLevels map[string]map[string]map[string]map[string]*SliceContainer +}{}) + // Clone duplicates src into dst and reports whether it succeeded. // To succeed, must be of types <*T, *T> or <*T, **T>, -// where T is one of SliceContainer,InterfaceContainer. +// where T is one of SliceContainer,InterfaceContainer,MapWithPointers,DeeplyNestedMap. func Clone(dst, src any) bool { switch src := src.(type) { case *SliceContainer: @@ -77,6 +184,24 @@ func Clone(dst, src any) bool { *dst = src.Clone() return true } + case *MapWithPointers: + switch dst := dst.(type) { + case *MapWithPointers: + *dst = *src.Clone() + return true + case **MapWithPointers: + *dst = src.Clone() + return true + } + case *DeeplyNestedMap: + switch dst := dst.(type) { + case *DeeplyNestedMap: + *dst = *src.Clone() + return true + case **DeeplyNestedMap: + *dst = src.Clone() + return true + } } return false } diff --git a/cmd/containerboot/main_test.go b/cmd/containerboot/main_test.go index 96feef682af5b..f92f353334de2 100644 --- a/cmd/containerboot/main_test.go +++ b/cmd/containerboot/main_test.go @@ -1287,8 +1287,8 @@ type localAPI struct { notify *ipn.Notify } -func (l *localAPI) Start() error { - path := filepath.Join(l.FSRoot, "tmp/tailscaled.sock.fake") +func (lc *localAPI) Start() error { + path := filepath.Join(lc.FSRoot, "tmp/tailscaled.sock.fake") if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { return err } @@ -1298,30 +1298,30 @@ func (l *localAPI) Start() error { return err } - l.srv = &http.Server{ - Handler: l, + lc.srv = &http.Server{ + Handler: lc, } - l.Path = path - l.cond = sync.NewCond(&l.Mutex) - go l.srv.Serve(ln) + lc.Path = path + lc.cond = sync.NewCond(&lc.Mutex) + go lc.srv.Serve(ln) return nil } -func (l *localAPI) Close() { - l.srv.Close() +func (lc *localAPI) Close() { + lc.srv.Close() } -func (l *localAPI) Notify(n *ipn.Notify) { +func (lc *localAPI) Notify(n *ipn.Notify) { if n == nil { return } - l.Lock() - defer l.Unlock() - l.notify = n - l.cond.Broadcast() + lc.Lock() + defer lc.Unlock() + lc.notify = n + lc.cond.Broadcast() } -func (l *localAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (lc *localAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/localapi/v0/serve-config": if r.Method != "POST" { @@ -1348,11 +1348,11 @@ func (l *localAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) { f.Flush() } enc := json.NewEncoder(w) - l.Lock() - defer l.Unlock() + lc.Lock() + defer lc.Unlock() for { - if l.notify != nil { - if err := enc.Encode(l.notify); err != nil { + if lc.notify != nil { + if err := enc.Encode(lc.notify); err != nil { // Usually broken pipe as the test client disconnects. return } @@ -1360,7 +1360,7 @@ func (l *localAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) { f.Flush() } } - l.cond.Wait() + lc.cond.Wait() } } diff --git a/cmd/derper/cert.go b/cmd/derper/cert.go index b95755c64d2a7..d383c82f01157 100644 --- a/cmd/derper/cert.go +++ b/cmd/derper/cert.go @@ -11,6 +11,7 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "encoding/base64" "encoding/json" "encoding/pem" "errors" @@ -24,6 +25,7 @@ import ( "regexp" "time" + "golang.org/x/crypto/acme" "golang.org/x/crypto/acme/autocert" "tailscale.com/tailcfg" ) @@ -42,17 +44,33 @@ type certProvider interface { HTTPHandler(fallback http.Handler) http.Handler } -func certProviderByCertMode(mode, dir, hostname string) (certProvider, error) { +func certProviderByCertMode(mode, dir, hostname, eabKID, eabKey string) (certProvider, error) { if dir == "" { return nil, errors.New("missing required --certdir flag") } switch mode { - case "letsencrypt": + case "letsencrypt", "gcp": certManager := &autocert.Manager{ Prompt: autocert.AcceptTOS, HostPolicy: autocert.HostWhitelist(hostname), Cache: autocert.DirCache(dir), } + if mode == "gcp" { + if eabKID == "" || eabKey == "" { + return nil, errors.New("--certmode=gcp requires --acme-eab-kid and --acme-eab-key flags") + } + keyBytes, err := decodeEABKey(eabKey) + if err != nil { + return nil, err + } + certManager.Client = &acme.Client{ + DirectoryURL: "https://dv.acme-v02.api.pki.goog/directory", + } + certManager.ExternalAccountBinding = &acme.ExternalAccountBinding{ + KID: eabKID, + Key: keyBytes, + } + } if hostname == "derp.tailscale.com" { certManager.HostPolicy = prodAutocertHostPolicy certManager.Email = "security@tailscale.com" @@ -209,3 +227,17 @@ func createSelfSignedIPCert(crtPath, keyPath, ipStr string) (*tls.Certificate, e } return &tlsCert, nil } + +// decodeEABKey decodes a base64-encoded EAB key. +// It accepts both standard base64 (with padding) and base64url (without padding). +func decodeEABKey(s string) ([]byte, error) { + // Try base64url first (no padding), then standard base64 (with padding). + // This handles both ACME spec format and gcloud output format. + if b, err := base64.RawURLEncoding.DecodeString(s); err == nil { + return b, nil + } + if b, err := base64.StdEncoding.DecodeString(s); err == nil { + return b, nil + } + return nil, errors.New("invalid base64 encoding for EAB key") +} diff --git a/cmd/derper/cert_test.go b/cmd/derper/cert_test.go index c8a3229e9f41c..3a8da46108428 100644 --- a/cmd/derper/cert_test.go +++ b/cmd/derper/cert_test.go @@ -91,7 +91,7 @@ func TestCertIP(t *testing.T) { t.Fatalf("Error closing key.pem: %v", err) } - cp, err := certProviderByCertMode("manual", dir, hostname) + cp, err := certProviderByCertMode("manual", dir, hostname, "", "") if err != nil { t.Fatal(err) } @@ -169,3 +169,37 @@ func TestPinnedCertRawIP(t *testing.T) { } defer connClose.Close() } + +func TestGCPCertMode(t *testing.T) { + dir := t.TempDir() + + // Missing EAB credentials + _, err := certProviderByCertMode("gcp", dir, "test.example.com", "", "") + if err == nil { + t.Fatal("expected error when EAB credentials are missing") + } + + // Invalid base64 + _, err = certProviderByCertMode("gcp", dir, "test.example.com", "kid", "not-valid!") + if err == nil { + t.Fatal("expected error for invalid base64") + } + + // Valid base64url (no padding) + cp, err := certProviderByCertMode("gcp", dir, "test.example.com", "kid", "dGVzdC1rZXk") + if err != nil { + t.Fatalf("base64url: %v", err) + } + if cp == nil { + t.Fatal("base64url: nil certProvider") + } + + // Valid standard base64 (with padding, gcloud format) + cp, err = certProviderByCertMode("gcp", dir, "test.example.com", "kid", "dGVzdC1rZXk=") + if err != nil { + t.Fatalf("base64: %v", err) + } + if cp == nil { + t.Fatal("base64: nil certProvider") + } +} diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index 01c278fbd1691..6ce5f044228a9 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -30,9 +30,9 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa github.com/prometheus/client_model/go from github.com/prometheus/client_golang/prometheus+ github.com/prometheus/common/expfmt from github.com/prometheus/client_golang/prometheus+ github.com/prometheus/common/model from github.com/prometheus/client_golang/prometheus+ - LD github.com/prometheus/procfs from github.com/prometheus/client_golang/prometheus - LD github.com/prometheus/procfs/internal/fs from github.com/prometheus/procfs - LD github.com/prometheus/procfs/internal/util from github.com/prometheus/procfs + L github.com/prometheus/procfs from github.com/prometheus/client_golang/prometheus + L github.com/prometheus/procfs/internal/fs from github.com/prometheus/procfs + L github.com/prometheus/procfs/internal/util from github.com/prometheus/procfs W 💣 github.com/tailscale/go-winio from tailscale.com/safesocket W 💣 github.com/tailscale/go-winio/internal/fs from github.com/tailscale/go-winio W 💣 github.com/tailscale/go-winio/internal/socket from github.com/tailscale/go-winio @@ -72,7 +72,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa google.golang.org/protobuf/reflect/protoregistry from google.golang.org/protobuf/encoding/prototext+ google.golang.org/protobuf/runtime/protoiface from google.golang.org/protobuf/internal/impl+ google.golang.org/protobuf/runtime/protoimpl from github.com/prometheus/client_model/go+ - google.golang.org/protobuf/types/known/timestamppb from github.com/prometheus/client_golang/prometheus+ + 💣 google.golang.org/protobuf/types/known/timestamppb from github.com/prometheus/client_golang/prometheus+ tailscale.com from tailscale.com/version 💣 tailscale.com/atomicfile from tailscale.com/cmd/derper+ tailscale.com/client/local from tailscale.com/derp/derpserver @@ -139,7 +139,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa tailscale.com/types/structs from tailscale.com/ipn+ tailscale.com/types/tkatype from tailscale.com/client/local+ tailscale.com/types/views from tailscale.com/ipn+ - tailscale.com/util/cibuild from tailscale.com/health + tailscale.com/util/cibuild from tailscale.com/health+ tailscale.com/util/clientmetric from tailscale.com/net/netmon tailscale.com/util/cloudenv from tailscale.com/hostinfo+ tailscale.com/util/ctxkey from tailscale.com/tsweb+ @@ -168,7 +168,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa tailscale.com/version from tailscale.com/cmd/derper+ tailscale.com/version/distro from tailscale.com/envknob+ tailscale.com/wgengine/filter/filtertype from tailscale.com/types/netmap - golang.org/x/crypto/acme from golang.org/x/crypto/acme/autocert + golang.org/x/crypto/acme from golang.org/x/crypto/acme/autocert+ golang.org/x/crypto/acme/autocert from tailscale.com/cmd/derper golang.org/x/crypto/argon2 from tailscale.com/tka golang.org/x/crypto/blake2b from golang.org/x/crypto/argon2+ diff --git a/cmd/derper/derper.go b/cmd/derper/derper.go index 857d7def3b6ff..aeb2adb5dc61d 100644 --- a/cmd/derper/derper.go +++ b/cmd/derper/derper.go @@ -60,9 +60,11 @@ var ( httpPort = flag.Int("http-port", 80, "The port on which to serve HTTP. Set to -1 to disable. The listener is bound to the same IP (if any) as specified in the -a flag.") stunPort = flag.Int("stun-port", 3478, "The UDP port on which to serve STUN. The listener is bound to the same IP (if any) as specified in the -a flag.") configPath = flag.String("c", "", "config file path") - certMode = flag.String("certmode", "letsencrypt", "mode for getting a cert. possible options: manual, letsencrypt") - certDir = flag.String("certdir", tsweb.DefaultCertDir("derper-certs"), "directory to store LetsEncrypt certs, if addr's port is :443") - hostname = flag.String("hostname", "derp.tailscale.com", "LetsEncrypt host name, if addr's port is :443. When --certmode=manual, this can be an IP address to avoid SNI checks") + certMode = flag.String("certmode", "letsencrypt", "mode for getting a cert. possible options: manual, letsencrypt, gcp") + certDir = flag.String("certdir", tsweb.DefaultCertDir("derper-certs"), "directory to store ACME (e.g. LetsEncrypt) certs, if addr's port is :443") + hostname = flag.String("hostname", "derp.tailscale.com", "TLS host name for certs, if addr's port is :443. When --certmode=manual, this can be an IP address to avoid SNI checks") + acmeEABKid = flag.String("acme-eab-kid", "", "ACME External Account Binding (EAB) Key ID (required for --certmode=gcp)") + acmeEABKey = flag.String("acme-eab-key", "", "ACME External Account Binding (EAB) HMAC key, base64-encoded (required for --certmode=gcp)") runSTUN = flag.Bool("stun", true, "whether to run a STUN server. It will bind to the same IP (if any) as the --addr flag value.") runDERP = flag.Bool("derp", true, "whether to run a DERP server. The only reason to set this false is if you're decommissioning a server but want to keep its bootstrap DNS functionality still running.") flagHome = flag.String("home", "", "what to serve at the root path. It may be left empty (the default, for a default homepage), \"blank\" for a blank page, or a URL to redirect to") @@ -343,7 +345,7 @@ func main() { if serveTLS { log.Printf("derper: serving on %s with TLS", *addr) var certManager certProvider - certManager, err = certProviderByCertMode(*certMode, *certDir, *hostname) + certManager, err = certProviderByCertMode(*certMode, *certDir, *hostname, *acmeEABKid, *acmeEABKey) if err != nil { log.Fatalf("derper: can not start cert provider: %v", err) } @@ -481,32 +483,32 @@ func newRateLimitedListener(ln net.Listener, limit rate.Limit, burst int) *rateL return &rateLimitedListener{Listener: ln, lim: rate.NewLimiter(limit, burst)} } -func (l *rateLimitedListener) ExpVar() expvar.Var { +func (ln *rateLimitedListener) ExpVar() expvar.Var { m := new(metrics.Set) - m.Set("counter_accepted_connections", &l.numAccepts) - m.Set("counter_rejected_connections", &l.numRejects) + m.Set("counter_accepted_connections", &ln.numAccepts) + m.Set("counter_rejected_connections", &ln.numRejects) return m } var errLimitedConn = errors.New("cannot accept connection; rate limited") -func (l *rateLimitedListener) Accept() (net.Conn, error) { +func (ln *rateLimitedListener) Accept() (net.Conn, error) { // Even under a rate limited situation, we accept the connection immediately // and close it, rather than being slow at accepting new connections. // This provides two benefits: 1) it signals to the client that something // is going on on the server, and 2) it prevents new connections from // piling up and occupying resources in the OS kernel. // The client will retry as needing (with backoffs in place). - cn, err := l.Listener.Accept() + cn, err := ln.Listener.Accept() if err != nil { return nil, err } - if !l.lim.Allow() { - l.numRejects.Add(1) + if !ln.lim.Allow() { + ln.numRejects.Add(1) cn.Close() return nil, errLimitedConn } - l.numAccepts.Add(1) + ln.numAccepts.Add(1) return cn, nil } diff --git a/cmd/jsonimports/format.go b/cmd/jsonimports/format.go new file mode 100644 index 0000000000000..6dbd175583a4d --- /dev/null +++ b/cmd/jsonimports/format.go @@ -0,0 +1,175 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "bytes" + "go/ast" + "go/format" + "go/parser" + "go/token" + "go/types" + "path" + "slices" + "strconv" + "strings" + + "tailscale.com/util/must" +) + +// mustFormatFile formats a Go source file and adjust "json" imports. +// It panics if there are any parsing errors. +// +// - "encoding/json" is imported under the name "jsonv1" or "jsonv1std" +// - "encoding/json/v2" is rewritten to import "github.com/go-json-experiment/json" instead +// - "encoding/json/jsontext" is rewritten to import "github.com/go-json-experiment/json/jsontext" instead +// - "github.com/go-json-experiment/json" is imported under the name "jsonv2" +// - "github.com/go-json-experiment/json/v1" is imported under the name "jsonv1" +// +// If no changes to the file is made, it returns input. +func mustFormatFile(in []byte) (out []byte) { + fset := token.NewFileSet() + f := must.Get(parser.ParseFile(fset, "", in, parser.ParseComments)) + + // Check for the existence of "json" imports. + jsonImports := make(map[string][]*ast.ImportSpec) + for _, imp := range f.Imports { + switch pkgPath := must.Get(strconv.Unquote(imp.Path.Value)); pkgPath { + case + "encoding/json", + "encoding/json/v2", + "encoding/json/jsontext", + "github.com/go-json-experiment/json", + "github.com/go-json-experiment/json/v1", + "github.com/go-json-experiment/json/jsontext": + jsonImports[pkgPath] = append(jsonImports[pkgPath], imp) + } + } + if len(jsonImports) == 0 { + return in + } + + // Best-effort local type-check of the file + // to resolve local declarations to detect shadowed variables. + typeInfo := &types.Info{Uses: make(map[*ast.Ident]types.Object)} + (&types.Config{ + Error: func(err error) {}, + }).Check("", fset, []*ast.File{f}, typeInfo) + + // Rewrite imports to instead use "github.com/go-json-experiment/json". + // This ensures that code continues to build even if + // goexperiment.jsonv2 is *not* specified. + // As of https://github.com/go-json-experiment/json/pull/186, + // imports to "github.com/go-json-experiment/json" are identical + // to the standard library if built with goexperiment.jsonv2. + for fromPath, toPath := range map[string]string{ + "encoding/json/v2": "github.com/go-json-experiment/json", + "encoding/json/jsontext": "github.com/go-json-experiment/json/jsontext", + } { + for _, imp := range jsonImports[fromPath] { + imp.Path.Value = strconv.Quote(toPath) + jsonImports[toPath] = append(jsonImports[toPath], imp) + } + delete(jsonImports, fromPath) + } + + // While in a transitory state, where both v1 and v2 json imports + // may exist in our codebase, always explicitly import with + // either jsonv1 or jsonv2 in the package name to avoid ambiguities + // when looking at a particular Marshal or Unmarshal call site. + renames := make(map[string]string) // mapping of old names to new names + deletes := make(map[*ast.ImportSpec]bool) // set of imports to delete + for pkgPath, imps := range jsonImports { + var newName string + switch pkgPath { + case "encoding/json": + newName = "jsonv1" + // If "github.com/go-json-experiment/json/v1" is also imported, + // then use jsonv1std for "encoding/json" to avoid a conflict. + if len(jsonImports["github.com/go-json-experiment/json/v1"]) > 0 { + newName += "std" + } + case "github.com/go-json-experiment/json": + newName = "jsonv2" + case "github.com/go-json-experiment/json/v1": + newName = "jsonv1" + } + + // Rename the import if different than expected. + if oldName := importName(imps[0]); oldName != newName && newName != "" { + renames[oldName] = newName + pos := imps[0].Pos() // preserve original positioning + imps[0].Name = ast.NewIdent(newName) + imps[0].Name.NamePos = pos + } + + // For all redundant imports, use the first imported name. + for _, imp := range imps[1:] { + renames[importName(imp)] = importName(imps[0]) + deletes[imp] = true + } + } + if len(deletes) > 0 { + f.Imports = slices.DeleteFunc(f.Imports, func(imp *ast.ImportSpec) bool { + return deletes[imp] + }) + for _, decl := range f.Decls { + if genDecl, ok := decl.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT { + genDecl.Specs = slices.DeleteFunc(genDecl.Specs, func(spec ast.Spec) bool { + return deletes[spec.(*ast.ImportSpec)] + }) + } + } + } + if len(renames) > 0 { + ast.Walk(astVisitor(func(n ast.Node) bool { + if sel, ok := n.(*ast.SelectorExpr); ok { + if id, ok := sel.X.(*ast.Ident); ok { + // Just because the selector looks like "json.Marshal" + // does not mean that it is referencing the "json" package. + // There could be a local "json" declaration that shadows + // the package import. Check partial type information + // to see if there was a local declaration. + if obj, ok := typeInfo.Uses[id]; ok { + if _, ok := obj.(*types.PkgName); !ok { + return true + } + } + + if newName, ok := renames[id.String()]; ok { + id.Name = newName + } + } + } + return true + }), f) + } + + bb := new(bytes.Buffer) + must.Do(format.Node(bb, fset, f)) + return must.Get(format.Source(bb.Bytes())) +} + +// importName is the local package name used for an import. +// If no explicit local name is used, then it uses string parsing +// to derive the package name from the path, relying on the convention +// that the package name is the base name of the package path. +func importName(imp *ast.ImportSpec) string { + if imp.Name != nil { + return imp.Name.String() + } + pkgPath, _ := strconv.Unquote(imp.Path.Value) + pkgPath = strings.TrimRight(pkgPath, "/v0123456789") // exclude version directories + return path.Base(pkgPath) +} + +// astVisitor is a function that implements [ast.Visitor]. +type astVisitor func(ast.Node) bool + +func (f astVisitor) Visit(node ast.Node) ast.Visitor { + if !f(node) { + return nil + } + return f +} diff --git a/cmd/jsonimports/format_test.go b/cmd/jsonimports/format_test.go new file mode 100644 index 0000000000000..28654eb4550ee --- /dev/null +++ b/cmd/jsonimports/format_test.go @@ -0,0 +1,162 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "go/format" + "testing" + + "tailscale.com/util/must" + "tailscale.com/util/safediff" +) + +func TestFormatFile(t *testing.T) { + tests := []struct{ in, want string }{{ + in: `package foobar + + import ( + "encoding/json" + jsonv2exp "github.com/go-json-experiment/json" + ) + + func main() { + json.Marshal() + jsonv2exp.Marshal() + { + var json T // deliberately shadow "json" package name + json.Marshal() // should not be re-written + } + } + `, + want: `package foobar + + import ( + jsonv1 "encoding/json" + jsonv2 "github.com/go-json-experiment/json" + ) + + func main() { + jsonv1.Marshal() + jsonv2.Marshal() + { + var json T // deliberately shadow "json" package name + json.Marshal() // should not be re-written + } + } + `, + }, { + in: `package foobar + + import ( + "github.com/go-json-experiment/json" + jsonv2exp "github.com/go-json-experiment/json" + ) + + func main() { + json.Marshal() + jsonv2exp.Marshal() + } + `, + want: `package foobar + import ( + jsonv2 "github.com/go-json-experiment/json" + ) + func main() { + jsonv2.Marshal() + jsonv2.Marshal() + } + `, + }, { + in: `package foobar + import "github.com/go-json-experiment/json/v1" + func main() { + json.Marshal() + } + `, + want: `package foobar + import jsonv1 "github.com/go-json-experiment/json/v1" + func main() { + jsonv1.Marshal() + } + `, + }, { + in: `package foobar + import ( + "encoding/json" + jsonv1in2 "github.com/go-json-experiment/json/v1" + ) + func main() { + json.Marshal() + jsonv1in2.Marshal() + } + `, + want: `package foobar + import ( + jsonv1std "encoding/json" + jsonv1 "github.com/go-json-experiment/json/v1" + ) + func main() { + jsonv1std.Marshal() + jsonv1.Marshal() + } + `, + }, { + in: `package foobar + import ( + "encoding/json" + jsonv1in2 "github.com/go-json-experiment/json/v1" + ) + func main() { + json.Marshal() + jsonv1in2.Marshal() + } + `, + want: `package foobar + import ( + jsonv1std "encoding/json" + jsonv1 "github.com/go-json-experiment/json/v1" + ) + func main() { + jsonv1std.Marshal() + jsonv1.Marshal() + } + `, + }, { + in: `package foobar + import ( + "encoding/json" + j2 "encoding/json/v2" + "encoding/json/jsontext" + ) + func main() { + json.Marshal() + j2.Marshal() + jsontext.NewEncoder + } + `, + want: `package foobar + import ( + jsonv1 "encoding/json" + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" + ) + func main() { + jsonv1.Marshal() + jsonv2.Marshal() + jsontext.NewEncoder + } + `, + }} + for _, tt := range tests { + got := string(must.Get(format.Source([]byte(tt.in)))) + got = string(mustFormatFile([]byte(got))) + want := string(must.Get(format.Source([]byte(tt.want)))) + if got != want { + diff, _ := safediff.Lines(got, want, -1) + t.Errorf("mismatch (-got +want)\n%s", diff) + t.Error(got) + t.Error(want) + } + } +} diff --git a/cmd/jsonimports/jsonimports.go b/cmd/jsonimports/jsonimports.go new file mode 100644 index 0000000000000..4be2e10cbe091 --- /dev/null +++ b/cmd/jsonimports/jsonimports.go @@ -0,0 +1,124 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The jsonimports tool formats all Go source files in the repository +// to enforce that "json" imports are consistent. +// +// With Go 1.25, the "encoding/json/v2" and "encoding/json/jsontext" +// packages are now available under goexperiment.jsonv2. +// This leads to possible confusion over the following: +// +// - "encoding/json" +// - "encoding/json/v2" +// - "encoding/json/jsontext" +// - "github.com/go-json-experiment/json/v1" +// - "github.com/go-json-experiment/json" +// - "github.com/go-json-experiment/json/jsontext" +// +// In order to enforce consistent usage, we apply the following rules: +// +// - Until the Go standard library formally accepts "encoding/json/v2" +// and "encoding/json/jsontext" into the standard library +// (i.e., they are no longer considered experimental), +// we forbid any code from directly importing those packages. +// Go code should instead import "github.com/go-json-experiment/json" +// and "github.com/go-json-experiment/json/jsontext". +// The latter packages contain aliases to the standard library +// if built on Go 1.25 with the goexperiment.jsonv2 tag specified. +// +// - Imports of "encoding/json" or "github.com/go-json-experiment/json/v1" +// must be explicitly imported under the package name "jsonv1". +// If both packages need to be imported, then the former should +// be imported under the package name "jsonv1std". +// +// - Imports of "github.com/go-json-experiment/json" +// must be explicitly imported under the package name "jsonv2". +// +// The latter two rules exist to provide clarity when reading code. +// Without them, it is unclear whether "json.Marshal" refers to v1 or v2. +// With them, however, it is clear that "jsonv1.Marshal" is calling v1 and +// that "jsonv2.Marshal" is calling v2. +// +// TODO(@joetsai): At this present moment, there is no guidance given on +// whether to use v1 or v2 for newly written Go source code. +// I will write a document in the near future providing more guidance. +// Feel free to continue using v1 "encoding/json" as you are accustomed to. +package main + +import ( + "bytes" + "flag" + "fmt" + "os" + "os/exec" + "runtime" + "strings" + "sync" + + "tailscale.com/syncs" + "tailscale.com/util/must" + "tailscale.com/util/safediff" +) + +func main() { + update := flag.Bool("update", false, "update all Go source files") + flag.Parse() + + // Change working directory to Git repository root. + repoRoot := strings.TrimSuffix(string(must.Get(exec.Command( + "git", "rev-parse", "--show-toplevel", + ).Output())), "\n") + must.Do(os.Chdir(repoRoot)) + + // Iterate over all indexed files in the Git repository. + var printMu sync.Mutex + var group sync.WaitGroup + sema := syncs.NewSemaphore(runtime.NumCPU()) + var numDiffs int + files := string(must.Get(exec.Command("git", "ls-files").Output())) + for file := range strings.Lines(files) { + sema.Acquire() + group.Go(func() { + defer sema.Release() + + // Ignore non-Go source files. + file = strings.TrimSuffix(file, "\n") + if !strings.HasSuffix(file, ".go") { + return + } + + // Format all "json" imports in the Go source file. + srcIn := must.Get(os.ReadFile(file)) + srcOut := mustFormatFile(srcIn) + + // Print differences with each formatted file. + if !bytes.Equal(srcIn, srcOut) { + numDiffs++ + + printMu.Lock() + fmt.Println(file) + lines, _ := safediff.Lines(string(srcIn), string(srcOut), -1) + for line := range strings.Lines(lines) { + fmt.Print("\t", line) + } + fmt.Println() + printMu.Unlock() + + // If -update is specified, write out the changes. + if *update { + mode := must.Get(os.Stat(file)).Mode() + must.Do(os.WriteFile(file, srcOut, mode)) + } + } + }) + } + group.Wait() + + // Report whether any differences were detected. + if numDiffs > 0 && !*update { + fmt.Printf(`%d files with "json" imports that need formatting`+"\n", numDiffs) + fmt.Println("Please run:") + fmt.Println("\t./tool/go run tailscale.com/cmd/jsonimports -update") + os.Exit(1) + } +} diff --git a/cmd/k8s-operator/api-server-proxy-pg.go b/cmd/k8s-operator/api-server-proxy-pg.go index 252859eb37197..1a81e4967e5d8 100644 --- a/cmd/k8s-operator/api-server-proxy-pg.go +++ b/cmd/k8s-operator/api-server-proxy-pg.go @@ -157,12 +157,6 @@ func (r *KubeAPIServerTSServiceReconciler) maybeProvision(ctx context.Context, s // 1. Check there isn't a Tailscale Service with the same hostname // already created and not owned by this ProxyGroup. existingTSSvc, err := r.tsClient.GetVIPService(ctx, serviceName) - if isErrorFeatureFlagNotEnabled(err) { - logger.Warn(msgFeatureFlagNotEnabled) - r.recorder.Event(pg, corev1.EventTypeWarning, warningTailscaleServiceFeatureFlagNotEnabled, msgFeatureFlagNotEnabled) - tsoperator.SetProxyGroupCondition(pg, tsapi.KubeAPIServerProxyValid, metav1.ConditionFalse, reasonKubeAPIServerProxyInvalid, msgFeatureFlagNotEnabled, pg.Generation, r.clock, logger) - return nil - } if err != nil && !isErrorTailscaleServiceNotFound(err) { return fmt.Errorf("error getting Tailscale Service %q: %w", serviceName, err) } diff --git a/cmd/k8s-operator/api-server-proxy-pg_test.go b/cmd/k8s-operator/api-server-proxy-pg_test.go index dfef63f22ff04..dee5057236675 100644 --- a/cmd/k8s-operator/api-server-proxy-pg_test.go +++ b/cmd/k8s-operator/api-server-proxy-pg_test.go @@ -182,9 +182,7 @@ func TestAPIServerProxyReconciler(t *testing.T) { expectEqual(t, fc, certSecretRoleBinding(pg, ns, defaultDomain)) // Simulate certs being issued; should observe AdvertiseServices config change. - if err := populateTLSSecret(t.Context(), fc, pgName, defaultDomain); err != nil { - t.Fatalf("populating TLS Secret: %v", err) - } + populateTLSSecret(t, fc, pgName, defaultDomain) expectReconciled(t, r, "", pgName) expectedCfg.AdvertiseServices = []string{"svc:" + pgName} @@ -247,9 +245,7 @@ func TestAPIServerProxyReconciler(t *testing.T) { expectMissing[rbacv1.RoleBinding](t, fc, ns, defaultDomain) // Check we get the new hostname in the status once ready. - if err := populateTLSSecret(t.Context(), fc, pgName, updatedDomain); err != nil { - t.Fatalf("populating TLS Secret: %v", err) - } + populateTLSSecret(t, fc, pgName, updatedDomain) mustUpdate(t, fc, "operator-ns", "test-pg-0", func(s *corev1.Secret) { s.Data["profile-foo"] = []byte(`{"AdvertiseServices":["svc:test-pg"],"Config":{"NodeID":"node-foo"}}`) }) diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index 6cffda2ddb2c8..c76a4236e1105 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -12,6 +12,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ github.com/coder/websocket/internal/errd from github.com/coder/websocket github.com/coder/websocket/internal/util from github.com/coder/websocket github.com/coder/websocket/internal/xsync from github.com/coder/websocket + github.com/creachadair/msync/trigger from tailscale.com/logtail 💣 github.com/davecgh/go-spew/spew from k8s.io/apimachinery/pkg/util/dump W 💣 github.com/dblohm7/wingoes from tailscale.com/net/tshttpproxy+ W 💣 github.com/dblohm7/wingoes/com from tailscale.com/util/osdiag+ @@ -70,8 +71,9 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ github.com/klauspost/compress/fse from github.com/klauspost/compress/huff0 github.com/klauspost/compress/huff0 from github.com/klauspost/compress/zstd github.com/klauspost/compress/internal/cpuinfo from github.com/klauspost/compress/huff0+ + 💣 github.com/klauspost/compress/internal/le from github.com/klauspost/compress/huff0+ github.com/klauspost/compress/internal/snapref from github.com/klauspost/compress/zstd - github.com/klauspost/compress/zstd from tailscale.com/util/zstdframe+ + github.com/klauspost/compress/zstd from tailscale.com/util/zstdframe github.com/klauspost/compress/zstd/internal/xxhash from github.com/klauspost/compress/zstd github.com/mailru/easyjson/buffer from github.com/mailru/easyjson/jwriter 💣 github.com/mailru/easyjson/jlexer from github.com/go-openapi/swag @@ -84,6 +86,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ 💣 github.com/modern-go/reflect2 from github.com/json-iterator/go github.com/munnerz/goautoneg from k8s.io/kube-openapi/pkg/handler3+ github.com/opencontainers/go-digest from github.com/distribution/reference + github.com/pires/go-proxyproto from tailscale.com/ipn/ipnlocal github.com/pkg/errors from github.com/evanphx/json-patch/v5+ D github.com/prometheus-community/pro-bing from tailscale.com/wgengine/netstack github.com/prometheus/client_golang/internal/github.com/golang/gddo/httputil from github.com/prometheus/client_golang/prometheus/promhttp @@ -92,6 +95,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ github.com/prometheus/client_golang/prometheus/collectors from sigs.k8s.io/controller-runtime/pkg/internal/controller/metrics+ github.com/prometheus/client_golang/prometheus/internal from github.com/prometheus/client_golang/prometheus+ github.com/prometheus/client_golang/prometheus/promhttp from sigs.k8s.io/controller-runtime/pkg/metrics/server+ + github.com/prometheus/client_golang/prometheus/promhttp/internal from github.com/prometheus/client_golang/prometheus/promhttp github.com/prometheus/client_model/go from github.com/prometheus/client_golang/prometheus+ github.com/prometheus/common/expfmt from github.com/prometheus/client_golang/prometheus+ github.com/prometheus/common/model from github.com/prometheus/client_golang/prometheus+ @@ -178,10 +182,10 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ google.golang.org/protobuf/reflect/protoregistry from github.com/golang/protobuf/proto+ google.golang.org/protobuf/runtime/protoiface from github.com/golang/protobuf/proto+ google.golang.org/protobuf/runtime/protoimpl from github.com/golang/protobuf/proto+ - google.golang.org/protobuf/types/descriptorpb from github.com/google/gnostic-models/openapiv3+ - google.golang.org/protobuf/types/gofeaturespb from google.golang.org/protobuf/reflect/protodesc - google.golang.org/protobuf/types/known/anypb from github.com/google/gnostic-models/compiler+ - google.golang.org/protobuf/types/known/timestamppb from github.com/prometheus/client_golang/prometheus+ + 💣 google.golang.org/protobuf/types/descriptorpb from github.com/google/gnostic-models/openapiv3+ + 💣 google.golang.org/protobuf/types/gofeaturespb from google.golang.org/protobuf/reflect/protodesc + 💣 google.golang.org/protobuf/types/known/anypb from github.com/google/gnostic-models/compiler+ + 💣 google.golang.org/protobuf/types/known/timestamppb from github.com/prometheus/client_golang/prometheus+ gopkg.in/evanphx/json-patch.v4 from k8s.io/client-go/testing gopkg.in/inf.v0 from k8s.io/apimachinery/pkg/api/resource gopkg.in/yaml.v3 from github.com/go-openapi/swag+ @@ -824,7 +828,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/tsweb from tailscale.com/util/eventbus tailscale.com/tsweb/varz from tailscale.com/util/usermetric+ tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal+ - tailscale.com/types/bools from tailscale.com/tsnet + tailscale.com/types/bools from tailscale.com/tsnet+ tailscale.com/types/dnstype from tailscale.com/ipn/ipnlocal+ tailscale.com/types/empty from tailscale.com/ipn+ tailscale.com/types/ipproto from tailscale.com/net/flowtrack+ @@ -847,7 +851,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/types/views from tailscale.com/appc+ tailscale.com/util/backoff from tailscale.com/cmd/k8s-operator+ tailscale.com/util/checkchange from tailscale.com/ipn/ipnlocal+ - tailscale.com/util/cibuild from tailscale.com/health + tailscale.com/util/cibuild from tailscale.com/health+ tailscale.com/util/clientmetric from tailscale.com/cmd/k8s-operator+ tailscale.com/util/cloudenv from tailscale.com/hostinfo+ LW tailscale.com/util/cmpver from tailscale.com/net/dns+ @@ -995,7 +999,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ crypto/ecdsa from crypto/tls+ crypto/ed25519 from crypto/tls+ crypto/elliptic from crypto/ecdsa+ - crypto/fips140 from crypto/tls/internal/fips140tls + crypto/fips140 from crypto/tls/internal/fips140tls+ crypto/hkdf from crypto/internal/hpke+ crypto/hmac from crypto/tls+ crypto/internal/boring from crypto/aes+ diff --git a/cmd/k8s-operator/deploy/chart/Chart.yaml b/cmd/k8s-operator/deploy/chart/Chart.yaml index 363d87d15954a..9db6389d1d944 100644 --- a/cmd/k8s-operator/deploy/chart/Chart.yaml +++ b/cmd/k8s-operator/deploy/chart/Chart.yaml @@ -26,4 +26,4 @@ maintainers: version: 0.1.0 # appVersion will be set to Tailscale repo tag at release time. -appVersion: "unstable" +appVersion: "stable" diff --git a/cmd/k8s-operator/deploy/chart/templates/apiserverproxy-rbac.yaml b/cmd/k8s-operator/deploy/chart/templates/apiserverproxy-rbac.yaml index ad0a6fb66f51e..d6e9d1bf48ef8 100644 --- a/cmd/k8s-operator/deploy/chart/templates/apiserverproxy-rbac.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/apiserverproxy-rbac.yaml @@ -3,8 +3,8 @@ # If old setting used, enable both old (operator) and new (ProxyGroup) workflows. # If new setting used, enable only new workflow. -{{ if or (eq .Values.apiServerProxyConfig.mode "true") - (eq .Values.apiServerProxyConfig.allowImpersonation "true") }} +{{ if or (eq (toString .Values.apiServerProxyConfig.mode) "true") + (eq (toString .Values.apiServerProxyConfig.allowImpersonation) "true") }} apiVersion: v1 kind: ServiceAccount metadata: @@ -25,7 +25,7 @@ kind: ClusterRoleBinding metadata: name: tailscale-auth-proxy subjects: -{{- if eq .Values.apiServerProxyConfig.mode "true" }} +{{- if eq (toString .Values.apiServerProxyConfig.mode) "true" }} - kind: ServiceAccount name: operator namespace: {{ .Release.Namespace }} diff --git a/cmd/k8s-operator/deploy/chart/templates/deployment.yaml b/cmd/k8s-operator/deploy/chart/templates/deployment.yaml index 51d0a88c36671..df9cb8ce1bcb0 100644 --- a/cmd/k8s-operator/deploy/chart/templates/deployment.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/deployment.yaml @@ -35,13 +35,23 @@ spec: {{- toYaml . | nindent 8 }} {{- end }} volumes: + {{- if .Values.oauthSecretVolume }} + - name: oauth + {{- toYaml .Values.oauthSecretVolume | nindent 10 }} + {{- else if .Values.oauth.audience }} + - name: oidc-jwt + projected: + defaultMode: 420 + sources: + - serviceAccountToken: + audience: {{ .Values.oauth.audience }} + expirationSeconds: 3600 + path: token + {{- else }} - name: oauth - {{- with .Values.oauthSecretVolume }} - {{- toYaml . | nindent 10 }} - {{- else }} secret: secretName: operator-oauth - {{- end }} + {{- end }} containers: - name: operator {{- with .Values.operatorConfig.securityContext }} @@ -72,10 +82,20 @@ spec: value: {{ .Values.loginServer }} - name: OPERATOR_INGRESS_CLASS_NAME value: {{ .Values.ingressClass.name }} + {{- if .Values.oauthSecretVolume }} - name: CLIENT_ID_FILE value: /oauth/client_id - name: CLIENT_SECRET_FILE value: /oauth/client_secret + {{- else if .Values.oauth.audience }} + - name: CLIENT_ID + value: {{ .Values.oauth.clientId }} + {{- else }} + - name: CLIENT_ID_FILE + value: /oauth/client_id + - name: CLIENT_SECRET_FILE + value: /oauth/client_secret + {{- end }} {{- $proxyTag := printf ":%s" ( .Values.proxyConfig.image.tag | default .Chart.AppVersion )}} - name: PROXY_IMAGE value: {{ coalesce .Values.proxyConfig.image.repo .Values.proxyConfig.image.repository }}{{- if .Values.proxyConfig.image.digest -}}{{ printf "@%s" .Values.proxyConfig.image.digest}}{{- else -}}{{ printf "%s" $proxyTag }}{{- end }} @@ -101,9 +121,19 @@ spec: {{- toYaml . | nindent 12 }} {{- end }} volumeMounts: - - name: oauth - mountPath: /oauth - readOnly: true + {{- if .Values.oauthSecretVolume }} + - name: oauth + mountPath: /oauth + readOnly: true + {{- else if .Values.oauth.audience }} + - name: oidc-jwt + mountPath: /var/run/secrets/tailscale/serviceaccount + readOnly: true + {{- else }} + - name: oauth + mountPath: /oauth + readOnly: true + {{- end }} {{- with .Values.operatorConfig.nodeSelector }} nodeSelector: {{- toYaml . | nindent 8 }} diff --git a/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml b/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml index b44fde0a17b49..759ba341a8f21 100644 --- a/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml @@ -1,7 +1,7 @@ # Copyright (c) Tailscale Inc & AUTHORS # SPDX-License-Identifier: BSD-3-Clause -{{ if and .Values.oauth .Values.oauth.clientId -}} +{{ if and .Values.oauth .Values.oauth.clientId (not .Values.oauth.audience) -}} apiVersion: v1 kind: Secret metadata: diff --git a/cmd/k8s-operator/deploy/chart/values.yaml b/cmd/k8s-operator/deploy/chart/values.yaml index 048a1d95ca2fa..cdcb2c1c50b7a 100644 --- a/cmd/k8s-operator/deploy/chart/values.yaml +++ b/cmd/k8s-operator/deploy/chart/values.yaml @@ -1,13 +1,20 @@ # Copyright (c) Tailscale Inc & AUTHORS # SPDX-License-Identifier: BSD-3-Clause -# Operator oauth credentials. If set a Kubernetes Secret with the provided -# values will be created in the operator namespace. If unset a Secret named -# operator-oauth must be precreated or oauthSecretVolume needs to be adjusted. -# This block will be overridden by oauthSecretVolume, if set. -oauth: {} - # clientId: "" - # clientSecret: "" +# Operator oauth credentials. If unset a Secret named operator-oauth must be +# precreated or oauthSecretVolume needs to be adjusted. This block will be +# overridden by oauthSecretVolume, if set. +oauth: + # The Client ID the operator will authenticate with. + clientId: "" + # If set a Kubernetes Secret with the provided value will be created in + # the operator namespace, and mounted into the operator Pod. Takes precedence + # over oauth.audience. + clientSecret: "" + # The audience for oauth.clientId if using a workload identity federation + # OAuth client. Mutually exclusive with oauth.clientSecret. + # See https://tailscale.com/kb/1581/workload-identity-federation. + audience: "" # URL of the control plane to be used by all resources managed by the operator. loginServer: "" diff --git a/cmd/k8s-operator/deploy/crds/tailscale.com_recorders.yaml b/cmd/k8s-operator/deploy/crds/tailscale.com_recorders.yaml index 0f3dcfcca52c8..48db3ef4bd84d 100644 --- a/cmd/k8s-operator/deploy/crds/tailscale.com_recorders.yaml +++ b/cmd/k8s-operator/deploy/crds/tailscale.com_recorders.yaml @@ -68,6 +68,11 @@ spec: Corresponds to --ui tsrecorder flag https://tailscale.com/kb/1246/tailscale-ssh-session-recording#deploy-a-recorder-node. Required if S3 storage is not set up, to ensure that recordings are accessible. type: boolean + replicas: + description: Replicas specifies how many instances of tsrecorder to run. Defaults to 1. + type: integer + format: int32 + minimum: 0 statefulSet: description: |- Configuration parameters for the Recorder's StatefulSet. The operator @@ -1683,6 +1688,9 @@ spec: items: type: string pattern: ^tag:[a-zA-Z][a-zA-Z0-9-]*$ + x-kubernetes-validations: + - rule: '!(self.replicas > 1 && (!has(self.storage) || !has(self.storage.s3)))' + message: S3 storage must be used when deploying multiple Recorder replicas status: description: |- RecorderStatus describes the status of the recorder. This is set diff --git a/cmd/k8s-operator/deploy/manifests/operator.yaml b/cmd/k8s-operator/deploy/manifests/operator.yaml index c7c5ef0a7d3b2..2757f09e5f36b 100644 --- a/cmd/k8s-operator/deploy/manifests/operator.yaml +++ b/cmd/k8s-operator/deploy/manifests/operator.yaml @@ -3348,6 +3348,11 @@ spec: Corresponds to --ui tsrecorder flag https://tailscale.com/kb/1246/tailscale-ssh-session-recording#deploy-a-recorder-node. Required if S3 storage is not set up, to ensure that recordings are accessible. type: boolean + replicas: + description: Replicas specifies how many instances of tsrecorder to run. Defaults to 1. + format: int32 + minimum: 0 + type: integer statefulSet: description: |- Configuration parameters for the Recorder's StatefulSet. The operator @@ -4964,6 +4969,9 @@ spec: type: string type: array type: object + x-kubernetes-validations: + - message: S3 storage must be used when deploying multiple Recorder replicas + rule: '!(self.replicas > 1 && (!has(self.storage) || !has(self.storage.s3)))' status: description: |- RecorderStatus describes the status of the recorder. This is set @@ -5366,7 +5374,7 @@ spec: - name: CLIENT_SECRET_FILE value: /oauth/client_secret - name: PROXY_IMAGE - value: tailscale/tailscale:unstable + value: tailscale/tailscale:stable - name: PROXY_TAGS value: tag:k8s - name: APISERVER_PROXY @@ -5381,7 +5389,7 @@ spec: valueFrom: fieldRef: fieldPath: metadata.uid - image: tailscale/k8s-operator:unstable + image: tailscale/k8s-operator:stable imagePullPolicy: Always name: operator volumeMounts: diff --git a/cmd/k8s-operator/egress-eps.go b/cmd/k8s-operator/egress-eps.go index 3441e12ba93ec..88da9935320bf 100644 --- a/cmd/k8s-operator/egress-eps.go +++ b/cmd/k8s-operator/egress-eps.go @@ -36,21 +36,21 @@ type egressEpsReconciler struct { // It compares tailnet service state stored in egress proxy state Secrets by containerboot with the desired // configuration stored in proxy-cfg ConfigMap to determine if the endpoint is ready. func (er *egressEpsReconciler) Reconcile(ctx context.Context, req reconcile.Request) (res reconcile.Result, err error) { - l := er.logger.With("Service", req.NamespacedName) - l.Debugf("starting reconcile") - defer l.Debugf("reconcile finished") + lg := er.logger.With("Service", req.NamespacedName) + lg.Debugf("starting reconcile") + defer lg.Debugf("reconcile finished") eps := new(discoveryv1.EndpointSlice) err = er.Get(ctx, req.NamespacedName, eps) if apierrors.IsNotFound(err) { - l.Debugf("EndpointSlice not found") + lg.Debugf("EndpointSlice not found") return reconcile.Result{}, nil } if err != nil { return reconcile.Result{}, fmt.Errorf("failed to get EndpointSlice: %w", err) } if !eps.DeletionTimestamp.IsZero() { - l.Debugf("EnpointSlice is being deleted") + lg.Debugf("EnpointSlice is being deleted") return res, nil } @@ -64,7 +64,7 @@ func (er *egressEpsReconciler) Reconcile(ctx context.Context, req reconcile.Requ } err = er.Get(ctx, client.ObjectKeyFromObject(svc), svc) if apierrors.IsNotFound(err) { - l.Infof("ExternalName Service %s/%s not found, perhaps it was deleted", svc.Namespace, svc.Name) + lg.Infof("ExternalName Service %s/%s not found, perhaps it was deleted", svc.Namespace, svc.Name) return res, nil } if err != nil { @@ -77,7 +77,7 @@ func (er *egressEpsReconciler) Reconcile(ctx context.Context, req reconcile.Requ oldEps := eps.DeepCopy() tailnetSvc := tailnetSvcName(svc) - l = l.With("tailnet-service-name", tailnetSvc) + lg = lg.With("tailnet-service-name", tailnetSvc) // Retrieve the desired tailnet service configuration from the ConfigMap. proxyGroupName := eps.Labels[labelProxyGroup] @@ -88,12 +88,12 @@ func (er *egressEpsReconciler) Reconcile(ctx context.Context, req reconcile.Requ if cfgs == nil { // TODO(irbekrm): this path would be hit if egress service was once exposed on a ProxyGroup that later // got deleted. Probably the EndpointSlices then need to be deleted too- need to rethink this flow. - l.Debugf("No egress config found, likely because ProxyGroup has not been created") + lg.Debugf("No egress config found, likely because ProxyGroup has not been created") return res, nil } cfg, ok := (*cfgs)[tailnetSvc] if !ok { - l.Infof("[unexpected] configuration for tailnet service %s not found", tailnetSvc) + lg.Infof("[unexpected] configuration for tailnet service %s not found", tailnetSvc) return res, nil } @@ -105,7 +105,7 @@ func (er *egressEpsReconciler) Reconcile(ctx context.Context, req reconcile.Requ } newEndpoints := make([]discoveryv1.Endpoint, 0) for _, pod := range podList.Items { - ready, err := er.podIsReadyToRouteTraffic(ctx, pod, &cfg, tailnetSvc, l) + ready, err := er.podIsReadyToRouteTraffic(ctx, pod, &cfg, tailnetSvc, lg) if err != nil { return res, fmt.Errorf("error verifying if Pod is ready to route traffic: %w", err) } @@ -130,7 +130,7 @@ func (er *egressEpsReconciler) Reconcile(ctx context.Context, req reconcile.Requ // run a cleanup for deleted Pods etc. eps.Endpoints = newEndpoints if !reflect.DeepEqual(eps, oldEps) { - l.Infof("Updating EndpointSlice to ensure traffic is routed to ready proxy Pods") + lg.Infof("Updating EndpointSlice to ensure traffic is routed to ready proxy Pods") if err := er.Update(ctx, eps); err != nil { return res, fmt.Errorf("error updating EndpointSlice: %w", err) } @@ -154,11 +154,11 @@ func podIPv4(pod *corev1.Pod) (string, error) { // podIsReadyToRouteTraffic returns true if it appears that the proxy Pod has configured firewall rules to be able to // route traffic to the given tailnet service. It retrieves the proxy's state Secret and compares the tailnet service // status written there to the desired service configuration. -func (er *egressEpsReconciler) podIsReadyToRouteTraffic(ctx context.Context, pod corev1.Pod, cfg *egressservices.Config, tailnetSvcName string, l *zap.SugaredLogger) (bool, error) { - l = l.With("proxy_pod", pod.Name) - l.Debugf("checking whether proxy is ready to route to egress service") +func (er *egressEpsReconciler) podIsReadyToRouteTraffic(ctx context.Context, pod corev1.Pod, cfg *egressservices.Config, tailnetSvcName string, lg *zap.SugaredLogger) (bool, error) { + lg = lg.With("proxy_pod", pod.Name) + lg.Debugf("checking whether proxy is ready to route to egress service") if !pod.DeletionTimestamp.IsZero() { - l.Debugf("proxy Pod is being deleted, ignore") + lg.Debugf("proxy Pod is being deleted, ignore") return false, nil } podIP, err := podIPv4(&pod) @@ -166,7 +166,7 @@ func (er *egressEpsReconciler) podIsReadyToRouteTraffic(ctx context.Context, pod return false, fmt.Errorf("error determining Pod IP address: %v", err) } if podIP == "" { - l.Infof("[unexpected] Pod does not have an IPv4 address, and IPv6 is not currently supported") + lg.Infof("[unexpected] Pod does not have an IPv4 address, and IPv6 is not currently supported") return false, nil } stateS := &corev1.Secret{ @@ -177,7 +177,7 @@ func (er *egressEpsReconciler) podIsReadyToRouteTraffic(ctx context.Context, pod } err = er.Get(ctx, client.ObjectKeyFromObject(stateS), stateS) if apierrors.IsNotFound(err) { - l.Debugf("proxy does not have a state Secret, waiting...") + lg.Debugf("proxy does not have a state Secret, waiting...") return false, nil } if err != nil { @@ -185,7 +185,7 @@ func (er *egressEpsReconciler) podIsReadyToRouteTraffic(ctx context.Context, pod } svcStatusBS := stateS.Data[egressservices.KeyEgressServices] if len(svcStatusBS) == 0 { - l.Debugf("proxy's state Secret does not contain egress services status, waiting...") + lg.Debugf("proxy's state Secret does not contain egress services status, waiting...") return false, nil } svcStatus := &egressservices.Status{} @@ -193,22 +193,22 @@ func (er *egressEpsReconciler) podIsReadyToRouteTraffic(ctx context.Context, pod return false, fmt.Errorf("error unmarshalling egress service status: %w", err) } if !strings.EqualFold(podIP, svcStatus.PodIPv4) { - l.Infof("proxy's egress service status is for Pod IP %s, current proxy's Pod IP %s, waiting for the proxy to reconfigure...", svcStatus.PodIPv4, podIP) + lg.Infof("proxy's egress service status is for Pod IP %s, current proxy's Pod IP %s, waiting for the proxy to reconfigure...", svcStatus.PodIPv4, podIP) return false, nil } st, ok := (*svcStatus).Services[tailnetSvcName] if !ok { - l.Infof("proxy's state Secret does not have egress service status, waiting...") + lg.Infof("proxy's state Secret does not have egress service status, waiting...") return false, nil } if !reflect.DeepEqual(cfg.TailnetTarget, st.TailnetTarget) { - l.Infof("proxy has configured egress service for tailnet target %v, current target is %v, waiting for proxy to reconfigure...", st.TailnetTarget, cfg.TailnetTarget) + lg.Infof("proxy has configured egress service for tailnet target %v, current target is %v, waiting for proxy to reconfigure...", st.TailnetTarget, cfg.TailnetTarget) return false, nil } if !reflect.DeepEqual(cfg.Ports, st.Ports) { - l.Debugf("proxy has configured egress service for ports %#+v, wants ports %#+v, waiting for proxy to reconfigure", st.Ports, cfg.Ports) + lg.Debugf("proxy has configured egress service for ports %#+v, wants ports %#+v, waiting for proxy to reconfigure", st.Ports, cfg.Ports) return false, nil } - l.Debugf("proxy is ready to route traffic to egress service") + lg.Debugf("proxy is ready to route traffic to egress service") return true, nil } diff --git a/cmd/k8s-operator/egress-pod-readiness.go b/cmd/k8s-operator/egress-pod-readiness.go index f3a812ecb9030..a732e08612c86 100644 --- a/cmd/k8s-operator/egress-pod-readiness.go +++ b/cmd/k8s-operator/egress-pod-readiness.go @@ -71,9 +71,9 @@ type egressPodsReconciler struct { // If the Pod does not appear to be serving the health check endpoint (pre-v1.80 proxies), the reconciler just sets the // readiness condition for backwards compatibility reasons. func (er *egressPodsReconciler) Reconcile(ctx context.Context, req reconcile.Request) (res reconcile.Result, err error) { - l := er.logger.With("Pod", req.NamespacedName) - l.Debugf("starting reconcile") - defer l.Debugf("reconcile finished") + lg := er.logger.With("Pod", req.NamespacedName) + lg.Debugf("starting reconcile") + defer lg.Debugf("reconcile finished") pod := new(corev1.Pod) err = er.Get(ctx, req.NamespacedName, pod) @@ -84,11 +84,11 @@ func (er *egressPodsReconciler) Reconcile(ctx context.Context, req reconcile.Req return reconcile.Result{}, fmt.Errorf("failed to get Pod: %w", err) } if !pod.DeletionTimestamp.IsZero() { - l.Debugf("Pod is being deleted, do nothing") + lg.Debugf("Pod is being deleted, do nothing") return res, nil } if pod.Labels[LabelParentType] != proxyTypeProxyGroup { - l.Infof("[unexpected] reconciler called for a Pod that is not a ProxyGroup Pod") + lg.Infof("[unexpected] reconciler called for a Pod that is not a ProxyGroup Pod") return res, nil } @@ -97,7 +97,7 @@ func (er *egressPodsReconciler) Reconcile(ctx context.Context, req reconcile.Req if !slices.ContainsFunc(pod.Spec.ReadinessGates, func(r corev1.PodReadinessGate) bool { return r.ConditionType == tsEgressReadinessGate }) { - l.Debug("Pod does not have egress readiness gate set, skipping") + lg.Debug("Pod does not have egress readiness gate set, skipping") return res, nil } @@ -107,7 +107,7 @@ func (er *egressPodsReconciler) Reconcile(ctx context.Context, req reconcile.Req return res, fmt.Errorf("error getting ProxyGroup %q: %w", proxyGroupName, err) } if pg.Spec.Type != typeEgress { - l.Infof("[unexpected] reconciler called for %q ProxyGroup Pod", pg.Spec.Type) + lg.Infof("[unexpected] reconciler called for %q ProxyGroup Pod", pg.Spec.Type) return res, nil } // Get all ClusterIP Services for all egress targets exposed to cluster via this ProxyGroup. @@ -125,7 +125,7 @@ func (er *egressPodsReconciler) Reconcile(ctx context.Context, req reconcile.Req return c.Type == tsEgressReadinessGate }) if idx != -1 { - l.Debugf("Pod is already ready, do nothing") + lg.Debugf("Pod is already ready, do nothing") return res, nil } @@ -134,7 +134,7 @@ func (er *egressPodsReconciler) Reconcile(ctx context.Context, req reconcile.Req for _, svc := range svcs.Items { s := svc go func() { - ll := l.With("service_name", s.Name) + ll := lg.With("service_name", s.Name) d := retrieveClusterDomain(er.tsNamespace, ll) healthCheckAddr := healthCheckForSvc(&s, d) if healthCheckAddr == "" { @@ -178,22 +178,22 @@ func (er *egressPodsReconciler) Reconcile(ctx context.Context, req reconcile.Req return res, fmt.Errorf("error verifying conectivity: %w", err) } if rm := routesMissing.Load(); rm { - l.Info("Pod is not yet added as an endpoint for all egress targets, waiting...") + lg.Info("Pod is not yet added as an endpoint for all egress targets, waiting...") return reconcile.Result{RequeueAfter: shortRequeue}, nil } - if err := er.setPodReady(ctx, pod, l); err != nil { + if err := er.setPodReady(ctx, pod, lg); err != nil { return res, fmt.Errorf("error setting Pod as ready: %w", err) } return res, nil } -func (er *egressPodsReconciler) setPodReady(ctx context.Context, pod *corev1.Pod, l *zap.SugaredLogger) error { +func (er *egressPodsReconciler) setPodReady(ctx context.Context, pod *corev1.Pod, lg *zap.SugaredLogger) error { if slices.ContainsFunc(pod.Status.Conditions, func(c corev1.PodCondition) bool { return c.Type == tsEgressReadinessGate }) { return nil } - l.Infof("Pod is ready to route traffic to all egress targets") + lg.Infof("Pod is ready to route traffic to all egress targets") pod.Status.Conditions = append(pod.Status.Conditions, corev1.PodCondition{ Type: tsEgressReadinessGate, Status: corev1.ConditionTrue, @@ -216,11 +216,11 @@ const ( ) // lookupPodRouteViaSvc attempts to reach a Pod using a health check endpoint served by a Service and returns the state of the health check. -func (er *egressPodsReconciler) lookupPodRouteViaSvc(ctx context.Context, pod *corev1.Pod, healthCheckAddr string, l *zap.SugaredLogger) (healthCheckState, error) { +func (er *egressPodsReconciler) lookupPodRouteViaSvc(ctx context.Context, pod *corev1.Pod, healthCheckAddr string, lg *zap.SugaredLogger) (healthCheckState, error) { if !slices.ContainsFunc(pod.Spec.Containers[0].Env, func(e corev1.EnvVar) bool { return e.Name == "TS_ENABLE_HEALTH_CHECK" && e.Value == "true" }) { - l.Debugf("Pod does not have health check enabled, unable to verify if it is currently routable via Service") + lg.Debugf("Pod does not have health check enabled, unable to verify if it is currently routable via Service") return cannotVerify, nil } wantsIP, err := podIPv4(pod) @@ -248,7 +248,7 @@ func (er *egressPodsReconciler) lookupPodRouteViaSvc(ctx context.Context, pod *c defer resp.Body.Close() gotIP := resp.Header.Get(kubetypes.PodIPv4Header) if gotIP == "" { - l.Debugf("Health check does not return Pod's IP header, unable to verify if Pod is currently routable via Service") + lg.Debugf("Health check does not return Pod's IP header, unable to verify if Pod is currently routable via Service") return cannotVerify, nil } if !strings.EqualFold(wantsIP, gotIP) { diff --git a/cmd/k8s-operator/egress-services-readiness.go b/cmd/k8s-operator/egress-services-readiness.go index ecf99b63cda44..80f3c7d285141 100644 --- a/cmd/k8s-operator/egress-services-readiness.go +++ b/cmd/k8s-operator/egress-services-readiness.go @@ -47,13 +47,13 @@ type egressSvcsReadinessReconciler struct { // route traffic to the target. It compares proxy Pod IPs with the endpoints set on the EndpointSlice for the egress // service to determine how many replicas are currently able to route traffic. func (esrr *egressSvcsReadinessReconciler) Reconcile(ctx context.Context, req reconcile.Request) (res reconcile.Result, err error) { - l := esrr.logger.With("Service", req.NamespacedName) - l.Debugf("starting reconcile") - defer l.Debugf("reconcile finished") + lg := esrr.logger.With("Service", req.NamespacedName) + lg.Debugf("starting reconcile") + defer lg.Debugf("reconcile finished") svc := new(corev1.Service) if err = esrr.Get(ctx, req.NamespacedName, svc); apierrors.IsNotFound(err) { - l.Debugf("Service not found") + lg.Debugf("Service not found") return res, nil } else if err != nil { return res, fmt.Errorf("failed to get Service: %w", err) @@ -64,7 +64,7 @@ func (esrr *egressSvcsReadinessReconciler) Reconcile(ctx context.Context, req re ) oldStatus := svc.Status.DeepCopy() defer func() { - tsoperator.SetServiceCondition(svc, tsapi.EgressSvcReady, st, reason, msg, esrr.clock, l) + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcReady, st, reason, msg, esrr.clock, lg) if !apiequality.Semantic.DeepEqual(oldStatus, &svc.Status) { err = errors.Join(err, esrr.Status().Update(ctx, svc)) } @@ -79,7 +79,7 @@ func (esrr *egressSvcsReadinessReconciler) Reconcile(ctx context.Context, req re return res, err } if eps == nil { - l.Infof("EndpointSlice for Service does not yet exist, waiting...") + lg.Infof("EndpointSlice for Service does not yet exist, waiting...") reason, msg = reasonClusterResourcesNotReady, reasonClusterResourcesNotReady st = metav1.ConditionFalse return res, nil @@ -91,7 +91,7 @@ func (esrr *egressSvcsReadinessReconciler) Reconcile(ctx context.Context, req re } err = esrr.Get(ctx, client.ObjectKeyFromObject(pg), pg) if apierrors.IsNotFound(err) { - l.Infof("ProxyGroup for Service does not exist, waiting...") + lg.Infof("ProxyGroup for Service does not exist, waiting...") reason, msg = reasonClusterResourcesNotReady, reasonClusterResourcesNotReady st = metav1.ConditionFalse return res, nil @@ -103,7 +103,7 @@ func (esrr *egressSvcsReadinessReconciler) Reconcile(ctx context.Context, req re return res, err } if !tsoperator.ProxyGroupAvailable(pg) { - l.Infof("ProxyGroup for Service is not ready, waiting...") + lg.Infof("ProxyGroup for Service is not ready, waiting...") reason, msg = reasonClusterResourcesNotReady, reasonClusterResourcesNotReady st = metav1.ConditionFalse return res, nil @@ -111,7 +111,7 @@ func (esrr *egressSvcsReadinessReconciler) Reconcile(ctx context.Context, req re replicas := pgReplicas(pg) if replicas == 0 { - l.Infof("ProxyGroup replicas set to 0") + lg.Infof("ProxyGroup replicas set to 0") reason, msg = reasonNoProxies, reasonNoProxies st = metav1.ConditionFalse return res, nil @@ -128,16 +128,16 @@ func (esrr *egressSvcsReadinessReconciler) Reconcile(ctx context.Context, req re return res, err } if pod == nil { - l.Warnf("[unexpected] ProxyGroup is ready, but replica %d was not found", i) + lg.Warnf("[unexpected] ProxyGroup is ready, but replica %d was not found", i) reason, msg = reasonClusterResourcesNotReady, reasonClusterResourcesNotReady return res, nil } - l.Debugf("looking at Pod with IPs %v", pod.Status.PodIPs) + lg.Debugf("looking at Pod with IPs %v", pod.Status.PodIPs) ready := false for _, ep := range eps.Endpoints { - l.Debugf("looking at endpoint with addresses %v", ep.Addresses) - if endpointReadyForPod(&ep, pod, l) { - l.Debugf("endpoint is ready for Pod") + lg.Debugf("looking at endpoint with addresses %v", ep.Addresses) + if endpointReadyForPod(&ep, pod, lg) { + lg.Debugf("endpoint is ready for Pod") ready = true break } @@ -163,10 +163,10 @@ func (esrr *egressSvcsReadinessReconciler) Reconcile(ctx context.Context, req re // endpointReadyForPod returns true if the endpoint is for the Pod's IPv4 address and is ready to serve traffic. // Endpoint must not be nil. -func endpointReadyForPod(ep *discoveryv1.Endpoint, pod *corev1.Pod, l *zap.SugaredLogger) bool { +func endpointReadyForPod(ep *discoveryv1.Endpoint, pod *corev1.Pod, lg *zap.SugaredLogger) bool { podIP, err := podIPv4(pod) if err != nil { - l.Warnf("[unexpected] error retrieving Pod's IPv4 address: %v", err) + lg.Warnf("[unexpected] error retrieving Pod's IPv4 address: %v", err) return false } // Currently we only ever set a single address on and Endpoint and nothing else is meant to modify this. diff --git a/cmd/k8s-operator/egress-services-readiness_test.go b/cmd/k8s-operator/egress-services-readiness_test.go index f80759aef927b..fdff4fafa3240 100644 --- a/cmd/k8s-operator/egress-services-readiness_test.go +++ b/cmd/k8s-operator/egress-services-readiness_test.go @@ -49,12 +49,12 @@ func TestEgressServiceReadiness(t *testing.T) { }, } fakeClusterIPSvc := &corev1.Service{ObjectMeta: metav1.ObjectMeta{Name: "my-app", Namespace: "operator-ns"}} - l := egressSvcEpsLabels(egressSvc, fakeClusterIPSvc) + labels := egressSvcEpsLabels(egressSvc, fakeClusterIPSvc) eps := &discoveryv1.EndpointSlice{ ObjectMeta: metav1.ObjectMeta{ Name: "my-app", Namespace: "operator-ns", - Labels: l, + Labels: labels, }, AddressType: discoveryv1.AddressTypeIPv4, } @@ -118,26 +118,26 @@ func TestEgressServiceReadiness(t *testing.T) { }) } -func setClusterNotReady(svc *corev1.Service, cl tstime.Clock, l *zap.SugaredLogger) { - tsoperator.SetServiceCondition(svc, tsapi.EgressSvcReady, metav1.ConditionFalse, reasonClusterResourcesNotReady, reasonClusterResourcesNotReady, cl, l) +func setClusterNotReady(svc *corev1.Service, cl tstime.Clock, lg *zap.SugaredLogger) { + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcReady, metav1.ConditionFalse, reasonClusterResourcesNotReady, reasonClusterResourcesNotReady, cl, lg) } -func setNotReady(svc *corev1.Service, cl tstime.Clock, l *zap.SugaredLogger, replicas int32) { +func setNotReady(svc *corev1.Service, cl tstime.Clock, lg *zap.SugaredLogger, replicas int32) { msg := fmt.Sprintf(msgReadyToRouteTemplate, 0, replicas) - tsoperator.SetServiceCondition(svc, tsapi.EgressSvcReady, metav1.ConditionFalse, reasonNotReady, msg, cl, l) + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcReady, metav1.ConditionFalse, reasonNotReady, msg, cl, lg) } -func setReady(svc *corev1.Service, cl tstime.Clock, l *zap.SugaredLogger, replicas, readyReplicas int32) { +func setReady(svc *corev1.Service, cl tstime.Clock, lg *zap.SugaredLogger, replicas, readyReplicas int32) { reason := reasonPartiallyReady if readyReplicas == replicas { reason = reasonReady } msg := fmt.Sprintf(msgReadyToRouteTemplate, readyReplicas, replicas) - tsoperator.SetServiceCondition(svc, tsapi.EgressSvcReady, metav1.ConditionTrue, reason, msg, cl, l) + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcReady, metav1.ConditionTrue, reason, msg, cl, lg) } -func setPGReady(pg *tsapi.ProxyGroup, cl tstime.Clock, l *zap.SugaredLogger) { - tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupAvailable, metav1.ConditionTrue, "foo", "foo", pg.Generation, cl, l) +func setPGReady(pg *tsapi.ProxyGroup, cl tstime.Clock, lg *zap.SugaredLogger) { + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupAvailable, metav1.ConditionTrue, "foo", "foo", pg.Generation, cl, lg) } func setEndpointForReplica(pg *tsapi.ProxyGroup, ordinal int32, eps *discoveryv1.EndpointSlice) { @@ -153,14 +153,14 @@ func setEndpointForReplica(pg *tsapi.ProxyGroup, ordinal int32, eps *discoveryv1 } func pod(pg *tsapi.ProxyGroup, ordinal int32) *corev1.Pod { - l := pgLabels(pg.Name, nil) - l[appsv1.PodIndexLabel] = fmt.Sprintf("%d", ordinal) + labels := pgLabels(pg.Name, nil) + labels[appsv1.PodIndexLabel] = fmt.Sprintf("%d", ordinal) ip := fmt.Sprintf("10.0.0.%d", ordinal) return &corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ Name: fmt.Sprintf("%s-%d", pg.Name, ordinal), Namespace: "operator-ns", - Labels: l, + Labels: labels, }, Status: corev1.PodStatus{ PodIPs: []corev1.PodIP{{IP: ip}}, diff --git a/cmd/k8s-operator/egress-services.go b/cmd/k8s-operator/egress-services.go index ca6562071eba7..05be8efed9402 100644 --- a/cmd/k8s-operator/egress-services.go +++ b/cmd/k8s-operator/egress-services.go @@ -98,12 +98,12 @@ type egressSvcsReconciler struct { // - updates the egress service config in a ConfigMap mounted to the ProxyGroup proxies with the tailnet target and the // portmappings. func (esr *egressSvcsReconciler) Reconcile(ctx context.Context, req reconcile.Request) (res reconcile.Result, err error) { - l := esr.logger.With("Service", req.NamespacedName) - defer l.Info("reconcile finished") + lg := esr.logger.With("Service", req.NamespacedName) + defer lg.Info("reconcile finished") svc := new(corev1.Service) if err = esr.Get(ctx, req.NamespacedName, svc); apierrors.IsNotFound(err) { - l.Info("Service not found") + lg.Info("Service not found") return res, nil } else if err != nil { return res, fmt.Errorf("failed to get Service: %w", err) @@ -111,7 +111,7 @@ func (esr *egressSvcsReconciler) Reconcile(ctx context.Context, req reconcile.Re // Name of the 'egress service', meaning the tailnet target. tailnetSvc := tailnetSvcName(svc) - l = l.With("tailnet-service", tailnetSvc) + lg = lg.With("tailnet-service", tailnetSvc) // Note that resources for egress Services are only cleaned up when the // Service is actually deleted (and not if, for example, user decides to @@ -119,8 +119,8 @@ func (esr *egressSvcsReconciler) Reconcile(ctx context.Context, req reconcile.Re // assume that the egress ExternalName Services are always created for // Tailscale operator specifically. if !svc.DeletionTimestamp.IsZero() { - l.Info("Service is being deleted, ensuring resource cleanup") - return res, esr.maybeCleanup(ctx, svc, l) + lg.Info("Service is being deleted, ensuring resource cleanup") + return res, esr.maybeCleanup(ctx, svc, lg) } oldStatus := svc.Status.DeepCopy() @@ -131,7 +131,7 @@ func (esr *egressSvcsReconciler) Reconcile(ctx context.Context, req reconcile.Re }() // Validate the user-created ExternalName Service and the associated ProxyGroup. - if ok, err := esr.validateClusterResources(ctx, svc, l); err != nil { + if ok, err := esr.validateClusterResources(ctx, svc, lg); err != nil { return res, fmt.Errorf("error validating cluster resources: %w", err) } else if !ok { return res, nil @@ -141,8 +141,8 @@ func (esr *egressSvcsReconciler) Reconcile(ctx context.Context, req reconcile.Re svc.Finalizers = append(svc.Finalizers, FinalizerName) if err := esr.updateSvcSpec(ctx, svc); err != nil { err := fmt.Errorf("failed to add finalizer: %w", err) - r := svcConfiguredReason(svc, false, l) - tsoperator.SetServiceCondition(svc, tsapi.EgressSvcConfigured, metav1.ConditionFalse, r, err.Error(), esr.clock, l) + r := svcConfiguredReason(svc, false, lg) + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcConfigured, metav1.ConditionFalse, r, err.Error(), esr.clock, lg) return res, err } esr.mu.Lock() @@ -151,16 +151,16 @@ func (esr *egressSvcsReconciler) Reconcile(ctx context.Context, req reconcile.Re esr.mu.Unlock() } - if err := esr.maybeCleanupProxyGroupConfig(ctx, svc, l); err != nil { + if err := esr.maybeCleanupProxyGroupConfig(ctx, svc, lg); err != nil { err = fmt.Errorf("cleaning up resources for previous ProxyGroup failed: %w", err) - r := svcConfiguredReason(svc, false, l) - tsoperator.SetServiceCondition(svc, tsapi.EgressSvcConfigured, metav1.ConditionFalse, r, err.Error(), esr.clock, l) + r := svcConfiguredReason(svc, false, lg) + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcConfigured, metav1.ConditionFalse, r, err.Error(), esr.clock, lg) return res, err } - if err := esr.maybeProvision(ctx, svc, l); err != nil { + if err := esr.maybeProvision(ctx, svc, lg); err != nil { if strings.Contains(err.Error(), optimisticLockErrorMsg) { - l.Infof("optimistic lock error, retrying: %s", err) + lg.Infof("optimistic lock error, retrying: %s", err) } else { return reconcile.Result{}, err } @@ -169,15 +169,15 @@ func (esr *egressSvcsReconciler) Reconcile(ctx context.Context, req reconcile.Re return res, nil } -func (esr *egressSvcsReconciler) maybeProvision(ctx context.Context, svc *corev1.Service, l *zap.SugaredLogger) (err error) { - r := svcConfiguredReason(svc, false, l) +func (esr *egressSvcsReconciler) maybeProvision(ctx context.Context, svc *corev1.Service, lg *zap.SugaredLogger) (err error) { + r := svcConfiguredReason(svc, false, lg) st := metav1.ConditionFalse defer func() { msg := r if st != metav1.ConditionTrue && err != nil { msg = err.Error() } - tsoperator.SetServiceCondition(svc, tsapi.EgressSvcConfigured, st, r, msg, esr.clock, l) + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcConfigured, st, r, msg, esr.clock, lg) }() crl := egressSvcChildResourceLabels(svc) @@ -189,36 +189,36 @@ func (esr *egressSvcsReconciler) maybeProvision(ctx context.Context, svc *corev1 if clusterIPSvc == nil { clusterIPSvc = esr.clusterIPSvcForEgress(crl) } - upToDate := svcConfigurationUpToDate(svc, l) + upToDate := svcConfigurationUpToDate(svc, lg) provisioned := true if !upToDate { - if clusterIPSvc, provisioned, err = esr.provision(ctx, svc.Annotations[AnnotationProxyGroup], svc, clusterIPSvc, l); err != nil { + if clusterIPSvc, provisioned, err = esr.provision(ctx, svc.Annotations[AnnotationProxyGroup], svc, clusterIPSvc, lg); err != nil { return err } } if !provisioned { - l.Infof("unable to provision cluster resources") + lg.Infof("unable to provision cluster resources") return nil } // Update ExternalName Service to point at the ClusterIP Service. - clusterDomain := retrieveClusterDomain(esr.tsNamespace, l) + clusterDomain := retrieveClusterDomain(esr.tsNamespace, lg) clusterIPSvcFQDN := fmt.Sprintf("%s.%s.svc.%s", clusterIPSvc.Name, clusterIPSvc.Namespace, clusterDomain) if svc.Spec.ExternalName != clusterIPSvcFQDN { - l.Infof("Configuring ExternalName Service to point to ClusterIP Service %s", clusterIPSvcFQDN) + lg.Infof("Configuring ExternalName Service to point to ClusterIP Service %s", clusterIPSvcFQDN) svc.Spec.ExternalName = clusterIPSvcFQDN if err = esr.updateSvcSpec(ctx, svc); err != nil { err = fmt.Errorf("error updating ExternalName Service: %w", err) return err } } - r = svcConfiguredReason(svc, true, l) + r = svcConfiguredReason(svc, true, lg) st = metav1.ConditionTrue return nil } -func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName string, svc, clusterIPSvc *corev1.Service, l *zap.SugaredLogger) (*corev1.Service, bool, error) { - l.Infof("updating configuration...") +func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName string, svc, clusterIPSvc *corev1.Service, lg *zap.SugaredLogger) (*corev1.Service, bool, error) { + lg.Infof("updating configuration...") usedPorts, err := esr.usedPortsForPG(ctx, proxyGroupName) if err != nil { return nil, false, fmt.Errorf("error calculating used ports for ProxyGroup %s: %w", proxyGroupName, err) @@ -246,7 +246,7 @@ func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName s } } if !found { - l.Debugf("portmapping %s:%d -> %s:%d is no longer required, removing", pm.Protocol, pm.TargetPort.IntVal, pm.Protocol, pm.Port) + lg.Debugf("portmapping %s:%d -> %s:%d is no longer required, removing", pm.Protocol, pm.TargetPort.IntVal, pm.Protocol, pm.Port) clusterIPSvc.Spec.Ports = slices.Delete(clusterIPSvc.Spec.Ports, i, i+1) } } @@ -277,7 +277,7 @@ func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName s return nil, false, fmt.Errorf("unable to allocate additional ports on ProxyGroup %s, %d ports already used. Create another ProxyGroup or open an issue if you believe this is unexpected.", proxyGroupName, maxPorts) } p := unusedPort(usedPorts) - l.Debugf("mapping tailnet target port %d to container port %d", wantsPM.Port, p) + lg.Debugf("mapping tailnet target port %d to container port %d", wantsPM.Port, p) usedPorts.Insert(p) clusterIPSvc.Spec.Ports = append(clusterIPSvc.Spec.Ports, corev1.ServicePort{ Name: wantsPM.Name, @@ -343,14 +343,14 @@ func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName s return nil, false, fmt.Errorf("error retrieving egress services configuration: %w", err) } if cm == nil { - l.Info("ConfigMap not yet created, waiting..") + lg.Info("ConfigMap not yet created, waiting..") return nil, false, nil } tailnetSvc := tailnetSvcName(svc) gotCfg := (*cfgs)[tailnetSvc] - wantsCfg := egressSvcCfg(svc, clusterIPSvc, esr.tsNamespace, l) + wantsCfg := egressSvcCfg(svc, clusterIPSvc, esr.tsNamespace, lg) if !reflect.DeepEqual(gotCfg, wantsCfg) { - l.Debugf("updating egress services ConfigMap %s", cm.Name) + lg.Debugf("updating egress services ConfigMap %s", cm.Name) mak.Set(cfgs, tailnetSvc, wantsCfg) bs, err := json.Marshal(cfgs) if err != nil { @@ -361,7 +361,7 @@ func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName s return nil, false, fmt.Errorf("error updating egress services ConfigMap: %w", err) } } - l.Infof("egress service configuration has been updated") + lg.Infof("egress service configuration has been updated") return clusterIPSvc, true, nil } @@ -402,7 +402,7 @@ func (esr *egressSvcsReconciler) maybeCleanup(ctx context.Context, svc *corev1.S return nil } -func (esr *egressSvcsReconciler) maybeCleanupProxyGroupConfig(ctx context.Context, svc *corev1.Service, l *zap.SugaredLogger) error { +func (esr *egressSvcsReconciler) maybeCleanupProxyGroupConfig(ctx context.Context, svc *corev1.Service, lg *zap.SugaredLogger) error { wantsProxyGroup := svc.Annotations[AnnotationProxyGroup] cond := tsoperator.GetServiceCondition(svc, tsapi.EgressSvcConfigured) if cond == nil { @@ -416,7 +416,7 @@ func (esr *egressSvcsReconciler) maybeCleanupProxyGroupConfig(ctx context.Contex return nil } esr.logger.Infof("egress Service configured on ProxyGroup %s, wants ProxyGroup %s, cleaning up...", ss[2], wantsProxyGroup) - if err := esr.ensureEgressSvcCfgDeleted(ctx, svc, l); err != nil { + if err := esr.ensureEgressSvcCfgDeleted(ctx, svc, lg); err != nil { return fmt.Errorf("error deleting egress service config: %w", err) } return nil @@ -471,17 +471,17 @@ func (esr *egressSvcsReconciler) ensureEgressSvcCfgDeleted(ctx context.Context, Namespace: esr.tsNamespace, }, } - l := logger.With("ConfigMap", client.ObjectKeyFromObject(cm)) - l.Debug("ensuring that egress service configuration is removed from proxy config") + lggr := logger.With("ConfigMap", client.ObjectKeyFromObject(cm)) + lggr.Debug("ensuring that egress service configuration is removed from proxy config") if err := esr.Get(ctx, client.ObjectKeyFromObject(cm), cm); apierrors.IsNotFound(err) { - l.Debugf("ConfigMap not found") + lggr.Debugf("ConfigMap not found") return nil } else if err != nil { return fmt.Errorf("error retrieving ConfigMap: %w", err) } bs := cm.BinaryData[egressservices.KeyEgressServices] if len(bs) == 0 { - l.Debugf("ConfigMap does not contain egress service configs") + lggr.Debugf("ConfigMap does not contain egress service configs") return nil } cfgs := &egressservices.Configs{} @@ -491,12 +491,12 @@ func (esr *egressSvcsReconciler) ensureEgressSvcCfgDeleted(ctx context.Context, tailnetSvc := tailnetSvcName(svc) _, ok := (*cfgs)[tailnetSvc] if !ok { - l.Debugf("ConfigMap does not contain egress service config, likely because it was already deleted") + lggr.Debugf("ConfigMap does not contain egress service config, likely because it was already deleted") return nil } - l.Infof("before deleting config %+#v", *cfgs) + lggr.Infof("before deleting config %+#v", *cfgs) delete(*cfgs, tailnetSvc) - l.Infof("after deleting config %+#v", *cfgs) + lggr.Infof("after deleting config %+#v", *cfgs) bs, err := json.Marshal(cfgs) if err != nil { return fmt.Errorf("error marshalling egress services configs: %w", err) @@ -505,7 +505,7 @@ func (esr *egressSvcsReconciler) ensureEgressSvcCfgDeleted(ctx context.Context, return esr.Update(ctx, cm) } -func (esr *egressSvcsReconciler) validateClusterResources(ctx context.Context, svc *corev1.Service, l *zap.SugaredLogger) (bool, error) { +func (esr *egressSvcsReconciler) validateClusterResources(ctx context.Context, svc *corev1.Service, lg *zap.SugaredLogger) (bool, error) { proxyGroupName := svc.Annotations[AnnotationProxyGroup] pg := &tsapi.ProxyGroup{ ObjectMeta: metav1.ObjectMeta{ @@ -513,36 +513,36 @@ func (esr *egressSvcsReconciler) validateClusterResources(ctx context.Context, s }, } if err := esr.Get(ctx, client.ObjectKeyFromObject(pg), pg); apierrors.IsNotFound(err) { - l.Infof("ProxyGroup %q not found, waiting...", proxyGroupName) - tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionUnknown, reasonProxyGroupNotReady, reasonProxyGroupNotReady, esr.clock, l) + lg.Infof("ProxyGroup %q not found, waiting...", proxyGroupName) + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionUnknown, reasonProxyGroupNotReady, reasonProxyGroupNotReady, esr.clock, lg) tsoperator.RemoveServiceCondition(svc, tsapi.EgressSvcConfigured) return false, nil } else if err != nil { err := fmt.Errorf("unable to retrieve ProxyGroup %s: %w", proxyGroupName, err) - tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionUnknown, reasonProxyGroupNotReady, err.Error(), esr.clock, l) + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionUnknown, reasonProxyGroupNotReady, err.Error(), esr.clock, lg) tsoperator.RemoveServiceCondition(svc, tsapi.EgressSvcConfigured) return false, err } if violations := validateEgressService(svc, pg); len(violations) > 0 { msg := fmt.Sprintf("invalid egress Service: %s", strings.Join(violations, ", ")) esr.recorder.Event(svc, corev1.EventTypeWarning, "INVALIDSERVICE", msg) - l.Info(msg) - tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionFalse, reasonEgressSvcInvalid, msg, esr.clock, l) + lg.Info(msg) + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionFalse, reasonEgressSvcInvalid, msg, esr.clock, lg) tsoperator.RemoveServiceCondition(svc, tsapi.EgressSvcConfigured) return false, nil } if !tsoperator.ProxyGroupAvailable(pg) { - tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionUnknown, reasonProxyGroupNotReady, reasonProxyGroupNotReady, esr.clock, l) + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionUnknown, reasonProxyGroupNotReady, reasonProxyGroupNotReady, esr.clock, lg) tsoperator.RemoveServiceCondition(svc, tsapi.EgressSvcConfigured) } - l.Debugf("egress service is valid") - tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionTrue, reasonEgressSvcValid, reasonEgressSvcValid, esr.clock, l) + lg.Debugf("egress service is valid") + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionTrue, reasonEgressSvcValid, reasonEgressSvcValid, esr.clock, lg) return true, nil } -func egressSvcCfg(externalNameSvc, clusterIPSvc *corev1.Service, ns string, l *zap.SugaredLogger) egressservices.Config { - d := retrieveClusterDomain(ns, l) +func egressSvcCfg(externalNameSvc, clusterIPSvc *corev1.Service, ns string, lg *zap.SugaredLogger) egressservices.Config { + d := retrieveClusterDomain(ns, lg) tt := tailnetTargetFromSvc(externalNameSvc) hep := healthCheckForSvc(clusterIPSvc, d) cfg := egressservices.Config{ @@ -691,18 +691,18 @@ func egressSvcChildResourceLabels(svc *corev1.Service) map[string]string { // egressEpsLabels returns labels to be added to an EndpointSlice created for an egress service. func egressSvcEpsLabels(extNSvc, clusterIPSvc *corev1.Service) map[string]string { - l := egressSvcChildResourceLabels(extNSvc) + lbels := egressSvcChildResourceLabels(extNSvc) // Adding this label is what makes kube proxy set up rules to route traffic sent to the clusterIP Service to the // endpoints defined on this EndpointSlice. // https://kubernetes.io/docs/concepts/services-networking/endpoint-slices/#ownership - l[discoveryv1.LabelServiceName] = clusterIPSvc.Name + lbels[discoveryv1.LabelServiceName] = clusterIPSvc.Name // Kubernetes recommends setting this label. // https://kubernetes.io/docs/concepts/services-networking/endpoint-slices/#management - l[discoveryv1.LabelManagedBy] = "tailscale.com" - return l + lbels[discoveryv1.LabelManagedBy] = "tailscale.com" + return lbels } -func svcConfigurationUpToDate(svc *corev1.Service, l *zap.SugaredLogger) bool { +func svcConfigurationUpToDate(svc *corev1.Service, lg *zap.SugaredLogger) bool { cond := tsoperator.GetServiceCondition(svc, tsapi.EgressSvcConfigured) if cond == nil { return false @@ -710,21 +710,21 @@ func svcConfigurationUpToDate(svc *corev1.Service, l *zap.SugaredLogger) bool { if cond.Status != metav1.ConditionTrue { return false } - wantsReadyReason := svcConfiguredReason(svc, true, l) + wantsReadyReason := svcConfiguredReason(svc, true, lg) return strings.EqualFold(wantsReadyReason, cond.Reason) } -func cfgHash(c cfg, l *zap.SugaredLogger) string { +func cfgHash(c cfg, lg *zap.SugaredLogger) string { bs, err := json.Marshal(c) if err != nil { // Don't use l.Error as that messes up component logs with, in this case, unnecessary stack trace. - l.Infof("error marhsalling Config: %v", err) + lg.Infof("error marhsalling Config: %v", err) return "" } h := sha256.New() if _, err := h.Write(bs); err != nil { // Don't use l.Error as that messes up component logs with, in this case, unnecessary stack trace. - l.Infof("error producing Config hash: %v", err) + lg.Infof("error producing Config hash: %v", err) return "" } return fmt.Sprintf("%x", h.Sum(nil)) @@ -736,7 +736,7 @@ type cfg struct { ProxyGroup string `json:"proxyGroup"` } -func svcConfiguredReason(svc *corev1.Service, configured bool, l *zap.SugaredLogger) string { +func svcConfiguredReason(svc *corev1.Service, configured bool, lg *zap.SugaredLogger) string { var r string if configured { r = "ConfiguredFor:" @@ -750,7 +750,7 @@ func svcConfiguredReason(svc *corev1.Service, configured bool, l *zap.SugaredLog TailnetTarget: tt, ProxyGroup: svc.Annotations[AnnotationProxyGroup], } - r += fmt.Sprintf(":Config:%s", cfgHash(s, l)) + r += fmt.Sprintf(":Config:%s", cfgHash(s, lg)) return r } diff --git a/cmd/k8s-operator/egress-services_test.go b/cmd/k8s-operator/egress-services_test.go index d8a5dfd32c1c2..202804d3011fd 100644 --- a/cmd/k8s-operator/egress-services_test.go +++ b/cmd/k8s-operator/egress-services_test.go @@ -249,9 +249,9 @@ func portsForEndpointSlice(svc *corev1.Service) []discoveryv1.EndpointPort { return ports } -func mustHaveConfigForSvc(t *testing.T, cl client.Client, extNSvc, clusterIPSvc *corev1.Service, cm *corev1.ConfigMap, l *zap.Logger) { +func mustHaveConfigForSvc(t *testing.T, cl client.Client, extNSvc, clusterIPSvc *corev1.Service, cm *corev1.ConfigMap, lg *zap.Logger) { t.Helper() - wantsCfg := egressSvcCfg(extNSvc, clusterIPSvc, clusterIPSvc.Namespace, l.Sugar()) + wantsCfg := egressSvcCfg(extNSvc, clusterIPSvc, clusterIPSvc.Namespace, lg.Sugar()) if err := cl.Get(context.Background(), client.ObjectKeyFromObject(cm), cm); err != nil { t.Fatalf("Error retrieving ConfigMap: %v", err) } diff --git a/cmd/k8s-operator/generate/main.go b/cmd/k8s-operator/generate/main.go index 5fd5d551b5e02..08bdc350d500c 100644 --- a/cmd/k8s-operator/generate/main.go +++ b/cmd/k8s-operator/generate/main.go @@ -69,7 +69,7 @@ func main() { }() log.Print("Templating Helm chart contents") helmTmplCmd := exec.Command("./tool/helm", "template", "operator", "./cmd/k8s-operator/deploy/chart", - "--namespace=tailscale") + "--namespace=tailscale", "--set=oauth.clientSecret=''") helmTmplCmd.Dir = repoRoot var out bytes.Buffer helmTmplCmd.Stdout = &out diff --git a/cmd/k8s-operator/ingress-for-pg.go b/cmd/k8s-operator/ingress-for-pg.go index 3afeb528f7f8f..460a1914ee799 100644 --- a/cmd/k8s-operator/ingress-for-pg.go +++ b/cmd/k8s-operator/ingress-for-pg.go @@ -29,6 +29,7 @@ import ( "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/reconcile" + "tailscale.com/internal/client/tailscale" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" @@ -154,11 +155,6 @@ func (r *HAIngressReconciler) maybeProvision(ctx context.Context, hostname strin // needs to be explicitly enabled for a tailnet to be able to use them. serviceName := tailcfg.ServiceName("svc:" + hostname) existingTSSvc, err := r.tsClient.GetVIPService(ctx, serviceName) - if isErrorFeatureFlagNotEnabled(err) { - logger.Warn(msgFeatureFlagNotEnabled) - r.recorder.Event(ing, corev1.EventTypeWarning, warningTailscaleServiceFeatureFlagNotEnabled, msgFeatureFlagNotEnabled) - return false, nil - } if err != nil && !isErrorTailscaleServiceNotFound(err) { return false, fmt.Errorf("error getting Tailscale Service %q: %w", hostname, err) } @@ -453,11 +449,6 @@ func (r *HAIngressReconciler) maybeCleanupProxyGroup(ctx context.Context, proxyG if !found { logger.Infof("Tailscale Service %q is not owned by any Ingress, cleaning up", tsSvcName) tsService, err := r.tsClient.GetVIPService(ctx, tsSvcName) - if isErrorFeatureFlagNotEnabled(err) { - msg := fmt.Sprintf("Unable to proceed with cleanup: %s.", msgFeatureFlagNotEnabled) - logger.Warn(msg) - return false, nil - } if isErrorTailscaleServiceNotFound(err) { return false, nil } @@ -514,16 +505,7 @@ func (r *HAIngressReconciler) maybeCleanup(ctx context.Context, hostname string, logger.Infof("Ensuring that Tailscale Service %q configuration is cleaned up", hostname) serviceName := tailcfg.ServiceName("svc:" + hostname) svc, err := r.tsClient.GetVIPService(ctx, serviceName) - if err != nil { - if isErrorFeatureFlagNotEnabled(err) { - msg := fmt.Sprintf("Unable to proceed with cleanup: %s.", msgFeatureFlagNotEnabled) - logger.Warn(msg) - r.recorder.Event(ing, corev1.EventTypeWarning, warningTailscaleServiceFeatureFlagNotEnabled, msg) - return false, nil - } - if isErrorTailscaleServiceNotFound(err) { - return false, nil - } + if err != nil && !isErrorTailscaleServiceNotFound(err) { return false, fmt.Errorf("error getting Tailscale Service: %w", err) } @@ -729,10 +711,15 @@ func (r *HAIngressReconciler) cleanupTailscaleService(ctx context.Context, svc * } if len(o.OwnerRefs) == 1 { logger.Infof("Deleting Tailscale Service %q", svc.Name) - return false, r.tsClient.DeleteVIPService(ctx, svc.Name) + if err = r.tsClient.DeleteVIPService(ctx, svc.Name); err != nil && !isErrorTailscaleServiceNotFound(err) { + return false, err + } + + return false, nil } + o.OwnerRefs = slices.Delete(o.OwnerRefs, ix, ix+1) - logger.Infof("Deleting Tailscale Service %q", svc.Name) + logger.Infof("Creating/Updating Tailscale Service %q", svc.Name) json, err := json.Marshal(o) if err != nil { return false, fmt.Errorf("error marshalling updated Tailscale Service owner reference: %w", err) @@ -1122,14 +1109,6 @@ func hasCerts(ctx context.Context, cl client.Client, lc localClient, ns string, return len(cert) > 0 && len(key) > 0, nil } -func isErrorFeatureFlagNotEnabled(err error) bool { - // messageFFNotEnabled is the error message returned by - // Tailscale control plane when a Tailscale Service API call is made for a - // tailnet that does not have the Tailscale Services feature flag enabled. - const messageFFNotEnabled = "feature unavailable for tailnet" - return err != nil && strings.Contains(err.Error(), messageFFNotEnabled) -} - func isErrorTailscaleServiceNotFound(err error) bool { var errResp tailscale.ErrResponse ok := errors.As(err, &errResp) diff --git a/cmd/k8s-operator/ingress-for-pg_test.go b/cmd/k8s-operator/ingress-for-pg_test.go index 77e5ecb37a677..5cc806ad1bf7a 100644 --- a/cmd/k8s-operator/ingress-for-pg_test.go +++ b/cmd/k8s-operator/ingress-for-pg_test.go @@ -25,6 +25,7 @@ import ( "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" + "tailscale.com/internal/client/tailscale" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" @@ -67,7 +68,7 @@ func TestIngressPGReconciler(t *testing.T) { // Verify initial reconciliation expectReconciled(t, ingPGR, "default", "test-ingress") - populateTLSSecret(context.Background(), fc, "test-pg", "my-svc.ts.net") + populateTLSSecret(t, fc, "test-pg", "my-svc.ts.net") expectReconciled(t, ingPGR, "default", "test-ingress") verifyServeConfig(t, fc, "svc:my-svc", false) verifyTailscaleService(t, ft, "svc:my-svc", []string{"tcp:443"}) @@ -89,7 +90,7 @@ func TestIngressPGReconciler(t *testing.T) { expectReconciled(t, ingPGR, "default", "test-ingress") // Verify Tailscale Service uses custom tags - tsSvc, err := ft.GetVIPService(context.Background(), "svc:my-svc") + tsSvc, err := ft.GetVIPService(t.Context(), "svc:my-svc") if err != nil { t.Fatalf("getting Tailscale Service: %v", err) } @@ -134,7 +135,7 @@ func TestIngressPGReconciler(t *testing.T) { // Verify second Ingress reconciliation expectReconciled(t, ingPGR, "default", "my-other-ingress") - populateTLSSecret(context.Background(), fc, "test-pg", "my-other-svc.ts.net") + populateTLSSecret(t, fc, "test-pg", "my-other-svc.ts.net") expectReconciled(t, ingPGR, "default", "my-other-ingress") verifyServeConfig(t, fc, "svc:my-other-svc", false) verifyTailscaleService(t, ft, "svc:my-other-svc", []string{"tcp:443"}) @@ -151,14 +152,14 @@ func TestIngressPGReconciler(t *testing.T) { verifyTailscaledConfig(t, fc, "test-pg", []string{"svc:my-svc", "svc:my-other-svc"}) // Delete second Ingress - if err := fc.Delete(context.Background(), ing2); err != nil { + if err := fc.Delete(t.Context(), ing2); err != nil { t.Fatalf("deleting second Ingress: %v", err) } expectReconciled(t, ingPGR, "default", "my-other-ingress") // Verify second Ingress cleanup cm := &corev1.ConfigMap{} - if err := fc.Get(context.Background(), types.NamespacedName{ + if err := fc.Get(t.Context(), types.NamespacedName{ Name: "test-pg-ingress-config", Namespace: "operator-ns", }, cm); err != nil { @@ -199,7 +200,7 @@ func TestIngressPGReconciler(t *testing.T) { expectEqual(t, fc, certSecretRoleBinding(pg, "operator-ns", "my-svc.ts.net")) // Delete the first Ingress and verify cleanup - if err := fc.Delete(context.Background(), ing); err != nil { + if err := fc.Delete(t.Context(), ing); err != nil { t.Fatalf("deleting Ingress: %v", err) } @@ -207,7 +208,7 @@ func TestIngressPGReconciler(t *testing.T) { // Verify the ConfigMap was cleaned up cm = &corev1.ConfigMap{} - if err := fc.Get(context.Background(), types.NamespacedName{ + if err := fc.Get(t.Context(), types.NamespacedName{ Name: "test-pg-second-ingress-config", Namespace: "operator-ns", }, cm); err != nil { @@ -228,6 +229,47 @@ func TestIngressPGReconciler(t *testing.T) { expectMissing[corev1.Secret](t, fc, "operator-ns", "my-svc.ts.net") expectMissing[rbacv1.Role](t, fc, "operator-ns", "my-svc.ts.net") expectMissing[rbacv1.RoleBinding](t, fc, "operator-ns", "my-svc.ts.net") + + // Create a third ingress + ing3 := &networkingv1.Ingress{ + TypeMeta: metav1.TypeMeta{Kind: "Ingress", APIVersion: "networking.k8s.io/v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: "my-other-ingress", + Namespace: "default", + UID: types.UID("5678-UID"), + Annotations: map[string]string{ + "tailscale.com/proxy-group": "test-pg", + }, + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + DefaultBackend: &networkingv1.IngressBackend{ + Service: &networkingv1.IngressServiceBackend{ + Name: "test", + Port: networkingv1.ServiceBackendPort{ + Number: 8080, + }, + }, + }, + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"my-other-svc.tailnetxyz.ts.net"}}, + }, + }, + } + + mustCreate(t, fc, ing3) + expectReconciled(t, ingPGR, ing3.Namespace, ing3.Name) + + // Delete the service from "control" + ft.vipServices = make(map[tailcfg.ServiceName]*tailscale.VIPService) + + // Delete the ingress and confirm we don't get stuck due to the VIP service not existing. + if err = fc.Delete(t.Context(), ing3); err != nil { + t.Fatalf("deleting Ingress: %v", err) + } + + expectReconciled(t, ingPGR, ing3.Namespace, ing3.Name) + expectMissing[networkingv1.Ingress](t, fc, ing3.Namespace, ing3.Name) } func TestIngressPGReconciler_UpdateIngressHostname(t *testing.T) { @@ -262,7 +304,7 @@ func TestIngressPGReconciler_UpdateIngressHostname(t *testing.T) { // Verify initial reconciliation expectReconciled(t, ingPGR, "default", "test-ingress") - populateTLSSecret(context.Background(), fc, "test-pg", "my-svc.ts.net") + populateTLSSecret(t, fc, "test-pg", "my-svc.ts.net") expectReconciled(t, ingPGR, "default", "test-ingress") verifyServeConfig(t, fc, "svc:my-svc", false) verifyTailscaleService(t, ft, "svc:my-svc", []string{"tcp:443"}) @@ -273,13 +315,13 @@ func TestIngressPGReconciler_UpdateIngressHostname(t *testing.T) { ing.Spec.TLS[0].Hosts[0] = "updated-svc" }) expectReconciled(t, ingPGR, "default", "test-ingress") - populateTLSSecret(context.Background(), fc, "test-pg", "updated-svc.ts.net") + populateTLSSecret(t, fc, "test-pg", "updated-svc.ts.net") expectReconciled(t, ingPGR, "default", "test-ingress") verifyServeConfig(t, fc, "svc:updated-svc", false) verifyTailscaleService(t, ft, "svc:updated-svc", []string{"tcp:443"}) verifyTailscaledConfig(t, fc, "test-pg", []string{"svc:updated-svc"}) - _, err := ft.GetVIPService(context.Background(), tailcfg.ServiceName("svc:my-svc")) + _, err := ft.GetVIPService(context.Background(), "svc:my-svc") if err == nil { t.Fatalf("svc:my-svc not cleaned up") } @@ -500,7 +542,7 @@ func TestIngressPGReconciler_HTTPEndpoint(t *testing.T) { // Verify initial reconciliation with HTTP enabled expectReconciled(t, ingPGR, "default", "test-ingress") - populateTLSSecret(context.Background(), fc, "test-pg", "my-svc.ts.net") + populateTLSSecret(t, fc, "test-pg", "my-svc.ts.net") expectReconciled(t, ingPGR, "default", "test-ingress") verifyTailscaleService(t, ft, "svc:my-svc", []string{"tcp:80", "tcp:443"}) verifyServeConfig(t, fc, "svc:my-svc", true) @@ -717,7 +759,9 @@ func TestOwnerAnnotations(t *testing.T) { } } -func populateTLSSecret(ctx context.Context, c client.Client, pgName, domain string) error { +func populateTLSSecret(t *testing.T, c client.Client, pgName, domain string) { + t.Helper() + secret := &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ Name: domain, @@ -736,10 +780,12 @@ func populateTLSSecret(ctx context.Context, c client.Client, pgName, domain stri }, } - _, err := createOrUpdate(ctx, c, "operator-ns", secret, func(s *corev1.Secret) { + _, err := createOrUpdate(t.Context(), c, "operator-ns", secret, func(s *corev1.Secret) { s.Data = secret.Data }) - return err + if err != nil { + t.Fatalf("failed to populate TLS secret: %v", err) + } } func verifyTailscaleService(t *testing.T, ft *fakeTSClient, serviceName string, wantPorts []string) { diff --git a/cmd/k8s-operator/nameserver.go b/cmd/k8s-operator/nameserver.go index 5de1c47ba2b7e..39db5f0f9cf16 100644 --- a/cmd/k8s-operator/nameserver.go +++ b/cmd/k8s-operator/nameserver.go @@ -26,6 +26,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/reconcile" "sigs.k8s.io/yaml" + tsoperator "tailscale.com/k8s-operator" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/kubetypes" @@ -45,10 +46,7 @@ const ( messageMultipleDNSConfigsPresent = "Multiple DNSConfig resources found in cluster. Please ensure no more than one is present." defaultNameserverImageRepo = "tailscale/k8s-nameserver" - // TODO (irbekrm): once we start publishing nameserver images for stable - // track, replace 'unstable' here with the version of this operator - // instance. - defaultNameserverImageTag = "unstable" + defaultNameserverImageTag = "stable" ) // NameserverReconciler knows how to create nameserver resources in cluster in diff --git a/cmd/k8s-operator/nameserver_test.go b/cmd/k8s-operator/nameserver_test.go index 6da52d8a21490..858cd973d82c2 100644 --- a/cmd/k8s-operator/nameserver_test.go +++ b/cmd/k8s-operator/nameserver_test.go @@ -19,6 +19,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/yaml" + operatorutils "tailscale.com/k8s-operator" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/tstest" @@ -182,7 +183,7 @@ func TestNameserverReconciler(t *testing.T) { dnsCfg.Spec.Nameserver.Image = nil }) expectReconciled(t, reconciler, "", "test") - wantsDeploy.Spec.Template.Spec.Containers[0].Image = "tailscale/k8s-nameserver:unstable" + wantsDeploy.Spec.Template.Spec.Containers[0].Image = "tailscale/k8s-nameserver:stable" expectEqual(t, fc, wantsDeploy) }) } diff --git a/cmd/k8s-operator/operator.go b/cmd/k8s-operator/operator.go index 89c8ff3e205bf..816fea5664557 100644 --- a/cmd/k8s-operator/operator.go +++ b/cmd/k8s-operator/operator.go @@ -44,10 +44,10 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager/signals" "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/reconcile" - "tailscale.com/envknob" "tailscale.com/client/local" "tailscale.com/client/tailscale" + "tailscale.com/envknob" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/ipn/store/kubestore" @@ -164,22 +164,24 @@ func main() { runReconcilers(rOpts) } -// initTSNet initializes the tsnet.Server and logs in to Tailscale. It uses the -// CLIENT_ID_FILE and CLIENT_SECRET_FILE environment variables to authenticate -// with Tailscale. +// initTSNet initializes the tsnet.Server and logs in to Tailscale. If CLIENT_ID +// is set, it authenticates to the Tailscale API using the federated OIDC workload +// identity flow. Otherwise, it uses the CLIENT_ID_FILE and CLIENT_SECRET_FILE +// environment variables to authenticate with static credentials. func initTSNet(zlog *zap.SugaredLogger, loginServer string) (*tsnet.Server, tsClient) { var ( - clientIDPath = defaultEnv("CLIENT_ID_FILE", "") - clientSecretPath = defaultEnv("CLIENT_SECRET_FILE", "") + clientID = defaultEnv("CLIENT_ID", "") // Used for workload identity federation. + clientIDPath = defaultEnv("CLIENT_ID_FILE", "") // Used for static client credentials. + clientSecretPath = defaultEnv("CLIENT_SECRET_FILE", "") // Used for static client credentials. hostname = defaultEnv("OPERATOR_HOSTNAME", "tailscale-operator") kubeSecret = defaultEnv("OPERATOR_SECRET", "") operatorTags = defaultEnv("OPERATOR_INITIAL_TAGS", "tag:k8s-operator") ) startlog := zlog.Named("startup") - if clientIDPath == "" || clientSecretPath == "" { - startlog.Fatalf("CLIENT_ID_FILE and CLIENT_SECRET_FILE must be set") + if clientID == "" && (clientIDPath == "" || clientSecretPath == "") { + startlog.Fatalf("CLIENT_ID_FILE and CLIENT_SECRET_FILE must be set") // TODO(tomhjp): error message can mention WIF once it's publicly available. } - tsc, err := newTSClient(context.Background(), clientIDPath, clientSecretPath, loginServer) + tsc, err := newTSClient(zlog.Named("ts-api-client"), clientID, clientIDPath, clientSecretPath, loginServer) if err != nil { startlog.Fatalf("error creating Tailscale client: %v", err) } @@ -636,7 +638,7 @@ func runReconcilers(opts reconcilerOpts) { recorder: eventRecorder, tsNamespace: opts.tailscaleNamespace, Client: mgr.GetClient(), - l: opts.log.Named("recorder-reconciler"), + log: opts.log.Named("recorder-reconciler"), clock: tstime.DefaultClock{}, tsClient: opts.tsClient, loginServer: opts.loginServer, @@ -691,7 +693,7 @@ func runReconcilers(opts reconcilerOpts) { Complete(&ProxyGroupReconciler{ recorder: eventRecorder, Client: mgr.GetClient(), - l: opts.log.Named("proxygroup-reconciler"), + log: opts.log.Named("proxygroup-reconciler"), clock: tstime.DefaultClock{}, tsClient: opts.tsClient, @@ -1120,7 +1122,7 @@ func serviceHandlerForIngress(cl client.Client, logger *zap.SugaredLogger, ingre reqs := make([]reconcile.Request, 0) for _, ing := range ingList.Items { if ing.Spec.IngressClassName == nil || *ing.Spec.IngressClassName != ingressClassName { - return nil + continue } if hasProxyGroupAnnotation(&ing) { // We don't want to reconcile backend Services for Ingresses for ProxyGroups. diff --git a/cmd/k8s-operator/operator_test.go b/cmd/k8s-operator/operator_test.go index 5af237342e8cd..e11235768dea2 100644 --- a/cmd/k8s-operator/operator_test.go +++ b/cmd/k8s-operator/operator_test.go @@ -1282,8 +1282,8 @@ func TestServiceProxyClassAnnotation(t *testing.T) { slist := &corev1.SecretList{} fc.List(context.Background(), slist, client.InNamespace("operator-ns")) for _, i := range slist.Items { - l, _ := json.Marshal(i.Labels) - t.Logf("found secret %q with labels %q ", i.Name, string(l)) + labels, _ := json.Marshal(i.Labels) + t.Logf("found secret %q with labels %q ", i.Name, string(labels)) } _, shortName := findGenName(t, fc, "default", "test", "svc") @@ -1698,6 +1698,42 @@ func Test_serviceHandlerForIngress(t *testing.T) { } } +func Test_serviceHandlerForIngress_multipleIngressClasses(t *testing.T) { + fc := fake.NewFakeClient() + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{Name: "backend", Namespace: "default"}, + } + mustCreate(t, fc, svc) + + mustCreate(t, fc, &networkingv1.Ingress{ + ObjectMeta: metav1.ObjectMeta{Name: "nginx-ing", Namespace: "default"}, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("nginx"), + DefaultBackend: &networkingv1.IngressBackend{Service: &networkingv1.IngressServiceBackend{Name: "backend"}}, + }, + }) + + mustCreate(t, fc, &networkingv1.Ingress{ + ObjectMeta: metav1.ObjectMeta{Name: "ts-ing", Namespace: "default"}, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + DefaultBackend: &networkingv1.IngressBackend{Service: &networkingv1.IngressServiceBackend{Name: "backend"}}, + }, + }) + + got := serviceHandlerForIngress(fc, zl.Sugar(), "tailscale")(context.Background(), svc) + want := []reconcile.Request{{NamespacedName: types.NamespacedName{Namespace: "default", Name: "ts-ing"}}} + + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("unexpected reconcile requests (-got +want):\n%s", diff) + } +} + func Test_clusterDomainFromResolverConf(t *testing.T) { zl, err := zap.NewDevelopment() if err != nil { diff --git a/cmd/k8s-operator/proxygroup.go b/cmd/k8s-operator/proxygroup.go index debeb5c6b3442..946e017a26f00 100644 --- a/cmd/k8s-operator/proxygroup.go +++ b/cmd/k8s-operator/proxygroup.go @@ -80,7 +80,7 @@ var ( // ProxyGroupReconciler ensures cluster resources for a ProxyGroup definition. type ProxyGroupReconciler struct { client.Client - l *zap.SugaredLogger + log *zap.SugaredLogger recorder record.EventRecorder clock tstime.Clock tsClient tsClient @@ -101,7 +101,7 @@ type ProxyGroupReconciler struct { } func (r *ProxyGroupReconciler) logger(name string) *zap.SugaredLogger { - return r.l.With("ProxyGroup", name) + return r.log.With("ProxyGroup", name) } func (r *ProxyGroupReconciler) Reconcile(ctx context.Context, req reconcile.Request) (_ reconcile.Result, err error) { diff --git a/cmd/k8s-operator/proxygroup_specs.go b/cmd/k8s-operator/proxygroup_specs.go index e185499f0e19d..930b7049d8ea9 100644 --- a/cmd/k8s-operator/proxygroup_specs.go +++ b/cmd/k8s-operator/proxygroup_specs.go @@ -182,6 +182,14 @@ func pgStatefulSet(pg *tsapi.ProxyGroup, namespace, image, tsFirewallMode string Name: "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR", Value: "/etc/tsconfig/$(POD_NAME)", }, + { + // This ensures that cert renewals can succeed if ACME account + // keys have changed since issuance. We cannot guarantee or + // validate that the account key has not changed, see + // https://github.com/tailscale/tailscale/issues/18251 + Name: "TS_DEBUG_ACME_FORCE_RENEWAL", + Value: "true", + }, } if port != nil { @@ -347,6 +355,14 @@ func kubeAPIServerStatefulSet(pg *tsapi.ProxyGroup, namespace, image string, por Name: "$(POD_NAME)-config", }.String(), }, + { + // This ensures that cert renewals can succeed if ACME account + // keys have changed since issuance. We cannot guarantee or + // validate that the account key has not changed, see + // https://github.com/tailscale/tailscale/issues/18251 + Name: "TS_DEBUG_ACME_FORCE_RENEWAL", + Value: "true", + }, } if port != nil { @@ -524,16 +540,16 @@ func pgSecretLabels(pgName, secretType string) map[string]string { } func pgLabels(pgName string, customLabels map[string]string) map[string]string { - l := make(map[string]string, len(customLabels)+3) + labels := make(map[string]string, len(customLabels)+3) for k, v := range customLabels { - l[k] = v + labels[k] = v } - l[kubetypes.LabelManaged] = "true" - l[LabelParentType] = "proxygroup" - l[LabelParentName] = pgName + labels[kubetypes.LabelManaged] = "true" + labels[LabelParentType] = "proxygroup" + labels[LabelParentName] = pgName - return l + return labels } func pgOwnerReference(owner *tsapi.ProxyGroup) []metav1.OwnerReference { diff --git a/cmd/k8s-operator/proxygroup_test.go b/cmd/k8s-operator/proxygroup_test.go index d763cf92276ec..2bcc9fb7a9720 100644 --- a/cmd/k8s-operator/proxygroup_test.go +++ b/cmd/k8s-operator/proxygroup_test.go @@ -670,7 +670,7 @@ func TestProxyGroupWithStaticEndpoints(t *testing.T) { t.Logf("created node %q with data", n.name) } - reconciler.l = zl.Sugar().With("TestName", tt.name).With("Reconcile", i) + reconciler.log = zl.Sugar().With("TestName", tt.name).With("Reconcile", i) pg.Spec.Replicas = r.replicas pc.Spec.StaticEndpoints = r.staticEndpointConfig @@ -784,7 +784,7 @@ func TestProxyGroupWithStaticEndpoints(t *testing.T) { Client: fc, tsClient: tsClient, recorder: fr, - l: zl.Sugar().With("TestName", tt.name).With("Reconcile", "cleanup"), + log: zl.Sugar().With("TestName", tt.name).With("Reconcile", "cleanup"), clock: cl, } @@ -845,7 +845,7 @@ func TestProxyGroup(t *testing.T) { Client: fc, tsClient: tsClient, recorder: fr, - l: zl.Sugar(), + log: zl.Sugar(), clock: cl, } crd := &apiextensionsv1.CustomResourceDefinition{ObjectMeta: metav1.ObjectMeta{Name: serviceMonitorCRD}} @@ -1049,7 +1049,7 @@ func TestProxyGroupTypes(t *testing.T) { tsNamespace: tsNamespace, tsProxyImage: testProxyImage, Client: fc, - l: zl.Sugar(), + log: zl.Sugar(), tsClient: &fakeTSClient{}, clock: tstest.NewClock(tstest.ClockOpts{}), } @@ -1289,24 +1289,24 @@ func TestKubeAPIServerStatusConditionFlow(t *testing.T) { tsNamespace: tsNamespace, tsProxyImage: testProxyImage, Client: fc, - l: zap.Must(zap.NewDevelopment()).Sugar(), + log: zap.Must(zap.NewDevelopment()).Sugar(), tsClient: &fakeTSClient{}, clock: tstest.NewClock(tstest.ClockOpts{}), } expectReconciled(t, r, "", pg.Name) pg.ObjectMeta.Finalizers = append(pg.ObjectMeta.Finalizers, FinalizerName) - tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupAvailable, metav1.ConditionFalse, reasonProxyGroupCreating, "", 0, r.clock, r.l) - tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionFalse, reasonProxyGroupCreating, "", 1, r.clock, r.l) + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupAvailable, metav1.ConditionFalse, reasonProxyGroupCreating, "", 0, r.clock, r.log) + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionFalse, reasonProxyGroupCreating, "", 1, r.clock, r.log) expectEqual(t, fc, pg, omitPGStatusConditionMessages) // Set kube-apiserver valid. mustUpdateStatus(t, fc, "", pg.Name, func(p *tsapi.ProxyGroup) { - tsoperator.SetProxyGroupCondition(p, tsapi.KubeAPIServerProxyValid, metav1.ConditionTrue, reasonKubeAPIServerProxyValid, "", 1, r.clock, r.l) + tsoperator.SetProxyGroupCondition(p, tsapi.KubeAPIServerProxyValid, metav1.ConditionTrue, reasonKubeAPIServerProxyValid, "", 1, r.clock, r.log) }) expectReconciled(t, r, "", pg.Name) - tsoperator.SetProxyGroupCondition(pg, tsapi.KubeAPIServerProxyValid, metav1.ConditionTrue, reasonKubeAPIServerProxyValid, "", 1, r.clock, r.l) - tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionFalse, reasonProxyGroupCreating, "", 1, r.clock, r.l) + tsoperator.SetProxyGroupCondition(pg, tsapi.KubeAPIServerProxyValid, metav1.ConditionTrue, reasonKubeAPIServerProxyValid, "", 1, r.clock, r.log) + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionFalse, reasonProxyGroupCreating, "", 1, r.clock, r.log) expectEqual(t, fc, pg, omitPGStatusConditionMessages) // Set available. @@ -1318,17 +1318,17 @@ func TestKubeAPIServerStatusConditionFlow(t *testing.T) { TailnetIPs: []string{"1.2.3.4", "::1"}, }, } - tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupAvailable, metav1.ConditionTrue, reasonProxyGroupAvailable, "", 0, r.clock, r.l) - tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionFalse, reasonProxyGroupCreating, "", 1, r.clock, r.l) + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupAvailable, metav1.ConditionTrue, reasonProxyGroupAvailable, "", 0, r.clock, r.log) + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionFalse, reasonProxyGroupCreating, "", 1, r.clock, r.log) expectEqual(t, fc, pg, omitPGStatusConditionMessages) // Set kube-apiserver configured. mustUpdateStatus(t, fc, "", pg.Name, func(p *tsapi.ProxyGroup) { - tsoperator.SetProxyGroupCondition(p, tsapi.KubeAPIServerProxyConfigured, metav1.ConditionTrue, reasonKubeAPIServerProxyConfigured, "", 1, r.clock, r.l) + tsoperator.SetProxyGroupCondition(p, tsapi.KubeAPIServerProxyConfigured, metav1.ConditionTrue, reasonKubeAPIServerProxyConfigured, "", 1, r.clock, r.log) }) expectReconciled(t, r, "", pg.Name) - tsoperator.SetProxyGroupCondition(pg, tsapi.KubeAPIServerProxyConfigured, metav1.ConditionTrue, reasonKubeAPIServerProxyConfigured, "", 1, r.clock, r.l) - tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionTrue, reasonProxyGroupReady, "", 1, r.clock, r.l) + tsoperator.SetProxyGroupCondition(pg, tsapi.KubeAPIServerProxyConfigured, metav1.ConditionTrue, reasonKubeAPIServerProxyConfigured, "", 1, r.clock, r.log) + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionTrue, reasonProxyGroupReady, "", 1, r.clock, r.log) expectEqual(t, fc, pg, omitPGStatusConditionMessages) } @@ -1342,7 +1342,7 @@ func TestKubeAPIServerType_DoesNotOverwriteServicesConfig(t *testing.T) { tsNamespace: tsNamespace, tsProxyImage: testProxyImage, Client: fc, - l: zap.Must(zap.NewDevelopment()).Sugar(), + log: zap.Must(zap.NewDevelopment()).Sugar(), tsClient: &fakeTSClient{}, clock: tstest.NewClock(tstest.ClockOpts{}), } @@ -1427,7 +1427,7 @@ func TestIngressAdvertiseServicesConfigPreserved(t *testing.T) { tsNamespace: tsNamespace, tsProxyImage: testProxyImage, Client: fc, - l: zap.Must(zap.NewDevelopment()).Sugar(), + log: zap.Must(zap.NewDevelopment()).Sugar(), tsClient: &fakeTSClient{}, clock: tstest.NewClock(tstest.ClockOpts{}), } @@ -1902,7 +1902,7 @@ func TestProxyGroupLetsEncryptStaging(t *testing.T) { defaultProxyClass: tt.defaultProxyClass, Client: fc, tsClient: &fakeTSClient{}, - l: zl.Sugar(), + log: zl.Sugar(), clock: cl, } diff --git a/cmd/k8s-operator/sts.go b/cmd/k8s-operator/sts.go index c52ffce85495b..d1cfda3815c79 100644 --- a/cmd/k8s-operator/sts.go +++ b/cmd/k8s-operator/sts.go @@ -670,6 +670,14 @@ func (a *tailscaleSTSReconciler) reconcileSTS(ctx context.Context, logger *zap.S Name: "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR", Value: "/etc/tsconfig/$(POD_NAME)", }, + corev1.EnvVar{ + // This ensures that cert renewals can succeed if ACME account + // keys have changed since issuance. We cannot guarantee or + // validate that the account key has not changed, see + // https://github.com/tailscale/tailscale/issues/18251 + Name: "TS_DEBUG_ACME_FORCE_RENEWAL", + Value: "true", + }, ) if sts.ForwardClusterTrafficViaL7IngressProxy { diff --git a/cmd/k8s-operator/sts_test.go b/cmd/k8s-operator/sts_test.go index ea28e77a14c36..afe54ed98bc49 100644 --- a/cmd/k8s-operator/sts_test.go +++ b/cmd/k8s-operator/sts_test.go @@ -71,11 +71,11 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { SecurityContext: &corev1.PodSecurityContext{ RunAsUser: ptr.To(int64(0)), }, - ImagePullSecrets: []corev1.LocalObjectReference{{Name: "docker-creds"}}, - NodeName: "some-node", - NodeSelector: map[string]string{"beta.kubernetes.io/os": "linux"}, - Affinity: &corev1.Affinity{NodeAffinity: &corev1.NodeAffinity{RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{}}}, - Tolerations: []corev1.Toleration{{Key: "", Operator: "Exists"}}, + ImagePullSecrets: []corev1.LocalObjectReference{{Name: "docker-creds"}}, + NodeName: "some-node", + NodeSelector: map[string]string{"beta.kubernetes.io/os": "linux"}, + Affinity: &corev1.Affinity{NodeAffinity: &corev1.NodeAffinity{RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{}}}, + Tolerations: []corev1.Toleration{{Key: "", Operator: "Exists"}}, PriorityClassName: "high-priority", TopologySpreadConstraints: []corev1.TopologySpreadConstraint{ { diff --git a/cmd/k8s-operator/svc-for-pg.go b/cmd/k8s-operator/svc-for-pg.go index 62cc36bd4a82b..144d3755811da 100644 --- a/cmd/k8s-operator/svc-for-pg.go +++ b/cmd/k8s-operator/svc-for-pg.go @@ -207,11 +207,6 @@ func (r *HAServiceReconciler) maybeProvision(ctx context.Context, hostname strin // already created and not owned by this Service. serviceName := tailcfg.ServiceName("svc:" + hostname) existingTSSvc, err := r.tsClient.GetVIPService(ctx, serviceName) - if isErrorFeatureFlagNotEnabled(err) { - logger.Warn(msgFeatureFlagNotEnabled) - r.recorder.Event(svc, corev1.EventTypeWarning, warningTailscaleServiceFeatureFlagNotEnabled, msgFeatureFlagNotEnabled) - return false, nil - } if err != nil && !isErrorTailscaleServiceNotFound(err) { return false, fmt.Errorf("error getting Tailscale Service %q: %w", hostname, err) } @@ -530,11 +525,6 @@ func (r *HAServiceReconciler) tailnetCertDomain(ctx context.Context) (string, er // It returns true if an existing Tailscale Service was updated to remove owner reference, as well as any error that occurred. func cleanupTailscaleService(ctx context.Context, tsClient tsClient, name tailcfg.ServiceName, operatorID string, logger *zap.SugaredLogger) (updated bool, err error) { svc, err := tsClient.GetVIPService(ctx, name) - if isErrorFeatureFlagNotEnabled(err) { - msg := fmt.Sprintf("Unable to proceed with cleanup: %s.", msgFeatureFlagNotEnabled) - logger.Warn(msg) - return false, nil - } if err != nil { errResp := &tailscale.ErrResponse{} ok := errors.As(err, errResp) diff --git a/cmd/k8s-operator/testutils_test.go b/cmd/k8s-operator/testutils_test.go index b4c468c8e8e94..a9e79cee5e714 100644 --- a/cmd/k8s-operator/testutils_test.go +++ b/cmd/k8s-operator/testutils_test.go @@ -91,6 +91,7 @@ func expectedSTS(t *testing.T, cl client.Client, opts configOpts) *appsv1.Statef {Name: "POD_UID", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "metadata.uid"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, {Name: "TS_KUBE_SECRET", Value: "$(POD_NAME)"}, {Name: "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR", Value: "/etc/tsconfig/$(POD_NAME)"}, + {Name: "TS_DEBUG_ACME_FORCE_RENEWAL", Value: "true"}, }, SecurityContext: &corev1.SecurityContext{ Privileged: ptr.To(true), @@ -280,6 +281,7 @@ func expectedSTSUserspace(t *testing.T, cl client.Client, opts configOpts) *apps {Name: "POD_UID", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "metadata.uid"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, {Name: "TS_KUBE_SECRET", Value: "$(POD_NAME)"}, {Name: "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR", Value: "/etc/tsconfig/$(POD_NAME)"}, + {Name: "TS_DEBUG_ACME_FORCE_RENEWAL", Value: "true"}, {Name: "TS_SERVE_CONFIG", Value: "/etc/tailscaled/$(POD_NAME)/serve-config"}, {Name: "TS_INTERNAL_APP", Value: opts.app}, }, diff --git a/cmd/k8s-operator/tsclient.go b/cmd/k8s-operator/tsclient.go index 50620c26ddf27..d22fa1797dd5c 100644 --- a/cmd/k8s-operator/tsclient.go +++ b/cmd/k8s-operator/tsclient.go @@ -8,8 +8,13 @@ package main import ( "context" "fmt" + "net/http" "os" + "sync" + "time" + "go.uber.org/zap" + "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" "tailscale.com/internal/client/tailscale" "tailscale.com/ipn" @@ -20,30 +25,53 @@ import ( // call should be performed on the default tailnet for the provided credentials. const ( defaultTailnet = "-" + oidcJWTPath = "/var/run/secrets/tailscale/serviceaccount/token" ) -func newTSClient(ctx context.Context, clientIDPath, clientSecretPath, loginServer string) (tsClient, error) { - clientID, err := os.ReadFile(clientIDPath) - if err != nil { - return nil, fmt.Errorf("error reading client ID %q: %w", clientIDPath, err) - } - clientSecret, err := os.ReadFile(clientSecretPath) - if err != nil { - return nil, fmt.Errorf("reading client secret %q: %w", clientSecretPath, err) - } - const tokenURLPath = "/api/v2/oauth/token" - tokenURL := fmt.Sprintf("%s%s", ipn.DefaultControlURL, tokenURLPath) +func newTSClient(logger *zap.SugaredLogger, clientID, clientIDPath, clientSecretPath, loginServer string) (*tailscale.Client, error) { + baseURL := ipn.DefaultControlURL if loginServer != "" { - tokenURL = fmt.Sprintf("%s%s", loginServer, tokenURLPath) + baseURL = loginServer } - credentials := clientcredentials.Config{ - ClientID: string(clientID), - ClientSecret: string(clientSecret), - TokenURL: tokenURL, + + var httpClient *http.Client + if clientID == "" { + // Use static client credentials mounted to disk. + id, err := os.ReadFile(clientIDPath) + if err != nil { + return nil, fmt.Errorf("error reading client ID %q: %w", clientIDPath, err) + } + secret, err := os.ReadFile(clientSecretPath) + if err != nil { + return nil, fmt.Errorf("reading client secret %q: %w", clientSecretPath, err) + } + credentials := clientcredentials.Config{ + ClientID: string(id), + ClientSecret: string(secret), + TokenURL: fmt.Sprintf("%s%s", baseURL, "/api/v2/oauth/token"), + } + tokenSrc := credentials.TokenSource(context.Background()) + httpClient = oauth2.NewClient(context.Background(), tokenSrc) + } else { + // Use workload identity federation. + tokenSrc := &jwtTokenSource{ + logger: logger, + jwtPath: oidcJWTPath, + baseCfg: clientcredentials.Config{ + ClientID: clientID, + TokenURL: fmt.Sprintf("%s%s", baseURL, "/api/v2/oauth/token-exchange"), + }, + } + httpClient = &http.Client{ + Transport: &oauth2.Transport{ + Source: tokenSrc, + }, + } } + c := tailscale.NewClient(defaultTailnet, nil) c.UserAgent = "tailscale-k8s-operator" - c.HTTPClient = credentials.Client(ctx) + c.HTTPClient = httpClient if loginServer != "" { c.BaseURL = loginServer } @@ -63,3 +91,43 @@ type tsClient interface { // DeleteVIPService is a method for deleting a Tailscale Service. DeleteVIPService(ctx context.Context, name tailcfg.ServiceName) error } + +// jwtTokenSource implements the [oauth2.TokenSource] interface, but with the +// ability to regenerate a fresh underlying token source each time a new value +// of the JWT parameter is needed due to expiration. +type jwtTokenSource struct { + logger *zap.SugaredLogger + jwtPath string // Path to the file containing an automatically refreshed JWT. + baseCfg clientcredentials.Config // Holds config that doesn't change for the lifetime of the process. + + mu sync.Mutex // Guards underlying. + underlying oauth2.TokenSource // The oauth2 client implementation. Does its own separate caching of the access token. +} + +func (s *jwtTokenSource) Token() (*oauth2.Token, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.underlying != nil { + t, err := s.underlying.Token() + if err == nil && t != nil && t.Valid() { + return t, nil + } + } + + s.logger.Debugf("Refreshing JWT from %s", s.jwtPath) + tk, err := os.ReadFile(s.jwtPath) + if err != nil { + return nil, fmt.Errorf("error reading JWT from %q: %w", s.jwtPath, err) + } + + // Shallow copy of the base config. + credentials := s.baseCfg + credentials.EndpointParams = map[string][]string{ + "jwt": {string(tk)}, + } + + src := credentials.TokenSource(context.Background()) + s.underlying = oauth2.ReuseTokenSourceWithExpiry(nil, src, time.Minute) + return s.underlying.Token() +} diff --git a/cmd/k8s-operator/tsclient_test.go b/cmd/k8s-operator/tsclient_test.go new file mode 100644 index 0000000000000..16de512d5809f --- /dev/null +++ b/cmd/k8s-operator/tsclient_test.go @@ -0,0 +1,135 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "go.uber.org/zap" + "golang.org/x/oauth2" +) + +func TestNewStaticClient(t *testing.T) { + const ( + clientIDFile = "client-id" + clientSecretFile = "client-secret" + ) + + tmp := t.TempDir() + clientIDPath := filepath.Join(tmp, clientIDFile) + if err := os.WriteFile(clientIDPath, []byte("test-client-id"), 0600); err != nil { + t.Fatalf("error writing test file %q: %v", clientIDPath, err) + } + clientSecretPath := filepath.Join(tmp, clientSecretFile) + if err := os.WriteFile(clientSecretPath, []byte("test-client-secret"), 0600); err != nil { + t.Fatalf("error writing test file %q: %v", clientSecretPath, err) + } + + srv := testAPI(t, 3600) + cl, err := newTSClient(zap.NewNop().Sugar(), "", clientIDPath, clientSecretPath, srv.URL) + if err != nil { + t.Fatalf("error creating Tailscale client: %v", err) + } + + resp, err := cl.HTTPClient.Get(srv.URL) + if err != nil { + t.Fatalf("error making test API call: %v", err) + } + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("error reading response body: %v", err) + } + want := "Bearer " + testToken("/api/v2/oauth/token", "test-client-id", "test-client-secret", "") + if string(got) != want { + t.Errorf("got %q; want %q", got, want) + } +} + +func TestNewWorkloadIdentityClient(t *testing.T) { + // 5 seconds is within expiryDelta leeway, so the access token will + // immediately be considered expired and get refreshed on each access. + srv := testAPI(t, 5) + cl, err := newTSClient(zap.NewNop().Sugar(), "test-client-id", "", "", srv.URL) + if err != nil { + t.Fatalf("error creating Tailscale client: %v", err) + } + + // Modify the path where the JWT will be read from. + oauth2Transport, ok := cl.HTTPClient.Transport.(*oauth2.Transport) + if !ok { + t.Fatalf("expected oauth2.Transport, got %T", cl.HTTPClient.Transport) + } + jwtTokenSource, ok := oauth2Transport.Source.(*jwtTokenSource) + if !ok { + t.Fatalf("expected jwtTokenSource, got %T", oauth2Transport.Source) + } + tmp := t.TempDir() + jwtPath := filepath.Join(tmp, "token") + jwtTokenSource.jwtPath = jwtPath + + for _, jwt := range []string{"test-jwt", "updated-test-jwt"} { + if err := os.WriteFile(jwtPath, []byte(jwt), 0600); err != nil { + t.Fatalf("error writing test file %q: %v", jwtPath, err) + } + resp, err := cl.HTTPClient.Get(srv.URL) + if err != nil { + t.Fatalf("error making test API call: %v", err) + } + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("error reading response body: %v", err) + } + if want := "Bearer " + testToken("/api/v2/oauth/token-exchange", "test-client-id", "", jwt); string(got) != want { + t.Errorf("got %q; want %q", got, want) + } + } +} + +func testAPI(t *testing.T, expirationSeconds int) *httptest.Server { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("test server got request: %s %s", r.Method, r.URL.Path) + switch r.URL.Path { + case "/api/v2/oauth/token", "/api/v2/oauth/token-exchange": + id, secret, ok := r.BasicAuth() + if !ok { + t.Fatal("missing or invalid basic auth") + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": testToken(r.URL.Path, id, secret, r.FormValue("jwt")), + "token_type": "Bearer", + "expires_in": expirationSeconds, + }); err != nil { + t.Fatalf("error writing response: %v", err) + } + case "/": + // Echo back the authz header for test assertions. + _, err := w.Write([]byte(r.Header.Get("Authorization"))) + if err != nil { + t.Fatalf("error writing response: %v", err) + } + default: + w.WriteHeader(http.StatusNotFound) + } + })) + t.Cleanup(srv.Close) + return srv +} + +func testToken(path, id, secret, jwt string) string { + return fmt.Sprintf("%s|%s|%s|%s", path, id, secret, jwt) +} diff --git a/cmd/k8s-operator/tsrecorder.go b/cmd/k8s-operator/tsrecorder.go index ec95ecf40dab5..bfb01fa86de67 100644 --- a/cmd/k8s-operator/tsrecorder.go +++ b/cmd/k8s-operator/tsrecorder.go @@ -12,6 +12,7 @@ import ( "fmt" "net/http" "slices" + "strconv" "strings" "sync" @@ -29,6 +30,7 @@ import ( "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/reconcile" + "tailscale.com/client/tailscale" tsoperator "tailscale.com/k8s-operator" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" @@ -54,7 +56,7 @@ var gaugeRecorderResources = clientmetric.NewGauge(kubetypes.MetricRecorderCount // Recorder CRs. type RecorderReconciler struct { client.Client - l *zap.SugaredLogger + log *zap.SugaredLogger recorder record.EventRecorder clock tstime.Clock tsNamespace string @@ -66,16 +68,16 @@ type RecorderReconciler struct { } func (r *RecorderReconciler) logger(name string) *zap.SugaredLogger { - return r.l.With("Recorder", name) + return r.log.With("Recorder", name) } -func (r *RecorderReconciler) Reconcile(ctx context.Context, req reconcile.Request) (_ reconcile.Result, err error) { +func (r *RecorderReconciler) Reconcile(ctx context.Context, req reconcile.Request) (reconcile.Result, error) { logger := r.logger(req.Name) logger.Debugf("starting reconcile") defer logger.Debugf("reconcile finished") tsr := new(tsapi.Recorder) - err = r.Get(ctx, req.NamespacedName, tsr) + err := r.Get(ctx, req.NamespacedName, tsr) if apierrors.IsNotFound(err) { logger.Debugf("Recorder not found, assuming it was deleted") return reconcile.Result{}, nil @@ -98,7 +100,7 @@ func (r *RecorderReconciler) Reconcile(ctx context.Context, req reconcile.Reques } tsr.Finalizers = slices.Delete(tsr.Finalizers, ix, ix+1) - if err := r.Update(ctx, tsr); err != nil { + if err = r.Update(ctx, tsr); err != nil { return reconcile.Result{}, err } return reconcile.Result{}, nil @@ -110,10 +112,11 @@ func (r *RecorderReconciler) Reconcile(ctx context.Context, req reconcile.Reques if !apiequality.Semantic.DeepEqual(oldTSRStatus, &tsr.Status) { // An error encountered here should get returned by the Reconcile function. if updateErr := r.Client.Status().Update(ctx, tsr); updateErr != nil { - err = errors.Join(err, updateErr) + return reconcile.Result{}, errors.Join(err, updateErr) } } - return reconcile.Result{}, err + + return reconcile.Result{}, nil } if !slices.Contains(tsr.Finalizers, FinalizerName) { @@ -123,12 +126,12 @@ func (r *RecorderReconciler) Reconcile(ctx context.Context, req reconcile.Reques // operation is underway. logger.Infof("ensuring Recorder is set up") tsr.Finalizers = append(tsr.Finalizers, FinalizerName) - if err := r.Update(ctx, tsr); err != nil { + if err = r.Update(ctx, tsr); err != nil { return setStatusReady(tsr, metav1.ConditionFalse, reasonRecorderCreationFailed, reasonRecorderCreationFailed) } } - if err := r.validate(ctx, tsr); err != nil { + if err = r.validate(ctx, tsr); err != nil { message := fmt.Sprintf("Recorder is invalid: %s", err) r.recorder.Eventf(tsr, corev1.EventTypeWarning, reasonRecorderInvalid, message) return setStatusReady(tsr, metav1.ConditionFalse, reasonRecorderInvalid, message) @@ -160,19 +163,29 @@ func (r *RecorderReconciler) maybeProvision(ctx context.Context, tsr *tsapi.Reco gaugeRecorderResources.Set(int64(r.recorders.Len())) r.mu.Unlock() - if err := r.ensureAuthSecretCreated(ctx, tsr); err != nil { + if err := r.ensureAuthSecretsCreated(ctx, tsr); err != nil { return fmt.Errorf("error creating secrets: %w", err) } - // State Secret is precreated so we can use the Recorder CR as its owner ref. - sec := tsrStateSecret(tsr, r.tsNamespace) - if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, sec, func(s *corev1.Secret) { - s.ObjectMeta.Labels = sec.ObjectMeta.Labels - s.ObjectMeta.Annotations = sec.ObjectMeta.Annotations - }); err != nil { - return fmt.Errorf("error creating state Secret: %w", err) + + // State Secrets are pre-created so we can use the Recorder CR as its owner ref. + var replicas int32 = 1 + if tsr.Spec.Replicas != nil { + replicas = *tsr.Spec.Replicas + } + + for replica := range replicas { + sec := tsrStateSecret(tsr, r.tsNamespace, replica) + _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, sec, func(s *corev1.Secret) { + s.ObjectMeta.Labels = sec.ObjectMeta.Labels + s.ObjectMeta.Annotations = sec.ObjectMeta.Annotations + }) + if err != nil { + return fmt.Errorf("error creating state Secret %q: %w", sec.Name, err) + } } + sa := tsrServiceAccount(tsr, r.tsNamespace) - if _, err := createOrMaybeUpdate(ctx, r.Client, r.tsNamespace, sa, func(s *corev1.ServiceAccount) error { + _, err := createOrMaybeUpdate(ctx, r.Client, r.tsNamespace, sa, func(s *corev1.ServiceAccount) error { // Perform this check within the update function to make sure we don't // have a race condition between the previous check and the update. if err := saOwnedByRecorder(s, tsr); err != nil { @@ -183,54 +196,68 @@ func (r *RecorderReconciler) maybeProvision(ctx context.Context, tsr *tsapi.Reco s.ObjectMeta.Annotations = sa.ObjectMeta.Annotations return nil - }); err != nil { + }) + if err != nil { return fmt.Errorf("error creating ServiceAccount: %w", err) } + role := tsrRole(tsr, r.tsNamespace) - if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, role, func(r *rbacv1.Role) { + _, err = createOrUpdate(ctx, r.Client, r.tsNamespace, role, func(r *rbacv1.Role) { r.ObjectMeta.Labels = role.ObjectMeta.Labels r.ObjectMeta.Annotations = role.ObjectMeta.Annotations r.Rules = role.Rules - }); err != nil { + }) + if err != nil { return fmt.Errorf("error creating Role: %w", err) } + roleBinding := tsrRoleBinding(tsr, r.tsNamespace) - if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, roleBinding, func(r *rbacv1.RoleBinding) { + _, err = createOrUpdate(ctx, r.Client, r.tsNamespace, roleBinding, func(r *rbacv1.RoleBinding) { r.ObjectMeta.Labels = roleBinding.ObjectMeta.Labels r.ObjectMeta.Annotations = roleBinding.ObjectMeta.Annotations r.RoleRef = roleBinding.RoleRef r.Subjects = roleBinding.Subjects - }); err != nil { + }) + if err != nil { return fmt.Errorf("error creating RoleBinding: %w", err) } + ss := tsrStatefulSet(tsr, r.tsNamespace, r.loginServer) - if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, ss, func(s *appsv1.StatefulSet) { + _, err = createOrUpdate(ctx, r.Client, r.tsNamespace, ss, func(s *appsv1.StatefulSet) { s.ObjectMeta.Labels = ss.ObjectMeta.Labels s.ObjectMeta.Annotations = ss.ObjectMeta.Annotations s.Spec = ss.Spec - }); err != nil { + }) + if err != nil { return fmt.Errorf("error creating StatefulSet: %w", err) } // ServiceAccount name may have changed, in which case we need to clean up // the previous ServiceAccount. RoleBinding will already be updated to point // to the new ServiceAccount. - if err := r.maybeCleanupServiceAccounts(ctx, tsr, sa.Name); err != nil { + if err = r.maybeCleanupServiceAccounts(ctx, tsr, sa.Name); err != nil { return fmt.Errorf("error cleaning up ServiceAccounts: %w", err) } + // If we have scaled the recorder down, we will have dangling state secrets + // that we need to clean up. + if err = r.maybeCleanupSecrets(ctx, tsr); err != nil { + return fmt.Errorf("error cleaning up Secrets: %w", err) + } + var devices []tsapi.RecorderTailnetDevice + for replica := range replicas { + dev, ok, err := r.getDeviceInfo(ctx, tsr.Name, replica) + switch { + case err != nil: + return fmt.Errorf("failed to get device info: %w", err) + case !ok: + logger.Debugf("no Tailscale hostname known yet, waiting for Recorder pod to finish auth") + continue + } - device, ok, err := r.getDeviceInfo(ctx, tsr.Name) - if err != nil { - return fmt.Errorf("failed to get device info: %w", err) + devices = append(devices, dev) } - if !ok { - logger.Debugf("no Tailscale hostname known yet, waiting for Recorder pod to finish auth") - return nil - } - - devices = append(devices, device) tsr.Status.Devices = devices @@ -257,22 +284,89 @@ func saOwnedByRecorder(sa *corev1.ServiceAccount, tsr *tsapi.Recorder) error { func (r *RecorderReconciler) maybeCleanupServiceAccounts(ctx context.Context, tsr *tsapi.Recorder, currentName string) error { logger := r.logger(tsr.Name) - // List all ServiceAccounts owned by this Recorder. + options := []client.ListOption{ + client.InNamespace(r.tsNamespace), + client.MatchingLabels(tsrLabels("recorder", tsr.Name, nil)), + } + sas := &corev1.ServiceAccountList{} - if err := r.List(ctx, sas, client.InNamespace(r.tsNamespace), client.MatchingLabels(labels("recorder", tsr.Name, nil))); err != nil { + if err := r.List(ctx, sas, options...); err != nil { return fmt.Errorf("error listing ServiceAccounts for cleanup: %w", err) } - for _, sa := range sas.Items { - if sa.Name == currentName { + + for _, serviceAccount := range sas.Items { + if serviceAccount.Name == currentName { + continue + } + + err := r.Delete(ctx, &serviceAccount) + switch { + case apierrors.IsNotFound(err): + logger.Debugf("ServiceAccount %s not found, likely already deleted", serviceAccount.Name) + continue + case err != nil: + return fmt.Errorf("error deleting ServiceAccount %s: %w", serviceAccount.Name, err) + } + } + + return nil +} + +func (r *RecorderReconciler) maybeCleanupSecrets(ctx context.Context, tsr *tsapi.Recorder) error { + options := []client.ListOption{ + client.InNamespace(r.tsNamespace), + client.MatchingLabels(tsrLabels("recorder", tsr.Name, nil)), + } + + secrets := &corev1.SecretList{} + if err := r.List(ctx, secrets, options...); err != nil { + return fmt.Errorf("error listing Secrets for cleanup: %w", err) + } + + // Get the largest ordinal suffix that we expect. Then we'll go through the list of secrets owned by this + // recorder and remove them. + var replicas int32 = 1 + if tsr.Spec.Replicas != nil { + replicas = *tsr.Spec.Replicas + } + + for _, secret := range secrets.Items { + parts := strings.Split(secret.Name, "-") + if len(parts) == 0 { + continue + } + + ordinal, err := strconv.ParseUint(parts[len(parts)-1], 10, 32) + if err != nil { + return fmt.Errorf("error parsing secret name %q: %w", secret.Name, err) + } + + if int32(ordinal) < replicas { continue } - if err := r.Delete(ctx, &sa); err != nil { - if apierrors.IsNotFound(err) { - logger.Debugf("ServiceAccount %s not found, likely already deleted", sa.Name) - } else { - return fmt.Errorf("error deleting ServiceAccount %s: %w", sa.Name, err) + + devicePrefs, ok, err := getDevicePrefs(&secret) + if err != nil { + return err + } + + if ok { + var errResp *tailscale.ErrResponse + + r.log.Debugf("deleting device %s", devicePrefs.Config.NodeID) + err = r.tsClient.DeleteDevice(ctx, string(devicePrefs.Config.NodeID)) + switch { + case errors.As(err, &errResp) && errResp.Status == http.StatusNotFound: + // This device has possibly already been deleted in the admin console. So we can ignore this + // and move on to removing the secret. + case err != nil: + return err } } + + if err = r.Delete(ctx, &secret); err != nil { + return err + } } return nil @@ -284,30 +378,38 @@ func (r *RecorderReconciler) maybeCleanupServiceAccounts(ctx context.Context, ts func (r *RecorderReconciler) maybeCleanup(ctx context.Context, tsr *tsapi.Recorder) (bool, error) { logger := r.logger(tsr.Name) - prefs, ok, err := r.getDevicePrefs(ctx, tsr.Name) - if err != nil { - return false, err + var replicas int32 = 1 + if tsr.Spec.Replicas != nil { + replicas = *tsr.Spec.Replicas } - if !ok { - logger.Debugf("state Secret %s-0 not found or does not contain node ID, continuing cleanup", tsr.Name) - r.mu.Lock() - r.recorders.Remove(tsr.UID) - gaugeRecorderResources.Set(int64(r.recorders.Len())) - r.mu.Unlock() - return true, nil - } - - id := string(prefs.Config.NodeID) - logger.Debugf("deleting device %s from control", string(id)) - if err := r.tsClient.DeleteDevice(ctx, string(id)); err != nil { - errResp := &tailscale.ErrResponse{} - if ok := errors.As(err, errResp); ok && errResp.Status == http.StatusNotFound { - logger.Debugf("device %s not found, likely because it has already been deleted from control", string(id)) - } else { + + for replica := range replicas { + devicePrefs, ok, err := r.getDevicePrefs(ctx, tsr.Name, replica) + if err != nil { + return false, err + } + if !ok { + logger.Debugf("state Secret %s-%d not found or does not contain node ID, continuing cleanup", tsr.Name, replica) + r.mu.Lock() + r.recorders.Remove(tsr.UID) + gaugeRecorderResources.Set(int64(r.recorders.Len())) + r.mu.Unlock() + return true, nil + } + + nodeID := string(devicePrefs.Config.NodeID) + logger.Debugf("deleting device %s from control", nodeID) + if err = r.tsClient.DeleteDevice(ctx, nodeID); err != nil { + errResp := &tailscale.ErrResponse{} + if errors.As(err, errResp) && errResp.Status == http.StatusNotFound { + logger.Debugf("device %s not found, likely because it has already been deleted from control", nodeID) + continue + } + return false, fmt.Errorf("error deleting device: %w", err) } - } else { - logger.Debugf("device %s deleted from control", string(id)) + + logger.Debugf("device %s deleted from control", nodeID) } // Unlike most log entries in the reconcile loop, this will get printed @@ -319,38 +421,46 @@ func (r *RecorderReconciler) maybeCleanup(ctx context.Context, tsr *tsapi.Record r.recorders.Remove(tsr.UID) gaugeRecorderResources.Set(int64(r.recorders.Len())) r.mu.Unlock() + return true, nil } -func (r *RecorderReconciler) ensureAuthSecretCreated(ctx context.Context, tsr *tsapi.Recorder) error { - logger := r.logger(tsr.Name) - key := types.NamespacedName{ - Namespace: r.tsNamespace, - Name: tsr.Name, - } - if err := r.Get(ctx, key, &corev1.Secret{}); err == nil { - // No updates, already created the auth key. - logger.Debugf("auth Secret %s already exists", key.Name) - return nil - } else if !apierrors.IsNotFound(err) { - return err +func (r *RecorderReconciler) ensureAuthSecretsCreated(ctx context.Context, tsr *tsapi.Recorder) error { + var replicas int32 = 1 + if tsr.Spec.Replicas != nil { + replicas = *tsr.Spec.Replicas } - // Create the auth key Secret which is going to be used by the StatefulSet - // to authenticate with Tailscale. - logger.Debugf("creating authkey for new Recorder") tags := tsr.Spec.Tags if len(tags) == 0 { tags = tsapi.Tags{"tag:k8s"} } - authKey, err := newAuthKey(ctx, r.tsClient, tags.Stringify()) - if err != nil { - return err - } - logger.Debug("creating a new Secret for the Recorder") - if err := r.Create(ctx, tsrAuthSecret(tsr, r.tsNamespace, authKey)); err != nil { - return err + logger := r.logger(tsr.Name) + + for replica := range replicas { + key := types.NamespacedName{ + Namespace: r.tsNamespace, + Name: fmt.Sprintf("%s-auth-%d", tsr.Name, replica), + } + + err := r.Get(ctx, key, &corev1.Secret{}) + switch { + case err == nil: + logger.Debugf("auth Secret %q already exists", key.Name) + continue + case !apierrors.IsNotFound(err): + return fmt.Errorf("failed to get Secret %q: %w", key.Name, err) + } + + authKey, err := newAuthKey(ctx, r.tsClient, tags.Stringify()) + if err != nil { + return err + } + + if err = r.Create(ctx, tsrAuthSecret(tsr, r.tsNamespace, authKey, replica)); err != nil { + return err + } } return nil @@ -361,6 +471,10 @@ func (r *RecorderReconciler) validate(ctx context.Context, tsr *tsapi.Recorder) return errors.New("must either enable UI or use S3 storage to ensure recordings are accessible") } + if tsr.Spec.Replicas != nil && *tsr.Spec.Replicas > 1 && tsr.Spec.Storage.S3 == nil { + return errors.New("must use S3 storage when using multiple replicas to ensure recordings are accessible") + } + // Check any custom ServiceAccount config doesn't conflict with pre-existing // ServiceAccounts. This check is performed once during validation to ensure // errors are raised early, but also again during any Updates to prevent a race. @@ -394,11 +508,11 @@ func (r *RecorderReconciler) validate(ctx context.Context, tsr *tsapi.Recorder) return nil } -func (r *RecorderReconciler) getStateSecret(ctx context.Context, tsrName string) (*corev1.Secret, error) { +func (r *RecorderReconciler) getStateSecret(ctx context.Context, tsrName string, replica int32) (*corev1.Secret, error) { secret := &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ Namespace: r.tsNamespace, - Name: fmt.Sprintf("%s-0", tsrName), + Name: fmt.Sprintf("%s-%d", tsrName, replica), }, } if err := r.Get(ctx, client.ObjectKeyFromObject(secret), secret); err != nil { @@ -412,8 +526,8 @@ func (r *RecorderReconciler) getStateSecret(ctx context.Context, tsrName string) return secret, nil } -func (r *RecorderReconciler) getDevicePrefs(ctx context.Context, tsrName string) (prefs prefs, ok bool, err error) { - secret, err := r.getStateSecret(ctx, tsrName) +func (r *RecorderReconciler) getDevicePrefs(ctx context.Context, tsrName string, replica int32) (prefs prefs, ok bool, err error) { + secret, err := r.getStateSecret(ctx, tsrName, replica) if err != nil || secret == nil { return prefs, false, err } @@ -441,8 +555,8 @@ func getDevicePrefs(secret *corev1.Secret) (prefs prefs, ok bool, err error) { return prefs, ok, nil } -func (r *RecorderReconciler) getDeviceInfo(ctx context.Context, tsrName string) (d tsapi.RecorderTailnetDevice, ok bool, err error) { - secret, err := r.getStateSecret(ctx, tsrName) +func (r *RecorderReconciler) getDeviceInfo(ctx context.Context, tsrName string, replica int32) (d tsapi.RecorderTailnetDevice, ok bool, err error) { + secret, err := r.getStateSecret(ctx, tsrName, replica) if err != nil || secret == nil { return tsapi.RecorderTailnetDevice{}, false, err } diff --git a/cmd/k8s-operator/tsrecorder_specs.go b/cmd/k8s-operator/tsrecorder_specs.go index f5eedc2a1d1da..b4a10f2962ae9 100644 --- a/cmd/k8s-operator/tsrecorder_specs.go +++ b/cmd/k8s-operator/tsrecorder_specs.go @@ -12,30 +12,36 @@ import ( corev1 "k8s.io/api/core/v1" rbacv1 "k8s.io/api/rbac/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/types/ptr" "tailscale.com/version" ) func tsrStatefulSet(tsr *tsapi.Recorder, namespace string, loginServer string) *appsv1.StatefulSet { - return &appsv1.StatefulSet{ + var replicas int32 = 1 + if tsr.Spec.Replicas != nil { + replicas = *tsr.Spec.Replicas + } + + ss := &appsv1.StatefulSet{ ObjectMeta: metav1.ObjectMeta{ Name: tsr.Name, Namespace: namespace, - Labels: labels("recorder", tsr.Name, tsr.Spec.StatefulSet.Labels), + Labels: tsrLabels("recorder", tsr.Name, tsr.Spec.StatefulSet.Labels), OwnerReferences: tsrOwnerReference(tsr), Annotations: tsr.Spec.StatefulSet.Annotations, }, Spec: appsv1.StatefulSetSpec{ - Replicas: ptr.To[int32](1), + Replicas: ptr.To(replicas), Selector: &metav1.LabelSelector{ - MatchLabels: labels("recorder", tsr.Name, tsr.Spec.StatefulSet.Pod.Labels), + MatchLabels: tsrLabels("recorder", tsr.Name, tsr.Spec.StatefulSet.Pod.Labels), }, Template: corev1.PodTemplateSpec{ ObjectMeta: metav1.ObjectMeta{ Name: tsr.Name, Namespace: namespace, - Labels: labels("recorder", tsr.Name, tsr.Spec.StatefulSet.Pod.Labels), + Labels: tsrLabels("recorder", tsr.Name, tsr.Spec.StatefulSet.Pod.Labels), Annotations: tsr.Spec.StatefulSet.Pod.Annotations, }, Spec: corev1.PodSpec{ @@ -59,7 +65,7 @@ func tsrStatefulSet(tsr *tsapi.Recorder, namespace string, loginServer string) * ImagePullPolicy: tsr.Spec.StatefulSet.Pod.Container.ImagePullPolicy, Resources: tsr.Spec.StatefulSet.Pod.Container.Resources, SecurityContext: tsr.Spec.StatefulSet.Pod.Container.SecurityContext, - Env: env(tsr, loginServer), + Env: tsrEnv(tsr, loginServer), EnvFrom: func() []corev1.EnvFromSource { if tsr.Spec.Storage.S3 == nil || tsr.Spec.Storage.S3.Credentials.Secret.Name == "" { return nil @@ -95,6 +101,28 @@ func tsrStatefulSet(tsr *tsapi.Recorder, namespace string, loginServer string) * }, }, } + + for replica := range replicas { + volumeName := fmt.Sprintf("authkey-%d", replica) + + ss.Spec.Template.Spec.Containers[0].VolumeMounts = append(ss.Spec.Template.Spec.Containers[0].VolumeMounts, corev1.VolumeMount{ + Name: volumeName, + ReadOnly: true, + MountPath: fmt.Sprintf("/etc/tailscaled/%s-%d", ss.Name, replica), + }) + + ss.Spec.Template.Spec.Volumes = append(ss.Spec.Template.Spec.Volumes, corev1.Volume{ + Name: volumeName, + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: fmt.Sprintf("%s-auth-%d", tsr.Name, replica), + Items: []corev1.KeyToPath{{Key: "authkey", Path: "authkey"}}, + }, + }, + }) + } + + return ss } func tsrServiceAccount(tsr *tsapi.Recorder, namespace string) *corev1.ServiceAccount { @@ -102,7 +130,7 @@ func tsrServiceAccount(tsr *tsapi.Recorder, namespace string) *corev1.ServiceAcc ObjectMeta: metav1.ObjectMeta{ Name: tsrServiceAccountName(tsr), Namespace: namespace, - Labels: labels("recorder", tsr.Name, nil), + Labels: tsrLabels("recorder", tsr.Name, nil), OwnerReferences: tsrOwnerReference(tsr), Annotations: tsr.Spec.StatefulSet.Pod.ServiceAccount.Annotations, }, @@ -120,11 +148,24 @@ func tsrServiceAccountName(tsr *tsapi.Recorder) string { } func tsrRole(tsr *tsapi.Recorder, namespace string) *rbacv1.Role { + var replicas int32 = 1 + if tsr.Spec.Replicas != nil { + replicas = *tsr.Spec.Replicas + } + + resourceNames := make([]string, 0) + for replica := range replicas { + resourceNames = append(resourceNames, + fmt.Sprintf("%s-%d", tsr.Name, replica), // State secret. + fmt.Sprintf("%s-auth-%d", tsr.Name, replica), // Auth key secret. + ) + } + return &rbacv1.Role{ ObjectMeta: metav1.ObjectMeta{ Name: tsr.Name, Namespace: namespace, - Labels: labels("recorder", tsr.Name, nil), + Labels: tsrLabels("recorder", tsr.Name, nil), OwnerReferences: tsrOwnerReference(tsr), }, Rules: []rbacv1.PolicyRule{ @@ -136,10 +177,7 @@ func tsrRole(tsr *tsapi.Recorder, namespace string) *rbacv1.Role { "patch", "update", }, - ResourceNames: []string{ - tsr.Name, // Contains the auth key. - fmt.Sprintf("%s-0", tsr.Name), // Contains the node state. - }, + ResourceNames: resourceNames, }, { APIGroups: []string{""}, @@ -159,7 +197,7 @@ func tsrRoleBinding(tsr *tsapi.Recorder, namespace string) *rbacv1.RoleBinding { ObjectMeta: metav1.ObjectMeta{ Name: tsr.Name, Namespace: namespace, - Labels: labels("recorder", tsr.Name, nil), + Labels: tsrLabels("recorder", tsr.Name, nil), OwnerReferences: tsrOwnerReference(tsr), }, Subjects: []rbacv1.Subject{ @@ -176,12 +214,12 @@ func tsrRoleBinding(tsr *tsapi.Recorder, namespace string) *rbacv1.RoleBinding { } } -func tsrAuthSecret(tsr *tsapi.Recorder, namespace string, authKey string) *corev1.Secret { +func tsrAuthSecret(tsr *tsapi.Recorder, namespace string, authKey string, replica int32) *corev1.Secret { return &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ Namespace: namespace, - Name: tsr.Name, - Labels: labels("recorder", tsr.Name, nil), + Name: fmt.Sprintf("%s-auth-%d", tsr.Name, replica), + Labels: tsrLabels("recorder", tsr.Name, nil), OwnerReferences: tsrOwnerReference(tsr), }, StringData: map[string]string{ @@ -190,30 +228,19 @@ func tsrAuthSecret(tsr *tsapi.Recorder, namespace string, authKey string) *corev } } -func tsrStateSecret(tsr *tsapi.Recorder, namespace string) *corev1.Secret { +func tsrStateSecret(tsr *tsapi.Recorder, namespace string, replica int32) *corev1.Secret { return &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ - Name: fmt.Sprintf("%s-0", tsr.Name), + Name: fmt.Sprintf("%s-%d", tsr.Name, replica), Namespace: namespace, - Labels: labels("recorder", tsr.Name, nil), + Labels: tsrLabels("recorder", tsr.Name, nil), OwnerReferences: tsrOwnerReference(tsr), }, } } -func env(tsr *tsapi.Recorder, loginServer string) []corev1.EnvVar { +func tsrEnv(tsr *tsapi.Recorder, loginServer string) []corev1.EnvVar { envs := []corev1.EnvVar{ - { - Name: "TS_AUTHKEY", - ValueFrom: &corev1.EnvVarSource{ - SecretKeyRef: &corev1.SecretKeySelector{ - LocalObjectReference: corev1.LocalObjectReference{ - Name: tsr.Name, - }, - Key: "authkey", - }, - }, - }, { Name: "POD_NAME", ValueFrom: &corev1.EnvVarSource{ @@ -231,6 +258,10 @@ func env(tsr *tsapi.Recorder, loginServer string) []corev1.EnvVar { }, }, }, + { + Name: "TS_AUTHKEY_FILE", + Value: "/etc/tailscaled/$(POD_NAME)/authkey", + }, { Name: "TS_STATE", Value: "kube:$(POD_NAME)", @@ -280,18 +311,18 @@ func env(tsr *tsapi.Recorder, loginServer string) []corev1.EnvVar { return envs } -func labels(app, instance string, customLabels map[string]string) map[string]string { - l := make(map[string]string, len(customLabels)+3) +func tsrLabels(app, instance string, customLabels map[string]string) map[string]string { + labels := make(map[string]string, len(customLabels)+3) for k, v := range customLabels { - l[k] = v + labels[k] = v } // ref: https://kubernetes.io/docs/concepts/overview/working-with-objects/common-labels/ - l["app.kubernetes.io/name"] = app - l["app.kubernetes.io/instance"] = instance - l["app.kubernetes.io/managed-by"] = "tailscale-operator" + labels["app.kubernetes.io/name"] = app + labels["app.kubernetes.io/instance"] = instance + labels["app.kubernetes.io/managed-by"] = "tailscale-operator" - return l + return labels } func tsrOwnerReference(owner metav1.Object) []metav1.OwnerReference { diff --git a/cmd/k8s-operator/tsrecorder_specs_test.go b/cmd/k8s-operator/tsrecorder_specs_test.go index 49332d09b6a08..0d78129fc76b3 100644 --- a/cmd/k8s-operator/tsrecorder_specs_test.go +++ b/cmd/k8s-operator/tsrecorder_specs_test.go @@ -12,6 +12,7 @@ import ( corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/types/ptr" ) @@ -23,6 +24,7 @@ func TestRecorderSpecs(t *testing.T) { Name: "test", }, Spec: tsapi.RecorderSpec{ + Replicas: ptr.To[int32](3), StatefulSet: tsapi.RecorderStatefulSet{ Labels: map[string]string{ "ss-label-key": "ss-label-value", @@ -101,10 +103,10 @@ func TestRecorderSpecs(t *testing.T) { } // Pod-level. - if diff := cmp.Diff(ss.Labels, labels("recorder", "test", tsr.Spec.StatefulSet.Labels)); diff != "" { + if diff := cmp.Diff(ss.Labels, tsrLabels("recorder", "test", tsr.Spec.StatefulSet.Labels)); diff != "" { t.Errorf("(-got +want):\n%s", diff) } - if diff := cmp.Diff(ss.Spec.Template.Labels, labels("recorder", "test", tsr.Spec.StatefulSet.Pod.Labels)); diff != "" { + if diff := cmp.Diff(ss.Spec.Template.Labels, tsrLabels("recorder", "test", tsr.Spec.StatefulSet.Pod.Labels)); diff != "" { t.Errorf("(-got +want):\n%s", diff) } if diff := cmp.Diff(ss.Spec.Template.Spec.Affinity, tsr.Spec.StatefulSet.Pod.Affinity); diff != "" { @@ -124,7 +126,7 @@ func TestRecorderSpecs(t *testing.T) { } // Container-level. - if diff := cmp.Diff(ss.Spec.Template.Spec.Containers[0].Env, env(tsr, tsLoginServer)); diff != "" { + if diff := cmp.Diff(ss.Spec.Template.Spec.Containers[0].Env, tsrEnv(tsr, tsLoginServer)); diff != "" { t.Errorf("(-got +want):\n%s", diff) } if diff := cmp.Diff(ss.Spec.Template.Spec.Containers[0].Image, tsr.Spec.StatefulSet.Pod.Container.Image); diff != "" { @@ -139,5 +141,17 @@ func TestRecorderSpecs(t *testing.T) { if diff := cmp.Diff(ss.Spec.Template.Spec.Containers[0].Resources, tsr.Spec.StatefulSet.Pod.Container.Resources); diff != "" { t.Errorf("(-got +want):\n%s", diff) } + + if *ss.Spec.Replicas != *tsr.Spec.Replicas { + t.Errorf("expected %d replicas, got %d", *tsr.Spec.Replicas, *ss.Spec.Replicas) + } + + if len(ss.Spec.Template.Spec.Volumes) != int(*tsr.Spec.Replicas)+1 { + t.Errorf("expected %d volumes, got %d", *tsr.Spec.Replicas+1, len(ss.Spec.Template.Spec.Volumes)) + } + + if len(ss.Spec.Template.Spec.Containers[0].VolumeMounts) != int(*tsr.Spec.Replicas)+1 { + t.Errorf("expected %d volume mounts, got %d", *tsr.Spec.Replicas+1, len(ss.Spec.Template.Spec.Containers[0].VolumeMounts)) + } }) } diff --git a/cmd/k8s-operator/tsrecorder_test.go b/cmd/k8s-operator/tsrecorder_test.go index 990bd68193e8b..f7ff797b1ebba 100644 --- a/cmd/k8s-operator/tsrecorder_test.go +++ b/cmd/k8s-operator/tsrecorder_test.go @@ -8,6 +8,7 @@ package main import ( "context" "encoding/json" + "fmt" "strings" "testing" @@ -20,9 +21,11 @@ import ( "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" + tsoperator "tailscale.com/k8s-operator" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/tstest" + "tailscale.com/types/ptr" ) const ( @@ -36,6 +39,9 @@ func TestRecorder(t *testing.T) { Name: "test", Finalizers: []string{"tailscale.com/finalizer"}, }, + Spec: tsapi.RecorderSpec{ + Replicas: ptr.To[int32](3), + }, } fc := fake.NewClientBuilder(). @@ -52,7 +58,7 @@ func TestRecorder(t *testing.T) { Client: fc, tsClient: tsClient, recorder: fr, - l: zl.Sugar(), + log: zl.Sugar(), clock: cl, loginServer: tsLoginServer, } @@ -80,6 +86,15 @@ func TestRecorder(t *testing.T) { }) expectReconciled(t, reconciler, "", tsr.Name) + expectedEvent = "Warning RecorderInvalid Recorder is invalid: must use S3 storage when using multiple replicas to ensure recordings are accessible" + expectEvents(t, fr, []string{expectedEvent}) + + tsr.Spec.Storage.S3 = &tsapi.S3{} + mustUpdate(t, fc, "", "test", func(t *tsapi.Recorder) { + t.Spec = tsr.Spec + }) + expectReconciled(t, reconciler, "", tsr.Name) + // Only check part of this error message, because it's defined in an // external package and may change. if err := fc.Get(context.Background(), client.ObjectKey{ @@ -180,33 +195,47 @@ func TestRecorder(t *testing.T) { }) t.Run("populate_node_info_in_state_secret_and_see_it_appear_in_status", func(t *testing.T) { - bytes, err := json.Marshal(map[string]any{ - "Config": map[string]any{ - "NodeID": "nodeid-123", - "UserProfile": map[string]any{ - "LoginName": "test-0.example.ts.net", - }, - }, - }) - if err != nil { - t.Fatal(err) - } const key = "profile-abc" - mustUpdate(t, fc, tsNamespace, "test-0", func(s *corev1.Secret) { - s.Data = map[string][]byte{ - currentProfileKey: []byte(key), - key: bytes, + for replica := range *tsr.Spec.Replicas { + bytes, err := json.Marshal(map[string]any{ + "Config": map[string]any{ + "NodeID": fmt.Sprintf("node-%d", replica), + "UserProfile": map[string]any{ + "LoginName": fmt.Sprintf("test-%d.example.ts.net", replica), + }, + }, + }) + if err != nil { + t.Fatal(err) } - }) + + name := fmt.Sprintf("%s-%d", "test", replica) + mustUpdate(t, fc, tsNamespace, name, func(s *corev1.Secret) { + s.Data = map[string][]byte{ + currentProfileKey: []byte(key), + key: bytes, + } + }) + } expectReconciled(t, reconciler, "", tsr.Name) tsr.Status.Devices = []tsapi.RecorderTailnetDevice{ { - Hostname: "hostname-nodeid-123", + Hostname: "hostname-node-0", TailnetIPs: []string{"1.2.3.4", "::1"}, URL: "https://test-0.example.ts.net", }, + { + Hostname: "hostname-node-1", + TailnetIPs: []string{"1.2.3.4", "::1"}, + URL: "https://test-1.example.ts.net", + }, + { + Hostname: "hostname-node-2", + TailnetIPs: []string{"1.2.3.4", "::1"}, + URL: "https://test-2.example.ts.net", + }, } expectEqual(t, fc, tsr) }) @@ -222,7 +251,7 @@ func TestRecorder(t *testing.T) { if expected := 0; reconciler.recorders.Len() != expected { t.Fatalf("expected %d recorders, got %d", expected, reconciler.recorders.Len()) } - if diff := cmp.Diff(tsClient.deleted, []string{"nodeid-123"}); diff != "" { + if diff := cmp.Diff(tsClient.deleted, []string{"node-0", "node-1", "node-2"}); diff != "" { t.Fatalf("unexpected deleted devices (-got +want):\n%s", diff) } // The fake client does not clean up objects whose owner has been @@ -233,26 +262,38 @@ func TestRecorder(t *testing.T) { func expectRecorderResources(t *testing.T, fc client.WithWatch, tsr *tsapi.Recorder, shouldExist bool) { t.Helper() - auth := tsrAuthSecret(tsr, tsNamespace, "secret-authkey") - state := tsrStateSecret(tsr, tsNamespace) + var replicas int32 = 1 + if tsr.Spec.Replicas != nil { + replicas = *tsr.Spec.Replicas + } + role := tsrRole(tsr, tsNamespace) roleBinding := tsrRoleBinding(tsr, tsNamespace) serviceAccount := tsrServiceAccount(tsr, tsNamespace) statefulSet := tsrStatefulSet(tsr, tsNamespace, tsLoginServer) if shouldExist { - expectEqual(t, fc, auth) - expectEqual(t, fc, state) expectEqual(t, fc, role) expectEqual(t, fc, roleBinding) expectEqual(t, fc, serviceAccount) expectEqual(t, fc, statefulSet, removeResourceReqs) } else { - expectMissing[corev1.Secret](t, fc, auth.Namespace, auth.Name) - expectMissing[corev1.Secret](t, fc, state.Namespace, state.Name) expectMissing[rbacv1.Role](t, fc, role.Namespace, role.Name) expectMissing[rbacv1.RoleBinding](t, fc, roleBinding.Namespace, roleBinding.Name) expectMissing[corev1.ServiceAccount](t, fc, serviceAccount.Namespace, serviceAccount.Name) expectMissing[appsv1.StatefulSet](t, fc, statefulSet.Namespace, statefulSet.Name) } + + for replica := range replicas { + auth := tsrAuthSecret(tsr, tsNamespace, "secret-authkey", replica) + state := tsrStateSecret(tsr, tsNamespace, replica) + + if shouldExist { + expectEqual(t, fc, auth) + expectEqual(t, fc, state) + } else { + expectMissing[corev1.Secret](t, fc, auth.Namespace, auth.Name) + expectMissing[corev1.Secret](t, fc, state.Namespace, state.Name) + } + } } diff --git a/cmd/k8s-proxy/internal/config/config.go b/cmd/k8s-proxy/internal/config/config.go index 4013047e76f0c..0f0bd1bfcf39d 100644 --- a/cmd/k8s-proxy/internal/config/config.go +++ b/cmd/k8s-proxy/internal/config/config.go @@ -50,32 +50,32 @@ func NewConfigLoader(logger *zap.SugaredLogger, client clientcorev1.CoreV1Interf } } -func (l *configLoader) WatchConfig(ctx context.Context, path string) error { +func (ld *configLoader) WatchConfig(ctx context.Context, path string) error { secretNamespacedName, isKubeSecret := strings.CutPrefix(path, "kube:") if isKubeSecret { secretNamespace, secretName, ok := strings.Cut(secretNamespacedName, string(types.Separator)) if !ok { return fmt.Errorf("invalid Kubernetes Secret reference %q, expected format /", path) } - if err := l.watchConfigSecretChanges(ctx, secretNamespace, secretName); err != nil && !errors.Is(err, context.Canceled) { + if err := ld.watchConfigSecretChanges(ctx, secretNamespace, secretName); err != nil && !errors.Is(err, context.Canceled) { return fmt.Errorf("error watching config Secret %q: %w", secretNamespacedName, err) } return nil } - if err := l.watchConfigFileChanges(ctx, path); err != nil && !errors.Is(err, context.Canceled) { + if err := ld.watchConfigFileChanges(ctx, path); err != nil && !errors.Is(err, context.Canceled) { return fmt.Errorf("error watching config file %q: %w", path, err) } return nil } -func (l *configLoader) reloadConfig(ctx context.Context, raw []byte) error { - if bytes.Equal(raw, l.previous) { - if l.cfgIgnored != nil && testenv.InTest() { - l.once.Do(func() { - close(l.cfgIgnored) +func (ld *configLoader) reloadConfig(ctx context.Context, raw []byte) error { + if bytes.Equal(raw, ld.previous) { + if ld.cfgIgnored != nil && testenv.InTest() { + ld.once.Do(func() { + close(ld.cfgIgnored) }) } return nil @@ -89,14 +89,14 @@ func (l *configLoader) reloadConfig(ctx context.Context, raw []byte) error { select { case <-ctx.Done(): return ctx.Err() - case l.cfgChan <- &cfg: + case ld.cfgChan <- &cfg: } - l.previous = raw + ld.previous = raw return nil } -func (l *configLoader) watchConfigFileChanges(ctx context.Context, path string) error { +func (ld *configLoader) watchConfigFileChanges(ctx context.Context, path string) error { var ( tickChan <-chan time.Time eventChan <-chan fsnotify.Event @@ -106,14 +106,14 @@ func (l *configLoader) watchConfigFileChanges(ctx context.Context, path string) if w, err := fsnotify.NewWatcher(); err != nil { // Creating a new fsnotify watcher would fail for example if inotify was not able to create a new file descriptor. // See https://github.com/tailscale/tailscale/issues/15081 - l.logger.Infof("Failed to create fsnotify watcher on config file %q; watching for changes on 5s timer: %v", path, err) + ld.logger.Infof("Failed to create fsnotify watcher on config file %q; watching for changes on 5s timer: %v", path, err) ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() tickChan = ticker.C } else { dir := filepath.Dir(path) file := filepath.Base(path) - l.logger.Infof("Watching directory %q for changes to config file %q", dir, file) + ld.logger.Infof("Watching directory %q for changes to config file %q", dir, file) defer w.Close() if err := w.Add(dir); err != nil { return fmt.Errorf("failed to add fsnotify watch: %w", err) @@ -128,7 +128,7 @@ func (l *configLoader) watchConfigFileChanges(ctx context.Context, path string) if err != nil { return fmt.Errorf("error reading config file %q: %w", path, err) } - if err := l.reloadConfig(ctx, b); err != nil { + if err := ld.reloadConfig(ctx, b); err != nil { return fmt.Errorf("error loading initial config file %q: %w", path, err) } @@ -163,14 +163,14 @@ func (l *configLoader) watchConfigFileChanges(ctx context.Context, path string) if len(b) == 0 { continue } - if err := l.reloadConfig(ctx, b); err != nil { + if err := ld.reloadConfig(ctx, b); err != nil { return fmt.Errorf("error reloading config file %q: %v", path, err) } } } -func (l *configLoader) watchConfigSecretChanges(ctx context.Context, secretNamespace, secretName string) error { - secrets := l.client.Secrets(secretNamespace) +func (ld *configLoader) watchConfigSecretChanges(ctx context.Context, secretNamespace, secretName string) error { + secrets := ld.client.Secrets(secretNamespace) w, err := secrets.Watch(ctx, metav1.ListOptions{ TypeMeta: metav1.TypeMeta{ Kind: "Secret", @@ -198,11 +198,11 @@ func (l *configLoader) watchConfigSecretChanges(ctx context.Context, secretNames return fmt.Errorf("failed to get config Secret %q: %w", secretName, err) } - if err := l.configFromSecret(ctx, secret); err != nil { + if err := ld.configFromSecret(ctx, secret); err != nil { return fmt.Errorf("error loading initial config: %w", err) } - l.logger.Infof("Watching config Secret %q for changes", secretName) + ld.logger.Infof("Watching config Secret %q for changes", secretName) for { var secret *corev1.Secret select { @@ -237,7 +237,7 @@ func (l *configLoader) watchConfigSecretChanges(ctx context.Context, secretNames if secret == nil || secret.Data == nil { continue } - if err := l.configFromSecret(ctx, secret); err != nil { + if err := ld.configFromSecret(ctx, secret); err != nil { return fmt.Errorf("error reloading config Secret %q: %v", secret.Name, err) } case watch.Error: @@ -250,13 +250,13 @@ func (l *configLoader) watchConfigSecretChanges(ctx context.Context, secretNames } } -func (l *configLoader) configFromSecret(ctx context.Context, s *corev1.Secret) error { +func (ld *configLoader) configFromSecret(ctx context.Context, s *corev1.Secret) error { b := s.Data[kubetypes.KubeAPIServerConfigFile] if len(b) == 0 { return fmt.Errorf("config Secret %q does not contain expected config in key %q", s.Name, kubetypes.KubeAPIServerConfigFile) } - if err := l.reloadConfig(ctx, b); err != nil { + if err := ld.reloadConfig(ctx, b); err != nil { return err } diff --git a/cmd/k8s-proxy/internal/config/config_test.go b/cmd/k8s-proxy/internal/config/config_test.go index 1603dbe1f398f..bcb1b9ebd14e6 100644 --- a/cmd/k8s-proxy/internal/config/config_test.go +++ b/cmd/k8s-proxy/internal/config/config_test.go @@ -125,15 +125,15 @@ func TestWatchConfig(t *testing.T) { } } configChan := make(chan *conf.Config) - l := NewConfigLoader(zap.Must(zap.NewDevelopment()).Sugar(), cl.CoreV1(), configChan) - l.cfgIgnored = make(chan struct{}) + loader := NewConfigLoader(zap.Must(zap.NewDevelopment()).Sugar(), cl.CoreV1(), configChan) + loader.cfgIgnored = make(chan struct{}) errs := make(chan error) ctx, cancel := context.WithCancel(t.Context()) defer cancel() writeFile(t, tc.initialConfig) go func() { - errs <- l.WatchConfig(ctx, cfgPath) + errs <- loader.WatchConfig(ctx, cfgPath) }() for i, p := range tc.phases { @@ -159,7 +159,7 @@ func TestWatchConfig(t *testing.T) { } else if !strings.Contains(err.Error(), p.expectedErr) { t.Fatalf("expected error to contain %q, got %q", p.expectedErr, err.Error()) } - case <-l.cfgIgnored: + case <-loader.cfgIgnored: if p.expectedConf != nil { t.Fatalf("expected config to be reloaded, but got ignored signal") } @@ -192,13 +192,13 @@ func TestWatchConfigSecret_Rewatches(t *testing.T) { }) configChan := make(chan *conf.Config) - l := NewConfigLoader(zap.Must(zap.NewDevelopment()).Sugar(), cl.CoreV1(), configChan) + loader := NewConfigLoader(zap.Must(zap.NewDevelopment()).Sugar(), cl.CoreV1(), configChan) mustCreateOrUpdate(t, cl, secretFrom(expected[0])) errs := make(chan error) go func() { - errs <- l.watchConfigSecretChanges(t.Context(), "default", "config-secret") + errs <- loader.watchConfigSecretChanges(t.Context(), "default", "config-secret") }() for i := range 2 { @@ -212,7 +212,7 @@ func TestWatchConfigSecret_Rewatches(t *testing.T) { } case err := <-errs: t.Fatalf("unexpected error: %v", err) - case <-l.cfgIgnored: + case <-loader.cfgIgnored: t.Fatalf("expected config to be reloaded, but got ignored signal") case <-time.After(5 * time.Second): t.Fatalf("timed out waiting for expected event") diff --git a/cmd/natc/ippool/consensusippool.go b/cmd/natc/ippool/consensusippool.go index 64807b6c272f5..bfa909b69a3b4 100644 --- a/cmd/natc/ippool/consensusippool.go +++ b/cmd/natc/ippool/consensusippool.go @@ -422,9 +422,9 @@ func (ipp *ConsensusIPPool) applyCheckoutAddr(nid tailcfg.NodeID, domain string, } // Apply is part of the raft.FSM interface. It takes an incoming log entry and applies it to the state. -func (ipp *ConsensusIPPool) Apply(l *raft.Log) any { +func (ipp *ConsensusIPPool) Apply(lg *raft.Log) any { var c tsconsensus.Command - if err := json.Unmarshal(l.Data, &c); err != nil { + if err := json.Unmarshal(lg.Data, &c); err != nil { panic(fmt.Sprintf("failed to unmarshal command: %s", err.Error())) } switch c.Name { diff --git a/cmd/netlogfmt/main.go b/cmd/netlogfmt/main.go index 65e87098fec5e..b8aba4aaa6196 100644 --- a/cmd/netlogfmt/main.go +++ b/cmd/netlogfmt/main.go @@ -44,25 +44,51 @@ import ( "github.com/dsnet/try" jsonv2 "github.com/go-json-experiment/json" "github.com/go-json-experiment/json/jsontext" + "tailscale.com/tailcfg" + "tailscale.com/types/bools" "tailscale.com/types/logid" "tailscale.com/types/netlogtype" "tailscale.com/util/must" ) var ( - resolveNames = flag.Bool("resolve-names", false, "convert tailscale IP addresses to hostnames; must also specify --api-key and --tailnet-id") - apiKey = flag.String("api-key", "", "API key to query the Tailscale API with; see https://login.tailscale.com/admin/settings/keys") - tailnetName = flag.String("tailnet-name", "", "tailnet domain name to lookup devices in; see https://login.tailscale.com/admin/settings/general") + resolveNames = flag.Bool("resolve-names", false, "This is equivalent to specifying \"--resolve-addrs=name\".") + resolveAddrs = flag.String("resolve-addrs", "", "Resolve each tailscale IP address as a node ID, name, or user.\n"+ + "If network flow logs do not support embedded node information,\n"+ + "then --api-key and --tailnet-name must also be provided.\n"+ + "Valid values include \"nodeId\", \"name\", or \"user\".") + apiKey = flag.String("api-key", "", "The API key to query the Tailscale API with.\nSee https://login.tailscale.com/admin/settings/keys") + tailnetName = flag.String("tailnet-name", "", "The Tailnet name to lookup nodes within.\nSee https://login.tailscale.com/admin/settings/general") ) -var namesByAddr map[netip.Addr]string +var ( + tailnetNodesByAddr map[netip.Addr]netlogtype.Node + tailnetNodesByID map[tailcfg.StableNodeID]netlogtype.Node +) func main() { flag.Parse() if *resolveNames { - namesByAddr = mustMakeNamesByAddr() + *resolveAddrs = "name" + } + *resolveAddrs = strings.ToLower(*resolveAddrs) // make case-insensitive + *resolveAddrs = strings.TrimSuffix(*resolveAddrs, "s") // allow plural form + *resolveAddrs = strings.ReplaceAll(*resolveAddrs, " ", "") // ignore spaces + *resolveAddrs = strings.ReplaceAll(*resolveAddrs, "-", "") // ignore dashes + *resolveAddrs = strings.ReplaceAll(*resolveAddrs, "_", "") // ignore underscores + switch *resolveAddrs { + case "id", "nodeid": + *resolveAddrs = "nodeid" + case "name", "hostname": + *resolveAddrs = "name" + case "user", "tag", "usertag", "taguser": + *resolveAddrs = "user" // tag resolution is implied + default: + log.Fatalf("--resolve-addrs must be \"nodeId\", \"name\", or \"user\"") } + mustLoadTailnetNodes() + // The logic handles a stream of arbitrary JSON. // So long as a JSON object seems like a network log message, // then this will unmarshal and print it. @@ -103,7 +129,7 @@ func processArray(dec *jsontext.Decoder) { func processObject(dec *jsontext.Decoder) { var hasTraffic bool - var rawMsg []byte + var rawMsg jsontext.Value try.E1(dec.ReadToken()) // parse '{' for dec.PeekKind() != '}' { // Capture any members that could belong to a network log message. @@ -111,13 +137,13 @@ func processObject(dec *jsontext.Decoder) { case "virtualTraffic", "subnetTraffic", "exitTraffic", "physicalTraffic": hasTraffic = true fallthrough - case "logtail", "nodeId", "logged", "start", "end": + case "logtail", "nodeId", "logged", "srcNode", "dstNodes", "start", "end": if len(rawMsg) == 0 { rawMsg = append(rawMsg, '{') } else { rawMsg = append(rawMsg[:len(rawMsg)-1], ',') } - rawMsg = append(append(append(rawMsg, '"'), name.String()...), '"') + rawMsg, _ = jsontext.AppendQuote(rawMsg, name.String()) rawMsg = append(rawMsg, ':') rawMsg = append(rawMsg, try.E1(dec.ReadValue())...) rawMsg = append(rawMsg, '}') @@ -145,6 +171,32 @@ type message struct { } func printMessage(msg message) { + var nodesByAddr map[netip.Addr]netlogtype.Node + var tailnetDNS string // e.g., ".acme-corp.ts.net" + if *resolveAddrs != "" { + nodesByAddr = make(map[netip.Addr]netlogtype.Node) + insertNode := func(node netlogtype.Node) { + for _, addr := range node.Addresses { + nodesByAddr[addr] = node + } + } + for _, node := range msg.DstNodes { + insertNode(node) + } + insertNode(msg.SrcNode) + + // Derive the Tailnet DNS of the self node. + detectTailnetDNS := func(nodeName string) { + if prefix, ok := strings.CutSuffix(nodeName, ".ts.net"); ok { + if i := strings.LastIndexByte(prefix, '.'); i > 0 { + tailnetDNS = nodeName[i:] + } + } + } + detectTailnetDNS(msg.SrcNode.Name) + detectTailnetDNS(tailnetNodesByID[msg.NodeID].Name) + } + // Construct a table of network traffic per connection. rows := [][7]string{{3: "Tx[P/s]", 4: "Tx[B/s]", 5: "Rx[P/s]", 6: "Rx[B/s]"}} duration := msg.End.Sub(msg.Start) @@ -175,16 +227,25 @@ func printMessage(msg message) { if !a.IsValid() { return "" } - if name, ok := namesByAddr[a.Addr()]; ok { - if a.Port() == 0 { - return name + name := a.Addr().String() + node, ok := tailnetNodesByAddr[a.Addr()] + if !ok { + node, ok = nodesByAddr[a.Addr()] + } + if ok { + switch *resolveAddrs { + case "nodeid": + name = cmp.Or(string(node.NodeID), name) + case "name": + name = cmp.Or(strings.TrimSuffix(string(node.Name), tailnetDNS), name) + case "user": + name = cmp.Or(bools.IfElse(len(node.Tags) > 0, fmt.Sprint(node.Tags), node.User), name) } - return name + ":" + strconv.Itoa(int(a.Port())) } - if a.Port() == 0 { - return a.Addr().String() + if a.Port() != 0 { + return name + ":" + strconv.Itoa(int(a.Port())) } - return a.String() + return name } for _, cc := range traffic { row := [7]string{ @@ -279,8 +340,10 @@ func printMessage(msg message) { } } -func mustMakeNamesByAddr() map[netip.Addr]string { +func mustLoadTailnetNodes() { switch { + case *apiKey == "" && *tailnetName == "": + return // rely on embedded node information in the logs themselves case *apiKey == "": log.Fatalf("--api-key must be specified with --resolve-names") case *tailnetName == "": @@ -300,57 +363,19 @@ func mustMakeNamesByAddr() map[netip.Addr]string { // Unmarshal the API response. var m struct { - Devices []struct { - Name string `json:"name"` - Addrs []netip.Addr `json:"addresses"` - } `json:"devices"` + Devices []netlogtype.Node `json:"devices"` } must.Do(json.Unmarshal(b, &m)) - // Construct a unique mapping of Tailscale IP addresses to hostnames. - // For brevity, we start with the first segment of the name and - // use more segments until we find the shortest prefix that is unique - // for all names in the tailnet. - seen := make(map[string]bool) - namesByAddr := make(map[netip.Addr]string) -retry: - for i := range 10 { - clear(seen) - clear(namesByAddr) - for _, d := range m.Devices { - name := fieldPrefix(d.Name, i) - if seen[name] { - continue retry - } - seen[name] = true - for _, a := range d.Addrs { - namesByAddr[a] = name - } - } - return namesByAddr - } - panic("unable to produce unique mapping of address to names") -} - -// fieldPrefix returns the first n number of dot-separated segments. -// -// Example: -// -// fieldPrefix("foo.bar.baz", 0) returns "" -// fieldPrefix("foo.bar.baz", 1) returns "foo" -// fieldPrefix("foo.bar.baz", 2) returns "foo.bar" -// fieldPrefix("foo.bar.baz", 3) returns "foo.bar.baz" -// fieldPrefix("foo.bar.baz", 4) returns "foo.bar.baz" -func fieldPrefix(s string, n int) string { - s0 := s - for i := 0; i < n && len(s) > 0; i++ { - if j := strings.IndexByte(s, '.'); j >= 0 { - s = s[j+1:] - } else { - s = "" + // Construct a mapping of Tailscale IP addresses to node information. + tailnetNodesByAddr = make(map[netip.Addr]netlogtype.Node) + tailnetNodesByID = make(map[tailcfg.StableNodeID]netlogtype.Node) + for _, node := range m.Devices { + for _, addr := range node.Addresses { + tailnetNodesByAddr[addr] = node } + tailnetNodesByID[node.NodeID] = node } - return strings.TrimSuffix(s0[:len(s0)-len(s)], ".") } func appendRepeatByte(b []byte, c byte, n int) []byte { diff --git a/cmd/sniproxy/sniproxy.go b/cmd/sniproxy/sniproxy.go index c020b4a1f1605..2115c8095b351 100644 --- a/cmd/sniproxy/sniproxy.go +++ b/cmd/sniproxy/sniproxy.go @@ -141,7 +141,7 @@ func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, pro // in the netmap. // We set the NotifyInitialNetMap flag so we will always get woken with the // current netmap, before only being woken on changes. - bus, err := lc.WatchIPNBus(ctx, ipn.NotifyWatchEngineUpdates|ipn.NotifyInitialNetMap|ipn.NotifyNoPrivateKeys) + bus, err := lc.WatchIPNBus(ctx, ipn.NotifyWatchEngineUpdates|ipn.NotifyInitialNetMap) if err != nil { log.Fatalf("watching IPN bus: %v", err) } diff --git a/cmd/sniproxy/sniproxy_test.go b/cmd/sniproxy/sniproxy_test.go index cd2e070bd336f..65e059efaa1d4 100644 --- a/cmd/sniproxy/sniproxy_test.go +++ b/cmd/sniproxy/sniproxy_test.go @@ -152,17 +152,17 @@ func TestSNIProxyWithNetmapConfig(t *testing.T) { configCapKey: []tailcfg.RawMessage{tailcfg.RawMessage(b)}, }) - // Lets spin up a second node (to represent the client). + // Let's spin up a second node (to represent the client). client, _, _ := startNode(t, ctx, controlURL, "client") // Make sure that the sni node has received its config. - l, err := sni.LocalClient() + lc, err := sni.LocalClient() if err != nil { t.Fatal(err) } gotConfigured := false for range 100 { - s, err := l.StatusWithoutPeers(ctx) + s, err := lc.StatusWithoutPeers(ctx) if err != nil { t.Fatal(err) } @@ -176,7 +176,7 @@ func TestSNIProxyWithNetmapConfig(t *testing.T) { t.Error("sni node never received its configuration from the coordination server!") } - // Lets make the client open a connection to the sniproxy node, and + // Let's make the client open a connection to the sniproxy node, and // make sure it results in a connection to our test listener. w, err := client.Dial(ctx, "tcp", fmt.Sprintf("%s:%d", ip, ln.Addr().(*net.TCPAddr).Port)) if err != nil { @@ -208,10 +208,10 @@ func TestSNIProxyWithFlagConfig(t *testing.T) { sni, _, ip := startNode(t, ctx, controlURL, "snitest") go run(ctx, sni, 0, sni.Hostname, false, 0, "", fmt.Sprintf("tcp/%d/localhost", ln.Addr().(*net.TCPAddr).Port)) - // Lets spin up a second node (to represent the client). + // Let's spin up a second node (to represent the client). client, _, _ := startNode(t, ctx, controlURL, "client") - // Lets make the client open a connection to the sniproxy node, and + // Let's make the client open a connection to the sniproxy node, and // make sure it results in a connection to our test listener. w, err := client.Dial(ctx, "tcp", fmt.Sprintf("%s:%d", ip, ln.Addr().(*net.TCPAddr).Port)) if err != nil { diff --git a/cmd/stund/depaware.txt b/cmd/stund/depaware.txt index bd8eebb7b1d27..7b945dd77ea79 100644 --- a/cmd/stund/depaware.txt +++ b/cmd/stund/depaware.txt @@ -14,9 +14,9 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar github.com/prometheus/client_model/go from github.com/prometheus/client_golang/prometheus+ github.com/prometheus/common/expfmt from github.com/prometheus/client_golang/prometheus+ github.com/prometheus/common/model from github.com/prometheus/client_golang/prometheus+ - LD github.com/prometheus/procfs from github.com/prometheus/client_golang/prometheus - LD github.com/prometheus/procfs/internal/fs from github.com/prometheus/procfs - LD github.com/prometheus/procfs/internal/util from github.com/prometheus/procfs + L github.com/prometheus/procfs from github.com/prometheus/client_golang/prometheus + L github.com/prometheus/procfs/internal/fs from github.com/prometheus/procfs + L github.com/prometheus/procfs/internal/util from github.com/prometheus/procfs 💣 go4.org/mem from tailscale.com/metrics+ go4.org/netipx from tailscale.com/net/tsaddr google.golang.org/protobuf/encoding/protodelim from github.com/prometheus/common/expfmt @@ -47,7 +47,7 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar google.golang.org/protobuf/reflect/protoregistry from google.golang.org/protobuf/encoding/prototext+ google.golang.org/protobuf/runtime/protoiface from google.golang.org/protobuf/internal/impl+ google.golang.org/protobuf/runtime/protoimpl from github.com/prometheus/client_model/go+ - google.golang.org/protobuf/types/known/timestamppb from github.com/prometheus/client_golang/prometheus+ + 💣 google.golang.org/protobuf/types/known/timestamppb from github.com/prometheus/client_golang/prometheus+ tailscale.com from tailscale.com/version tailscale.com/envknob from tailscale.com/tsweb+ tailscale.com/feature from tailscale.com/tsweb @@ -82,8 +82,9 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar tailscale.com/util/mak from tailscale.com/syncs+ tailscale.com/util/nocasemaps from tailscale.com/types/ipproto tailscale.com/util/rands from tailscale.com/tsweb + tailscale.com/util/set from tailscale.com/types/key tailscale.com/util/slicesx from tailscale.com/tailcfg - tailscale.com/util/testenv from tailscale.com/types/logger + tailscale.com/util/testenv from tailscale.com/types/logger+ tailscale.com/util/vizerror from tailscale.com/tailcfg+ tailscale.com/version from tailscale.com/envknob+ tailscale.com/version/distro from tailscale.com/envknob @@ -94,7 +95,7 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar golang.org/x/crypto/nacl/box from tailscale.com/types/key golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ - golang.org/x/exp/constraints from tailscale.com/tsweb/varz + golang.org/x/exp/constraints from tailscale.com/tsweb/varz+ golang.org/x/sys/cpu from golang.org/x/crypto/blake2b+ LD golang.org/x/sys/unix from github.com/prometheus/procfs+ W golang.org/x/sys/windows from github.com/prometheus/client_golang/prometheus diff --git a/cmd/stunstamp/stunstamp.go b/cmd/stunstamp/stunstamp.go index 71ed505690243..153dc9303bbb0 100644 --- a/cmd/stunstamp/stunstamp.go +++ b/cmd/stunstamp/stunstamp.go @@ -135,18 +135,18 @@ type lportsPool struct { ports []int } -func (l *lportsPool) get() int { - l.Lock() - defer l.Unlock() - ret := l.ports[0] - l.ports = append(l.ports[:0], l.ports[1:]...) +func (pl *lportsPool) get() int { + pl.Lock() + defer pl.Unlock() + ret := pl.ports[0] + pl.ports = append(pl.ports[:0], pl.ports[1:]...) return ret } -func (l *lportsPool) put(i int) { - l.Lock() - defer l.Unlock() - l.ports = append(l.ports, int(i)) +func (pl *lportsPool) put(i int) { + pl.Lock() + defer pl.Unlock() + pl.ports = append(pl.ports, int(i)) } var ( @@ -173,19 +173,19 @@ func init() { // measure dial time. type lportForTCPConn int -func (l *lportForTCPConn) Close() error { - if *l == 0 { +func (lp *lportForTCPConn) Close() error { + if *lp == 0 { return nil } - lports.put(int(*l)) + lports.put(int(*lp)) return nil } -func (l *lportForTCPConn) Write([]byte) (int, error) { +func (lp *lportForTCPConn) Write([]byte) (int, error) { return 0, errors.New("unimplemented") } -func (l *lportForTCPConn) Read([]byte) (int, error) { +func (lp *lportForTCPConn) Read([]byte) (int, error) { return 0, errors.New("unimplemented") } diff --git a/cmd/sync-containers/main.go b/cmd/sync-containers/main.go index 6317b4943ae82..63efa54531b10 100644 --- a/cmd/sync-containers/main.go +++ b/cmd/sync-containers/main.go @@ -65,9 +65,9 @@ func main() { } add, remove := diffTags(stags, dtags) - if l := len(add); l > 0 { + if ln := len(add); ln > 0 { log.Printf("%d tags to push: %s", len(add), strings.Join(add, ", ")) - if *max > 0 && l > *max { + if *max > 0 && ln > *max { log.Printf("Limiting sync to %d tags", *max) add = add[:*max] } diff --git a/cmd/tailscale/cli/cli_test.go b/cmd/tailscale/cli/cli_test.go index 2e1bec8c9bcb0..8762b7aaeb905 100644 --- a/cmd/tailscale/cli/cli_test.go +++ b/cmd/tailscale/cli/cli_test.go @@ -174,6 +174,7 @@ func TestCheckForAccidentalSettingReverts(t *testing.T) { curUser string // os.Getenv("USER") on the client side goos string // empty means "linux" distro distro.Distro + backendState string // empty means "Running" want string }{ @@ -188,6 +189,28 @@ func TestCheckForAccidentalSettingReverts(t *testing.T) { }, want: "", }, + { + name: "bare_up_needs_login_default_prefs", + flags: []string{}, + curPrefs: ipn.NewPrefs(), + backendState: ipn.NeedsLogin.String(), + want: "", + }, + { + name: "bare_up_needs_login_losing_prefs", + flags: []string{}, + curPrefs: &ipn.Prefs{ + // defaults: + ControlURL: ipn.DefaultControlURL, + WantRunning: false, + NetfilterMode: preftype.NetfilterOn, + NoStatefulFiltering: opt.NewBool(true), + // non-default: + CorpDNS: false, + }, + backendState: ipn.NeedsLogin.String(), + want: accidentalUpPrefix + " --accept-dns=false", + }, { name: "losing_hostname", flags: []string{"--accept-dns"}, @@ -620,9 +643,13 @@ func TestCheckForAccidentalSettingReverts(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - goos := "linux" - if tt.goos != "" { - goos = tt.goos + goos := stdcmp.Or(tt.goos, "linux") + backendState := stdcmp.Or(tt.backendState, ipn.Running.String()) + // Needs to match the other conditions in checkForAccidentalSettingReverts + tt.curPrefs.Persist = &persist.Persist{ + UserProfile: tailcfg.UserProfile{ + LoginName: "janet", + }, } var upArgs upArgsT flagSet := newUpFlagSet(goos, &upArgs, "up") @@ -638,10 +665,11 @@ func TestCheckForAccidentalSettingReverts(t *testing.T) { curExitNodeIP: tt.curExitNodeIP, distro: tt.distro, user: tt.curUser, + backendState: backendState, } applyImplicitPrefs(newPrefs, tt.curPrefs, upEnv) var got string - if err := checkForAccidentalSettingReverts(newPrefs, tt.curPrefs, upEnv); err != nil { + if _, err := checkForAccidentalSettingReverts(newPrefs, tt.curPrefs, upEnv); err != nil { got = err.Error() } if strings.TrimSpace(got) != tt.want { @@ -1011,13 +1039,10 @@ func TestUpdatePrefs(t *testing.T) { wantErrSubtr string }{ { - name: "bare_up_means_up", - flags: []string{}, - curPrefs: &ipn.Prefs{ - ControlURL: ipn.DefaultControlURL, - WantRunning: false, - Hostname: "foo", - }, + name: "bare_up_means_up", + flags: []string{}, + curPrefs: ipn.NewPrefs(), + wantSimpleUp: false, // user profile not set, so no simple up }, { name: "just_up", @@ -1031,6 +1056,32 @@ func TestUpdatePrefs(t *testing.T) { }, wantSimpleUp: true, }, + { + name: "just_up_needs_login_default_prefs", + flags: []string{}, + curPrefs: ipn.NewPrefs(), + env: upCheckEnv{ + backendState: "NeedsLogin", + }, + wantSimpleUp: false, + }, + { + name: "just_up_needs_login_losing_prefs", + flags: []string{}, + curPrefs: &ipn.Prefs{ + // defaults: + ControlURL: ipn.DefaultControlURL, + WantRunning: false, + NetfilterMode: preftype.NetfilterOn, + // non-default: + CorpDNS: false, + }, + env: upCheckEnv{ + backendState: "NeedsLogin", + }, + wantSimpleUp: false, + wantErrSubtr: "tailscale up --accept-dns=false", + }, { name: "just_edit", flags: []string{}, diff --git a/cmd/tailscale/cli/debug.go b/cmd/tailscale/cli/debug.go index 2836ae29814e7..2facd66ae0278 100644 --- a/cmd/tailscale/cli/debug.go +++ b/cmd/tailscale/cli/debug.go @@ -182,6 +182,12 @@ func debugCmd() *ffcli.Command { Exec: localAPIAction("rebind"), ShortHelp: "Force a magicsock rebind", }, + { + Name: "rotate-disco-key", + ShortUsage: "tailscale debug rotate-disco-key", + Exec: localAPIAction("rotate-disco-key"), + ShortHelp: "Rotate the discovery key", + }, { Name: "derp-set-on-demand", ShortUsage: "tailscale debug derp-set-on-demand", @@ -258,7 +264,6 @@ func debugCmd() *ffcli.Command { fs.BoolVar(&watchIPNArgs.netmap, "netmap", true, "include netmap in messages") fs.BoolVar(&watchIPNArgs.initial, "initial", false, "include initial status") fs.BoolVar(&watchIPNArgs.rateLimit, "rate-limit", true, "rate limit messags") - fs.BoolVar(&watchIPNArgs.showPrivateKey, "show-private-key", false, "include node private key in printed netmap") fs.IntVar(&watchIPNArgs.count, "count", 0, "exit after printing this many statuses, or 0 to keep going forever") return fs })(), @@ -270,7 +275,6 @@ func debugCmd() *ffcli.Command { ShortHelp: "Print the current network map", FlagSet: (func() *flag.FlagSet { fs := newFlagSet("netmap") - fs.BoolVar(&netmapArgs.showPrivateKey, "show-private-key", false, "include node private key in printed netmap") return fs })(), }, @@ -614,11 +618,10 @@ func runPrefs(ctx context.Context, args []string) error { } var watchIPNArgs struct { - netmap bool - initial bool - showPrivateKey bool - rateLimit bool - count int + netmap bool + initial bool + rateLimit bool + count int } func runWatchIPN(ctx context.Context, args []string) error { @@ -626,9 +629,6 @@ func runWatchIPN(ctx context.Context, args []string) error { if watchIPNArgs.initial { mask = ipn.NotifyInitialState | ipn.NotifyInitialPrefs | ipn.NotifyInitialNetMap } - if !watchIPNArgs.showPrivateKey { - mask |= ipn.NotifyNoPrivateKeys - } if watchIPNArgs.rateLimit { mask |= ipn.NotifyRateLimit } @@ -652,18 +652,11 @@ func runWatchIPN(ctx context.Context, args []string) error { return nil } -var netmapArgs struct { - showPrivateKey bool -} - func runNetmap(ctx context.Context, args []string) error { ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() var mask ipn.NotifyWatchOpt = ipn.NotifyInitialNetMap - if !netmapArgs.showPrivateKey { - mask |= ipn.NotifyNoPrivateKeys - } watcher, err := localClient.WatchIPNBus(ctx, mask) if err != nil { return err diff --git a/cmd/tailscale/cli/jsonoutput/jsonoutput.go b/cmd/tailscale/cli/jsonoutput/jsonoutput.go new file mode 100644 index 0000000000000..aa49acc28baae --- /dev/null +++ b/cmd/tailscale/cli/jsonoutput/jsonoutput.go @@ -0,0 +1,84 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package jsonoutput provides stable and versioned JSON serialisation for CLI output. +// This allows us to provide stable output to scripts/clients, but also make +// breaking changes to the output when it's useful. +// +// Historically we only used `--json` as a boolean flag, so changing the output +// could break scripts that rely on the existing format. +// +// This package allows callers to pass a version number to `--json` and get +// a consistent output. We'll bump the version when we make a breaking change +// that's likely to break scripts that rely on the existing output, e.g. if +// we remove a field or change the type/format. +// +// Passing just the boolean flag `--json` will always return v1, to preserve +// compatibility with scripts written before we versioned our output. +package jsonoutput + +import ( + "errors" + "fmt" + "strconv" +) + +// JSONSchemaVersion implements flag.Value, and tracks whether the CLI has +// been called with `--json`, and if so, with what value. +type JSONSchemaVersion struct { + // IsSet tracks if the flag was provided at all. + IsSet bool + + // Value tracks the desired schema version, which defaults to 1 if + // the user passes `--json` without an argument. + Value int +} + +// String returns the default value which is printed in the CLI help text. +func (v *JSONSchemaVersion) String() string { + if v.IsSet { + return strconv.Itoa(v.Value) + } else { + return "(not set)" + } +} + +// Set is called when the user passes the flag as a command-line argument. +func (v *JSONSchemaVersion) Set(s string) error { + if v.IsSet { + return errors.New("received multiple instances of --json; only pass it once") + } + + v.IsSet = true + + // If the user doesn't supply a schema version, default to 1. + // This ensures that any existing scripts will continue to get their + // current output. + if s == "true" { + v.Value = 1 + return nil + } + + version, err := strconv.Atoi(s) + if err != nil { + return fmt.Errorf("invalid integer value passed to --json: %q", s) + } + v.Value = version + return nil +} + +// IsBoolFlag tells the flag package that JSONSchemaVersion can be set +// without an argument. +func (v *JSONSchemaVersion) IsBoolFlag() bool { + return true +} + +// ResponseEnvelope is a set of fields common to all versioned JSON output. +type ResponseEnvelope struct { + // SchemaVersion is the version of the JSON output, e.g. "1", "2", "3" + SchemaVersion string + + // ResponseWarning tells a user if a newer version of the JSON output + // is available. + ResponseWarning string `json:"_WARNING,omitzero"` +} diff --git a/cmd/tailscale/cli/jsonoutput/network-lock-v1.go b/cmd/tailscale/cli/jsonoutput/network-lock-v1.go new file mode 100644 index 0000000000000..8a2d2de336b3d --- /dev/null +++ b/cmd/tailscale/cli/jsonoutput/network-lock-v1.go @@ -0,0 +1,203 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package jsonoutput + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "io" + + "tailscale.com/ipn/ipnstate" + "tailscale.com/tka" +) + +// PrintNetworkLockJSONV1 prints the stored TKA state as a JSON object to the CLI, +// in a stable "v1" format. +// +// This format includes: +// +// - the AUM hash as a base32-encoded string +// - the raw AUM as base64-encoded bytes +// - the expanded AUM, which prints named fields for consumption by other tools +func PrintNetworkLockJSONV1(out io.Writer, updates []ipnstate.NetworkLockUpdate) error { + messages := make([]logMessageV1, len(updates)) + + for i, update := range updates { + var aum tka.AUM + if err := aum.Unserialize(update.Raw); err != nil { + return fmt.Errorf("decoding: %w", err) + } + + h := aum.Hash() + + if !bytes.Equal(h[:], update.Hash[:]) { + return fmt.Errorf("incorrect AUM hash: got %v, want %v", h, update) + } + + messages[i] = toLogMessageV1(aum, update) + } + + result := struct { + ResponseEnvelope + Messages []logMessageV1 + }{ + ResponseEnvelope: ResponseEnvelope{ + SchemaVersion: "1", + }, + Messages: messages, + } + + enc := json.NewEncoder(out) + enc.SetIndent("", " ") + return enc.Encode(result) +} + +// toLogMessageV1 converts a [tka.AUM] and [ipnstate.NetworkLockUpdate] to the +// JSON output returned by the CLI. +func toLogMessageV1(aum tka.AUM, update ipnstate.NetworkLockUpdate) logMessageV1 { + expandedAUM := expandedAUMV1{} + expandedAUM.MessageKind = aum.MessageKind.String() + if len(aum.PrevAUMHash) > 0 { + expandedAUM.PrevAUMHash = aum.PrevAUMHash.String() + } + if key := aum.Key; key != nil { + expandedAUM.Key = toExpandedKeyV1(key) + } + if keyID := aum.KeyID; keyID != nil { + expandedAUM.KeyID = fmt.Sprintf("tlpub:%x", keyID) + } + if state := aum.State; state != nil { + expandedState := expandedStateV1{} + if h := state.LastAUMHash; h != nil { + expandedState.LastAUMHash = h.String() + } + for _, secret := range state.DisablementSecrets { + expandedState.DisablementSecrets = append(expandedState.DisablementSecrets, fmt.Sprintf("%x", secret)) + } + for _, key := range state.Keys { + expandedState.Keys = append(expandedState.Keys, toExpandedKeyV1(&key)) + } + expandedState.StateID1 = state.StateID1 + expandedState.StateID2 = state.StateID2 + expandedAUM.State = expandedState + } + if votes := aum.Votes; votes != nil { + expandedAUM.Votes = *votes + } + expandedAUM.Meta = aum.Meta + for _, signature := range aum.Signatures { + expandedAUM.Signatures = append(expandedAUM.Signatures, expandedSignatureV1{ + KeyID: fmt.Sprintf("tlpub:%x", signature.KeyID), + Signature: base64.URLEncoding.EncodeToString(signature.Signature), + }) + } + + return logMessageV1{ + Hash: aum.Hash().String(), + AUM: expandedAUM, + Raw: base64.URLEncoding.EncodeToString(update.Raw), + } +} + +// toExpandedKeyV1 converts a [tka.Key] to the JSON output returned +// by the CLI. +func toExpandedKeyV1(key *tka.Key) expandedKeyV1 { + return expandedKeyV1{ + Kind: key.Kind.String(), + Votes: key.Votes, + Public: fmt.Sprintf("tlpub:%x", key.Public), + Meta: key.Meta, + } +} + +// logMessageV1 is the JSON representation of an AUM as both raw bytes and +// in its expanded form, and the CLI output is a list of these entries. +type logMessageV1 struct { + // The BLAKE2s digest of the CBOR-encoded AUM. This is printed as a + // base32-encoded string, e.g. KCE…XZQ + Hash string + + // The expanded form of the AUM, which presents the fields in a more + // accessible format than doing a CBOR decoding. + AUM expandedAUMV1 + + // The raw bytes of the CBOR-encoded AUM, encoded as base64. + // This is useful for verifying the AUM hash. + Raw string +} + +// expandedAUMV1 is the expanded version of a [tka.AUM], designed so external tools +// can read the AUM without knowing our CBOR definitions. +type expandedAUMV1 struct { + MessageKind string + PrevAUMHash string `json:"PrevAUMHash,omitzero"` + + // Key encodes a public key to be added to the key authority. + // This field is used for AddKey AUMs. + Key expandedKeyV1 `json:"Key,omitzero"` + + // KeyID references a public key which is part of the key authority. + // This field is used for RemoveKey and UpdateKey AUMs. + KeyID string `json:"KeyID,omitzero"` + + // State describes the full state of the key authority. + // This field is used for Checkpoint AUMs. + State expandedStateV1 `json:"State,omitzero"` + + // Votes and Meta describe properties of a key in the key authority. + // These fields are used for UpdateKey AUMs. + Votes uint `json:"Votes,omitzero"` + Meta map[string]string `json:"Meta,omitzero"` + + // Signatures lists the signatures over this AUM. + Signatures []expandedSignatureV1 `json:"Signatures,omitzero"` +} + +// expandedAUMV1 is the expanded version of a [tka.Key], which describes +// the public components of a key known to network-lock. +type expandedKeyV1 struct { + Kind string + + // Votes describes the weight applied to signatures using this key. + Votes uint + + // Public encodes the public key of the key as a hex string. + Public string + + // Meta describes arbitrary metadata about the key. This could be + // used to store the name of the key, for instance. + Meta map[string]string `json:"Meta,omitzero"` +} + +// expandedStateV1 is the expanded version of a [tka.State], which describes +// Tailnet Key Authority state at an instant in time. +type expandedStateV1 struct { + // LastAUMHash is the blake2s digest of the last-applied AUM. + LastAUMHash string `json:"LastAUMHash,omitzero"` + + // DisablementSecrets are KDF-derived values which can be used + // to turn off the TKA in the event of a consensus-breaking bug. + DisablementSecrets []string + + // Keys are the public keys of either: + // + // 1. The signing nodes currently trusted by the TKA. + // 2. Ephemeral keys that were used to generate pre-signed auth keys. + Keys []expandedKeyV1 + + // StateID's are nonce's, generated on enablement and fixed for + // the lifetime of the Tailnet Key Authority. + StateID1 uint64 + StateID2 uint64 +} + +// expandedSignatureV1 is the expanded form of a [tka.Signature], which +// describes a signature over an AUM. This signature can be verified +// using the key referenced by KeyID. +type expandedSignatureV1 struct { + KeyID string + Signature string +} diff --git a/cmd/tailscale/cli/netcheck.go b/cmd/tailscale/cli/netcheck.go index 5ae8db8fa3fbb..a8a8992f5ba23 100644 --- a/cmd/tailscale/cli/netcheck.go +++ b/cmd/tailscale/cli/netcheck.go @@ -180,7 +180,11 @@ func printReport(dm *tailcfg.DERPMap, report *netcheck.Report) error { printf("\t* Nearest DERP: unknown (no response to latency probes)\n") } else { if report.PreferredDERP != 0 { - printf("\t* Nearest DERP: %v\n", dm.Regions[report.PreferredDERP].RegionName) + if region, ok := dm.Regions[report.PreferredDERP]; ok { + printf("\t* Nearest DERP: %v\n", region.RegionName) + } else { + printf("\t* Nearest DERP: %v (region not found in map)\n", report.PreferredDERP) + } } else { printf("\t* Nearest DERP: [none]\n") } diff --git a/cmd/tailscale/cli/network-lock.go b/cmd/tailscale/cli/network-lock.go index a15d9ab88b596..73b1d62016a75 100644 --- a/cmd/tailscale/cli/network-lock.go +++ b/cmd/tailscale/cli/network-lock.go @@ -10,10 +10,11 @@ import ( "context" "crypto/rand" "encoding/hex" - "encoding/json" + jsonv1 "encoding/json" "errors" "flag" "fmt" + "io" "os" "strconv" "strings" @@ -21,6 +22,7 @@ import ( "github.com/mattn/go-isatty" "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/cmd/tailscale/cli/jsonoutput" "tailscale.com/ipn/ipnstate" "tailscale.com/tka" "tailscale.com/tsconst" @@ -219,7 +221,7 @@ func runNetworkLockStatus(ctx context.Context, args []string) error { } if nlStatusArgs.json { - enc := json.NewEncoder(os.Stdout) + enc := jsonv1.NewEncoder(os.Stdout) enc.SetIndent("", " ") return enc.Encode(st) } @@ -600,7 +602,7 @@ func runNetworkLockDisablementKDF(ctx context.Context, args []string) error { var nlLogArgs struct { limit int - json bool + json jsonoutput.JSONSchemaVersion } var nlLogCmd = &ffcli.Command{ @@ -612,7 +614,7 @@ var nlLogCmd = &ffcli.Command{ FlagSet: (func() *flag.FlagSet { fs := newFlagSet("lock log") fs.IntVar(&nlLogArgs.limit, "limit", 50, "max number of updates to list") - fs.BoolVar(&nlLogArgs.json, "json", false, "output in JSON format (WARNING: format subject to change)") + fs.Var(&nlLogArgs.json, "json", "output in JSON format") return fs })(), } @@ -678,7 +680,7 @@ func nlDescribeUpdate(update ipnstate.NetworkLockUpdate, color bool) (string, er default: // Print a JSON encoding of the AUM as a fallback. - e := json.NewEncoder(&stanza) + e := jsonv1.NewEncoder(&stanza) e.SetIndent("", "\t") if err := e.Encode(aum); err != nil { return "", err @@ -702,14 +704,21 @@ func runNetworkLockLog(ctx context.Context, args []string) error { if err != nil { return fixTailscaledConnectError(err) } - if nlLogArgs.json { - enc := json.NewEncoder(Stdout) - enc.SetIndent("", " ") - return enc.Encode(updates) - } out, useColor := colorableOutput() + return printNetworkLockLog(updates, out, nlLogArgs.json, useColor) +} + +func printNetworkLockLog(updates []ipnstate.NetworkLockUpdate, out io.Writer, jsonSchema jsonoutput.JSONSchemaVersion, useColor bool) error { + if jsonSchema.IsSet { + if jsonSchema.Value == 1 { + return jsonoutput.PrintNetworkLockJSONV1(out, updates) + } else { + return fmt.Errorf("unrecognised version: %q", jsonSchema.Value) + } + } + for _, update := range updates { stanza, err := nlDescribeUpdate(update, useColor) if err != nil { diff --git a/cmd/tailscale/cli/network-lock_test.go b/cmd/tailscale/cli/network-lock_test.go new file mode 100644 index 0000000000000..ccd2957ab560e --- /dev/null +++ b/cmd/tailscale/cli/network-lock_test.go @@ -0,0 +1,204 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "bytes" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/cmd/tailscale/cli/jsonoutput" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tka" + "tailscale.com/types/tkatype" +) + +func TestNetworkLockLogOutput(t *testing.T) { + votes := uint(1) + aum1 := tka.AUM{ + MessageKind: tka.AUMAddKey, + Key: &tka.Key{ + Kind: tka.Key25519, + Votes: 1, + Public: []byte{2, 2}, + }, + } + h1 := aum1.Hash() + aum2 := tka.AUM{ + MessageKind: tka.AUMRemoveKey, + KeyID: []byte{3, 3}, + PrevAUMHash: h1[:], + Signatures: []tkatype.Signature{ + { + KeyID: []byte{3, 4}, + Signature: []byte{4, 5}, + }, + }, + Meta: map[string]string{"en": "three", "de": "drei", "es": "tres"}, + } + h2 := aum2.Hash() + aum3 := tka.AUM{ + MessageKind: tka.AUMCheckpoint, + PrevAUMHash: h2[:], + State: &tka.State{ + Keys: []tka.Key{ + { + Kind: tka.Key25519, + Votes: 1, + Public: []byte{1, 1}, + Meta: map[string]string{"en": "one", "de": "eins", "es": "uno"}, + }, + }, + DisablementSecrets: [][]byte{ + {1, 2, 3}, + {4, 5, 6}, + {7, 8, 9}, + }, + }, + Votes: &votes, + } + + updates := []ipnstate.NetworkLockUpdate{ + { + Hash: aum3.Hash(), + Change: aum3.MessageKind.String(), + Raw: aum3.Serialize(), + }, + { + Hash: aum2.Hash(), + Change: aum2.MessageKind.String(), + Raw: aum2.Serialize(), + }, + { + Hash: aum1.Hash(), + Change: aum1.MessageKind.String(), + Raw: aum1.Serialize(), + }, + } + + t.Run("human-readable", func(t *testing.T) { + t.Parallel() + + var outBuf bytes.Buffer + json := jsonoutput.JSONSchemaVersion{} + useColor := false + + printNetworkLockLog(updates, &outBuf, json, useColor) + + t.Logf("%s", outBuf.String()) + + want := `update 4M4Q3IXBARPQMFVXHJBDCYQMWU5H5FBKD7MFF75HE4O5JMIWR2UA (checkpoint) +Disablement values: + - 010203 + - 040506 + - 070809 +Keys: + Type: 25519 + KeyID: tlpub:0101 + Metadata: map[de:eins en:one es:uno] + +update BKVVXHOVBW7Y7YXYTLVVLMNSYG6DS5GVRVSYZLASNU3AQKA732XQ (remove-key) +KeyID: tlpub:0303 + +update UKJIKFHILQ62AEN7MQIFHXJ6SFVDGQCQA3OHVI3LWVPM736EMSAA (add-key) +Type: 25519 +KeyID: tlpub:0202 + +` + + if diff := cmp.Diff(outBuf.String(), want); diff != "" { + t.Fatalf("wrong output (-got, +want):\n%s", diff) + } + }) + + jsonV1 := `{ + "SchemaVersion": "1", + "Messages": [ + { + "Hash": "4M4Q3IXBARPQMFVXHJBDCYQMWU5H5FBKD7MFF75HE4O5JMIWR2UA", + "AUM": { + "MessageKind": "checkpoint", + "PrevAUMHash": "BKVVXHOVBW7Y7YXYTLVVLMNSYG6DS5GVRVSYZLASNU3AQKA732XQ", + "State": { + "DisablementSecrets": [ + "010203", + "040506", + "070809" + ], + "Keys": [ + { + "Kind": "25519", + "Votes": 1, + "Public": "tlpub:0101", + "Meta": { + "de": "eins", + "en": "one", + "es": "uno" + } + } + ], + "StateID1": 0, + "StateID2": 0 + }, + "Votes": 1 + }, + "Raw": "pAEFAlggCqtbndUNv4_i-JrrVbGywbw5dNWNZYysEm02CCgf3q8FowH2AoNDAQIDQwQFBkMHCAkDgaQBAQIBA0IBAQyjYmRlZGVpbnNiZW5jb25lYmVzY3VubwYB" + }, + { + "Hash": "BKVVXHOVBW7Y7YXYTLVVLMNSYG6DS5GVRVSYZLASNU3AQKA732XQ", + "AUM": { + "MessageKind": "remove-key", + "PrevAUMHash": "UKJIKFHILQ62AEN7MQIFHXJ6SFVDGQCQA3OHVI3LWVPM736EMSAA", + "KeyID": "tlpub:0303", + "Meta": { + "de": "drei", + "en": "three", + "es": "tres" + }, + "Signatures": [ + { + "KeyID": "tlpub:0304", + "Signature": "BAU=" + } + ] + }, + "Raw": "pQECAlggopKFFOhcPaARv2QQU90-kWozQFAG3Hqja7Vez-_EZIAEQgMDB6NiZGVkZHJlaWJlbmV0aHJlZWJlc2R0cmVzF4GiAUIDBAJCBAU=" + }, + { + "Hash": "UKJIKFHILQ62AEN7MQIFHXJ6SFVDGQCQA3OHVI3LWVPM736EMSAA", + "AUM": { + "MessageKind": "add-key", + "Key": { + "Kind": "25519", + "Votes": 1, + "Public": "tlpub:0202" + } + }, + "Raw": "owEBAvYDowEBAgEDQgIC" + } + ] +} +` + + t.Run("json-1", func(t *testing.T) { + t.Parallel() + t.Logf("BOOM") + + var outBuf bytes.Buffer + json := jsonoutput.JSONSchemaVersion{ + IsSet: true, + Value: 1, + } + useColor := false + + printNetworkLockLog(updates, &outBuf, json, useColor) + + want := jsonV1 + t.Logf("%s", outBuf.String()) + + if diff := cmp.Diff(outBuf.String(), want); diff != "" { + t.Fatalf("wrong output (-got, +want):\n%s", diff) + } + }) +} diff --git a/cmd/tailscale/cli/serve_legacy.go b/cmd/tailscale/cli/serve_legacy.go index 95808fdf2eb34..580393ce489b1 100644 --- a/cmd/tailscale/cli/serve_legacy.go +++ b/cmd/tailscale/cli/serve_legacy.go @@ -149,6 +149,7 @@ type localServeClient interface { IncrementCounter(ctx context.Context, name string, delta int) error GetPrefs(ctx context.Context) (*ipn.Prefs, error) EditPrefs(ctx context.Context, mp *ipn.MaskedPrefs) (*ipn.Prefs, error) + CheckSOMarkInUse(ctx context.Context) (bool, error) } // serveEnv is the environment the serve command runs within. All I/O should be @@ -162,20 +163,21 @@ type serveEnv struct { json bool // output JSON (status only for now) // v2 specific flags - bg bgBoolFlag // background mode - setPath string // serve path - https uint // HTTP port - http uint // HTTP port - tcp uint // TCP port - tlsTerminatedTCP uint // a TLS terminated TCP port - subcmd serveMode // subcommand - yes bool // update without prompt - service tailcfg.ServiceName // service name - tun bool // redirect traffic to OS for service - allServices bool // apply config file to all services + bg bgBoolFlag // background mode + setPath string // serve path + https uint // HTTP port + http uint // HTTP port + tcp uint // TCP port + tlsTerminatedTCP uint // a TLS terminated TCP port + proxyProtocol uint // PROXY protocol version (1 or 2) + subcmd serveMode // subcommand + yes bool // update without prompt + service tailcfg.ServiceName // service name + tun bool // redirect traffic to OS for service + allServices bool // apply config file to all services + acceptAppCaps []tailcfg.PeerCapability // app capabilities to forward lc localServeClient // localClient interface, specific to serve - // optional stuff for tests: testFlagOut io.Writer testStdout io.Writer @@ -570,7 +572,7 @@ func (e *serveEnv) handleTCPServe(ctx context.Context, srcType string, srcPort u return fmt.Errorf("cannot serve TCP; already serving web on %d", srcPort) } - sc.SetTCPForwarding(srcPort, fwdAddr, terminateTLS, dnsName) + sc.SetTCPForwarding(srcPort, fwdAddr, terminateTLS, 0 /* proxy proto */, dnsName) if !reflect.DeepEqual(cursc, sc) { if err := e.lc.SetServeConfig(ctx, sc); err != nil { diff --git a/cmd/tailscale/cli/serve_legacy_test.go b/cmd/tailscale/cli/serve_legacy_test.go index c509508dfb1f0..819017ad81bb5 100644 --- a/cmd/tailscale/cli/serve_legacy_test.go +++ b/cmd/tailscale/cli/serve_legacy_test.go @@ -860,6 +860,8 @@ type fakeLocalServeClient struct { setCount int // counts calls to SetServeConfig queryFeatureResponse *mockQueryFeatureResponse // mock response to QueryFeature calls prefs *ipn.Prefs // fake preferences, used to test GetPrefs and SetPrefs + SOMarkInUse bool // fake SO mark in use status + statusWithoutPeers *ipnstate.Status // nil for fakeStatus } // fakeStatus is a fake ipnstate.Status value for tests. @@ -880,7 +882,10 @@ var fakeStatus = &ipnstate.Status{ } func (lc *fakeLocalServeClient) StatusWithoutPeers(ctx context.Context) (*ipnstate.Status, error) { - return fakeStatus, nil + if lc.statusWithoutPeers == nil { + return fakeStatus, nil + } + return lc.statusWithoutPeers, nil } func (lc *fakeLocalServeClient) GetServeConfig(ctx context.Context) (*ipn.ServeConfig, error) { @@ -933,6 +938,10 @@ func (lc *fakeLocalServeClient) IncrementCounter(ctx context.Context, name strin return nil // unused in tests } +func (lc *fakeLocalServeClient) CheckSOMarkInUse(ctx context.Context) (bool, error) { + return lc.SOMarkInUse, nil +} + // exactError returns an error checker that wants exactly the provided want error. // If optName is non-empty, it's used in the error message. func exactErr(want error, optName ...string) func(error) string { diff --git a/cmd/tailscale/cli/serve_v2.go b/cmd/tailscale/cli/serve_v2.go index 9b0af2cad7a0c..89d247be9f773 100644 --- a/cmd/tailscale/cli/serve_v2.go +++ b/cmd/tailscale/cli/serve_v2.go @@ -20,6 +20,8 @@ import ( "os/signal" "path" "path/filepath" + "regexp" + "runtime" "slices" "sort" "strconv" @@ -32,6 +34,7 @@ import ( "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" "tailscale.com/types/ipproto" + "tailscale.com/util/dnsname" "tailscale.com/util/mak" "tailscale.com/util/prompt" "tailscale.com/util/set" @@ -96,6 +99,41 @@ func (b *bgBoolFlag) String() string { return strconv.FormatBool(b.Value) } +type acceptAppCapsFlag struct { + Value *[]tailcfg.PeerCapability +} + +// An application capability name has the form {domain}/{name}. +// Both parts must use the (simplified) FQDN label character set. +// The "name" can contain forward slashes. +// \pL = Unicode Letter, \pN = Unicode Number, - = Hyphen +var validAppCap = regexp.MustCompile(`^([\pL\pN-]+\.)+[\pL\pN-]+\/[\pL\pN-/]+$`) + +// Set appends s to the list of appCaps to accept. +func (u *acceptAppCapsFlag) Set(s string) error { + if s == "" { + return nil + } + appCaps := strings.Split(s, ",") + for _, appCap := range appCaps { + appCap = strings.TrimSpace(appCap) + if !validAppCap.MatchString(appCap) { + return fmt.Errorf("%q does not match the form {domain}/{name}, where domain must be a fully qualified domain name", appCap) + } + *u.Value = append(*u.Value, tailcfg.PeerCapability(appCap)) + } + return nil +} + +// String returns the string representation of the slice of appCaps to accept. +func (u *acceptAppCapsFlag) String() string { + s := make([]string, len(*u.Value)) + for i, v := range *u.Value { + s[i] = string(v) + } + return strings.Join(s, ",") +} + var serveHelpCommon = strings.TrimSpace(` can be a file, directory, text, or most commonly the location to a service running on the local machine. The location to the location service can be expressed as a port number (e.g., 3000), @@ -199,10 +237,12 @@ func newServeV2Command(e *serveEnv, subcmd serveMode) *ffcli.Command { fs.UintVar(&e.https, "https", 0, "Expose an HTTPS server at the specified port (default mode)") if subcmd == serve { fs.UintVar(&e.http, "http", 0, "Expose an HTTP server at the specified port") + fs.Var(&acceptAppCapsFlag{Value: &e.acceptAppCaps}, "accept-app-caps", "App capabilities to forward to the server (specify multiple capabilities with a comma-separated list)") + fs.Var(&serviceNameFlag{Value: &e.service}, "service", "Serve for a service with distinct virtual IP instead on node itself.") } fs.UintVar(&e.tcp, "tcp", 0, "Expose a TCP forwarder to forward raw TCP packets at the specified port") fs.UintVar(&e.tlsTerminatedTCP, "tls-terminated-tcp", 0, "Expose a TCP forwarder to forward TLS-terminated TCP packets at the specified port") - fs.Var(&serviceNameFlag{Value: &e.service}, "service", "Serve for a service with distinct virtual IP instead on node itself.") + fs.UintVar(&e.proxyProtocol, "proxy-protocol", 0, "PROXY protocol version (1 or 2) for TCP forwarding") fs.BoolVar(&e.yes, "yes", false, "Update without interactive prompts (default false)") fs.BoolVar(&e.tun, "tun", false, "Forward all traffic to the local machine (default false), only supported for services. Refer to docs for more information.") }), @@ -255,7 +295,7 @@ func newServeV2Command(e *serveEnv, subcmd serveMode) *ffcli.Command { Name: "get-config", ShortUsage: fmt.Sprintf("tailscale %s get-config [--service=] [--all]", info.Name), ShortHelp: "Get service configuration to save to a file", - LongHelp: hidden + "Get the configuration for services that this node is currently hosting in a\n" + + LongHelp: "Get the configuration for services that this node is currently hosting in a\n" + "format that can later be provided to set-config. This can be used to declaratively set\n" + "configuration for a service host.", Exec: e.runServeGetConfig, @@ -268,10 +308,11 @@ func newServeV2Command(e *serveEnv, subcmd serveMode) *ffcli.Command { Name: "set-config", ShortUsage: fmt.Sprintf("tailscale %s set-config [--service=] [--all]", info.Name), ShortHelp: "Define service configuration from a file", - LongHelp: hidden + "Read the provided configuration file and use it to declaratively set the configuration\n" + + LongHelp: "Read the provided configuration file and use it to declaratively set the configuration\n" + "for either a single service, or for all services that this node is hosting. If --service is specified,\n" + "all endpoint handlers for that service are overwritten. If --all is specified, all endpoint handlers for\n" + - "all services are overwritten.", + "all services are overwritten.\n\n" + + "For information on the file format, see tailscale.com/kb/1589/tailscale-services-configuration-file", Exec: e.runServeSetConfig, FlagSet: e.newFlags("serve-set-config", func(fs *flag.FlagSet) { fs.BoolVar(&e.allServices, "all", false, "apply config to all services") @@ -375,6 +416,14 @@ func (e *serveEnv) runServeCombined(subcmd serveMode) execFunc { return errHelpFunc(subcmd) } + if (srvType == serveTypeHTTP || srvType == serveTypeHTTPS) && e.proxyProtocol != 0 { + return fmt.Errorf("PROXY protocol is only supported for TCP forwarding, not HTTP/HTTPS") + } + // Validate PROXY protocol version + if e.proxyProtocol != 0 && e.proxyProtocol != 1 && e.proxyProtocol != 2 { + return fmt.Errorf("invalid PROXY protocol version %d; must be 1 or 2", e.proxyProtocol) + } + sc, err := e.lc.GetServeConfig(ctx) if err != nil { return fmt.Errorf("error getting serve config: %w", err) @@ -420,20 +469,19 @@ func (e *serveEnv) runServeCombined(subcmd serveMode) execFunc { svcName = e.service dnsName = e.service.String() } + tagged := st.Self.Tags != nil && st.Self.Tags.Len() > 0 + if forService && !tagged && !turnOff { + return errors.New("service hosts must be tagged nodes") + } if !forService && srvType == serveTypeTUN { return errors.New("tun mode is only supported for services") } wantFg := !e.bg.Value && !turnOff if wantFg { - // validate the config before creating a WatchIPNBus session - if err := e.validateConfig(parentSC, srvPort, srvType, svcName); err != nil { - return err - } - // if foreground mode, create a WatchIPNBus session // and use the nested config for all following operations // TODO(marwan-at-work): nested-config validations should happen here or previous to this point. - watcher, err = e.lc.WatchIPNBus(ctx, ipn.NotifyInitialState|ipn.NotifyNoPrivateKeys) + watcher, err = e.lc.WatchIPNBus(ctx, ipn.NotifyInitialState) if err != nil { return err } @@ -455,9 +503,6 @@ func (e *serveEnv) runServeCombined(subcmd serveMode) execFunc { // only unset serve when trying to unset with type and port flags. err = e.unsetServe(sc, dnsName, srvType, srvPort, mount, magicDNSSuffix) } else { - if err := e.validateConfig(parentSC, srvPort, srvType, svcName); err != nil { - return err - } if forService { e.addServiceToPrefs(ctx, svcName) } @@ -465,7 +510,10 @@ func (e *serveEnv) runServeCombined(subcmd serveMode) execFunc { if len(args) > 0 { target = args[0] } - err = e.setServe(sc, dnsName, srvType, srvPort, mount, target, funnel, magicDNSSuffix) + if err := e.shouldWarnRemoteDestCompatibility(ctx, target); err != nil { + return err + } + err = e.setServe(sc, dnsName, srvType, srvPort, mount, target, funnel, magicDNSSuffix, e.acceptAppCaps, int(e.proxyProtocol)) msg = e.messageForPort(sc, st, dnsName, srvType, srvPort) } if err != nil { @@ -786,7 +834,7 @@ func (e *serveEnv) runServeSetConfig(ctx context.Context, args []string) (err er for name, details := range scf.Services { for ppr, ep := range details.Endpoints { if ep.Protocol == conffile.ProtoTUN { - err := e.setServe(sc, name.String(), serveTypeTUN, 0, "", "", false, magicDNSSuffix) + err := e.setServe(sc, name.String(), serveTypeTUN, 0, "", "", false, magicDNSSuffix, nil, 0 /* proxy protocol */) if err != nil { return err } @@ -808,7 +856,7 @@ func (e *serveEnv) runServeSetConfig(ctx context.Context, args []string) (err er portStr := fmt.Sprint(destPort) target = fmt.Sprintf("%s://%s", ep.Protocol, net.JoinHostPort(ep.Destination, portStr)) } - err := e.setServe(sc, name.String(), serveType, port, "/", target, false, magicDNSSuffix) + err := e.setServe(sc, name.String(), serveType, port, "/", target, false, magicDNSSuffix, nil, 0 /* proxy protocol */) if err != nil { return fmt.Errorf("service %q: %w", name, err) } @@ -851,72 +899,12 @@ func (e *serveEnv) runServeSetConfig(ctx context.Context, args []string) (err er return e.lc.SetServeConfig(ctx, sc) } -const backgroundExistsMsg = "background configuration already exists, use `tailscale %s --%s=%d off` to remove the existing configuration" - -// validateConfig checks if the serve config is valid to serve the type wanted on the port. -// dnsName is a FQDN or a serviceName (with `svc:` prefix). -func (e *serveEnv) validateConfig(sc *ipn.ServeConfig, port uint16, wantServe serveType, svcName tailcfg.ServiceName) error { - var tcpHandlerForPort *ipn.TCPPortHandler - if svcName != noService { - svc := sc.Services[svcName] - if svc == nil { - return nil - } - if wantServe == serveTypeTUN && (svc.TCP != nil || svc.Web != nil) { - return errors.New("service already has a TCP or Web handler, cannot serve in TUN mode") - } - if svc.Tun && wantServe != serveTypeTUN { - return errors.New("service is already being served in TUN mode") - } - if svc.TCP[port] == nil { - return nil - } - tcpHandlerForPort = svc.TCP[port] - } else { - sc, isFg := sc.FindConfig(port) - if sc == nil { - return nil - } - if isFg { - return errors.New("foreground already exists under this port") - } - if !e.bg.Value { - return fmt.Errorf(backgroundExistsMsg, infoMap[e.subcmd].Name, wantServe.String(), port) - } - tcpHandlerForPort = sc.TCP[port] - } - existingServe := serveFromPortHandler(tcpHandlerForPort) - if wantServe != existingServe { - target := svcName - if target == noService { - target = "machine" - } - return fmt.Errorf("want to serve %q but port is already serving %q for %q", wantServe, existingServe, target) - } - return nil -} - -func serveFromPortHandler(tcp *ipn.TCPPortHandler) serveType { - switch { - case tcp.HTTP: - return serveTypeHTTP - case tcp.HTTPS: - return serveTypeHTTPS - case tcp.TerminateTLS != "": - return serveTypeTLSTerminatedTCP - case tcp.TCPForward != "": - return serveTypeTCP - default: - return -1 - } -} - -func (e *serveEnv) setServe(sc *ipn.ServeConfig, dnsName string, srvType serveType, srvPort uint16, mount string, target string, allowFunnel bool, mds string) error { +func (e *serveEnv) setServe(sc *ipn.ServeConfig, dnsName string, srvType serveType, srvPort uint16, mount string, target string, allowFunnel bool, mds string, caps []tailcfg.PeerCapability, proxyProtocol int) error { // update serve config based on the type switch srvType { case serveTypeHTTPS, serveTypeHTTP: useTLS := srvType == serveTypeHTTPS - err := e.applyWebServe(sc, dnsName, srvPort, useTLS, mount, target, mds) + err := e.applyWebServe(sc, dnsName, srvPort, useTLS, mount, target, mds, caps) if err != nil { return fmt.Errorf("failed apply web serve: %w", err) } @@ -924,7 +912,7 @@ func (e *serveEnv) setServe(sc *ipn.ServeConfig, dnsName string, srvType serveTy if e.setPath != "" { return fmt.Errorf("cannot mount a path for TCP serve") } - err := e.applyTCPServe(sc, dnsName, srvType, srvPort, target) + err := e.applyTCPServe(sc, dnsName, srvType, srvPort, target, proxyProtocol) if err != nil { return fmt.Errorf("failed to apply TCP serve: %w", err) } @@ -948,16 +936,17 @@ func (e *serveEnv) setServe(sc *ipn.ServeConfig, dnsName string, srvType serveTy } var ( - msgFunnelAvailable = "Available on the internet:" - msgServeAvailable = "Available within your tailnet:" - msgServiceWaitingApproval = "This machine is configured as a service proxy for %s, but approval from an admin is required. Once approved, it will be available in your Tailnet as:" - msgRunningInBackground = "%s started and running in the background." - msgRunningTunService = "IPv4 and IPv6 traffic to %s is being routed to your operating system." - msgDisableProxy = "To disable the proxy, run: tailscale %s --%s=%d off" - msgDisableServiceProxy = "To disable the proxy, run: tailscale serve --service=%s --%s=%d off" - msgDisableServiceTun = "To disable the service in TUN mode, run: tailscale serve --service=%s --tun off" - msgDisableService = "To remove config for the service, run: tailscale serve clear %s" - msgToExit = "Press Ctrl+C to exit." + msgFunnelAvailable = "Available on the internet:" + msgServeAvailable = "Available within your tailnet:" + msgServiceWaitingApproval = "This machine is configured as a service proxy for %s, but approval from an admin is required. Once approved, it will be available in your Tailnet as:" + msgRunningInBackground = "%s started and running in the background." + msgRunningTunService = "IPv4 and IPv6 traffic to %s is being routed to your operating system." + msgDisableProxy = "To disable the proxy, run: tailscale %s --%s=%d off" + msgDisableServiceProxy = "To disable the proxy, run: tailscale serve --service=%s --%s=%d off" + msgDisableServiceTun = "To disable the service in TUN mode, run: tailscale serve --service=%s --tun off" + msgDisableService = "To remove config for the service, run: tailscale serve clear %s" + msgWarnRemoteDestCompatibility = "Warning: %s doesn't support connecting to remote destinations from non-default route, see tailscale.com/kb/1552/tailscale-services for detail." + msgToExit = "Press Ctrl+C to exit." ) // messageForPort returns a message for the given port based on the @@ -1050,6 +1039,9 @@ func (e *serveEnv) messageForPort(sc *ipn.ServeConfig, st *ipnstate.Status, dnsN if tcpHandler.TerminateTLS != "" { tlsStatus = "TLS terminated" } + if ver := tcpHandler.ProxyProtocol; ver != 0 { + tlsStatus = fmt.Sprintf("%s, PROXY protocol v%d", tlsStatus, ver) + } output.WriteString(fmt.Sprintf("|-- tcp://%s:%d (%s)\n", host, srvPort, tlsStatus)) for _, a := range ips { @@ -1080,7 +1072,78 @@ func (e *serveEnv) messageForPort(sc *ipn.ServeConfig, st *ipnstate.Status, dnsN return output.String() } -func (e *serveEnv) applyWebServe(sc *ipn.ServeConfig, dnsName string, srvPort uint16, useTLS bool, mount, target string, mds string) error { +// isRemote reports whether the given destination from serve config +// is a remote destination. +func isRemote(target string) bool { + // target being a port number means it's localhost + if _, err := strconv.ParseUint(target, 10, 16); err == nil { + return false + } + + // prepend tmp:// if no scheme is present just to help parsing + if !strings.Contains(target, "://") { + target = "tmp://" + target + } + + // make sure we can parse the target, wether it's a full URL or just a host:port + u, err := url.ParseRequestURI(target) + if err != nil { + // If we can't parse the target, it doesn't matter if it's remote or not + return false + } + validHN := dnsname.ValidHostname(u.Hostname()) == nil + validIP := net.ParseIP(u.Hostname()) != nil + if !validHN && !validIP { + return false + } + if u.Hostname() == "localhost" || u.Hostname() == "127.0.0.1" || u.Hostname() == "::1" { + return false + } + return true +} + +// shouldWarnRemoteDestCompatibility reports whether we should warn the user +// that their current OS/environment may not be compatible with +// service's proxy destination. +func (e *serveEnv) shouldWarnRemoteDestCompatibility(ctx context.Context, target string) error { + // no target means nothing to check + if target == "" { + return nil + } + + if filepath.IsAbs(target) || strings.HasPrefix(target, "text:") { + // local path or text target, nothing to check + return nil + } + + // only check for remote destinations + if !isRemote(target) { + return nil + } + + // Check if running as Mac extension and warn + if version.IsMacAppStore() || version.IsMacSysExt() { + return fmt.Errorf(msgWarnRemoteDestCompatibility, "the MacOS extension") + } + + // Check for linux, if it's running with TS_FORCE_LINUX_BIND_TO_DEVICE=true + // and tailscale bypass mark is not working. If any of these conditions are true, and the dest is + // a remote destination, return true. + if runtime.GOOS == "linux" { + SOMarkInUse, err := e.lc.CheckSOMarkInUse(ctx) + if err != nil { + log.Printf("error checking SO mark in use: %v", err) + return nil + } + if !SOMarkInUse { + return fmt.Errorf(msgWarnRemoteDestCompatibility, "the Linux tailscaled without SO_MARK") + } + } + + return nil +} + +func (e *serveEnv) applyWebServe(sc *ipn.ServeConfig, dnsName string, srvPort uint16, useTLS bool, mount, target, mds string, caps []tailcfg.PeerCapability) error { h := new(ipn.HTTPHandler) switch { case strings.HasPrefix(target, "text:"): @@ -1114,6 +1177,7 @@ func (e *serveEnv) applyWebServe(sc *ipn.ServeConfig, dnsName string, srvPort ui return err } h.Proxy = t + h.AcceptAppCaps = caps } // TODO: validation needs to check nested foreground configs @@ -1127,7 +1191,7 @@ func (e *serveEnv) applyWebServe(sc *ipn.ServeConfig, dnsName string, srvPort ui return nil } -func (e *serveEnv) applyTCPServe(sc *ipn.ServeConfig, dnsName string, srcType serveType, srcPort uint16, target string) error { +func (e *serveEnv) applyTCPServe(sc *ipn.ServeConfig, dnsName string, srcType serveType, srcPort uint16, target string, proxyProtocol int) error { var terminateTLS bool switch srcType { case serveTypeTCP: @@ -1138,6 +1202,8 @@ func (e *serveEnv) applyTCPServe(sc *ipn.ServeConfig, dnsName string, srcType se return fmt.Errorf("invalid TCP target %q", target) } + svcName := tailcfg.AsServiceName(dnsName) + targetURL, err := ipn.ExpandProxyTargetValue(target, []string{"tcp"}, "tcp") if err != nil { return fmt.Errorf("unable to expand target: %v", err) @@ -1149,13 +1215,11 @@ func (e *serveEnv) applyTCPServe(sc *ipn.ServeConfig, dnsName string, srcType se } // TODO: needs to account for multiple configs from foreground mode - svcName := tailcfg.AsServiceName(dnsName) if sc.IsServingWeb(srcPort, svcName) { return fmt.Errorf("cannot serve TCP; already serving web on %d for %s", srcPort, dnsName) } - sc.SetTCPForwarding(srcPort, dstURL.Host, terminateTLS, dnsName) - + sc.SetTCPForwarding(srcPort, dstURL.Host, terminateTLS, proxyProtocol, dnsName) return nil } diff --git a/cmd/tailscale/cli/serve_v2_test.go b/cmd/tailscale/cli/serve_v2_test.go index 1deeaf3eaa9b5..513c0d1ec97d4 100644 --- a/cmd/tailscale/cli/serve_v2_test.go +++ b/cmd/tailscale/cli/serve_v2_test.go @@ -12,6 +12,7 @@ import ( "os" "path/filepath" "reflect" + "regexp" "slices" "strconv" "strings" @@ -22,6 +23,7 @@ import ( "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" + "tailscale.com/types/views" ) func TestServeDevConfigMutations(t *testing.T) { @@ -33,10 +35,11 @@ func TestServeDevConfigMutations(t *testing.T) { } // group is a group of steps that share the same - // config mutation, but always starts from an empty config + // config mutation type group struct { - name string - steps []step + name string + steps []step + initialState fakeLocalServeClient // use the zero value for empty config } // creaet a temporary directory for path-based destinations @@ -217,10 +220,20 @@ func TestServeDevConfigMutations(t *testing.T) { }}, }, { - name: "invalid_host", + name: "ip_host", + initialState: fakeLocalServeClient{ + SOMarkInUse: true, + }, steps: []step{{ - command: cmd("serve --https=443 --bg http://somehost:3000"), // invalid host - wantErr: anyErr(), + command: cmd("serve --https=443 --bg http://192.168.1.1:3000"), + want: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{443: {HTTPS: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "foo.test.ts.net:443": {Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://192.168.1.1:3000"}, + }}, + }, + }, }}, }, { @@ -230,6 +243,16 @@ func TestServeDevConfigMutations(t *testing.T) { wantErr: anyErr(), }}, }, + { + name: "no_scheme_remote_host_tcp", + initialState: fakeLocalServeClient{ + SOMarkInUse: true, + }, + steps: []step{{ + command: cmd("serve --https=443 --bg 192.168.1.1:3000"), + wantErr: exactErrMsg(errHelp), + }}, + }, { name: "turn_off_https", steps: []step{ @@ -399,15 +422,11 @@ func TestServeDevConfigMutations(t *testing.T) { }, }}, }, - { - name: "unknown_host_tcp", - steps: []step{{ - command: cmd("serve --tls-terminated-tcp=443 --bg tcp://somehost:5432"), - wantErr: exactErrMsg(errHelp), - }}, - }, { name: "tcp_port_too_low", + initialState: fakeLocalServeClient{ + SOMarkInUse: true, + }, steps: []step{{ command: cmd("serve --tls-terminated-tcp=443 --bg tcp://somehost:0"), wantErr: exactErrMsg(errHelp), @@ -415,6 +434,9 @@ func TestServeDevConfigMutations(t *testing.T) { }, { name: "tcp_port_too_high", + initialState: fakeLocalServeClient{ + SOMarkInUse: true, + }, steps: []step{{ command: cmd("serve --tls-terminated-tcp=443 --bg tcp://somehost:65536"), wantErr: exactErrMsg(errHelp), @@ -529,6 +551,9 @@ func TestServeDevConfigMutations(t *testing.T) { }, { name: "bad_path", + initialState: fakeLocalServeClient{ + SOMarkInUse: true, + }, steps: []step{{ command: cmd("serve --bg --https=443 bad/path"), wantErr: exactErrMsg(errHelp), @@ -795,36 +820,186 @@ func TestServeDevConfigMutations(t *testing.T) { }, }, { - name: "forground_with_bg_conflict", + name: "advertise_service", + initialState: fakeLocalServeClient{ + statusWithoutPeers: &ipnstate.Status{ + BackendState: ipn.Running.String(), + Self: &ipnstate.PeerStatus{ + DNSName: "foo.test.ts.net", + CapMap: tailcfg.NodeCapMap{ + tailcfg.NodeAttrFunnel: nil, + tailcfg.CapabilityFunnelPorts + "?ports=443,8443": nil, + }, + Tags: ptrToReadOnlySlice([]string{"some-tag"}), + }, + CurrentTailnet: &ipnstate.TailnetStatus{MagicDNSSuffix: "test.ts.net"}, + }, + SOMarkInUse: true, + }, + steps: []step{{ + command: cmd("serve --service=svc:foo --http=80 text:foo"), + want: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "foo.test.ts.net:80": {Handlers: map[string]*ipn.HTTPHandler{ + "/": {Text: "foo"}, + }}, + }, + }, + }, + }, + }}, + }, + { + name: "advertise_service_from_untagged_node", + steps: []step{{ + command: cmd("serve --service=svc:foo --http=80 text:foo"), + wantErr: anyErr(), + }}, + }, + { + name: "forward_grant_header", steps: []step{ { - command: cmd("serve --bg --http=3000 localhost:3000"), + command: cmd("serve --bg --accept-app-caps=example.com/cap/foo 3000"), want: &ipn.ServeConfig{ - TCP: map[uint16]*ipn.TCPPortHandler{3000: {HTTP: true}}, + TCP: map[uint16]*ipn.TCPPortHandler{443: {HTTPS: true}}, Web: map[ipn.HostPort]*ipn.WebServerConfig{ - "foo.test.ts.net:3000": {Handlers: map[string]*ipn.HTTPHandler{ - "/": {Proxy: "http://localhost:3000"}, + "foo.test.ts.net:443": {Handlers: map[string]*ipn.HTTPHandler{ + "/": { + Proxy: "http://127.0.0.1:3000", + AcceptAppCaps: []tailcfg.PeerCapability{"example.com/cap/foo"}, + }, + }}, + }, + }, + }, + { + command: cmd("serve --bg --accept-app-caps=example.com/cap/foo,example.com/cap/bar 3000"), + want: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{443: {HTTPS: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "foo.test.ts.net:443": {Handlers: map[string]*ipn.HTTPHandler{ + "/": { + Proxy: "http://127.0.0.1:3000", + AcceptAppCaps: []tailcfg.PeerCapability{"example.com/cap/foo", "example.com/cap/bar"}, + }, }}, }, }, }, { - command: cmd("serve --http=3000 localhost:3000"), - wantErr: exactErrMsg(fmt.Errorf(backgroundExistsMsg, "serve", "http", 3000)), + command: cmd("serve --bg --accept-app-caps=example.com/cap/bar 3000"), + want: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{443: {HTTPS: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "foo.test.ts.net:443": {Handlers: map[string]*ipn.HTTPHandler{ + "/": { + Proxy: "http://127.0.0.1:3000", + AcceptAppCaps: []tailcfg.PeerCapability{"example.com/cap/bar"}, + }, + }}, + }, + }, + }, + }, + }, + { + name: "invalid_accept_caps_invalid_app_cap", + steps: []step{ + { + command: cmd("serve --bg --accept-app-caps=example.com/cap/fine,NOTFINE 3000"), // should be {domain.tld}/{name} + wantErr: func(err error) (badErrMsg string) { + if err == nil || !strings.Contains(err.Error(), fmt.Sprintf("%q does not match", "NOTFINE")) { + return fmt.Sprintf("wanted validation error that quotes the non-matching capability (and nothing more) but got %q", err.Error()) + } + return "" + }, + }, + }, + }, + { + name: "tcp_with_proxy_protocol_v1", + steps: []step{{ + command: cmd("serve --tcp=8000 --proxy-protocol=1 --bg tcp://localhost:5432"), + want: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 8000: { + TCPForward: "localhost:5432", + ProxyProtocol: 1, + }, + }, + }, + }}, + }, + { + name: "tls_terminated_tcp_with_proxy_protocol_v2", + steps: []step{{ + command: cmd("serve --tls-terminated-tcp=443 --proxy-protocol=2 --bg tcp://localhost:5432"), + want: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: { + TCPForward: "localhost:5432", + TerminateTLS: "foo.test.ts.net", + ProxyProtocol: 2, + }, + }, + }, + }}, + }, + { + name: "tcp_update_to_add_proxy_protocol", + steps: []step{ + { + command: cmd("serve --tcp=8000 --bg tcp://localhost:5432"), + want: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 8000: {TCPForward: "localhost:5432"}, + }, + }, + }, + { + command: cmd("serve --tcp=8000 --proxy-protocol=1 --bg tcp://localhost:5432"), + want: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 8000: { + TCPForward: "localhost:5432", + ProxyProtocol: 1, + }, + }, + }, }, }, }, + { + name: "tcp_proxy_protocol_invalid_version", + steps: []step{{ + command: cmd("serve --tcp=8000 --proxy-protocol=3 --bg tcp://localhost:5432"), + wantErr: anyErr(), + }}, + }, + { + name: "proxy_protocol_without_tcp", + steps: []step{{ + command: cmd("serve --https=443 --proxy-protocol=1 --bg http://localhost:3000"), + wantErr: anyErr(), + }}, + }, } for _, group := range groups { t.Run(group.name, func(t *testing.T) { - lc := &fakeLocalServeClient{} + lc := group.initialState for i, st := range group.steps { var stderr bytes.Buffer var stdout bytes.Buffer var flagOut bytes.Buffer e := &serveEnv{ - lc: lc, + lc: &lc, testFlagOut: &flagOut, testStdout: &stdout, testStderr: &stderr, @@ -872,190 +1047,6 @@ func TestServeDevConfigMutations(t *testing.T) { } } -func TestValidateConfig(t *testing.T) { - tests := [...]struct { - name string - desc string - cfg *ipn.ServeConfig - svc tailcfg.ServiceName - servePort uint16 - serveType serveType - bg bgBoolFlag - wantErr bool - }{ - { - name: "nil_config", - desc: "when config is nil, all requests valid", - cfg: nil, - servePort: 3000, - serveType: serveTypeHTTPS, - }, - { - name: "new_bg_tcp", - desc: "no error when config exists but we're adding a new bg tcp port", - cfg: &ipn.ServeConfig{ - TCP: map[uint16]*ipn.TCPPortHandler{ - 443: {HTTPS: true}, - }, - }, - bg: bgBoolFlag{true, false}, - servePort: 10000, - serveType: serveTypeHTTPS, - }, - { - name: "override_bg_tcp", - desc: "no error when overwriting previous port under the same serve type", - cfg: &ipn.ServeConfig{ - TCP: map[uint16]*ipn.TCPPortHandler{ - 443: {TCPForward: "http://localhost:4545"}, - }, - }, - bg: bgBoolFlag{true, false}, - servePort: 443, - serveType: serveTypeTCP, - }, - { - name: "override_bg_tcp", - desc: "error when overwriting previous port under a different serve type", - cfg: &ipn.ServeConfig{ - TCP: map[uint16]*ipn.TCPPortHandler{ - 443: {HTTPS: true}, - }, - }, - bg: bgBoolFlag{true, false}, - servePort: 443, - serveType: serveTypeHTTP, - wantErr: true, - }, - { - name: "new_fg_port", - desc: "no error when serving a new foreground port", - cfg: &ipn.ServeConfig{ - TCP: map[uint16]*ipn.TCPPortHandler{ - 443: {HTTPS: true}, - }, - Foreground: map[string]*ipn.ServeConfig{ - "abc123": { - TCP: map[uint16]*ipn.TCPPortHandler{ - 3000: {HTTPS: true}, - }, - }, - }, - }, - servePort: 4040, - serveType: serveTypeTCP, - }, - { - name: "same_fg_port", - desc: "error when overwriting a previous fg port", - cfg: &ipn.ServeConfig{ - Foreground: map[string]*ipn.ServeConfig{ - "abc123": { - TCP: map[uint16]*ipn.TCPPortHandler{ - 3000: {HTTPS: true}, - }, - }, - }, - }, - servePort: 3000, - serveType: serveTypeTCP, - wantErr: true, - }, - { - name: "new_service_tcp", - desc: "no error when adding a new service port", - cfg: &ipn.ServeConfig{ - Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ - "svc:foo": { - TCP: map[uint16]*ipn.TCPPortHandler{80: {HTTP: true}}, - }, - }, - }, - svc: "svc:foo", - servePort: 8080, - serveType: serveTypeTCP, - }, - { - name: "override_service_tcp", - desc: "no error when overwriting a previous service port", - cfg: &ipn.ServeConfig{ - Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ - "svc:foo": { - TCP: map[uint16]*ipn.TCPPortHandler{ - 443: {TCPForward: "http://localhost:4545"}, - }, - }, - }, - }, - svc: "svc:foo", - servePort: 443, - serveType: serveTypeTCP, - }, - { - name: "override_service_tcp", - desc: "error when overwriting a previous service port with a different serve type", - cfg: &ipn.ServeConfig{ - Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ - "svc:foo": { - TCP: map[uint16]*ipn.TCPPortHandler{ - 443: {HTTPS: true}, - }, - }, - }, - }, - svc: "svc:foo", - servePort: 443, - serveType: serveTypeHTTP, - wantErr: true, - }, - { - name: "override_service_tcp", - desc: "error when setting previous tcp service to tun mode", - cfg: &ipn.ServeConfig{ - Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ - "svc:foo": { - TCP: map[uint16]*ipn.TCPPortHandler{ - 443: {TCPForward: "http://localhost:4545"}, - }, - }, - }, - }, - svc: "svc:foo", - serveType: serveTypeTUN, - wantErr: true, - }, - { - name: "override_service_tun", - desc: "error when setting previous tun service to tcp forwarder", - cfg: &ipn.ServeConfig{ - Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ - "svc:foo": { - Tun: true, - }, - }, - }, - svc: "svc:foo", - serveType: serveTypeTCP, - servePort: 443, - wantErr: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - se := serveEnv{bg: tc.bg} - err := se.validateConfig(tc.cfg, tc.servePort, tc.serveType, tc.svc) - if err == nil && tc.wantErr { - t.Fatal("expected an error but got nil") - } - if err != nil && !tc.wantErr { - t.Fatalf("expected no error but got: %v", err) - } - }) - } - -} - func TestSrcTypeFromFlags(t *testing.T) { tests := []struct { name string @@ -1130,6 +1121,118 @@ func TestSrcTypeFromFlags(t *testing.T) { } } +func TestAcceptSetAppCapsFlag(t *testing.T) { + testCases := []struct { + name string + inputs []string + expectErr bool + expectErrToMatch *regexp.Regexp + expectedValue []tailcfg.PeerCapability + }{ + { + name: "valid_simple", + inputs: []string{"example.com/name"}, + expectErr: false, + expectedValue: []tailcfg.PeerCapability{"example.com/name"}, + }, + { + name: "valid_unicode", + inputs: []string{"bücher.de/something"}, + expectErr: false, + expectedValue: []tailcfg.PeerCapability{"bücher.de/something"}, + }, + { + name: "more_valid_unicode", + inputs: []string{"example.tw/某某某"}, + expectErr: false, + expectedValue: []tailcfg.PeerCapability{"example.tw/某某某"}, + }, + { + name: "valid_path_slashes", + inputs: []string{"domain.com/path/to/name"}, + expectErr: false, + expectedValue: []tailcfg.PeerCapability{"domain.com/path/to/name"}, + }, + { + name: "valid_multiple_sets", + inputs: []string{"one.com/foo,two.com/bar"}, + expectErr: false, + expectedValue: []tailcfg.PeerCapability{"one.com/foo", "two.com/bar"}, + }, + { + name: "valid_empty_string", + inputs: []string{""}, + expectErr: false, + expectedValue: nil, // Empty string should be a no-op and not append anything. + }, + { + name: "invalid_path_chars", + inputs: []string{"domain.com/path_with_underscore"}, + expectErr: true, + expectErrToMatch: regexp.MustCompile(`"domain.com/path_with_underscore"`), + expectedValue: nil, // Slice should remain empty. + }, + { + name: "valid_subdomain", + inputs: []string{"sub.domain.com/name"}, + expectErr: false, + expectedValue: []tailcfg.PeerCapability{"sub.domain.com/name"}, + }, + { + name: "invalid_no_path", + inputs: []string{"domain.com/"}, + expectErr: true, + expectErrToMatch: regexp.MustCompile(`"domain.com/"`), + expectedValue: nil, + }, + { + name: "invalid_no_domain", + inputs: []string{"/path/only"}, + expectErr: true, + expectErrToMatch: regexp.MustCompile(`"/path/only"`), + expectedValue: nil, + }, + { + name: "some_invalid_some_valid", + inputs: []string{"one.com/foo,bad/bar,two.com/baz"}, + expectErr: true, + expectErrToMatch: regexp.MustCompile(`"bad/bar"`), + expectedValue: []tailcfg.PeerCapability{"one.com/foo"}, // Parsing will stop after first error + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var v []tailcfg.PeerCapability + flag := &acceptAppCapsFlag{Value: &v} + + var err error + for _, s := range tc.inputs { + err = flag.Set(s) + if err != nil { + break + } + } + + if tc.expectErr && err == nil { + t.Errorf("expected an error, but got none") + } + if tc.expectErrToMatch != nil { + if !tc.expectErrToMatch.MatchString(err.Error()) { + t.Errorf("expected error to match %q, but was %q", tc.expectErrToMatch, err) + } + } + if !tc.expectErr && err != nil { + t.Errorf("did not expect an error, but got: %v", err) + } + + if !reflect.DeepEqual(tc.expectedValue, v) { + t.Errorf("unexpected value, got: %q, want: %q", v, tc.expectedValue) + } + }) + } +} + func TestCleanURLPath(t *testing.T) { tests := []struct { input string @@ -1672,18 +1775,19 @@ func TestSetServe(t *testing.T) { e := &serveEnv{} magicDNSSuffix := "test.ts.net" tests := []struct { - name string - desc string - cfg *ipn.ServeConfig - st *ipnstate.Status - dnsName string - srvType serveType - srvPort uint16 - mountPath string - target string - allowFunnel bool - expected *ipn.ServeConfig - expectErr bool + name string + desc string + cfg *ipn.ServeConfig + st *ipnstate.Status + dnsName string + srvType serveType + srvPort uint16 + mountPath string + target string + allowFunnel bool + proxyProtocol int + expected *ipn.ServeConfig + expectErr bool }{ { name: "add new handler", @@ -1966,7 +2070,7 @@ func TestSetServe(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := e.setServe(tt.cfg, tt.dnsName, tt.srvType, tt.srvPort, tt.mountPath, tt.target, tt.allowFunnel, magicDNSSuffix) + err := e.setServe(tt.cfg, tt.dnsName, tt.srvType, tt.srvPort, tt.mountPath, tt.target, tt.allowFunnel, magicDNSSuffix, nil, tt.proxyProtocol) if err != nil && !tt.expectErr { t.Fatalf("got error: %v; did not expect error.", err) } @@ -2249,3 +2353,8 @@ func exactErrMsg(want error) func(error) string { return fmt.Sprintf("\ngot: %v\nwant: %v\n", got, want) } } + +func ptrToReadOnlySlice[T any](s []T) *views.Slice[T] { + vs := views.SliceOf(s) + return &vs +} diff --git a/cmd/tailscale/cli/set.go b/cmd/tailscale/cli/set.go index 43f8bbbc34afd..31662392f8437 100644 --- a/cmd/tailscale/cli/set.go +++ b/cmd/tailscale/cli/set.go @@ -11,6 +11,7 @@ import ( "net/netip" "os/exec" "runtime" + "slices" "strconv" "strings" @@ -25,6 +26,7 @@ import ( "tailscale.com/types/opt" "tailscale.com/types/ptr" "tailscale.com/types/views" + "tailscale.com/util/set" "tailscale.com/version" ) @@ -43,28 +45,30 @@ Only settings explicitly mentioned will be set. There are no default values.`, } type setArgsT struct { - acceptRoutes bool - acceptDNS bool - exitNodeIP string - exitNodeAllowLANAccess bool - shieldsUp bool - runSSH bool - runWebClient bool - hostname string - advertiseRoutes string - advertiseDefaultRoute bool - advertiseConnector bool - opUser string - acceptedRisks string - profileName string - forceDaemon bool - updateCheck bool - updateApply bool - reportPosture bool - snat bool - statefulFiltering bool - netfilterMode string - relayServerPort string + acceptRoutes bool + acceptDNS bool + exitNodeIP string + exitNodeAllowLANAccess bool + shieldsUp bool + runSSH bool + runWebClient bool + hostname string + advertiseRoutes string + advertiseDefaultRoute bool + advertiseConnector bool + opUser string + acceptedRisks string + profileName string + forceDaemon bool + updateCheck bool + updateApply bool + reportPosture bool + snat bool + statefulFiltering bool + sync bool + netfilterMode string + relayServerPort string + relayServerStaticEndpoints string } func newSetFlagSet(goos string, setArgs *setArgsT) *flag.FlagSet { @@ -85,7 +89,9 @@ func newSetFlagSet(goos string, setArgs *setArgsT) *flag.FlagSet { setf.BoolVar(&setArgs.updateApply, "auto-update", false, "automatically update to the latest available version") setf.BoolVar(&setArgs.reportPosture, "report-posture", false, "allow management plane to gather device posture information") setf.BoolVar(&setArgs.runWebClient, "webclient", false, "expose the web interface for managing this node over Tailscale at port 5252") + setf.BoolVar(&setArgs.sync, "sync", false, hidden+"actively sync configuration from the control plane (set to false only for network failure testing)") setf.StringVar(&setArgs.relayServerPort, "relay-server-port", "", "UDP port number (0 will pick a random unused port) for the relay server to bind to, on all interfaces, or empty string to disable relay server functionality") + setf.StringVar(&setArgs.relayServerStaticEndpoints, "relay-server-static-endpoints", "", "static IP:port endpoints to advertise as candidates for relay connections (comma-separated, e.g. \"[2001:db8::1]:40000,192.0.2.1:40000\") or empty string to not advertise any static endpoints") ffcomplete.Flag(setf, "exit-node", func(args []string) ([]string, ffcomplete.ShellCompDirective, error) { st, err := localClient.Status(context.Background()) @@ -108,7 +114,7 @@ func newSetFlagSet(goos string, setArgs *setArgsT) *flag.FlagSet { switch goos { case "linux": setf.BoolVar(&setArgs.snat, "snat-subnet-routes", true, "source NAT traffic to local routes advertised with --advertise-routes") - setf.BoolVar(&setArgs.statefulFiltering, "stateful-filtering", false, "apply stateful filtering to forwarded packets (subnet routers, exit nodes, etc.)") + setf.BoolVar(&setArgs.statefulFiltering, "stateful-filtering", false, "apply stateful filtering to forwarded packets (subnet routers, exit nodes, and so on)") setf.StringVar(&setArgs.netfilterMode, "netfilter-mode", defaultNetfilterMode(), "netfilter mode (one of on, nodivert, off)") case "windows": setf.BoolVar(&setArgs.forceDaemon, "unattended", false, "run in \"Unattended Mode\" where Tailscale keeps running even after the current GUI user logs out (Windows-only)") @@ -149,6 +155,7 @@ func runSet(ctx context.Context, args []string) (retErr error) { OperatorUser: setArgs.opUser, NoSNAT: !setArgs.snat, ForceDaemon: setArgs.forceDaemon, + Sync: opt.NewBool(setArgs.sync), AutoUpdate: ipn.AutoUpdatePrefs{ Check: setArgs.updateCheck, Apply: opt.NewBool(setArgs.updateApply), @@ -242,7 +249,22 @@ func runSet(ctx context.Context, args []string) (retErr error) { if err != nil { return fmt.Errorf("failed to set relay server port: %v", err) } - maskedPrefs.Prefs.RelayServerPort = ptr.To(int(uport)) + maskedPrefs.Prefs.RelayServerPort = ptr.To(uint16(uport)) + } + + if setArgs.relayServerStaticEndpoints != "" { + endpointsSet := make(set.Set[netip.AddrPort]) + endpointsSplit := strings.Split(setArgs.relayServerStaticEndpoints, ",") + for _, s := range endpointsSplit { + ap, err := netip.ParseAddrPort(s) + if err != nil { + return fmt.Errorf("failed to set relay server static endpoints: %q is not a valid IP:port", s) + } + endpointsSet.Add(ap) + } + endpoints := endpointsSet.Slice() + slices.SortFunc(endpoints, netip.AddrPort.Compare) + maskedPrefs.Prefs.RelayServerStaticEndpoints = endpoints } checkPrefs := curPrefs.Clone() diff --git a/cmd/tailscale/cli/up.go b/cmd/tailscale/cli/up.go index 61cade8de68d0..72515400d8fa1 100644 --- a/cmd/tailscale/cli/up.go +++ b/cmd/tailscale/cli/up.go @@ -122,7 +122,7 @@ func newUpFlagSet(goos string, upArgs *upArgsT, cmd string) *flag.FlagSet { switch goos { case "linux": upf.BoolVar(&upArgs.snat, "snat-subnet-routes", true, "source NAT traffic to local routes advertised with --advertise-routes") - upf.BoolVar(&upArgs.statefulFiltering, "stateful-filtering", false, "apply stateful filtering to forwarded packets (subnet routers, exit nodes, etc.)") + upf.BoolVar(&upArgs.statefulFiltering, "stateful-filtering", false, "apply stateful filtering to forwarded packets (subnet routers, exit nodes, and so on)") upf.StringVar(&upArgs.netfilterMode, "netfilter-mode", defaultNetfilterMode(), "netfilter mode (one of on, nodivert, off)") case "windows": upf.BoolVar(&upArgs.forceDaemon, "unattended", false, "run in \"Unattended Mode\" where Tailscale keeps running even after the current GUI user logs out (Windows-only)") @@ -388,7 +388,8 @@ func updatePrefs(prefs, curPrefs *ipn.Prefs, env upCheckEnv) (simpleUp bool, jus if !env.upArgs.reset { applyImplicitPrefs(prefs, curPrefs, env) - if err := checkForAccidentalSettingReverts(prefs, curPrefs, env); err != nil { + simpleUp, err = checkForAccidentalSettingReverts(prefs, curPrefs, env) + if err != nil { return false, nil, err } } @@ -420,11 +421,6 @@ func updatePrefs(prefs, curPrefs *ipn.Prefs, env upCheckEnv) (simpleUp bool, jus tagsChanged := !reflect.DeepEqual(curPrefs.AdvertiseTags, prefs.AdvertiseTags) - simpleUp = env.flagSet.NFlag() == 0 && - curPrefs.Persist != nil && - curPrefs.Persist.UserProfile.LoginName != "" && - env.backendState != ipn.NeedsLogin.String() - justEdit := env.backendState == ipn.Running.String() && !env.upArgs.forceReauth && env.upArgs.authKeyOrFile == "" && @@ -890,6 +886,8 @@ func init() { addPrefFlagMapping("advertise-connector", "AppConnector") addPrefFlagMapping("report-posture", "PostureChecking") addPrefFlagMapping("relay-server-port", "RelayServerPort") + addPrefFlagMapping("sync", "Sync") + addPrefFlagMapping("relay-server-static-endpoints", "RelayServerStaticEndpoints") } func addPrefFlagMapping(flagName string, prefNames ...string) { @@ -925,7 +923,7 @@ func updateMaskedPrefsFromUpOrSetFlag(mp *ipn.MaskedPrefs, flagName string) { if prefs, ok := prefsOfFlag[flagName]; ok { for _, pref := range prefs { f := reflect.ValueOf(mp).Elem() - for _, name := range strings.Split(pref, ".") { + for name := range strings.SplitSeq(pref, ".") { f = f.FieldByName(name + "Set") } f.SetBool(true) @@ -967,10 +965,10 @@ type upCheckEnv struct { // // mp is the mask of settings actually set, where mp.Prefs is the new // preferences to set, including any values set from implicit flags. -func checkForAccidentalSettingReverts(newPrefs, curPrefs *ipn.Prefs, env upCheckEnv) error { +func checkForAccidentalSettingReverts(newPrefs, curPrefs *ipn.Prefs, env upCheckEnv) (simpleUp bool, err error) { if curPrefs.ControlURL == "" { // Don't validate things on initial "up" before a control URL has been set. - return nil + return false, nil } flagIsSet := map[string]bool{} @@ -978,10 +976,13 @@ func checkForAccidentalSettingReverts(newPrefs, curPrefs *ipn.Prefs, env upCheck flagIsSet[f.Name] = true }) - if len(flagIsSet) == 0 { + if len(flagIsSet) == 0 && + curPrefs.Persist != nil && + curPrefs.Persist.UserProfile.LoginName != "" && + env.backendState != ipn.NeedsLogin.String() { // A bare "tailscale up" is a special case to just // mean bringing the network up without any changes. - return nil + return true, nil } // flagsCur is what flags we'd need to use to keep the exact @@ -1023,7 +1024,7 @@ func checkForAccidentalSettingReverts(newPrefs, curPrefs *ipn.Prefs, env upCheck missing = append(missing, fmtFlagValueArg(flagName, valCur)) } if len(missing) == 0 { - return nil + return false, nil } // Some previously provided flags are missing. This run of 'tailscale @@ -1056,7 +1057,7 @@ func checkForAccidentalSettingReverts(newPrefs, curPrefs *ipn.Prefs, env upCheck fmt.Fprintf(&sb, " %s", a) } sb.WriteString("\n\n") - return errors.New(sb.String()) + return false, errors.New(sb.String()) } // applyImplicitPrefs mutates prefs to add implicit preferences for the user operator. diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index b249639bc80bc..8b576ffc3a4dd 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -85,6 +85,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/cmd/tailscale/cli from tailscale.com/cmd/tailscale tailscale.com/cmd/tailscale/cli/ffcomplete from tailscale.com/cmd/tailscale/cli tailscale.com/cmd/tailscale/cli/ffcomplete/internal from tailscale.com/cmd/tailscale/cli/ffcomplete + tailscale.com/cmd/tailscale/cli/jsonoutput from tailscale.com/cmd/tailscale/cli tailscale.com/control/controlbase from tailscale.com/control/controlhttp+ tailscale.com/control/controlhttp from tailscale.com/control/ts2021 tailscale.com/control/controlhttp/controlhttpcommon from tailscale.com/control/controlhttp @@ -171,7 +172,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/types/structs from tailscale.com/ipn+ tailscale.com/types/tkatype from tailscale.com/types/key+ tailscale.com/types/views from tailscale.com/tailcfg+ - tailscale.com/util/cibuild from tailscale.com/health + tailscale.com/util/cibuild from tailscale.com/health+ tailscale.com/util/clientmetric from tailscale.com/net/netcheck+ tailscale.com/util/cloudenv from tailscale.com/net/dnscache+ tailscale.com/util/cmpver from tailscale.com/net/tshttpproxy+ diff --git a/cmd/tailscaled/depaware-min.txt b/cmd/tailscaled/depaware-min.txt index 224026f25368d..69e6559a0173b 100644 --- a/cmd/tailscaled/depaware-min.txt +++ b/cmd/tailscaled/depaware-min.txt @@ -16,6 +16,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de github.com/klauspost/compress/fse from github.com/klauspost/compress/huff0 github.com/klauspost/compress/huff0 from github.com/klauspost/compress/zstd github.com/klauspost/compress/internal/cpuinfo from github.com/klauspost/compress/huff0+ + 💣 github.com/klauspost/compress/internal/le from github.com/klauspost/compress/huff0+ github.com/klauspost/compress/internal/snapref from github.com/klauspost/compress/zstd github.com/klauspost/compress/zstd from tailscale.com/util/zstdframe github.com/klauspost/compress/zstd/internal/xxhash from github.com/klauspost/compress/zstd @@ -69,7 +70,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/ipn/ipnstate from tailscale.com/control/controlclient+ tailscale.com/ipn/localapi from tailscale.com/ipn/ipnserver tailscale.com/ipn/store from tailscale.com/cmd/tailscaled - tailscale.com/ipn/store/mem from tailscale.com/ipn/store + tailscale.com/ipn/store/mem from tailscale.com/ipn/store+ tailscale.com/kube/kubetypes from tailscale.com/envknob tailscale.com/log/filelogger from tailscale.com/logpolicy tailscale.com/log/sockstatlog from tailscale.com/ipn/ipnlocal @@ -144,7 +145,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/types/views from tailscale.com/appc+ tailscale.com/util/backoff from tailscale.com/control/controlclient+ tailscale.com/util/checkchange from tailscale.com/ipn/ipnlocal+ - tailscale.com/util/cibuild from tailscale.com/health + tailscale.com/util/cibuild from tailscale.com/health+ tailscale.com/util/clientmetric from tailscale.com/appc+ tailscale.com/util/cloudenv from tailscale.com/hostinfo+ tailscale.com/util/ctxkey from tailscale.com/client/tailscale/apitype+ diff --git a/cmd/tailscaled/depaware-minbox.txt b/cmd/tailscaled/depaware-minbox.txt index 9633e73989046..55a21c426b5d5 100644 --- a/cmd/tailscaled/depaware-minbox.txt +++ b/cmd/tailscaled/depaware-minbox.txt @@ -20,6 +20,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de github.com/klauspost/compress/fse from github.com/klauspost/compress/huff0 github.com/klauspost/compress/huff0 from github.com/klauspost/compress/zstd github.com/klauspost/compress/internal/cpuinfo from github.com/klauspost/compress/huff0+ + 💣 github.com/klauspost/compress/internal/le from github.com/klauspost/compress/huff0+ github.com/klauspost/compress/internal/snapref from github.com/klauspost/compress/zstd github.com/klauspost/compress/zstd from tailscale.com/util/zstdframe github.com/klauspost/compress/zstd/internal/xxhash from github.com/klauspost/compress/zstd @@ -92,7 +93,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/ipn/ipnstate from tailscale.com/control/controlclient+ tailscale.com/ipn/localapi from tailscale.com/ipn/ipnserver tailscale.com/ipn/store from tailscale.com/cmd/tailscaled - tailscale.com/ipn/store/mem from tailscale.com/ipn/store + tailscale.com/ipn/store/mem from tailscale.com/ipn/store+ tailscale.com/kube/kubetypes from tailscale.com/envknob tailscale.com/licenses from tailscale.com/cmd/tailscale/cli tailscale.com/log/filelogger from tailscale.com/logpolicy @@ -171,7 +172,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/types/views from tailscale.com/appc+ tailscale.com/util/backoff from tailscale.com/control/controlclient+ tailscale.com/util/checkchange from tailscale.com/ipn/ipnlocal+ - tailscale.com/util/cibuild from tailscale.com/health + tailscale.com/util/cibuild from tailscale.com/health+ tailscale.com/util/clientmetric from tailscale.com/appc+ tailscale.com/util/cloudenv from tailscale.com/hostinfo+ tailscale.com/util/cmpver from tailscale.com/clientupdate diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index e92d41b9855df..79f92deb92f38 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -86,6 +86,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de github.com/coder/websocket/internal/util from github.com/coder/websocket github.com/coder/websocket/internal/xsync from github.com/coder/websocket L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw + github.com/creachadair/msync/trigger from tailscale.com/logtail LD 💣 github.com/creack/pty from tailscale.com/ssh/tailssh W 💣 github.com/dblohm7/wingoes from github.com/dblohm7/wingoes/com+ W 💣 github.com/dblohm7/wingoes/com from tailscale.com/cmd/tailscaled+ @@ -138,6 +139,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de github.com/klauspost/compress/fse from github.com/klauspost/compress/huff0 github.com/klauspost/compress/huff0 from github.com/klauspost/compress/zstd github.com/klauspost/compress/internal/cpuinfo from github.com/klauspost/compress/huff0+ + 💣 github.com/klauspost/compress/internal/le from github.com/klauspost/compress/huff0+ github.com/klauspost/compress/internal/snapref from github.com/klauspost/compress/zstd github.com/klauspost/compress/zstd from tailscale.com/util/zstdframe github.com/klauspost/compress/zstd/internal/xxhash from github.com/klauspost/compress/zstd @@ -155,6 +157,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de L github.com/pierrec/lz4/v4/internal/lz4errors from github.com/pierrec/lz4/v4+ L github.com/pierrec/lz4/v4/internal/lz4stream from github.com/pierrec/lz4/v4 L github.com/pierrec/lz4/v4/internal/xxh32 from github.com/pierrec/lz4/v4/internal/lz4stream + github.com/pires/go-proxyproto from tailscale.com/ipn/ipnlocal LD github.com/pkg/sftp from tailscale.com/ssh/tailssh LD github.com/pkg/sftp/internal/encoding/ssh/filexfer from github.com/pkg/sftp D github.com/prometheus-community/pro-bing from tailscale.com/wgengine/netstack @@ -391,6 +394,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/tsweb from tailscale.com/util/eventbus tailscale.com/tsweb/varz from tailscale.com/cmd/tailscaled+ tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal+ + tailscale.com/types/bools from tailscale.com/wgengine/netlog tailscale.com/types/dnstype from tailscale.com/ipn/ipnlocal+ tailscale.com/types/empty from tailscale.com/ipn+ tailscale.com/types/flagtype from tailscale.com/cmd/tailscaled @@ -414,7 +418,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/types/views from tailscale.com/ipn/ipnlocal+ tailscale.com/util/backoff from tailscale.com/cmd/tailscaled+ tailscale.com/util/checkchange from tailscale.com/ipn/ipnlocal+ - tailscale.com/util/cibuild from tailscale.com/health + tailscale.com/util/cibuild from tailscale.com/health+ tailscale.com/util/clientmetric from tailscale.com/control/controlclient+ tailscale.com/util/cloudenv from tailscale.com/net/dns/resolver+ tailscale.com/util/cmpver from tailscale.com/net/dns+ @@ -567,7 +571,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de crypto/ecdsa from crypto/tls+ crypto/ed25519 from crypto/tls+ crypto/elliptic from crypto/ecdsa+ - crypto/fips140 from crypto/tls/internal/fips140tls + crypto/fips140 from crypto/tls/internal/fips140tls+ crypto/hkdf from crypto/internal/hpke+ crypto/hmac from crypto/tls+ crypto/internal/boring from crypto/aes+ diff --git a/cmd/tailscaled/tailscaled.go b/cmd/tailscaled/tailscaled.go index f14cdcff072b1..203b5a0acba7a 100644 --- a/cmd/tailscaled/tailscaled.go +++ b/cmd/tailscaled/tailscaled.go @@ -33,12 +33,14 @@ import ( "tailscale.com/feature" "tailscale.com/feature/buildfeatures" _ "tailscale.com/feature/condregister" + "tailscale.com/health" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/ipn/conffile" "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/ipnserver" "tailscale.com/ipn/store" + "tailscale.com/ipn/store/mem" "tailscale.com/logpolicy" "tailscale.com/logtail" "tailscale.com/net/dns" @@ -207,7 +209,10 @@ func main() { flag.BoolVar(&args.disableLogs, "no-logs-no-support", false, "disable log uploads; this also disables any technical support") flag.StringVar(&args.confFile, "config", "", "path to config file, or 'vm:user-data' to use the VM's user-data (EC2)") if buildfeatures.HasTPM { - flag.Var(&args.hardwareAttestation, "hardware-attestation", "use hardware-backed keys to bind node identity to this device when supported by the OS and hardware. Uses TPM 2.0 on Linux and Windows; SecureEnclave on macOS and iOS; and Keystore on Android") + flag.Var(&args.hardwareAttestation, "hardware-attestation", `use hardware-backed keys to bind node identity to this device when supported +by the OS and hardware. Uses TPM 2.0 on Linux and Windows; SecureEnclave on +macOS and iOS; and Keystore on Android. Only supported for Tailscale nodes that +store state on filesystem.`) } if f, ok := hookRegisterOutboundProxyFlags.GetOk(); ok { f() @@ -644,7 +649,16 @@ func getLocalBackend(ctx context.Context, logf logger.Logf, logID logid.PublicID store, err := store.New(logf, statePathOrDefault()) if err != nil { - return nil, fmt.Errorf("store.New: %w", err) + // If we can't create the store (for example if it's TPM-sealed and the + // TPM is reset), create a dummy in-memory store to propagate the error + // to the user. + ht, ok := sys.HealthTracker.GetOK() + if !ok { + return nil, fmt.Errorf("store.New: %w", err) + } + logf("store.New failed: %v; starting with in-memory store with a health warning", err) + store = new(mem.Store) + ht.SetUnhealthy(ipn.StateStoreHealth, health.Args{health.ArgError: err.Error()}) } sys.Set(store) @@ -893,17 +907,17 @@ func applyIntegrationTestEnvKnob() { func handleTPMFlags() { switch { case args.hardwareAttestation.v: - if _, err := key.NewEmptyHardwareAttestationKey(); err == key.ErrUnsupported { + if err := canUseHardwareAttestation(); err != nil { log.SetFlags(0) - log.Fatalf("--hardware-attestation is not supported on this platform or in this build of tailscaled") + log.Fatal(err) } case !args.hardwareAttestation.set: - policyHWAttestation, _ := policyclient.Get().GetBoolean(pkey.HardwareAttestation, feature.HardwareAttestationAvailable()) - if !policyHWAttestation { - break - } - if feature.TPMAvailable() { - args.hardwareAttestation.v = true + policyHWAttestation, _ := policyclient.Get().GetBoolean(pkey.HardwareAttestation, false) + if err := canUseHardwareAttestation(); err != nil { + log.Printf("[unexpected] policy requires hardware attestation, but device does not support it: %v", err) + args.hardwareAttestation.v = false + } else { + args.hardwareAttestation.v = policyHWAttestation } } @@ -915,18 +929,46 @@ func handleTPMFlags() { log.Fatal(err) } case !args.encryptState.set: - policyEncrypt, _ := policyclient.Get().GetBoolean(pkey.EncryptState, feature.TPMAvailable()) - if !policyEncrypt { - // Default disabled, no need to validate. - return - } - // Default enabled if available. - if err := canEncryptState(); err == nil { + policyEncrypt, _ := policyclient.Get().GetBoolean(pkey.EncryptState, false) + if err := canEncryptState(); policyEncrypt && err == nil { args.encryptState.v = true } } } +// canUseHardwareAttestation returns an error if hardware attestation can't be +// enabled, either due to availability or compatibility with other settings. +func canUseHardwareAttestation() error { + if _, err := key.NewEmptyHardwareAttestationKey(); err == key.ErrUnsupported { + return errors.New("--hardware-attestation is not supported on this platform or in this build of tailscaled") + } + // Hardware attestation keys are TPM-bound and cannot be migrated between + // machines. Disable when using portable state stores like kube: or arn: + // where state may be loaded on a different machine. + if args.statepath != "" && isPortableStore(args.statepath) { + return errors.New("--hardware-attestation cannot be used with portable state stores (kube:, arn:) because TPM-bound keys cannot be migrated between machines") + } + return nil +} + +// isPortableStore reports whether the given state path refers to a portable +// state store where state may be loaded on different machines. +// All stores apart from file store and TPM store are portable. +func isPortableStore(path string) bool { + if store.HasKnownProviderPrefix(path) && !strings.HasPrefix(path, store.TPMPrefix) { + return true + } + // In most cases Kubernetes Secret and AWS SSM stores would have been caught + // by the earlier check - but that check relies on those stores having been + // registered. This additional check is here to ensure that if we ever + // produce a faulty build that failed to register some store, users who + // upgraded to that don't get hardware keys generated. + if strings.HasPrefix(path, "kube:") || strings.HasPrefix(path, "arn:") { + return true + } + return false +} + // canEncryptState returns an error if state encryption can't be enabled, // either due to availability or compatibility with other settings. func canEncryptState() error { diff --git a/cmd/tailscaled/tailscaled_test.go b/cmd/tailscaled/tailscaled_test.go index c50c237591170..36327cccc7bc7 100644 --- a/cmd/tailscaled/tailscaled_test.go +++ b/cmd/tailscaled/tailscaled_test.go @@ -4,9 +4,17 @@ package main // import "tailscale.com/cmd/tailscaled" import ( + "os" + "strings" "testing" + "tailscale.com/envknob" + "tailscale.com/ipn" + "tailscale.com/net/netmon" + "tailscale.com/tsd" "tailscale.com/tstest/deptest" + "tailscale.com/types/logid" + "tailscale.com/util/must" ) func TestNothing(t *testing.T) { @@ -38,3 +46,98 @@ func TestDeps(t *testing.T) { }, }.Check(t) } + +func TestStateStoreError(t *testing.T) { + logID, err := logid.NewPrivateID() + if err != nil { + t.Fatal(err) + } + // Don't upload any logs from tests. + envknob.SetNoLogsNoSupport() + + args.statedir = t.TempDir() + args.tunname = "userspace-networking" + + t.Run("new state", func(t *testing.T) { + sys := tsd.NewSystem() + sys.NetMon.Set(must.Get(netmon.New(sys.Bus.Get(), t.Logf))) + lb, err := getLocalBackend(t.Context(), t.Logf, logID.Public(), sys) + if err != nil { + t.Fatal(err) + } + defer lb.Shutdown() + if lb.HealthTracker().IsUnhealthy(ipn.StateStoreHealth) { + t.Errorf("StateStoreHealth is unhealthy on fresh LocalBackend:\n%s", strings.Join(lb.HealthTracker().Strings(), "\n")) + } + }) + t.Run("corrupt state", func(t *testing.T) { + sys := tsd.NewSystem() + sys.NetMon.Set(must.Get(netmon.New(sys.Bus.Get(), t.Logf))) + // Populate the state file with something that will fail to parse to + // trigger an error from store.New. + if err := os.WriteFile(statePathOrDefault(), []byte("bad json"), 0644); err != nil { + t.Fatal(err) + } + lb, err := getLocalBackend(t.Context(), t.Logf, logID.Public(), sys) + if err != nil { + t.Fatal(err) + } + defer lb.Shutdown() + if !lb.HealthTracker().IsUnhealthy(ipn.StateStoreHealth) { + t.Errorf("StateStoreHealth is healthy when state file is corrupt") + } + }) +} + +func TestIsPortableStore(t *testing.T) { + tests := []struct { + name string + path string + want bool + }{ + { + name: "kube_store", + path: "kube:my-secret", + want: true, + }, + { + name: "aws_arn_store", + path: "arn:aws:ssm:us-east-1:123456789012:parameter/tailscale/state", + want: true, + }, + { + name: "tpm_store", + path: "tpmseal:/var/lib/tailscale/tailscaled.state", + want: false, + }, + { + name: "local_file_store", + path: "/var/lib/tailscale/tailscaled.state", + want: false, + }, + { + name: "empty_path", + path: "", + want: false, + }, + { + name: "mem_store", + path: "mem:", + want: true, + }, + { + name: "windows_file_store", + path: `C:\ProgramData\Tailscale\server-state.conf`, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isPortableStore(tt.path) + if got != tt.want { + t.Errorf("isPortableStore(%q) = %v, want %v", tt.path, got, tt.want) + } + }) + } +} diff --git a/cmd/tl-longchain/tl-longchain.go b/cmd/tl-longchain/tl-longchain.go index 2a4dc10ba331c..384d24222e6d5 100644 --- a/cmd/tl-longchain/tl-longchain.go +++ b/cmd/tl-longchain/tl-longchain.go @@ -75,8 +75,8 @@ func peerInfo(peer *ipnstate.TKAPeer) string { // print prints a message about a node key signature and a re-signing command if needed. func print(info string, nodeKey key.NodePublic, sig tka.NodeKeySignature) { - if l := chainLength(sig); l > *maxRotations { - log.Printf("%s: chain length %d, printing command to re-sign", info, l) + if ln := chainLength(sig); ln > *maxRotations { + log.Printf("%s: chain length %d, printing command to re-sign", info, ln) wrapping, _ := sig.UnverifiedWrappingPublic() fmt.Printf("tailscale lock sign %s %s\n", nodeKey, key.NLPublicFromEd25519Unsafe(wrapping).CLIString()) } else { diff --git a/cmd/tsconnect/wasm/wasm_js.go b/cmd/tsconnect/wasm/wasm_js.go index 2e81fa4a8a2e7..c7aa00d1d794f 100644 --- a/cmd/tsconnect/wasm/wasm_js.go +++ b/cmd/tsconnect/wasm/wasm_js.go @@ -261,7 +261,7 @@ func (i *jsIPN) run(jsCallbacks js.Value) { jsNetMap := jsNetMap{ Self: jsNetMapSelfNode{ jsNetMapNode: jsNetMapNode{ - Name: nm.Name, + Name: nm.SelfName(), Addresses: mapSliceView(nm.GetAddresses(), func(a netip.Prefix) string { return a.Addr().String() }), NodeKey: nm.NodeKey.String(), MachineKey: nm.MachineKey.String(), diff --git a/cmd/tsidp/depaware.txt b/cmd/tsidp/depaware.txt index a2a473a5068ec..5c6aae5121196 100644 --- a/cmd/tsidp/depaware.txt +++ b/cmd/tsidp/depaware.txt @@ -9,6 +9,7 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar github.com/coder/websocket/internal/errd from github.com/coder/websocket github.com/coder/websocket/internal/util from github.com/coder/websocket github.com/coder/websocket/internal/xsync from github.com/coder/websocket + github.com/creachadair/msync/trigger from tailscale.com/logtail W 💣 github.com/dblohm7/wingoes from tailscale.com/net/tshttpproxy+ W 💣 github.com/dblohm7/wingoes/com from tailscale.com/util/osdiag+ W 💣 github.com/dblohm7/wingoes/com/automation from tailscale.com/util/osdiag/internal/wsc @@ -35,6 +36,7 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar github.com/klauspost/compress/fse from github.com/klauspost/compress/huff0 github.com/klauspost/compress/huff0 from github.com/klauspost/compress/zstd github.com/klauspost/compress/internal/cpuinfo from github.com/klauspost/compress/huff0+ + 💣 github.com/klauspost/compress/internal/le from github.com/klauspost/compress/huff0+ github.com/klauspost/compress/internal/snapref from github.com/klauspost/compress/zstd github.com/klauspost/compress/zstd from tailscale.com/util/zstdframe github.com/klauspost/compress/zstd/internal/xxhash from github.com/klauspost/compress/zstd @@ -42,6 +44,7 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar L 💣 github.com/mdlayher/netlink/nlenc from github.com/jsimonetti/rtnetlink+ L 💣 github.com/mdlayher/socket from github.com/mdlayher/netlink+ 💣 github.com/mitchellh/go-ps from tailscale.com/safesocket + github.com/pires/go-proxyproto from tailscale.com/ipn/ipnlocal D github.com/prometheus-community/pro-bing from tailscale.com/wgengine/netstack L 💣 github.com/safchain/ethtool from tailscale.com/net/netkernelconf W 💣 github.com/tailscale/certstore from tailscale.com/control/controlclient @@ -229,7 +232,7 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar tailscale.com/tsweb from tailscale.com/util/eventbus tailscale.com/tsweb/varz from tailscale.com/tsweb+ tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal+ - tailscale.com/types/bools from tailscale.com/tsnet + tailscale.com/types/bools from tailscale.com/tsnet+ tailscale.com/types/dnstype from tailscale.com/client/local+ tailscale.com/types/empty from tailscale.com/ipn+ tailscale.com/types/ipproto from tailscale.com/ipn+ @@ -252,7 +255,7 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar tailscale.com/types/views from tailscale.com/appc+ tailscale.com/util/backoff from tailscale.com/control/controlclient+ tailscale.com/util/checkchange from tailscale.com/ipn/ipnlocal+ - tailscale.com/util/cibuild from tailscale.com/health + tailscale.com/util/cibuild from tailscale.com/health+ tailscale.com/util/clientmetric from tailscale.com/appc+ tailscale.com/util/cloudenv from tailscale.com/hostinfo+ LW tailscale.com/util/cmpver from tailscale.com/net/dns+ @@ -397,7 +400,7 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar crypto/ecdsa from crypto/tls+ crypto/ed25519 from crypto/tls+ crypto/elliptic from crypto/ecdsa+ - crypto/fips140 from crypto/tls/internal/fips140tls + crypto/fips140 from crypto/tls/internal/fips140tls+ crypto/hkdf from crypto/internal/hpke+ crypto/hmac from crypto/tls+ crypto/internal/boring from crypto/aes+ diff --git a/cmd/tsidp/tsidp.go b/cmd/tsidp/tsidp.go index c02b09745aec8..7093ab9ee193a 100644 --- a/cmd/tsidp/tsidp.go +++ b/cmd/tsidp/tsidp.go @@ -287,7 +287,7 @@ func serveOnLocalTailscaled(ctx context.Context, lc *local.Client, st *ipnstate. // We watch the IPN bus just to get a session ID. The session expires // when we stop watching the bus, and that auto-deletes the foreground // serve/funnel configs we are creating below. - watcher, err := lc.WatchIPNBus(ctx, ipn.NotifyInitialState|ipn.NotifyNoPrivateKeys) + watcher, err := lc.WatchIPNBus(ctx, ipn.NotifyInitialState) if err != nil { return nil, nil, fmt.Errorf("could not set up ipn bus watcher: %v", err) } diff --git a/cmd/vet/jsontags/analyzer.go b/cmd/vet/jsontags/analyzer.go new file mode 100644 index 0000000000000..d799b66cbb583 --- /dev/null +++ b/cmd/vet/jsontags/analyzer.go @@ -0,0 +1,201 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package jsontags checks for incompatible usage of JSON struct tags. +package jsontags + +import ( + "go/ast" + "go/types" + "reflect" + "strings" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/inspector" +) + +var Analyzer = &analysis.Analyzer{ + Name: "jsonvet", + Doc: "check for incompatible usages of JSON struct tags", + Requires: []*analysis.Analyzer{inspect.Analyzer}, + Run: run, +} + +func run(pass *analysis.Pass) (any, error) { + inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) + + // TODO: Report byte arrays fields without an explicit `format` tag option. + + inspect.Preorder([]ast.Node{(*ast.StructType)(nil)}, func(n ast.Node) { + structType, ok := pass.TypesInfo.Types[n.(*ast.StructType)].Type.(*types.Struct) + if !ok { + return // type information may be incomplete + } + for i := range structType.NumFields() { + fieldVar := structType.Field(i) + tag := reflect.StructTag(structType.Tag(i)).Get("json") + if tag == "" { + continue + } + var seenName, hasFormat bool + for opt := range strings.SplitSeq(tag, ",") { + if !seenName { + seenName = true + continue + } + switch opt { + case "omitempty": + // For bools, ints, uints, floats, strings, and interfaces, + // it is always safe to migrate from `omitempty` to `omitzero` + // so long as the type does not have an IsZero method or + // the IsZero method is identical to reflect.Value.IsZero. + // + // For pointers, it is only safe to migrate from `omitempty` to `omitzero` + // so long as the type does not have an IsZero method, regardless of + // whether the IsZero method is identical to reflect.Value.IsZero. + // + // For pointers, `omitempty` behaves identically on both v1 and v2 + // so long as the type does not implement a Marshal method that + // might serialize as an empty JSON value (i.e., null, "", [], or {}). + hasIsZero := hasIsZeroMethod(fieldVar.Type()) && !hasPureIsZeroMethod(fieldVar.Type()) + underType := fieldVar.Type().Underlying() + basic, isBasic := underType.(*types.Basic) + array, isArrayKind := underType.(*types.Array) + _, isMapKind := underType.(*types.Map) + _, isSliceKind := underType.(*types.Slice) + _, isPointerKind := underType.(*types.Pointer) + _, isInterfaceKind := underType.(*types.Interface) + supportedInV1 := isNumericKind(underType) || + isBasic && basic.Kind() == types.Bool || + isBasic && basic.Kind() == types.String || + isArrayKind && array.Len() == 0 || + isMapKind || isSliceKind || isPointerKind || isInterfaceKind + notSupportedInV2 := isNumericKind(underType) || + isBasic && basic.Kind() == types.Bool + switch { + case isMapKind, isSliceKind: + // This operates the same under both v1 and v2 so long as + // the map or slice type does not implement Marshal + // that could emit an empty JSON value for cases + // other than when the map or slice are empty. + // This is very rare. + case isString(fieldVar.Type()): + // This operates the same under both v1 and v2. + // These are safe to migrate to `omitzero`, + // but doing so is probably unnecessary churn. + // Note that this is only for a unnamed string type. + case !supportedInV1: + // This never worked in v1. Switching to `omitzero` + // may lead to unexpected behavior changes. + report(pass, structType, fieldVar, OmitEmptyUnsupportedInV1) + case notSupportedInV2: + // This does not work in v2. Switching to `omitzero` + // may lead to unexpected behavior changes. + report(pass, structType, fieldVar, OmitEmptyUnsupportedInV2) + case !hasIsZero: + // These are safe to migrate to `omitzero` such that + // it behaves identically under v1 and v2. + report(pass, structType, fieldVar, OmitEmptyShouldBeOmitZero) + case isPointerKind: + // This operates the same under both v1 and v2 so long as + // the pointer type does not implement Marshal that + // could emit an empty JSON value. + // For example, time.Time is safe since the zero value + // never marshals as an empty JSON string. + default: + // This is a non-pointer type with an IsZero method. + // If IsZero is not identical to reflect.Value.IsZero, + // omission may behave slightly differently when using + // `omitzero` instead of `omitempty`. + // Thus the finding uses the word "should". + report(pass, structType, fieldVar, OmitEmptyShouldBeOmitZeroButHasIsZero) + } + case "string": + if !isNumericKind(fieldVar.Type()) { + report(pass, structType, fieldVar, StringOnNonNumericKind) + } + default: + key, _, ok := strings.Cut(opt, ":") + hasFormat = key == "format" && ok + } + } + if !hasFormat && isTimeDuration(mayPointerElem(fieldVar.Type())) { + report(pass, structType, fieldVar, FormatMissingOnTimeDuration) + } + } + }) + return nil, nil +} + +// hasIsZeroMethod reports whether t has an IsZero method. +func hasIsZeroMethod(t types.Type) bool { + for method := range types.NewMethodSet(t).Methods() { + if fn, ok := method.Type().(*types.Signature); ok && method.Obj().Name() == "IsZero" { + if fn.Params().Len() == 0 && fn.Results().Len() == 1 && isBool(fn.Results().At(0).Type()) { + return true + } + } + } + return false +} + +// isBool reports whether t is a bool type. +func isBool(t types.Type) bool { + basic, ok := t.(*types.Basic) + return ok && basic.Kind() == types.Bool +} + +// isString reports whether t is a string type. +func isString(t types.Type) bool { + basic, ok := t.(*types.Basic) + return ok && basic.Kind() == types.String +} + +// isTimeDuration reports whether t is a time.Duration type. +func isTimeDuration(t types.Type) bool { + return isNamed(t, "time", "Duration") +} + +// mayPointerElem returns the pointed-at type if t is a pointer, +// otherwise it returns t as-is. +func mayPointerElem(t types.Type) types.Type { + if pointer, ok := t.(*types.Pointer); ok { + return pointer.Elem() + } + return t +} + +// isNamed reports t is a named typed of the given path and name. +func isNamed(t types.Type, path, name string) bool { + gotPath, gotName := typeName(t) + return gotPath == path && gotName == name +} + +// typeName reports the pkgPath and name of the type. +// It recursively follows type aliases to get the underlying named type. +func typeName(t types.Type) (pkgPath, name string) { + if named, ok := types.Unalias(t).(*types.Named); ok { + obj := named.Obj() + if pkg := obj.Pkg(); pkg != nil { + return pkg.Path(), obj.Name() + } + return "", obj.Name() + } + return "", "" +} + +// isNumericKind reports whether t is a numeric kind. +func isNumericKind(t types.Type) bool { + if basic, ok := t.Underlying().(*types.Basic); ok { + switch basic.Kind() { + case types.Int, types.Int8, types.Int16, types.Int32, types.Int64: + case types.Uint, types.Uint8, types.Uint16, types.Uint32, types.Uint64, types.Uintptr: + case types.Float32, types.Float64: + default: + return false + } + return true + } + return false +} diff --git a/cmd/vet/jsontags/iszero.go b/cmd/vet/jsontags/iszero.go new file mode 100644 index 0000000000000..77520d72c66f3 --- /dev/null +++ b/cmd/vet/jsontags/iszero.go @@ -0,0 +1,75 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package jsontags + +import ( + "go/types" + "reflect" + + "tailscale.com/util/set" +) + +var _ = reflect.Value.IsZero // refer for hot-linking purposes + +var pureIsZeroMethods map[string]set.Set[string] + +// hasPureIsZeroMethod reports whether the IsZero method is truly +// identical to [reflect.Value.IsZero]. +func hasPureIsZeroMethod(t types.Type) bool { + // TODO: Detect this automatically by checking the method AST? + path, name := typeName(t) + return pureIsZeroMethods[path].Contains(name) +} + +// PureIsZeroMethodsInTailscaleModule is a list of known IsZero methods +// in the "tailscale.com" module that are pure. +var PureIsZeroMethodsInTailscaleModule = map[string]set.Set[string]{ + "tailscale.com/net/packet": set.Of( + "TailscaleRejectReason", + ), + "tailscale.com/tailcfg": set.Of( + "UserID", + "LoginID", + "NodeID", + "StableNodeID", + ), + "tailscale.com/tka": set.Of( + "AUMHash", + ), + "tailscale.com/types/geo": set.Of( + "Point", + ), + "tailscale.com/tstime/mono": set.Of( + "Time", + ), + "tailscale.com/types/key": set.Of( + "NLPrivate", + "NLPublic", + "DERPMesh", + "MachinePrivate", + "MachinePublic", + "ControlPrivate", + "DiscoPrivate", + "DiscoPublic", + "DiscoShared", + "HardwareAttestationPublic", + "ChallengePublic", + "NodePrivate", + "NodePublic", + ), + "tailscale.com/types/netlogtype": set.Of( + "Connection", + "Counts", + ), +} + +// RegisterPureIsZeroMethods specifies a list of pure IsZero methods +// where it is identical to calling [reflect.Value.IsZero] on the receiver. +// This is not strictly necessary, but allows for more accurate +// detection of improper use of `json` tags. +// +// This must be called at init and the input must not be mutated. +func RegisterPureIsZeroMethods(methods map[string]set.Set[string]) { + pureIsZeroMethods = methods +} diff --git a/cmd/vet/jsontags/report.go b/cmd/vet/jsontags/report.go new file mode 100644 index 0000000000000..8e5869060799c --- /dev/null +++ b/cmd/vet/jsontags/report.go @@ -0,0 +1,135 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package jsontags + +import ( + "fmt" + "go/types" + "os" + "strings" + + _ "embed" + + "golang.org/x/tools/go/analysis" + "tailscale.com/util/set" +) + +var jsontagsAllowlist map[ReportKind]set.Set[string] + +// ParseAllowlist parses an allowlist of reports to ignore, +// which is a newline-delimited list of tuples separated by a tab, +// where each tuple is a [ReportKind] and a fully-qualified field name. +// +// For example: +// +// OmitEmptyUnsupportedInV1 tailscale.com/path/to/package.StructType.FieldName +// OmitEmptyUnsupportedInV1 tailscale.com/path/to/package.*.FieldName +// +// The struct type name may be "*" for anonymous struct types such +// as those declared within a function or as a type literal in a variable. +func ParseAllowlist(s string) map[ReportKind]set.Set[string] { + var allowlist map[ReportKind]set.Set[string] + for line := range strings.SplitSeq(s, "\n") { + kind, field, _ := strings.Cut(strings.TrimSpace(line), "\t") + if allowlist == nil { + allowlist = make(map[ReportKind]set.Set[string]) + } + fields := allowlist[ReportKind(kind)] + if fields == nil { + fields = make(set.Set[string]) + } + fields.Add(field) + allowlist[ReportKind(kind)] = fields + } + return allowlist +} + +// RegisterAllowlist registers an allowlist of reports to ignore, +// which is represented by a set of fully-qualified field names +// for each [ReportKind]. +// +// For example: +// +// { +// "OmitEmptyUnsupportedInV1": set.Of( +// "tailscale.com/path/to/package.StructType.FieldName", +// "tailscale.com/path/to/package.*.FieldName", +// ), +// } +// +// The struct type name may be "*" for anonymous struct types such +// as those declared within a function or as a type literal in a variable. +// +// This must be called at init and the input must not be mutated. +func RegisterAllowlist(allowlist map[ReportKind]set.Set[string]) { + jsontagsAllowlist = allowlist +} + +type ReportKind string + +const ( + OmitEmptyUnsupportedInV1 ReportKind = "OmitEmptyUnsupportedInV1" + OmitEmptyUnsupportedInV2 ReportKind = "OmitEmptyUnsupportedInV2" + OmitEmptyShouldBeOmitZero ReportKind = "OmitEmptyShouldBeOmitZero" + OmitEmptyShouldBeOmitZeroButHasIsZero ReportKind = "OmitEmptyShouldBeOmitZeroButHasIsZero" + StringOnNonNumericKind ReportKind = "StringOnNonNumericKind" + FormatMissingOnTimeDuration ReportKind = "FormatMissingOnTimeDuration" +) + +func (k ReportKind) message() string { + switch k { + case OmitEmptyUnsupportedInV1: + return "uses `omitempty` on an unsupported type in json/v1; should probably use `omitzero` instead" + case OmitEmptyUnsupportedInV2: + return "uses `omitempty` on an unsupported type in json/v2; should probably use `omitzero` instead" + case OmitEmptyShouldBeOmitZero: + return "should use `omitzero` instead of `omitempty`" + case OmitEmptyShouldBeOmitZeroButHasIsZero: + return "should probably use `omitzero` instead of `omitempty`" + case StringOnNonNumericKind: + return "must not use `string` on non-numeric types" + case FormatMissingOnTimeDuration: + return "must use an explicit `format` tag (e.g., `format:nano`) on a time.Duration type; see https://go.dev/issue/71631" + default: + return string(k) + } +} + +func report(pass *analysis.Pass, structType *types.Struct, fieldVar *types.Var, k ReportKind) { + // Lookup the full name of the struct type. + var fullName string + for _, name := range pass.Pkg.Scope().Names() { + if obj := pass.Pkg.Scope().Lookup(name); obj != nil { + if named, ok := obj.(*types.TypeName); ok { + if types.Identical(named.Type().Underlying(), structType) { + fullName = fmt.Sprintf("%v.%v.%v", named.Pkg().Path(), named.Name(), fieldVar.Name()) + break + } + } + } + } + if fullName == "" { + // Full name could not be found since this is probably an anonymous type + // or locally declared within a function scope. + // Use just the package path and field name instead. + // This is imprecise, but better than nothing. + fullName = fmt.Sprintf("%s.*.%s", fieldVar.Pkg().Path(), fieldVar.Name()) + } + if jsontagsAllowlist[k].Contains(fullName) { + return + } + + const appendAllowlist = "" + if appendAllowlist != "" { + if f, err := os.OpenFile(appendAllowlist, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0664); err == nil { + fmt.Fprintf(f, "%v\t%v\n", k, fullName) + f.Close() + } + } + + pass.Report(analysis.Diagnostic{ + Pos: fieldVar.Pos(), + Message: fmt.Sprintf("field %q %s", fieldVar.Name(), k.message()), + }) +} diff --git a/cmd/vet/jsontags_allowlist b/cmd/vet/jsontags_allowlist new file mode 100644 index 0000000000000..9526f44ef9d9a --- /dev/null +++ b/cmd/vet/jsontags_allowlist @@ -0,0 +1,315 @@ +OmitEmptyShouldBeOmitZero tailscale.com/client/web.authResponse.ViewerIdentity +OmitEmptyShouldBeOmitZero tailscale.com/cmd/k8s-operator.OwnerRef.Resource +OmitEmptyShouldBeOmitZero tailscale.com/cmd/tailscale/cli.apiResponse.Error +OmitEmptyShouldBeOmitZero tailscale.com/health.UnhealthyState.PrimaryAction +OmitEmptyShouldBeOmitZero tailscale.com/internal/client/tailscale.VIPService.Name +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.AcceptDNS +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.AcceptRoutes +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.AllowLANWhileUsingExitNode +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.AppConnector +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.AuthKey +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.AutoUpdate +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.DisableSNAT +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.Enabled +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.ExitNode +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.Hostname +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.Locked +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.NetfilterMode +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.NoStatefulFiltering +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.OperatorUser +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.PostureChecking +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.RunSSHServer +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.RunWebClient +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.ServeConfigTemp +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.ServerURL +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.ShieldsUp +OmitEmptyShouldBeOmitZero tailscale.com/ipn.OutgoingFile.PeerID +OmitEmptyShouldBeOmitZero tailscale.com/ipn.Prefs.AutoExitNode +OmitEmptyShouldBeOmitZero tailscale.com/ipn.Prefs.NoStatefulFiltering +OmitEmptyShouldBeOmitZero tailscale.com/ipn.Prefs.RelayServerPort +OmitEmptyShouldBeOmitZero tailscale.com/ipn/auditlog.transaction.Action +OmitEmptyShouldBeOmitZero tailscale.com/ipn/ipnstate.PeerStatus.AllowedIPs +OmitEmptyShouldBeOmitZero tailscale.com/ipn/ipnstate.PeerStatus.Location +OmitEmptyShouldBeOmitZero tailscale.com/ipn/ipnstate.PeerStatus.PrimaryRoutes +OmitEmptyShouldBeOmitZero tailscale.com/ipn/ipnstate.PeerStatus.Tags +OmitEmptyShouldBeOmitZero tailscale.com/ipn/ipnstate.Status.ExitNodeStatus +OmitEmptyShouldBeOmitZero tailscale.com/ipn/ipnstate.UpdateProgress.Status +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ConnectorSpec.AppConnector +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ConnectorSpec.Hostname +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ConnectorSpec.HostnamePrefix +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ConnectorSpec.Replicas +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ConnectorSpec.SubnetRouter +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Container.Debug +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Container.ImagePullPolicy +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Container.SecurityContext +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.KubeAPIServerConfig.Mode +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Nameserver.Image +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Nameserver.Pod +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Nameserver.Replicas +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Nameserver.Service +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Pod.Affinity +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Pod.DNSConfig +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Pod.DNSPolicy +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Pod.SecurityContext +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Pod.TailscaleContainer +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Pod.TailscaleInitContainer +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ProxyClassSpec.Metrics +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ProxyClassSpec.StaticEndpoints +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ProxyClassSpec.TailscaleConfig +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ProxyGroupSpec.HostnamePrefix +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ProxyGroupSpec.KubeAPIServer +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ProxyGroupSpec.Replicas +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.RecorderContainer.ImagePullPolicy +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.RecorderContainer.SecurityContext +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.RecorderPod.Affinity +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.RecorderPod.SecurityContext +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.StatefulSet.Pod +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Storage.S3 +OmitEmptyShouldBeOmitZero tailscale.com/kube/ingressservices.Config.IPv4Mapping +OmitEmptyShouldBeOmitZero tailscale.com/kube/ingressservices.Config.IPv6Mapping +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.APIServerProxyConfig.Enabled +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.APIServerProxyConfig.IssueCerts +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.APIServerProxyConfig.Mode +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.APIServerProxyConfig.ServiceName +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.AcceptRoutes +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.APIServerProxy +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.App +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.AuthKey +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.HealthCheckEnabled +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.Hostname +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.LocalAddr +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.LocalPort +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.LogLevel +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.MetricsEnabled +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.ServerURL +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.State +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.VersionedConfig.V1Alpha1 +OmitEmptyShouldBeOmitZero tailscale.com/kube/kubeapi.ObjectMeta.DeletionGracePeriodSeconds +OmitEmptyShouldBeOmitZero tailscale.com/kube/kubeapi.Status.Details +OmitEmptyShouldBeOmitZero tailscale.com/kube/kubeclient.JSONPatch.Value +OmitEmptyShouldBeOmitZero tailscale.com/kube/kubetypes.*.Mode +OmitEmptyShouldBeOmitZero tailscale.com/kube/kubetypes.KubernetesCapRule.Impersonate +OmitEmptyShouldBeOmitZero tailscale.com/sessionrecording.CastHeader.Kubernetes +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.AuditLogRequest.Action +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Debug.Exit +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.DERPMap.HomeParams +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.DisplayMessage.PrimaryAction +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Hostinfo.AppConnector +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Hostinfo.Container +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Hostinfo.Desktop +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Hostinfo.Location +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Hostinfo.NetInfo +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Hostinfo.StateEncrypted +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Hostinfo.TPM +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Hostinfo.Userspace +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Hostinfo.UserspaceRouter +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.ClientVersion +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.CollectServices +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.ControlDialPlan +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.Debug +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.DeprecatedDefaultAutoUpdate +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.DERPMap +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.DNSConfig +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.Node +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.PingRequest +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.SSHPolicy +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.TKAInfo +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.NetPortRange.Bits +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Node.Online +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Node.SelfNodeV4MasqAddrForThisPeer +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Node.SelfNodeV6MasqAddrForThisPeer +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.PeerChange.Online +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.RegisterRequest.Auth +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.RegisterResponseAuth.Oauth2Token +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.SSHAction.OnRecordingFailure +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.SSHPrincipal.Node +OmitEmptyShouldBeOmitZero tailscale.com/tempfork/acme.*.ExternalAccountBinding +OmitEmptyShouldBeOmitZero tailscale.com/tsweb.AccessLogRecord.RequestID +OmitEmptyShouldBeOmitZero tailscale.com/types/opt.*.Unset +OmitEmptyShouldBeOmitZero tailscale.com/types/views.viewStruct.AddrsPtr +OmitEmptyShouldBeOmitZero tailscale.com/types/views.viewStruct.StringsPtr +OmitEmptyShouldBeOmitZero tailscale.com/wgengine/magicsock.EndpointChange.From +OmitEmptyShouldBeOmitZero tailscale.com/wgengine/magicsock.EndpointChange.To +OmitEmptyShouldBeOmitZeroButHasIsZero tailscale.com/types/persist.Persist.AttestationKey +OmitEmptyUnsupportedInV1 tailscale.com/client/tailscale.KeyCapabilities.Devices +OmitEmptyUnsupportedInV1 tailscale.com/client/tailscale/apitype.ExitNodeSuggestionResponse.Location +OmitEmptyUnsupportedInV1 tailscale.com/cmd/k8s-operator.ServiceMonitorSpec.NamespaceSelector +OmitEmptyUnsupportedInV1 tailscale.com/derp.ClientInfo.MeshKey +OmitEmptyUnsupportedInV1 tailscale.com/ipn.MaskedPrefs.AutoUpdateSet +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.Connector.ObjectMeta +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.Container.Resources +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.DNSConfig.ObjectMeta +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.ProxyClass.ObjectMeta +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.ProxyGroup.ObjectMeta +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.Recorder.ObjectMeta +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.RecorderContainer.Resources +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.RecorderPod.Container +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.RecorderPod.ServiceAccount +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.RecorderSpec.Storage +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.RecorderStatefulSet.Pod +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.S3.Credentials +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.S3Credentials.Secret +OmitEmptyUnsupportedInV1 tailscale.com/kube/kubeapi.Event.FirstTimestamp +OmitEmptyUnsupportedInV1 tailscale.com/kube/kubeapi.Event.LastTimestamp +OmitEmptyUnsupportedInV1 tailscale.com/kube/kubeapi.Event.Source +OmitEmptyUnsupportedInV1 tailscale.com/kube/kubeapi.ObjectMeta.CreationTimestamp +OmitEmptyUnsupportedInV1 tailscale.com/tailcfg_test.*.Groups +OmitEmptyUnsupportedInV1 tailscale.com/tailcfg.Oauth2Token.Expiry +OmitEmptyUnsupportedInV1 tailscale.com/tailcfg.QueryFeatureRequest.NodeKey +OmitEmptyUnsupportedInV2 tailscale.com/client/tailscale.*.ExpirySeconds +OmitEmptyUnsupportedInV2 tailscale.com/client/tailscale.DerpRegion.Preferred +OmitEmptyUnsupportedInV2 tailscale.com/client/tailscale.DevicePostureIdentity.Disabled +OmitEmptyUnsupportedInV2 tailscale.com/client/tailscale/apitype.DNSResolver.UseWithExitNode +OmitEmptyUnsupportedInV2 tailscale.com/client/web.authResponse.NeedsSynoAuth +OmitEmptyUnsupportedInV2 tailscale.com/cmd/tsidp.tailscaleClaims.UserID +OmitEmptyUnsupportedInV2 tailscale.com/derp.ClientInfo.IsProber +OmitEmptyUnsupportedInV2 tailscale.com/derp.ClientInfo.Version +OmitEmptyUnsupportedInV2 tailscale.com/derp.ServerInfo.TokenBucketBytesBurst +OmitEmptyUnsupportedInV2 tailscale.com/derp.ServerInfo.TokenBucketBytesPerSecond +OmitEmptyUnsupportedInV2 tailscale.com/derp.ServerInfo.Version +OmitEmptyUnsupportedInV2 tailscale.com/health.UnhealthyState.ImpactsConnectivity +OmitEmptyUnsupportedInV2 tailscale.com/ipn.AutoUpdatePrefsMask.ApplySet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.AutoUpdatePrefsMask.CheckSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.AdvertiseRoutesSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.AdvertiseServicesSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.AdvertiseTagsSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.AppConnectorSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.AutoExitNodeSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.ControlURLSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.CorpDNSSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.DriveSharesSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.EggSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.ExitNodeAllowLANAccessSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.ExitNodeIDSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.ExitNodeIPSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.ForceDaemonSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.HostnameSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.InternalExitNodePriorSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.LoggedOutSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.NetfilterKindSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.NetfilterModeSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.NoSNATSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.NoStatefulFilteringSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.NotepadURLsSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.OperatorUserSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.PostureCheckingSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.ProfileNameSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.RelayServerPortSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.RouteAllSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.RunSSHSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.RunWebClientSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.ShieldsUpSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.WantRunningSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.PartialFile.Done +OmitEmptyUnsupportedInV2 tailscale.com/ipn.Prefs.Egg +OmitEmptyUnsupportedInV2 tailscale.com/ipn.Prefs.ForceDaemon +OmitEmptyUnsupportedInV2 tailscale.com/ipn.ServiceConfig.Tun +OmitEmptyUnsupportedInV2 tailscale.com/ipn.TCPPortHandler.HTTP +OmitEmptyUnsupportedInV2 tailscale.com/ipn.TCPPortHandler.HTTPS +OmitEmptyUnsupportedInV2 tailscale.com/ipn/auditlog.transaction.Retries +OmitEmptyUnsupportedInV2 tailscale.com/ipn/ipnstate.PeerStatus.AltSharerUserID +OmitEmptyUnsupportedInV2 tailscale.com/ipn/ipnstate.PeerStatus.Expired +OmitEmptyUnsupportedInV2 tailscale.com/ipn/ipnstate.PeerStatus.ShareeNode +OmitEmptyUnsupportedInV2 tailscale.com/ipn/ipnstate.PingResult.IsLocalIP +OmitEmptyUnsupportedInV2 tailscale.com/ipn/ipnstate.PingResult.PeerAPIPort +OmitEmptyUnsupportedInV2 tailscale.com/ipn/ipnstate.Status.HaveNodeKey +OmitEmptyUnsupportedInV2 tailscale.com/k8s-operator/apis/v1alpha1.PortRange.EndPort +OmitEmptyUnsupportedInV2 tailscale.com/k8s-operator/apis/v1alpha1.ProxyClassSpec.UseLetsEncryptStagingEnvironment +OmitEmptyUnsupportedInV2 tailscale.com/k8s-operator/apis/v1alpha1.RecorderSpec.EnableUI +OmitEmptyUnsupportedInV2 tailscale.com/k8s-operator/apis/v1alpha1.TailscaleConfig.AcceptRoutes +OmitEmptyUnsupportedInV2 tailscale.com/kube/kubeapi.Event.Count +OmitEmptyUnsupportedInV2 tailscale.com/kube/kubeapi.ObjectMeta.Generation +OmitEmptyUnsupportedInV2 tailscale.com/kube/kubeapi.Status.Code +OmitEmptyUnsupportedInV2 tailscale.com/kube/kubetypes.KubernetesCapRule.EnforceRecorder +OmitEmptyUnsupportedInV2 tailscale.com/log/sockstatlog.event.IsCellularInterface +OmitEmptyUnsupportedInV2 tailscale.com/sessionrecording.CastHeader.SrcNodeUserID +OmitEmptyUnsupportedInV2 tailscale.com/sessionrecording.Source.NodeUserID +OmitEmptyUnsupportedInV2 tailscale.com/sessionrecording.v2ResponseFrame.Ack +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg_test.*.ToggleOn +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.AuditLogRequest.Version +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.C2NPostureIdentityResponse.PostureDisabled +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.C2NSSHUsernamesRequest.Max +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.C2NTLSCertInfo.Expired +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.C2NTLSCertInfo.Missing +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.C2NTLSCertInfo.Valid +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.ClientVersion.Notify +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.ClientVersion.RunningLatest +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.ClientVersion.UrgentSecurityUpdate +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.ControlIPCandidate.DialStartDelaySec +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.ControlIPCandidate.DialTimeoutSec +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.ControlIPCandidate.Priority +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Debug.DisableLogTail +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Debug.SleepSeconds +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DERPMap.OmitDefaultRegions +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DERPNode.CanPort80 +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DERPNode.DERPPort +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DERPNode.InsecureForTests +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DERPNode.STUNOnly +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DERPNode.STUNPort +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DERPRegion.Avoid +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DERPRegion.Latitude +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DERPRegion.Longitude +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DERPRegion.NoMeasureNoHome +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DisplayMessage.ImpactsConnectivity +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DNSConfig.Proxied +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Hostinfo.AllowsUpdate +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Hostinfo.IngressEnabled +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Hostinfo.NoLogsNoSupport +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Hostinfo.ShareeNode +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Hostinfo.ShieldsUp +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Hostinfo.WireIngress +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Location.Latitude +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Location.Longitude +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Location.Priority +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.MapRequest.MapSessionSeq +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.MapRequest.OmitPeers +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.MapRequest.ReadOnly +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.MapResponse.KeepAlive +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.MapResponse.Seq +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.NetInfo.HavePortMap +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Node.Cap +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Node.Expired +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Node.HomeDERP +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Node.IsJailed +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Node.IsWireGuardOnly +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Node.MachineAuthorized +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Node.Sharer +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Node.UnsignedPeerAPIOnly +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.PeerChange.Cap +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.PeerChange.DERPRegion +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.PingRequest.Log +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.PingRequest.URLIsNoise +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.PingResponse.DERPRegionID +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.PingResponse.IsLocalIP +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.PingResponse.LatencySeconds +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.PingResponse.PeerAPIPort +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.QueryFeatureResponse.Complete +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.QueryFeatureResponse.ShouldWait +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.RegisterRequest.Ephemeral +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.RegisterRequest.SignatureType +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.SSHAction.Accept +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.SSHAction.AllowAgentForwarding +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.SSHAction.AllowLocalPortForwarding +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.SSHAction.AllowRemotePortForwarding +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.SSHAction.Reject +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.SSHAction.SessionDuration +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.SSHPrincipal.Any +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.TKAInfo.Disabled +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.TPMInfo.FirmwareVersion +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.TPMInfo.Model +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.TPMInfo.SpecRevision +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.WebClientAuthResponse.Complete +OmitEmptyUnsupportedInV2 tailscale.com/tempfork/acme.*.TermsAgreed +OmitEmptyUnsupportedInV2 tailscale.com/tstime/rate.jsonValue.Updated +OmitEmptyUnsupportedInV2 tailscale.com/tstime/rate.jsonValue.Value +OmitEmptyUnsupportedInV2 tailscale.com/tsweb.AccessLogRecord.Bytes +OmitEmptyUnsupportedInV2 tailscale.com/tsweb.AccessLogRecord.Code +OmitEmptyUnsupportedInV2 tailscale.com/tsweb.AccessLogRecord.Seconds +OmitEmptyUnsupportedInV2 tailscale.com/tsweb.AccessLogRecord.TLS +OmitEmptyUnsupportedInV2 tailscale.com/tsweb/varz.SomeStats.TotalY +OmitEmptyUnsupportedInV2 tailscale.com/types/appctype.AppConnectorConfig.AdvertiseRoutes +OmitEmptyUnsupportedInV2 tailscale.com/types/dnstype.Resolver.UseWithExitNode +OmitEmptyUnsupportedInV2 tailscale.com/types/opt.testStruct.Int +OmitEmptyUnsupportedInV2 tailscale.com/version.Meta.GitDirty +OmitEmptyUnsupportedInV2 tailscale.com/version.Meta.IsDev +OmitEmptyUnsupportedInV2 tailscale.com/version.Meta.UnstableBranch diff --git a/cmd/vet/vet.go b/cmd/vet/vet.go new file mode 100644 index 0000000000000..45473af48f0ee --- /dev/null +++ b/cmd/vet/vet.go @@ -0,0 +1,24 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package vet is a tool to statically check Go source code. +package main + +import ( + _ "embed" + + "golang.org/x/tools/go/analysis/unitchecker" + "tailscale.com/cmd/vet/jsontags" +) + +//go:embed jsontags_allowlist +var jsontagsAllowlistSource string + +func init() { + jsontags.RegisterAllowlist(jsontags.ParseAllowlist(jsontagsAllowlistSource)) + jsontags.RegisterPureIsZeroMethods(jsontags.PureIsZeroMethodsInTailscaleModule) +} + +func main() { + unitchecker.Main(jsontags.Analyzer) +} diff --git a/cmd/viewer/tests/tests.go b/cmd/viewer/tests/tests.go index 4020e5651978a..d1c753db78710 100644 --- a/cmd/viewer/tests/tests.go +++ b/cmd/viewer/tests/tests.go @@ -13,7 +13,7 @@ import ( "tailscale.com/types/views" ) -//go:generate go run tailscale.com/cmd/viewer --type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers,StructWithTypeAliasFields,GenericTypeAliasStruct --clone-only-type=OnlyGetClone +//go:generate go run tailscale.com/cmd/viewer --type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers,StructWithTypeAliasFields,GenericTypeAliasStruct,StructWithMapOfViews --clone-only-type=OnlyGetClone type StructWithoutPtrs struct { Int int @@ -238,3 +238,7 @@ type GenericTypeAliasStruct[T integer, T2 views.ViewCloner[T2, V2], V2 views.Str NonCloneable T Cloneable T2 } + +type StructWithMapOfViews struct { + MapOfViews map[string]StructWithoutPtrsView +} diff --git a/cmd/viewer/tests/tests_clone.go b/cmd/viewer/tests/tests_clone.go index 106a9b6843b56..4602b9d887d2b 100644 --- a/cmd/viewer/tests/tests_clone.go +++ b/cmd/viewer/tests/tests_clone.go @@ -547,3 +547,20 @@ func _GenericTypeAliasStructCloneNeedsRegeneration[T integer, T2 views.ViewClone Cloneable T2 }{}) } + +// Clone makes a deep copy of StructWithMapOfViews. +// The result aliases no memory with the original. +func (src *StructWithMapOfViews) Clone() *StructWithMapOfViews { + if src == nil { + return nil + } + dst := new(StructWithMapOfViews) + *dst = *src + dst.MapOfViews = maps.Clone(src.MapOfViews) + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _StructWithMapOfViewsCloneNeedsRegeneration = StructWithMapOfViews(struct { + MapOfViews map[string]StructWithoutPtrsView +}{}) diff --git a/cmd/viewer/tests/tests_view.go b/cmd/viewer/tests/tests_view.go index e50a71c9e0220..495281c23b3aa 100644 --- a/cmd/viewer/tests/tests_view.go +++ b/cmd/viewer/tests/tests_view.go @@ -16,7 +16,7 @@ import ( "tailscale.com/types/views" ) -//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers,StructWithTypeAliasFields,GenericTypeAliasStruct +//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers,StructWithTypeAliasFields,GenericTypeAliasStruct,StructWithMapOfViews // View returns a read-only view of StructWithPtrs. func (p *StructWithPtrs) View() StructWithPtrsView { @@ -1053,3 +1053,79 @@ func _GenericTypeAliasStructViewNeedsRegeneration[T integer, T2 views.ViewCloner Cloneable T2 }{}) } + +// View returns a read-only view of StructWithMapOfViews. +func (p *StructWithMapOfViews) View() StructWithMapOfViewsView { + return StructWithMapOfViewsView{ж: p} +} + +// StructWithMapOfViewsView provides a read-only view over StructWithMapOfViews. +// +// Its methods should only be called if `Valid()` returns true. +type StructWithMapOfViewsView struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *StructWithMapOfViews +} + +// Valid reports whether v's underlying value is non-nil. +func (v StructWithMapOfViewsView) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v StructWithMapOfViewsView) AsStruct() *StructWithMapOfViews { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +// MarshalJSON implements [jsonv1.Marshaler]. +func (v StructWithMapOfViewsView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.ж) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v StructWithMapOfViewsView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.ж) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. +func (v *StructWithMapOfViewsView) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x StructWithMapOfViews + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *StructWithMapOfViewsView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.ж != nil { + return errors.New("already initialized") + } + var x StructWithMapOfViews + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v StructWithMapOfViewsView) MapOfViews() views.Map[string, StructWithoutPtrsView] { + return views.MapOf(v.ж.MapOfViews) +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _StructWithMapOfViewsViewNeedsRegeneration = StructWithMapOfViews(struct { + MapOfViews map[string]StructWithoutPtrsView +}{}) diff --git a/cmd/viewer/viewer.go b/cmd/viewer/viewer.go index 4fd81ea510d40..3fae737cde692 100644 --- a/cmd/viewer/viewer.go +++ b/cmd/viewer/viewer.go @@ -367,14 +367,21 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, fie case *types.Struct, *types.Named, *types.Alias: strucT := u args.FieldType = it.QualifiedName(fieldType) - if codegen.ContainsPointers(strucT) { + + // We need to call View() unless the type is + // either a View itself or does not contain + // pointers (and can thus be shallow-copied). + // + // Otherwise, we need to create a View of the + // map value. + if codegen.IsViewType(strucT) || !codegen.ContainsPointers(strucT) { + template = "mapField" + args.MapValueType = it.QualifiedName(mElem) + } else { args.MapFn = "t.View()" template = "mapFnField" args.MapValueType = it.QualifiedName(mElem) args.MapValueView = appendNameSuffix(args.MapValueType, "View") - } else { - template = "mapField" - args.MapValueType = it.QualifiedName(mElem) } case *types.Basic: template = "mapField" diff --git a/control/controlbase/conn.go b/control/controlbase/conn.go index dc22212e887cb..78ef73f71000b 100644 --- a/control/controlbase/conn.go +++ b/control/controlbase/conn.go @@ -18,6 +18,7 @@ import ( "golang.org/x/crypto/blake2s" chp "golang.org/x/crypto/chacha20poly1305" + "tailscale.com/syncs" "tailscale.com/types/key" ) @@ -48,7 +49,7 @@ type Conn struct { // rxState is all the Conn state that Read uses. type rxState struct { - sync.Mutex + syncs.Mutex cipher cipher.AEAD nonce nonce buf *maxMsgBuffer // or nil when reads exhausted diff --git a/control/controlclient/auto.go b/control/controlclient/auto.go index 9f5bf38aeecc6..336a8d491bc9c 100644 --- a/control/controlclient/auto.go +++ b/control/controlclient/auto.go @@ -23,6 +23,7 @@ import ( "tailscale.com/util/backoff" "tailscale.com/util/clientmetric" "tailscale.com/util/execqueue" + "tailscale.com/util/testenv" ) type LoginGoal struct { @@ -117,12 +118,13 @@ type Auto struct { logf logger.Logf closed bool updateCh chan struct{} // readable when we should inform the server of a change - observer Observer // called to update Client status; always non-nil + observer Observer // if non-nil, called to update Client status observerQueue execqueue.ExecQueue shutdownFn func() // to be called prior to shutdown or nil mu sync.Mutex // mutex guards the following fields + started bool // whether [Auto.Start] has been called wantLoggedIn bool // whether the user wants to be logged in per last method call urlToVisit string // the last url we were told to visit expiry time.Time @@ -138,7 +140,6 @@ type Auto struct { loggedIn bool // true if currently logged in loginGoal *LoginGoal // non-nil if some login activity is desired inMapPoll bool // true once we get the first MapResponse in a stream; false when HTTP response ends - state State // TODO(bradfitz): delete this, make it computed by method from other state authCtx context.Context // context used for auth requests mapCtx context.Context // context used for netmap and update requests @@ -151,15 +152,21 @@ type Auto struct { // New creates and starts a new Auto. func New(opts Options) (*Auto, error) { - c, err := NewNoStart(opts) - if c != nil { - c.Start() + c, err := newNoStart(opts) + if err != nil { + return nil, err + } + if opts.StartPaused { + c.SetPaused(true) + } + if !opts.SkipStartForTests { + c.start() } return c, err } -// NewNoStart creates a new Auto, but without calling Start on it. -func NewNoStart(opts Options) (_ *Auto, err error) { +// newNoStart creates a new Auto, but without calling Start on it. +func newNoStart(opts Options) (_ *Auto, err error) { direct, err := NewDirect(opts) if err != nil { return nil, err @@ -170,9 +177,6 @@ func NewNoStart(opts Options) (_ *Auto, err error) { } }() - if opts.Observer == nil { - return nil, errors.New("missing required Options.Observer") - } if opts.Logf == nil { opts.Logf = func(fmt string, args ...any) {} } @@ -222,10 +226,21 @@ func (c *Auto) SetPaused(paused bool) { c.unpauseWaiters = nil } -// Start starts the client's goroutines. +// StartForTest starts the client's goroutines. // -// It should only be called for clients created by NewNoStart. -func (c *Auto) Start() { +// It should only be called for clients created with [Options.SkipStartForTests]. +func (c *Auto) StartForTest() { + testenv.AssertInTest() + c.start() +} + +func (c *Auto) start() { + c.mu.Lock() + defer c.mu.Unlock() + if c.started { + return + } + c.started = true go c.authRoutine() go c.mapRoutine() go c.updateRoutine() @@ -299,10 +314,11 @@ func (c *Auto) authRoutine() { c.mu.Lock() goal := c.loginGoal ctx := c.authCtx + loggedIn := c.loggedIn if goal != nil { - c.logf("[v1] authRoutine: %s; wantLoggedIn=%v", c.state, true) + c.logf("[v1] authRoutine: loggedIn=%v; wantLoggedIn=%v", loggedIn, true) } else { - c.logf("[v1] authRoutine: %s; goal=nil paused=%v", c.state, c.paused) + c.logf("[v1] authRoutine: loggedIn=%v; goal=nil paused=%v", loggedIn, c.paused) } c.mu.Unlock() @@ -325,11 +341,6 @@ func (c *Auto) authRoutine() { c.mu.Lock() c.urlToVisit = goal.url - if goal.url != "" { - c.state = StateURLVisitRequired - } else { - c.state = StateAuthenticating - } c.mu.Unlock() var url string @@ -363,7 +374,6 @@ func (c *Auto) authRoutine() { flags: LoginDefault, url: url, } - c.state = StateURLVisitRequired c.mu.Unlock() c.sendStatus("authRoutine-url", err, url, nil) @@ -383,7 +393,6 @@ func (c *Auto) authRoutine() { c.urlToVisit = "" c.loggedIn = true c.loginGoal = nil - c.state = StateAuthenticated c.mu.Unlock() c.sendStatus("authRoutine-success", nil, "", nil) @@ -433,21 +442,17 @@ func (mrs mapRoutineState) UpdateFullNetmap(nm *netmap.NetworkMap) { c := mrs.c c.mu.Lock() - ctx := c.mapCtx c.inMapPoll = true - if c.loggedIn { - c.state = StateSynchronized - } - c.expiry = nm.Expiry + c.expiry = nm.SelfKeyExpiry() stillAuthed := c.loggedIn - c.logf("[v1] mapRoutine: netmap received: %s", c.state) + c.logf("[v1] mapRoutine: netmap received: loggedIn=%v inMapPoll=true", stillAuthed) c.mu.Unlock() if stillAuthed { c.sendStatus("mapRoutine-got-netmap", nil, "", nm) } // Reset the backoff timer if we got a netmap. - mrs.bo.BackOff(ctx, nil) + mrs.bo.Reset() } func (mrs mapRoutineState) UpdateNetmapDelta(muts []netmap.NodeMutation) bool { @@ -488,8 +493,8 @@ func (c *Auto) mapRoutine() { } c.mu.Lock() - c.logf("[v1] mapRoutine: %s", c.state) loggedIn := c.loggedIn + c.logf("[v1] mapRoutine: loggedIn=%v", loggedIn) ctx := c.mapCtx c.mu.Unlock() @@ -520,9 +525,6 @@ func (c *Auto) mapRoutine() { c.direct.health.SetOutOfPollNetMap() c.mu.Lock() c.inMapPoll = false - if c.state == StateSynchronized { - c.state = StateAuthenticated - } paused := c.paused c.mu.Unlock() @@ -588,12 +590,12 @@ func (c *Auto) sendStatus(who string, err error, url string, nm *netmap.NetworkM c.mu.Unlock() return } - state := c.state loggedIn := c.loggedIn inMapPoll := c.inMapPoll + loginGoal := c.loginGoal c.mu.Unlock() - c.logf("[v1] sendStatus: %s: %v", who, state) + c.logf("[v1] sendStatus: %s: loggedIn=%v inMapPoll=%v", who, loggedIn, inMapPoll) var p persist.PersistView if nm != nil && loggedIn && inMapPoll { @@ -604,18 +606,31 @@ func (c *Auto) sendStatus(who string, err error, url string, nm *netmap.NetworkM nm = nil } newSt := &Status{ - URL: url, - Persist: p, - NetMap: nm, - Err: err, - state: state, + URL: url, + Persist: p, + NetMap: nm, + Err: err, + LoggedIn: loggedIn && loginGoal == nil, + InMapPoll: inMapPoll, } + + if c.observer == nil { + return + } + c.lastStatus.Store(newSt) // Launch a new goroutine to avoid blocking the caller while the observer // does its thing, which may result in a call back into the client. metricQueued.Add(1) c.observerQueue.Add(func() { + c.mu.Lock() + closed := c.closed + c.mu.Unlock() + if closed { + return + } + if canSkipStatus(newSt, c.lastStatus.Load()) { metricSkippable.Add(1) if !c.direct.controlKnobs.DisableSkipStatusQueue.Load() { @@ -659,14 +674,15 @@ func canSkipStatus(s1, s2 *Status) bool { // we can't skip it. return false } - if s1.Err != nil || s1.URL != "" { - // If s1 has an error or a URL, we shouldn't skip it, lest the error go - // away in s2 or in-between. We want to make sure all the subsystems see - // it. Plus there aren't many of these, so not worth skipping. + if s1.Err != nil || s1.URL != "" || s1.LoggedIn { + // If s1 has an error, a URL, or LoginFinished set, we shouldn't skip it, + // lest the error go away in s2 or in-between. We want to make sure all + // the subsystems see it. Plus there aren't many of these, so not worth + // skipping. return false } - if !s1.Persist.Equals(s2.Persist) || s1.state != s2.state { - // If s1 has a different Persist or state than s2, + if !s1.Persist.Equals(s2.Persist) || s1.LoggedIn != s2.LoggedIn || s1.InMapPoll != s2.InMapPoll || s1.URL != s2.URL { + // If s1 has a different Persist, LoginFinished, Synced, or URL than s2, // don't skip it. We only care about skipping the typical // entries where the only difference is the NetMap. return false @@ -728,7 +744,6 @@ func (c *Auto) Logout(ctx context.Context) error { } c.mu.Lock() c.loggedIn = false - c.state = StateNotAuthenticated c.cancelAuthCtxLocked() c.cancelMapCtxLocked() c.mu.Unlock() @@ -752,6 +767,13 @@ func (c *Auto) UpdateEndpoints(endpoints []tailcfg.Endpoint) { } } +// SetDiscoPublicKey sets the client's Disco public to key and sends the change +// to the control server. +func (c *Auto) SetDiscoPublicKey(key key.DiscoPublic) { + c.direct.SetDiscoPublicKey(key) + c.updateControl() +} + func (c *Auto) Shutdown() { c.mu.Lock() if c.closed { diff --git a/control/controlclient/client.go b/control/controlclient/client.go index d0aa129ae95b4..41b39622b0199 100644 --- a/control/controlclient/client.go +++ b/control/controlclient/client.go @@ -12,6 +12,7 @@ import ( "context" "tailscale.com/tailcfg" + "tailscale.com/types/key" ) // LoginFlags is a bitmask of options to change the behavior of Client.Login @@ -80,7 +81,12 @@ type Client interface { // TODO: a server-side change would let us simply upload this // in a separate http request. It has nothing to do with the rest of // the state machine. + // Note: the auto client uploads the new endpoints to control immediately. UpdateEndpoints(endpoints []tailcfg.Endpoint) + // SetDiscoPublicKey updates the disco public key that will be sent in + // future map requests. This should be called after rotating the discovery key. + // Note: the auto client uploads the new key to control immediately. + SetDiscoPublicKey(key.DiscoPublic) // ClientID returns the ClientID of a client. This ID is meant to // distinguish one client from another. ClientID() int64 diff --git a/control/controlclient/controlclient_test.go b/control/controlclient/controlclient_test.go index 3914d10ef8310..bc301122673f7 100644 --- a/control/controlclient/controlclient_test.go +++ b/control/controlclient/controlclient_test.go @@ -15,7 +15,6 @@ import ( "net/netip" "net/url" "reflect" - "slices" "sync/atomic" "testing" "time" @@ -49,7 +48,7 @@ func fieldsOf(t reflect.Type) (fields []string) { func TestStatusEqual(t *testing.T) { // Verify that the Equal method stays in sync with reality - equalHandles := []string{"Err", "URL", "NetMap", "Persist", "state"} + equalHandles := []string{"Err", "URL", "LoggedIn", "InMapPoll", "NetMap", "Persist"} if have := fieldsOf(reflect.TypeFor[Status]()); !reflect.DeepEqual(have, equalHandles) { t.Errorf("Status.Equal check might be out of sync\nfields: %q\nhandled: %q\n", have, equalHandles) @@ -81,7 +80,7 @@ func TestStatusEqual(t *testing.T) { }, { &Status{}, - &Status{state: StateAuthenticated}, + &Status{LoggedIn: true, Persist: new(persist.Persist).View()}, false, }, } @@ -135,8 +134,20 @@ func TestCanSkipStatus(t *testing.T) { want: false, }, { - name: "s1-state-diff", - s1: &Status{state: 123, NetMap: nm1}, + name: "s1-login-finished-diff", + s1: &Status{LoggedIn: true, Persist: new(persist.Persist).View(), NetMap: nm1}, + s2: &Status{NetMap: nm2}, + want: false, + }, + { + name: "s1-login-finished", + s1: &Status{LoggedIn: true, Persist: new(persist.Persist).View(), NetMap: nm1}, + s2: &Status{NetMap: nm2}, + want: false, + }, + { + name: "s1-synced-diff", + s1: &Status{InMapPoll: true, LoggedIn: true, Persist: new(persist.Persist).View(), NetMap: nm1}, s2: &Status{NetMap: nm2}, want: false, }, @@ -167,10 +178,11 @@ func TestCanSkipStatus(t *testing.T) { }) } - want := []string{"Err", "URL", "NetMap", "Persist", "state"} - if f := fieldsOf(reflect.TypeFor[Status]()); !slices.Equal(f, want) { - t.Errorf("Status fields = %q; this code was only written to handle fields %q", f, want) + coveredFields := []string{"Err", "URL", "LoggedIn", "InMapPoll", "NetMap", "Persist"} + if have := fieldsOf(reflect.TypeFor[Status]()); !reflect.DeepEqual(have, coveredFields) { + t.Errorf("Status fields = %q; this code was only written to handle fields %q", have, coveredFields) } + } func TestRetryableErrors(t *testing.T) { diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index fe7cc235b05f8..d5cd6a13e5120 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -7,6 +7,8 @@ import ( "bytes" "cmp" "context" + "crypto" + "crypto/sha256" "encoding/binary" "encoding/json" "errors" @@ -21,7 +23,6 @@ import ( "runtime" "slices" "strings" - "sync" "sync/atomic" "time" @@ -42,6 +43,7 @@ import ( "tailscale.com/net/netx" "tailscale.com/net/tlsdial" "tailscale.com/net/tsdial" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tka" "tailscale.com/tstime" @@ -72,7 +74,6 @@ type Direct struct { logf logger.Logf netMon *netmon.Monitor // non-nil health *health.Tracker - discoPubKey key.DiscoPublic busClient *eventbus.Client clientVersionPub *eventbus.Publisher[tailcfg.ClientVersion] autoUpdatePub *eventbus.Publisher[AutoUpdate] @@ -90,9 +91,10 @@ type Direct struct { dialPlan ControlDialPlanner // can be nil - mu sync.Mutex // mutex guards the following fields + mu syncs.Mutex // mutex guards the following fields serverLegacyKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key; only used for signRegisterRequest on Windows now serverNoiseKey key.MachinePublic + discoPubKey key.DiscoPublic // protected by mu; can be updated via [SetDiscoPublicKey] sfGroup singleflight.Group[struct{}, *ts2021.Client] // protects noiseClient creation. noiseClient *ts2021.Client // also protected by mu @@ -113,6 +115,9 @@ type Direct struct { // Observer is implemented by users of the control client (such as LocalBackend) // to get notified of changes in the control client's status. +// +// If an implementation of Observer also implements [NetmapDeltaUpdater], they get +// delta updates as well as full netmap updates. type Observer interface { // SetControlClientStatus is called when the client has a new status to // report. The Client is provided to allow the Observer to track which @@ -141,8 +146,17 @@ type Options struct { ControlKnobs *controlknobs.Knobs // or nil to ignore Bus *eventbus.Bus // non-nil, for setting up publishers + SkipStartForTests bool // if true, don't call [Auto.Start] to avoid any background goroutines (for tests only) + + // StartPaused indicates whether the client should start in a paused state + // where it doesn't do network requests. This primarily exists for testing + // but not necessarily "go test" tests, so it isn't restricted to only + // being used in tests. + StartPaused bool + // Observer is called when there's a change in status to report // from the control client. + // If nil, no status updates are reported. Observer Observer // SkipIPForwardingCheck declares that the host's IP @@ -302,7 +316,6 @@ func NewDirect(opts Options) (*Direct, error) { logf: opts.Logf, persist: opts.Persist.View(), authKey: opts.AuthKey, - discoPubKey: opts.DiscoPublicKey, debugFlags: opts.DebugFlags, netMon: netMon, health: opts.HealthTracker, @@ -315,6 +328,7 @@ func NewDirect(opts Options) (*Direct, error) { dnsCache: dnsCache, dialPlan: opts.DialPlan, } + c.discoPubKey = opts.DiscoPublicKey c.closedCtx, c.closeCtx = context.WithCancel(context.Background()) c.controlClientID = nextControlClientID.Add(1) @@ -839,6 +853,14 @@ func (c *Direct) SendUpdate(ctx context.Context) error { return c.sendMapRequest(ctx, false, nil) } +// SetDiscoPublicKey updates the disco public key in local state. +// It does not implicitly trigger [SendUpdate]; callers should arrange for that. +func (c *Direct) SetDiscoPublicKey(key key.DiscoPublic) { + c.mu.Lock() + defer c.mu.Unlock() + c.discoPubKey = key +} + // ClientID returns the controlClientID of the controlClient. func (c *Direct) ClientID() int64 { return c.controlClientID @@ -888,6 +910,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap persist := c.persist serverURL := c.serverURL serverNoiseKey := c.serverNoiseKey + discoKey := c.discoPubKey hi := c.hostInfoLocked() backendLogID := hi.BackendLogID connectionHandleForTest := c.connectionHandleForTest @@ -931,11 +954,12 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap } nodeKey := persist.PublicNodeKey() + request := &tailcfg.MapRequest{ Version: tailcfg.CurrentCapabilityVersion, KeepAlive: true, NodeKey: nodeKey, - DiscoKey: c.discoPubKey, + DiscoKey: discoKey, Endpoints: eps, EndpointTypes: epTypes, Stream: isStreaming, @@ -946,6 +970,26 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap ConnectionHandleForTest: connectionHandleForTest, } + // If we have a hardware attestation key, sign the node key with it and send + // the key & signature in the map request. + if buildfeatures.HasTPM { + if k := persist.AsStruct().AttestationKey; k != nil && !k.IsZero() { + hwPub := key.HardwareAttestationPublicFromPlatformKey(k) + request.HardwareAttestationKey = hwPub + + t := c.clock.Now() + msg := fmt.Sprintf("%d|%s", t.Unix(), nodeKey.String()) + digest := sha256.Sum256([]byte(msg)) + sig, err := k.Sign(nil, digest[:], crypto.SHA256) + if err != nil { + c.logf("failed to sign node key with hardware attestation key: %v", err) + } else { + request.HardwareAttestationKeySignature = sig + request.HardwareAttestationKeySignatureTimestamp = t + } + } + } + var extraDebugFlags []string if buildfeatures.HasAdvertiseRoutes && hi != nil && c.netMon != nil && !c.skipIPForwardingCheck && ipForwardingBroken(hi.RoutableIPs, c.netMon.InterfaceState()) { @@ -1059,7 +1103,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap c.persist = newPersist.View() persist = c.persist } - c.expiry = nm.Expiry + c.expiry = nm.SelfKeyExpiry() } // gotNonKeepAliveMessage is whether we've yet received a MapResponse message without @@ -1140,7 +1184,19 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap metricMapResponseKeepAlives.Add(1) continue } - if au, ok := resp.DefaultAutoUpdate.Get(); ok { + + // DefaultAutoUpdate in its CapMap and deprecated top-level field forms. + if self := resp.Node; self != nil { + for _, v := range self.CapMap[tailcfg.NodeAttrDefaultAutoUpdate] { + switch v { + case "true", "false": + c.autoUpdatePub.Publish(AutoUpdate{c.controlClientID, v == "true"}) + default: + c.logf("netmap: [unexpected] unknown %s in CapMap: %q", tailcfg.NodeAttrDefaultAutoUpdate, v) + } + } + } + if au, ok := resp.DeprecatedDefaultAutoUpdate.Get(); ok { c.autoUpdatePub.Publish(AutoUpdate{c.controlClientID, au}) } diff --git a/control/controlclient/direct_test.go b/control/controlclient/direct_test.go index dd93dc7b33d61..4329fc878ceb3 100644 --- a/control/controlclient/direct_test.go +++ b/control/controlclient/direct_test.go @@ -20,6 +20,32 @@ import ( "tailscale.com/util/eventbus/eventbustest" ) +func TestSetDiscoPublicKey(t *testing.T) { + initialKey := key.NewDisco().Public() + + c := &Direct{ + discoPubKey: initialKey, + } + + c.mu.Lock() + if c.discoPubKey != initialKey { + t.Fatalf("initial disco key mismatch: got %v, want %v", c.discoPubKey, initialKey) + } + c.mu.Unlock() + + newKey := key.NewDisco().Public() + c.SetDiscoPublicKey(newKey) + + c.mu.Lock() + if c.discoPubKey != newKey { + t.Fatalf("disco key not updated: got %v, want %v", c.discoPubKey, newKey) + } + if c.discoPubKey == initialKey { + t.Fatal("disco key should have changed") + } + c.mu.Unlock() +} + func TestNewDirect(t *testing.T) { hi := hostinfo.New() ni := tailcfg.NetInfo{LinkType: "wired"} diff --git a/control/controlclient/map.go b/control/controlclient/map.go index eafdb2d565a76..9aa8e37107a99 100644 --- a/control/controlclient/map.go +++ b/control/controlclient/map.go @@ -867,7 +867,6 @@ func (ms *mapSession) netmap() *netmap.NetworkMap { nm := &netmap.NetworkMap{ NodeKey: ms.publicNodeKey, - PrivateKey: ms.privateNodeKey, MachineKey: ms.machinePubKey, Peers: peerViews, UserProfiles: make(map[tailcfg.UserID]tailcfg.UserProfileView), @@ -892,8 +891,6 @@ func (ms *mapSession) netmap() *netmap.NetworkMap { if node := ms.lastNode; node.Valid() { nm.SelfNode = node - nm.Expiry = node.KeyExpiry() - nm.Name = node.Name() nm.AllCaps = ms.lastCapSet } diff --git a/control/controlclient/status.go b/control/controlclient/status.go index d0fdf80d745e3..65afb7a5011f2 100644 --- a/control/controlclient/status.go +++ b/control/controlclient/status.go @@ -4,8 +4,6 @@ package controlclient import ( - "encoding/json" - "fmt" "reflect" "tailscale.com/types/netmap" @@ -13,57 +11,6 @@ import ( "tailscale.com/types/structs" ) -// State is the high-level state of the client. It is used only in -// unit tests for proper sequencing, don't depend on it anywhere else. -// -// TODO(apenwarr): eliminate the state, as it's now obsolete. -// -// apenwarr: Historical note: controlclient.Auto was originally -// intended to be the state machine for the whole tailscale client, but that -// turned out to not be the right abstraction layer, and it moved to -// ipn.Backend. Since ipn.Backend now has a state machine, it would be -// much better if controlclient could be a simple stateless API. But the -// current server-side API (two interlocking polling https calls) makes that -// very hard to implement. A server side API change could untangle this and -// remove all the statefulness. -type State int - -const ( - StateNew = State(iota) - StateNotAuthenticated - StateAuthenticating - StateURLVisitRequired - StateAuthenticated - StateSynchronized // connected and received map update -) - -func (s State) AppendText(b []byte) ([]byte, error) { - return append(b, s.String()...), nil -} - -func (s State) MarshalText() ([]byte, error) { - return []byte(s.String()), nil -} - -func (s State) String() string { - switch s { - case StateNew: - return "state:new" - case StateNotAuthenticated: - return "state:not-authenticated" - case StateAuthenticating: - return "state:authenticating" - case StateURLVisitRequired: - return "state:url-visit-required" - case StateAuthenticated: - return "state:authenticated" - case StateSynchronized: - return "state:synchronized" - default: - return fmt.Sprintf("state:unknown:%d", int(s)) - } -} - type Status struct { _ structs.Incomparable @@ -76,6 +23,14 @@ type Status struct { // URL, if non-empty, is the interactive URL to visit to finish logging in. URL string + // LoggedIn, if true, indicates that serveRegister has completed and no + // other login change is in progress. + LoggedIn bool + + // InMapPoll, if true, indicates that we've received at least one netmap + // and are connected to receive updates. + InMapPoll bool + // NetMap is the latest server-pushed state of the tailnet network. NetMap *netmap.NetworkMap @@ -83,26 +38,8 @@ type Status struct { // // TODO(bradfitz,maisem): clarify this. Persist persist.PersistView - - // state is the internal state. It should not be exposed outside this - // package, but we have some automated tests elsewhere that need to - // use it via the StateForTest accessor. - // TODO(apenwarr): Unexport or remove these. - state State } -// LoginFinished reports whether the controlclient is in its "StateAuthenticated" -// state where it's in a happy register state but not yet in a map poll. -// -// TODO(bradfitz): delete this and everything around Status.state. -func (s *Status) LoginFinished() bool { return s.state == StateAuthenticated } - -// StateForTest returns the internal state of s for tests only. -func (s *Status) StateForTest() State { return s.state } - -// SetStateForTest sets the internal state of s for tests only. -func (s *Status) SetStateForTest(state State) { s.state = state } - // Equal reports whether s and s2 are equal. func (s *Status) Equal(s2 *Status) bool { if s == nil && s2 == nil { @@ -111,15 +48,8 @@ func (s *Status) Equal(s2 *Status) bool { return s != nil && s2 != nil && s.Err == s2.Err && s.URL == s2.URL && - s.state == s2.state && + s.LoggedIn == s2.LoggedIn && + s.InMapPoll == s2.InMapPoll && reflect.DeepEqual(s.Persist, s2.Persist) && reflect.DeepEqual(s.NetMap, s2.NetMap) } - -func (s Status) String() string { - b, err := json.MarshalIndent(s, "", "\t") - if err != nil { - panic(err) - } - return s.state.String() + " " + string(b) -} diff --git a/derp/derphttp/derphttp_test.go b/derp/derphttp/derphttp_test.go index 76681d4984252..5208481ed7258 100644 --- a/derp/derphttp/derphttp_test.go +++ b/derp/derphttp/derphttp_test.go @@ -620,6 +620,9 @@ func TestURLDial(t *testing.T) { } netMon := netmon.NewStatic() c, err := derphttp.NewClient(key.NewNode(), "https://"+hostname+"/", t.Logf, netMon) + if err != nil { + t.Errorf("NewClient: %v", err) + } defer c.Close() if err := c.Connect(context.Background()); err != nil { diff --git a/derp/derpserver/derpserver.go b/derp/derpserver/derpserver.go index 31cf9363a43bf..0bbc667806a5a 100644 --- a/derp/derpserver/derpserver.go +++ b/derp/derpserver/derpserver.go @@ -177,7 +177,7 @@ type Server struct { verifyClientsURL string verifyClientsURLFailOpen bool - mu sync.Mutex + mu syncs.Mutex closed bool netConns map[derp.Conn]chan struct{} // chan is closed when conn closes clients map[key.NodePublic]*clientSet diff --git a/drive/driveimpl/connlistener.go b/drive/driveimpl/connlistener.go index e1fcb3b675924..ff60f73404230 100644 --- a/drive/driveimpl/connlistener.go +++ b/drive/driveimpl/connlistener.go @@ -25,12 +25,12 @@ func newConnListener() *connListener { } } -func (l *connListener) Accept() (net.Conn, error) { +func (ln *connListener) Accept() (net.Conn, error) { select { - case <-l.closedCh: + case <-ln.closedCh: // TODO(oxtoacart): make this error match what a regular net.Listener does return nil, syscall.EINVAL - case conn := <-l.ch: + case conn := <-ln.ch: return conn, nil } } @@ -38,32 +38,32 @@ func (l *connListener) Accept() (net.Conn, error) { // Addr implements net.Listener. This always returns nil. It is assumed that // this method is currently unused, so it logs a warning if it ever does get // called. -func (l *connListener) Addr() net.Addr { +func (ln *connListener) Addr() net.Addr { log.Println("warning: unexpected call to connListener.Addr()") return nil } -func (l *connListener) Close() error { - l.closeMu.Lock() - defer l.closeMu.Unlock() +func (ln *connListener) Close() error { + ln.closeMu.Lock() + defer ln.closeMu.Unlock() select { - case <-l.closedCh: + case <-ln.closedCh: // Already closed. return syscall.EINVAL default: // We don't close l.ch because someone maybe trying to send to that, // which would cause a panic. - close(l.closedCh) + close(ln.closedCh) return nil } } -func (l *connListener) HandleConn(c net.Conn, remoteAddr net.Addr) error { +func (ln *connListener) HandleConn(c net.Conn, remoteAddr net.Addr) error { select { - case <-l.closedCh: + case <-ln.closedCh: return syscall.EINVAL - case l.ch <- &connWithRemoteAddr{Conn: c, remoteAddr: remoteAddr}: + case ln.ch <- &connWithRemoteAddr{Conn: c, remoteAddr: remoteAddr}: // Connection has been accepted. } return nil diff --git a/drive/driveimpl/connlistener_test.go b/drive/driveimpl/connlistener_test.go index d8666448af6ef..6adf15acbd56f 100644 --- a/drive/driveimpl/connlistener_test.go +++ b/drive/driveimpl/connlistener_test.go @@ -10,20 +10,20 @@ import ( ) func TestConnListener(t *testing.T) { - l, err := net.Listen("tcp", "127.0.0.1:") + ln, err := net.Listen("tcp", "127.0.0.1:") if err != nil { t.Fatalf("failed to Listen: %s", err) } cl := newConnListener() // Test that we can accept a connection - cc, err := net.Dial("tcp", l.Addr().String()) + cc, err := net.Dial("tcp", ln.Addr().String()) if err != nil { t.Fatalf("failed to Dial: %s", err) } defer cc.Close() - sc, err := l.Accept() + sc, err := ln.Accept() if err != nil { t.Fatalf("failed to Accept: %s", err) } diff --git a/drive/driveimpl/drive_test.go b/drive/driveimpl/drive_test.go index cff55fbb2c858..818e84990baef 100644 --- a/drive/driveimpl/drive_test.go +++ b/drive/driveimpl/drive_test.go @@ -467,14 +467,14 @@ func newSystem(t *testing.T) *system { tstest.ResourceCheck(t) fs := newFileSystemForLocal(log.Printf, nil) - l, err := net.Listen("tcp", "127.0.0.1:0") + ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to Listen: %s", err) } - t.Logf("FileSystemForLocal listening at %s", l.Addr()) + t.Logf("FileSystemForLocal listening at %s", ln.Addr()) go func() { for { - conn, err := l.Accept() + conn, err := ln.Accept() if err != nil { t.Logf("Accept: %v", err) return @@ -483,11 +483,11 @@ func newSystem(t *testing.T) *system { } }() - client := gowebdav.NewAuthClient(fmt.Sprintf("http://%s", l.Addr()), &noopAuthorizer{}) + client := gowebdav.NewAuthClient(fmt.Sprintf("http://%s", ln.Addr()), &noopAuthorizer{}) client.SetTransport(&http.Transport{DisableKeepAlives: true}) s := &system{ t: t, - local: &local{l: l, fs: fs}, + local: &local{l: ln, fs: fs}, client: client, remotes: make(map[string]*remote), } @@ -496,11 +496,11 @@ func newSystem(t *testing.T) *system { } func (s *system) addRemote(name string) string { - l, err := net.Listen("tcp", "127.0.0.1:0") + ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { s.t.Fatalf("failed to Listen: %s", err) } - s.t.Logf("Remote for %v listening at %s", name, l.Addr()) + s.t.Logf("Remote for %v listening at %s", name, ln.Addr()) fileServer, err := NewFileServer() if err != nil { @@ -510,14 +510,14 @@ func (s *system) addRemote(name string) string { s.t.Logf("FileServer for %v listening at %s", name, fileServer.Addr()) r := &remote{ - l: l, + l: ln, fileServer: fileServer, fs: NewFileSystemForRemote(log.Printf), shares: make(map[string]string), permissions: make(map[string]drive.Permission), } r.fs.SetFileServerAddr(fileServer.Addr()) - go http.Serve(l, r) + go http.Serve(ln, r) s.remotes[name] = r remotes := make([]*drive.Remote, 0, len(s.remotes)) diff --git a/drive/driveimpl/fileserver.go b/drive/driveimpl/fileserver.go index 113cb3b440218..d448d83af761d 100644 --- a/drive/driveimpl/fileserver.go +++ b/drive/driveimpl/fileserver.go @@ -20,7 +20,7 @@ import ( // It's typically used in a separate process from the actual Taildrive server to // serve up files as an unprivileged user. type FileServer struct { - l net.Listener + ln net.Listener secretToken string shareHandlers map[string]http.Handler sharesMu sync.RWMutex @@ -41,10 +41,10 @@ type FileServer struct { // called. func NewFileServer() (*FileServer, error) { // path := filepath.Join(os.TempDir(), fmt.Sprintf("%v.socket", uuid.New().String())) - // l, err := safesocket.Listen(path) + // ln, err := safesocket.Listen(path) // if err != nil { // TODO(oxtoacart): actually get safesocket working in more environments (MacOS Sandboxed, Windows, ???) - l, err := net.Listen("tcp", "127.0.0.1:0") + ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return nil, fmt.Errorf("listen: %w", err) } @@ -55,7 +55,7 @@ func NewFileServer() (*FileServer, error) { } return &FileServer{ - l: l, + ln: ln, secretToken: secretToken, shareHandlers: make(map[string]http.Handler), }, nil @@ -74,12 +74,12 @@ func generateSecretToken() (string, error) { // Addr returns the address at which this FileServer is listening. This // includes the secret token in front of the address, delimited by a pipe |. func (s *FileServer) Addr() string { - return fmt.Sprintf("%s|%s", s.secretToken, s.l.Addr().String()) + return fmt.Sprintf("%s|%s", s.secretToken, s.ln.Addr().String()) } // Serve() starts serving files and blocks until it encounters a fatal error. func (s *FileServer) Serve() error { - return http.Serve(s.l, s) + return http.Serve(s.ln, s) } // LockShares locks the map of shares in preparation for manipulating it. @@ -162,5 +162,5 @@ func (s *FileServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (s *FileServer) Close() error { - return s.l.Close() + return s.ln.Close() } diff --git a/envknob/envknob.go b/envknob/envknob.go index 9dea8f74d15df..17a21387ecaea 100644 --- a/envknob/envknob.go +++ b/envknob/envknob.go @@ -28,19 +28,19 @@ import ( "slices" "strconv" "strings" - "sync" "sync/atomic" "time" "tailscale.com/feature/buildfeatures" "tailscale.com/kube/kubetypes" + "tailscale.com/syncs" "tailscale.com/types/opt" "tailscale.com/version" "tailscale.com/version/distro" ) var ( - mu sync.Mutex + mu syncs.Mutex // +checklocks:mu set = map[string]string{} // +checklocks:mu diff --git a/feature/buildfeatures/feature_cachenetmap_disabled.go b/feature/buildfeatures/feature_cachenetmap_disabled.go new file mode 100644 index 0000000000000..22407fe38a57f --- /dev/null +++ b/feature/buildfeatures/feature_cachenetmap_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_cachenetmap + +package buildfeatures + +// HasCacheNetMap is whether the binary was built with support for modular feature "Cache the netmap on disk between runs". +// Specifically, it's whether the binary was NOT built with the "ts_omit_cachenetmap" build tag. +// It's a const so it can be used for dead code elimination. +const HasCacheNetMap = false diff --git a/feature/buildfeatures/feature_cachenetmap_enabled.go b/feature/buildfeatures/feature_cachenetmap_enabled.go new file mode 100644 index 0000000000000..02663c416bcbb --- /dev/null +++ b/feature/buildfeatures/feature_cachenetmap_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_cachenetmap + +package buildfeatures + +// HasCacheNetMap is whether the binary was built with support for modular feature "Cache the netmap on disk between runs". +// Specifically, it's whether the binary was NOT built with the "ts_omit_cachenetmap" build tag. +// It's a const so it can be used for dead code elimination. +const HasCacheNetMap = true diff --git a/feature/buildfeatures/feature_identity_federation_disabled.go b/feature/buildfeatures/feature_identityfederation_disabled.go similarity index 70% rename from feature/buildfeatures/feature_identity_federation_disabled.go rename to feature/buildfeatures/feature_identityfederation_disabled.go index c7b16f729cbc5..94488adc8637c 100644 --- a/feature/buildfeatures/feature_identity_federation_disabled.go +++ b/feature/buildfeatures/feature_identityfederation_disabled.go @@ -3,11 +3,11 @@ // Code generated by gen.go; DO NOT EDIT. -//go:build ts_omit_identity_federation +//go:build ts_omit_identityfederation package buildfeatures -// HasIdentityFederation is whether the binary was built with support for modular feature "Identity token exchange for auth key support". -// Specifically, it's whether the binary was NOT built with the "ts_omit_identity_federation" build tag. +// HasIdentityFederation is whether the binary was built with support for modular feature "Auth key generation via identity federation support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_identityfederation" build tag. // It's a const so it can be used for dead code elimination. const HasIdentityFederation = false diff --git a/feature/buildfeatures/feature_identity_federation_enabled.go b/feature/buildfeatures/feature_identityfederation_enabled.go similarity index 70% rename from feature/buildfeatures/feature_identity_federation_enabled.go rename to feature/buildfeatures/feature_identityfederation_enabled.go index 1f7cf17423c96..892d62d66c37c 100644 --- a/feature/buildfeatures/feature_identity_federation_enabled.go +++ b/feature/buildfeatures/feature_identityfederation_enabled.go @@ -3,11 +3,11 @@ // Code generated by gen.go; DO NOT EDIT. -//go:build !ts_omit_identity_federation +//go:build !ts_omit_identityfederation package buildfeatures -// HasIdentityFederation is whether the binary was built with support for modular feature "Identity token exchange for auth key support". -// Specifically, it's whether the binary was NOT built with the "ts_omit_identity_federation" build tag. +// HasIdentityFederation is whether the binary was built with support for modular feature "Auth key generation via identity federation support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_identityfederation" build tag. // It's a const so it can be used for dead code elimination. const HasIdentityFederation = true diff --git a/feature/feature.go b/feature/feature.go index 110b104daae00..48a4aff43b84d 100644 --- a/feature/feature.go +++ b/feature/feature.go @@ -7,6 +7,8 @@ package feature import ( "errors" "reflect" + + "tailscale.com/util/testenv" ) var ErrUnavailable = errors.New("feature not included in this build") @@ -55,6 +57,19 @@ func (h *Hook[Func]) Set(f Func) { h.ok = true } +// SetForTest sets the hook function for tests, blowing +// away any previous value. It will panic if called from +// non-test code. +// +// It returns a restore function that resets the hook +// to its previous value. +func (h *Hook[Func]) SetForTest(f Func) (restore func()) { + testenv.AssertInTest() + old := *h + h.f, h.ok = f, true + return func() { *h = old } +} + // Get returns the hook function, or panics if it hasn't been set. // Use IsSet to check if it's been set, or use GetOrNil if you're // okay with a nil return value. diff --git a/feature/featuretags/featuretags.go b/feature/featuretags/featuretags.go index c93e8b15b1001..44b1295769c56 100644 --- a/feature/featuretags/featuretags.go +++ b/feature/featuretags/featuretags.go @@ -123,6 +123,10 @@ var Features = map[FeatureTag]FeatureMeta{ Desc: "Control-to-node (C2N) support", ImplementationDetail: true, }, + "cachenetmap": { + Sym: "CacheNetMap", + Desc: "Cache the netmap on disk between runs", + }, "captiveportal": {Sym: "CaptivePortal", Desc: "Captive portal detection"}, "capture": {Sym: "Capture", Desc: "Packet capture"}, "cli": {Sym: "CLI", Desc: "embed the CLI into the tailscaled binary"}, diff --git a/feature/hooks.go b/feature/hooks.go index a3c6c0395ee81..7e31061a7eaac 100644 --- a/feature/hooks.go +++ b/feature/hooks.go @@ -6,6 +6,8 @@ package feature import ( "net/http" "net/url" + "os" + "sync" "tailscale.com/types/logger" "tailscale.com/types/persist" @@ -15,9 +17,16 @@ import ( // to conditionally initialize. var HookCanAutoUpdate Hook[func() bool] +var testAllowAutoUpdate = sync.OnceValue(func() bool { + return os.Getenv("TS_TEST_ALLOW_AUTO_UPDATE") == "1" +}) + // CanAutoUpdate reports whether the current binary is built with auto-update // support and, if so, whether the current platform supports it. func CanAutoUpdate() bool { + if testAllowAutoUpdate() { + return true + } if f, ok := HookCanAutoUpdate.GetOk(); ok { return f() } diff --git a/feature/relayserver/relayserver.go b/feature/relayserver/relayserver.go index 868d5f61a2fa7..4f23ae18e4248 100644 --- a/feature/relayserver/relayserver.go +++ b/feature/relayserver/relayserver.go @@ -8,14 +8,10 @@ package relayserver import ( "encoding/json" "fmt" - "log" "net/http" "net/netip" - "strings" - "sync" "tailscale.com/disco" - "tailscale.com/envknob" "tailscale.com/feature" "tailscale.com/ipn" "tailscale.com/ipn/ipnext" @@ -23,10 +19,12 @@ import ( "tailscale.com/net/udprelay" "tailscale.com/net/udprelay/endpoint" "tailscale.com/net/udprelay/status" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/ptr" + "tailscale.com/types/views" "tailscale.com/util/eventbus" "tailscale.com/wgengine/magicsock" ) @@ -71,8 +69,8 @@ func servePeerRelayDebugSessions(h *localapi.Handler, w http.ResponseWriter, r * // imported. func newExtension(logf logger.Logf, sb ipnext.SafeBackend) (ipnext.Extension, error) { e := &extension{ - newServerFn: func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) { - return udprelay.NewServer(logf, port, overrideAddrs) + newServerFn: func(logf logger.Logf, port uint16, onlyStaticAddrPorts bool) (relayServer, error) { + return udprelay.NewServer(logf, port, onlyStaticAddrPorts) }, logf: logger.WithPrefix(logf, featureName+": "), } @@ -89,22 +87,24 @@ type relayServer interface { AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.ServerEndpoint, error) GetSessions() []status.ServerSession SetDERPMapView(tailcfg.DERPMapView) + SetStaticAddrPorts(addrPorts views.Slice[netip.AddrPort]) } // extension is an [ipnext.Extension] managing the relay server on platforms // that import this package. type extension struct { - newServerFn func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) // swappable for tests + newServerFn func(logf logger.Logf, port uint16, onlyStaticAddrPorts bool) (relayServer, error) // swappable for tests logf logger.Logf ec *eventbus.Client respPub *eventbus.Publisher[magicsock.UDPRelayAllocResp] - mu sync.Mutex // guards the following fields - shutdown bool // true if Shutdown() has been called - rs relayServer // nil when disabled - port *int // ipn.Prefs.RelayServerPort, nil if disabled - derpMapView tailcfg.DERPMapView // latest seen over the eventbus - hasNodeAttrDisableRelayServer bool // [tailcfg.NodeAttrDisableRelayServer] + mu syncs.Mutex // guards the following fields + shutdown bool // true if Shutdown() has been called + rs relayServer // nil when disabled + port *uint16 // ipn.Prefs.RelayServerPort, nil if disabled + staticEndpoints views.Slice[netip.AddrPort] // ipn.Prefs.RelayServerStaticEndpoints + derpMapView tailcfg.DERPMapView // latest seen over the eventbus + hasNodeAttrDisableRelayServer bool // [tailcfg.NodeAttrDisableRelayServer] } // Name implements [ipnext.Extension]. @@ -151,7 +151,12 @@ func (e *extension) onAllocReq(req magicsock.UDPRelayAllocReq) { e.logf("error allocating endpoint: %v", err) return } - e.respPub.Publish(magicsock.UDPRelayAllocResp{ + // Take a defensive stance around publishing from within an + // [*eventbus.SubscribeFunc] by publishing from a separate goroutine. At the + // time of writing (2025-11-21), publishing from within the + // [*eventbus.SubscribeFunc] goroutine is potentially unsafe if publisher + // and subscriber share a lock. + go e.respPub.Publish(magicsock.UDPRelayAllocResp{ ReqRxFromNodeKey: req.RxFromNodeKey, ReqRxFromDiscoKey: req.RxFromDiscoKey, Message: &disco.AllocateUDPRelayEndpointResponse{ @@ -170,7 +175,7 @@ func (e *extension) onAllocReq(req magicsock.UDPRelayAllocReq) { } func (e *extension) tryStartRelayServerLocked() { - rs, err := e.newServerFn(e.logf, *e.port, overrideAddrs()) + rs, err := e.newServerFn(e.logf, *e.port, false) if err != nil { e.logf("error initializing server: %v", err) return @@ -185,6 +190,7 @@ func (e *extension) relayServerShouldBeRunningLocked() bool { // handleRelayServerLifetimeLocked handles the lifetime of [e.rs]. func (e *extension) handleRelayServerLifetimeLocked() { + defer e.handleRelayServerStaticAddrPortsLocked() if !e.relayServerShouldBeRunningLocked() { e.stopRelayServerLocked() return @@ -194,6 +200,13 @@ func (e *extension) handleRelayServerLifetimeLocked() { e.tryStartRelayServerLocked() } +func (e *extension) handleRelayServerStaticAddrPortsLocked() { + if e.rs != nil { + // TODO(jwhited): env var support + e.rs.SetStaticAddrPorts(e.staticEndpoints) + } +} + func (e *extension) selfNodeViewChanged(nodeView tailcfg.NodeView) { e.mu.Lock() defer e.mu.Unlock() @@ -204,6 +217,7 @@ func (e *extension) selfNodeViewChanged(nodeView tailcfg.NodeView) { func (e *extension) profileStateChanged(_ ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { e.mu.Lock() defer e.mu.Unlock() + e.staticEndpoints = prefs.RelayServerStaticEndpoints() newPort, ok := prefs.RelayServerPort().GetOk() enableOrDisableServer := ok != (e.port != nil) portChanged := ok && e.port != nil && newPort != *e.port @@ -217,26 +231,6 @@ func (e *extension) profileStateChanged(_ ipn.LoginProfileView, prefs ipn.PrefsV e.handleRelayServerLifetimeLocked() } -// overrideAddrs returns TS_DEBUG_RELAY_SERVER_ADDRS as []netip.Addr, if set. It -// can be between 0 and 3 comma-separated Addrs. TS_DEBUG_RELAY_SERVER_ADDRS is -// not a stable interface, and is subject to change. -var overrideAddrs = sync.OnceValue(func() (ret []netip.Addr) { - all := envknob.String("TS_DEBUG_RELAY_SERVER_ADDRS") - const max = 3 - remain := all - for remain != "" && len(ret) < max { - var s string - s, remain, _ = strings.Cut(remain, ",") - addr, err := netip.ParseAddr(s) - if err != nil { - log.Printf("ignoring invalid Addr %q in TS_DEBUG_RELAY_SERVER_ADDRS %q: %v", s, all, err) - continue - } - ret = append(ret, addr) - } - return -}) - func (e *extension) stopRelayServerLocked() { if e.rs != nil { e.rs.Close() diff --git a/feature/relayserver/relayserver_test.go b/feature/relayserver/relayserver_test.go index 2184b51759b61..807306c707bc1 100644 --- a/feature/relayserver/relayserver_test.go +++ b/feature/relayserver/relayserver_test.go @@ -7,6 +7,7 @@ import ( "errors" "net/netip" "reflect" + "slices" "testing" "tailscale.com/ipn" @@ -18,15 +19,21 @@ import ( "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/ptr" + "tailscale.com/types/views" ) func Test_extension_profileStateChanged(t *testing.T) { - prefsWithPortOne := ipn.Prefs{RelayServerPort: ptr.To(1)} + prefsWithPortOne := ipn.Prefs{RelayServerPort: ptr.To(uint16(1))} prefsWithNilPort := ipn.Prefs{RelayServerPort: nil} + prefsWithPortOneRelayEndpoints := ipn.Prefs{ + RelayServerPort: ptr.To(uint16(1)), + RelayServerStaticEndpoints: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:7777")}, + } type fields struct { - port *int - rs relayServer + port *uint16 + staticEndpoints views.Slice[netip.AddrPort] + rs relayServer } type args struct { prefs ipn.PrefsView @@ -36,28 +43,75 @@ func Test_extension_profileStateChanged(t *testing.T) { name string fields fields args args - wantPort *int + wantPort *uint16 wantRelayServerFieldNonNil bool wantRelayServerFieldMutated bool + wantEndpoints []netip.AddrPort }{ { name: "no changes non-nil port previously running", fields: fields{ - port: ptr.To(1), + port: ptr.To(uint16(1)), rs: mockRelayServerNotZeroVal(), }, args: args{ prefs: prefsWithPortOne.View(), sameNode: true, }, - wantPort: ptr.To(1), + wantPort: ptr.To(uint16(1)), + wantRelayServerFieldNonNil: true, + wantRelayServerFieldMutated: false, + }, + { + name: "set addr ports unchanged port previously running", + fields: fields{ + port: ptr.To(uint16(1)), + rs: mockRelayServerNotZeroVal(), + }, + args: args{ + prefs: prefsWithPortOneRelayEndpoints.View(), + sameNode: true, + }, + wantPort: ptr.To(uint16(1)), + wantRelayServerFieldNonNil: true, + wantRelayServerFieldMutated: false, + wantEndpoints: prefsWithPortOneRelayEndpoints.RelayServerStaticEndpoints, + }, + { + name: "set addr ports not previously running", + fields: fields{ + port: nil, + rs: nil, + }, + args: args{ + prefs: prefsWithPortOneRelayEndpoints.View(), + sameNode: true, + }, + wantPort: ptr.To(uint16(1)), + wantRelayServerFieldNonNil: true, + wantRelayServerFieldMutated: true, + wantEndpoints: prefsWithPortOneRelayEndpoints.RelayServerStaticEndpoints, + }, + { + name: "clear addr ports unchanged port previously running", + fields: fields{ + port: ptr.To(uint16(1)), + staticEndpoints: views.SliceOf(prefsWithPortOneRelayEndpoints.RelayServerStaticEndpoints), + rs: mockRelayServerNotZeroVal(), + }, + args: args{ + prefs: prefsWithPortOne.View(), + sameNode: true, + }, + wantPort: ptr.To(uint16(1)), wantRelayServerFieldNonNil: true, wantRelayServerFieldMutated: false, + wantEndpoints: nil, }, { name: "prefs port nil", fields: fields{ - port: ptr.To(1), + port: ptr.To(uint16(1)), }, args: args{ prefs: prefsWithNilPort.View(), @@ -70,7 +124,7 @@ func Test_extension_profileStateChanged(t *testing.T) { { name: "prefs port nil previously running", fields: fields{ - port: ptr.To(1), + port: ptr.To(uint16(1)), rs: mockRelayServerNotZeroVal(), }, args: args{ @@ -84,54 +138,54 @@ func Test_extension_profileStateChanged(t *testing.T) { { name: "prefs port changed", fields: fields{ - port: ptr.To(2), + port: ptr.To(uint16(2)), }, args: args{ prefs: prefsWithPortOne.View(), sameNode: true, }, - wantPort: ptr.To(1), + wantPort: ptr.To(uint16(1)), wantRelayServerFieldNonNil: true, wantRelayServerFieldMutated: true, }, { name: "prefs port changed previously running", fields: fields{ - port: ptr.To(2), + port: ptr.To(uint16(2)), rs: mockRelayServerNotZeroVal(), }, args: args{ prefs: prefsWithPortOne.View(), sameNode: true, }, - wantPort: ptr.To(1), + wantPort: ptr.To(uint16(1)), wantRelayServerFieldNonNil: true, wantRelayServerFieldMutated: true, }, { name: "sameNode false", fields: fields{ - port: ptr.To(1), + port: ptr.To(uint16(1)), }, args: args{ prefs: prefsWithPortOne.View(), sameNode: false, }, - wantPort: ptr.To(1), + wantPort: ptr.To(uint16(1)), wantRelayServerFieldNonNil: true, wantRelayServerFieldMutated: true, }, { name: "sameNode false previously running", fields: fields{ - port: ptr.To(1), + port: ptr.To(uint16(1)), rs: mockRelayServerNotZeroVal(), }, args: args{ prefs: prefsWithPortOne.View(), sameNode: false, }, - wantPort: ptr.To(1), + wantPort: ptr.To(uint16(1)), wantRelayServerFieldNonNil: true, wantRelayServerFieldMutated: true, }, @@ -144,7 +198,7 @@ func Test_extension_profileStateChanged(t *testing.T) { prefs: prefsWithPortOne.View(), sameNode: false, }, - wantPort: ptr.To(1), + wantPort: ptr.To(uint16(1)), wantRelayServerFieldNonNil: true, wantRelayServerFieldMutated: true, }, @@ -157,10 +211,11 @@ func Test_extension_profileStateChanged(t *testing.T) { t.Fatal(err) } e := ipne.(*extension) - e.newServerFn = func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) { + e.newServerFn = func(logf logger.Logf, port uint16, onlyStaticAddrPorts bool) (relayServer, error) { return &mockRelayServer{}, nil } e.port = tt.fields.port + e.staticEndpoints = tt.fields.staticEndpoints e.rs = tt.fields.rs defer e.Shutdown() e.profileStateChanged(ipn.LoginProfileView{}, tt.args.prefs, tt.args.sameNode) @@ -175,24 +230,34 @@ func Test_extension_profileStateChanged(t *testing.T) { if tt.wantRelayServerFieldMutated != !reflect.DeepEqual(tt.fields.rs, e.rs) { t.Errorf("wantRelayServerFieldMutated: %v != !reflect.DeepEqual(tt.fields.rs, e.rs): %v", tt.wantRelayServerFieldMutated, !reflect.DeepEqual(tt.fields.rs, e.rs)) } + if !slices.Equal(tt.wantEndpoints, e.staticEndpoints.AsSlice()) { + t.Errorf("wantEndpoints: %v != %v", tt.wantEndpoints, e.staticEndpoints.AsSlice()) + } + if e.rs != nil && !slices.Equal(tt.wantEndpoints, e.rs.(*mockRelayServer).addrPorts.AsSlice()) { + t.Errorf("wantEndpoints: %v != %v", tt.wantEndpoints, e.rs.(*mockRelayServer).addrPorts.AsSlice()) + } }) } } func mockRelayServerNotZeroVal() *mockRelayServer { - return &mockRelayServer{true} + return &mockRelayServer{set: true} } type mockRelayServer struct { - set bool + set bool + addrPorts views.Slice[netip.AddrPort] } -func (mockRelayServer) Close() error { return nil } -func (mockRelayServer) AllocateEndpoint(_, _ key.DiscoPublic) (endpoint.ServerEndpoint, error) { +func (m *mockRelayServer) Close() error { return nil } +func (m *mockRelayServer) AllocateEndpoint(_, _ key.DiscoPublic) (endpoint.ServerEndpoint, error) { return endpoint.ServerEndpoint{}, errors.New("not implemented") } -func (mockRelayServer) GetSessions() []status.ServerSession { return nil } -func (mockRelayServer) SetDERPMapView(tailcfg.DERPMapView) { return } +func (m *mockRelayServer) GetSessions() []status.ServerSession { return nil } +func (m *mockRelayServer) SetDERPMapView(tailcfg.DERPMapView) { return } +func (m *mockRelayServer) SetStaticAddrPorts(aps views.Slice[netip.AddrPort]) { + m.addrPorts = aps +} type mockSafeBackend struct { sys *tsd.System @@ -206,7 +271,7 @@ func Test_extension_handleRelayServerLifetimeLocked(t *testing.T) { tests := []struct { name string shutdown bool - port *int + port *uint16 rs relayServer hasNodeAttrDisableRelayServer bool wantRelayServerFieldNonNil bool @@ -215,7 +280,7 @@ func Test_extension_handleRelayServerLifetimeLocked(t *testing.T) { { name: "want running", shutdown: false, - port: ptr.To(1), + port: ptr.To(uint16(1)), hasNodeAttrDisableRelayServer: false, wantRelayServerFieldNonNil: true, wantRelayServerFieldMutated: true, @@ -223,7 +288,7 @@ func Test_extension_handleRelayServerLifetimeLocked(t *testing.T) { { name: "want running previously running", shutdown: false, - port: ptr.To(1), + port: ptr.To(uint16(1)), rs: mockRelayServerNotZeroVal(), hasNodeAttrDisableRelayServer: false, wantRelayServerFieldNonNil: true, @@ -232,7 +297,7 @@ func Test_extension_handleRelayServerLifetimeLocked(t *testing.T) { { name: "shutdown true", shutdown: true, - port: ptr.To(1), + port: ptr.To(uint16(1)), hasNodeAttrDisableRelayServer: false, wantRelayServerFieldNonNil: false, wantRelayServerFieldMutated: false, @@ -240,7 +305,7 @@ func Test_extension_handleRelayServerLifetimeLocked(t *testing.T) { { name: "shutdown true previously running", shutdown: true, - port: ptr.To(1), + port: ptr.To(uint16(1)), rs: mockRelayServerNotZeroVal(), hasNodeAttrDisableRelayServer: false, wantRelayServerFieldNonNil: false, @@ -289,7 +354,7 @@ func Test_extension_handleRelayServerLifetimeLocked(t *testing.T) { t.Fatal(err) } e := ipne.(*extension) - e.newServerFn = func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) { + e.newServerFn = func(logf logger.Logf, port uint16, onlyStaticAddrPorts bool) (relayServer, error) { return &mockRelayServer{}, nil } e.shutdown = tt.shutdown diff --git a/feature/sdnotify.go b/feature/sdnotify.go index e785dc1acc09a..7a786dfabd519 100644 --- a/feature/sdnotify.go +++ b/feature/sdnotify.go @@ -23,10 +23,17 @@ var HookSystemdStatus Hook[func(format string, args ...any)] // It does nothing on non-Linux systems or if the binary was built without // the sdnotify feature. func SystemdStatus(format string, args ...any) { - if runtime.GOOS != "linux" || !buildfeatures.HasSDNotify { + if !CanSystemdStatus { // mid-stack inlining DCE return } if f, ok := HookSystemdStatus.GetOk(); ok { f(format, args...) } } + +// CanSystemdStatus reports whether the current build has systemd notifications +// linked in. +// +// It's effectively the same as HookSystemdStatus.IsSet(), but a constant for +// dead code elimination reasons. +const CanSystemdStatus = runtime.GOOS == "linux" && buildfeatures.HasSDNotify diff --git a/feature/sdnotify/sdnotify_linux.go b/feature/sdnotify/sdnotify_linux.go index b005f1bdb2bb2..2b13e24bbe509 100644 --- a/feature/sdnotify/sdnotify_linux.go +++ b/feature/sdnotify/sdnotify_linux.go @@ -29,8 +29,8 @@ type logOnce struct { sync.Once } -func (l *logOnce) logf(format string, args ...any) { - l.Once.Do(func() { +func (lg *logOnce) logf(format string, args ...any) { + lg.Once.Do(func() { log.Printf(format, args...) }) } diff --git a/feature/tpm/attestation.go b/feature/tpm/attestation.go index 5fbda3b17bab3..197a8d6b8798a 100644 --- a/feature/tpm/attestation.go +++ b/feature/tpm/attestation.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "log" + "sync" "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2/transport" @@ -19,7 +20,8 @@ import ( ) type attestationKey struct { - tpm transport.TPMCloser + tpmMu sync.Mutex + tpm transport.TPMCloser // private and public parts of the TPM key as returned from tpm2.Create. // These are used for serialization. tpmPrivate tpm2.TPM2BPrivate @@ -57,10 +59,12 @@ func newAttestationKey() (ak *attestationKey, retErr error) { SensitiveDataOrigin: true, UserWithAuth: true, AdminWithPolicy: true, - NoDA: true, - FixedTPM: true, - FixedParent: true, - SignEncrypt: true, + // We don't set an authorization policy on this key, so + // DA isn't helpful. + NoDA: true, + FixedTPM: true, + FixedParent: true, + SignEncrypt: true, }, Parameters: tpm2.NewTPMUPublicParms( tpm2.TPMAlgECC, @@ -144,7 +148,7 @@ type attestationKeySerialized struct { // MarshalJSON implements json.Marshaler. func (ak *attestationKey) MarshalJSON() ([]byte, error) { - if ak == nil || ak.IsZero() { + if ak == nil || len(ak.tpmPublic.Bytes()) == 0 || len(ak.tpmPrivate.Buffer) == 0 { return []byte("null"), nil } return json.Marshal(attestationKeySerialized{ @@ -163,6 +167,13 @@ func (ak *attestationKey) UnmarshalJSON(data []byte) (retErr error) { ak.tpmPrivate = tpm2.TPM2BPrivate{Buffer: aks.TPMPrivate} ak.tpmPublic = tpm2.BytesAs2B[tpm2.TPMTPublic, *tpm2.TPMTPublic](aks.TPMPublic) + ak.tpmMu.Lock() + defer ak.tpmMu.Unlock() + if ak.tpm != nil { + ak.tpm.Close() + ak.tpm = nil + } + tpm, err := open() if err != nil { return key.ErrUnsupported @@ -182,6 +193,9 @@ func (ak *attestationKey) Public() crypto.PublicKey { } func (ak *attestationKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { + ak.tpmMu.Lock() + defer ak.tpmMu.Unlock() + if !ak.loaded() { return nil, errors.New("tpm2 attestation key is not loaded during Sign") } @@ -247,6 +261,9 @@ func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) { } func (ak *attestationKey) Close() error { + ak.tpmMu.Lock() + defer ak.tpmMu.Unlock() + var errs []error if ak.handle != nil && ak.tpm != nil { _, err := tpm2.FlushContext{FlushHandle: ak.handle.Handle}.Execute(ak.tpm) @@ -259,21 +276,34 @@ func (ak *attestationKey) Close() error { } func (ak *attestationKey) Clone() key.HardwareAttestationKey { - if ak == nil { + if ak.IsZero() { + return nil + } + + tpm, err := open() + if err != nil { + log.Printf("[unexpected] failed to open a TPM connection in feature/tpm.attestationKey.Clone: %v", err) return nil } - return &attestationKey{ - tpm: ak.tpm, + akc := &attestationKey{ + tpm: tpm, tpmPrivate: ak.tpmPrivate, tpmPublic: ak.tpmPublic, - handle: ak.handle, - pub: ak.pub, } + if err := akc.load(); err != nil { + log.Printf("[unexpected] failed to load TPM key in feature/tpm.attestationKey.Clone: %v", err) + tpm.Close() + return nil + } + return akc } func (ak *attestationKey) IsZero() bool { if ak == nil { return true } + + ak.tpmMu.Lock() + defer ak.tpmMu.Unlock() return !ak.loaded() } diff --git a/feature/tpm/attestation_test.go b/feature/tpm/attestation_test.go index ead88c955aeea..e7ff729871230 100644 --- a/feature/tpm/attestation_test.go +++ b/feature/tpm/attestation_test.go @@ -10,6 +10,8 @@ import ( "crypto/rand" "crypto/sha256" "encoding/json" + "runtime" + "sync" "testing" ) @@ -62,6 +64,37 @@ func TestAttestationKeySign(t *testing.T) { } } +func TestAttestationKeySignConcurrent(t *testing.T) { + skipWithoutTPM(t) + ak, err := newAttestationKey() + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := ak.Close(); err != nil { + t.Errorf("ak.Close: %v", err) + } + }) + + data := []byte("secrets") + digest := sha256.Sum256(data) + + wg := sync.WaitGroup{} + for range runtime.GOMAXPROCS(-1) { + wg.Go(func() { + // Check signature/validation round trip. + sig, err := ak.Sign(rand.Reader, digest[:], crypto.SHA256) + if err != nil { + t.Fatal(err) + } + if !ecdsa.VerifyASN1(ak.Public().(*ecdsa.PublicKey), digest[:], sig) { + t.Errorf("ecdsa.VerifyASN1 failed") + } + }) + } + wg.Wait() +} + func TestAttestationKeyUnmarshal(t *testing.T) { skipWithoutTPM(t) ak, err := newAttestationKey() @@ -96,3 +129,36 @@ func TestAttestationKeyUnmarshal(t *testing.T) { t.Error("unmarshalled public key is not the same as the original public key") } } + +func TestAttestationKeyClone(t *testing.T) { + skipWithoutTPM(t) + ak, err := newAttestationKey() + if err != nil { + t.Fatal(err) + } + + ak2 := ak.Clone() + if ak2 == nil { + t.Fatal("Clone failed") + } + t.Cleanup(func() { + if err := ak2.Close(); err != nil { + t.Errorf("ak2.Close: %v", err) + } + }) + // Close the original key, ak2 should remain open and usable. + if err := ak.Close(); err != nil { + t.Fatal(err) + } + + data := []byte("secrets") + digest := sha256.Sum256(data) + // Check signature/validation round trip using cloned key. + sig, err := ak2.Sign(rand.Reader, digest[:], crypto.SHA256) + if err != nil { + t.Fatal(err) + } + if !ecdsa.VerifyASN1(ak2.Public().(*ecdsa.PublicKey), digest[:], sig) { + t.Errorf("ecdsa.VerifyASN1 failed") + } +} diff --git a/feature/tpm/tpm.go b/feature/tpm/tpm.go index 4b27a241fa255..8df269b95bc2e 100644 --- a/feature/tpm/tpm.go +++ b/feature/tpm/tpm.go @@ -35,12 +35,15 @@ import ( "tailscale.com/util/testenv" ) -var infoOnce = sync.OnceValue(info) +var ( + infoOnce = sync.OnceValue(info) + tpmSupportedOnce = sync.OnceValue(tpmSupported) +) func init() { feature.Register("tpm") - feature.HookTPMAvailable.Set(tpmSupported) - feature.HookHardwareAttestationAvailable.Set(tpmSupported) + feature.HookTPMAvailable.Set(tpmSupportedOnce) + feature.HookHardwareAttestationAvailable.Set(tpmSupportedOnce) hostinfo.RegisterHostinfoNewHook(func(hi *tailcfg.Hostinfo) { hi.TPM = infoOnce() @@ -411,6 +414,9 @@ func tpmSeal(logf logger.Logf, data []byte) (*tpmSealedData, error) { FixedTPM: true, FixedParent: true, UserWithAuth: true, + // We don't set an authorization policy on this key, so DA + // isn't helpful. + NoDA: true, }, }), } diff --git a/feature/tpm/tpm_linux.go b/feature/tpm/tpm_linux.go index 6c8131e8d8a28..3f05c9a8c38ad 100644 --- a/feature/tpm/tpm_linux.go +++ b/feature/tpm/tpm_linux.go @@ -4,6 +4,8 @@ package tpm import ( + "errors" + "github.com/google/go-tpm/tpm2/transport" "github.com/google/go-tpm/tpm2/transport/linuxtpm" ) @@ -13,5 +15,10 @@ func open() (transport.TPMCloser, error) { if err == nil { return tpm, nil } - return linuxtpm.Open("/dev/tpm0") + errs := []error{err} + tpm, err = linuxtpm.Open("/dev/tpm0") + if err == nil { + return tpm, nil + } + return nil, errors.Join(errs...) } diff --git a/flake.nix b/flake.nix index e8b5420c55cef..505061a765362 100644 --- a/flake.nix +++ b/flake.nix @@ -151,5 +151,4 @@ }); }; } -# nix-direnv cache busting line: sha256-AUOjLomba75qfzb9Vxc0Sktyeces6hBSuOMgboWcDnE= - +# nix-direnv cache busting line: sha256-jJSSXMyUqcJoZuqfSlBsKDQezyqS+jDkRglMMjG1K8g= diff --git a/go.mod b/go.mod index 96df00f65044f..aee9fa2d76bf9 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module tailscale.com -go 1.25.3 +go 1.25.5 require ( filippo.io/mkcert v1.4.4 @@ -16,11 +16,13 @@ require ( github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.58 github.com/aws/aws-sdk-go-v2/service/s3 v1.75.3 github.com/aws/aws-sdk-go-v2/service/ssm v1.44.7 + github.com/bradfitz/go-tool-cache v0.0.0-20251113223507-0124e698e0bd github.com/bramvdbogaerde/go-scp v1.4.0 github.com/cilium/ebpf v0.15.0 github.com/coder/websocket v1.8.12 github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf + github.com/creachadair/msync v0.7.1 github.com/creachadair/taskgroup v0.13.2 github.com/creack/pty v1.1.23 github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa @@ -59,7 +61,7 @@ require ( github.com/jellydator/ttlcache/v3 v3.1.0 github.com/jsimonetti/rtnetlink v1.4.0 github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 - github.com/klauspost/compress v1.17.11 + github.com/klauspost/compress v1.18.0 github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a github.com/mattn/go-colorable v0.1.13 github.com/mattn/go-isatty v0.0.20 @@ -69,11 +71,12 @@ require ( github.com/miekg/dns v1.1.58 github.com/mitchellh/go-ps v1.0.0 github.com/peterbourgon/ff/v3 v3.4.0 + github.com/pires/go-proxyproto v0.8.1 github.com/pkg/errors v0.9.1 github.com/pkg/sftp v1.13.6 github.com/prometheus-community/pro-bing v0.4.0 - github.com/prometheus/client_golang v1.20.5 - github.com/prometheus/common v0.55.0 + github.com/prometheus/client_golang v1.23.0 + github.com/prometheus/common v0.65.0 github.com/prometheus/prometheus v0.49.2-0.20240125131847-c3b8ef1694ff github.com/safchain/ethtool v0.3.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e @@ -100,21 +103,21 @@ require ( go.uber.org/zap v1.27.0 go4.org/mem v0.0.0-20240501181205-ae6ca9944745 go4.org/netipx v0.0.0-20231129151722-fdeea329fbba - golang.org/x/crypto v0.38.0 - golang.org/x/exp v0.0.0-20250210185358-939b2ce775ac - golang.org/x/mod v0.24.0 - golang.org/x/net v0.40.0 + golang.org/x/crypto v0.45.0 + golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b + golang.org/x/mod v0.30.0 + golang.org/x/net v0.47.0 golang.org/x/oauth2 v0.30.0 - golang.org/x/sync v0.14.0 - golang.org/x/sys v0.33.0 - golang.org/x/term v0.32.0 + golang.org/x/sync v0.18.0 + golang.org/x/sys v0.38.0 + golang.org/x/term v0.37.0 golang.org/x/time v0.11.0 - golang.org/x/tools v0.33.0 + golang.org/x/tools v0.39.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard/windows v0.5.3 gopkg.in/square/go-jose.v2 v2.6.0 gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 - honnef.co/go/tools v0.5.1 + honnef.co/go/tools v0.7.0-0.dev.0.20251022135355-8273271481d0 k8s.io/api v0.32.0 k8s.io/apimachinery v0.32.0 k8s.io/apiserver v0.32.0 @@ -185,6 +188,9 @@ require ( go.opentelemetry.io/otel/metric v1.33.0 // indirect go.opentelemetry.io/otel/trace v1.33.0 // indirect go.uber.org/automaxprocs v1.5.3 // indirect + golang.org/x/telemetry v0.0.0-20251111182119-bc8e575c7b54 // indirect + golang.org/x/tools/go/expect v0.1.1-deprecated // indirect + golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated // indirect golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9 // indirect gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect k8s.io/component-base v0.32.0 // indirect @@ -350,8 +356,8 @@ require ( github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/polyfloyd/go-errorlint v1.4.8 // indirect - github.com/prometheus/client_model v0.6.1 - github.com/prometheus/procfs v0.15.1 // indirect + github.com/prometheus/client_model v0.6.2 + github.com/prometheus/procfs v0.16.1 // indirect github.com/quasilyte/go-ruleguard v0.4.2 // indirect github.com/quasilyte/gogrep v0.5.0 // indirect github.com/quasilyte/regex/syntax v0.0.0-20210819130434-b3f0c404a727 // indirect @@ -407,9 +413,9 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f // indirect golang.org/x/image v0.27.0 // indirect - golang.org/x/text v0.25.0 // indirect + golang.org/x/text v0.31.0 // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect - google.golang.org/protobuf v1.36.3 // indirect + google.golang.org/protobuf v1.36.6 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect diff --git a/go.mod.sri b/go.mod.sri index 790d851a19a94..66422652e2262 100644 --- a/go.mod.sri +++ b/go.mod.sri @@ -1 +1 @@ -sha256-AUOjLomba75qfzb9Vxc0Sktyeces6hBSuOMgboWcDnE= +sha256-jJSSXMyUqcJoZuqfSlBsKDQezyqS+jDkRglMMjG1K8g= diff --git a/go.sum b/go.sum index 9864eab71aa9a..f70fe9159f614 100644 --- a/go.sum +++ b/go.sum @@ -186,6 +186,8 @@ github.com/boltdb/bolt v1.3.1 h1:JQmyP4ZBrce+ZQu0dY660FMfatumYDLun9hBCUVIkF4= github.com/boltdb/bolt v1.3.1/go.mod h1:clJnj/oiGkjum5o1McbSZDSLxVThjynRyGBgiAx27Ps= github.com/bombsimon/wsl/v4 v4.2.1 h1:Cxg6u+XDWff75SIFFmNsqnIOgob+Q9hG6y/ioKbRFiM= github.com/bombsimon/wsl/v4 v4.2.1/go.mod h1:Xu/kDxGZTofQcDGCtQe9KCzhHphIe0fDuyWTxER9Feo= +github.com/bradfitz/go-tool-cache v0.0.0-20251113223507-0124e698e0bd h1:1Df3FBmfyUCIQ4eKzAPXIWTfewY89L0fWPWO56zWCyI= +github.com/bradfitz/go-tool-cache v0.0.0-20251113223507-0124e698e0bd/go.mod h1:2+xptBAd0m2kZ1wLO4AYZhldLEFPy+KeGwmnlXLvy+w= github.com/bramvdbogaerde/go-scp v1.4.0 h1:jKMwpwCbcX1KyvDbm/PDJuXcMuNVlLGi0Q0reuzjyKY= github.com/bramvdbogaerde/go-scp v1.4.0/go.mod h1:on2aH5AxaFb2G0N5Vsdy6B0Ml7k9HuHSwfo1y0QzAbQ= github.com/breml/bidichk v0.2.7 h1:dAkKQPLl/Qrk7hnP6P+E0xOodrq8Us7+U0o4UBOAlQY= @@ -244,8 +246,10 @@ github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6/go.mod h1:Qe8 github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU= github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= -github.com/creachadair/mds v0.17.1 h1:lXQbTGKmb3nE3aK6OEp29L1gCx6B5ynzlQ6c1KOBurc= -github.com/creachadair/mds v0.17.1/go.mod h1:4b//mUiL8YldH6TImXjmW45myzTLNS1LLjOmrk888eg= +github.com/creachadair/mds v0.25.9 h1:080Hr8laN2h+l3NeVCGMBpXtIPnl9mz8e4HLraGPqtA= +github.com/creachadair/mds v0.25.9/go.mod h1:4hatI3hRM+qhzuAmqPRFvaBM8mONkS7nsLxkcuTYUIs= +github.com/creachadair/msync v0.7.1 h1:SeZmuEBXQPe5GqV/C94ER7QIZPwtvFbeQiykzt/7uho= +github.com/creachadair/msync v0.7.1/go.mod h1:8CcFlLsSujfHE5wWm19uUBLHIPDAUr6LXDwneVMO008= github.com/creachadair/taskgroup v0.13.2 h1:3KyqakBuFsm3KkXi/9XIb0QcA8tEzLHLgaoidf0MdVc= github.com/creachadair/taskgroup v0.13.2/go.mod h1:i3V1Zx7H8RjwljUEeUWYT30Lmb9poewSb2XI1yTwD0g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= @@ -660,8 +664,8 @@ github.com/kisielk/errcheck v1.7.0/go.mod h1:1kLL+jV4e+CFfueBmI1dSK2ADDyQnlrnrY/ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kkHAIKE/contextcheck v1.1.4 h1:B6zAaLhOEEcjvUgIYEqystmnFk1Oemn8bvJhbt0GMb8= github.com/kkHAIKE/contextcheck v1.1.4/go.mod h1:1+i/gWqokIa+dm31mqGLZhZJ7Uh44DJGZVmr6QRBNJg= -github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= -github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU= github.com/klauspost/pgzip v1.2.6/go.mod h1:Ch1tH69qFZu15pkjo5kYi6mth2Zzwzt50oCQKQE9RUs= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -807,6 +811,8 @@ github.com/peterbourgon/ff/v3 v3.4.0 h1:QBvM/rizZM1cB0p0lGMdmR7HxZeI/ZrBWB4DqLkM github.com/peterbourgon/ff/v3 v3.4.0/go.mod h1:zjJVUhx+twciwfDl0zBcFzl4dW8axCRyXE/eKY9RztQ= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0= +github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU= github.com/pjbgf/sha1cd v0.3.0 h1:4D5XXmUUBUl/xQ6IjCkEAbqXskkq/4O7LmGn0AqMDs4= github.com/pjbgf/sha1cd v0.3.0/go.mod h1:nZ1rrWOcGJ5uZgEEVL1VUM9iRQiZvWdbZjkKyFzPPsI= github.com/pkg/diff v0.0.0-20200914180035-5b29258ca4f7/go.mod h1:zO8QMzTeZd5cpnIkz/Gn6iK0jDfGicM1nynOkkPIl28= @@ -836,29 +842,29 @@ github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_golang v1.11.1/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= -github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y= -github.com/prometheus/client_golang v1.20.5/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= +github.com/prometheus/client_golang v1.23.0 h1:ust4zpdl9r4trLY/gSjlm07PuiBq2ynaXXlptpfy8Uc= +github.com/prometheus/client_golang v1.23.0/go.mod h1:i/o0R9ByOnHX0McrTMTyhYvKE4haaf2mW08I+jGAjEE= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= -github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.9.1/go.mod h1:yhUN8i9wzaXS3w1O07YhxHEBxD+W35wd8bs7vj7HSQ4= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls= -github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc= -github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8= +github.com/prometheus/common v0.65.0 h1:QDwzd+G1twt//Kwj/Ww6E9FQq1iVMmODnILtW1t2VzE= +github.com/prometheus/common v0.65.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= -github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= +github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= github.com/prometheus/prometheus v0.49.2-0.20240125131847-c3b8ef1694ff h1:X1Tly81aZ22DA1fxBdfvR3iw8+yFoUBUHMEd+AX/ZXI= github.com/prometheus/prometheus v0.49.2-0.20240125131847-c3b8ef1694ff/go.mod h1:FvE8dtQ1Ww63IlyKBn1V4s+zMwF9kHkVNkQBR1pM4CU= github.com/puzpuzpuz/xsync v1.5.2 h1:yRAP4wqSOZG+/4pxJ08fPTwrfL0IzE/LKQ/cw509qGY= @@ -1124,8 +1130,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= -golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= -golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -1136,8 +1142,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20250210185358-939b2ce775ac h1:l5+whBCLH3iH2ZNHYLbAe58bo7yrN4mVcnkHDYz5vvs= -golang.org/x/exp v0.0.0-20250210185358-939b2ce775ac/go.mod h1:hH+7mtFmImwwcMvScyxUhjuVHR3HGaDPMn9rMSUUbxo= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= golang.org/x/exp/typeparams v0.0.0-20220428152302-39d4317da171/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= golang.org/x/exp/typeparams v0.0.0-20230203172020-98cc5a0785f9/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8= @@ -1173,8 +1179,8 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91 golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= -golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= +golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= +golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -1214,8 +1220,8 @@ golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= -golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1237,8 +1243,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= -golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -1301,16 +1307,18 @@ golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= -golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/telemetry v0.0.0-20251111182119-bc8e575c7b54 h1:E2/AqCUMZGgd73TQkxUMcMla25GB9i/5HOdLr+uH7Vo= +golang.org/x/telemetry v0.0.0-20251111182119-bc8e575c7b54/go.mod h1:hKdjCMrbv9skySur+Nek8Hd0uJ0GuxJIoIX2payrIdQ= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= -golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= +golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= +golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1321,8 +1329,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= -golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= +golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= +golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1392,8 +1400,12 @@ golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= golang.org/x/tools v0.3.0/go.mod h1:/rWhSS2+zyEVwoJf8YAX6L2f0ntZ7Kn/mGgAWcipA5k= golang.org/x/tools v0.5.0/go.mod h1:N+Kgy78s5I24c24dU8OfWNEotWjutIs8SnJvn5IDq+k= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= -golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= +golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= +golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= +golang.org/x/tools/go/expect v0.1.1-deprecated h1:jpBZDwmgPhXsKZC6WhL20P4b/wmnpsEAGHaNy0n/rJM= +golang.org/x/tools/go/expect v0.1.1-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY= +golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM= +golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -1488,8 +1500,8 @@ google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGj google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.36.3 h1:82DV7MYdb8anAVi3qge1wSnMDrnKK7ebr+I0hHRN1BU= -google.golang.org/protobuf v1.36.3/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -1534,8 +1546,8 @@ honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -honnef.co/go/tools v0.5.1 h1:4bH5o3b5ZULQ4UrBmP+63W9r7qIkqJClEA9ko5YKx+I= -honnef.co/go/tools v0.5.1/go.mod h1:e9irvo83WDG9/irijV44wr3tbhcFeRnfpVlRqVwpzMs= +honnef.co/go/tools v0.7.0-0.dev.0.20251022135355-8273271481d0 h1:5SXjd4ET5dYijLaf0O3aOenC0Z4ZafIWSpjUzsQaNho= +honnef.co/go/tools v0.7.0-0.dev.0.20251022135355-8273271481d0/go.mod h1:EPDDhEZqVHhWuPI5zPAsjU0U7v9xNIWjoOVyZ5ZcniQ= howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM= howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g= k8s.io/api v0.32.0 h1:OL9JpbvAU5ny9ga2fb24X8H6xQlVp+aJMFlgtQjR9CE= diff --git a/go.toolchain.rev b/go.toolchain.rev index 9ea6b37dcbc32..16058a407c704 100644 --- a/go.toolchain.rev +++ b/go.toolchain.rev @@ -1 +1 @@ -5c01b77ad0d27a8bd4ef89ef7e713fd7043c5a91 +0bab982699fa5903259ba9b4cba3e5fd6cb3baf2 diff --git a/go.toolchain.rev.sri b/go.toolchain.rev.sri index a62a525998ac7..310dcf87fcf1c 100644 --- a/go.toolchain.rev.sri +++ b/go.toolchain.rev.sri @@ -1 +1 @@ -sha256-2TYziJLJrFOW2FehhahKficnDACJEwjuvVYyeQZbrcc= +sha256-fBezkBGRHCnfJiOUmMMqBCPCqjlGC4F6KEt5h1JhsCg= diff --git a/go.toolchain.version b/go.toolchain.version index 5bb76b575e1f5..b45fe310644f7 100644 --- a/go.toolchain.version +++ b/go.toolchain.version @@ -1 +1 @@ -1.25.3 +1.25.5 diff --git a/health/health.go b/health/health.go index cbfa599c56eaf..f0f6a6ffbb162 100644 --- a/health/health.go +++ b/health/health.go @@ -20,6 +20,7 @@ import ( "tailscale.com/envknob" "tailscale.com/feature/buildfeatures" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tstime" "tailscale.com/types/opt" @@ -30,7 +31,7 @@ import ( ) var ( - mu sync.Mutex + mu syncs.Mutex debugHandler map[string]http.Handler ) diff --git a/ipn/auditlog/extension.go b/ipn/auditlog/extension.go index f73681db073c1..ae2a296b2c420 100644 --- a/ipn/auditlog/extension.go +++ b/ipn/auditlog/extension.go @@ -7,7 +7,6 @@ import ( "context" "errors" "fmt" - "sync" "time" "tailscale.com/control/controlclient" @@ -15,6 +14,7 @@ import ( "tailscale.com/ipn" "tailscale.com/ipn/ipnauth" "tailscale.com/ipn/ipnext" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/types/lazy" "tailscale.com/types/logger" @@ -40,7 +40,7 @@ type extension struct { store lazy.SyncValue[LogStore] // mu protects all following fields. - mu sync.Mutex + mu syncs.Mutex // logger is the current audit logger, or nil if it is not set up, // such as before the first control client is created, or after // a profile change and before the new control client is created. diff --git a/ipn/backend.go b/ipn/backend.go index 91cf81ca52962..b4ba958c5dd1e 100644 --- a/ipn/backend.go +++ b/ipn/backend.go @@ -74,7 +74,7 @@ const ( NotifyInitialPrefs NotifyWatchOpt = 1 << 2 // if set, the first Notify message (sent immediately) will contain the current Prefs NotifyInitialNetMap NotifyWatchOpt = 1 << 3 // if set, the first Notify message (sent immediately) will contain the current NetMap - NotifyNoPrivateKeys NotifyWatchOpt = 1 << 4 // if set, private keys that would normally be sent in updates are zeroed out + NotifyNoPrivateKeys NotifyWatchOpt = 1 << 4 // (no-op) it used to redact private keys; now they always are and this does nothing NotifyInitialDriveShares NotifyWatchOpt = 1 << 5 // if set, the first Notify message (sent immediately) will contain the current Taildrive Shares NotifyInitialOutgoingFiles NotifyWatchOpt = 1 << 6 // if set, the first Notify message (sent immediately) will contain the current Taildrop OutgoingFiles diff --git a/ipn/desktop/zsyscall_windows.go b/ipn/desktop/zsyscall_windows.go index 535274016f9ca..8d97c4d8089ef 100644 --- a/ipn/desktop/zsyscall_windows.go +++ b/ipn/desktop/zsyscall_windows.go @@ -57,12 +57,12 @@ var ( ) func setLastError(dwErrorCode uint32) { - syscall.Syscall(procSetLastError.Addr(), 1, uintptr(dwErrorCode), 0, 0) + syscall.SyscallN(procSetLastError.Addr(), uintptr(dwErrorCode)) return } func createWindowEx(dwExStyle uint32, lpClassName *uint16, lpWindowName *uint16, dwStyle uint32, x int32, y int32, nWidth int32, nHeight int32, hWndParent windows.HWND, hMenu windows.Handle, hInstance windows.Handle, lpParam unsafe.Pointer) (hWnd windows.HWND, err error) { - r0, _, e1 := syscall.Syscall12(procCreateWindowExW.Addr(), 12, uintptr(dwExStyle), uintptr(unsafe.Pointer(lpClassName)), uintptr(unsafe.Pointer(lpWindowName)), uintptr(dwStyle), uintptr(x), uintptr(y), uintptr(nWidth), uintptr(nHeight), uintptr(hWndParent), uintptr(hMenu), uintptr(hInstance), uintptr(lpParam)) + r0, _, e1 := syscall.SyscallN(procCreateWindowExW.Addr(), uintptr(dwExStyle), uintptr(unsafe.Pointer(lpClassName)), uintptr(unsafe.Pointer(lpWindowName)), uintptr(dwStyle), uintptr(x), uintptr(y), uintptr(nWidth), uintptr(nHeight), uintptr(hWndParent), uintptr(hMenu), uintptr(hInstance), uintptr(lpParam)) hWnd = windows.HWND(r0) if hWnd == 0 { err = errnoErr(e1) @@ -71,13 +71,13 @@ func createWindowEx(dwExStyle uint32, lpClassName *uint16, lpWindowName *uint16, } func defWindowProc(hwnd windows.HWND, msg uint32, wparam uintptr, lparam uintptr) (res uintptr) { - r0, _, _ := syscall.Syscall6(procDefWindowProcW.Addr(), 4, uintptr(hwnd), uintptr(msg), uintptr(wparam), uintptr(lparam), 0, 0) + r0, _, _ := syscall.SyscallN(procDefWindowProcW.Addr(), uintptr(hwnd), uintptr(msg), uintptr(wparam), uintptr(lparam)) res = uintptr(r0) return } func destroyWindow(hwnd windows.HWND) (err error) { - r1, _, e1 := syscall.Syscall(procDestroyWindow.Addr(), 1, uintptr(hwnd), 0, 0) + r1, _, e1 := syscall.SyscallN(procDestroyWindow.Addr(), uintptr(hwnd)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -85,24 +85,24 @@ func destroyWindow(hwnd windows.HWND) (err error) { } func dispatchMessage(lpMsg *_MSG) (res uintptr) { - r0, _, _ := syscall.Syscall(procDispatchMessageW.Addr(), 1, uintptr(unsafe.Pointer(lpMsg)), 0, 0) + r0, _, _ := syscall.SyscallN(procDispatchMessageW.Addr(), uintptr(unsafe.Pointer(lpMsg))) res = uintptr(r0) return } func getMessage(lpMsg *_MSG, hwnd windows.HWND, msgMin uint32, msgMax uint32) (ret int32) { - r0, _, _ := syscall.Syscall6(procGetMessageW.Addr(), 4, uintptr(unsafe.Pointer(lpMsg)), uintptr(hwnd), uintptr(msgMin), uintptr(msgMax), 0, 0) + r0, _, _ := syscall.SyscallN(procGetMessageW.Addr(), uintptr(unsafe.Pointer(lpMsg)), uintptr(hwnd), uintptr(msgMin), uintptr(msgMax)) ret = int32(r0) return } func postQuitMessage(exitCode int32) { - syscall.Syscall(procPostQuitMessage.Addr(), 1, uintptr(exitCode), 0, 0) + syscall.SyscallN(procPostQuitMessage.Addr(), uintptr(exitCode)) return } func registerClassEx(windowClass *_WNDCLASSEX) (atom uint16, err error) { - r0, _, e1 := syscall.Syscall(procRegisterClassExW.Addr(), 1, uintptr(unsafe.Pointer(windowClass)), 0, 0) + r0, _, e1 := syscall.SyscallN(procRegisterClassExW.Addr(), uintptr(unsafe.Pointer(windowClass))) atom = uint16(r0) if atom == 0 { err = errnoErr(e1) @@ -111,19 +111,19 @@ func registerClassEx(windowClass *_WNDCLASSEX) (atom uint16, err error) { } func sendMessage(hwnd windows.HWND, msg uint32, wparam uintptr, lparam uintptr) (res uintptr) { - r0, _, _ := syscall.Syscall6(procSendMessageW.Addr(), 4, uintptr(hwnd), uintptr(msg), uintptr(wparam), uintptr(lparam), 0, 0) + r0, _, _ := syscall.SyscallN(procSendMessageW.Addr(), uintptr(hwnd), uintptr(msg), uintptr(wparam), uintptr(lparam)) res = uintptr(r0) return } func translateMessage(lpMsg *_MSG) (res bool) { - r0, _, _ := syscall.Syscall(procTranslateMessage.Addr(), 1, uintptr(unsafe.Pointer(lpMsg)), 0, 0) + r0, _, _ := syscall.SyscallN(procTranslateMessage.Addr(), uintptr(unsafe.Pointer(lpMsg))) res = r0 != 0 return } func registerSessionNotification(hServer windows.Handle, hwnd windows.HWND, flags uint32) (err error) { - r1, _, e1 := syscall.Syscall(procWTSRegisterSessionNotificationEx.Addr(), 3, uintptr(hServer), uintptr(hwnd), uintptr(flags)) + r1, _, e1 := syscall.SyscallN(procWTSRegisterSessionNotificationEx.Addr(), uintptr(hServer), uintptr(hwnd), uintptr(flags)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -131,7 +131,7 @@ func registerSessionNotification(hServer windows.Handle, hwnd windows.HWND, flag } func unregisterSessionNotification(hServer windows.Handle, hwnd windows.HWND) (err error) { - r1, _, e1 := syscall.Syscall(procWTSUnRegisterSessionNotificationEx.Addr(), 2, uintptr(hServer), uintptr(hwnd), 0) + r1, _, e1 := syscall.SyscallN(procWTSUnRegisterSessionNotificationEx.Addr(), uintptr(hServer), uintptr(hwnd)) if int32(r1) == 0 { err = errnoErr(e1) } diff --git a/ipn/ipn_clone.go b/ipn/ipn_clone.go index 3d67efc6fd33b..4bf78b40b022b 100644 --- a/ipn/ipn_clone.go +++ b/ipn/ipn_clone.go @@ -64,45 +64,48 @@ func (src *Prefs) Clone() *Prefs { if dst.RelayServerPort != nil { dst.RelayServerPort = ptr.To(*src.RelayServerPort) } + dst.RelayServerStaticEndpoints = append(src.RelayServerStaticEndpoints[:0:0], src.RelayServerStaticEndpoints...) dst.Persist = src.Persist.Clone() return dst } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _PrefsCloneNeedsRegeneration = Prefs(struct { - ControlURL string - RouteAll bool - ExitNodeID tailcfg.StableNodeID - ExitNodeIP netip.Addr - AutoExitNode ExitNodeExpression - InternalExitNodePrior tailcfg.StableNodeID - ExitNodeAllowLANAccess bool - CorpDNS bool - RunSSH bool - RunWebClient bool - WantRunning bool - LoggedOut bool - ShieldsUp bool - AdvertiseTags []string - Hostname string - NotepadURLs bool - ForceDaemon bool - Egg bool - AdvertiseRoutes []netip.Prefix - AdvertiseServices []string - NoSNAT bool - NoStatefulFiltering opt.Bool - NetfilterMode preftype.NetfilterMode - OperatorUser string - ProfileName string - AutoUpdate AutoUpdatePrefs - AppConnector AppConnectorPrefs - PostureChecking bool - NetfilterKind string - DriveShares []*drive.Share - RelayServerPort *int - AllowSingleHosts marshalAsTrueInJSON - Persist *persist.Persist + ControlURL string + RouteAll bool + ExitNodeID tailcfg.StableNodeID + ExitNodeIP netip.Addr + AutoExitNode ExitNodeExpression + InternalExitNodePrior tailcfg.StableNodeID + ExitNodeAllowLANAccess bool + CorpDNS bool + RunSSH bool + RunWebClient bool + WantRunning bool + LoggedOut bool + ShieldsUp bool + AdvertiseTags []string + Hostname string + NotepadURLs bool + ForceDaemon bool + Egg bool + AdvertiseRoutes []netip.Prefix + AdvertiseServices []string + Sync opt.Bool + NoSNAT bool + NoStatefulFiltering opt.Bool + NetfilterMode preftype.NetfilterMode + OperatorUser string + ProfileName string + AutoUpdate AutoUpdatePrefs + AppConnector AppConnectorPrefs + PostureChecking bool + NetfilterKind string + DriveShares []*drive.Share + RelayServerPort *uint16 + RelayServerStaticEndpoints []netip.AddrPort + AllowSingleHosts marshalAsTrueInJSON + Persist *persist.Persist }{}) // Clone makes a deep copy of ServeConfig. @@ -218,10 +221,11 @@ func (src *TCPPortHandler) Clone() *TCPPortHandler { // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _TCPPortHandlerCloneNeedsRegeneration = TCPPortHandler(struct { - HTTPS bool - HTTP bool - TCPForward string - TerminateTLS string + HTTPS bool + HTTP bool + TCPForward string + TerminateTLS string + ProxyProtocol int }{}) // Clone makes a deep copy of HTTPHandler. @@ -232,14 +236,17 @@ func (src *HTTPHandler) Clone() *HTTPHandler { } dst := new(HTTPHandler) *dst = *src + dst.AcceptAppCaps = append(src.AcceptAppCaps[:0:0], src.AcceptAppCaps...) return dst } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _HTTPHandlerCloneNeedsRegeneration = HTTPHandler(struct { - Path string - Proxy string - Text string + Path string + Proxy string + Text string + AcceptAppCaps []tailcfg.PeerCapability + Redirect string }{}) // Clone makes a deep copy of WebServerConfig. @@ -256,7 +263,7 @@ func (src *WebServerConfig) Clone() *WebServerConfig { if v == nil { dst.Handlers[k] = nil } else { - dst.Handlers[k] = ptr.To(*v) + dst.Handlers[k] = v.Clone() } } } diff --git a/ipn/ipn_view.go b/ipn/ipn_view.go index 1c7639f6ff932..4157ec76e61a8 100644 --- a/ipn/ipn_view.go +++ b/ipn/ipn_view.go @@ -363,6 +363,12 @@ func (v PrefsView) AdvertiseServices() views.Slice[string] { return views.SliceOf(v.ж.AdvertiseServices) } +// Sync is whether this node should sync its configuration from +// the control plane. If unset, this defaults to true. +// This exists primarily for testing, to verify that netmap caching +// and offline operation work correctly. +func (v PrefsView) Sync() opt.Bool { return v.ж.Sync } + // NoSNAT specifies whether to source NAT traffic going to // destinations in AdvertiseRoutes. The default is to apply source // NAT, which makes the traffic appear to come from the router @@ -435,16 +441,21 @@ func (v PrefsView) DriveShares() views.SliceView[*drive.Share, drive.ShareView] // RelayServerPort is the UDP port number for the relay server to bind to, // on all interfaces. A non-nil zero value signifies a random unused port // should be used. A nil value signifies relay server functionality -// should be disabled. This field is currently experimental, and therefore -// no guarantees are made about its current naming and functionality when -// non-nil/enabled. -func (v PrefsView) RelayServerPort() views.ValuePointer[int] { +// should be disabled. +func (v PrefsView) RelayServerPort() views.ValuePointer[uint16] { return views.ValuePointerOf(v.ж.RelayServerPort) } +// RelayServerStaticEndpoints are static IP:port endpoints to advertise as +// candidates for relay connections. Only relevant when RelayServerPort is +// non-nil. +func (v PrefsView) RelayServerStaticEndpoints() views.Slice[netip.AddrPort] { + return views.SliceOf(v.ж.RelayServerStaticEndpoints) +} + // AllowSingleHosts was a legacy field that was always true // for the past 4.5 years. It controlled whether Tailscale -// peers got /32 or /127 routes for each other. +// peers got /32 or /128 routes for each other. // As of 2024-05-17 we're starting to ignore it, but to let // people still downgrade Tailscale versions and not break // all peer-to-peer networking we still write it to disk (as JSON) @@ -462,39 +473,41 @@ func (v PrefsView) Persist() persist.PersistView { return v.ж.Persist.View() } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _PrefsViewNeedsRegeneration = Prefs(struct { - ControlURL string - RouteAll bool - ExitNodeID tailcfg.StableNodeID - ExitNodeIP netip.Addr - AutoExitNode ExitNodeExpression - InternalExitNodePrior tailcfg.StableNodeID - ExitNodeAllowLANAccess bool - CorpDNS bool - RunSSH bool - RunWebClient bool - WantRunning bool - LoggedOut bool - ShieldsUp bool - AdvertiseTags []string - Hostname string - NotepadURLs bool - ForceDaemon bool - Egg bool - AdvertiseRoutes []netip.Prefix - AdvertiseServices []string - NoSNAT bool - NoStatefulFiltering opt.Bool - NetfilterMode preftype.NetfilterMode - OperatorUser string - ProfileName string - AutoUpdate AutoUpdatePrefs - AppConnector AppConnectorPrefs - PostureChecking bool - NetfilterKind string - DriveShares []*drive.Share - RelayServerPort *int - AllowSingleHosts marshalAsTrueInJSON - Persist *persist.Persist + ControlURL string + RouteAll bool + ExitNodeID tailcfg.StableNodeID + ExitNodeIP netip.Addr + AutoExitNode ExitNodeExpression + InternalExitNodePrior tailcfg.StableNodeID + ExitNodeAllowLANAccess bool + CorpDNS bool + RunSSH bool + RunWebClient bool + WantRunning bool + LoggedOut bool + ShieldsUp bool + AdvertiseTags []string + Hostname string + NotepadURLs bool + ForceDaemon bool + Egg bool + AdvertiseRoutes []netip.Prefix + AdvertiseServices []string + Sync opt.Bool + NoSNAT bool + NoStatefulFiltering opt.Bool + NetfilterMode preftype.NetfilterMode + OperatorUser string + ProfileName string + AutoUpdate AutoUpdatePrefs + AppConnector AppConnectorPrefs + PostureChecking bool + NetfilterKind string + DriveShares []*drive.Share + RelayServerPort *uint16 + RelayServerStaticEndpoints []netip.AddrPort + AllowSingleHosts marshalAsTrueInJSON + Persist *persist.Persist }{}) // View returns a read-only view of ServeConfig. @@ -807,12 +820,19 @@ func (v TCPPortHandlerView) TCPForward() string { return v.ж.TCPForward } // (the HTTPS mode uses ServeConfig.Web) func (v TCPPortHandlerView) TerminateTLS() string { return v.ж.TerminateTLS } +// ProxyProtocol indicates whether to send a PROXY protocol header +// before forwarding the connection to TCPForward. +// +// This is only valid if TCPForward is non-empty. +func (v TCPPortHandlerView) ProxyProtocol() int { return v.ж.ProxyProtocol } + // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _TCPPortHandlerViewNeedsRegeneration = TCPPortHandler(struct { - HTTPS bool - HTTP bool - TCPForward string - TerminateTLS string + HTTPS bool + HTTP bool + TCPForward string + TerminateTLS string + ProxyProtocol int }{}) // View returns a read-only view of HTTPHandler. @@ -891,11 +911,27 @@ func (v HTTPHandlerView) Proxy() string { return v.ж.Proxy } // plaintext to serve (primarily for testing) func (v HTTPHandlerView) Text() string { return v.ж.Text } +// peer capabilities to forward in grant header, e.g. example.com/cap/mon +func (v HTTPHandlerView) AcceptAppCaps() views.Slice[tailcfg.PeerCapability] { + return views.SliceOf(v.ж.AcceptAppCaps) +} + +// Redirect, if not empty, is the target URL to redirect requests to. +// By default, we redirect with HTTP 302 (Found) status. +// If Redirect starts with ':', then we use that status instead. +// +// The target URL supports the following expansion variables: +// - ${HOST}: replaced with the request's Host header value +// - ${REQUEST_URI}: replaced with the request's full URI (path and query string) +func (v HTTPHandlerView) Redirect() string { return v.ж.Redirect } + // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _HTTPHandlerViewNeedsRegeneration = HTTPHandler(struct { - Path string - Proxy string - Text string + Path string + Proxy string + Text string + AcceptAppCaps []tailcfg.PeerCapability + Redirect string }{}) // View returns a read-only view of WebServerConfig. diff --git a/ipn/ipnext/ipnext.go b/ipn/ipnext/ipnext.go index 4ff37dc8e3775..fc93cc8760a0b 100644 --- a/ipn/ipnext/ipnext.go +++ b/ipn/ipnext/ipnext.go @@ -323,7 +323,8 @@ type ProfileStateChangeCallback func(_ ipn.LoginProfileView, _ ipn.PrefsView, sa // [ProfileStateChangeCallback]s are called first. // // It returns a function to be called when the cc is being shut down, -// or nil if no cleanup is needed. +// or nil if no cleanup is needed. That cleanup function should not call +// back into LocalBackend, which may be locked during shutdown. type NewControlClientCallback func(controlclient.Client, ipn.LoginProfileView) (cleanup func()) // Hooks is a collection of hooks that extensions can add to (non-concurrently) diff --git a/ipn/ipnlocal/c2n.go b/ipn/ipnlocal/c2n.go index 0c228060faf63..b5e722b97c4a4 100644 --- a/ipn/ipnlocal/c2n.go +++ b/ipn/ipnlocal/c2n.go @@ -179,7 +179,6 @@ func handleC2NDebugNetMap(b *LocalBackend, w http.ResponseWriter, r *http.Reques } field.SetZero() } - nm, _ = redactNetmapPrivateKeys(nm) return json.Marshal(nm) } diff --git a/ipn/ipnlocal/c2n_test.go b/ipn/ipnlocal/c2n_test.go index 95cd5fa6995bc..86cc6a5490865 100644 --- a/ipn/ipnlocal/c2n_test.go +++ b/ipn/ipnlocal/c2n_test.go @@ -13,21 +13,17 @@ import ( "os" "path/filepath" "reflect" - "strings" "testing" "time" "tailscale.com/ipn/store/mem" "tailscale.com/tailcfg" "tailscale.com/tstest" - "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/types/views" "tailscale.com/util/must" - "tailscale.com/util/set" - "tailscale.com/wgengine/filter/filtertype" gcmp "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -144,342 +140,8 @@ func TestHandleC2NTLSCertStatus(t *testing.T) { } -// eachStructField calls cb for each struct field in struct type tp, recursively. -func eachStructField(tp reflect.Type, cb func(reflect.Type, reflect.StructField)) { - if !strings.HasPrefix(tp.PkgPath(), "tailscale.com/") { - // Stop traversing when we reach a non-tailscale type. - return - } - - for i := range tp.NumField() { - cb(tp, tp.Field(i)) - - switch tp.Field(i).Type.Kind() { - case reflect.Struct: - eachStructField(tp.Field(i).Type, cb) - case reflect.Slice, reflect.Array, reflect.Ptr, reflect.Map: - if tp.Field(i).Type.Elem().Kind() == reflect.Struct { - eachStructField(tp.Field(i).Type.Elem(), cb) - } - } - } -} - -// eachStructValue calls cb for each struct field in the struct value v, recursively. -func eachStructValue(v reflect.Value, cb func(reflect.Type, reflect.StructField, reflect.Value)) { - if v.IsZero() { - return - } - - for i := range v.NumField() { - cb(v.Type(), v.Type().Field(i), v.Field(i)) - - switch v.Type().Field(i).Type.Kind() { - case reflect.Struct: - eachStructValue(v.Field(i), cb) - case reflect.Slice, reflect.Array, reflect.Ptr, reflect.Map: - if v.Field(i).Type().Elem().Kind() == reflect.Struct { - eachStructValue(v.Field(i).Addr().Elem(), cb) - } - } - } -} - -// TestRedactNetmapPrivateKeys tests that redactNetmapPrivateKeys redacts all private keys -// and other private fields from a netmap.NetworkMap, and only those fields. -func TestRedactNetmapPrivateKeys(t *testing.T) { - type field struct { - t reflect.Type - f string - } - f := func(t any, f string) field { - return field{reflect.TypeOf(t), f} - } - // fields is a map of all struct fields in netmap.NetworkMap and its - // sub-structs, marking each field as private (true) or public (false). - // If you add a new field to netmap.NetworkMap or its sub-structs, - // you must add it to this list, marking it as private or public. - fields := map[field]bool{ - // Private fields to be redacted. - f(netmap.NetworkMap{}, "PrivateKey"): true, - - // All other fields are public. - f(netmap.NetworkMap{}, "AllCaps"): false, - f(netmap.NetworkMap{}, "CollectServices"): false, - f(netmap.NetworkMap{}, "DERPMap"): false, - f(netmap.NetworkMap{}, "DNS"): false, - f(netmap.NetworkMap{}, "DisplayMessages"): false, - f(netmap.NetworkMap{}, "Domain"): false, - f(netmap.NetworkMap{}, "DomainAuditLogID"): false, - f(netmap.NetworkMap{}, "Expiry"): false, - f(netmap.NetworkMap{}, "MachineKey"): false, - f(netmap.NetworkMap{}, "Name"): false, - f(netmap.NetworkMap{}, "NodeKey"): false, - f(netmap.NetworkMap{}, "PacketFilter"): false, - f(netmap.NetworkMap{}, "PacketFilterRules"): false, - f(netmap.NetworkMap{}, "Peers"): false, - f(netmap.NetworkMap{}, "SSHPolicy"): false, - f(netmap.NetworkMap{}, "SelfNode"): false, - f(netmap.NetworkMap{}, "TKAEnabled"): false, - f(netmap.NetworkMap{}, "TKAHead"): false, - f(netmap.NetworkMap{}, "UserProfiles"): false, - f(filtertype.CapMatch{}, "Cap"): false, - f(filtertype.CapMatch{}, "Dst"): false, - f(filtertype.CapMatch{}, "Values"): false, - f(filtertype.Match{}, "Caps"): false, - f(filtertype.Match{}, "Dsts"): false, - f(filtertype.Match{}, "IPProto"): false, - f(filtertype.Match{}, "SrcCaps"): false, - f(filtertype.Match{}, "Srcs"): false, - f(filtertype.Match{}, "SrcsContains"): false, - f(filtertype.NetPortRange{}, "Net"): false, - f(filtertype.NetPortRange{}, "Ports"): false, - f(filtertype.PortRange{}, "First"): false, - f(filtertype.PortRange{}, "Last"): false, - f(key.DiscoPublic{}, "k"): false, - f(key.MachinePublic{}, "k"): false, - f(key.NodePrivate{}, "_"): false, - f(key.NodePrivate{}, "k"): false, - f(key.NodePublic{}, "k"): false, - f(tailcfg.CapGrant{}, "CapMap"): false, - f(tailcfg.CapGrant{}, "Caps"): false, - f(tailcfg.CapGrant{}, "Dsts"): false, - f(tailcfg.DERPHomeParams{}, "RegionScore"): false, - f(tailcfg.DERPMap{}, "HomeParams"): false, - f(tailcfg.DERPMap{}, "OmitDefaultRegions"): false, - f(tailcfg.DERPMap{}, "Regions"): false, - f(tailcfg.DNSConfig{}, "CertDomains"): false, - f(tailcfg.DNSConfig{}, "Domains"): false, - f(tailcfg.DNSConfig{}, "ExitNodeFilteredSet"): false, - f(tailcfg.DNSConfig{}, "ExtraRecords"): false, - f(tailcfg.DNSConfig{}, "FallbackResolvers"): false, - f(tailcfg.DNSConfig{}, "Nameservers"): false, - f(tailcfg.DNSConfig{}, "Proxied"): false, - f(tailcfg.DNSConfig{}, "Resolvers"): false, - f(tailcfg.DNSConfig{}, "Routes"): false, - f(tailcfg.DNSConfig{}, "TempCorpIssue13969"): false, - f(tailcfg.DNSRecord{}, "Name"): false, - f(tailcfg.DNSRecord{}, "Type"): false, - f(tailcfg.DNSRecord{}, "Value"): false, - f(tailcfg.DisplayMessageAction{}, "Label"): false, - f(tailcfg.DisplayMessageAction{}, "URL"): false, - f(tailcfg.DisplayMessage{}, "ImpactsConnectivity"): false, - f(tailcfg.DisplayMessage{}, "PrimaryAction"): false, - f(tailcfg.DisplayMessage{}, "Severity"): false, - f(tailcfg.DisplayMessage{}, "Text"): false, - f(tailcfg.DisplayMessage{}, "Title"): false, - f(tailcfg.FilterRule{}, "CapGrant"): false, - f(tailcfg.FilterRule{}, "DstPorts"): false, - f(tailcfg.FilterRule{}, "IPProto"): false, - f(tailcfg.FilterRule{}, "SrcBits"): false, - f(tailcfg.FilterRule{}, "SrcIPs"): false, - f(tailcfg.HostinfoView{}, "ж"): false, - f(tailcfg.Hostinfo{}, "AllowsUpdate"): false, - f(tailcfg.Hostinfo{}, "App"): false, - f(tailcfg.Hostinfo{}, "AppConnector"): false, - f(tailcfg.Hostinfo{}, "BackendLogID"): false, - f(tailcfg.Hostinfo{}, "Cloud"): false, - f(tailcfg.Hostinfo{}, "Container"): false, - f(tailcfg.Hostinfo{}, "Desktop"): false, - f(tailcfg.Hostinfo{}, "DeviceModel"): false, - f(tailcfg.Hostinfo{}, "Distro"): false, - f(tailcfg.Hostinfo{}, "DistroCodeName"): false, - f(tailcfg.Hostinfo{}, "DistroVersion"): false, - f(tailcfg.Hostinfo{}, "Env"): false, - f(tailcfg.Hostinfo{}, "ExitNodeID"): false, - f(tailcfg.Hostinfo{}, "FrontendLogID"): false, - f(tailcfg.Hostinfo{}, "GoArch"): false, - f(tailcfg.Hostinfo{}, "GoArchVar"): false, - f(tailcfg.Hostinfo{}, "GoVersion"): false, - f(tailcfg.Hostinfo{}, "Hostname"): false, - f(tailcfg.Hostinfo{}, "IPNVersion"): false, - f(tailcfg.Hostinfo{}, "IngressEnabled"): false, - f(tailcfg.Hostinfo{}, "Location"): false, - f(tailcfg.Hostinfo{}, "Machine"): false, - f(tailcfg.Hostinfo{}, "NetInfo"): false, - f(tailcfg.Hostinfo{}, "NoLogsNoSupport"): false, - f(tailcfg.Hostinfo{}, "OS"): false, - f(tailcfg.Hostinfo{}, "OSVersion"): false, - f(tailcfg.Hostinfo{}, "Package"): false, - f(tailcfg.Hostinfo{}, "PushDeviceToken"): false, - f(tailcfg.Hostinfo{}, "RequestTags"): false, - f(tailcfg.Hostinfo{}, "RoutableIPs"): false, - f(tailcfg.Hostinfo{}, "SSH_HostKeys"): false, - f(tailcfg.Hostinfo{}, "Services"): false, - f(tailcfg.Hostinfo{}, "ServicesHash"): false, - f(tailcfg.Hostinfo{}, "ShareeNode"): false, - f(tailcfg.Hostinfo{}, "ShieldsUp"): false, - f(tailcfg.Hostinfo{}, "StateEncrypted"): false, - f(tailcfg.Hostinfo{}, "TPM"): false, - f(tailcfg.Hostinfo{}, "Userspace"): false, - f(tailcfg.Hostinfo{}, "UserspaceRouter"): false, - f(tailcfg.Hostinfo{}, "WireIngress"): false, - f(tailcfg.Hostinfo{}, "WoLMACs"): false, - f(tailcfg.Location{}, "City"): false, - f(tailcfg.Location{}, "CityCode"): false, - f(tailcfg.Location{}, "Country"): false, - f(tailcfg.Location{}, "CountryCode"): false, - f(tailcfg.Location{}, "Latitude"): false, - f(tailcfg.Location{}, "Longitude"): false, - f(tailcfg.Location{}, "Priority"): false, - f(tailcfg.NetInfo{}, "DERPLatency"): false, - f(tailcfg.NetInfo{}, "FirewallMode"): false, - f(tailcfg.NetInfo{}, "HairPinning"): false, - f(tailcfg.NetInfo{}, "HavePortMap"): false, - f(tailcfg.NetInfo{}, "LinkType"): false, - f(tailcfg.NetInfo{}, "MappingVariesByDestIP"): false, - f(tailcfg.NetInfo{}, "OSHasIPv6"): false, - f(tailcfg.NetInfo{}, "PCP"): false, - f(tailcfg.NetInfo{}, "PMP"): false, - f(tailcfg.NetInfo{}, "PreferredDERP"): false, - f(tailcfg.NetInfo{}, "UPnP"): false, - f(tailcfg.NetInfo{}, "WorkingICMPv4"): false, - f(tailcfg.NetInfo{}, "WorkingIPv6"): false, - f(tailcfg.NetInfo{}, "WorkingUDP"): false, - f(tailcfg.NetPortRange{}, "Bits"): false, - f(tailcfg.NetPortRange{}, "IP"): false, - f(tailcfg.NetPortRange{}, "Ports"): false, - f(tailcfg.NetPortRange{}, "_"): false, - f(tailcfg.NodeView{}, "ж"): false, - f(tailcfg.Node{}, "Addresses"): false, - f(tailcfg.Node{}, "AllowedIPs"): false, - f(tailcfg.Node{}, "Cap"): false, - f(tailcfg.Node{}, "CapMap"): false, - f(tailcfg.Node{}, "Capabilities"): false, - f(tailcfg.Node{}, "ComputedName"): false, - f(tailcfg.Node{}, "ComputedNameWithHost"): false, - f(tailcfg.Node{}, "Created"): false, - f(tailcfg.Node{}, "DataPlaneAuditLogID"): false, - f(tailcfg.Node{}, "DiscoKey"): false, - f(tailcfg.Node{}, "Endpoints"): false, - f(tailcfg.Node{}, "ExitNodeDNSResolvers"): false, - f(tailcfg.Node{}, "Expired"): false, - f(tailcfg.Node{}, "HomeDERP"): false, - f(tailcfg.Node{}, "Hostinfo"): false, - f(tailcfg.Node{}, "ID"): false, - f(tailcfg.Node{}, "IsJailed"): false, - f(tailcfg.Node{}, "IsWireGuardOnly"): false, - f(tailcfg.Node{}, "Key"): false, - f(tailcfg.Node{}, "KeyExpiry"): false, - f(tailcfg.Node{}, "KeySignature"): false, - f(tailcfg.Node{}, "LastSeen"): false, - f(tailcfg.Node{}, "LegacyDERPString"): false, - f(tailcfg.Node{}, "Machine"): false, - f(tailcfg.Node{}, "MachineAuthorized"): false, - f(tailcfg.Node{}, "Name"): false, - f(tailcfg.Node{}, "Online"): false, - f(tailcfg.Node{}, "PrimaryRoutes"): false, - f(tailcfg.Node{}, "SelfNodeV4MasqAddrForThisPeer"): false, - f(tailcfg.Node{}, "SelfNodeV6MasqAddrForThisPeer"): false, - f(tailcfg.Node{}, "Sharer"): false, - f(tailcfg.Node{}, "StableID"): false, - f(tailcfg.Node{}, "Tags"): false, - f(tailcfg.Node{}, "UnsignedPeerAPIOnly"): false, - f(tailcfg.Node{}, "User"): false, - f(tailcfg.Node{}, "computedHostIfDifferent"): false, - f(tailcfg.PortRange{}, "First"): false, - f(tailcfg.PortRange{}, "Last"): false, - f(tailcfg.SSHPolicy{}, "Rules"): false, - f(tailcfg.Service{}, "Description"): false, - f(tailcfg.Service{}, "Port"): false, - f(tailcfg.Service{}, "Proto"): false, - f(tailcfg.Service{}, "_"): false, - f(tailcfg.TPMInfo{}, "FamilyIndicator"): false, - f(tailcfg.TPMInfo{}, "FirmwareVersion"): false, - f(tailcfg.TPMInfo{}, "Manufacturer"): false, - f(tailcfg.TPMInfo{}, "Model"): false, - f(tailcfg.TPMInfo{}, "SpecRevision"): false, - f(tailcfg.TPMInfo{}, "Vendor"): false, - f(tailcfg.UserProfileView{}, "ж"): false, - f(tailcfg.UserProfile{}, "DisplayName"): false, - f(tailcfg.UserProfile{}, "ID"): false, - f(tailcfg.UserProfile{}, "LoginName"): false, - f(tailcfg.UserProfile{}, "ProfilePicURL"): false, - f(views.Slice[ipproto.Proto]{}, "ж"): false, - f(views.Slice[tailcfg.FilterRule]{}, "ж"): false, - } - - t.Run("field_list_is_complete", func(t *testing.T) { - seen := set.Set[field]{} - eachStructField(reflect.TypeOf(netmap.NetworkMap{}), func(rt reflect.Type, sf reflect.StructField) { - f := field{rt, sf.Name} - seen.Add(f) - if _, ok := fields[f]; !ok { - // Fail the test if netmap has a field not in the list. If you see this test - // failure, please add the new field to the fields map above, marking it as private or public. - t.Errorf("netmap field has not been declared as private or public: %v.%v", rt, sf.Name) - } - }) - - for want := range fields { - if !seen.Contains(want) { - // Fail the test if the list has a field not in netmap. If you see this test - // failure, please remove the field from the fields map above. - t.Errorf("field declared that has not been found in netmap: %v.%v", want.t, want.f) - } - } - }) - - // tests is a list of test cases, each with a non-redacted netmap and the expected redacted netmap. - // If you add a new private field to netmap.NetworkMap or its sub-structs, please add a test case - // here that has that field set in nm, and the expected redacted value in wantRedacted. - tests := []struct { - name string - nm *netmap.NetworkMap - wantRedacted *netmap.NetworkMap - }{ - { - name: "redact_private_key", - nm: &netmap.NetworkMap{ - PrivateKey: key.NewNode(), - }, - wantRedacted: &netmap.NetworkMap{}, - }, - } - - // confirmedRedacted is a set of all private fields that have been covered by the tests above. - confirmedRedacted := set.Set[field]{} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - - // Record which of the private fields are set in the non-redacted netmap. - eachStructValue(reflect.ValueOf(tt.nm).Elem(), func(tt reflect.Type, sf reflect.StructField, v reflect.Value) { - f := field{tt, sf.Name} - if shouldRedact := fields[f]; shouldRedact && !v.IsZero() { - confirmedRedacted.Add(f) - } - }) - - got, _ := redactNetmapPrivateKeys(tt.nm) - if !reflect.DeepEqual(got, tt.wantRedacted) { - t.Errorf("unexpected redacted netmap: %+v", got) - } - - // Check that all private fields in the redacted netmap are zero. - eachStructValue(reflect.ValueOf(got).Elem(), func(tt reflect.Type, sf reflect.StructField, v reflect.Value) { - f := field{tt, sf.Name} - if shouldRedact := fields[f]; shouldRedact && !v.IsZero() { - t.Errorf("field not redacted: %v.%v", tt, sf.Name) - } - }) - }) - } - - // Check that all private fields in netmap.NetworkMap and its sub-structs - // are covered by the tests above. If you see a test failure here, - // please add a test case above that has that field set in nm. - for f, shouldRedact := range fields { - if shouldRedact { - if !confirmedRedacted.Contains(f) { - t.Errorf("field not covered by tests: %v.%v", f.t, f.f) - } - } - } -} - func TestHandleC2NDebugNetmap(t *testing.T) { nm := &netmap.NetworkMap{ - Name: "myhost", SelfNode: (&tailcfg.Node{ ID: 100, Name: "myhost", @@ -496,10 +158,7 @@ func TestHandleC2NDebugNetmap(t *testing.T) { Hostinfo: (&tailcfg.Hostinfo{Hostname: "peer1"}).View(), }).View(), }, - PrivateKey: key.NewNode(), } - withoutPrivateKey := *nm - withoutPrivateKey.PrivateKey = key.NodePrivate{} for _, tt := range []struct { name string @@ -508,12 +167,12 @@ func TestHandleC2NDebugNetmap(t *testing.T) { }{ { name: "simple_get", - want: &withoutPrivateKey, + want: nm, }, { name: "post_no_omit", req: &tailcfg.C2NDebugNetmapRequest{}, - want: &withoutPrivateKey, + want: nm, }, { name: "post_omit_peers_and_name", @@ -525,7 +184,7 @@ func TestHandleC2NDebugNetmap(t *testing.T) { { name: "post_omit_nonexistent_field", req: &tailcfg.C2NDebugNetmapRequest{OmitFields: []string{"ThisFieldDoesNotExist"}}, - want: &withoutPrivateKey, + want: nm, }, } { t.Run(tt.name, func(t *testing.T) { diff --git a/ipn/ipnlocal/cert.go b/ipn/ipnlocal/cert.go index ab49976c8aeea..8804fcb5ce2e8 100644 --- a/ipn/ipnlocal/cert.go +++ b/ipn/ipnlocal/cert.go @@ -30,7 +30,6 @@ import ( "runtime" "slices" "strings" - "sync" "time" "tailscale.com/atomicfile" @@ -42,6 +41,7 @@ import ( "tailscale.com/ipn/store" "tailscale.com/ipn/store/mem" "tailscale.com/net/bakedroots" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tempfork/acme" "tailscale.com/types/logger" @@ -60,9 +60,9 @@ var ( // acmeMu guards all ACME operations, so concurrent requests // for certs don't slam ACME. The first will go through and // populate the on-disk cache and the rest should use that. - acmeMu sync.Mutex + acmeMu syncs.Mutex - renewMu sync.Mutex // lock order: acmeMu before renewMu + renewMu syncs.Mutex // lock order: acmeMu before renewMu renewCertAt = map[string]time.Time{} ) @@ -144,7 +144,11 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string if minValidity == 0 { logf("starting async renewal") // Start renewal in the background, return current valid cert. - b.goTracker.Go(func() { getCertPEM(context.Background(), b, cs, logf, traceACME, domain, now, minValidity) }) + b.goTracker.Go(func() { + if _, err := getCertPEM(context.Background(), b, cs, logf, traceACME, domain, now, minValidity); err != nil { + logf("async renewal failed: getCertPem: %v", err) + } + }) return pair, nil } // If the caller requested a specific validity duration, fall through @@ -547,8 +551,11 @@ var getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf l // If we have a previous cert, include it in the order. Assuming we're // within the ARI renewal window this should exclude us from LE rate // limits. + // Note that this order extension will fail renewals if the ACME account key has changed + // since the last issuance, see + // https://github.com/tailscale/tailscale/issues/18251 var opts []acme.OrderOption - if previous != nil { + if previous != nil && !envknob.Bool("TS_DEBUG_ACME_FORCE_RENEWAL") { prevCrt, err := previous.parseCertificate() if err == nil { opts = append(opts, acme.WithOrderReplacesCert(prevCrt)) diff --git a/ipn/ipnlocal/dnsconfig_test.go b/ipn/ipnlocal/dnsconfig_test.go index 71f1751488788..e23d8a057546f 100644 --- a/ipn/ipnlocal/dnsconfig_test.go +++ b/ipn/ipnlocal/dnsconfig_test.go @@ -70,8 +70,8 @@ func TestDNSConfigForNetmap(t *testing.T) { { name: "self_name_and_peers", nm: &netmap.NetworkMap{ - Name: "myname.net", SelfNode: (&tailcfg.Node{ + Name: "myname.net.", Addresses: ipps("100.101.101.101"), }).View(), }, @@ -109,15 +109,15 @@ func TestDNSConfigForNetmap(t *testing.T) { // even if they have IPv4. name: "v6_only_self", nm: &netmap.NetworkMap{ - Name: "myname.net", SelfNode: (&tailcfg.Node{ + Name: "myname.net.", Addresses: ipps("fe75::1"), }).View(), }, peers: nodeViews([]*tailcfg.Node{ { ID: 1, - Name: "peera.net", + Name: "peera.net.", Addresses: ipps("100.102.0.1", "100.102.0.2", "fe75::1001"), }, { @@ -146,8 +146,8 @@ func TestDNSConfigForNetmap(t *testing.T) { { name: "extra_records", nm: &netmap.NetworkMap{ - Name: "myname.net", SelfNode: (&tailcfg.Node{ + Name: "myname.net.", Addresses: ipps("100.101.101.101"), }).View(), DNS: tailcfg.DNSConfig{ @@ -171,7 +171,9 @@ func TestDNSConfigForNetmap(t *testing.T) { { name: "corp_dns_misc", nm: &netmap.NetworkMap{ - Name: "host.some.domain.net.", + SelfNode: (&tailcfg.Node{ + Name: "host.some.domain.net.", + }).View(), DNS: tailcfg.DNSConfig{ Proxied: true, Domains: []string{"foo.com", "bar.com"}, @@ -331,8 +333,8 @@ func TestDNSConfigForNetmap(t *testing.T) { { name: "self_expired", nm: &netmap.NetworkMap{ - Name: "myname.net", SelfNode: (&tailcfg.Node{ + Name: "myname.net.", Addresses: ipps("100.101.101.101"), }).View(), }, diff --git a/ipn/ipnlocal/drive.go b/ipn/ipnlocal/drive.go index 7d6dc2427adae..456cd45441ba9 100644 --- a/ipn/ipnlocal/drive.go +++ b/ipn/ipnlocal/drive.go @@ -433,7 +433,7 @@ func (rbw *responseBodyWrapper) Close() error { // b.Dialer().PeerAPITransport() with metrics tracking. type driveTransport struct { b *LocalBackend - tr *http.Transport + tr http.RoundTripper } func (b *LocalBackend) newDriveTransport() *driveTransport { @@ -443,7 +443,7 @@ func (b *LocalBackend) newDriveTransport() *driveTransport { } } -func (dt *driveTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) { +func (dt *driveTransport) RoundTrip(req *http.Request) (*http.Response, error) { // Some WebDAV clients include origin and refer headers, which peerapi does // not like. Remove them. req.Header.Del("origin") @@ -455,42 +455,45 @@ func (dt *driveTransport) RoundTrip(req *http.Request) (resp *http.Response, err req.Body = bw } - defer func() { - contentType := "unknown" - if ct := req.Header.Get("Content-Type"); ct != "" { - contentType = ct - } + resp, err := dt.tr.RoundTrip(req) + if err != nil { + return nil, err + } - dt.b.mu.Lock() - selfNodeKey := dt.b.currentNode().Self().Key().ShortString() - dt.b.mu.Unlock() - n, _, ok := dt.b.WhoIs("tcp", netip.MustParseAddrPort(req.URL.Host)) - shareNodeKey := "unknown" - if ok { - shareNodeKey = string(n.Key().ShortString()) - } + contentType := "unknown" + if ct := req.Header.Get("Content-Type"); ct != "" { + contentType = ct + } - rbw := responseBodyWrapper{ - log: dt.b.logf, - logVerbose: req.Method != httpm.GET && req.Method != httpm.PUT, // other requests like PROPFIND are quite chatty, so we log those at verbose level - method: req.Method, - bytesTx: int64(bw.bytesRead), - selfNodeKey: selfNodeKey, - shareNodeKey: shareNodeKey, - contentType: contentType, - contentLength: resp.ContentLength, - fileExtension: parseDriveFileExtensionForLog(req.URL.Path), - statusCode: resp.StatusCode, - ReadCloser: resp.Body, - } + dt.b.mu.Lock() + selfNodeKey := dt.b.currentNode().Self().Key().ShortString() + dt.b.mu.Unlock() + n, _, ok := dt.b.WhoIs("tcp", netip.MustParseAddrPort(req.URL.Host)) + shareNodeKey := "unknown" + if ok { + shareNodeKey = string(n.Key().ShortString()) + } - if resp.StatusCode >= 400 { - // in case of error response, just log immediately - rbw.logAccess("") - } else { - resp.Body = &rbw - } - }() + rbw := responseBodyWrapper{ + log: dt.b.logf, + logVerbose: req.Method != httpm.GET && req.Method != httpm.PUT, // other requests like PROPFIND are quite chatty, so we log those at verbose level + method: req.Method, + bytesTx: int64(bw.bytesRead), + selfNodeKey: selfNodeKey, + shareNodeKey: shareNodeKey, + contentType: contentType, + contentLength: resp.ContentLength, + fileExtension: parseDriveFileExtensionForLog(req.URL.Path), + statusCode: resp.StatusCode, + ReadCloser: resp.Body, + } + + if resp.StatusCode >= 400 { + // in case of error response, just log immediately + rbw.logAccess("") + } else { + resp.Body = &rbw + } - return dt.tr.RoundTrip(req) + return resp, nil } diff --git a/ipn/ipnlocal/drive_test.go b/ipn/ipnlocal/drive_test.go new file mode 100644 index 0000000000000..323c3821499ed --- /dev/null +++ b/ipn/ipnlocal/drive_test.go @@ -0,0 +1,50 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_drive + +package ipnlocal + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" +) + +// TestDriveTransportRoundTrip_NetworkError tests that driveTransport.RoundTrip +// doesn't panic when the underlying transport returns a nil response with an +// error. +// +// See: https://github.com/tailscale/tailscale/issues/17306 +func TestDriveTransportRoundTrip_NetworkError(t *testing.T) { + b := newTestLocalBackend(t) + + testErr := errors.New("network connection failed") + mockTransport := &mockRoundTripper{ + err: testErr, + } + dt := &driveTransport{ + b: b, + tr: mockTransport, + } + + req := httptest.NewRequest("GET", "http://100.64.0.1:1234/some/path", nil) + resp, err := dt.RoundTrip(req) + if err == nil { + t.Fatal("got nil error, expected non-nil") + } else if !errors.Is(err, testErr) { + t.Errorf("got error %v, expected %v", err, testErr) + } + if resp != nil { + t.Errorf("wanted nil response, got %v", resp) + } +} + +type mockRoundTripper struct { + err error +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return nil, m.err +} diff --git a/ipn/ipnlocal/hwattest.go b/ipn/ipnlocal/hwattest.go new file mode 100644 index 0000000000000..2c93cad4c97ff --- /dev/null +++ b/ipn/ipnlocal/hwattest.go @@ -0,0 +1,48 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_tpm + +package ipnlocal + +import ( + "errors" + + "tailscale.com/feature" + "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/types/persist" +) + +func init() { + feature.HookGenerateAttestationKeyIfEmpty.Set(generateAttestationKeyIfEmpty) +} + +// generateAttestationKeyIfEmpty generates a new hardware attestation key if +// none exists. It returns true if a new key was generated and stored in +// p.AttestationKey. +func generateAttestationKeyIfEmpty(p *persist.Persist, logf logger.Logf) (bool, error) { + // attempt to generate a new hardware attestation key if none exists + var ak key.HardwareAttestationKey + if p != nil { + ak = p.AttestationKey + } + + if ak == nil || ak.IsZero() { + var err error + ak, err = key.NewHardwareAttestationKey() + if err != nil { + if !errors.Is(err, key.ErrUnsupported) { + logf("failed to create hardware attestation key: %v", err) + } + } else if ak != nil { + logf("using new hardware attestation key: %v", ak.Public()) + if p == nil { + p = &persist.Persist{} + } + p.AttestationKey = ak + return true, nil + } + } + return false, nil +} diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 1ffbbbca624a3..2ea8c62391e4f 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -10,12 +10,10 @@ import ( "context" "crypto/sha256" "encoding/binary" - "encoding/hex" "encoding/json" "errors" "fmt" "io" - "log" "math" "math/rand/v2" "net" @@ -181,14 +179,14 @@ var ( // state machine generates events back out to zero or more components. type LocalBackend struct { // Elements that are thread-safe or constant after construction. - ctx context.Context // canceled by [LocalBackend.Shutdown] - ctxCancel context.CancelCauseFunc // cancels ctx - logf logger.Logf // general logging - keyLogf logger.Logf // for printing list of peers on change - statsLogf logger.Logf // for printing peers stats on change - sys *tsd.System - eventSubs eventbus.Monitor - appcTask execqueue.ExecQueue // handles updates from appc + ctx context.Context // canceled by [LocalBackend.Shutdown] + ctxCancel context.CancelCauseFunc // cancels ctx + logf logger.Logf // general logging + keyLogf logger.Logf // for printing list of peers on change + statsLogf logger.Logf // for printing peers stats on change + sys *tsd.System + eventClient *eventbus.Client + appcTask execqueue.ExecQueue // handles updates from appc health *health.Tracker // always non-nil polc policyclient.Client // always non-nil @@ -247,8 +245,10 @@ type LocalBackend struct { // to prevent state changes while invoking callbacks. extHost *ExtensionHost + peerAPIPorts syncs.AtomicValue[map[netip.Addr]int] // can be read without b.mu held; TODO(nickkhyl): remove or move to nodeBackend? + // The mutex protects the following elements. - mu sync.Mutex + mu syncs.Mutex // currentNodeAtomic is the current node context. It is always non-nil. // It must be re-created when [LocalBackend] switches to a different profile/node @@ -272,9 +272,14 @@ type LocalBackend struct { sshServer SSHServer // or nil, initialized lazily. appConnector *appc.AppConnector // or nil, initialized when configured. // notifyCancel cancels notifications to the current SetNotifyCallback. - notifyCancel context.CancelFunc - cc controlclient.Client // TODO(nickkhyl): move to nodeBackend - ccAuto *controlclient.Auto // if cc is of type *controlclient.Auto; TODO(nickkhyl): move to nodeBackend + notifyCancel context.CancelFunc + cc controlclient.Client // TODO(nickkhyl): move to nodeBackend + ccAuto *controlclient.Auto // if cc is of type *controlclient.Auto; TODO(nickkhyl): move to nodeBackend + + // ignoreControlClientUpdates indicates whether we want to ignore SetControlClientStatus updates + // before acquiring b.mu. This is used during shutdown to avoid deadlocks. + ignoreControlClientUpdates atomic.Bool + machinePrivKey key.MachinePrivate tka *tkaState // TODO(nickkhyl): move to nodeBackend state ipn.State // TODO(nickkhyl): move to nodeBackend @@ -292,8 +297,8 @@ type LocalBackend struct { authActor ipnauth.Actor // an actor who called [LocalBackend.StartLoginInteractive] last, or nil; TODO(nickkhyl): move to nodeBackend egg bool prevIfState *netmon.State - peerAPIServer *peerAPIServer // or nil - peerAPIListeners []*peerAPIListener + peerAPIServer *peerAPIServer // or nil + peerAPIListeners []*peerAPIListener // TODO(nickkhyl): move to nodeBackend loginFlags controlclient.LoginFlags notifyWatchers map[string]*watchSession // by session ID lastStatusTime time.Time // status.AsOf value of the last processed status update @@ -315,10 +320,6 @@ type LocalBackend struct { serveListeners map[netip.AddrPort]*localListener // listeners for local serve traffic serveProxyHandlers sync.Map // string (HTTPHandler.Proxy) => *reverseProxy - // mu must be held before calling statusChanged.Wait() or - // statusChanged.Broadcast(). - statusChanged *sync.Cond - // dialPlan is any dial plan that we've received from the control // server during a previous connection; it is cleared on logout. dialPlan atomic.Pointer[tailcfg.ControlDialPlan] // TODO(nickkhyl): maybe move to nodeBackend? @@ -329,14 +330,14 @@ type LocalBackend struct { // // tkaSyncLock MUST be taken before mu (or inversely, mu must not be held // at the moment that tkaSyncLock is taken). - tkaSyncLock sync.Mutex + tkaSyncLock syncs.Mutex clock tstime.Clock // Last ClientVersion received in MapResponse, guarded by mu. lastClientVersion *tailcfg.ClientVersion // lastNotifiedDriveSharesMu guards lastNotifiedDriveShares - lastNotifiedDriveSharesMu sync.Mutex + lastNotifiedDriveSharesMu syncs.Mutex // lastNotifiedDriveShares keeps track of the last set of shares that we // notified about. @@ -521,8 +522,6 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo b.currentNodeAtomic.Store(nb) nb.ready() - mConn.SetNetInfoCallback(b.setNetInfo) - if sys.InitialConfig != nil { if err := b.initPrefsFromConfig(sys.InitialConfig); err != nil { return nil, err @@ -546,7 +545,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo netMon := sys.NetMon.Get() b.sockstatLogger, err = sockstatlog.NewLogger(logpolicy.LogsDir(logf), logf, logID, netMon, sys.HealthTracker.Get(), sys.Bus.Get()) if err != nil { - log.Printf("error setting up sockstat logger: %v", err) + logf("error setting up sockstat logger: %v", err) } // Enable sockstats logs only on non-mobile unstable builds if version.IsUnstableBuild() && !version.IsMobile() && b.sockstatLogger != nil { @@ -560,7 +559,6 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo b.setTCPPortsIntercepted(nil) - b.statusChanged = sync.NewCond(&b.mu) b.e.SetStatusCallback(b.setWgengineStatus) b.prevIfState = netMon.InterfaceState() @@ -591,76 +589,47 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo // Start the event bus late, once all the assignments above are done. // (See previous race in tailscale/tailscale#17252) ec := b.Sys().Bus.Get().Client("ipnlocal.LocalBackend") - b.eventSubs = ec.Monitor(b.consumeEventbusTopics(ec)) - - return b, nil -} - -// consumeEventbusTopics consumes events from all relevant -// [eventbus.Subscriber]'s and passes them to their related handler. Events are -// always handled in the order they are received, i.e. the next event is not -// read until the previous event's handler has returned. It returns when the -// [eventbus.Client] is closed. -func (b *LocalBackend) consumeEventbusTopics(ec *eventbus.Client) func(*eventbus.Client) { - clientVersionSub := eventbus.Subscribe[tailcfg.ClientVersion](ec) - autoUpdateSub := eventbus.Subscribe[controlclient.AutoUpdate](ec) - - var healthChange <-chan health.Change + b.eventClient = ec + eventbus.SubscribeFunc(ec, b.onClientVersion) + eventbus.SubscribeFunc(ec, func(au controlclient.AutoUpdate) { + b.onTailnetDefaultAutoUpdate(au.Value) + }) + eventbus.SubscribeFunc(ec, func(cd netmon.ChangeDelta) { b.linkChange(&cd) }) if buildfeatures.HasHealth { - healthChangeSub := eventbus.Subscribe[health.Change](ec) - healthChange = healthChangeSub.Events() + eventbus.SubscribeFunc(ec, b.onHealthChange) } - changeDeltaSub := eventbus.Subscribe[netmon.ChangeDelta](ec) - routeUpdateSub := eventbus.Subscribe[appctype.RouteUpdate](ec) - storeRoutesSub := eventbus.Subscribe[appctype.RouteInfo](ec) - - var portlist <-chan PortlistServices if buildfeatures.HasPortList { - portlistSub := eventbus.Subscribe[PortlistServices](ec) - portlist = portlistSub.Events() + eventbus.SubscribeFunc(ec, b.setPortlistServices) } + eventbus.SubscribeFunc(ec, b.onAppConnectorRouteUpdate) + eventbus.SubscribeFunc(ec, b.onAppConnectorStoreRoutes) + mConn.SetNetInfoCallback(b.setNetInfo) // TODO(tailscale/tailscale#17887): move to eventbus - return func(ec *eventbus.Client) { - for { - select { - case <-ec.Done(): - return - case clientVersion := <-clientVersionSub.Events(): - b.onClientVersion(&clientVersion) - case au := <-autoUpdateSub.Events(): - b.onTailnetDefaultAutoUpdate(au.Value) - case change := <-healthChange: - b.onHealthChange(change) - case changeDelta := <-changeDeltaSub.Events(): - b.linkChange(&changeDelta) - - case pl := <-portlist: - if buildfeatures.HasPortList { // redundant, but explicit for linker deadcode and humans - b.setPortlistServices(pl) - } - case ru := <-routeUpdateSub.Events(): - // TODO(creachadair, 2025-10-02): It is currently possible for updates produced under - // one profile to arrive and be applied after a switch to another profile. - // We need to find a way to ensure that changes to the backend state are applied - // consistently in the presnce of profile changes, which currently may not happen in - // a single atomic step. See: https://github.com/tailscale/tailscale/issues/17414 - b.appcTask.Add(func() { - if err := b.AdvertiseRoute(ru.Advertise...); err != nil { - b.logf("appc: failed to advertise routes: %v: %v", ru.Advertise, err) - } - if err := b.UnadvertiseRoute(ru.Unadvertise...); err != nil { - b.logf("appc: failed to unadvertise routes: %v: %v", ru.Unadvertise, err) - } - }) - case ri := <-storeRoutesSub.Events(): - // Whether or not routes should be stored can change over time. - shouldStoreRoutes := b.ControlKnobs().AppCStoreRoutes.Load() - if shouldStoreRoutes { - if err := b.storeRouteInfo(ri); err != nil { - b.logf("appc: failed to store route info: %v", err) - } - } - } + return b, nil +} + +func (b *LocalBackend) onAppConnectorRouteUpdate(ru appctype.RouteUpdate) { + // TODO(creachadair, 2025-10-02): It is currently possible for updates produced under + // one profile to arrive and be applied after a switch to another profile. + // We need to find a way to ensure that changes to the backend state are applied + // consistently in the presnce of profile changes, which currently may not happen in + // a single atomic step. See: https://github.com/tailscale/tailscale/issues/17414 + b.appcTask.Add(func() { + if err := b.AdvertiseRoute(ru.Advertise...); err != nil { + b.logf("appc: failed to advertise routes: %v: %v", ru.Advertise, err) + } + if err := b.UnadvertiseRoute(ru.Unadvertise...); err != nil { + b.logf("appc: failed to unadvertise routes: %v: %v", ru.Unadvertise, err) + } + }) +} + +func (b *LocalBackend) onAppConnectorStoreRoutes(ri appctype.RouteInfo) { + // Whether or not routes should be stored can change over time. + shouldStoreRoutes := b.ControlKnobs().AppCStoreRoutes.Load() + if shouldStoreRoutes { + if err := b.storeRouteInfo(ri); err != nil { + b.logf("appc: failed to store route info: %v", err) } } } @@ -869,8 +838,8 @@ func (b *LocalBackend) Dialer() *tsdial.Dialer { // It returns (false, nil) if not running in declarative mode, (true, nil) on // success, or (false, error) on failure. func (b *LocalBackend) ReloadConfig() (ok bool, err error) { - unlock := b.lockAndGetUnlock() - defer unlock() + b.mu.Lock() + defer b.mu.Unlock() if b.conf == nil { return false, nil } @@ -878,7 +847,7 @@ func (b *LocalBackend) ReloadConfig() (ok bool, err error) { if err != nil { return false, err } - if err := b.setConfigLockedOnEntry(conf, unlock); err != nil { + if err := b.setConfigLocked(conf); err != nil { return false, fmt.Errorf("error setting config: %w", err) } @@ -902,6 +871,7 @@ func (b *LocalBackend) initPrefsFromConfig(conf *conffile.Config) error { if err := b.pm.SetPrefs(p.View(), ipn.NetworkProfile{}); err != nil { return err } + b.updateWarnSync(p.View()) b.setStaticEndpointsFromConfigLocked(conf) b.conf = conf return nil @@ -935,10 +905,9 @@ func (b *LocalBackend) setStateLocked(state ipn.State) { } } -// setConfigLockedOnEntry uses the provided config to update the backend's prefs +// setConfigLocked uses the provided config to update the backend's prefs // and other state. -func (b *LocalBackend) setConfigLockedOnEntry(conf *conffile.Config, unlock unlockOnce) error { - defer unlock() +func (b *LocalBackend) setConfigLocked(conf *conffile.Config) error { p := b.pm.CurrentPrefs().AsStruct() mp, err := conf.Parsed.ToPrefs() if err != nil { @@ -946,7 +915,7 @@ func (b *LocalBackend) setConfigLockedOnEntry(conf *conffile.Config, unlock unlo } p.ApplyEdits(&mp) b.setStaticEndpointsFromConfigLocked(conf) - b.setPrefsLockedOnEntry(p, unlock) + b.setPrefsLocked(p) b.conf = conf return nil @@ -964,7 +933,12 @@ func (b *LocalBackend) pauseOrResumeControlClientLocked() { return } networkUp := b.prevIfState.AnyInterfaceUp() - b.cc.SetPaused((b.state == ipn.Stopped && b.NetMap() != nil) || (!networkUp && !testenv.InTest() && !assumeNetworkUpdateForTest())) + pauseForNetwork := (b.state == ipn.Stopped && b.NetMap() != nil) || (!networkUp && !testenv.InTest() && !assumeNetworkUpdateForTest()) + + prefs := b.pm.CurrentPrefs() + pauseForSyncPref := prefs.Valid() && prefs.Sync().EqualBool(false) + + b.cc.SetPaused(pauseForNetwork || pauseForSyncPref) } // DisconnectControl shuts down control client. This can be run before node shutdown to force control to consider this ndoe @@ -972,12 +946,12 @@ func (b *LocalBackend) pauseOrResumeControlClientLocked() { // down, clients switch over to other replicas whilst the existing connections are kept alive for some period of time. func (b *LocalBackend) DisconnectControl() { b.mu.Lock() - defer b.mu.Unlock() cc := b.resetControlClientLocked() - if cc == nil { - return + b.mu.Unlock() + + if cc != nil { + cc.Shutdown() } - cc.Shutdown() } // linkChange is our network monitor callback, called whenever the network changes. @@ -1111,14 +1085,13 @@ func (b *LocalBackend) ClearCaptureSink() { // Shutdown halts the backend and all its sub-components. The backend // can no longer be used after Shutdown returns. func (b *LocalBackend) Shutdown() { - // Close the [eventbus.Client] and wait for LocalBackend.consumeEventbusTopics - // to return. Do this before acquiring b.mu: - // 1. LocalBackend.consumeEventbusTopics event handlers also acquire b.mu, - // they can deadlock with c.Shutdown(). - // 2. LocalBackend.consumeEventbusTopics event handlers may not guard against - // undesirable post/in-progress LocalBackend.Shutdown() behaviors. + // Close the [eventbus.Client] to wait for subscribers to + // return before acquiring b.mu: + // 1. Event handlers also acquire b.mu, they can deadlock with c.Shutdown(). + // 2. Event handlers may not guard against undesirable post/in-progress + // LocalBackend.Shutdown() behaviors. b.appcTask.Shutdown() - b.eventSubs.Close() + b.eventClient.Close() b.em.close() @@ -1221,6 +1194,7 @@ func stripKeysFromPrefs(p ipn.PrefsView) ipn.PrefsView { p2.Persist.PrivateNodeKey = key.NodePrivate{} p2.Persist.OldPrivateNodeKey = key.NodePrivate{} p2.Persist.NetworkLockKey = key.NLPrivate{} + p2.Persist.AttestationKey = nil return p2.View() } @@ -1328,7 +1302,7 @@ func (b *LocalBackend) UpdateStatus(sb *ipnstate.StatusBuilder) { if hi := nm.SelfNode.Hostinfo(); hi.Valid() { ss.HostName = hi.Hostname() } - ss.DNSName = nm.Name + ss.DNSName = nm.SelfName() ss.UserID = nm.User() if sn := nm.SelfNode; sn.Valid() { peerStatusFromNode(ss, sn) @@ -1552,19 +1526,37 @@ func (b *LocalBackend) GetFilterForTest() *filter.Filter { return nb.filterAtomic.Load() } +func (b *LocalBackend) settleEventBus() { + // The move to eventbus made some things racy that + // weren't before so we have to wait for it to all be settled + // before we call certain things. + // See https://github.com/tailscale/tailscale/issues/16369 + // But we can't do this while holding b.mu without deadlocks, + // (https://github.com/tailscale/tailscale/pull/17804#issuecomment-3514426485) so + // now we just do it in lots of places before acquiring b.mu. + // Is this winning?? + if b.sys != nil { + if ms, ok := b.sys.MagicSock.GetOK(); ok { + ms.Synchronize() + } + } +} + // SetControlClientStatus is the callback invoked by the control client whenever it posts a new status. // Among other things, this is where we update the netmap, packet filters, DNS and DERP maps. func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st controlclient.Status) { - unlock := b.lockAndGetUnlock() - defer unlock() + if b.ignoreControlClientUpdates.Load() { + b.logf("ignoring SetControlClientStatus during controlclient shutdown") + return + } + b.mu.Lock() + defer b.mu.Unlock() if b.cc != c { b.logf("Ignoring SetControlClientStatus from old client") return } if st.Err != nil { - // The following do not depend on any data for which we need b locked. - unlock.UnlockEarly() if errors.Is(st.Err, io.EOF) { b.logf("[v1] Received error: EOF") return @@ -1573,7 +1565,7 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control var uerr controlclient.UserVisibleError if errors.As(st.Err, &uerr) { s := uerr.UserVisibleError() - b.send(ipn.Notify{ErrMessage: &s}) + b.sendLocked(ipn.Notify{ErrMessage: &s}) } return } @@ -1626,32 +1618,27 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control keyExpiryExtended := false if st.NetMap != nil { wasExpired := b.keyExpired - isExpired := !st.NetMap.Expiry.IsZero() && st.NetMap.Expiry.Before(b.clock.Now()) + isExpired := !st.NetMap.SelfKeyExpiry().IsZero() && st.NetMap.SelfKeyExpiry().Before(b.clock.Now()) if wasExpired && !isExpired { keyExpiryExtended = true } b.keyExpired = isExpired } - unlock.UnlockEarly() - if keyExpiryExtended && wasBlocked { // Key extended, unblock the engine - b.blockEngineUpdates(false) + b.blockEngineUpdatesLocked(false) } - if st.LoginFinished() && (wasBlocked || authWasInProgress) { + if st.LoggedIn && (wasBlocked || authWasInProgress) { if wasBlocked { // Auth completed, unblock the engine - b.blockEngineUpdates(false) + b.blockEngineUpdatesLocked(false) } - b.authReconfig() - b.send(ipn.Notify{LoginFinished: &empty.Message{}}) + b.authReconfigLocked() + b.sendLocked(ipn.Notify{LoginFinished: &empty.Message{}}) } - // Lock b again and do only the things that require locking. - b.mu.Lock() - prefsChanged := false cn := b.currentNode() prefs := b.pm.CurrentPrefs().AsStruct() @@ -1678,8 +1665,8 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control prefs.Persist = st.Persist.AsStruct() } } - if st.LoginFinished() { - if b.authURL != "" { + if st.LoggedIn { + if authWasInProgress { b.resetAuthURLLocked() // Interactive login finished successfully (URL visited). // After an interactive login, the user always wants @@ -1764,15 +1751,12 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control b.setNetMapLocked(st.NetMap) b.updateFilterLocked(prefs.View()) } - b.mu.Unlock() // Now complete the lock-free parts of what we started while locked. if st.NetMap != nil { if envknob.NoLogsNoSupport() && st.NetMap.HasCap(tailcfg.CapabilityDataPlaneAuditLogs) { msg := "tailnet requires logging to be enabled. Remove --no-logs-no-support from tailscaled command line." b.health.SetLocalLogConfigHealth(errors.New(msg)) - // Connecting to this tailnet without logging is forbidden; boot us outta here. - b.mu.Lock() // Get the current prefs again, since we unlocked above. prefs := b.pm.CurrentPrefs().AsStruct() prefs.WantRunning = false @@ -1784,8 +1768,7 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control }); err != nil { b.logf("Failed to save new controlclient state: %v", err) } - b.mu.Unlock() - b.send(ipn.Notify{ErrMessage: &msg, Prefs: &p}) + b.sendLocked(ipn.Notify{ErrMessage: &msg, Prefs: &p}) return } if oldNetMap != nil { @@ -1807,11 +1790,11 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control // Update the DERP map in the health package, which uses it for health notifications b.health.SetDERPMap(st.NetMap.DERPMap) - b.send(ipn.Notify{NetMap: st.NetMap}) + b.sendLocked(ipn.Notify{NetMap: st.NetMap}) // The error here is unimportant as is the result. This will recalculate the suggested exit node // cache the value and push any changes to the IPN bus. - b.SuggestExitNode() + b.suggestExitNodeLocked() // Check and update the exit node if needed, now that we have a new netmap. // @@ -1821,16 +1804,16 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control // // Otherwise, it might briefly show the exit node as offline and display a warning, // if the node wasn't online or wasn't advertising default routes in the previous netmap. - b.RefreshExitNode() + b.refreshExitNodeLocked() } if st.URL != "" { b.logf("Received auth URL: %.20v...", st.URL) - b.setAuthURL(st.URL) + b.setAuthURLLocked(st.URL) } - b.stateMachine() + b.stateMachineLocked() // This is currently (2020-07-28) necessary; conditionally disabling it is fragile! // This is where netmap information gets propagated to router and magicsock. - b.authReconfig() + b.authReconfigLocked() } type preferencePolicyInfo struct { @@ -2036,13 +2019,14 @@ func (b *LocalBackend) registerSysPolicyWatch() (unregister func(), err error) { // // b.mu must not be held. func (b *LocalBackend) reconcilePrefs() (_ ipn.PrefsView, anyChange bool) { - unlock := b.lockAndGetUnlock() + b.mu.Lock() + defer b.mu.Unlock() + prefs := b.pm.CurrentPrefs().AsStruct() if !b.reconcilePrefsLocked(prefs) { - unlock.UnlockEarly() return prefs.View(), false } - return b.setPrefsLockedOnEntry(prefs, unlock), true + return b.setPrefsLocked(prefs), true } // sysPolicyChanged is a callback triggered by syspolicy when it detects @@ -2090,6 +2074,11 @@ func (b *LocalBackend) UpdateNetmapDelta(muts []netmap.NodeMutation) (handled bo b.send(*notify) } }() + + // Gross. See https://github.com/tailscale/tailscale/issues/16369 + b.settleEventBus() + defer b.settleEventBus() + b.mu.Lock() defer b.mu.Unlock() @@ -2110,7 +2099,7 @@ func (b *LocalBackend) UpdateNetmapDelta(muts []netmap.NodeMutation) (handled bo if !ok || n.StableID() != exitNodeID { continue } - b.goTracker.Go(b.RefreshExitNode) + b.refreshExitNodeLocked() break } } @@ -2274,51 +2263,60 @@ func (b *LocalBackend) resolveExitNodeIPLocked(prefs *ipn.Prefs) (prefsChanged b func (b *LocalBackend) setWgengineStatus(s *wgengine.Status, err error) { if err != nil { b.logf("wgengine status error: %v", err) - b.broadcastStatusChanged() return } if s == nil { b.logf("[unexpected] non-error wgengine update with status=nil: %v", s) - b.broadcastStatusChanged() return } b.mu.Lock() + defer b.mu.Unlock() + + // For now, only check this in the callback, but don't check it in setWgengineStatusLocked if s.AsOf.Before(b.lastStatusTime) { // Don't process a status update that is older than the one we have // already processed. (corp#2579) - b.mu.Unlock() return } b.lastStatusTime = s.AsOf + + b.setWgengineStatusLocked(s) +} + +// setWgengineStatusLocked updates LocalBackend's view of the engine status and +// updates the endpoints both in the backend and in the control client. +// +// Unlike setWgengineStatus it does not discard out-of-order updates, so +// statuses sent here are always processed. This is useful for ensuring we don't +// miss a "we shut down" status during backend shutdown even if other statuses +// arrive out of order. +// +// TODO(zofrex): we should ensure updates actually do arrive in order and move +// the out-of-order check into this function. +// +// b.mu must be held. +func (b *LocalBackend) setWgengineStatusLocked(s *wgengine.Status) { es := b.parseWgStatusLocked(s) cc := b.cc + + // TODO(zofrex): the only reason we even write this is to transition from + // "Starting" to "Running" in the call to state machine a few lines below + // this. Maybe we don't even need to store it at all. b.engineStatus = es + needUpdateEndpoints := !slices.Equal(s.LocalAddrs, b.endpoints) if needUpdateEndpoints { b.endpoints = append([]tailcfg.Endpoint{}, s.LocalAddrs...) } - b.mu.Unlock() if cc != nil { if needUpdateEndpoints { cc.UpdateEndpoints(s.LocalAddrs) } - b.stateMachine() + b.stateMachineLocked() } - b.broadcastStatusChanged() - b.send(ipn.Notify{Engine: &es}) -} - -// broadcastStatusChanged must not be called with b.mu held. -func (b *LocalBackend) broadcastStatusChanged() { - // The sync.Cond docs say: "It is allowed but not required for the caller to hold c.L during the call." - // In this particular case, we must acquire b.mu. Otherwise we might broadcast before - // the waiter (in requestEngineStatusAndWait) starts to wait, in which case - // the waiter can get stuck indefinitely. See PR 2865. - b.mu.Lock() - b.statusChanged.Broadcast() - b.mu.Unlock() + b.sendLocked(ipn.Notify{Engine: &es}) } // SetNotifyCallback sets the function to call when the backend has something to @@ -2398,18 +2396,24 @@ func (b *LocalBackend) initOnce() { // actually a supported operation (it should be, but it's very unclear // from the following whether or not that is a safe transition). func (b *LocalBackend) Start(opts ipn.Options) error { - b.logf("Start") + defer b.settleEventBus() // with b.mu unlocked + b.mu.Lock() + defer b.mu.Unlock() + return b.startLocked(opts) +} +func (b *LocalBackend) startLocked(opts ipn.Options) error { + b.logf("Start") + logf := logger.WithPrefix(b.logf, "Start: ") b.startOnce.Do(b.initOnce) var clientToShutdown controlclient.Client defer func() { if clientToShutdown != nil { - clientToShutdown.Shutdown() + // Shutdown outside of b.mu to avoid deadlocks. + b.goTracker.Go(clientToShutdown.Shutdown) } }() - unlock := b.lockAndGetUnlock() - defer unlock() if opts.UpdatePrefs != nil { if err := b.checkPrefsLocked(opts.UpdatePrefs); err != nil { @@ -2431,7 +2435,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error { if b.state != ipn.Running && b.conf == nil && opts.AuthKey == "" { sysak, _ := b.polc.GetString(pkey.AuthKey, "") if sysak != "" { - b.logf("Start: setting opts.AuthKey by syspolicy, len=%v", len(sysak)) + logf("setting opts.AuthKey by syspolicy, len=%v", len(sysak)) opts.AuthKey = strings.TrimSpace(sysak) } } @@ -2464,11 +2468,13 @@ func (b *LocalBackend) Start(opts ipn.Options) error { cn := b.currentNode() - prefsChanged := false + var prefsChanged bool + var prefsChangedWhy []string newPrefs := b.pm.CurrentPrefs().AsStruct() if opts.UpdatePrefs != nil { newPrefs = opts.UpdatePrefs.Clone() prefsChanged = true + prefsChangedWhy = append(prefsChangedWhy, "opts.UpdatePrefs") } // Apply any syspolicy overrides, resolve exit node ID, etc. // As of 2025-07-03, this is primarily needed in two cases: @@ -2476,26 +2482,35 @@ func (b *LocalBackend) Start(opts ipn.Options) error { // - when Always Mode is enabled and we need to set WantRunning to true if b.reconcilePrefsLocked(newPrefs) { prefsChanged = true + prefsChangedWhy = append(prefsChangedWhy, "reconcilePrefsLocked") } // neither UpdatePrefs or reconciliation should change Persist newPrefs.Persist = b.pm.CurrentPrefs().Persist().AsStruct() - if buildfeatures.HasTPM { + if buildfeatures.HasTPM && b.HardwareAttested() { if genKey, ok := feature.HookGenerateAttestationKeyIfEmpty.GetOk(); ok { - newKey, err := genKey(newPrefs.Persist, b.logf) + newKey, err := genKey(newPrefs.Persist, logf) if err != nil { - b.logf("failed to populate attestation key from TPM: %v", err) + logf("failed to populate attestation key from TPM: %v", err) } if newKey { prefsChanged = true + prefsChangedWhy = append(prefsChangedWhy, "newKey") } } } + // Remove any existing attestation key if HardwareAttested is false. + if !b.HardwareAttested() && newPrefs.Persist != nil && newPrefs.Persist.AttestationKey != nil && !newPrefs.Persist.AttestationKey.IsZero() { + newPrefs.Persist.AttestationKey = nil + prefsChanged = true + prefsChangedWhy = append(prefsChangedWhy, "removeAttestationKey") + } if prefsChanged { + logf("updated prefs: %v, reason: %v", newPrefs.Pretty(), prefsChangedWhy) if err := b.pm.SetPrefs(newPrefs.View(), cn.NetworkProfile()); err != nil { - b.logf("failed to save updated and reconciled prefs: %v", err) + logf("failed to save updated and reconciled prefs (but still using updated prefs in memory): %v", err) } } prefs := newPrefs.View() @@ -2515,9 +2530,10 @@ func (b *LocalBackend) Start(opts ipn.Options) error { serverURL := prefs.ControlURLOrDefault(b.polc) if inServerMode := prefs.ForceDaemon(); inServerMode || runtime.GOOS == "windows" { - b.logf("Start: serverMode=%v", inServerMode) + logf("serverMode=%v", inServerMode) } b.applyPrefsToHostinfoLocked(hostinfo, prefs) + b.updateWarnSync(prefs) persistv := prefs.Persist().AsStruct() if persistv == nil { @@ -2569,6 +2585,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error { ControlKnobs: b.sys.ControlKnobs(), Shutdown: ccShutdown, Bus: b.sys.Bus.Get(), + StartPaused: prefs.Sync().EqualBool(false), // Don't warn about broken Linux IP forwarding when // netstack is being used. @@ -2583,7 +2600,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error { endpoints := b.endpoints if err := b.initTKALocked(); err != nil { - b.logf("initTKALocked: %v", err) + logf("initTKALocked: %v", err) } var tkaHead string if b.tka != nil { @@ -2624,7 +2641,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error { // regress tsnet.Server restarts. cc.Login(controlclient.LoginDefault) } - b.stateMachineLockedOnEntry(unlock) + b.stateMachineLocked() return nil } @@ -3056,9 +3073,6 @@ func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWa func (b *LocalBackend) WatchNotificationsAs(ctx context.Context, actor ipnauth.Actor, mask ipn.NotifyWatchOpt, onWatchAdded func(), fn func(roNotify *ipn.Notify) (keepGoing bool)) { ch := make(chan *ipn.Notify, 128) sessionID := rands.HexString(16) - if mask&ipn.NotifyNoPrivateKeys != 0 { - fn = filterPrivateKeys(fn) - } if mask&ipn.NotifyHealthActions == 0 { // if UI does not support PrimaryAction in health warnings, append // action URLs to the warning text instead. @@ -3158,39 +3172,6 @@ func (b *LocalBackend) WatchNotificationsAs(ctx context.Context, actor ipnauth.A sender.Run(ctx, ch) } -// filterPrivateKeys returns an IPN listener func that wraps the supplied IPN -// listener and zeroes out the PrivateKey in the NetMap passed to the wrapped -// listener. -func filterPrivateKeys(fn func(roNotify *ipn.Notify) (keepGoing bool)) func(*ipn.Notify) bool { - return func(n *ipn.Notify) bool { - redacted, changed := redactNetmapPrivateKeys(n.NetMap) - if !changed { - return fn(n) - } - - // The netmap in n is shared across all watchers, so to mutate it for a - // single watcher we have to clone the notify and the netmap. We can - // make shallow clones, at least. - n2 := *n - n2.NetMap = redacted - return fn(&n2) - } -} - -// redactNetmapPrivateKeys returns a copy of nm with private keys zeroed out. -// If no change was needed, it returns nm unmodified. -func redactNetmapPrivateKeys(nm *netmap.NetworkMap) (redacted *netmap.NetworkMap, changed bool) { - if nm == nil || nm.PrivateKey.IsZero() { - return nm, false - } - - // The netmap might be shared across watchers, so make at least a shallow - // clone before mutating it. - nm2 := *nm - nm2.PrivateKey = key.NodePrivate{} - return &nm2, true -} - // appendHealthActions returns an IPN listener func that wraps the supplied IPN // listener func and transforms health messages passed to the wrapped listener. // If health messages with PrimaryActions are present, it appends the label & @@ -3293,6 +3274,10 @@ func (b *LocalBackend) send(n ipn.Notify) { b.sendTo(n, allClients) } +func (b *LocalBackend) sendLocked(n ipn.Notify) { + b.sendToLocked(n, allClients) +} + // SendNotify sends a notification to the IPN bus, // typically to the GUI client. func (b *LocalBackend) SendNotify(n ipn.Notify) { @@ -3383,21 +3368,22 @@ func (b *LocalBackend) sendToLocked(n ipn.Notify, recipient notificationTarget) } } -// setAuthURL sets the authURL and triggers [LocalBackend.popBrowserAuthNow] if the URL has changed. +// setAuthURLLocked sets the authURL and triggers [LocalBackend.popBrowserAuthNow] if the URL has changed. // This method is called when a new authURL is received from the control plane, meaning that either a user // has started a new interactive login (e.g., by running `tailscale login` or clicking Login in the GUI), // or the control plane was unable to authenticate this node non-interactively (e.g., due to key expiration). // A non-nil b.authActor indicates that an interactive login is in progress and was initiated by the specified actor. +// +// b.mu must be held. +// // If url is "", it is equivalent to calling [LocalBackend.resetAuthURLLocked] with b.mu held. -func (b *LocalBackend) setAuthURL(url string) { +func (b *LocalBackend) setAuthURLLocked(url string) { var popBrowser, keyExpired bool var recipient ipnauth.Actor - b.mu.Lock() switch { case url == "": b.resetAuthURLLocked() - b.mu.Unlock() return case b.authURL != url: b.authURL = url @@ -3414,33 +3400,33 @@ func (b *LocalBackend) setAuthURL(url string) { // Consume the StartLoginInteractive call, if any, that caused the control // plane to send us this URL. b.authActor = nil - b.mu.Unlock() if popBrowser { - b.popBrowserAuthNow(url, keyExpired, recipient) + b.popBrowserAuthNowLocked(url, keyExpired, recipient) } } -// popBrowserAuthNow shuts down the data plane and sends the URL to the recipient's +// popBrowserAuthNowLocked shuts down the data plane and sends the URL to the recipient's // [watchSession]s if the recipient is non-nil; otherwise, it sends the URL to all watchSessions. // keyExpired is the value of b.keyExpired upon entry and indicates // whether the node's key has expired. -// It must not be called with b.mu held. -func (b *LocalBackend) popBrowserAuthNow(url string, keyExpired bool, recipient ipnauth.Actor) { +// +// b.mu must be held. +func (b *LocalBackend) popBrowserAuthNowLocked(url string, keyExpired bool, recipient ipnauth.Actor) { b.logf("popBrowserAuthNow(%q): url=%v, key-expired=%v, seamless-key-renewal=%v", maybeUsernameOf(recipient), url != "", keyExpired, b.seamlessRenewalEnabled()) // Deconfigure the local network data plane if: // - seamless key renewal is not enabled; // - key is expired (in which case tailnet connectivity is down anyway). if !b.seamlessRenewalEnabled() || keyExpired { - b.blockEngineUpdates(true) - b.stopEngineAndWait() + b.blockEngineUpdatesLocked(true) + b.stopEngineAndWaitLocked() - if b.State() == ipn.Running { - b.enterState(ipn.Starting) + if b.state == ipn.Running { + b.enterStateLocked(ipn.Starting) } } - b.tellRecipientToBrowseToURL(url, toNotificationTarget(recipient)) + b.tellRecipientToBrowseToURLLocked(url, toNotificationTarget(recipient)) } // validPopBrowserURL reports whether urlStr is a valid value for a @@ -3488,13 +3474,16 @@ func (b *LocalBackend) validPopBrowserURLLocked(urlStr string) bool { } func (b *LocalBackend) tellClientToBrowseToURL(url string) { - b.tellRecipientToBrowseToURL(url, allClients) + b.mu.Lock() + defer b.mu.Unlock() + b.tellRecipientToBrowseToURLLocked(url, allClients) } -// tellRecipientToBrowseToURL is like tellClientToBrowseToURL but allows specifying a recipient. -func (b *LocalBackend) tellRecipientToBrowseToURL(url string, recipient notificationTarget) { - if b.validPopBrowserURL(url) { - b.sendTo(ipn.Notify{BrowseToURL: &url}, recipient) +// tellRecipientToBrowseToURLLocked is like tellClientToBrowseToURL but allows specifying a recipient +// and b.mu must be held. +func (b *LocalBackend) tellRecipientToBrowseToURLLocked(url string, recipient notificationTarget) { + if b.validPopBrowserURLLocked(url) { + b.sendToLocked(ipn.Notify{BrowseToURL: &url}, recipient) } } @@ -3509,8 +3498,8 @@ func (b *LocalBackend) onClientVersion(v *tailcfg.ClientVersion) { } func (b *LocalBackend) onTailnetDefaultAutoUpdate(au bool) { - unlock := b.lockAndGetUnlock() - defer unlock() + b.mu.Lock() + defer b.mu.Unlock() prefs := b.pm.CurrentPrefs() if !prefs.Valid() { @@ -3532,14 +3521,14 @@ func (b *LocalBackend) onTailnetDefaultAutoUpdate(au bool) { b.logf("using tailnet default auto-update setting: %v", au) prefsClone := prefs.AsStruct() prefsClone.AutoUpdate.Apply = opt.NewBool(au) - _, err := b.editPrefsLockedOnEntry( + _, err := b.editPrefsLocked( ipnauth.Self, &ipn.MaskedPrefs{ Prefs: *prefsClone, AutoUpdateSet: ipn.AutoUpdatePrefsMask{ ApplySet: true, }, - }, unlock) + }) if err != nil { b.logf("failed to apply tailnet-wide default for auto-updates (%v): %v", au, err) return @@ -3771,7 +3760,11 @@ func (b *LocalBackend) StartLoginInteractive(ctx context.Context) error { // the control plane sends us one. Otherwise, the notification will be delivered to all // active [watchSession]s. func (b *LocalBackend) StartLoginInteractiveAs(ctx context.Context, user ipnauth.Actor) error { + if b.health.IsUnhealthy(ipn.StateStoreHealth) { + return errors.New("cannot log in when state store is unhealthy") + } b.mu.Lock() + defer b.mu.Unlock() if b.cc == nil { panic("LocalBackend.assertClient: b.cc == nil") } @@ -3789,12 +3782,11 @@ func (b *LocalBackend) StartLoginInteractiveAs(ctx context.Context, user ipnauth b.authActor = user } cc := b.cc - b.mu.Unlock() b.logf("StartLoginInteractiveAs(%q): url=%v", maybeUsernameOf(user), hasValidURL) if hasValidURL { - b.popBrowserAuthNow(url, keyExpired, user) + b.popBrowserAuthNowLocked(url, keyExpired, user) } else { cc.Login(b.loginFlags | controlclient.LoginInteractive) } @@ -3924,8 +3916,8 @@ func (b *LocalBackend) parseWgStatusLocked(s *wgengine.Status) (ret ipn.EngineSt // // On non-multi-user systems, the actor should be set to nil. func (b *LocalBackend) SetCurrentUser(actor ipnauth.Actor) { - unlock := b.lockAndGetUnlock() - defer unlock() + b.mu.Lock() + defer b.mu.Unlock() var userIdentifier string if user := cmp.Or(actor, b.currentUser); user != nil { @@ -3947,7 +3939,7 @@ func (b *LocalBackend) SetCurrentUser(actor ipnauth.Actor) { action = "connected" } reason := fmt.Sprintf("client %s (%s)", action, userIdentifier) - b.switchToBestProfileLockedOnEntry(reason, unlock) + b.switchToBestProfileLocked(reason) } // SwitchToBestProfile selects the best profile to use, @@ -3957,13 +3949,14 @@ func (b *LocalBackend) SetCurrentUser(actor ipnauth.Actor) { // or disconnecting, or a change in the desktop session state, and is used // for logging. func (b *LocalBackend) SwitchToBestProfile(reason string) { - b.switchToBestProfileLockedOnEntry(reason, b.lockAndGetUnlock()) + b.mu.Lock() + defer b.mu.Unlock() + b.switchToBestProfileLocked(reason) } -// switchToBestProfileLockedOnEntry is like [LocalBackend.SwitchToBestProfile], -// but b.mu must held on entry. It is released on exit. -func (b *LocalBackend) switchToBestProfileLockedOnEntry(reason string, unlock unlockOnce) { - defer unlock() +// switchToBestProfileLocked is like [LocalBackend.SwitchToBestProfile], +// but b.mu must held on entry. +func (b *LocalBackend) switchToBestProfileLocked(reason string) { oldControlURL := b.pm.CurrentPrefs().ControlURLOrDefault(b.polc) profile, background := b.resolveBestProfileLocked() cp, switched, err := b.pm.SwitchToProfile(profile) @@ -3994,7 +3987,7 @@ func (b *LocalBackend) switchToBestProfileLockedOnEntry(reason string, unlock un if newControlURL := b.pm.CurrentPrefs().ControlURLOrDefault(b.polc); oldControlURL != newControlURL { b.resetDialPlan() } - if err := b.resetForProfileChangeLockedOnEntry(unlock); err != nil { + if err := b.resetForProfileChangeLocked(); err != nil { // TODO(nickkhyl): The actual reset cannot fail. However, // the TKA initialization or [LocalBackend.Start] can fail. // These errors are not critical as far as we're concerned. @@ -4242,8 +4235,8 @@ func (b *LocalBackend) checkAutoUpdatePrefsLocked(p *ipn.Prefs) error { // Setting the value to false when use of an exit node is already false is not an error, // nor is true when the exit node is already in use. func (b *LocalBackend) SetUseExitNodeEnabled(actor ipnauth.Actor, v bool) (ipn.PrefsView, error) { - unlock := b.lockAndGetUnlock() - defer unlock() + b.mu.Lock() + defer b.mu.Unlock() p0 := b.pm.CurrentPrefs() if !buildfeatures.HasUseExitNode { @@ -4287,7 +4280,7 @@ func (b *LocalBackend) SetUseExitNodeEnabled(actor ipnauth.Actor, v bool) (ipn.P mp.InternalExitNodePrior = p0.ExitNodeID() } } - return b.editPrefsLockedOnEntry(actor, mp, unlock) + return b.editPrefsLocked(actor, mp) } // MaybeClearAppConnector clears the routes from any AppConnector if @@ -4318,8 +4311,11 @@ func (b *LocalBackend) EditPrefsAs(mp *ipn.MaskedPrefs, actor ipnauth.Actor) (ip if mp.SetsInternal() { return ipn.PrefsView{}, errors.New("can't set Internal fields") } + defer b.settleEventBus() - return b.editPrefsLockedOnEntry(actor, mp, b.lockAndGetUnlock()) + b.mu.Lock() + defer b.mu.Unlock() + return b.editPrefsLocked(actor, mp) } // checkEditPrefsAccessLocked checks whether the current user has access @@ -4509,8 +4505,8 @@ func (b *LocalBackend) startReconnectTimerLocked(d time.Duration) { profileID := b.pm.CurrentProfile().ID() var reconnectTimer tstime.TimerController reconnectTimer = b.clock.AfterFunc(d, func() { - unlock := b.lockAndGetUnlock() - defer unlock() + b.mu.Lock() + defer b.mu.Unlock() if b.reconnectTimer != reconnectTimer { // We're either not the most recent timer, or we lost the race when @@ -4528,7 +4524,7 @@ func (b *LocalBackend) startReconnectTimerLocked(d time.Duration) { } mp := &ipn.MaskedPrefs{WantRunningSet: true, Prefs: ipn.Prefs{WantRunning: true}} - if _, err := b.editPrefsLockedOnEntry(ipnauth.Self, mp, unlock); err != nil { + if _, err := b.editPrefsLocked(ipnauth.Self, mp); err != nil { b.logf("failed to automatically reconnect as %q after %v: %v", cp.Name(), d, err) } else { b.logf("automatically reconnected as %q after %v", cp.Name(), d) @@ -4557,11 +4553,8 @@ func (b *LocalBackend) stopReconnectTimerLocked() { } } -// Warning: b.mu must be held on entry, but it unlocks it on the way out. -// TODO(bradfitz): redo the locking on all these weird methods like this. -func (b *LocalBackend) editPrefsLockedOnEntry(actor ipnauth.Actor, mp *ipn.MaskedPrefs, unlock unlockOnce) (ipn.PrefsView, error) { - defer unlock() // for error paths - +// b.mu must be held. +func (b *LocalBackend) editPrefsLocked(actor ipnauth.Actor, mp *ipn.MaskedPrefs) (ipn.PrefsView, error) { p0 := b.pm.CurrentPrefs() // Check if the changes in mp are allowed. @@ -4598,11 +4591,11 @@ func (b *LocalBackend) editPrefsLockedOnEntry(actor ipnauth.Actor, mp *ipn.Maske // before the modified prefs are actually set for the current profile. b.onEditPrefsLocked(actor, mp, p0, p1.View()) - newPrefs := b.setPrefsLockedOnEntry(p1, unlock) + newPrefs := b.setPrefsLocked(p1) // Note: don't perform any actions for the new prefs here. Not // every prefs change goes through EditPrefs. Put your actions - // in setPrefsLocksOnEntry instead. + // in setPrefsLocked instead. // This should return the public prefs, not the private ones. return stripKeysFromPrefs(newPrefs), nil @@ -4625,12 +4618,10 @@ func (b *LocalBackend) checkProfileNameLocked(p *ipn.Prefs) error { return nil } -// setPrefsLockedOnEntry requires b.mu be held to call it, but it -// unlocks b.mu when done. newp ownership passes to this function. +// setPrefsLocked requires b.mu be held to call it. +// newp ownership passes to this function. // It returns a read-only copy of the new prefs. -func (b *LocalBackend) setPrefsLockedOnEntry(newp *ipn.Prefs, unlock unlockOnce) ipn.PrefsView { - defer unlock() - +func (b *LocalBackend) setPrefsLocked(newp *ipn.Prefs) ipn.PrefsView { cn := b.currentNode() netMap := cn.NetMap() b.setAtomicValuesFromPrefsLocked(newp.View()) @@ -4691,10 +4682,11 @@ func (b *LocalBackend) setPrefsLockedOnEntry(newp *ipn.Prefs, unlock unlockOnce) b.resetAlwaysOnOverrideLocked() } - unlock.UnlockEarly() + b.pauseOrResumeControlClientLocked() // for prefs.Sync changes + b.updateWarnSync(prefs) if oldp.ShieldsUp() != newp.ShieldsUp || hostInfoChanged { - b.doSetHostinfoFilterServices() + b.doSetHostinfoFilterServicesLocked() } if netMap != nil { @@ -4707,12 +4699,12 @@ func (b *LocalBackend) setPrefsLockedOnEntry(newp *ipn.Prefs, unlock unlockOnce) } if oldp.WantRunning() != newp.WantRunning { - b.stateMachine() + b.stateMachineLocked() } else { - b.authReconfig() + b.authReconfigLocked() } - b.send(ipn.Notify{Prefs: &prefs}) + b.sendLocked(ipn.Notify{Prefs: &prefs}) return prefs } @@ -4722,14 +4714,8 @@ func (b *LocalBackend) GetPeerAPIPort(ip netip.Addr) (port uint16, ok bool) { if !buildfeatures.HasPeerAPIServer { return 0, false } - b.mu.Lock() - defer b.mu.Unlock() - for _, pln := range b.peerAPIListeners { - if pln.ip == ip { - return uint16(pln.port), true - } - } - return 0, false + portInt, ok := b.peerAPIPorts.Load()[ip] + return uint16(portInt), ok } // handlePeerAPIConn serves an already-accepted connection c. @@ -4830,9 +4816,13 @@ func (b *LocalBackend) setPortlistServices(sl []tailcfg.Service) { // TODO(danderson): we shouldn't be mangling hostinfo here after // painstakingly constructing it in twelvety other places. func (b *LocalBackend) doSetHostinfoFilterServices() { - unlock := b.lockAndGetUnlock() - defer unlock() + b.mu.Lock() + defer b.mu.Unlock() + b.doSetHostinfoFilterServicesLocked() +} +// b.mu must be held +func (b *LocalBackend) doSetHostinfoFilterServicesLocked() { cc := b.cc if cc == nil { // Control client isn't up yet. @@ -4856,8 +4846,6 @@ func (b *LocalBackend) doSetHostinfoFilterServices() { hi.Services = []tailcfg.Service{} } - unlock.UnlockEarly() - // Don't mutate hi.Service's underlying array. Append to // the slice with no free capacity. c := len(hi.Services) @@ -4903,15 +4891,15 @@ func (b *LocalBackend) isEngineBlocked() bool { return b.blocked } -// blockEngineUpdate sets b.blocked to block, while holding b.mu. Its -// indirect effect is to turn b.authReconfig() into a no-op if block -// is true. -func (b *LocalBackend) blockEngineUpdates(block bool) { +// blockEngineUpdatesLocked sets b.blocked to block. +// +// Its indirect effect is to turn b.authReconfig() into a no-op if block is +// true. +// +// b.mu must be held. +func (b *LocalBackend) blockEngineUpdatesLocked(block bool) { b.logf("blockEngineUpdates(%v)", block) - - b.mu.Lock() b.blocked = block - b.mu.Unlock() } // reconfigAppConnectorLocked updates the app connector state based on the @@ -5022,38 +5010,41 @@ func (b *LocalBackend) readvertiseAppConnectorRoutes() { // updates are not currently blocked, based on the cached netmap and // user prefs. func (b *LocalBackend) authReconfig() { - // Wait for magicsock to process pending [eventbus] events, - // such as netmap updates. This should be completed before - // wireguard-go is reconfigured. See tailscale/tailscale#16369. - b.MagicConn().Synchronize() - b.mu.Lock() - blocked := b.blocked - prefs := b.pm.CurrentPrefs() - cn := b.currentNode() - nm := cn.NetMap() - hasPAC := b.prevIfState.HasPAC() - disableSubnetsIfPAC := cn.SelfHasCap(tailcfg.NodeAttrDisableSubnetsIfPAC) - dohURL, dohURLOK := cn.exitNodeCanProxyDNS(prefs.ExitNodeID()) - dcfg := cn.dnsConfigForNetmap(prefs, b.keyExpired, version.OS()) - // If the current node is an app connector, ensure the app connector machine is started - b.reconfigAppConnectorLocked(nm, prefs) - closing := b.shutdownCalled - b.mu.Unlock() + defer b.mu.Unlock() + b.authReconfigLocked() +} + +// authReconfigLocked is the locked version of [LocalBackend.authReconfig]. +// +// b.mu must be held. +func (b *LocalBackend) authReconfigLocked() { - if closing { + if b.shutdownCalled { b.logf("[v1] authReconfig: skipping because in shutdown") return } - - if blocked { + if b.blocked { b.logf("[v1] authReconfig: blocked, skipping.") return } + + cn := b.currentNode() + + nm := cn.NetMap() if nm == nil { b.logf("[v1] authReconfig: netmap not yet valid. Skipping.") return } + + prefs := b.pm.CurrentPrefs() + hasPAC := b.prevIfState.HasPAC() + disableSubnetsIfPAC := cn.SelfHasCap(tailcfg.NodeAttrDisableSubnetsIfPAC) + dohURL, dohURLOK := cn.exitNodeCanProxyDNS(prefs.ExitNodeID()) + dcfg := cn.dnsConfigForNetmap(prefs, b.keyExpired, version.OS()) + // If the current node is an app connector, ensure the app connector machine is started + b.reconfigAppConnectorLocked(nm, prefs) + if !prefs.WantRunning() { b.logf("[v1] authReconfig: skipping because !WantRunning.") return @@ -5081,14 +5072,19 @@ func (b *LocalBackend) authReconfig() { } } - cfg, err := nmcfg.WGCfg(nm, b.logf, flags, prefs.ExitNodeID()) + priv := b.pm.CurrentPrefs().Persist().PrivateNodeKey() + if !priv.IsZero() && priv.Public() != nm.NodeKey { + priv = key.NodePrivate{} + } + + cfg, err := nmcfg.WGCfg(priv, nm, b.logf, flags, prefs.ExitNodeID()) if err != nil { b.logf("wgcfg: %v", err) return } oneCGNATRoute := shouldUseOneCGNATRoute(b.logf, b.sys.NetMon.Get(), b.sys.ControlKnobs(), version.OS()) - rcfg := b.routerConfig(cfg, prefs, oneCGNATRoute) + rcfg := b.routerConfigLocked(cfg, prefs, oneCGNATRoute) err = b.e.Reconfig(cfg, rcfg, dcfg) if err == wgengine.ErrNoChanges { @@ -5096,9 +5092,9 @@ func (b *LocalBackend) authReconfig() { } b.logf("[v1] authReconfig: ra=%v dns=%v 0x%02x: %v", prefs.RouteAll(), prefs.CorpDNS(), flags, err) - b.initPeerAPIListener() + b.initPeerAPIListenerLocked() if buildfeatures.HasAppConnectors { - b.readvertiseAppConnectorRoutes() + go b.goTracker.Go(b.readvertiseAppConnectorRoutes) } } @@ -5211,6 +5207,7 @@ func (b *LocalBackend) closePeerAPIListenersLocked() { pln.Close() } b.peerAPIListeners = nil + b.peerAPIPorts.Store(nil) } // peerAPIListenAsync is whether the operating system requires that we @@ -5221,12 +5218,18 @@ func (b *LocalBackend) closePeerAPIListenersLocked() { const peerAPIListenAsync = runtime.GOOS == "windows" || runtime.GOOS == "android" func (b *LocalBackend) initPeerAPIListener() { + b.mu.Lock() + defer b.mu.Unlock() + b.initPeerAPIListenerLocked() +} + +// b.mu must be held. +func (b *LocalBackend) initPeerAPIListenerLocked() { if !buildfeatures.HasPeerAPIServer { return } b.logf("[v1] initPeerAPIListener: entered") - b.mu.Lock() - defer b.mu.Unlock() + if b.shutdownCalled { b.logf("[v1] initPeerAPIListener: shutting down") return @@ -5277,6 +5280,7 @@ func (b *LocalBackend) initPeerAPIListener() { b.peerAPIServer = ps isNetstack := b.sys.IsNetstack() + peerAPIPorts := make(map[netip.Addr]int) for i, a := range addrs.All() { var ln net.Listener var err error @@ -5309,7 +5313,9 @@ func (b *LocalBackend) initPeerAPIListener() { b.logf("peerapi: serving on %s", pln.urlStr) go pln.serve() b.peerAPIListeners = append(b.peerAPIListeners, pln) + peerAPIPorts[a.Addr()] = pln.port } + b.peerAPIPorts.Store(peerAPIPorts) b.goTracker.Go(b.doSetHostinfoFilterServices) } @@ -5389,15 +5395,15 @@ func peerRoutes(logf logger.Logf, peers []wgcfg.Peer, cgnatThreshold int) (route } // routerConfig produces a router.Config from a wireguard config and IPN prefs. -func (b *LocalBackend) routerConfig(cfg *wgcfg.Config, prefs ipn.PrefsView, oneCGNATRoute bool) *router.Config { +// +// b.mu must be held. +func (b *LocalBackend) routerConfigLocked(cfg *wgcfg.Config, prefs ipn.PrefsView, oneCGNATRoute bool) *router.Config { singleRouteThreshold := 10_000 if oneCGNATRoute { singleRouteThreshold = 1 } - b.mu.Lock() - netfilterKind := b.capForcedNetfilter // protected by b.mu - b.mu.Unlock() + netfilterKind := b.capForcedNetfilter // protected by b.mu (hence the Locked suffix) if prefs.NetfilterKind() != "" { if netfilterKind != "" { @@ -5526,20 +5532,9 @@ func (b *LocalBackend) applyPrefsToHostinfoLocked(hi *tailcfg.Hostinfo, prefs ip } hi.SSH_HostKeys = sshHostKeys - hi.ServicesHash = b.vipServiceHash(b.vipServicesFromPrefsLocked(prefs)) - - // The Hostinfo.IngressEnabled field is used to communicate to control whether - // the node has funnel enabled. - hi.IngressEnabled = b.hasIngressEnabledLocked() - // The Hostinfo.WantIngress field tells control whether the user intends - // to use funnel with this node even though it is not currently enabled. - // This is an optimization to control- Funnel requires creation of DNS - // records and because DNS propagation can take time, we want to ensure - // that the records exist for any node that intends to use funnel even - // if it's not enabled. If hi.IngressEnabled is true, control knows that - // DNS records are needed, so we can save bandwidth and not send - // WireIngress. - hi.WireIngress = b.shouldWireInactiveIngressLocked() + for _, f := range hookMaybeMutateHostinfoLocked { + f(b, hi, prefs) + } if buildfeatures.HasAppConnectors { hi.AppConnector.Set(prefs.AppConnector().Advertise) @@ -5566,21 +5561,16 @@ func (b *LocalBackend) applyPrefsToHostinfoLocked(hi *tailcfg.Hostinfo, prefs ip } } -// enterState transitions the backend into newState, updating internal +// enterStateLocked transitions the backend into newState, updating internal // state and propagating events out as needed. // // TODO(danderson): while this isn't a lie, exactly, a ton of other // places twiddle IPN internal state without going through here, so // really this is more "one of several places in which random things // happen". -func (b *LocalBackend) enterState(newState ipn.State) { - unlock := b.lockAndGetUnlock() - b.enterStateLockedOnEntry(newState, unlock) -} - -// enterStateLockedOnEntry is like enterState but requires b.mu be held to call -// it, but it unlocks b.mu when done (via unlock, a once func). -func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlockOnce) { +// +// b.mu must be held. +func (b *LocalBackend) enterStateLocked(newState ipn.State) { cn := b.currentNode() oldState := b.state b.setStateLocked(newState) @@ -5632,17 +5622,16 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlock } b.pauseOrResumeControlClientLocked() - unlock.UnlockEarly() - // prefs may change irrespective of state; WantRunning should be explicitly // set before potential early return even if the state is unchanged. b.health.SetIPNState(newState.String(), prefs.Valid() && prefs.WantRunning()) if oldState == newState { return } + b.logf("Switching ipn state %v -> %v (WantRunning=%v, nm=%v)", oldState, newState, prefs.WantRunning(), netMap != nil) - b.send(ipn.Notify{State: &newState}) + b.sendLocked(ipn.Notify{State: &newState}) switch newState { case ipn.NeedsLogin: @@ -5650,7 +5639,7 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlock // always block updates on NeedsLogin even if seamless renewal is enabled, // to prevent calls to authReconfig from reconfiguring the engine when our // key has expired and we're waiting to authenticate to use the new key. - b.blockEngineUpdates(true) + b.blockEngineUpdatesLocked(true) fallthrough case ipn.Stopped, ipn.NoState: // Unconfigure the engine if it has stopped (WantRunning is set to false) @@ -5664,16 +5653,18 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlock feature.SystemdStatus("Stopped; run 'tailscale up' to log in") } case ipn.Starting, ipn.NeedsMachineAuth: - b.authReconfig() + b.authReconfigLocked() // Needed so that UpdateEndpoints can run - b.e.RequestStatus() + b.goTracker.Go(b.e.RequestStatus) case ipn.Running: - var addrStrs []string - addrs := netMap.GetAddresses() - for _, p := range addrs.All() { - addrStrs = append(addrStrs, p.Addr().String()) + if feature.CanSystemdStatus { + var addrStrs []string + addrs := netMap.GetAddresses() + for _, p := range addrs.All() { + addrStrs = append(addrStrs, p.Addr().String()) + } + feature.SystemdStatus("Connected; %s; %s", activeLogin, strings.Join(addrStrs, " ")) } - feature.SystemdStatus("Connected; %s; %s", activeLogin, strings.Join(addrStrs, " ")) default: b.logf("[unexpected] unknown newState %#v", newState) } @@ -5700,6 +5691,9 @@ func (b *LocalBackend) NodeKey() key.NodePublic { // // b.mu must be held func (b *LocalBackend) nextStateLocked() ipn.State { + if b.health.IsUnhealthy(ipn.StateStoreHealth) { + return ipn.NoState + } var ( cc = b.cc cn = b.currentNode() @@ -5773,107 +5767,28 @@ func (b *LocalBackend) nextStateLocked() ipn.State { // that have happened. It is invoked from the various callbacks that // feed events into LocalBackend. // -// TODO(apenwarr): use a channel or something to prevent reentrancy? -// Or maybe just call the state machine from fewer places. -func (b *LocalBackend) stateMachine() { - unlock := b.lockAndGetUnlock() - b.stateMachineLockedOnEntry(unlock) -} - -// stateMachineLockedOnEntry is like stateMachine but requires b.mu be held to -// call it, but it unlocks b.mu when done (via unlock, a once func). -func (b *LocalBackend) stateMachineLockedOnEntry(unlock unlockOnce) { - b.enterStateLockedOnEntry(b.nextStateLocked(), unlock) -} - -// lockAndGetUnlock locks b.mu and returns a sync.OnceFunc function that will -// unlock it at most once. -// -// This is all very unfortunate but exists as a guardrail against the -// unfortunate "lockedOnEntry" methods in this package (primarily -// enterStateLockedOnEntry) that require b.mu held to be locked on entry to the -// function but unlock the mutex on their way out. As a stepping stone to -// cleaning things up (as of 2024-04-06), we at least pass the unlock func -// around now and defer unlock in the caller to avoid missing unlocks and double -// unlocks. TODO(bradfitz,maisem): make the locking in this package more -// traditional (simple). See https://github.com/tailscale/tailscale/issues/11649 -func (b *LocalBackend) lockAndGetUnlock() (unlock unlockOnce) { - b.mu.Lock() - var unlocked atomic.Bool - return func() bool { - if unlocked.CompareAndSwap(false, true) { - b.mu.Unlock() - return true - } - return false - } -} - -// unlockOnce is a func that unlocks only b.mu the first time it's called. -// Therefore it can be safely deferred to catch error paths, without worrying -// about double unlocks if a different point in the code later needs to explicitly -// unlock it first as well. It reports whether it was unlocked. -type unlockOnce func() bool - -// UnlockEarly unlocks the LocalBackend.mu. It panics if u returns false, -// indicating that this unlocker was already used. -// -// We're using this method to help us document & find the places that have -// atypical locking patterns. See -// https://github.com/tailscale/tailscale/issues/11649 for background. -// -// A normal unlock is a deferred one or an explicit b.mu.Unlock a few lines -// after the lock, without lots of control flow in-between. An "early" unlock is -// one that happens in weird places, like in various "LockedOnEntry" methods in -// this package that require the mutex to be locked on entry but unlock it -// somewhere in the middle (maybe several calls away) and then sometimes proceed -// to lock it again. -// -// The reason UnlockeEarly panics if already called is because these are the -// points at which it's assumed that the mutex is already held and it now needs -// to be released. If somebody already released it, that invariant was violated. -// On the other hand, simply calling u only returns false instead of panicking -// so you can defer it without care, confident you got all the error return -// paths which were previously done by hand. -func (u unlockOnce) UnlockEarly() { - if !u() { - panic("Unlock on already-called unlockOnce") - } +// requires b.mu to be held. +func (b *LocalBackend) stateMachineLocked() { + b.enterStateLocked(b.nextStateLocked()) } // stopEngineAndWait deconfigures the local network data plane, and // waits for it to deliver a status update indicating it has stopped // before returning. -func (b *LocalBackend) stopEngineAndWait() { +// +// b.mu must be held. +func (b *LocalBackend) stopEngineAndWaitLocked() { b.logf("stopEngineAndWait...") - b.e.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{}) - b.requestEngineStatusAndWaitForStopped() - b.logf("stopEngineAndWait: done.") -} - -// Requests the wgengine status, and does not return until a status was -// delivered (to the usual callback) that indicates the engine is stopped. -func (b *LocalBackend) requestEngineStatusAndWaitForStopped() { - b.logf("requestEngineStatusAndWaitForStopped") - - b.mu.Lock() - defer b.mu.Unlock() - - b.goTracker.Go(b.e.RequestStatus) - b.logf("requestEngineStatusAndWaitForStopped: waiting...") - for { - b.statusChanged.Wait() // temporarily releases lock while waiting - - if !b.blocked { - b.logf("requestEngineStatusAndWaitForStopped: engine is no longer blocked, must have stopped and started again, not safe to wait.") - break - } - if b.engineStatus.NumLive == 0 && b.engineStatus.LiveDERPs == 0 { - b.logf("requestEngineStatusAndWaitForStopped: engine is stopped.") - break - } - b.logf("requestEngineStatusAndWaitForStopped: engine is still running. Waiting...") + st, err := b.e.ResetAndStop() + if err != nil { + // TODO(braditz): our caller, popBrowserAuthNowLocked, probably + // should handle this somehow. For now, just log it. + // See tailscale/tailscale#18187 + b.logf("stopEngineAndWait: ResetAndStop error: %v", err) + return } + b.setWgengineStatusLocked(st) + b.logf("stopEngineAndWait: done.") } // setControlClientLocked sets the control client to cc, @@ -5883,6 +5798,7 @@ func (b *LocalBackend) requestEngineStatusAndWaitForStopped() { func (b *LocalBackend) setControlClientLocked(cc controlclient.Client) { b.cc = cc b.ccAuto, _ = cc.(*controlclient.Auto) + b.ignoreControlClientUpdates.Store(cc == nil) } // resetControlClientLocked sets b.cc to nil and returns the old value. If the @@ -5976,11 +5892,11 @@ func (b *LocalBackend) ShouldHandleViaIP(ip netip.Addr) bool { // Logout logs out the current profile, if any, and waits for the logout to // complete. func (b *LocalBackend) Logout(ctx context.Context, actor ipnauth.Actor) error { - unlock := b.lockAndGetUnlock() - defer unlock() + b.mu.Lock() if !b.hasNodeKeyLocked() { // Already logged out. + b.mu.Unlock() return nil } cc := b.cc @@ -5989,17 +5905,17 @@ func (b *LocalBackend) Logout(ctx context.Context, actor ipnauth.Actor) error { // delete it later. profile := b.pm.CurrentProfile() - _, err := b.editPrefsLockedOnEntry( + _, err := b.editPrefsLocked( actor, &ipn.MaskedPrefs{ WantRunningSet: true, LoggedOutSet: true, Prefs: ipn.Prefs{WantRunning: false, LoggedOut: true}, - }, unlock) + }) + b.mu.Unlock() if err != nil { return err } - // b.mu is now unlocked, after editPrefsLockedOnEntry. // Clear any previous dial plan(s), if set. b.resetDialPlan() @@ -6019,14 +5935,14 @@ func (b *LocalBackend) Logout(ctx context.Context, actor ipnauth.Actor) error { return err } - unlock = b.lockAndGetUnlock() - defer unlock() + b.mu.Lock() + defer b.mu.Unlock() if err := b.pm.DeleteProfile(profile.ID()); err != nil { b.logf("error deleting profile: %v", err) return err } - return b.resetForProfileChangeLockedOnEntry(unlock) + return b.resetForProfileChangeLocked() } // setNetInfo sets b.hostinfo.NetInfo to ni, and passes ni along to the @@ -6077,12 +5993,19 @@ func (b *LocalBackend) RefreshExitNode() { if !buildfeatures.HasUseExitNode { return } - if b.resolveExitNode() { - b.authReconfig() + b.mu.Lock() + defer b.mu.Unlock() + b.refreshExitNodeLocked() +} + +// refreshExitNodeLocked is like RefreshExitNode but requires b.mu be held. +func (b *LocalBackend) refreshExitNodeLocked() { + if b.resolveExitNodeLocked() { + b.authReconfigLocked() } } -// resolveExitNode determines which exit node to use based on the current prefs +// resolveExitNodeLocked determines which exit node to use based on the current prefs // and netmap. It updates the exit node ID in the prefs if needed, updates the // exit node ID in the hostinfo if needed, sends a notification to clients, and // returns true if the exit node has changed. @@ -6090,13 +6013,11 @@ func (b *LocalBackend) RefreshExitNode() { // It is the caller's responsibility to reconfigure routes and actually // start using the selected exit node, if needed. // -// b.mu must not be held. -func (b *LocalBackend) resolveExitNode() (changed bool) { +// b.mu must be held. +func (b *LocalBackend) resolveExitNodeLocked() (changed bool) { if !buildfeatures.HasUseExitNode { return false } - b.mu.Lock() - defer b.mu.Unlock() nm := b.currentNode().NetMap() prefs := b.pm.CurrentPrefs().AsStruct() @@ -6218,9 +6139,10 @@ func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { b.setDebugLogsByCapabilityLocked(nm) } - // See the netns package for documentation on what this capability does. - netns.SetBindToInterfaceByRoute(nm.HasCap(tailcfg.CapabilityBindToInterfaceByRoute)) - netns.SetDisableBindConnToInterface(nm.HasCap(tailcfg.CapabilityDebugDisableBindConnToInterface)) + // See the netns package for documentation on what these capability do. + netns.SetBindToInterfaceByRoute(b.logf, nm.HasCap(tailcfg.CapabilityBindToInterfaceByRoute)) + netns.SetDisableBindConnToInterface(b.logf, nm.HasCap(tailcfg.CapabilityDebugDisableBindConnToInterface)) + netns.SetDisableBindConnToInterfaceAppleExt(b.logf, nm.HasCap(tailcfg.CapabilityDebugDisableBindConnToInterfaceAppleExt)) b.setTCPPortsInterceptedFromNetmapAndPrefsLocked(b.pm.CurrentPrefs()) if buildfeatures.HasServe { @@ -6321,36 +6243,34 @@ func (b *LocalBackend) setTCPPortsInterceptedFromNetmapAndPrefsLocked(prefs ipn. } // Update funnel and service hash info in hostinfo and kick off control update if needed. - b.updateIngressAndServiceHashLocked(prefs) + b.maybeSentHostinfoIfChangedLocked(prefs) b.setTCPPortsIntercepted(handlePorts) } -// updateIngressAndServiceHashLocked updates the hostinfo.ServicesHash, hostinfo.WireIngress and +// hookMaybeMutateHostinfoLocked is a hook that allows conditional features +// to mutate the provided hostinfo before it is sent to control. +// +// The hook function should return true if it mutated the hostinfo. +// +// The LocalBackend's mutex is held while calling. +var hookMaybeMutateHostinfoLocked feature.Hooks[func(*LocalBackend, *tailcfg.Hostinfo, ipn.PrefsView) bool] + +// maybeSentHostinfoIfChangedLocked updates the hostinfo.ServicesHash, hostinfo.WireIngress and // hostinfo.IngressEnabled fields and kicks off a Hostinfo update if the values have changed. // // b.mu must be held. -func (b *LocalBackend) updateIngressAndServiceHashLocked(prefs ipn.PrefsView) { +func (b *LocalBackend) maybeSentHostinfoIfChangedLocked(prefs ipn.PrefsView) { if b.hostinfo == nil { return } - hostInfoChanged := false - if ie := b.hasIngressEnabledLocked(); b.hostinfo.IngressEnabled != ie { - b.logf("Hostinfo.IngressEnabled changed to %v", ie) - b.hostinfo.IngressEnabled = ie - hostInfoChanged = true - } - if wire := b.shouldWireInactiveIngressLocked(); b.hostinfo.WireIngress != wire { - b.logf("Hostinfo.WireIngress changed to %v", wire) - b.hostinfo.WireIngress = wire - hostInfoChanged = true - } - latestHash := b.vipServiceHash(b.vipServicesFromPrefsLocked(prefs)) - if b.hostinfo.ServicesHash != latestHash { - b.hostinfo.ServicesHash = latestHash - hostInfoChanged = true + changed := false + for _, f := range hookMaybeMutateHostinfoLocked { + if f(b, b.hostinfo, prefs) { + changed = true + } } // Kick off a Hostinfo update to control if ingress status has changed. - if hostInfoChanged { + if changed { b.goTracker.Go(b.doSetHostinfoFilterServices) } } @@ -6724,6 +6644,30 @@ func (b *LocalBackend) DebugReSTUN() error { return nil } +func (b *LocalBackend) DebugRotateDiscoKey() error { + if !buildfeatures.HasDebug { + return nil + } + + mc := b.MagicConn() + mc.RotateDiscoKey() + + newDiscoKey := mc.DiscoPublicKey() + + if tunWrap, ok := b.sys.Tun.GetOK(); ok { + tunWrap.SetDiscoKey(newDiscoKey) + } + + b.mu.Lock() + cc := b.cc + b.mu.Unlock() + if cc != nil { + cc.SetDiscoPublicKey(newDiscoKey) + } + + return nil +} + func (b *LocalBackend) DebugPeerRelayServers() set.Set[netip.Addr] { return b.MagicConn().PeerRelays() } @@ -6780,6 +6724,13 @@ func (b *LocalBackend) sshServerOrInit() (_ SSHServer, err error) { return b.sshServer, nil } +var warnSyncDisabled = health.Register(&health.Warnable{ + Code: "sync-disabled", + Title: "Tailscale Sync is Disabled", + Severity: health.SeverityHigh, + Text: health.StaticMessage("Tailscale control plane syncing is disabled; run `tailscale set --sync` to restore"), +}) + var warnSSHSELinuxWarnable = health.Register(&health.Warnable{ Code: "ssh-unavailable-selinux-enabled", Title: "Tailscale SSH and SELinux", @@ -6795,6 +6746,14 @@ func (b *LocalBackend) updateSELinuxHealthWarning() { } } +func (b *LocalBackend) updateWarnSync(prefs ipn.PrefsView) { + if prefs.Sync().EqualBool(false) { + b.health.SetUnhealthy(warnSyncDisabled, nil) + } else { + b.health.SetHealthy(warnSyncDisabled) + } +} + func (b *LocalBackend) handleSSHConn(c net.Conn) (err error) { s, err := b.sshServerOrInit() if err != nil { @@ -6904,8 +6863,8 @@ func (b *LocalBackend) ShouldInterceptVIPServiceTCPPort(ap netip.AddrPort) bool // It will restart the backend on success. // If the profile is not known, it returns an errProfileNotFound. func (b *LocalBackend) SwitchProfile(profile ipn.ProfileID) error { - unlock := b.lockAndGetUnlock() - defer unlock() + b.mu.Lock() + defer b.mu.Unlock() oldControlURL := b.pm.CurrentPrefs().ControlURLOrDefault(b.polc) if _, changed, err := b.pm.SwitchToProfileByID(profile); !changed || err != nil { @@ -6917,7 +6876,7 @@ func (b *LocalBackend) SwitchProfile(profile ipn.ProfileID) error { b.resetDialPlan() } - return b.resetForProfileChangeLockedOnEntry(unlock) + return b.resetForProfileChangeLocked() } // resetDialPlan resets the dialPlan for this LocalBackend. It will log if @@ -6931,12 +6890,10 @@ func (b *LocalBackend) resetDialPlan() { } } -// resetForProfileChangeLockedOnEntry resets the backend for a profile change. +// resetForProfileChangeLocked resets the backend for a profile change. // -// b.mu must held on entry. It is released on exit. -func (b *LocalBackend) resetForProfileChangeLockedOnEntry(unlock unlockOnce) error { - defer unlock() - +// b.mu must be held. +func (b *LocalBackend) resetForProfileChangeLocked() error { if b.shutdownCalled { // Prevent a call back to Start during Shutdown, which calls Logout for // ephemeral nodes, which can then call back here. But we're shutting @@ -6953,8 +6910,8 @@ func (b *LocalBackend) resetForProfileChangeLockedOnEntry(unlock unlockOnce) err // Reset the NetworkMap in the engine b.e.SetNetworkMap(new(netmap.NetworkMap)) if prevCC := b.resetControlClientLocked(); prevCC != nil { - // Needs to happen without b.mu held. - defer prevCC.Shutdown() + // Shutdown outside of b.mu to avoid deadlocks. + b.goTracker.Go(prevCC.Shutdown) } // TKA errors should not prevent resetting the backend state. // However, we should still return the error to the caller. @@ -6967,19 +6924,19 @@ func (b *LocalBackend) resetForProfileChangeLockedOnEntry(unlock unlockOnce) err b.resetAlwaysOnOverrideLocked() b.extHost.NotifyProfileChange(b.pm.CurrentProfile(), b.pm.CurrentPrefs(), false) b.setAtomicValuesFromPrefsLocked(b.pm.CurrentPrefs()) - b.enterStateLockedOnEntry(ipn.NoState, unlock) // Reset state; releases b.mu + b.enterStateLocked(ipn.NoState) b.health.SetLocalLogConfigHealth(nil) if tkaErr != nil { return tkaErr } - return b.Start(ipn.Options{}) + return b.startLocked(ipn.Options{}) } // DeleteProfile deletes a profile with the given ID. // If the profile is not known, it is a no-op. func (b *LocalBackend) DeleteProfile(p ipn.ProfileID) error { - unlock := b.lockAndGetUnlock() - defer unlock() + b.mu.Lock() + defer b.mu.Unlock() needToRestart := b.pm.CurrentProfile().ID() == p if err := b.pm.DeleteProfile(p); err != nil { @@ -6991,7 +6948,7 @@ func (b *LocalBackend) DeleteProfile(p ipn.ProfileID) error { if !needToRestart { return nil } - return b.resetForProfileChangeLockedOnEntry(unlock) + return b.resetForProfileChangeLocked() } // CurrentProfile returns the current LoginProfile. @@ -7004,8 +6961,11 @@ func (b *LocalBackend) CurrentProfile() ipn.LoginProfileView { // NewProfile creates and switches to the new profile. func (b *LocalBackend) NewProfile() error { - unlock := b.lockAndGetUnlock() - defer unlock() + if b.health.IsUnhealthy(ipn.StateStoreHealth) { + return errors.New("cannot log in when state store is unhealthy") + } + b.mu.Lock() + defer b.mu.Unlock() b.pm.SwitchToNewProfile() @@ -7013,7 +6973,7 @@ func (b *LocalBackend) NewProfile() error { // set. Conservatively reset the dialPlan. b.resetDialPlan() - return b.resetForProfileChangeLockedOnEntry(unlock) + return b.resetForProfileChangeLocked() } // ListProfiles returns a list of all LoginProfiles. @@ -7028,12 +6988,12 @@ func (b *LocalBackend) ListProfiles() []ipn.LoginProfileView { // backend is left with a new profile, ready for StartLoginInterative to be // called to register it as new node. func (b *LocalBackend) ResetAuth() error { - unlock := b.lockAndGetUnlock() - defer unlock() + b.mu.Lock() + defer b.mu.Unlock() - prevCC := b.resetControlClientLocked() - if prevCC != nil { - defer prevCC.Shutdown() // call must happen after release b.mu + if prevCC := b.resetControlClientLocked(); prevCC != nil { + // Shutdown outside of b.mu to avoid deadlocks. + b.goTracker.Go(prevCC.Shutdown) } if err := b.clearMachineKeyLocked(); err != nil { return err @@ -7042,7 +7002,7 @@ func (b *LocalBackend) ResetAuth() error { return err } b.resetDialPlan() // always reset if we're removing everything - return b.resetForProfileChangeLockedOnEntry(unlock) + return b.resetForProfileChangeLocked() } func (b *LocalBackend) GetPeerEndpointChanges(ctx context.Context, ip netip.Addr) ([]magicsock.EndpointChange, error) { @@ -7273,7 +7233,7 @@ var ErrNoPreferredDERP = errors.New("no preferred DERP, try again later") // be selected at random, so the result is not stable. To be eligible for // consideration, the peer must have NodeAttrSuggestExitNode in its CapMap. // -// b.mu.lock() must be held. +// b.mu must be held. func (b *LocalBackend) suggestExitNodeLocked() (response apitype.ExitNodeSuggestionResponse, err error) { if !buildfeatures.HasUseExitNode { return response, feature.ErrUnavailable @@ -7319,7 +7279,12 @@ func (b *LocalBackend) refreshAllowedSuggestions() { } b.allowedSuggestedExitNodesMu.Lock() defer b.allowedSuggestedExitNodesMu.Unlock() - b.allowedSuggestedExitNodes = fillAllowedSuggestions(b.polc) + + var err error + b.allowedSuggestedExitNodes, err = fillAllowedSuggestions(b.polc) + if err != nil { + b.logf("error refreshing allowed suggestions: %v", err) + } } // selectRegionFunc returns a DERP region from the slice of candidate regions. @@ -7331,20 +7296,19 @@ type selectRegionFunc func(views.Slice[int]) int // choice. type selectNodeFunc func(nodes views.Slice[tailcfg.NodeView], last tailcfg.StableNodeID) tailcfg.NodeView -func fillAllowedSuggestions(polc policyclient.Client) set.Set[tailcfg.StableNodeID] { +func fillAllowedSuggestions(polc policyclient.Client) (set.Set[tailcfg.StableNodeID], error) { nodes, err := polc.GetStringArray(pkey.AllowedSuggestedExitNodes, nil) if err != nil { - log.Printf("fillAllowedSuggestions: unable to look up %q policy: %v", pkey.AllowedSuggestedExitNodes, err) - return nil + return nil, fmt.Errorf("fillAllowedSuggestions: unable to look up %q policy: %w", pkey.AllowedSuggestedExitNodes, err) } if nodes == nil { - return nil + return nil, nil } s := make(set.Set[tailcfg.StableNodeID], len(nodes)) for _, n := range nodes { s.Add(tailcfg.StableNodeID(n)) } - return s + return s, nil } // suggestExitNode returns a suggestion for reasonably good exit node based on @@ -7355,6 +7319,9 @@ func suggestExitNode(report *netcheck.Report, nb *nodeBackend, prevSuggestion ta // The traffic-steering feature flag is enabled on this tailnet. return suggestExitNodeUsingTrafficSteering(nb, allowList) default: + // The control plane will always strip the `traffic-steering` + // node attribute if it isn’t enabled for this tailnet, even if + // it is set in the policy file: tailscale/corp#34401 return suggestExitNodeUsingDERP(report, nb, prevSuggestion, selectRegion, selectNode, allowList) } } @@ -7483,6 +7450,16 @@ func suggestExitNodeUsingDERP(report *netcheck.Report, nb *nodeBackend, prevSugg } } bestCandidates := pickWeighted(pickFrom) + + // We may have an empty list of candidates here, if none of the candidates + // have home DERP info. + // + // We know that candidates is non-empty or we'd already have returned, so if + // we've filtered everything out of bestCandidates, just use candidates. + if len(bestCandidates) == 0 { + bestCandidates = candidates + } + chosen := selectNode(views.SliceOf(bestCandidates), prevSuggestion) if !chosen.Valid() { return res, errors.New("chosen candidate invalid: this is a bug") @@ -7744,19 +7721,6 @@ func maybeUsernameOf(actor ipnauth.Actor) string { return username } -func (b *LocalBackend) vipServiceHash(services []*tailcfg.VIPService) string { - if len(services) == 0 { - return "" - } - buf, err := json.Marshal(services) - if err != nil { - b.logf("vipServiceHashLocked: %v", err) - return "" - } - hash := sha256.Sum256(buf) - return hex.EncodeToString(hash[:]) -} - var ( metricCurrentWatchIPNBus = clientmetric.NewGauge("localbackend_current_watch_ipn_bus") ) diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index 33ecb688c52a3..02997a0e12fce 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -20,6 +20,7 @@ import ( "slices" "strings" "sync" + "sync/atomic" "testing" "time" @@ -49,6 +50,7 @@ import ( "tailscale.com/tsd" "tailscale.com/tstest" "tailscale.com/tstest/deptest" + "tailscale.com/tstest/typewalk" "tailscale.com/types/appctype" "tailscale.com/types/dnstype" "tailscale.com/types/ipproto" @@ -57,6 +59,7 @@ import ( "tailscale.com/types/logid" "tailscale.com/types/netmap" "tailscale.com/types/opt" + "tailscale.com/types/persist" "tailscale.com/types/ptr" "tailscale.com/types/views" "tailscale.com/util/dnsname" @@ -1503,15 +1506,6 @@ func wantExitNodeIDNotify(want tailcfg.StableNodeID) wantedNotification { } } -func wantStateNotify(want ipn.State) wantedNotification { - return wantedNotification{ - name: "State=" + want.String(), - cond: func(_ testing.TB, _ ipnauth.Actor, n *ipn.Notify) bool { - return n.State != nil && *n.State == want - }, - } -} - func TestInternalAndExternalInterfaces(t *testing.T) { type interfacePrefix struct { i netmon.Interface @@ -2718,8 +2712,8 @@ func TestSetExitNodeIDPolicy(t *testing.T) { exitNodeIPWant: "127.0.0.1", prefsChanged: false, nm: &netmap.NetworkMap{ - Name: "foo.tailnet", SelfNode: (&tailcfg.Node{ + Name: "foo.tailnet.", Addresses: []netip.Prefix{ pfx("100.102.103.104/32"), pfx("100::123/128"), @@ -2755,8 +2749,8 @@ func TestSetExitNodeIDPolicy(t *testing.T) { exitNodeIDWant: "123", prefsChanged: true, nm: &netmap.NetworkMap{ - Name: "foo.tailnet", SelfNode: (&tailcfg.Node{ + Name: "foo.tailnet.", Addresses: []netip.Prefix{ pfx("100.102.103.104/32"), pfx("100::123/128"), @@ -2793,8 +2787,8 @@ func TestSetExitNodeIDPolicy(t *testing.T) { exitNodeIDWant: "123", prefsChanged: true, nm: &netmap.NetworkMap{ - Name: "foo.tailnet", SelfNode: (&tailcfg.Node{ + Name: "foo.tailnet.", Addresses: []netip.Prefix{ pfx("100.102.103.104/32"), pfx("100::123/128"), @@ -2833,8 +2827,8 @@ func TestSetExitNodeIDPolicy(t *testing.T) { exitNodeIDWant: "123", prefsChanged: true, nm: &netmap.NetworkMap{ - Name: "foo.tailnet", SelfNode: (&tailcfg.Node{ + Name: "foo.tailnet.", Addresses: []netip.Prefix{ pfx("100.102.103.104/32"), pfx("100::123/128"), @@ -4318,9 +4312,9 @@ func (b *LocalBackend) SetPrefsForTest(newp *ipn.Prefs) { if newp == nil { panic("SetPrefsForTest got nil prefs") } - unlock := b.lockAndGetUnlock() - defer unlock() - b.setPrefsLockedOnEntry(newp, unlock) + b.mu.Lock() + defer b.mu.Unlock() + b.setPrefsLocked(newp) } type peerOptFunc func(*tailcfg.Node) @@ -4442,6 +4436,14 @@ func deterministicRegionForTest(t testing.TB, want views.Slice[int], use int) se } } +// deterministicNodeForTest returns a deterministic selectNodeFunc, which +// allows us to make stable assertions about which exit node will be chosen +// from a list of possible candidates. +// +// When given a list of candidates, it checks that `use` is in the list and +// returns that. +// +// It verifies that `wantLast` was passed to `selectNode(…, want)`. func deterministicNodeForTest(t testing.TB, want views.Slice[tailcfg.StableNodeID], wantLast tailcfg.StableNodeID, use tailcfg.StableNodeID) selectNodeFunc { t.Helper() @@ -4450,6 +4452,16 @@ func deterministicNodeForTest(t testing.TB, want views.Slice[tailcfg.StableNodeI } return func(got views.Slice[tailcfg.NodeView], last tailcfg.StableNodeID) tailcfg.NodeView { + // In the tests, we choose nodes deterministically so we can get + // stable results, but in the real code, we choose nodes randomly. + // + // Call the randomNode function anyway, and ensure it returns + // a sensible result. + view := randomNode(got, last) + if !views.SliceContains(got, view) { + t.Fatalf("randomNode returns an unexpected node") + } + var ret tailcfg.NodeView gotIDs := make([]tailcfg.StableNodeID, got.Len()) @@ -4535,6 +4547,7 @@ func TestSuggestExitNode(t *testing.T) { Longitude: -97.3325, Priority: 100, } + var emptyLocation *tailcfg.Location peer1 := makePeer(1, withExitRoutes(), @@ -4574,6 +4587,18 @@ func TestSuggestExitNode(t *testing.T) { withExitRoutes(), withSuggest(), withLocation(fortWorthLowPriority.View())) + emptyLocationPeer9 := makePeer(9, + withoutDERP(), + withExitRoutes(), + withSuggest(), + withLocation(emptyLocation.View()), + ) + emptyLocationPeer10 := makePeer(10, + withoutDERP(), + withExitRoutes(), + withSuggest(), + withLocation(emptyLocation.View()), + ) selfNode := tailcfg.Node{ Addresses: []netip.Prefix{ @@ -4904,6 +4929,31 @@ func TestSuggestExitNode(t *testing.T) { wantName: "San Jose", wantLocation: sanJose.View(), }, + { + // Regression test for https://github.com/tailscale/tailscale/issues/17661 + name: "exit nodes with no home DERP, randomly selected", + lastReport: &netcheck.Report{ + RegionLatency: map[int]time.Duration{ + 1: 10, + 2: 20, + 3: 10, + }, + PreferredDERP: 1, + }, + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + DERPMap: defaultDERPMap, + Peers: []tailcfg.NodeView{ + emptyLocationPeer9, + emptyLocationPeer10, + }, + }, + wantRegions: []int{1, 2}, + wantName: "peer9", + wantNodes: []tailcfg.StableNodeID{"stable9", "stable10"}, + wantID: "stable9", + useRegion: 1, + }, } for _, tt := range tests { @@ -5179,6 +5229,26 @@ func TestSuggestExitNodeTrafficSteering(t *testing.T) { wantID: "stable3", wantName: "peer3", }, + { + name: "exit-nodes-without-priority-for-suggestions", + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + Peers: []tailcfg.NodeView{ + makePeer(1, + withExitRoutes(), + withSuggest()), + makePeer(2, + withExitRoutes(), + withSuggest()), + makePeer(3, + withExitRoutes(), + withLocationPriority(1)), + }, + }, + wantID: "stable1", + wantName: "peer1", + wantPri: 0, + }, { name: "exit-nodes-with-and-without-priority", netMap: &netmap.NetworkMap{ @@ -5596,7 +5666,10 @@ func TestFillAllowedSuggestions(t *testing.T) { var pol policytest.Config pol.Set(pkey.AllowedSuggestedExitNodes, tt.allowPolicy) - got := fillAllowedSuggestions(pol) + got, err := fillAllowedSuggestions(pol) + if err != nil { + t.Fatal(err) + } if got == nil { if tt.want == nil { return @@ -5808,12 +5881,12 @@ func TestNotificationTargetMatch(t *testing.T) { type newTestControlFn func(tb testing.TB, opts controlclient.Options) controlclient.Client -func newLocalBackendWithTestControl(t *testing.T, enableLogging bool, newControl newTestControlFn) *LocalBackend { +func newLocalBackendWithTestControl(t testing.TB, enableLogging bool, newControl newTestControlFn) *LocalBackend { bus := eventbustest.NewBus(t) return newLocalBackendWithSysAndTestControl(t, enableLogging, tsd.NewSystemWithBus(bus), newControl) } -func newLocalBackendWithSysAndTestControl(t *testing.T, enableLogging bool, sys *tsd.System, newControl newTestControlFn) *LocalBackend { +func newLocalBackendWithSysAndTestControl(t testing.TB, enableLogging bool, sys *tsd.System, newControl newTestControlFn) *LocalBackend { logf := logger.Discard if enableLogging { logf = tstest.WhileTestRunningLogger(t) @@ -6745,7 +6818,7 @@ func TestUpdateIngressAndServiceHashLocked(t *testing.T) { if tt.hasPreviousSC { b.mu.Lock() b.serveConfig = previousSC.View() - b.hostinfo.ServicesHash = b.vipServiceHash(b.vipServicesFromPrefsLocked(prefs)) + b.hostinfo.ServicesHash = vipServiceHash(b.logf, b.vipServicesFromPrefsLocked(prefs)) b.mu.Unlock() } b.serveConfig = tt.sc.View() @@ -6763,7 +6836,7 @@ func TestUpdateIngressAndServiceHashLocked(t *testing.T) { })() was := b.goTracker.StartedGoroutines() - b.updateIngressAndServiceHashLocked(prefs) + b.maybeSentHostinfoIfChangedLocked(prefs) if tt.hi != nil { if tt.hi.IngressEnabled != tt.wantIngress { @@ -6773,7 +6846,7 @@ func TestUpdateIngressAndServiceHashLocked(t *testing.T) { t.Errorf("WireIngress = %v, want %v", tt.hi.WireIngress, tt.wantWireIngress) } b.mu.Lock() - svcHash := b.vipServiceHash(b.vipServicesFromPrefsLocked(prefs)) + svcHash := vipServiceHash(b.logf, b.vipServicesFromPrefsLocked(prefs)) b.mu.Unlock() if tt.hi.ServicesHash != svcHash { t.Errorf("ServicesHash = %v, want %v", tt.hi.ServicesHash, svcHash) @@ -7121,3 +7194,104 @@ func eqUpdate(want appctype.RouteUpdate) func(appctype.RouteUpdate) error { return nil } } + +type fakeAttestationKey struct{ key.HardwareAttestationKey } + +func (f *fakeAttestationKey) Clone() key.HardwareAttestationKey { + return &fakeAttestationKey{} +} + +// TestStripKeysFromPrefs tests that LocalBackend's [stripKeysFromPrefs] (as used +// by sendNotify etc) correctly removes all private keys from an ipn.Notify. +// +// It does so by testing the the two ways that Notifys are sent: via sendNotify, +// and via extension hooks. +func TestStripKeysFromPrefs(t *testing.T) { + // genNotify generates a sample ipn.Notify with various private keys set + // at a certain path through the Notify data structure. + genNotify := map[string]func() ipn.Notify{ + "Notify.Prefs.ж.Persist.PrivateNodeKey": func() ipn.Notify { + return ipn.Notify{ + Prefs: ptr.To((&ipn.Prefs{ + Persist: &persist.Persist{PrivateNodeKey: key.NewNode()}, + }).View()), + } + }, + "Notify.Prefs.ж.Persist.OldPrivateNodeKey": func() ipn.Notify { + return ipn.Notify{ + Prefs: ptr.To((&ipn.Prefs{ + Persist: &persist.Persist{OldPrivateNodeKey: key.NewNode()}, + }).View()), + } + }, + "Notify.Prefs.ж.Persist.NetworkLockKey": func() ipn.Notify { + return ipn.Notify{ + Prefs: ptr.To((&ipn.Prefs{ + Persist: &persist.Persist{NetworkLockKey: key.NewNLPrivate()}, + }).View()), + } + }, + "Notify.Prefs.ж.Persist.AttestationKey": func() ipn.Notify { + return ipn.Notify{ + Prefs: ptr.To((&ipn.Prefs{ + Persist: &persist.Persist{AttestationKey: new(fakeAttestationKey)}, + }).View()), + } + }, + } + + private := key.PrivateTypesForTest() + + for path := range typewalk.MatchingPaths(reflect.TypeFor[ipn.Notify](), private.Contains) { + t.Run(path.Name, func(t *testing.T) { + gen, ok := genNotify[path.Name] + if !ok { + t.Fatalf("no genNotify function for path %q", path.Name) + } + withKey := gen() + + if path.Walk(reflect.ValueOf(withKey)).IsZero() { + t.Fatalf("generated notify does not have non-zero value at path %q", path.Name) + } + + h := &ExtensionHost{} + ch := make(chan *ipn.Notify, 1) + b := &LocalBackend{ + extHost: h, + notifyWatchers: map[string]*watchSession{ + "test": {ch: ch}, + }, + } + + var okay atomic.Int32 + testNotify := func(via string) func(*ipn.Notify) { + return func(n *ipn.Notify) { + if n == nil { + t.Errorf("notify from %s is nil", via) + return + } + if !path.Walk(reflect.ValueOf(*n)).IsZero() { + t.Errorf("notify from %s has non-zero value at path %q; key not stripped", via, path.Name) + } else { + okay.Add(1) + } + } + } + + h.Hooks().MutateNotifyLocked.Add(testNotify("MutateNotifyLocked hook")) + + b.send(withKey) + + select { + case n := <-ch: + testNotify("watchSession")(n) + default: + t.Errorf("no notify sent to watcher channel") + } + + if got := okay.Load(); got != 2 { + t.Errorf("notify passed validation %d times; want 2", got) + } + }) + } +} diff --git a/ipn/ipnlocal/network-lock.go b/ipn/ipnlocal/network-lock.go index c769e242d4405..f25c6fa9b5e36 100644 --- a/ipn/ipnlocal/network-lock.go +++ b/ipn/ipnlocal/network-lock.go @@ -300,8 +300,11 @@ func (b *LocalBackend) tkaSyncIfNeeded(nm *netmap.NetworkMap, prefs ipn.PrefsVie return nil } - if b.tka != nil || nm.TKAEnabled { - b.logf("tkaSyncIfNeeded: enabled=%v, head=%v", nm.TKAEnabled, nm.TKAHead) + isEnabled := b.tka != nil + wantEnabled := nm.TKAEnabled + + if isEnabled || wantEnabled { + b.logf("tkaSyncIfNeeded: isEnabled=%t, wantEnabled=%t, head=%v", isEnabled, wantEnabled, nm.TKAHead) } ourNodeKey, ok := prefs.Persist().PublicNodeKeyOK() @@ -309,8 +312,6 @@ func (b *LocalBackend) tkaSyncIfNeeded(nm *netmap.NetworkMap, prefs ipn.PrefsVie return errors.New("tkaSyncIfNeeded: no node key in prefs") } - isEnabled := b.tka != nil - wantEnabled := nm.TKAEnabled didJustEnable := false if isEnabled != wantEnabled { var ourHead tka.AUMHash @@ -355,25 +356,18 @@ func (b *LocalBackend) tkaSyncIfNeeded(nm *netmap.NetworkMap, prefs ipn.PrefsVie if err := b.tkaSyncLocked(ourNodeKey); err != nil { return fmt.Errorf("tka sync: %w", err) } + // Try to compact the TKA state, to avoid unbounded storage on nodes. + // + // We run this on every sync so that clients compact consistently. In many + // cases this will be a no-op. + if err := b.tka.authority.Compact(b.tka.storage, tkaCompactionDefaults); err != nil { + return fmt.Errorf("tka compact: %w", err) + } } return nil } -func toSyncOffer(head string, ancestors []string) (tka.SyncOffer, error) { - var out tka.SyncOffer - if err := out.Head.UnmarshalText([]byte(head)); err != nil { - return tka.SyncOffer{}, fmt.Errorf("head.UnmarshalText: %v", err) - } - out.Ancestors = make([]tka.AUMHash, len(ancestors)) - for i, a := range ancestors { - if err := out.Ancestors[i].UnmarshalText([]byte(a)); err != nil { - return tka.SyncOffer{}, fmt.Errorf("ancestor[%d].UnmarshalText: %v", i, err) - } - } - return out, nil -} - // tkaSyncLocked synchronizes TKA state with control. b.mu must be held // and tka must be initialized. b.mu will be stepped out of (and back into) // during network RPCs. @@ -391,7 +385,7 @@ func (b *LocalBackend) tkaSyncLocked(ourNodeKey key.NodePublic) error { if err != nil { return fmt.Errorf("offer RPC: %w", err) } - controlOffer, err := toSyncOffer(offerResp.Head, offerResp.Ancestors) + controlOffer, err := tka.ToSyncOffer(offerResp.Head, offerResp.Ancestors) if err != nil { return fmt.Errorf("control offer: %v", err) } @@ -476,10 +470,6 @@ func (b *LocalBackend) chonkPathLocked() string { // // b.mu must be held. func (b *LocalBackend) tkaBootstrapFromGenesisLocked(g tkatype.MarshaledAUM, persist persist.PersistView) error { - if err := b.CanSupportNetworkLock(); err != nil { - return err - } - var genesis tka.AUM if err := genesis.Unserialize(g); err != nil { return fmt.Errorf("reading genesis: %v", err) @@ -503,7 +493,7 @@ func (b *LocalBackend) tkaBootstrapFromGenesisLocked(g tkatype.MarshaledAUM, per if root == "" { b.health.SetUnhealthy(noNetworkLockStateDirWarnable, nil) b.logf("network-lock using in-memory storage; no state directory") - storage = &tka.Mem{} + storage = tka.ChonkMem() } else { chonkDir := b.chonkPathLocked() chonk, err := tka.ChonkDir(chonkDir) @@ -525,20 +515,6 @@ func (b *LocalBackend) tkaBootstrapFromGenesisLocked(g tkatype.MarshaledAUM, per return nil } -// CanSupportNetworkLock returns nil if tailscaled is able to operate -// a local tailnet key authority (and hence enforce network lock). -func (b *LocalBackend) CanSupportNetworkLock() error { - if b.tka != nil { - // If the TKA is being used, it is supported. - return nil - } - - // There's a var root (aka --statedir), so if network lock gets - // initialized we have somewhere to store our AUMs. That's all - // we need. - return nil -} - // NetworkLockStatus returns a structure describing the state of the // tailnet key authority, if any. func (b *LocalBackend) NetworkLockStatus() *ipnstate.NetworkLockStatus { @@ -652,12 +628,7 @@ func tkaStateFromPeer(p tailcfg.NodeView) ipnstate.TKAPeer { // needing signatures is returned as a response. // The Finish RPC submits signatures for all these nodes, at which point // Control has everything it needs to atomically enable network lock. -// TODO(alexc): Only with persistent backend func (b *LocalBackend) NetworkLockInit(keys []tka.Key, disablementValues [][]byte, supportDisablement []byte) error { - if err := b.CanSupportNetworkLock(); err != nil { - return err - } - var ourNodeKey key.NodePublic var nlPriv key.NLPrivate @@ -681,7 +652,7 @@ func (b *LocalBackend) NetworkLockInit(keys []tka.Key, disablementValues [][]byt // We use an in-memory tailchonk because we don't want to commit to // the filesystem until we've finished the initialization sequence, // just in case something goes wrong. - _, genesisAUM, err := tka.Create(&tka.Mem{}, tka.State{ + _, genesisAUM, err := tka.Create(tka.ChonkMem(), tka.State{ Keys: keys, // TODO(tom): s/tka.State.DisablementSecrets/tka.State.DisablementValues // This will center on consistent nomenclature: @@ -709,7 +680,7 @@ func (b *LocalBackend) NetworkLockInit(keys []tka.Key, disablementValues [][]byt // Our genesis AUM was accepted but before Control turns on enforcement of // node-key signatures, we need to sign keys for all the existing nodes. - // If we don't get these signatures ahead of time, everyone will loose + // If we don't get these signatures ahead of time, everyone will lose // connectivity because control won't have any signatures to send which // satisfy network-lock checks. sigs := make(map[tailcfg.NodeID]tkatype.MarshaledSignature, len(initResp.NeedSignatures)) @@ -782,7 +753,6 @@ func (b *LocalBackend) NetworkLockForceLocalDisable() error { // NetworkLockSign signs the given node-key and submits it to the control plane. // rotationPublic, if specified, must be an ed25519 public key. -// TODO(alexc): in-memory only func (b *LocalBackend) NetworkLockSign(nodeKey key.NodePublic, rotationPublic []byte) error { ourNodeKey, sig, err := func(nodeKey key.NodePublic, rotationPublic []byte) (key.NodePublic, tka.NodeKeySignature, error) { b.mu.Lock() @@ -960,7 +930,7 @@ func (b *LocalBackend) NetworkLockLog(maxEntries int) ([]ipnstate.NetworkLockUpd if err == os.ErrNotExist { break } - return out, fmt.Errorf("reading AUM: %w", err) + return out, fmt.Errorf("reading AUM (%v): %w", cursor, err) } update := ipnstate.NetworkLockUpdate{ @@ -1310,27 +1280,10 @@ func (b *LocalBackend) tkaFetchBootstrap(ourNodeKey key.NodePublic, head tka.AUM return a, nil } -func fromSyncOffer(offer tka.SyncOffer) (head string, ancestors []string, err error) { - headBytes, err := offer.Head.MarshalText() - if err != nil { - return "", nil, fmt.Errorf("head.MarshalText: %v", err) - } - - ancestors = make([]string, len(offer.Ancestors)) - for i, ancestor := range offer.Ancestors { - hash, err := ancestor.MarshalText() - if err != nil { - return "", nil, fmt.Errorf("ancestor[%d].MarshalText: %v", i, err) - } - ancestors[i] = string(hash) - } - return string(headBytes), ancestors, nil -} - // tkaDoSyncOffer sends a /machine/tka/sync/offer RPC to the control plane // over noise. This is the first of two RPCs implementing tka synchronization. func (b *LocalBackend) tkaDoSyncOffer(ourNodeKey key.NodePublic, offer tka.SyncOffer) (*tailcfg.TKASyncOfferResponse, error) { - head, ancestors, err := fromSyncOffer(offer) + head, ancestors, err := tka.FromSyncOffer(offer) if err != nil { return nil, fmt.Errorf("encoding offer: %v", err) } diff --git a/ipn/ipnlocal/network-lock_test.go b/ipn/ipnlocal/network-lock_test.go index c7c4c905f5ca1..e5df38bdb6d76 100644 --- a/ipn/ipnlocal/network-lock_test.go +++ b/ipn/ipnlocal/network-lock_test.go @@ -17,6 +17,7 @@ import ( "path/filepath" "reflect" "testing" + "time" go4mem "go4.org/mem" @@ -31,23 +32,18 @@ import ( "tailscale.com/tailcfg" "tailscale.com/tka" "tailscale.com/tsd" + "tailscale.com/tstest" + "tailscale.com/tstest/tkatest" "tailscale.com/types/key" "tailscale.com/types/netmap" "tailscale.com/types/persist" "tailscale.com/types/tkatype" - "tailscale.com/util/eventbus" "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/must" "tailscale.com/util/set" ) -type observerFunc func(controlclient.Status) - -func (f observerFunc) SetControlClientStatus(_ controlclient.Client, s controlclient.Status) { - f(s) -} - -func fakeControlClient(t *testing.T, c *http.Client) (*controlclient.Auto, *eventbus.Bus) { +func fakeControlClient(t *testing.T, c *http.Client) *controlclient.Auto { hi := hostinfo.New() ni := tailcfg.NetInfo{LinkType: "wired"} hi.NetInfo = &ni @@ -55,7 +51,6 @@ func fakeControlClient(t *testing.T, c *http.Client) (*controlclient.Auto, *even k := key.NewMachine() dialer := tsdial.NewDialer(netmon.NewStatic()) - dialer.SetBus(bus) opts := controlclient.Options{ ServerURL: "https://example.com", Hostinfo: hi, @@ -64,19 +59,21 @@ func fakeControlClient(t *testing.T, c *http.Client) (*controlclient.Auto, *even }, HTTPTestClient: c, NoiseTestClient: c, - Observer: observerFunc(func(controlclient.Status) {}), Dialer: dialer, Bus: bus, + + SkipStartForTests: true, } - cc, err := controlclient.NewNoStart(opts) + cc, err := controlclient.New(opts) if err != nil { t.Fatal(err) } - return cc, bus + return cc } func fakeNoiseServer(t *testing.T, handler http.HandlerFunc) (*httptest.Server, *http.Client) { + t.Helper() ts := httptest.NewUnstartedServer(handler) ts.StartTLS() client := ts.Client() @@ -87,6 +84,17 @@ func fakeNoiseServer(t *testing.T, handler http.HandlerFunc) (*httptest.Server, return ts, client } +func setupProfileManager(t *testing.T, nodePriv key.NodePrivate, nlPriv key.NLPrivate) *profileManager { + pm := must.Get(newProfileManager(new(mem.Store), t.Logf, health.NewTracker(eventbustest.NewBus(t)))) + must.Do(pm.SetPrefs((&ipn.Prefs{ + Persist: &persist.Persist{ + PrivateNodeKey: nodePriv, + NetworkLockKey: nlPriv, + }, + }).View(), ipn.NetworkProfile{})) + return pm +} + func TestTKAEnablementFlow(t *testing.T) { nodePriv := key.NewNode() @@ -94,7 +102,8 @@ func TestTKAEnablementFlow(t *testing.T) { // our mock server can communicate. nlPriv := key.NewNLPrivate() key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} - a1, genesisAUM, err := tka.Create(&tka.Mem{}, tka.State{ + chonk := tka.ChonkMem() + a1, genesisAUM, err := tka.Create(chonk, tka.State{ Keys: []tka.Key{key}, DisablementSecrets: [][]byte{bytes.Repeat([]byte{0xa5}, 32)}, }, nlPriv) @@ -106,51 +115,31 @@ func TestTKAEnablementFlow(t *testing.T) { defer r.Body.Close() switch r.URL.Path { case "/machine/tka/bootstrap": - body := new(tailcfg.TKABootstrapRequest) - if err := json.NewDecoder(r.Body).Decode(body); err != nil { - t.Fatal(err) - } - if body.Version != tailcfg.CurrentCapabilityVersion { - t.Errorf("bootstrap CapVer = %v, want %v", body.Version, tailcfg.CurrentCapabilityVersion) - } - if body.NodeKey != nodePriv.Public() { - t.Errorf("bootstrap nodeKey=%v, want %v", body.NodeKey, nodePriv.Public()) + resp := tailcfg.TKABootstrapResponse{ + GenesisAUM: genesisAUM.Serialize(), } - if body.Head != "" { - t.Errorf("bootstrap head=%s, want empty hash", body.Head) + req, err := tkatest.HandleTKABootstrap(w, r, resp) + if err != nil { + t.Errorf("HandleTKABootstrap: %v", err) } - - w.WriteHeader(200) - out := tailcfg.TKABootstrapResponse{ - GenesisAUM: genesisAUM.Serialize(), + if req.NodeKey != nodePriv.Public() { + t.Errorf("bootstrap nodeKey=%v, want %v", req.NodeKey, nodePriv.Public()) } - if err := json.NewEncoder(w).Encode(out); err != nil { - t.Fatal(err) + if req.Head != "" { + t.Errorf("bootstrap head=%s, want empty hash", req.Head) } // Sync offer/send endpoints are hit even though the node is up-to-date, // so we implement enough of a fake that the client doesn't explode. case "/machine/tka/sync/offer": - head, err := a1.Head().MarshalText() + err := tkatest.HandleTKASyncOffer(w, r, a1, chonk) if err != nil { - t.Fatal(err) - } - w.WriteHeader(200) - if err := json.NewEncoder(w).Encode(tailcfg.TKASyncOfferResponse{ - Head: string(head), - }); err != nil { - t.Fatal(err) + t.Errorf("HandleTKASyncOffer: %v", err) } case "/machine/tka/sync/send": - head, err := a1.Head().MarshalText() + err := tkatest.HandleTKASyncSend(w, r, a1, chonk) if err != nil { - t.Fatal(err) - } - w.WriteHeader(200) - if err := json.NewEncoder(w).Encode(tailcfg.TKASyncSendResponse{ - Head: string(head), - }); err != nil { - t.Fatal(err) + t.Errorf("HandleTKASyncOffer: %v", err) } default: @@ -161,14 +150,8 @@ func TestTKAEnablementFlow(t *testing.T) { defer ts.Close() temp := t.TempDir() - cc, bus := fakeControlClient(t, client) - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, health.NewTracker(bus))) - must.Do(pm.SetPrefs((&ipn.Prefs{ - Persist: &persist.Persist{ - PrivateNodeKey: nodePriv, - NetworkLockKey: nlPriv, - }, - }).View(), ipn.NetworkProfile{})) + cc := fakeControlClient(t, client) + pm := setupProfileManager(t, nodePriv, nlPriv) b := LocalBackend{ capTailnetLock: true, varRoot: temp, @@ -202,13 +185,7 @@ func TestTKADisablementFlow(t *testing.T) { nlPriv := key.NewNLPrivate() key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, health.NewTracker(eventbustest.NewBus(t)))) - must.Do(pm.SetPrefs((&ipn.Prefs{ - Persist: &persist.Persist{ - PrivateNodeKey: nodePriv, - NetworkLockKey: nlPriv, - }, - }).View(), ipn.NetworkProfile{})) + pm := setupProfileManager(t, nodePriv, nlPriv) temp := t.TempDir() tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) @@ -230,37 +207,28 @@ func TestTKADisablementFlow(t *testing.T) { defer r.Body.Close() switch r.URL.Path { case "/machine/tka/bootstrap": - body := new(tailcfg.TKABootstrapRequest) - if err := json.NewDecoder(r.Body).Decode(body); err != nil { - t.Fatal(err) - } - if body.Version != tailcfg.CurrentCapabilityVersion { - t.Errorf("bootstrap CapVer = %v, want %v", body.Version, tailcfg.CurrentCapabilityVersion) - } - if body.NodeKey != nodePriv.Public() { - t.Errorf("nodeKey=%v, want %v", body.NodeKey, nodePriv.Public()) - } - var head tka.AUMHash - if err := head.UnmarshalText([]byte(body.Head)); err != nil { - t.Fatalf("failed unmarshal of body.Head: %v", err) - } - if head != authority.Head() { - t.Errorf("reported head = %x, want %x", head, authority.Head()) - } - var disablement []byte if returnWrongSecret { disablement = bytes.Repeat([]byte{0x42}, 32) // wrong secret } else { disablement = disablementSecret } - - w.WriteHeader(200) - out := tailcfg.TKABootstrapResponse{ + resp := tailcfg.TKABootstrapResponse{ DisablementSecret: disablement, } - if err := json.NewEncoder(w).Encode(out); err != nil { - t.Fatal(err) + req, err := tkatest.HandleTKABootstrap(w, r, resp) + if err != nil { + t.Errorf("HandleTKABootstrap: %v", err) + } + if req.NodeKey != nodePriv.Public() { + t.Errorf("nodeKey=%v, want %v", req.NodeKey, nodePriv.Public()) + } + var head tka.AUMHash + if err := head.UnmarshalText([]byte(req.Head)); err != nil { + t.Fatalf("failed unmarshal of body.Head: %v", err) + } + if head != authority.Head() { + t.Errorf("reported head = %x, want %x", head, authority.Head()) } default: @@ -270,7 +238,7 @@ func TestTKADisablementFlow(t *testing.T) { })) defer ts.Close() - cc, _ := fakeControlClient(t, client) + cc := fakeControlClient(t, client) b := LocalBackend{ varRoot: temp, cc: cc, @@ -394,17 +362,11 @@ func TestTKASync(t *testing.T) { t.Run(tc.name, func(t *testing.T) { nodePriv := key.NewNode() nlPriv := key.NewNLPrivate() - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, health.NewTracker(eventbustest.NewBus(t)))) - must.Do(pm.SetPrefs((&ipn.Prefs{ - Persist: &persist.Persist{ - PrivateNodeKey: nodePriv, - NetworkLockKey: nlPriv, - }, - }).View(), ipn.NetworkProfile{})) + pm := setupProfileManager(t, nodePriv, nlPriv) // Setup the tka authority on the control plane. key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} - controlStorage := &tka.Mem{} + controlStorage := tka.ChonkMem() controlAuthority, bootstrap, err := tka.Create(controlStorage, tka.State{ Keys: []tka.Key{key, someKey}, DisablementSecrets: [][]byte{tka.DisablementKDF(disablementSecret)}, @@ -441,76 +403,15 @@ func TestTKASync(t *testing.T) { defer r.Body.Close() switch r.URL.Path { case "/machine/tka/sync/offer": - body := new(tailcfg.TKASyncOfferRequest) - if err := json.NewDecoder(r.Body).Decode(body); err != nil { - t.Fatal(err) - } - t.Logf("got sync offer:\n%+v", body) - nodeOffer, err := toSyncOffer(body.Head, body.Ancestors) - if err != nil { - t.Fatal(err) - } - controlOffer, err := controlAuthority.SyncOffer(controlStorage) + err := tkatest.HandleTKASyncOffer(w, r, controlAuthority, controlStorage) if err != nil { - t.Fatal(err) - } - sendAUMs, err := controlAuthority.MissingAUMs(controlStorage, nodeOffer) - if err != nil { - t.Fatal(err) - } - - head, ancestors, err := fromSyncOffer(controlOffer) - if err != nil { - t.Fatal(err) - } - resp := tailcfg.TKASyncOfferResponse{ - Head: head, - Ancestors: ancestors, - MissingAUMs: make([]tkatype.MarshaledAUM, len(sendAUMs)), - } - for i, a := range sendAUMs { - resp.MissingAUMs[i] = a.Serialize() - } - - t.Logf("responding to sync offer with:\n%+v", resp) - w.WriteHeader(200) - if err := json.NewEncoder(w).Encode(resp); err != nil { - t.Fatal(err) + t.Errorf("HandleTKASyncOffer: %v", err) } case "/machine/tka/sync/send": - body := new(tailcfg.TKASyncSendRequest) - if err := json.NewDecoder(r.Body).Decode(body); err != nil { - t.Fatal(err) - } - t.Logf("got sync send:\n%+v", body) - - var remoteHead tka.AUMHash - if err := remoteHead.UnmarshalText([]byte(body.Head)); err != nil { - t.Fatalf("head unmarshal: %v", err) - } - toApply := make([]tka.AUM, len(body.MissingAUMs)) - for i, a := range body.MissingAUMs { - if err := toApply[i].Unserialize(a); err != nil { - t.Fatalf("decoding missingAUM[%d]: %v", i, err) - } - } - - if len(toApply) > 0 { - if err := controlAuthority.Inform(controlStorage, toApply); err != nil { - t.Fatalf("control.Inform(%+v) failed: %v", toApply, err) - } - } - head, err := controlAuthority.Head().MarshalText() + err := tkatest.HandleTKASyncSend(w, r, controlAuthority, controlStorage) if err != nil { - t.Fatal(err) - } - - w.WriteHeader(200) - if err := json.NewEncoder(w).Encode(tailcfg.TKASyncSendResponse{ - Head: string(head), - }); err != nil { - t.Fatal(err) + t.Errorf("HandleTKASyncSend: %v", err) } default: @@ -521,7 +422,7 @@ func TestTKASync(t *testing.T) { defer ts.Close() // Setup the client. - cc, _ := fakeControlClient(t, client) + cc := fakeControlClient(t, client) b := LocalBackend{ varRoot: temp, cc: cc, @@ -535,7 +436,7 @@ func TestTKASync(t *testing.T) { }, } - // Finally, lets trigger a sync. + // Finally, let's trigger a sync. err = b.tkaSyncIfNeeded(&netmap.NetworkMap{ TKAEnabled: true, TKAHead: controlAuthority.Head(), @@ -553,10 +454,159 @@ func TestTKASync(t *testing.T) { } } +// Whenever we run a TKA sync and get new state from control, we compact the +// local state. +func TestTKASyncTriggersCompact(t *testing.T) { + someKeyPriv := key.NewNLPrivate() + someKey := tka.Key{Kind: tka.Key25519, Public: someKeyPriv.Public().Verifier(), Votes: 1} + + disablementSecret := bytes.Repeat([]byte{0xa5}, 32) + + nodePriv := key.NewNode() + nlPriv := key.NewNLPrivate() + pm := setupProfileManager(t, nodePriv, nlPriv) + + // Create a clock, and roll it back by 30 days. + // + // Our compaction algorithm preserves AUMs received in the last 14 days, so + // we need to backdate the commit times to make the AUMs eligible for compaction. + clock := tstest.NewClock(tstest.ClockOpts{}) + clock.Advance(-30 * 24 * time.Hour) + + // Set up the TKA authority on the control plane. + key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} + controlStorage := tka.ChonkMem() + controlStorage.SetClock(clock) + controlAuthority, bootstrap, err := tka.Create(controlStorage, tka.State{ + Keys: []tka.Key{key, someKey}, + DisablementSecrets: [][]byte{tka.DisablementKDF(disablementSecret)}, + }, nlPriv) + if err != nil { + t.Fatalf("tka.Create() failed: %v", err) + } + + // Fill the control plane TKA authority with a lot of AUMs, enough so that: + // + // 1. the chain of AUMs includes some checkpoints + // 2. the chain is long enough it would be trimmed if we ran the compaction + // algorithm with the defaults + for range 100 { + upd := controlAuthority.NewUpdater(nlPriv) + if err := upd.RemoveKey(someKey.MustID()); err != nil { + t.Fatalf("RemoveKey: %v", err) + } + if err := upd.AddKey(someKey); err != nil { + t.Fatalf("AddKey: %v", err) + } + aums, err := upd.Finalize(controlStorage) + if err != nil { + t.Fatalf("Finalize: %v", err) + } + if err := controlAuthority.Inform(controlStorage, aums); err != nil { + t.Fatalf("controlAuthority.Inform() failed: %v", err) + } + } + + // Set up the TKA authority on the node. + nodeStorage := tka.ChonkMem() + nodeStorage.SetClock(clock) + nodeAuthority, err := tka.Bootstrap(nodeStorage, bootstrap) + if err != nil { + t.Fatalf("tka.Bootstrap() failed: %v", err) + } + + // Make a mock control server. + ts, client := fakeNoiseServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + switch r.URL.Path { + case "/machine/tka/sync/offer": + err := tkatest.HandleTKASyncOffer(w, r, controlAuthority, controlStorage) + if err != nil { + t.Errorf("HandleTKASyncOffer: %v", err) + } + + case "/machine/tka/sync/send": + err := tkatest.HandleTKASyncSend(w, r, controlAuthority, controlStorage) + if err != nil { + t.Errorf("HandleTKASyncSend: %v", err) + } + + default: + t.Errorf("unhandled endpoint path: %v", r.URL.Path) + w.WriteHeader(404) + } + })) + defer ts.Close() + + // Setup the client. + cc := fakeControlClient(t, client) + b := LocalBackend{ + cc: cc, + ccAuto: cc, + logf: t.Logf, + pm: pm, + store: pm.Store(), + tka: &tkaState{ + authority: nodeAuthority, + storage: nodeStorage, + }, + } + + // Trigger a sync. + err = b.tkaSyncIfNeeded(&netmap.NetworkMap{ + TKAEnabled: true, + TKAHead: controlAuthority.Head(), + }, pm.CurrentPrefs()) + if err != nil { + t.Errorf("tkaSyncIfNeeded() failed: %v", err) + } + + // Add a new AUM in control. + upd := controlAuthority.NewUpdater(nlPriv) + if err := upd.RemoveKey(someKey.MustID()); err != nil { + t.Fatalf("RemoveKey: %v", err) + } + aums, err := upd.Finalize(controlStorage) + if err != nil { + t.Fatalf("Finalize: %v", err) + } + if err := controlAuthority.Inform(controlStorage, aums); err != nil { + t.Fatalf("controlAuthority.Inform() failed: %v", err) + } + + // Run a second sync, which should trigger a compaction. + err = b.tkaSyncIfNeeded(&netmap.NetworkMap{ + TKAEnabled: true, + TKAHead: controlAuthority.Head(), + }, pm.CurrentPrefs()) + if err != nil { + t.Errorf("tkaSyncIfNeeded() failed: %v", err) + } + + // Check that the node and control plane are in sync. + if nodeHead, controlHead := b.tka.authority.Head(), controlAuthority.Head(); nodeHead != controlHead { + t.Errorf("node head = %v, want %v", nodeHead, controlHead) + } + + // Check the node has compacted away some of its AUMs; that it has purged some AUMs which + // are still kept in the control plane. + nodeAUMs, err := b.tka.storage.AllAUMs() + if err != nil { + t.Errorf("AllAUMs() for node failed: %v", err) + } + controlAUMS, err := controlStorage.AllAUMs() + if err != nil { + t.Errorf("AllAUMs() for control failed: %v", err) + } + if len(nodeAUMs) == len(controlAUMS) { + t.Errorf("node has not compacted; it has the same number of AUMs as control (node = control = %d)", len(nodeAUMs)) + } +} + func TestTKAFilterNetmap(t *testing.T) { nlPriv := key.NewNLPrivate() nlKey := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} - storage := &tka.Mem{} + storage := tka.ChonkMem() authority, _, err := tka.Create(storage, tka.State{ Keys: []tka.Key{nlKey}, DisablementSecrets: [][]byte{bytes.Repeat([]byte{0xa5}, 32)}, @@ -710,13 +760,7 @@ func TestTKADisable(t *testing.T) { disablementSecret := bytes.Repeat([]byte{0xa5}, 32) nlPriv := key.NewNLPrivate() - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, health.NewTracker(eventbustest.NewBus(t)))) - must.Do(pm.SetPrefs((&ipn.Prefs{ - Persist: &persist.Persist{ - PrivateNodeKey: nodePriv, - NetworkLockKey: nlPriv, - }, - }).View(), ipn.NetworkProfile{})) + pm := setupProfileManager(t, nodePriv, nlPriv) temp := t.TempDir() tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) @@ -772,7 +816,7 @@ func TestTKADisable(t *testing.T) { })) defer ts.Close() - cc, _ := fakeControlClient(t, client) + cc := fakeControlClient(t, client) b := LocalBackend{ varRoot: temp, cc: cc, @@ -801,13 +845,7 @@ func TestTKASign(t *testing.T) { toSign := key.NewNode() nlPriv := key.NewNLPrivate() - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, health.NewTracker(eventbustest.NewBus(t)))) - must.Do(pm.SetPrefs((&ipn.Prefs{ - Persist: &persist.Persist{ - PrivateNodeKey: nodePriv, - NetworkLockKey: nlPriv, - }, - }).View(), ipn.NetworkProfile{})) + pm := setupProfileManager(t, nodePriv, nlPriv) // Make a fake TKA authority, to seed local state. disablementSecret := bytes.Repeat([]byte{0xa5}, 32) @@ -832,29 +870,9 @@ func TestTKASign(t *testing.T) { defer r.Body.Close() switch r.URL.Path { case "/machine/tka/sign": - body := new(tailcfg.TKASubmitSignatureRequest) - if err := json.NewDecoder(r.Body).Decode(body); err != nil { - t.Fatal(err) - } - if body.Version != tailcfg.CurrentCapabilityVersion { - t.Errorf("sign CapVer = %v, want %v", body.Version, tailcfg.CurrentCapabilityVersion) - } - if body.NodeKey != nodePriv.Public() { - t.Errorf("nodeKey = %v, want %v", body.NodeKey, nodePriv.Public()) - } - - var sig tka.NodeKeySignature - if err := sig.Unserialize(body.Signature); err != nil { - t.Fatalf("malformed signature: %v", err) - } - - if err := authority.NodeKeyAuthorized(toSign.Public(), body.Signature); err != nil { - t.Errorf("signature does not verify: %v", err) - } - - w.WriteHeader(200) - if err := json.NewEncoder(w).Encode(tailcfg.TKASubmitSignatureResponse{}); err != nil { - t.Fatal(err) + _, _, err := tkatest.HandleTKASign(w, r, authority) + if err != nil { + t.Errorf("HandleTKASign: %v", err) } default: @@ -863,7 +881,7 @@ func TestTKASign(t *testing.T) { } })) defer ts.Close() - cc, _ := fakeControlClient(t, client) + cc := fakeControlClient(t, client) b := LocalBackend{ varRoot: temp, cc: cc, @@ -890,13 +908,7 @@ func TestTKAForceDisable(t *testing.T) { nlPriv := key.NewNLPrivate() key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, health.NewTracker(eventbustest.NewBus(t)))) - must.Do(pm.SetPrefs((&ipn.Prefs{ - Persist: &persist.Persist{ - PrivateNodeKey: nodePriv, - NetworkLockKey: nlPriv, - }, - }).View(), ipn.NetworkProfile{})) + pm := setupProfileManager(t, nodePriv, nlPriv) temp := t.TempDir() tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) @@ -917,23 +929,15 @@ func TestTKAForceDisable(t *testing.T) { defer r.Body.Close() switch r.URL.Path { case "/machine/tka/bootstrap": - body := new(tailcfg.TKABootstrapRequest) - if err := json.NewDecoder(r.Body).Decode(body); err != nil { - t.Fatal(err) - } - if body.Version != tailcfg.CurrentCapabilityVersion { - t.Errorf("bootstrap CapVer = %v, want %v", body.Version, tailcfg.CurrentCapabilityVersion) - } - if body.NodeKey != nodePriv.Public() { - t.Errorf("nodeKey=%v, want %v", body.NodeKey, nodePriv.Public()) - } - - w.WriteHeader(200) - out := tailcfg.TKABootstrapResponse{ + resp := tailcfg.TKABootstrapResponse{ GenesisAUM: genesis.Serialize(), } - if err := json.NewEncoder(w).Encode(out); err != nil { - t.Fatal(err) + req, err := tkatest.HandleTKABootstrap(w, r, resp) + if err != nil { + t.Errorf("HandleTKABootstrap: %v", err) + } + if req.NodeKey != nodePriv.Public() { + t.Errorf("nodeKey=%v, want %v", req.NodeKey, nodePriv.Public()) } default: @@ -943,7 +947,7 @@ func TestTKAForceDisable(t *testing.T) { })) defer ts.Close() - cc, _ := fakeControlClient(t, client) + cc := fakeControlClient(t, client) sys := tsd.NewSystem() sys.Set(pm.Store()) @@ -988,13 +992,7 @@ func TestTKAAffectedSigs(t *testing.T) { // toSign := key.NewNode() nlPriv := key.NewNLPrivate() - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, health.NewTracker(eventbustest.NewBus(t)))) - must.Do(pm.SetPrefs((&ipn.Prefs{ - Persist: &persist.Persist{ - PrivateNodeKey: nodePriv, - NetworkLockKey: nlPriv, - }, - }).View(), ipn.NetworkProfile{})) + pm := setupProfileManager(t, nodePriv, nlPriv) // Make a fake TKA authority, to seed local state. disablementSecret := bytes.Repeat([]byte{0xa5}, 32) @@ -1079,7 +1077,7 @@ func TestTKAAffectedSigs(t *testing.T) { } })) defer ts.Close() - cc, _ := fakeControlClient(t, client) + cc := fakeControlClient(t, client) b := LocalBackend{ varRoot: temp, cc: cc, @@ -1121,13 +1119,7 @@ func TestTKARecoverCompromisedKeyFlow(t *testing.T) { cosignPriv := key.NewNLPrivate() compromisedPriv := key.NewNLPrivate() - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, health.NewTracker(eventbustest.NewBus(t)))) - must.Do(pm.SetPrefs((&ipn.Prefs{ - Persist: &persist.Persist{ - PrivateNodeKey: nodePriv, - NetworkLockKey: nlPriv, - }, - }).View(), ipn.NetworkProfile{})) + pm := setupProfileManager(t, nodePriv, nlPriv) // Make a fake TKA authority, to seed local state. disablementSecret := bytes.Repeat([]byte{0xa5}, 32) @@ -1154,44 +1146,23 @@ func TestTKARecoverCompromisedKeyFlow(t *testing.T) { defer r.Body.Close() switch r.URL.Path { case "/machine/tka/sync/send": - body := new(tailcfg.TKASyncSendRequest) - if err := json.NewDecoder(r.Body).Decode(body); err != nil { - t.Fatal(err) - } - t.Logf("got sync send:\n%+v", body) - - var remoteHead tka.AUMHash - if err := remoteHead.UnmarshalText([]byte(body.Head)); err != nil { - t.Fatalf("head unmarshal: %v", err) - } - toApply := make([]tka.AUM, len(body.MissingAUMs)) - for i, a := range body.MissingAUMs { - if err := toApply[i].Unserialize(a); err != nil { - t.Fatalf("decoding missingAUM[%d]: %v", i, err) - } + err := tkatest.HandleTKASyncSend(w, r, authority, chonk) + if err != nil { + t.Errorf("HandleTKASyncSend: %v", err) } - // Apply the recovery AUM to an authority to make sure it works. - if err := authority.Inform(chonk, toApply); err != nil { - t.Errorf("recovery AUM could not be applied: %v", err) - } // Make sure the key we removed isn't trusted. if authority.KeyTrusted(compromisedPriv.KeyID()) { t.Error("compromised key was not removed from tka") } - w.WriteHeader(200) - if err := json.NewEncoder(w).Encode(tailcfg.TKASubmitSignatureResponse{}); err != nil { - t.Fatal(err) - } - default: t.Errorf("unhandled endpoint path: %v", r.URL.Path) w.WriteHeader(404) } })) defer ts.Close() - cc, _ := fakeControlClient(t, client) + cc := fakeControlClient(t, client) b := LocalBackend{ varRoot: temp, cc: cc, @@ -1212,13 +1183,7 @@ func TestTKARecoverCompromisedKeyFlow(t *testing.T) { // Cosign using the cosigning key. { - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, health.NewTracker(eventbustest.NewBus(t)))) - must.Do(pm.SetPrefs((&ipn.Prefs{ - Persist: &persist.Persist{ - PrivateNodeKey: nodePriv, - NetworkLockKey: cosignPriv, - }, - }).View(), ipn.NetworkProfile{})) + pm := setupProfileManager(t, nodePriv, cosignPriv) b := LocalBackend{ varRoot: temp, logf: t.Logf, diff --git a/ipn/ipnlocal/node_backend.go b/ipn/ipnlocal/node_backend.go index dbe23e4d5245a..efef57ea492e7 100644 --- a/ipn/ipnlocal/node_backend.go +++ b/ipn/ipnlocal/node_backend.go @@ -16,6 +16,7 @@ import ( "tailscale.com/ipn" "tailscale.com/net/dns" "tailscale.com/net/tsaddr" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" "tailscale.com/types/key" @@ -82,7 +83,7 @@ type nodeBackend struct { derpMapViewPub *eventbus.Publisher[tailcfg.DERPMapView] // TODO(nickkhyl): maybe use sync.RWMutex? - mu sync.Mutex // protects the following fields + mu syncs.Mutex // protects the following fields shutdownOnce sync.Once // guards calling [nodeBackend.shutdown] readyCh chan struct{} // closed by [nodeBackend.ready]; nil after shutdown @@ -747,7 +748,7 @@ func dnsConfigForNetmap(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg. } dcfg.Hosts[fqdn] = ips } - set(nm.Name, nm.GetAddresses()) + set(nm.SelfName(), nm.GetAddresses()) for _, peer := range peers { set(peer.Name(), peer.Addresses()) } diff --git a/ipn/ipnlocal/profiles.go b/ipn/ipnlocal/profiles.go index 3e80cdaa93d1f..7080e3c3edd50 100644 --- a/ipn/ipnlocal/profiles.go +++ b/ipn/ipnlocal/profiles.go @@ -5,10 +5,12 @@ package ipnlocal import ( "cmp" + "crypto" "crypto/rand" "encoding/json" "errors" "fmt" + "io" "runtime" "slices" "strings" @@ -19,9 +21,12 @@ import ( "tailscale.com/ipn" "tailscale.com/ipn/ipnext" "tailscale.com/tailcfg" + "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/types/persist" "tailscale.com/util/clientmetric" "tailscale.com/util/eventbus" + "tailscale.com/util/testenv" ) var debug = envknob.RegisterBool("TS_DEBUG_PROFILES") @@ -56,6 +61,9 @@ type profileManager struct { // extHost is the bridge between [profileManager] and the registered [ipnext.Extension]s. // It may be nil in tests. A nil pointer is a valid, no-op host. extHost *ExtensionHost + + // Override for key.NewEmptyHardwareAttestationKey used for testing. + newEmptyHardwareAttestationKey func() (key.HardwareAttestationKey, error) } // SetExtensionHost sets the [ExtensionHost] for the [profileManager]. @@ -654,8 +662,26 @@ func (pm *profileManager) loadSavedPrefs(k ipn.StateKey) (ipn.PrefsView, error) return ipn.PrefsView{}, err } savedPrefs := ipn.NewPrefs() + + // if supported by the platform, create an empty hardware attestation key to use when deserializing + // to avoid type exceptions from json.Unmarshaling into an interface{}. + hw, _ := pm.newEmptyHardwareAttestationKey() + savedPrefs.Persist = &persist.Persist{ + AttestationKey: hw, + } + if err := ipn.PrefsFromBytes(bs, savedPrefs); err != nil { - return ipn.PrefsView{}, fmt.Errorf("parsing saved prefs: %v", err) + // Try loading again, this time ignoring the AttestationKey contents. + // If that succeeds, there's something wrong with the underlying + // attestation key mechanism (most likely the TPM changed), but we + // should at least proceed with client startup. + origErr := err + savedPrefs.Persist.AttestationKey = &noopAttestationKey{} + if err := ipn.PrefsFromBytes(bs, savedPrefs); err != nil { + return ipn.PrefsView{}, fmt.Errorf("parsing saved prefs: %w", err) + } else { + pm.logf("failed to parse savedPrefs with attestation key (error: %v) but parsing without the attestation key succeeded; will proceed without using the old attestation key", origErr) + } } pm.logf("using backend prefs for %q: %v", k, savedPrefs.Pretty()) @@ -839,6 +865,7 @@ func (pm *profileManager) CurrentPrefs() ipn.PrefsView { // ReadStartupPrefsForTest reads the startup prefs from disk. It is only used for testing. func ReadStartupPrefsForTest(logf logger.Logf, store ipn.StateStore) (ipn.PrefsView, error) { + testenv.AssertInTest() bus := eventbus.New() defer bus.Close() ht := health.NewTracker(bus) // in tests, don't care about the health status @@ -900,11 +927,12 @@ func newProfileManagerWithGOOS(store ipn.StateStore, logf logger.Logf, ht *healt metricProfileCount.Set(int64(len(knownProfiles))) pm := &profileManager{ - goos: goos, - store: store, - knownProfiles: knownProfiles, - logf: logf, - health: ht, + goos: goos, + store: store, + knownProfiles: knownProfiles, + logf: logf, + health: ht, + newEmptyHardwareAttestationKey: key.NewEmptyHardwareAttestationKey, } var initialProfile ipn.LoginProfileView @@ -973,3 +1001,21 @@ var ( metricMigrationError = clientmetric.NewCounter("profiles_migration_error") metricMigrationSuccess = clientmetric.NewCounter("profiles_migration_success") ) + +// noopAttestationKey is a key.HardwareAttestationKey that always successfully +// unmarshals as a zero key. +type noopAttestationKey struct{} + +func (n noopAttestationKey) Public() crypto.PublicKey { + panic("noopAttestationKey.Public should not be called; missing IsZero check somewhere?") +} + +func (n noopAttestationKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { + panic("noopAttestationKey.Sign should not be called; missing IsZero check somewhere?") +} + +func (n noopAttestationKey) MarshalJSON() ([]byte, error) { return nil, nil } +func (n noopAttestationKey) UnmarshalJSON([]byte) error { return nil } +func (n noopAttestationKey) Close() error { return nil } +func (n noopAttestationKey) Clone() key.HardwareAttestationKey { return n } +func (n noopAttestationKey) IsZero() bool { return true } diff --git a/ipn/ipnlocal/profiles_test.go b/ipn/ipnlocal/profiles_test.go index 60c92ff8d3493..6be7f0e53f59e 100644 --- a/ipn/ipnlocal/profiles_test.go +++ b/ipn/ipnlocal/profiles_test.go @@ -4,6 +4,7 @@ package ipnlocal import ( + "errors" "fmt" "os/user" "strconv" @@ -151,6 +152,7 @@ func TestProfileDupe(t *testing.T) { ID: tailcfg.UserID(user), LoginName: fmt.Sprintf("user%d@example.com", user), }, + AttestationKey: nil, } } user1Node1 := newPersist(1, 1) @@ -1128,10 +1130,12 @@ func TestProfileStateChangeCallback(t *testing.T) { } gotChanges := make([]stateChange, 0, len(tt.wantChanges)) - pm.StateChangeHook = func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { + pm.StateChangeHook = func(profile ipn.LoginProfileView, prefView ipn.PrefsView, sameNode bool) { + prefs := prefView.AsStruct() + prefs.Sync = prefs.Sync.Normalized() gotChanges = append(gotChanges, stateChange{ Profile: profile.AsStruct(), - Prefs: prefs.AsStruct(), + Prefs: prefs, SameNode: sameNode, }) } @@ -1144,3 +1148,40 @@ func TestProfileStateChangeCallback(t *testing.T) { }) } } + +func TestProfileBadAttestationKey(t *testing.T) { + store := new(mem.Store) + pm, err := newProfileManagerWithGOOS(store, t.Logf, health.NewTracker(eventbustest.NewBus(t)), "linux") + if err != nil { + t.Fatal(err) + } + fk := new(failingHardwareAttestationKey) + pm.newEmptyHardwareAttestationKey = func() (key.HardwareAttestationKey, error) { + return fk, nil + } + sk := ipn.StateKey(t.Name()) + if err := pm.store.WriteState(sk, []byte(`{"Config": {"AttestationKey": {}}}`)); err != nil { + t.Fatal(err) + } + prefs, err := pm.loadSavedPrefs(sk) + if err != nil { + t.Fatal(err) + } + ak := prefs.Persist().AsStruct().AttestationKey + if _, ok := ak.(noopAttestationKey); !ok { + t.Errorf("loaded attestation key of type %T, want noopAttestationKey", ak) + } + if !fk.unmarshalCalled { + t.Error("UnmarshalJSON was not called on failingHardwareAttestationKey") + } +} + +type failingHardwareAttestationKey struct { + noopAttestationKey + unmarshalCalled bool +} + +func (k *failingHardwareAttestationKey) UnmarshalJSON([]byte) error { + k.unmarshalCalled = true + return errors.New("failed to unmarshal attestation key!") +} diff --git a/ipn/ipnlocal/serve.go b/ipn/ipnlocal/serve.go index 3c967fd1e6403..ef4e9154557a4 100644 --- a/ipn/ipnlocal/serve.go +++ b/ipn/ipnlocal/serve.go @@ -33,6 +33,7 @@ import ( "time" "unicode/utf8" + "github.com/pires/go-proxyproto" "go4.org/mem" "tailscale.com/ipn" "tailscale.com/net/netutil" @@ -40,6 +41,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/types/lazy" "tailscale.com/types/logger" + "tailscale.com/types/views" "tailscale.com/util/backoff" "tailscale.com/util/clientmetric" "tailscale.com/util/ctxkey" @@ -58,6 +60,9 @@ func init() { b.setVIPServicesTCPPortsInterceptedLocked(nil) }) + hookMaybeMutateHostinfoLocked.Add(maybeUpdateHostinfoServicesHashLocked) + hookMaybeMutateHostinfoLocked.Add(maybeUpdateHostinfoFunnelLocked) + RegisterC2N("GET /vip-services", handleC2NVIPServicesGet) } @@ -80,6 +85,8 @@ type serveHTTPContext struct { // provides funnel-specific context, nil if not funneled Funnel *funnelFlow + // AppCapabilities lists all PeerCapabilities that should be forwarded by serve + AppCapabilities views.Slice[tailcfg.PeerCapability] } // funnelFlow represents a funneled connection initiated via IngressPeer @@ -285,6 +292,10 @@ func (b *LocalBackend) updateServeTCPPortNetMapAddrListenersLocked(ports []uint1 // SetServeConfig establishes or replaces the current serve config. // ETag is an optional parameter to enforce Optimistic Concurrency Control. // If it is an empty string, then the config will be overwritten. +// +// New foreground config cannot override existing listeners--neither existing +// foreground listeners nor existing background listeners. Background config can +// change as long as the serve type (e.g. HTTP, TCP, etc.) remains the same. func (b *LocalBackend) SetServeConfig(config *ipn.ServeConfig, etag string) error { b.mu.Lock() defer b.mu.Unlock() @@ -300,12 +311,6 @@ func (b *LocalBackend) setServeConfigLocked(config *ipn.ServeConfig, etag string return errors.New("can't reconfigure tailscaled when using a config file; config file is locked") } - if config != nil { - if err := config.CheckValidServicesConfig(); err != nil { - return err - } - } - nm := b.NetMap() if nm == nil { return errors.New("netMap is nil") @@ -333,6 +338,10 @@ func (b *LocalBackend) setServeConfigLocked(config *ipn.ServeConfig, etag string } } + if err := validateServeConfigUpdate(prevConfig, config.View()); err != nil { + return err + } + var bs []byte if config != nil { j, err := json.Marshal(config) @@ -665,10 +674,81 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort, }) } + var proxyHeader []byte + if ver := tcph.ProxyProtocol(); ver > 0 { + // backAddr is the final "destination" of the connection, + // which is the connection to the proxied-to backend. + backAddr := backConn.RemoteAddr().(*net.TCPAddr) + + // We always want to format the PROXY protocol + // header based on the IPv4 or IPv6-ness of + // the client. The SourceAddr and + // DestinationAddr need to match in type, so we + // need to be careful to not e.g. set a + // SourceAddr of type IPv6 and DestinationAddr + // of type IPv4. + // + // If this is an IPv6-mapped IPv4 address, + // though, unmap it. + proxySrcAddr := srcAddr + if proxySrcAddr.Addr().Is4In6() { + proxySrcAddr = netip.AddrPortFrom( + proxySrcAddr.Addr().Unmap(), + proxySrcAddr.Port(), + ) + } + + is4 := proxySrcAddr.Addr().Is4() + + var destAddr netip.Addr + if self := b.currentNode().Self(); self.Valid() { + if is4 { + destAddr = nodeIP(self, netip.Addr.Is4) + } else { + destAddr = nodeIP(self, netip.Addr.Is6) + } + } + if !destAddr.IsValid() { + // Pick a best-effort destination address of localhost. + if is4 { + destAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1}) + } else { + destAddr = netip.IPv6Loopback() + } + } + + header := &proxyproto.Header{ + Version: byte(ver), + Command: proxyproto.PROXY, + SourceAddr: net.TCPAddrFromAddrPort(proxySrcAddr), + DestinationAddr: &net.TCPAddr{ + IP: destAddr.AsSlice(), + Port: backAddr.Port, + }, + } + if is4 { + header.TransportProtocol = proxyproto.TCPv4 + } else { + header.TransportProtocol = proxyproto.TCPv6 + } + var err error + proxyHeader, err = header.Format() + if err != nil { + b.logf("localbackend: failed to format proxy protocol header for port %v (from %v) to %s: %v", dport, srcAddr, backDst, err) + } + } + // TODO(bradfitz): do the RegisterIPPortIdentity and // UnregisterIPPortIdentity stuff that netstack does errc := make(chan error, 1) go func() { + if len(proxyHeader) > 0 { + if _, err := backConn.Write(proxyHeader); err != nil { + errc <- err + backConn.Close() // to ensure that the other side gets EOF + return + } + } _, err := io.Copy(backConn, conn) errc <- err }() @@ -803,9 +883,11 @@ func (rp *reverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { r.Out.Host = r.In.Host addProxyForwardedHeaders(r) rp.lb.addTailscaleIdentityHeaders(r) - }} - - // There is no way to autodetect h2c as per RFC 9113 + if err := rp.lb.addAppCapabilitiesHeader(r); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + }} // There is no way to autodetect h2c as per RFC 9113 // https://datatracker.ietf.org/doc/html/rfc9113#name-starting-http-2. // However, we assume that http:// proxy prefix in combination with the // protoccol being HTTP/2 is sufficient to detect h2c for our needs. Only use this for @@ -927,6 +1009,53 @@ func encTailscaleHeaderValue(v string) string { return mime.QEncoding.Encode("utf-8", v) } +func (b *LocalBackend) addAppCapabilitiesHeader(r *httputil.ProxyRequest) error { + const appCapabilitiesHeaderName = "Tailscale-App-Capabilities" + r.Out.Header.Del(appCapabilitiesHeaderName) + + c, ok := serveHTTPContextKey.ValueOk(r.Out.Context()) + if !ok || c.Funnel != nil { + return nil + } + acceptCaps := c.AppCapabilities + if acceptCaps.IsNil() { + return nil + } + peerCaps := b.PeerCaps(c.SrcAddr.Addr()) + if peerCaps == nil { + return nil + } + + peerCapsFiltered := make(map[tailcfg.PeerCapability][]tailcfg.RawMessage, acceptCaps.Len()) + for _, cap := range acceptCaps.AsSlice() { + if peerCaps.HasCapability(cap) { + peerCapsFiltered[cap] = peerCaps[cap] + } + } + + peerCapsSerialized, err := json.Marshal(peerCapsFiltered) + if err != nil { + b.logf("serve: failed to serialize filtered PeerCapMap: %v", err) + return fmt.Errorf("unable to process app capabilities") + } + + r.Out.Header.Set(appCapabilitiesHeaderName, encTailscaleHeaderValue(string(peerCapsSerialized))) + return nil +} + +// parseRedirectWithCode parses a redirect string that may optionally start with +// a HTTP redirect status code ("3xx:"). +// Returns the status code and the final redirect URL. +// If no code prefix is found, returns http.StatusFound (302). +func parseRedirectWithCode(redirect string) (code int, url string) { + if len(redirect) >= 4 && redirect[3] == ':' { + if statusCode, err := strconv.Atoi(redirect[:3]); err == nil && statusCode >= 300 && statusCode <= 399 { + return statusCode, redirect[4:] + } + } + return http.StatusFound, redirect +} + // serveWebHandler is an http.HandlerFunc that maps incoming requests to the // correct *http. func (b *LocalBackend) serveWebHandler(w http.ResponseWriter, r *http.Request) { @@ -940,6 +1069,13 @@ func (b *LocalBackend) serveWebHandler(w http.ResponseWriter, r *http.Request) { io.WriteString(w, s) return } + if v := h.Redirect(); v != "" { + code, v := parseRedirectWithCode(v) + v = strings.ReplaceAll(v, "${HOST}", r.Host) + v = strings.ReplaceAll(v, "${REQUEST_URI}", r.RequestURI) + http.Redirect(w, r, v, code) + return + } if v := h.Path(); v != "" { b.serveFileOrDirectory(w, r, v, mountPoint) return @@ -950,6 +1086,12 @@ func (b *LocalBackend) serveWebHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "unknown proxy destination", http.StatusInternalServerError) return } + // Inject app capabilities to forward into the request context + c, ok := serveHTTPContextKey.ValueOk(r.Context()) + if !ok { + return + } + c.AppCapabilities = h.AcceptAppCaps() h := p.(http.Handler) // Trim the mount point from the URL path before proxying. (#6571) if r.URL.Path != "/" { @@ -1162,7 +1304,7 @@ func handleC2NVIPServicesGet(b *LocalBackend, w http.ResponseWriter, r *http.Req b.logf("c2n: GET /vip-services received") var res tailcfg.C2NVIPServicesResponse res.VIPServices = b.VIPServices() - res.ServicesHash = b.vipServiceHash(res.VIPServices) + res.ServicesHash = vipServiceHash(b.logf, res.VIPServices) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(res) @@ -1378,3 +1520,192 @@ func (b *LocalBackend) setVIPServicesTCPPortsInterceptedLocked(svcPorts map[tail b.shouldInterceptVIPServicesTCPPortAtomic.Store(generateInterceptVIPServicesTCPPortFunc(svcAddrPorts)) } + +func maybeUpdateHostinfoServicesHashLocked(b *LocalBackend, hi *tailcfg.Hostinfo, prefs ipn.PrefsView) bool { + latestHash := vipServiceHash(b.logf, b.vipServicesFromPrefsLocked(prefs)) + if hi.ServicesHash != latestHash { + hi.ServicesHash = latestHash + return true + } + return false +} + +func maybeUpdateHostinfoFunnelLocked(b *LocalBackend, hi *tailcfg.Hostinfo, prefs ipn.PrefsView) (changed bool) { + // The Hostinfo.IngressEnabled field is used to communicate to control whether + // the node has funnel enabled. + if ie := b.hasIngressEnabledLocked(); hi.IngressEnabled != ie { + b.logf("Hostinfo.IngressEnabled changed to %v", ie) + hi.IngressEnabled = ie + changed = true + } + // The Hostinfo.WireIngress field tells control whether the user intends + // to use funnel with this node even though it is not currently enabled. + // This is an optimization to control- Funnel requires creation of DNS + // records and because DNS propagation can take time, we want to ensure + // that the records exist for any node that intends to use funnel even + // if it's not enabled. If hi.IngressEnabled is true, control knows that + // DNS records are needed, so we can save bandwidth and not send + // WireIngress. + if wire := b.shouldWireInactiveIngressLocked(); hi.WireIngress != wire { + b.logf("Hostinfo.WireIngress changed to %v", wire) + hi.WireIngress = wire + changed = true + } + return changed +} + +func vipServiceHash(logf logger.Logf, services []*tailcfg.VIPService) string { + if len(services) == 0 { + return "" + } + h := sha256.New() + jh := json.NewEncoder(h) + if err := jh.Encode(services); err != nil { + logf("vipServiceHashLocked: %v", err) + return "" + } + var buf [sha256.Size]byte + h.Sum(buf[:0]) + return hex.EncodeToString(buf[:]) +} + +// validateServeConfigUpdate validates changes proposed by incoming serve +// configuration. +func validateServeConfigUpdate(existing, incoming ipn.ServeConfigView) error { + // Error messages returned by this function may be presented to end-users by + // frontends like the CLI. Thus these error messages should provide enough + // information for end-users to diagnose and resolve conflicts. + + if !incoming.Valid() { + return nil + } + + // For Services, TUN mode is mutually exclusive with L4 or L7 handlers. + for svcName, svcCfg := range incoming.Services().All() { + hasTCP := svcCfg.TCP().Len() > 0 + hasWeb := svcCfg.Web().Len() > 0 + if svcCfg.Tun() && (hasTCP || hasWeb) { + return fmt.Errorf("cannot configure TUN mode in combination with TCP or web handlers for %s", svcName) + } + } + + if !existing.Valid() { + return nil + } + + // New foreground listeners must be on open ports. + for sessionID, incomingFg := range incoming.Foreground().All() { + if !existing.Foreground().Has(sessionID) { + // This is a new session. + for port := range incomingFg.TCPs() { + if _, exists := existing.FindTCP(port); exists { + return fmt.Errorf("listener already exists for port %d", port) + } + } + } + } + + // New background listeners cannot overwrite existing foreground listeners. + for port := range incoming.TCP().All() { + if _, exists := existing.FindForegroundTCP(port); exists { + return fmt.Errorf("foreground listener already exists for port %d", port) + } + } + + // Incoming configuration cannot change the serve type in use by a port. + for port, incomingHandler := range incoming.TCP().All() { + existingHandler, exists := existing.FindTCP(port) + if !exists { + continue + } + + existingServeType := serveTypeFromPortHandler(existingHandler) + incomingServeType := serveTypeFromPortHandler(incomingHandler) + if incomingServeType != existingServeType { + return fmt.Errorf("want to serve %q, but port %d is already serving %q", incomingServeType, port, existingServeType) + } + } + + // Validations for Tailscale Services. + for svcName, incomingSvcCfg := range incoming.Services().All() { + existingSvcCfg, exists := existing.Services().GetOk(svcName) + if !exists { + continue + } + + // Incoming configuration cannot change the serve type in use by a port. + for port, incomingHandler := range incomingSvcCfg.TCP().All() { + existingHandler, exists := existingSvcCfg.TCP().GetOk(port) + if !exists { + continue + } + + existingServeType := serveTypeFromPortHandler(existingHandler) + incomingServeType := serveTypeFromPortHandler(incomingHandler) + if incomingServeType != existingServeType { + return fmt.Errorf("want to serve %q, but port %d is already serving %q for %s", incomingServeType, port, existingServeType, svcName) + } + } + + existingHasTCP := existingSvcCfg.TCP().Len() > 0 + existingHasWeb := existingSvcCfg.Web().Len() > 0 + + // A Service cannot turn on TUN mode if TCP or web handlers exist. + if incomingSvcCfg.Tun() && (existingHasTCP || existingHasWeb) { + return fmt.Errorf("cannot turn on TUN mode with existing TCP or web handlers for %s", svcName) + } + + incomingHasTCP := incomingSvcCfg.TCP().Len() > 0 + incomingHasWeb := incomingSvcCfg.Web().Len() > 0 + + // A Service cannot add TCP or web handlers if TUN mode is enabled. + if (incomingHasTCP || incomingHasWeb) && existingSvcCfg.Tun() { + return fmt.Errorf("cannot add TCP or web handlers as TUN mode is enabled for %s", svcName) + } + } + + return nil +} + +// serveType is a high-level descriptor of the kind of serve performed by a TCP +// port handler. +type serveType int + +const ( + serveTypeHTTPS serveType = iota + serveTypeHTTP + serveTypeTCP + serveTypeTLSTerminatedTCP +) + +func (s serveType) String() string { + switch s { + case serveTypeHTTP: + return "http" + case serveTypeHTTPS: + return "https" + case serveTypeTCP: + return "tcp" + case serveTypeTLSTerminatedTCP: + return "tls-terminated-tcp" + default: + return "unknownServeType" + } +} + +// serveTypeFromPortHandler is used to get a high-level descriptor of the kind +// of serve being performed by a port handler. +func serveTypeFromPortHandler(ph ipn.TCPPortHandlerView) serveType { + switch { + case ph.HTTP(): + return serveTypeHTTP + case ph.HTTPS(): + return serveTypeHTTPS + case ph.TerminateTLS() != "": + return serveTypeTLSTerminatedTCP + case ph.TCPForward() != "": + return serveTypeTCP + default: + return -1 + } +} diff --git a/ipn/ipnlocal/serve_test.go b/ipn/ipnlocal/serve_test.go index b4461d12f2ad0..6ee2181a0aaa2 100644 --- a/ipn/ipnlocal/serve_test.go +++ b/ipn/ipnlocal/serve_test.go @@ -16,6 +16,7 @@ import ( "errors" "fmt" "io" + "mime" "net/http" "net/http/httptest" "net/netip" @@ -27,6 +28,7 @@ import ( "testing" "time" + "tailscale.com/control/controlclient" "tailscale.com/health" "tailscale.com/ipn" "tailscale.com/ipn/store/mem" @@ -41,6 +43,7 @@ import ( "tailscale.com/util/must" "tailscale.com/util/syspolicy/policyclient" "tailscale.com/wgengine" + "tailscale.com/wgengine/filter" ) func TestExpandProxyArg(t *testing.T) { @@ -69,6 +72,41 @@ func TestExpandProxyArg(t *testing.T) { } } +func TestParseRedirectWithRedirectCode(t *testing.T) { + tests := []struct { + in string + wantCode int + wantURL string + }{ + {"301:https://example.com", 301, "https://example.com"}, + {"302:https://example.com", 302, "https://example.com"}, + {"303:/path", 303, "/path"}, + {"307:https://example.com/path?query=1", 307, "https://example.com/path?query=1"}, + {"308:https://example.com", 308, "https://example.com"}, + + {"https://example.com", 302, "https://example.com"}, + {"/path", 302, "/path"}, + {"http://example.com", 302, "http://example.com"}, + {"git://example.com", 302, "git://example.com"}, + + {"200:https://example.com", 302, "200:https://example.com"}, + {"404:https://example.com", 302, "404:https://example.com"}, + {"500:https://example.com", 302, "500:https://example.com"}, + {"30:https://example.com", 302, "30:https://example.com"}, + {"3:https://example.com", 302, "3:https://example.com"}, + {"3012:https://example.com", 302, "3012:https://example.com"}, + {"abc:https://example.com", 302, "abc:https://example.com"}, + {"301", 302, "301"}, + } + for _, tt := range tests { + gotCode, gotURL := parseRedirectWithCode(tt.in) + if gotCode != tt.wantCode || gotURL != tt.wantURL { + t.Errorf("parseRedirectWithCode(%q) = (%d, %q), want (%d, %q)", + tt.in, gotCode, gotURL, tt.wantCode, tt.wantURL) + } + } +} + func TestGetServeHandler(t *testing.T) { const serverName = "example.ts.net" conf1 := &ipn.ServeConfig{ @@ -350,7 +388,7 @@ func TestServeConfigServices(t *testing.T) { tests := []struct { name string conf *ipn.ServeConfig - expectedErr error + errExpected bool packetDstAddrPort []netip.AddrPort intercepted bool }{ @@ -374,7 +412,7 @@ func TestServeConfigServices(t *testing.T) { }, }, }, - expectedErr: ipn.ErrServiceConfigHasBothTCPAndTun, + errExpected: true, }, { // one correctly configured service with packet should be intercepted @@ -481,13 +519,13 @@ func TestServeConfigServices(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := b.SetServeConfig(tt.conf, "") - if err != nil && tt.expectedErr != nil { - if !errors.Is(err, tt.expectedErr) { - t.Fatalf("expected error %v,\n got %v", tt.expectedErr, err) - } - return + if err == nil && tt.errExpected { + t.Fatal("expected error") } if err != nil { + if tt.errExpected { + return + } t.Fatal(err) } for _, addrPort := range tt.packetDstAddrPort { @@ -768,6 +806,156 @@ func TestServeHTTPProxyHeaders(t *testing.T) { } } +func TestServeHTTPProxyGrantHeader(t *testing.T) { + b := newTestBackend(t) + + nm := b.NetMap() + matches, err := filter.MatchesFromFilterRules([]tailcfg.FilterRule{ + { + SrcIPs: []string{"100.150.151.152"}, + CapGrant: []tailcfg.CapGrant{{ + Dsts: []netip.Prefix{ + netip.MustParsePrefix("100.150.151.151/32"), + }, + CapMap: tailcfg.PeerCapMap{ + "example.com/cap/interesting": []tailcfg.RawMessage{ + `{"role": "🐿"}`, + }, + }, + }}, + }, + { + SrcIPs: []string{"100.150.151.153"}, + CapGrant: []tailcfg.CapGrant{{ + Dsts: []netip.Prefix{ + netip.MustParsePrefix("100.150.151.151/32"), + }, + CapMap: tailcfg.PeerCapMap{ + "example.com/cap/boring": []tailcfg.RawMessage{ + `{"role": "Viewer"}`, + }, + "example.com/cap/irrelevant": []tailcfg.RawMessage{ + `{"role": "Editor"}`, + }, + }, + }}, + }, + }) + if err != nil { + t.Fatal(err) + } + nm.PacketFilter = matches + b.SetControlClientStatus(nil, controlclient.Status{NetMap: nm}) + + // Start test serve endpoint. + testServ := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + // Piping all the headers through the response writer + // so we can check their values in tests below. + for key, val := range r.Header { + w.Header().Add(key, strings.Join(val, ",")) + } + }, + )) + defer testServ.Close() + + conf := &ipn.ServeConfig{ + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "example.ts.net:443": {Handlers: map[string]*ipn.HTTPHandler{ + "/": { + Proxy: testServ.URL, + AcceptAppCaps: []tailcfg.PeerCapability{"example.com/cap/interesting", "example.com/cap/boring"}, + }, + }}, + }, + } + if err := b.SetServeConfig(conf, ""); err != nil { + t.Fatal(err) + } + + type headerCheck struct { + header string + want string + } + + tests := []struct { + name string + srcIP string + wantHeaders []headerCheck + }{ + { + name: "request-from-user-within-tailnet", + srcIP: "100.150.151.152", + wantHeaders: []headerCheck{ + {"X-Forwarded-Proto", "https"}, + {"X-Forwarded-For", "100.150.151.152"}, + {"Tailscale-User-Login", "someone@example.com"}, + {"Tailscale-User-Name", "Some One"}, + {"Tailscale-User-Profile-Pic", "https://example.com/photo.jpg"}, + {"Tailscale-Headers-Info", "https://tailscale.com/s/serve-headers"}, + {"Tailscale-App-Capabilities", `{"example.com/cap/interesting":[{"role":"🐿"}]}`}, + }, + }, + { + name: "request-from-tagged-node-within-tailnet", + srcIP: "100.150.151.153", + wantHeaders: []headerCheck{ + {"X-Forwarded-Proto", "https"}, + {"X-Forwarded-For", "100.150.151.153"}, + {"Tailscale-User-Login", ""}, + {"Tailscale-User-Name", ""}, + {"Tailscale-User-Profile-Pic", ""}, + {"Tailscale-Headers-Info", ""}, + {"Tailscale-App-Capabilities", `{"example.com/cap/boring":[{"role":"Viewer"}]}`}, + }, + }, + { + name: "request-from-outside-tailnet", + srcIP: "100.160.161.162", + wantHeaders: []headerCheck{ + {"X-Forwarded-Proto", "https"}, + {"X-Forwarded-For", "100.160.161.162"}, + {"Tailscale-User-Login", ""}, + {"Tailscale-User-Name", ""}, + {"Tailscale-User-Profile-Pic", ""}, + {"Tailscale-Headers-Info", ""}, + {"Tailscale-App-Capabilities", ""}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &http.Request{ + URL: &url.URL{Path: "/"}, + TLS: &tls.ConnectionState{ServerName: "example.ts.net"}, + } + req = req.WithContext(serveHTTPContextKey.WithValue(req.Context(), &serveHTTPContext{ + DestPort: 443, + SrcAddr: netip.MustParseAddrPort(tt.srcIP + ":1234"), // random src port for tests + })) + + w := httptest.NewRecorder() + b.serveWebHandler(w, req) + + // Verify the headers. The contract with users is that identity and grant headers containing non-ASCII + // UTF-8 characters will be Q-encoded. + h := w.Result().Header + dec := new(mime.WordDecoder) + for _, c := range tt.wantHeaders { + maybeEncoded := h.Get(c.header) + got, err := dec.DecodeHeader(maybeEncoded) + if err != nil { + t.Fatalf("invalid %q header; failed to decode: %v", maybeEncoded, err) + } + if got != c.want { + t.Errorf("invalid %q header; want=%q, got=%q", c.header, c.want, got) + } + } + }) + } +} + func Test_reverseProxyConfiguration(t *testing.T) { b := newTestBackend(t) type test struct { @@ -926,6 +1114,9 @@ func newTestBackend(t *testing.T, opts ...any) *LocalBackend { b.currentNode().SetNetMap(&netmap.NetworkMap{ SelfNode: (&tailcfg.Node{ Name: "example.ts.net", + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.150.151.151/32"), + }, }).View(), UserProfiles: map[tailcfg.UserID]tailcfg.UserProfileView{ tailcfg.UserID(1): (&tailcfg.UserProfile{ @@ -1171,3 +1362,407 @@ func TestServeGRPCProxy(t *testing.T) { }) } } + +func TestServeHTTPRedirect(t *testing.T) { + b := newTestBackend(t) + + tests := []struct { + host string + path string + redirect string + reqURI string + wantCode int + wantLoc string + }{ + { + host: "hardcoded-root", + path: "/", + redirect: "https://example.com/", + reqURI: "/old", + wantCode: http.StatusFound, // 302 is the default + wantLoc: "https://example.com/", + }, + { + host: "template-host-and-uri", + path: "/", + redirect: "https://${HOST}${REQUEST_URI}", + reqURI: "/path?foo=bar", + wantCode: http.StatusFound, // 302 is the default + wantLoc: "https://template-host-and-uri/path?foo=bar", + }, + { + host: "custom-301", + path: "/", + redirect: "301:https://example.com/", + reqURI: "/old", + wantCode: http.StatusMovedPermanently, // 301 + wantLoc: "https://example.com/", + }, + { + host: "custom-307", + path: "/", + redirect: "307:https://example.com/new", + reqURI: "/old", + wantCode: http.StatusTemporaryRedirect, // 307 + wantLoc: "https://example.com/new", + }, + { + host: "custom-308", + path: "/", + redirect: "308:https://example.com/permanent", + reqURI: "/old", + wantCode: http.StatusPermanentRedirect, // 308 + wantLoc: "https://example.com/permanent", + }, + } + + for _, tt := range tests { + t.Run(tt.host, func(t *testing.T) { + conf := &ipn.ServeConfig{ + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + ipn.HostPort(tt.host + ":80"): { + Handlers: map[string]*ipn.HTTPHandler{ + tt.path: {Redirect: tt.redirect}, + }, + }, + }, + } + if err := b.SetServeConfig(conf, ""); err != nil { + t.Fatal(err) + } + + req := &http.Request{ + Host: tt.host, + URL: &url.URL{Path: tt.path}, + RequestURI: tt.reqURI, + TLS: &tls.ConnectionState{ServerName: tt.host}, + } + req = req.WithContext(serveHTTPContextKey.WithValue(req.Context(), &serveHTTPContext{ + DestPort: 80, + SrcAddr: netip.MustParseAddrPort("1.2.3.4:1234"), + })) + + w := httptest.NewRecorder() + b.serveWebHandler(w, req) + + if w.Code != tt.wantCode { + t.Errorf("got status %d, want %d", w.Code, tt.wantCode) + } + if got := w.Header().Get("Location"); got != tt.wantLoc { + t.Errorf("got Location %q, want %q", got, tt.wantLoc) + } + }) + } +} + +func TestValidateServeConfigUpdate(t *testing.T) { + tests := []struct { + name, description string + existing, incoming *ipn.ServeConfig + wantError bool + }{ + { + name: "empty existing config", + description: "should be able to update with empty existing config", + existing: &ipn.ServeConfig{}, + incoming: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 8080: {}, + }, + }, + wantError: false, + }, + { + name: "no existing config", + description: "should be able to update with no existing config", + existing: nil, + incoming: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 8080: {}, + }, + }, + wantError: false, + }, + { + name: "empty incoming config", + description: "wiping config should work", + existing: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {}, + }, + }, + incoming: &ipn.ServeConfig{}, + wantError: false, + }, + { + name: "no incoming config", + description: "missing incoming config should not result in an error", + existing: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {}, + }, + }, + incoming: nil, + wantError: false, + }, + { + name: "non-overlapping update", + description: "non-overlapping update should work", + existing: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {}, + }, + }, + incoming: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 8080: {}, + }, + }, + wantError: false, + }, + { + name: "overwriting background port", + description: "should be able to overwrite a background port", + existing: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: { + TCPForward: "localhost:8080", + }, + }, + }, + incoming: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: { + TCPForward: "localhost:9999", + }, + }, + }, + wantError: false, + }, + { + name: "broken existing config", + description: "broken existing config should not prevent new config updates", + existing: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + // Broken because HTTPS and TCPForward are mutually exclusive. + 9000: { + HTTPS: true, + TCPForward: "127.0.0.1:9000", + }, + // Broken because foreground and background handlers cannot coexist. + 443: {}, + }, + Foreground: map[string]*ipn.ServeConfig{ + "12345": { + TCP: map[uint16]*ipn.TCPPortHandler{ + // Broken because foreground and background handlers cannot coexist. + 443: {}, + }, + }, + }, + // Broken because Services cannot specify TUN mode and a TCP handler. + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 6060: {}, + }, + Tun: true, + }, + }, + }, + incoming: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {}, + }, + }, + wantError: false, + }, + { + name: "services same port as background", + description: "services should be able to use the same port as background listeners", + existing: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {}, + }, + }, + incoming: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {}, + }, + }, + }, + }, + wantError: false, + }, + { + name: "services tun mode", + description: "TUN mode should be mutually exclusive with TCP or web handlers for new Services", + existing: &ipn.ServeConfig{}, + incoming: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 6060: {}, + }, + Tun: true, + }, + }, + }, + wantError: true, + }, + { + name: "new foreground listener", + description: "new foreground listeners must be on open ports", + existing: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {}, + }, + }, + incoming: &ipn.ServeConfig{ + Foreground: map[string]*ipn.ServeConfig{ + "12345": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {}, + }, + }, + }, + }, + wantError: true, + }, + { + name: "new background listener", + description: "new background listers cannot overwrite foreground listeners", + existing: &ipn.ServeConfig{ + Foreground: map[string]*ipn.ServeConfig{ + "12345": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {}, + }, + }, + }, + }, + incoming: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {}, + }, + }, + wantError: true, + }, + { + name: "serve type overwrite", + description: "incoming configuration cannot change the serve type in use by a port", + existing: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: { + HTTP: true, + }, + }, + }, + incoming: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: { + TCPForward: "localhost:8080", + }, + }, + }, + wantError: true, + }, + { + name: "serve type overwrite services", + description: "incoming Services configuration cannot change the serve type in use by a port", + existing: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: { + HTTP: true, + }, + }, + }, + }, + }, + incoming: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: { + TCPForward: "localhost:8080", + }, + }, + }, + }, + }, + wantError: true, + }, + { + name: "tun mode with handlers", + description: "Services cannot enable TUN mode if L4 or L7 handlers already exist", + existing: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: { + HTTPS: true, + }, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "127.0.0.1:443": { + Handlers: map[string]*ipn.HTTPHandler{}, + }, + }, + }, + }, + }, + incoming: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + Tun: true, + }, + }, + }, + wantError: true, + }, + { + name: "handlers with tun mode", + description: "Services cannot add L4 or L7 handlers if TUN mode is already enabled", + existing: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + Tun: true, + }, + }, + }, + incoming: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: { + HTTPS: true, + }, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "127.0.0.1:443": { + Handlers: map[string]*ipn.HTTPHandler{}, + }, + }, + }, + }, + }, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateServeConfigUpdate(tt.existing.View(), tt.incoming.View()) + if err != nil && !tt.wantError { + t.Error("unexpected error:", err) + } + if err == nil && tt.wantError { + t.Error("expected error, got nil;", tt.description) + } + }) + } +} diff --git a/ipn/ipnlocal/state_test.go b/ipn/ipnlocal/state_test.go index fca01f1056fcb..152b375b0f7b8 100644 --- a/ipn/ipnlocal/state_test.go +++ b/ipn/ipnlocal/state_test.go @@ -206,9 +206,7 @@ func (cc *mockControl) send(opts sendOpt) { Err: err, } if loginFinished { - s.SetStateForTest(controlclient.StateAuthenticated) - } else if url == "" && err == nil && nm == nil { - s.SetStateForTest(controlclient.StateNotAuthenticated) + s.LoggedIn = true } cc.opts.Observer.SetControlClientStatus(cc, s) } @@ -228,7 +226,6 @@ func (cc *mockControl) sendAuthURL(nm *netmap.NetworkMap) { NetMap: nm, Persist: cc.persist.View(), } - s.SetStateForTest(controlclient.StateURLVisitRequired) cc.opts.Observer.SetControlClientStatus(cc, s) } @@ -319,6 +316,11 @@ func (cc *mockControl) UpdateEndpoints(endpoints []tailcfg.Endpoint) { cc.called("UpdateEndpoints") } +func (cc *mockControl) SetDiscoPublicKey(key key.DiscoPublic) { + cc.logf("SetDiscoPublicKey: %v", key) + cc.called("SetDiscoPublicKey") +} + func (cc *mockControl) ClientID() int64 { return cc.controlClientID } @@ -434,8 +436,11 @@ func runTestStateMachine(t *testing.T, seamless bool) { // for it, so it doesn't count as Prefs.LoggedOut==true. c.Assert(prefs.LoggedOut(), qt.IsTrue) c.Assert(prefs.WantRunning(), qt.IsFalse) - c.Assert(ipn.NeedsLogin, qt.Equals, *nn[1].State) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // Verify notification indicates we need login (prefs show logged out) + c.Assert(nn[1].Prefs == nil || nn[1].Prefs.LoggedOut(), qt.IsTrue) + // Verify the actual facts about our state + c.Assert(needsLogin(b), qt.IsTrue) + c.Assert(hasValidNetMap(b), qt.IsFalse) } // Restart the state machine. @@ -455,8 +460,11 @@ func runTestStateMachine(t *testing.T, seamless bool) { c.Assert(nn[1].State, qt.IsNotNil) c.Assert(nn[0].Prefs.LoggedOut(), qt.IsTrue) c.Assert(nn[0].Prefs.WantRunning(), qt.IsFalse) - c.Assert(ipn.NeedsLogin, qt.Equals, *nn[1].State) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // Verify notification indicates we need login + c.Assert(nn[1].Prefs == nil || nn[1].Prefs.LoggedOut(), qt.IsTrue) + // Verify the actual facts about our state + c.Assert(needsLogin(b), qt.IsTrue) + c.Assert(hasValidNetMap(b), qt.IsFalse) } // Start non-interactive login with no token. @@ -473,7 +481,8 @@ func runTestStateMachine(t *testing.T, seamless bool) { // (This behaviour is needed so that b.Login() won't // start connecting to an old account right away, if one // exists when you launch another login.) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // Verify we still need login + c.Assert(needsLogin(b), qt.IsTrue) } // Attempted non-interactive login with no key; indicate that @@ -500,10 +509,11 @@ func runTestStateMachine(t *testing.T, seamless bool) { c.Assert(nn[1].Prefs, qt.IsNotNil) c.Assert(nn[1].Prefs.LoggedOut(), qt.IsTrue) c.Assert(nn[1].Prefs.WantRunning(), qt.IsFalse) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // Verify we need URL visit + c.Assert(hasAuthURL(b), qt.IsTrue) c.Assert(nn[2].BrowseToURL, qt.IsNotNil) c.Assert(url1, qt.Equals, *nn[2].BrowseToURL) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + c.Assert(isFullyAuthenticated(b), qt.IsFalse) } // Now we'll try an interactive login. @@ -518,7 +528,8 @@ func runTestStateMachine(t *testing.T, seamless bool) { cc.assertCalls() c.Assert(nn[0].BrowseToURL, qt.IsNotNil) c.Assert(url1, qt.Equals, *nn[0].BrowseToURL) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // Verify we still need to complete login + c.Assert(needsLogin(b), qt.IsTrue) } // Sometimes users press the Login button again, in the middle of @@ -534,7 +545,8 @@ func runTestStateMachine(t *testing.T, seamless bool) { notifies.drain(0) // backend asks control for another login sequence cc.assertCalls("Login") - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // Verify we still need login + c.Assert(needsLogin(b), qt.IsTrue) } // Provide a new interactive login URL. @@ -550,7 +562,8 @@ func runTestStateMachine(t *testing.T, seamless bool) { nn := notifies.drain(1) c.Assert(nn[0].BrowseToURL, qt.IsNotNil) c.Assert(url2, qt.Equals, *nn[0].BrowseToURL) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // Verify we still need to complete login + c.Assert(needsLogin(b), qt.IsTrue) } // Pretend that the interactive login actually happened. @@ -582,10 +595,18 @@ func runTestStateMachine(t *testing.T, seamless bool) { cc.assertCalls() c.Assert(nn[0].LoginFinished, qt.IsNotNil) c.Assert(nn[1].Prefs, qt.IsNotNil) - c.Assert(nn[2].State, qt.IsNotNil) c.Assert(nn[1].Prefs.Persist().UserProfile().LoginName, qt.Equals, "user1") - c.Assert(ipn.NeedsMachineAuth, qt.Equals, *nn[2].State) - c.Assert(ipn.NeedsMachineAuth, qt.Equals, b.State()) + // nn[2] is a state notification after login + // Verify login finished but need machine auth using backend state + c.Assert(isFullyAuthenticated(b), qt.IsTrue) + c.Assert(needsMachineAuth(b), qt.IsTrue) + nm := b.NetMap() + c.Assert(nm, qt.IsNotNil) + // For an empty netmap (after initial login), SelfNode may not be valid yet. + // In this case, we can't check MachineAuthorized, but needsMachineAuth already verified the state. + if nm.SelfNode.Valid() { + c.Assert(nm.SelfNode.MachineAuthorized(), qt.IsFalse) + } } // Pretend that the administrator has authorized our machine. @@ -603,8 +624,13 @@ func runTestStateMachine(t *testing.T, seamless bool) { { nn := notifies.drain(1) cc.assertCalls() - c.Assert(nn[0].State, qt.IsNotNil) - c.Assert(ipn.Starting, qt.Equals, *nn[0].State) + // nn[0] is a state notification after machine auth granted + c.Assert(len(nn), qt.Equals, 1) + // Verify machine authorized using backend state + nm := b.NetMap() + c.Assert(nm, qt.IsNotNil) + c.Assert(nm.SelfNode.Valid(), qt.IsTrue) + c.Assert(nm.SelfNode.MachineAuthorized(), qt.IsTrue) } // TODO: add a fake DERP server to our fake netmap, so we can @@ -627,9 +653,9 @@ func runTestStateMachine(t *testing.T, seamless bool) { nn := notifies.drain(2) cc.assertCalls("pause") // BUG: I would expect Prefs to change first, and state after. - c.Assert(nn[0].State, qt.IsNotNil) + // nn[0] is state notification, nn[1] is prefs notification c.Assert(nn[1].Prefs, qt.IsNotNil) - c.Assert(ipn.Stopped, qt.Equals, *nn[0].State) + c.Assert(nn[1].Prefs.WantRunning(), qt.IsFalse) } // The user changes their preference to WantRunning after all. @@ -645,15 +671,12 @@ func runTestStateMachine(t *testing.T, seamless bool) { // BUG: Login isn't needed here. We never logged out. cc.assertCalls("Login", "unpause") // BUG: I would expect Prefs to change first, and state after. - c.Assert(nn[0].State, qt.IsNotNil) + // nn[0] is state notification, nn[1] is prefs notification c.Assert(nn[1].Prefs, qt.IsNotNil) - c.Assert(ipn.Starting, qt.Equals, *nn[0].State) + c.Assert(nn[1].Prefs.WantRunning(), qt.IsTrue) c.Assert(store.sawWrite(), qt.IsTrue) } - // undo the state hack above. - b.state = ipn.Starting - // User wants to logout. store.awaitWrite() t.Logf("\n\nLogout") @@ -662,27 +685,26 @@ func runTestStateMachine(t *testing.T, seamless bool) { { nn := notifies.drain(5) previousCC.assertCalls("pause", "Logout", "unpause", "Shutdown") + // nn[0] is state notification (Stopped) c.Assert(nn[0].State, qt.IsNotNil) c.Assert(*nn[0].State, qt.Equals, ipn.Stopped) - + // nn[1] is prefs notification after logout c.Assert(nn[1].Prefs, qt.IsNotNil) c.Assert(nn[1].Prefs.LoggedOut(), qt.IsTrue) c.Assert(nn[1].Prefs.WantRunning(), qt.IsFalse) cc.assertCalls("New") - c.Assert(nn[2].State, qt.IsNotNil) - c.Assert(*nn[2].State, qt.Equals, ipn.NoState) - - c.Assert(nn[3].Prefs, qt.IsNotNil) // emptyPrefs + // nn[2] is the initial state notification after New (NoState) + // nn[3] is prefs notification with emptyPrefs + c.Assert(nn[3].Prefs, qt.IsNotNil) c.Assert(nn[3].Prefs.LoggedOut(), qt.IsTrue) c.Assert(nn[3].Prefs.WantRunning(), qt.IsFalse) - c.Assert(nn[4].State, qt.IsNotNil) - c.Assert(*nn[4].State, qt.Equals, ipn.NeedsLogin) - - c.Assert(b.State(), qt.Equals, ipn.NeedsLogin) - c.Assert(store.sawWrite(), qt.IsTrue) + // nn[4] is state notification (NeedsLogin) + // Verify logged out and needs new login using backend state + c.Assert(needsLogin(b), qt.IsTrue) + c.Assert(hasValidNetMap(b), qt.IsFalse) } // A second logout should be a no-op as we are in the NeedsLogin state. @@ -694,7 +716,8 @@ func runTestStateMachine(t *testing.T, seamless bool) { cc.assertCalls() c.Assert(b.Prefs().LoggedOut(), qt.IsTrue) c.Assert(b.Prefs().WantRunning(), qt.IsFalse) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // Verify still needs login + c.Assert(needsLogin(b), qt.IsTrue) } // A third logout should also be a no-op as the cc should be in @@ -707,7 +730,8 @@ func runTestStateMachine(t *testing.T, seamless bool) { cc.assertCalls() c.Assert(b.Prefs().LoggedOut(), qt.IsTrue) c.Assert(b.Prefs().WantRunning(), qt.IsFalse) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // Verify still needs login + c.Assert(needsLogin(b), qt.IsTrue) } // Oh, you thought we were done? Ha! Now we have to test what @@ -730,11 +754,13 @@ func runTestStateMachine(t *testing.T, seamless bool) { nn := notifies.drain(2) cc.assertCalls() c.Assert(nn[0].Prefs, qt.IsNotNil) - c.Assert(nn[1].State, qt.IsNotNil) c.Assert(nn[0].Prefs.LoggedOut(), qt.IsTrue) c.Assert(nn[0].Prefs.WantRunning(), qt.IsFalse) - c.Assert(ipn.NeedsLogin, qt.Equals, *nn[1].State) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // Verify notification indicates we need login + c.Assert(nn[1].Prefs == nil || nn[1].Prefs.LoggedOut(), qt.IsTrue) + // Verify we need login after restart + c.Assert(needsLogin(b), qt.IsTrue) + c.Assert(hasValidNetMap(b), qt.IsFalse) } // Explicitly set the ControlURL to avoid defaulting to [ipn.DefaultControlURL]. @@ -785,8 +811,9 @@ func runTestStateMachine(t *testing.T, seamless bool) { c.Assert(nn[1].Prefs.LoggedOut(), qt.IsFalse) // If a user initiates an interactive login, they also expect WantRunning to become true. c.Assert(nn[1].Prefs.WantRunning(), qt.IsTrue) - c.Assert(nn[2].State, qt.IsNotNil) - c.Assert(ipn.Starting, qt.Equals, *nn[2].State) + // nn[2] is state notification (Starting) - verify using backend state + c.Assert(isWantRunning(b), qt.IsTrue) + c.Assert(isLoggedIn(b), qt.IsTrue) } // Now we've logged in successfully. Let's disconnect. @@ -800,9 +827,9 @@ func runTestStateMachine(t *testing.T, seamless bool) { nn := notifies.drain(2) cc.assertCalls("pause") // BUG: I would expect Prefs to change first, and state after. - c.Assert(nn[0].State, qt.IsNotNil) + // nn[0] is state notification (Stopped), nn[1] is prefs notification c.Assert(nn[1].Prefs, qt.IsNotNil) - c.Assert(ipn.Stopped, qt.Equals, *nn[0].State) + c.Assert(nn[1].Prefs.WantRunning(), qt.IsFalse) c.Assert(nn[1].Prefs.LoggedOut(), qt.IsFalse) } @@ -820,10 +847,11 @@ func runTestStateMachine(t *testing.T, seamless bool) { // and WantRunning is false, so cc should be paused. cc.assertCalls("New", "Login", "pause") c.Assert(nn[0].Prefs, qt.IsNotNil) - c.Assert(nn[1].State, qt.IsNotNil) c.Assert(nn[0].Prefs.WantRunning(), qt.IsFalse) c.Assert(nn[0].Prefs.LoggedOut(), qt.IsFalse) - c.Assert(*nn[1].State, qt.Equals, ipn.Stopped) + // nn[1] is state notification (Stopped) + // Verify backend shows we're not wanting to run + c.Assert(isWantRunning(b), qt.IsFalse) } // When logged in but !WantRunning, ipn leaves us unpaused to retrieve @@ -861,9 +889,9 @@ func runTestStateMachine(t *testing.T, seamless bool) { nn := notifies.drain(2) cc.assertCalls("Login", "unpause") // BUG: I would expect Prefs to change first, and state after. - c.Assert(nn[0].State, qt.IsNotNil) + // nn[0] is state notification (Starting), nn[1] is prefs notification c.Assert(nn[1].Prefs, qt.IsNotNil) - c.Assert(ipn.Starting, qt.Equals, *nn[0].State) + c.Assert(nn[1].Prefs.WantRunning(), qt.IsTrue) } // Disconnect. @@ -877,9 +905,9 @@ func runTestStateMachine(t *testing.T, seamless bool) { nn := notifies.drain(2) cc.assertCalls("pause") // BUG: I would expect Prefs to change first, and state after. - c.Assert(nn[0].State, qt.IsNotNil) + // nn[0] is state notification (Stopped), nn[1] is prefs notification c.Assert(nn[1].Prefs, qt.IsNotNil) - c.Assert(ipn.Stopped, qt.Equals, *nn[0].State) + c.Assert(nn[1].Prefs.WantRunning(), qt.IsFalse) } // We want to try logging in as a different user, while Stopped. @@ -924,12 +952,13 @@ func runTestStateMachine(t *testing.T, seamless bool) { cc.assertCalls("unpause") c.Assert(nn[0].LoginFinished, qt.IsNotNil) c.Assert(nn[1].Prefs, qt.IsNotNil) - c.Assert(nn[2].State, qt.IsNotNil) // Prefs after finishing the login, so LoginName updated. c.Assert(nn[1].Prefs.Persist().UserProfile().LoginName, qt.Equals, "user3") c.Assert(nn[1].Prefs.LoggedOut(), qt.IsFalse) c.Assert(nn[1].Prefs.WantRunning(), qt.IsTrue) - c.Assert(ipn.Starting, qt.Equals, *nn[2].State) + // nn[2] is state notification (Starting) - verify using backend state + c.Assert(isWantRunning(b), qt.IsTrue) + c.Assert(isLoggedIn(b), qt.IsTrue) } // The last test case is the most common one: restarting when both @@ -948,11 +977,10 @@ func runTestStateMachine(t *testing.T, seamless bool) { c.Assert(nn[0].Prefs, qt.IsNotNil) c.Assert(nn[0].Prefs.LoggedOut(), qt.IsFalse) c.Assert(nn[0].Prefs.WantRunning(), qt.IsTrue) - // We're logged in and have a valid netmap, so we should - // be in the Starting state. - c.Assert(nn[1].State, qt.IsNotNil) - c.Assert(*nn[1].State, qt.Equals, ipn.Starting) - c.Assert(b.State(), qt.Equals, ipn.Starting) + // nn[1] is state notification (Starting) + // Verify we're authenticated with valid netmap using backend state + c.Assert(isFullyAuthenticated(b), qt.IsTrue) + c.Assert(hasValidNetMap(b), qt.IsTrue) } // Control server accepts our valid key from before. @@ -969,35 +997,44 @@ func runTestStateMachine(t *testing.T, seamless bool) { // NOTE: No prefs change this time. WantRunning stays true. // We were in Starting in the first place, so that doesn't // change either, so we don't expect any notifications. - c.Assert(ipn.Starting, qt.Equals, b.State()) + // Verify we're still authenticated with valid netmap + c.Assert(isFullyAuthenticated(b), qt.IsTrue) + c.Assert(hasValidNetMap(b), qt.IsTrue) } t.Logf("\n\nExpireKey") notifies.expect(1) cc.send(sendOpt{nm: &netmap.NetworkMap{ - Expiry: time.Now().Add(-time.Minute), - SelfNode: (&tailcfg.Node{MachineAuthorized: true}).View(), + SelfNode: (&tailcfg.Node{ + KeyExpiry: time.Now().Add(-time.Minute), + MachineAuthorized: true, + }).View(), }}) { nn := notifies.drain(1) cc.assertCalls() - c.Assert(nn[0].State, qt.IsNotNil) - c.Assert(ipn.NeedsLogin, qt.Equals, *nn[0].State) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // nn[0] is state notification (NeedsLogin) due to key expiry + c.Assert(len(nn), qt.Equals, 1) + // Verify key expired, need new login using backend state + c.Assert(needsLogin(b), qt.IsTrue) c.Assert(b.isEngineBlocked(), qt.IsTrue) } t.Logf("\n\nExtendKey") notifies.expect(1) cc.send(sendOpt{nm: &netmap.NetworkMap{ - Expiry: time.Now().Add(time.Minute), - SelfNode: (&tailcfg.Node{MachineAuthorized: true}).View(), + SelfNode: (&tailcfg.Node{ + MachineAuthorized: true, + KeyExpiry: time.Now().Add(time.Minute), + }).View(), }}) { nn := notifies.drain(1) cc.assertCalls() - c.Assert(nn[0].State, qt.IsNotNil) - c.Assert(ipn.Starting, qt.Equals, *nn[0].State) - c.Assert(ipn.Starting, qt.Equals, b.State()) + // nn[0] is state notification (Starting) after key extension + c.Assert(len(nn), qt.Equals, 1) + // Verify key extended, authenticated again using backend state + c.Assert(isFullyAuthenticated(b), qt.IsTrue) + c.Assert(hasValidNetMap(b), qt.IsTrue) c.Assert(b.isEngineBlocked(), qt.IsFalse) } notifies.expect(1) @@ -1006,9 +1043,10 @@ func runTestStateMachine(t *testing.T, seamless bool) { { nn := notifies.drain(1) cc.assertCalls() - c.Assert(nn[0].State, qt.IsNotNil) - c.Assert(ipn.Running, qt.Equals, *nn[0].State) - c.Assert(ipn.Running, qt.Equals, b.State()) + // nn[0] is state notification (Running) after DERP connection + c.Assert(len(nn), qt.Equals, 1) + // Verify we can route traffic using backend state + c.Assert(canRouteTraffic(b), qt.IsTrue) } } @@ -1241,8 +1279,6 @@ func TestEngineReconfigOnStateChange(t *testing.T) { // After the auth is completed, the configs must be updated to reflect the node's netmap. wantState: ipn.Starting, wantCfg: &wgcfg.Config{ - Name: "tailscale", - NodeID: node1.SelfNode.StableID(), Peers: []wgcfg.Peer{}, Addresses: node1.SelfNode.Addresses().AsSlice(), }, @@ -1299,8 +1335,6 @@ func TestEngineReconfigOnStateChange(t *testing.T) { // Once the auth is completed, the configs must be updated to reflect the node's netmap. wantState: ipn.Starting, wantCfg: &wgcfg.Config{ - Name: "tailscale", - NodeID: node2.SelfNode.StableID(), Peers: []wgcfg.Peer{}, Addresses: node2.SelfNode.Addresses().AsSlice(), }, @@ -1349,8 +1383,6 @@ func TestEngineReconfigOnStateChange(t *testing.T) { // must be updated to reflect the node's netmap. wantState: ipn.Starting, wantCfg: &wgcfg.Config{ - Name: "tailscale", - NodeID: node1.SelfNode.StableID(), Peers: []wgcfg.Peer{}, Addresses: node1.SelfNode.Addresses().AsSlice(), }, @@ -1374,8 +1406,6 @@ func TestEngineReconfigOnStateChange(t *testing.T) { }, wantState: ipn.Starting, wantCfg: &wgcfg.Config{ - Name: "tailscale", - NodeID: node3.SelfNode.StableID(), Peers: []wgcfg.Peer{ { PublicKey: node1.SelfNode.Key(), @@ -1406,7 +1436,9 @@ func TestEngineReconfigOnStateChange(t *testing.T) { mustDo2(t)(lb.EditPrefs(connect)) cc().authenticated(node1) cc().send(sendOpt{nm: &netmap.NetworkMap{ - Expiry: time.Now().Add(-time.Minute), + SelfNode: (&tailcfg.Node{ + KeyExpiry: time.Now().Add(-time.Minute), + }).View(), }}) }, wantState: ipn.NeedsLogin, @@ -1447,8 +1479,6 @@ func TestEngineReconfigOnStateChange(t *testing.T) { }, wantState: ipn.Starting, wantCfg: &wgcfg.Config{ - Name: "tailscale", - NodeID: node1.SelfNode.StableID(), Peers: []wgcfg.Peer{}, Addresses: node1.SelfNode.Addresses().AsSlice(), }, @@ -1478,8 +1508,6 @@ func TestEngineReconfigOnStateChange(t *testing.T) { // With seamless renewal, starting a reauth should leave everything up: wantState: ipn.Starting, wantCfg: &wgcfg.Config{ - Name: "tailscale", - NodeID: node1.SelfNode.StableID(), Peers: []wgcfg.Peer{}, Addresses: node1.SelfNode.Addresses().AsSlice(), }, @@ -1511,8 +1539,6 @@ func TestEngineReconfigOnStateChange(t *testing.T) { }, wantState: ipn.Starting, wantCfg: &wgcfg.Config{ - Name: "tailscale", - NodeID: node1.SelfNode.StableID(), Peers: []wgcfg.Peer{}, Addresses: node1.SelfNode.Addresses().AsSlice(), }, @@ -1535,7 +1561,9 @@ func TestEngineReconfigOnStateChange(t *testing.T) { mustDo2(t)(lb.EditPrefs(connect)) cc().authenticated(node1) cc().send(sendOpt{nm: &netmap.NetworkMap{ - Expiry: time.Now().Add(-time.Minute), + SelfNode: (&tailcfg.Node{ + KeyExpiry: time.Now().Add(-time.Minute), + }).View(), }}) }, // Even with seamless, if the key we are using expires, we want to disconnect: @@ -1554,6 +1582,11 @@ func TestEngineReconfigOnStateChange(t *testing.T) { tt.steps(t, lb, cc) } + // TODO(bradfitz): this whole event bus settling thing + // should be unnecessary once the bogus uses of eventbus + // are removed. (https://github.com/tailscale/tailscale/issues/16369) + lb.settleEventBus() + if gotState := lb.State(); gotState != tt.wantState { t.Errorf("State: got %v; want %v", gotState, tt.wantState) } @@ -1584,35 +1617,30 @@ func TestEngineReconfigOnStateChange(t *testing.T) { } } -// TestStateMachineURLRace tests that wgengine updates arriving in the middle of +// TestSendPreservesAuthURL tests that wgengine updates arriving in the middle of // processing an auth URL doesn't result in the auth URL being cleared. -func TestStateMachineURLRace(t *testing.T) { - runTestStateMachineURLRace(t, false) +func TestSendPreservesAuthURL(t *testing.T) { + runTestSendPreservesAuthURL(t, false) } -func TestStateMachineURLRaceSeamless(t *testing.T) { - runTestStateMachineURLRace(t, true) +func TestSendPreservesAuthURLSeamless(t *testing.T) { + runTestSendPreservesAuthURL(t, true) } -func runTestStateMachineURLRace(t *testing.T, seamless bool) { +func runTestSendPreservesAuthURL(t *testing.T, seamless bool) { var cc *mockControl b := newLocalBackendWithTestControl(t, true, func(tb testing.TB, opts controlclient.Options) controlclient.Client { cc = newClient(t, opts) return cc }) - nw := newNotificationWatcher(t, b, &ipnauth.TestActor{}) - t.Logf("Start") - nw.watch(0, []wantedNotification{ - wantStateNotify(ipn.NeedsLogin)}) b.Start(ipn.Options{ UpdatePrefs: &ipn.Prefs{ WantRunning: true, ControlURL: "https://localhost:1/", }, }) - nw.check() t.Logf("LoginFinished") cc.persist.UserProfile.LoginName = "user1" @@ -1622,72 +1650,16 @@ func runTestStateMachineURLRace(t *testing.T, seamless bool) { b.sys.ControlKnobs().SeamlessKeyRenewal.Store(true) } - nw.watch(0, []wantedNotification{ - wantStateNotify(ipn.Starting)}) cc.send(sendOpt{loginFinished: true, nm: &netmap.NetworkMap{ SelfNode: (&tailcfg.Node{MachineAuthorized: true}).View(), }}) - nw.check() t.Logf("Running") - nw.watch(0, []wantedNotification{ - wantStateNotify(ipn.Running)}) b.setWgengineStatus(&wgengine.Status{AsOf: time.Now(), DERPs: 1}, nil) - nw.check() t.Logf("Re-auth (StartLoginInteractive)") b.StartLoginInteractive(t.Context()) - stop := make(chan struct{}) - stopSpamming := sync.OnceFunc(func() { - stop <- struct{}{} - }) - // if seamless renewal is enabled, the engine won't be disabled, and we won't - // ever call stopSpamming, so make sure it does get called - defer stopSpamming() - - // Intercept updates between the engine and localBackend, so that we can see - // when the "stopped" update comes in and ensure we stop sending our "we're - // up" updates after that point. - b.e.SetStatusCallback(func(s *wgengine.Status, err error) { - // This is not one of our fake status updates, this is generated from the - // engine in response to LocalBackend calling RequestStatus. Stop spamming - // our fake statuses. - // - // TODO(zofrex): This is fragile, it works right now but would break if the - // calling pattern of RequestStatus changes. We should ensure that we keep - // sending "we're up" statuses right until Reconfig is called with - // zero-valued configs, and after that point only send "stopped" statuses. - stopSpamming() - - // Once stopSpamming returns we are guaranteed to not send any more updates, - // so we can now send the real update (indicating shutdown) and be certain - // it will be received after any fake updates we sent. This is possibly a - // stronger guarantee than we get from the real engine? - b.setWgengineStatus(s, err) - }) - - // time needs to be >= last time for the status to be accepted, send all our - // spam with the same stale time so that when a real update comes in it will - // definitely be accepted. - time := b.lastStatusTime - - // Flood localBackend with a lot of wgengine status updates, so if there are - // any race conditions in the multiple locks/unlocks that happen as we process - // the received auth URL, we will hit them. - go func() { - t.Logf("sending lots of fake wgengine status updates") - for { - select { - case <-stop: - t.Logf("stopping fake wgengine status updates") - return - default: - b.setWgengineStatus(&wgengine.Status{AsOf: time, DERPs: 1}, nil) - } - } - }() - t.Logf("Re-auth (receive URL)") url1 := "https://localhost:1/1" cc.send(sendOpt{url: url1}) @@ -1697,122 +1669,11 @@ func runTestStateMachineURLRace(t *testing.T, seamless bool) { // status update to trample it have ended as well. if b.authURL == "" { t.Fatalf("expected authURL to be set") + } else { + t.Log("authURL was set") } } -func TestWGEngineDownThenUpRace(t *testing.T) { - var cc *mockControl - b := newLocalBackendWithTestControl(t, true, func(tb testing.TB, opts controlclient.Options) controlclient.Client { - cc = newClient(t, opts) - return cc - }) - - nw := newNotificationWatcher(t, b, &ipnauth.TestActor{}) - - t.Logf("Start") - nw.watch(0, []wantedNotification{ - wantStateNotify(ipn.NeedsLogin)}) - b.Start(ipn.Options{ - UpdatePrefs: &ipn.Prefs{ - WantRunning: true, - ControlURL: "https://localhost:1/", - }, - }) - nw.check() - - t.Logf("LoginFinished") - cc.persist.UserProfile.LoginName = "user1" - cc.persist.NodeID = "node1" - - nw.watch(0, []wantedNotification{ - wantStateNotify(ipn.Starting)}) - cc.send(sendOpt{loginFinished: true, nm: &netmap.NetworkMap{ - SelfNode: (&tailcfg.Node{MachineAuthorized: true}).View(), - }}) - nw.check() - - nw.watch(0, []wantedNotification{ - wantStateNotify(ipn.Running)}) - b.setWgengineStatus(&wgengine.Status{AsOf: time.Now(), DERPs: 1}, nil) - nw.check() - - t.Logf("Re-auth (StartLoginInteractive)") - b.StartLoginInteractive(t.Context()) - - var timeLock sync.RWMutex - timestamp := b.lastStatusTime - - engineShutdown := make(chan struct{}) - gotShutdown := sync.OnceFunc(func() { - t.Logf("engineShutdown") - engineShutdown <- struct{}{} - }) - - b.e.SetStatusCallback(func(s *wgengine.Status, err error) { - timeLock.Lock() - if s.AsOf.After(timestamp) { - timestamp = s.AsOf - } - timeLock.Unlock() - - if err != nil || (s.DERPs == 0 && len(s.Peers) == 0) { - gotShutdown() - } else { - b.setWgengineStatus(s, err) - } - }) - - t.Logf("Re-auth (receive URL)") - url1 := "https://localhost:1/1" - - done := make(chan struct{}) - var wg sync.WaitGroup - - wg.Go(func() { - t.Log("cc.send starting") - cc.send(sendOpt{url: url1}) // will block until engine stops - t.Log("cc.send returned") - }) - - <-engineShutdown // will get called once cc.send is blocked - gotShutdown = sync.OnceFunc(func() { - t.Logf("engineShutdown") - engineShutdown <- struct{}{} - }) - - wg.Go(func() { - t.Log("StartLoginInteractive starting") - b.StartLoginInteractive(t.Context()) // will also block until engine stops - t.Log("StartLoginInteractive returned") - }) - - <-engineShutdown // will get called once StartLoginInteractive is blocked - - st := controlclient.Status{} - st.SetStateForTest(controlclient.StateAuthenticated) - b.SetControlClientStatus(cc, st) - - timeLock.RLock() - b.setWgengineStatus(&wgengine.Status{AsOf: timestamp}, nil) // engine is down event finally arrives - b.setWgengineStatus(&wgengine.Status{AsOf: timestamp, DERPs: 1}, nil) // engine is back up - timeLock.RUnlock() - - go func() { - wg.Wait() - done <- struct{}{} - }() - - t.Log("waiting for .send and .StartLoginInteractive to return") - - select { - case <-done: - case <-time.After(10 * time.Second): - t.Fatalf("timed out waiting") - } - - t.Log("both returned") -} - func buildNetmapWithPeers(self tailcfg.NodeView, peers ...tailcfg.NodeView) *netmap.NetworkMap { const ( firstAutoUserID = tailcfg.UserID(10000) @@ -1877,7 +1738,6 @@ func buildNetmapWithPeers(self tailcfg.NodeView, peers ...tailcfg.NodeView) *net return &netmap.NetworkMap{ SelfNode: self, - Name: self.Name(), Domain: domain, Peers: peers, UserProfiles: users, @@ -2045,6 +1905,14 @@ func (e *mockEngine) RequestStatus() { } } +func (e *mockEngine) ResetAndStop() (*wgengine.Status, error) { + err := e.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{}) + if err != nil { + return nil, err + } + return &wgengine.Status{AsOf: time.Now()}, nil +} + func (e *mockEngine) PeerByKey(key.NodePublic) (_ wgint.Peer, ok bool) { return wgint.Peer{}, false } @@ -2072,3 +1940,77 @@ func (e *mockEngine) Close() { func (e *mockEngine) Done() <-chan struct{} { return e.done } + +// hasValidNetMap returns true if the backend has a valid network map with a valid self node. +func hasValidNetMap(b *LocalBackend) bool { + nm := b.NetMap() + return nm != nil && nm.SelfNode.Valid() +} + +// needsLogin returns true if the backend needs user login action. +// This is true when logged out, when an auth URL is present (interactive login in progress), +// or when the node key has expired. +func needsLogin(b *LocalBackend) bool { + // Note: b.Prefs() handles its own locking, so we lock only for authURL and keyExpired access + b.mu.Lock() + authURL := b.authURL + keyExpired := b.keyExpired + b.mu.Unlock() + return b.Prefs().LoggedOut() || authURL != "" || keyExpired +} + +// needsMachineAuth returns true if the user has logged in but the machine is not yet authorized. +// This includes the case where we have a netmap but no valid SelfNode yet (empty netmap after initial login). +func needsMachineAuth(b *LocalBackend) bool { + // Note: b.NetMap() and b.Prefs() handle their own locking + nm := b.NetMap() + prefs := b.Prefs() + if prefs.LoggedOut() || nm == nil { + return false + } + // If we have a valid SelfNode, check its MachineAuthorized status + if nm.SelfNode.Valid() { + return !nm.SelfNode.MachineAuthorized() + } + // Empty netmap (no SelfNode yet) after login also means we need machine auth + return true +} + +// hasAuthURL returns true if an authentication URL is present (user needs to visit a URL). +func hasAuthURL(b *LocalBackend) bool { + b.mu.Lock() + authURL := b.authURL + b.mu.Unlock() + return authURL != "" +} + +// canRouteTraffic returns true if the backend is capable of routing traffic. +// This requires a valid netmap, machine authorization, and WantRunning preference. +func canRouteTraffic(b *LocalBackend) bool { + // Note: b.NetMap() and b.Prefs() handle their own locking + nm := b.NetMap() + prefs := b.Prefs() + return nm != nil && + nm.SelfNode.Valid() && + nm.SelfNode.MachineAuthorized() && + prefs.WantRunning() +} + +// isFullyAuthenticated returns true if the user has completed login and no auth URL is pending. +func isFullyAuthenticated(b *LocalBackend) bool { + // Note: b.Prefs() handles its own locking, so we lock only for authURL access + b.mu.Lock() + authURL := b.authURL + b.mu.Unlock() + return !b.Prefs().LoggedOut() && authURL == "" +} + +// isWantRunning returns true if the WantRunning preference is set. +func isWantRunning(b *LocalBackend) bool { + return b.Prefs().WantRunning() +} + +// isLoggedIn returns true if the user is logged in (not logged out). +func isLoggedIn(b *LocalBackend) bool { + return !b.Prefs().LoggedOut() +} diff --git a/ipn/localapi/debug.go b/ipn/localapi/debug.go index 8aca7f0093f7d..ae9cb01e02fe9 100644 --- a/ipn/localapi/debug.go +++ b/ipn/localapi/debug.go @@ -31,6 +31,7 @@ import ( func init() { Register("component-debug-logging", (*Handler).serveComponentDebugLogging) Register("debug", (*Handler).serveDebug) + Register("debug-rotate-disco-key", (*Handler).serveDebugRotateDiscoKey) Register("dev-set-state-store", (*Handler).serveDevSetStateStore) Register("debug-bus-events", (*Handler).serveDebugBusEvents) Register("debug-bus-graph", (*Handler).serveEventBusGraph) @@ -232,6 +233,8 @@ func (h *Handler) serveDebug(w http.ResponseWriter, r *http.Request) { if err == nil { return } + case "rotate-disco-key": + err = h.b.DebugRotateDiscoKey() case "": err = fmt.Errorf("missing parameter 'action'") default: @@ -473,3 +476,20 @@ func (h *Handler) serveDebugOptionalFeatures(w http.ResponseWriter, r *http.Requ w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(of) } + +func (h *Handler) serveDebugRotateDiscoKey(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "debug access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "POST required", http.StatusMethodNotAllowed) + return + } + if err := h.b.DebugRotateDiscoKey(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "text/plain") + io.WriteString(w, "done\n") +} diff --git a/ipn/localapi/localapi.go b/ipn/localapi/localapi.go index 9e7c16891fc20..7f249fe530e15 100644 --- a/ipn/localapi/localapi.go +++ b/ipn/localapi/localapi.go @@ -7,6 +7,7 @@ package localapi import ( "bytes" "cmp" + "crypto/subtle" "encoding/json" "errors" "fmt" @@ -34,6 +35,7 @@ import ( "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/ipnstate" "tailscale.com/logtail" + "tailscale.com/net/netns" "tailscale.com/net/netutil" "tailscale.com/tailcfg" "tailscale.com/tstime" @@ -71,20 +73,21 @@ var handler = map[string]LocalAPIHandler{ // The other /localapi/v0/NAME handlers are exact matches and contain only NAME // without a trailing slash: - "check-prefs": (*Handler).serveCheckPrefs, - "derpmap": (*Handler).serveDERPMap, - "goroutines": (*Handler).serveGoroutines, - "login-interactive": (*Handler).serveLoginInteractive, - "logout": (*Handler).serveLogout, - "ping": (*Handler).servePing, - "prefs": (*Handler).servePrefs, - "reload-config": (*Handler).reloadConfig, - "reset-auth": (*Handler).serveResetAuth, - "set-expiry-sooner": (*Handler).serveSetExpirySooner, - "shutdown": (*Handler).serveShutdown, - "start": (*Handler).serveStart, - "status": (*Handler).serveStatus, - "whois": (*Handler).serveWhoIs, + "check-prefs": (*Handler).serveCheckPrefs, + "check-so-mark-in-use": (*Handler).serveCheckSOMarkInUse, + "derpmap": (*Handler).serveDERPMap, + "goroutines": (*Handler).serveGoroutines, + "login-interactive": (*Handler).serveLoginInteractive, + "logout": (*Handler).serveLogout, + "ping": (*Handler).servePing, + "prefs": (*Handler).servePrefs, + "reload-config": (*Handler).reloadConfig, + "reset-auth": (*Handler).serveResetAuth, + "set-expiry-sooner": (*Handler).serveSetExpirySooner, + "shutdown": (*Handler).serveShutdown, + "start": (*Handler).serveStart, + "status": (*Handler).serveStatus, + "whois": (*Handler).serveWhoIs, } func init() { @@ -257,13 +260,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, "auth required", http.StatusUnauthorized) return } - if pass != h.RequiredPassword { + if subtle.ConstantTimeCompare([]byte(pass), []byte(h.RequiredPassword)) == 0 { metricInvalidRequests.Add(1) http.Error(w, "bad password", http.StatusForbidden) return } } - if fn, ok := handlerForPath(r.URL.Path); ok { + if fn, route, ok := handlerForPath(r.URL.Path); ok { + h.logRequest(r.Method, route) fn(h, w, r) } else { http.NotFound(w, r) @@ -299,9 +303,9 @@ func (h *Handler) validHost(hostname string) bool { // handlerForPath returns the LocalAPI handler for the provided Request.URI.Path. // (the path doesn't include any query parameters) -func handlerForPath(urlPath string) (h LocalAPIHandler, ok bool) { +func handlerForPath(urlPath string) (h LocalAPIHandler, route string, ok bool) { if urlPath == "/" { - return (*Handler).serveLocalAPIRoot, true + return (*Handler).serveLocalAPIRoot, "/", true } suff, ok := strings.CutPrefix(urlPath, "/localapi/v0/") if !ok { @@ -309,22 +313,31 @@ func handlerForPath(urlPath string) (h LocalAPIHandler, ok bool) { // to people that they're not necessarily stable APIs. In practice we'll // probably need to keep them pretty stable anyway, but for now treat // them as an internal implementation detail. - return nil, false + return nil, "", false } if fn, ok := handler[suff]; ok { // Here we match exact handler suffixes like "status" or ones with a // slash already in their name, like "tka/status". - return fn, true + return fn, "/localapi/v0/" + suff, true } // Otherwise, it might be a prefix match like "files/*" which we look up // by the prefix including first trailing slash. if i := strings.IndexByte(suff, '/'); i != -1 { suff = suff[:i+1] if fn, ok := handler[suff]; ok { - return fn, true + return fn, "/localapi/v0/" + suff, true } } - return nil, false + return nil, "", false +} + +func (h *Handler) logRequest(method, route string) { + switch method { + case httpm.GET, httpm.HEAD, httpm.OPTIONS: + // don't log safe methods + default: + h.Logf("localapi: [%s] %s", method, route) + } } func (*Handler) serveLocalAPIRoot(w http.ResponseWriter, r *http.Request) { @@ -749,6 +762,23 @@ func (h *Handler) serveCheckIPForwarding(w http.ResponseWriter, r *http.Request) }) } +// serveCheckSOMarkInUse reports whether SO_MARK is in use on the linux while +// running without TUN. For any other OS, it reports false. +func (h *Handler) serveCheckSOMarkInUse(w http.ResponseWriter, r *http.Request) { + if !h.PermitRead { + http.Error(w, "SO_MARK check access denied", http.StatusForbidden) + return + } + usingSOMark := netns.UseSocketMark() + usingUserspaceNetworking := h.b.Sys().IsNetstack() + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(struct { + UseSOMark bool + }{ + UseSOMark: usingSOMark || usingUserspaceNetworking, + }) +} + func (h *Handler) serveCheckReversePathFiltering(w http.ResponseWriter, r *http.Request) { if !h.PermitRead { http.Error(w, "reverse path filtering check access denied", http.StatusForbidden) @@ -876,14 +906,6 @@ func (h *Handler) serveWatchIPNBus(w http.ResponseWriter, r *http.Request) { } mask = ipn.NotifyWatchOpt(v) } - // Users with only read access must request private key filtering. If they - // don't filter out private keys, require write access. - if (mask & ipn.NotifyNoPrivateKeys) == 0 { - if !h.PermitWrite { - http.Error(w, "watch IPN bus access denied, must set ipn.NotifyNoPrivateKeys when not running as admin/root or operator", http.StatusForbidden) - return - } - } w.Header().Set("Content-Type", "application/json") ctx := r.Context() @@ -908,7 +930,10 @@ func (h *Handler) serveLoginInteractive(w http.ResponseWriter, r *http.Request) http.Error(w, "want POST", http.StatusBadRequest) return } - h.b.StartLoginInteractiveAs(r.Context(), h.Actor) + if err := h.b.StartLoginInteractiveAs(r.Context(), h.Actor); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } w.WriteHeader(http.StatusNoContent) return } @@ -927,6 +952,11 @@ func (h *Handler) serveStart(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) return } + + if h.b.HealthTracker().IsUnhealthy(ipn.StateStoreHealth) { + http.Error(w, "cannot start backend when state store is unhealthy", http.StatusInternalServerError) + return + } err := h.b.Start(o) if err != nil { // TODO(bradfitz): map error to a good HTTP error diff --git a/ipn/localapi/localapi_test.go b/ipn/localapi/localapi_test.go index fa24717f7a942..5d228ffd69343 100644 --- a/ipn/localapi/localapi_test.go +++ b/ipn/localapi/localapi_test.go @@ -25,9 +25,11 @@ import ( "testing" "tailscale.com/client/tailscale/apitype" + "tailscale.com/health" "tailscale.com/ipn" "tailscale.com/ipn/ipnauth" "tailscale.com/ipn/ipnlocal" + "tailscale.com/ipn/ipnstate" "tailscale.com/ipn/store/mem" "tailscale.com/tailcfg" "tailscale.com/tsd" @@ -40,6 +42,19 @@ import ( "tailscale.com/wgengine" ) +func handlerForTest(t testing.TB, h *Handler) *Handler { + if h.Actor == nil { + h.Actor = &ipnauth.TestActor{} + } + if h.b == nil { + h.b = &ipnlocal.LocalBackend{} + } + if h.logf == nil { + h.logf = logger.TestLogger(t) + } + return h +} + func TestValidHost(t *testing.T) { tests := []struct { host string @@ -57,7 +72,7 @@ func TestValidHost(t *testing.T) { for _, test := range tests { t.Run(test.host, func(t *testing.T) { - h := &Handler{} + h := handlerForTest(t, &Handler{}) if got := h.validHost(test.host); got != test.valid { t.Errorf("validHost(%q)=%v, want %v", test.host, got, test.valid) } @@ -68,10 +83,9 @@ func TestValidHost(t *testing.T) { func TestSetPushDeviceToken(t *testing.T) { tstest.Replace(t, &validLocalHostForTesting, true) - h := &Handler{ + h := handlerForTest(t, &Handler{ PermitWrite: true, - b: &ipnlocal.LocalBackend{}, - } + }) s := httptest.NewServer(h) defer s.Close() c := s.Client() @@ -125,9 +139,9 @@ func (b whoIsBackend) PeerCaps(ip netip.Addr) tailcfg.PeerCapMap { // // And https://github.com/tailscale/tailscale/issues/12465 func TestWhoIsArgTypes(t *testing.T) { - h := &Handler{ + h := handlerForTest(t, &Handler{ PermitRead: true, - } + }) match := func() (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) { return (&tailcfg.Node{ @@ -190,7 +204,10 @@ func TestWhoIsArgTypes(t *testing.T) { func TestShouldDenyServeConfigForGOOSAndUserContext(t *testing.T) { newHandler := func(connIsLocalAdmin bool) *Handler { - return &Handler{Actor: &ipnauth.TestActor{LocalAdmin: connIsLocalAdmin}, b: newTestLocalBackend(t)} + return handlerForTest(t, &Handler{ + Actor: &ipnauth.TestActor{LocalAdmin: connIsLocalAdmin}, + b: newTestLocalBackend(t), + }) } tests := []struct { name string @@ -263,13 +280,17 @@ func TestShouldDenyServeConfigForGOOSAndUserContext(t *testing.T) { }) } +// TestServeWatchIPNBus used to test that various WatchIPNBus mask flags +// changed the permissions required to access the endpoint. +// However, since the removal of the NotifyNoPrivateKeys flag requirement +// for read-only users, this test now only verifies that the endpoint +// behaves correctly based on the PermitRead and PermitWrite settings. func TestServeWatchIPNBus(t *testing.T) { tstest.Replace(t, &validLocalHostForTesting, true) tests := []struct { desc string permitRead, permitWrite bool - mask ipn.NotifyWatchOpt // extra bits in addition to ipn.NotifyInitialState wantStatus int }{ { @@ -279,20 +300,13 @@ func TestServeWatchIPNBus(t *testing.T) { wantStatus: http.StatusForbidden, }, { - desc: "read-initial-state", - permitRead: true, - permitWrite: false, - wantStatus: http.StatusForbidden, - }, - { - desc: "read-initial-state-no-private-keys", + desc: "read-only", permitRead: true, permitWrite: false, - mask: ipn.NotifyNoPrivateKeys, wantStatus: http.StatusOK, }, { - desc: "read-initial-state-with-private-keys", + desc: "read-and-write", permitRead: true, permitWrite: true, wantStatus: http.StatusOK, @@ -301,17 +315,17 @@ func TestServeWatchIPNBus(t *testing.T) { for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - h := &Handler{ + h := handlerForTest(t, &Handler{ PermitRead: tt.permitRead, PermitWrite: tt.permitWrite, b: newTestLocalBackend(t), - } + }) s := httptest.NewServer(h) defer s.Close() c := s.Client() ctx, cancel := context.WithCancel(context.Background()) - req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/localapi/v0/watch-ipn-bus?mask=%d", s.URL, ipn.NotifyInitialState|tt.mask), nil) + req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/localapi/v0/watch-ipn-bus?mask=%d", s.URL, ipn.NotifyInitialState), nil) if err != nil { t.Fatal(err) } @@ -416,3 +430,73 @@ func TestKeepItSorted(t *testing.T) { } } } + +func TestServeWithUnhealthyState(t *testing.T) { + tstest.Replace(t, &validLocalHostForTesting, true) + h := &Handler{ + PermitRead: true, + PermitWrite: true, + b: newTestLocalBackend(t), + logf: t.Logf, + } + h.b.HealthTracker().SetUnhealthy(ipn.StateStoreHealth, health.Args{health.ArgError: "testing"}) + if err := h.b.Start(ipn.Options{}); err != nil { + t.Fatal(err) + } + + check500Body := func(wantResp string) func(t *testing.T, code int, resp []byte) { + return func(t *testing.T, code int, resp []byte) { + if code != http.StatusInternalServerError { + t.Errorf("got code: %v, want %v\nresponse: %q", code, http.StatusInternalServerError, resp) + } + if got := strings.TrimSpace(string(resp)); got != wantResp { + t.Errorf("got response: %q, want %q", got, wantResp) + } + } + } + tests := []struct { + desc string + req *http.Request + check func(t *testing.T, code int, resp []byte) + }{ + { + desc: "status", + req: httptest.NewRequest("GET", "http://localhost:1234/localapi/v0/status", nil), + check: func(t *testing.T, code int, resp []byte) { + if code != http.StatusOK { + t.Errorf("got code: %v, want %v\nresponse: %q", code, http.StatusOK, resp) + } + var status ipnstate.Status + if err := json.Unmarshal(resp, &status); err != nil { + t.Fatal(err) + } + if status.BackendState != "NoState" { + t.Errorf("got backend state: %q, want %q", status.BackendState, "NoState") + } + }, + }, + { + desc: "login-interactive", + req: httptest.NewRequest("POST", "http://localhost:1234/localapi/v0/login-interactive", nil), + check: check500Body("cannot log in when state store is unhealthy"), + }, + { + desc: "start", + req: httptest.NewRequest("POST", "http://localhost:1234/localapi/v0/start", strings.NewReader("{}")), + check: check500Body("cannot start backend when state store is unhealthy"), + }, + { + desc: "new-profile", + req: httptest.NewRequest("PUT", "http://localhost:1234/localapi/v0/profiles/", nil), + check: check500Body("cannot log in when state store is unhealthy"), + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + resp := httptest.NewRecorder() + h.ServeHTTP(resp, tt.req) + tt.check(t, resp.Code, resp.Body.Bytes()) + }) + } +} diff --git a/ipn/localapi/tailnetlock.go b/ipn/localapi/tailnetlock.go index 4baadb7339871..e5f999bb8847e 100644 --- a/ipn/localapi/tailnetlock.go +++ b/ipn/localapi/tailnetlock.go @@ -266,12 +266,12 @@ func (h *Handler) serveTKALog(w http.ResponseWriter, r *http.Request) { limit := 50 if limitStr := r.FormValue("limit"); limitStr != "" { - l, err := strconv.Atoi(limitStr) + lm, err := strconv.Atoi(limitStr) if err != nil { http.Error(w, "parsing 'limit' parameter: "+err.Error(), http.StatusBadRequest) return } - limit = int(l) + limit = int(lm) } updates, err := h.b.NetworkLockLog(limit) diff --git a/ipn/prefs.go b/ipn/prefs.go index 81dd1c1c3dc49..9f98465d2d883 100644 --- a/ipn/prefs.go +++ b/ipn/prefs.go @@ -207,6 +207,12 @@ type Prefs struct { // control server. AdvertiseServices []string + // Sync is whether this node should sync its configuration from + // the control plane. If unset, this defaults to true. + // This exists primarily for testing, to verify that netmap caching + // and offline operation work correctly. + Sync opt.Bool + // NoSNAT specifies whether to source NAT traffic going to // destinations in AdvertiseRoutes. The default is to apply source // NAT, which makes the traffic appear to come from the router @@ -277,14 +283,17 @@ type Prefs struct { // RelayServerPort is the UDP port number for the relay server to bind to, // on all interfaces. A non-nil zero value signifies a random unused port // should be used. A nil value signifies relay server functionality - // should be disabled. This field is currently experimental, and therefore - // no guarantees are made about its current naming and functionality when - // non-nil/enabled. - RelayServerPort *int `json:",omitempty"` + // should be disabled. + RelayServerPort *uint16 `json:",omitempty"` + + // RelayServerStaticEndpoints are static IP:port endpoints to advertise as + // candidates for relay connections. Only relevant when RelayServerPort is + // non-nil. + RelayServerStaticEndpoints []netip.AddrPort `json:",omitempty"` // AllowSingleHosts was a legacy field that was always true // for the past 4.5 years. It controlled whether Tailscale - // peers got /32 or /127 routes for each other. + // peers got /32 or /128 routes for each other. // As of 2024-05-17 we're starting to ignore it, but to let // people still downgrade Tailscale versions and not break // all peer-to-peer networking we still write it to disk (as JSON) @@ -344,37 +353,39 @@ type AppConnectorPrefs struct { type MaskedPrefs struct { Prefs - ControlURLSet bool `json:",omitempty"` - RouteAllSet bool `json:",omitempty"` - ExitNodeIDSet bool `json:",omitempty"` - ExitNodeIPSet bool `json:",omitempty"` - AutoExitNodeSet bool `json:",omitempty"` - InternalExitNodePriorSet bool `json:",omitempty"` // Internal; can't be set by LocalAPI clients - ExitNodeAllowLANAccessSet bool `json:",omitempty"` - CorpDNSSet bool `json:",omitempty"` - RunSSHSet bool `json:",omitempty"` - RunWebClientSet bool `json:",omitempty"` - WantRunningSet bool `json:",omitempty"` - LoggedOutSet bool `json:",omitempty"` - ShieldsUpSet bool `json:",omitempty"` - AdvertiseTagsSet bool `json:",omitempty"` - HostnameSet bool `json:",omitempty"` - NotepadURLsSet bool `json:",omitempty"` - ForceDaemonSet bool `json:",omitempty"` - EggSet bool `json:",omitempty"` - AdvertiseRoutesSet bool `json:",omitempty"` - AdvertiseServicesSet bool `json:",omitempty"` - NoSNATSet bool `json:",omitempty"` - NoStatefulFilteringSet bool `json:",omitempty"` - NetfilterModeSet bool `json:",omitempty"` - OperatorUserSet bool `json:",omitempty"` - ProfileNameSet bool `json:",omitempty"` - AutoUpdateSet AutoUpdatePrefsMask `json:",omitempty"` - AppConnectorSet bool `json:",omitempty"` - PostureCheckingSet bool `json:",omitempty"` - NetfilterKindSet bool `json:",omitempty"` - DriveSharesSet bool `json:",omitempty"` - RelayServerPortSet bool `json:",omitempty"` + ControlURLSet bool `json:",omitempty"` + RouteAllSet bool `json:",omitempty"` + ExitNodeIDSet bool `json:",omitempty"` + ExitNodeIPSet bool `json:",omitempty"` + AutoExitNodeSet bool `json:",omitempty"` + InternalExitNodePriorSet bool `json:",omitempty"` // Internal; can't be set by LocalAPI clients + ExitNodeAllowLANAccessSet bool `json:",omitempty"` + CorpDNSSet bool `json:",omitempty"` + RunSSHSet bool `json:",omitempty"` + RunWebClientSet bool `json:",omitempty"` + WantRunningSet bool `json:",omitempty"` + LoggedOutSet bool `json:",omitempty"` + ShieldsUpSet bool `json:",omitempty"` + AdvertiseTagsSet bool `json:",omitempty"` + HostnameSet bool `json:",omitempty"` + NotepadURLsSet bool `json:",omitempty"` + ForceDaemonSet bool `json:",omitempty"` + EggSet bool `json:",omitempty"` + AdvertiseRoutesSet bool `json:",omitempty"` + AdvertiseServicesSet bool `json:",omitempty"` + SyncSet bool `json:",omitzero"` + NoSNATSet bool `json:",omitempty"` + NoStatefulFilteringSet bool `json:",omitempty"` + NetfilterModeSet bool `json:",omitempty"` + OperatorUserSet bool `json:",omitempty"` + ProfileNameSet bool `json:",omitempty"` + AutoUpdateSet AutoUpdatePrefsMask `json:",omitzero"` + AppConnectorSet bool `json:",omitempty"` + PostureCheckingSet bool `json:",omitempty"` + NetfilterKindSet bool `json:",omitempty"` + DriveSharesSet bool `json:",omitempty"` + RelayServerPortSet bool `json:",omitempty"` + RelayServerStaticEndpointsSet bool `json:",omitzero"` } // SetsInternal reports whether mp has any of the Internal*Set field bools set @@ -547,6 +558,9 @@ func (p *Prefs) pretty(goos string) string { if p.LoggedOut { sb.WriteString("loggedout=true ") } + if p.Sync.EqualBool(false) { + sb.WriteString("sync=false ") + } if p.ForceDaemon { sb.WriteString("server=true ") } @@ -611,6 +625,9 @@ func (p *Prefs) pretty(goos string) string { if buildfeatures.HasRelayServer && p.RelayServerPort != nil { fmt.Fprintf(&sb, "relayServerPort=%d ", *p.RelayServerPort) } + if buildfeatures.HasRelayServer && len(p.RelayServerStaticEndpoints) > 0 { + fmt.Fprintf(&sb, "relayServerStaticEndpoints=%v ", p.RelayServerStaticEndpoints) + } if p.Persist != nil { sb.WriteString(p.Persist.Pretty()) } else { @@ -653,6 +670,7 @@ func (p *Prefs) Equals(p2 *Prefs) bool { p.ExitNodeAllowLANAccess == p2.ExitNodeAllowLANAccess && p.CorpDNS == p2.CorpDNS && p.RunSSH == p2.RunSSH && + p.Sync.Normalized() == p2.Sync.Normalized() && p.RunWebClient == p2.RunWebClient && p.WantRunning == p2.WantRunning && p.LoggedOut == p2.LoggedOut && @@ -674,7 +692,8 @@ func (p *Prefs) Equals(p2 *Prefs) bool { p.PostureChecking == p2.PostureChecking && slices.EqualFunc(p.DriveShares, p2.DriveShares, drive.SharesEqual) && p.NetfilterKind == p2.NetfilterKind && - compareIntPtrs(p.RelayServerPort, p2.RelayServerPort) + compareUint16Ptrs(p.RelayServerPort, p2.RelayServerPort) && + slices.Equal(p.RelayServerStaticEndpoints, p2.RelayServerStaticEndpoints) } func (au AutoUpdatePrefs) Pretty() string { @@ -694,7 +713,7 @@ func (ap AppConnectorPrefs) Pretty() string { return "" } -func compareIntPtrs(a, b *int) bool { +func compareUint16Ptrs(a, b *uint16) bool { if (a == nil) != (b == nil) { return false } @@ -956,10 +975,15 @@ func PrefsFromBytes(b []byte, base *Prefs) error { if len(b) == 0 { return nil } - return json.Unmarshal(b, base) } +func (p *Prefs) normalizeOptBools() { + if p.Sync == opt.ExplicitlyUnset { + p.Sync = "" + } +} + var jsonEscapedZero = []byte(`\u0000`) // LoadPrefsWindows loads a legacy relaynode config file into Prefs with diff --git a/ipn/prefs_test.go b/ipn/prefs_test.go index 3339a631ce827..aa152843a5af9 100644 --- a/ipn/prefs_test.go +++ b/ipn/prefs_test.go @@ -57,6 +57,7 @@ func TestPrefsEqual(t *testing.T) { "Egg", "AdvertiseRoutes", "AdvertiseServices", + "Sync", "NoSNAT", "NoStatefulFiltering", "NetfilterMode", @@ -68,6 +69,7 @@ func TestPrefsEqual(t *testing.T) { "NetfilterKind", "DriveShares", "RelayServerPort", + "RelayServerStaticEndpoints", "AllowSingleHosts", "Persist", } @@ -76,7 +78,7 @@ func TestPrefsEqual(t *testing.T) { have, prefsHandles) } - relayServerPort := func(port int) *int { + relayServerPort := func(port uint16) *uint16 { return &port } nets := func(strs ...string) (ns []netip.Prefix) { @@ -89,6 +91,16 @@ func TestPrefsEqual(t *testing.T) { } return ns } + aps := func(strs ...string) (ret []netip.AddrPort) { + for _, s := range strs { + n, err := netip.ParseAddrPort(s) + if err != nil { + panic(err) + } + ret = append(ret, n) + } + return ret + } tests := []struct { a, b *Prefs want bool @@ -368,6 +380,16 @@ func TestPrefsEqual(t *testing.T) { &Prefs{RelayServerPort: relayServerPort(1)}, false, }, + { + &Prefs{RelayServerStaticEndpoints: aps("[2001:db8::1]:40000", "192.0.2.1:40000")}, + &Prefs{RelayServerStaticEndpoints: aps("[2001:db8::1]:40000", "192.0.2.1:40000")}, + true, + }, + { + &Prefs{RelayServerStaticEndpoints: aps("[2001:db8::1]:40000", "192.0.2.2:40000")}, + &Prefs{RelayServerStaticEndpoints: aps("[2001:db8::1]:40000", "192.0.2.1:40000")}, + false, + }, } for i, tt := range tests { got := tt.a.Equals(tt.b) @@ -404,6 +426,7 @@ func checkPrefs(t *testing.T, p Prefs) { if err != nil { t.Fatalf("PrefsFromBytes(p2) failed: bytes=%q; err=%v\n", p2.ToBytes(), err) } + p2b.normalizeOptBools() p2p := p2.Pretty() p2bp := p2b.Pretty() t.Logf("\np2p: %#v\np2bp: %#v\n", p2p, p2bp) @@ -419,6 +442,42 @@ func checkPrefs(t *testing.T, p Prefs) { } } +// PrefsFromBytes documents that it preserves fields unset in the JSON. +// This verifies that stays true. +func TestPrefsFromBytesPreservesOldValues(t *testing.T) { + tests := []struct { + name string + old Prefs + json []byte + want Prefs + }{ + { + name: "preserve-control-url", + old: Prefs{ControlURL: "https://foo"}, + json: []byte(`{"RouteAll": true}`), + want: Prefs{ControlURL: "https://foo", RouteAll: true}, + }, + { + name: "opt.Bool", // test that we don't normalize it early + old: Prefs{Sync: "unset"}, + json: []byte(`{}`), + want: Prefs{Sync: "unset"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + old := tt.old // shallow + err := PrefsFromBytes(tt.json, &old) + if err != nil { + t.Fatalf("PrefsFromBytes failed: %v", err) + } + if !old.Equals(&tt.want) { + t.Fatalf("got %+v; want %+v", old, tt.want) + } + }) + } +} + func TestBasicPrefs(t *testing.T) { tstest.PanicOnLog() @@ -501,7 +560,7 @@ func TestPrefsPretty(t *testing.T) { }, }, "linux", - `Prefs{ra=false dns=false want=false routes=[] nf=off update=off Persist{o=, n=[B1VKl] u=""}}`, + `Prefs{ra=false dns=false want=false routes=[] nf=off update=off Persist{o=, n=[B1VKl] u="" ak=-}}`, }, { Prefs{ @@ -591,6 +650,11 @@ func TestPrefsPretty(t *testing.T) { "linux", `Prefs{ra=false dns=false want=false routes=[] nf=off update=off Persist=nil}`, }, + { + Prefs{Sync: "false"}, + "linux", + "Prefs{ra=false dns=false want=false sync=false routes=[] nf=off update=off Persist=nil}", + }, } for i, tt := range tests { got := tt.p.pretty(tt.os) diff --git a/ipn/serve.go b/ipn/serve.go index a0f1334d7d150..1f15578893d84 100644 --- a/ipn/serve.go +++ b/ipn/serve.go @@ -17,6 +17,7 @@ import ( "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" "tailscale.com/types/ipproto" + "tailscale.com/util/dnsname" "tailscale.com/util/mak" "tailscale.com/util/set" ) @@ -149,6 +150,12 @@ type TCPPortHandler struct { // SNI name with this value. It is only used if TCPForward is non-empty. // (the HTTPS mode uses ServeConfig.Web) TerminateTLS string `json:",omitempty"` + + // ProxyProtocol indicates whether to send a PROXY protocol header + // before forwarding the connection to TCPForward. + // + // This is only valid if TCPForward is non-empty. + ProxyProtocol int `json:",omitzero"` } // HTTPHandler is either a path or a proxy to serve. @@ -160,8 +167,19 @@ type HTTPHandler struct { Text string `json:",omitempty"` // plaintext to serve (primarily for testing) + AcceptAppCaps []tailcfg.PeerCapability `json:",omitempty"` // peer capabilities to forward in grant header, e.g. example.com/cap/mon + + // Redirect, if not empty, is the target URL to redirect requests to. + // By default, we redirect with HTTP 302 (Found) status. + // If Redirect starts with ':', then we use that status instead. + // + // The target URL supports the following expansion variables: + // - ${HOST}: replaced with the request's Host header value + // - ${REQUEST_URI}: replaced with the request's full URI (path and query string) + Redirect string `json:",omitempty"` + // TODO(bradfitz): bool to not enumerate directories? TTL on mapping for - // temporary ones? Error codes? Redirects? + // temporary ones? Error codes? } // WebHandlerExists reports whether if the ServeConfig Web handler exists for @@ -220,6 +238,20 @@ func (sc *ServeConfig) HasPathHandler() bool { } } + if sc.Services != nil { + for _, serviceConfig := range sc.Services { + if serviceConfig.Web != nil { + for _, webServerConfig := range serviceConfig.Web { + for _, httpHandler := range webServerConfig.Handlers { + if httpHandler.Path != "" { + return true + } + } + } + } + } + } + if sc.Foreground != nil { for _, fgConfig := range sc.Foreground { if fgConfig.HasPathHandler() { @@ -393,7 +425,10 @@ func (sc *ServeConfig) SetWebHandler(handler *HTTPHandler, host string, port uin // connections from the given port. If terminateTLS is true, TLS connections // are terminated with only the given host name permitted before passing them // to the fwdAddr. -func (sc *ServeConfig) SetTCPForwarding(port uint16, fwdAddr string, terminateTLS bool, host string) { +// +// If proxyProtocol is non-zero, the corresponding PROXY protocol version +// header is sent before forwarding the connection. +func (sc *ServeConfig) SetTCPForwarding(port uint16, fwdAddr string, terminateTLS bool, proxyProtocol int, host string) { if sc == nil { sc = new(ServeConfig) } @@ -406,11 +441,15 @@ func (sc *ServeConfig) SetTCPForwarding(port uint16, fwdAddr string, terminateTL } tcpPortHandler = &svcConfig.TCP } - mak.Set(tcpPortHandler, port, &TCPPortHandler{TCPForward: fwdAddr}) + handler := &TCPPortHandler{ + TCPForward: fwdAddr, + ProxyProtocol: proxyProtocol, // can be 0 + } if terminateTLS { - (*tcpPortHandler)[port].TerminateTLS = host + handler.TerminateTLS = host } + mak.Set(tcpPortHandler, port, handler) } // SetFunnel sets the sc.AllowFunnel value for the given host and port. @@ -649,7 +688,8 @@ func CheckFunnelPort(wantedPort uint16, node *ipnstate.PeerStatus) error { // ExpandProxyTargetValue expands the supported target values to be proxied // allowing for input values to be a port number, a partial URL, or a full URL -// including a path. +// including a path. If it's for a service, remote addresses are allowed and +// there doesn't have to be a port specified. // // examples: // - 3000 @@ -659,17 +699,25 @@ func CheckFunnelPort(wantedPort uint16, node *ipnstate.PeerStatus) error { // - https://localhost:3000 // - https-insecure://localhost:3000 // - https-insecure://localhost:3000/foo +// - https://tailscale.com func ExpandProxyTargetValue(target string, supportedSchemes []string, defaultScheme string) (string, error) { const host = "127.0.0.1" + // empty target is invalid + if target == "" { + return "", fmt.Errorf("empty target") + } + // support target being a port number if port, err := strconv.ParseUint(target, 10, 16); err == nil { return fmt.Sprintf("%s://%s:%d", defaultScheme, host, port), nil } + hasScheme := true // prepend scheme if not present if !strings.Contains(target, "://") { target = defaultScheme + "://" + target + hasScheme = false } // make sure we can parse the target @@ -683,16 +731,28 @@ func ExpandProxyTargetValue(target string, supportedSchemes []string, defaultSch return "", fmt.Errorf("must be a URL starting with one of the supported schemes: %v", supportedSchemes) } - // validate the host. - switch u.Hostname() { - case "localhost", "127.0.0.1": - default: - return "", errors.New("only localhost or 127.0.0.1 proxies are currently supported") + // validate port according to host. + if u.Hostname() == "localhost" || u.Hostname() == "127.0.0.1" || u.Hostname() == "::1" { + // require port for localhost targets + if u.Port() == "" { + return "", fmt.Errorf("port required for localhost target %q", target) + } + } else { + validHN := dnsname.ValidHostname(u.Hostname()) == nil + validIP := net.ParseIP(u.Hostname()) != nil + if !validHN && !validIP { + return "", fmt.Errorf("invalid hostname or IP address %q", u.Hostname()) + } + // require scheme for non-localhost targets + if !hasScheme { + return "", fmt.Errorf("non-localhost target %q must include a scheme", target) + } } - - // validate the port port, err := strconv.ParseUint(u.Port(), 10, 16) if err != nil || port == 0 { + if u.Port() == "" { + return u.String(), nil // allow no port for remote destinations + } return "", fmt.Errorf("invalid port %q", u.Port()) } @@ -756,6 +816,7 @@ func (v ServeConfigView) FindServiceTCP(svcName tailcfg.ServiceName, port uint16 return svcCfg.TCP().GetOk(port) } +// FindServiceWeb returns the web handler for the service's host-port. func (v ServeConfigView) FindServiceWeb(svcName tailcfg.ServiceName, hp HostPort) (res WebServerConfigView, ok bool) { if svcCfg, ok := v.Services().GetOk(svcName); ok { if res, ok := svcCfg.Web().GetOk(hp); ok { @@ -769,10 +830,9 @@ func (v ServeConfigView) FindServiceWeb(svcName tailcfg.ServiceName, hp HostPort // prefers a foreground match first followed by a background search if none // existed. func (v ServeConfigView) FindTCP(port uint16) (res TCPPortHandlerView, ok bool) { - for _, conf := range v.Foreground().All() { - if res, ok := conf.TCP().GetOk(port); ok { - return res, ok - } + res, ok = v.FindForegroundTCP(port) + if ok { + return res, ok } return v.TCP().GetOk(port) } @@ -789,6 +849,17 @@ func (v ServeConfigView) FindWeb(hp HostPort) (res WebServerConfigView, ok bool) return v.Web().GetOk(hp) } +// FindForegroundTCP returns the first foreground TCP handler matching the input +// port. +func (v ServeConfigView) FindForegroundTCP(port uint16) (res TCPPortHandlerView, ok bool) { + for _, conf := range v.Foreground().All() { + if res, ok := conf.TCP().GetOk(port); ok { + return res, ok + } + } + return res, false +} + // HasAllowFunnel returns whether this config has at least one AllowFunnel // set in the background or foreground configs. func (v ServeConfigView) HasAllowFunnel() bool { @@ -817,17 +888,6 @@ func (v ServeConfigView) HasFunnelForTarget(target HostPort) bool { return false } -// CheckValidServicesConfig reports whether the ServeConfig has -// invalid service configurations. -func (sc *ServeConfig) CheckValidServicesConfig() error { - for svcName, service := range sc.Services { - if err := service.checkValidConfig(); err != nil { - return fmt.Errorf("invalid service configuration for %q: %w", svcName, err) - } - } - return nil -} - // ServicePortRange returns the list of tailcfg.ProtoPortRange that represents // the proto/ports pairs that are being served by the service. // @@ -865,17 +925,3 @@ func (v ServiceConfigView) ServicePortRange() []tailcfg.ProtoPortRange { } return ranges } - -// ErrServiceConfigHasBothTCPAndTun signals that a service -// in Tun mode cannot also has TCP or Web handlers set. -var ErrServiceConfigHasBothTCPAndTun = errors.New("the VIP Service configuration can not set TUN at the same time as TCP or Web") - -// checkValidConfig checks if the service configuration is valid. -// Currently, the only invalid configuration is when the service is in Tun mode -// and has TCP or Web handlers. -func (v *ServiceConfig) checkValidConfig() error { - if v.Tun && (len(v.TCP) > 0 || len(v.Web) > 0) { - return ErrServiceConfigHasBothTCPAndTun - } - return nil -} diff --git a/ipn/serve_test.go b/ipn/serve_test.go index 7028c1e17cd71..5e0f4a43a38e7 100644 --- a/ipn/serve_test.go +++ b/ipn/serve_test.go @@ -117,6 +117,36 @@ func TestHasPathHandler(t *testing.T) { }, want: false, }, + { + name: "with-service-path-handler", + cfg: ServeConfig{ + Services: map[tailcfg.ServiceName]*ServiceConfig{ + "svc:foo": { + Web: map[HostPort]*WebServerConfig{ + "foo.test.ts.net:443": {Handlers: map[string]*HTTPHandler{ + "/": {Path: "/tmp"}, + }}, + }, + }, + }, + }, + want: true, + }, + { + name: "with-service-proxy-handler", + cfg: ServeConfig{ + Services: map[tailcfg.ServiceName]*ServiceConfig{ + "svc:foo": { + Web: map[HostPort]*WebServerConfig{ + "foo.test.ts.net:443": {Handlers: map[string]*HTTPHandler{ + "/": {Proxy: "http://127.0.0.1:3000"}, + }}, + }, + }, + }, + }, + want: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -260,12 +290,16 @@ func TestExpandProxyTargetDev(t *testing.T) { {name: "https+insecure-scheme", input: "https+insecure://localhost:8080", expected: "https+insecure://localhost:8080"}, {name: "change-default-scheme", input: "localhost:8080", defaultScheme: "https", expected: "https://localhost:8080"}, {name: "change-supported-schemes", input: "localhost:8080", defaultScheme: "tcp", supportedSchemes: []string{"tcp"}, expected: "tcp://localhost:8080"}, + {name: "remote-target", input: "https://example.com:8080", expected: "https://example.com:8080"}, + {name: "remote-IP-target", input: "http://120.133.20.2:8080", expected: "http://120.133.20.2:8080"}, + {name: "remote-target-no-port", input: "https://example.com", expected: "https://example.com"}, // errors {name: "invalid-port", input: "localhost:9999999", wantErr: true}, + {name: "invalid-hostname", input: "192.168.1:8080", wantErr: true}, {name: "unsupported-scheme", input: "ftp://localhost:8080", expected: "", wantErr: true}, - {name: "not-localhost", input: "https://tailscale.com:8080", expected: "", wantErr: true}, {name: "empty-input", input: "", expected: "", wantErr: true}, + {name: "localhost-no-port", input: "localhost", expected: "", wantErr: true}, } for _, tt := range tests { diff --git a/ipn/store.go b/ipn/store.go index 9da5288c0d371..2034ae09a92f9 100644 --- a/ipn/store.go +++ b/ipn/store.go @@ -10,6 +10,8 @@ import ( "fmt" "net" "strconv" + + "tailscale.com/health" ) // ErrStateNotExist is returned by StateStore.ReadState when the @@ -60,6 +62,19 @@ const ( TaildropReceivedKey = StateKey("_taildrop-received") ) +// StateStoreHealth is a Warnable set when store.New fails at startup. If +// unhealthy, we block all login attempts and return a health message in status +// responses. +var StateStoreHealth = health.Register(&health.Warnable{ + Code: "state-store-health", + Severity: health.SeverityHigh, + Title: "Tailscale state store failed to initialize", + Text: func(args health.Args) string { + return fmt.Sprintf("State store failed to initialize, Tailscale will not work until this is resolved. See https://tailscale.com/s/state-store-init-error. Error: %s", args[health.ArgError]) + }, + ImpactsConnectivity: true, +}) + // CurrentProfileID returns the StateKey that stores the // current profile ID. The value is a JSON-encoded LoginProfile. // If the userID is empty, the key returned is CurrentProfileStateKey, diff --git a/ipn/store/kubestore/store_kube.go b/ipn/store/kubestore/store_kube.go index f48237c057142..ba45409ed7903 100644 --- a/ipn/store/kubestore/store_kube.go +++ b/ipn/store/kubestore/store_kube.go @@ -6,8 +6,8 @@ package kubestore import ( "context" + "encoding/json" "fmt" - "log" "net" "net/http" "os" @@ -57,6 +57,8 @@ type Store struct { certShareMode string // 'ro', 'rw', or empty podName string + logf logger.Logf + // memory holds the latest tailscale state. Writes write state to a kube // Secret and memory, Reads read from memory. memory mem.Store @@ -96,6 +98,7 @@ func newWithClient(logf logger.Logf, c kubeclient.Client, secretName string) (*S canPatch: canPatch, secretName: secretName, podName: os.Getenv("POD_NAME"), + logf: logf, } if envknob.IsCertShareReadWriteMode() { s.certShareMode = "rw" @@ -113,11 +116,11 @@ func newWithClient(logf logger.Logf, c kubeclient.Client, secretName string) (*S if err := s.loadCerts(context.Background(), sel); err != nil { // We will attempt to again retrieve the certs from Secrets when a request for an HTTPS endpoint // is received. - log.Printf("[unexpected] error loading TLS certs: %v", err) + s.logf("[unexpected] error loading TLS certs: %v", err) } } if s.certShareMode == "ro" { - go s.runCertReload(context.Background(), logf) + go s.runCertReload(context.Background()) } return s, nil } @@ -147,7 +150,7 @@ func (s *Store) WriteState(id ipn.StateKey, bs []byte) (err error) { // of a Tailscale Kubernetes node's state Secret. func (s *Store) WriteTLSCertAndKey(domain string, cert, key []byte) (err error) { if s.certShareMode == "ro" { - log.Printf("[unexpected] TLS cert and key write in read-only mode") + s.logf("[unexpected] TLS cert and key write in read-only mode") } if err := dnsname.ValidHostname(domain); err != nil { return fmt.Errorf("invalid domain name %q: %w", domain, err) @@ -258,11 +261,11 @@ func (s *Store) updateSecret(data map[string][]byte, secretName string) (err err defer func() { if err != nil { if err := s.client.Event(ctx, eventTypeWarning, reasonTailscaleStateUpdateFailed, err.Error()); err != nil { - log.Printf("kubestore: error creating tailscaled state update Event: %v", err) + s.logf("kubestore: error creating tailscaled state update Event: %v", err) } } else { if err := s.client.Event(ctx, eventTypeNormal, reasonTailscaleStateUpdated, "Successfully updated tailscaled state Secret"); err != nil { - log.Printf("kubestore: error creating tailscaled state Event: %v", err) + s.logf("kubestore: error creating tailscaled state Event: %v", err) } } cancel() @@ -342,17 +345,72 @@ func (s *Store) loadState() (err error) { return ipn.ErrStateNotExist } if err := s.client.Event(ctx, eventTypeWarning, reasonTailscaleStateLoadFailed, err.Error()); err != nil { - log.Printf("kubestore: error creating Event: %v", err) + s.logf("kubestore: error creating Event: %v", err) } return err } if err := s.client.Event(ctx, eventTypeNormal, reasonTailscaleStateLoaded, "Successfully loaded tailscaled state from Secret"); err != nil { - log.Printf("kubestore: error creating Event: %v", err) + s.logf("kubestore: error creating Event: %v", err) + } + data, err := s.maybeStripAttestationKeyFromProfile(secret.Data) + if err != nil { + return fmt.Errorf("error attempting to strip attestation data from state Secret: %w", err) } - s.memory.LoadFromMap(secret.Data) + s.memory.LoadFromMap(data) return nil } +// maybeStripAttestationKeyFromProfile removes the hardware attestation key +// field from serialized Tailscale profile. This is done to recover from a bug +// introduced in 1.92, where node-bound hardware attestation keys were added to +// Tailscale states stored in Kubernetes Secrets. +// See https://github.com/tailscale/tailscale/issues/18302 +// TODO(irbekrm): it would be good if we could somehow determine when we no +// longer need to run this check. +func (s *Store) maybeStripAttestationKeyFromProfile(data map[string][]byte) (map[string][]byte, error) { + prefsKey := extractPrefsKey(data) + prefsBytes, ok := data[prefsKey] + if !ok { + return data, nil + } + var prefs map[string]any + if err := json.Unmarshal(prefsBytes, &prefs); err != nil { + s.logf("[unexpected]: kube store: failed to unmarshal prefs data") + // don't error as in most cases the state won't have the attestation key + return data, nil + } + + config, ok := prefs["Config"].(map[string]any) + if !ok { + return data, nil + } + if _, hasKey := config["AttestationKey"]; !hasKey { + return data, nil + } + s.logf("kube store: found redundant attestation key, deleting") + delete(config, "AttestationKey") + prefsBytes, err := json.Marshal(prefs) + if err != nil { + return nil, fmt.Errorf("[unexpected] kube store: failed to marshal profile after removing attestation key: %v", err) + } + data[prefsKey] = prefsBytes + if err := s.updateSecret(map[string][]byte{prefsKey: prefsBytes}, s.secretName); err != nil { + // don't error out - this might have been a temporary kube API server + // connection issue. The key will be removed from the in-memory cache + // and we'll retry updating the Secret on the next restart. + s.logf("kube store: error updating Secret after stripping AttestationKey: %v", err) + } + return data, nil +} + +const currentProfileKey = "_current-profile" + +// extractPrefs returns the key at which Tailscale prefs are stored in the +// provided Secret data. +func extractPrefsKey(data map[string][]byte) string { + return string(data[currentProfileKey]) +} + // runCertReload relists and reloads all TLS certs for endpoints shared by this // node from Secrets other than the state Secret to ensure that renewed certs get eventually loaded. // It is not critical to reload a cert immediately after @@ -361,7 +419,7 @@ func (s *Store) loadState() (err error) { // Note that if shared certs are not found in memory on an HTTPS request, we // do a Secret lookup, so this mechanism does not need to ensure that newly // added Ingresses' certs get loaded. -func (s *Store) runCertReload(ctx context.Context, logf logger.Logf) { +func (s *Store) runCertReload(ctx context.Context) { ticker := time.NewTicker(time.Hour * 24) defer ticker.Stop() for { @@ -371,7 +429,7 @@ func (s *Store) runCertReload(ctx context.Context, logf logger.Logf) { case <-ticker.C: sel := s.certSecretSelector() if err := s.loadCerts(ctx, sel); err != nil { - logf("[unexpected] error reloading TLS certs: %v", err) + s.logf("[unexpected] error reloading TLS certs: %v", err) } } } diff --git a/ipn/store/kubestore/store_kube_test.go b/ipn/store/kubestore/store_kube_test.go index 8c8e5e87075f0..44a4bbb7fc14d 100644 --- a/ipn/store/kubestore/store_kube_test.go +++ b/ipn/store/kubestore/store_kube_test.go @@ -20,6 +20,90 @@ import ( "tailscale.com/kube/kubetypes" ) +func TestKubernetesPodMigrationWithTPMAttestationKey(t *testing.T) { + stateWithAttestationKey := `{ + "Config": { + "NodeID": "nSTABLE123456", + "AttestationKey": { + "tpmPrivate": "c2Vuc2l0aXZlLXRwbS1kYXRhLXRoYXQtb25seS13b3Jrcy1vbi1vcmlnaW5hbC1ub2Rl", + "tpmPublic": "cHVibGljLXRwbS1kYXRhLWZvci1hdHRlc3RhdGlvbi1rZXk=" + } + } + }` + + secretData := map[string][]byte{ + "profile-abc123": []byte(stateWithAttestationKey), + "_current-profile": []byte("profile-abc123"), + } + + client := &kubeclient.FakeClient{ + GetSecretImpl: func(ctx context.Context, name string) (*kubeapi.Secret, error) { + return &kubeapi.Secret{Data: secretData}, nil + }, + CheckSecretPermissionsImpl: func(ctx context.Context, name string) (bool, bool, error) { + return true, true, nil + }, + JSONPatchResourceImpl: func(ctx context.Context, name, resourceType string, patches []kubeclient.JSONPatch) error { + for _, p := range patches { + if p.Op == "add" && p.Path == "/data" { + secretData = p.Value.(map[string][]byte) + } + } + return nil + }, + } + + store := &Store{ + client: client, + canPatch: true, + secretName: "ts-state", + memory: mem.Store{}, + logf: t.Logf, + } + + if err := store.loadState(); err != nil { + t.Fatalf("loadState failed: %v", err) + } + + // Verify we can read the state from the store + stateBytes, err := store.ReadState("profile-abc123") + if err != nil { + t.Fatalf("ReadState failed: %v", err) + } + + // The state should be readable as JSON + var state map[string]json.RawMessage + if err := json.Unmarshal(stateBytes, &state); err != nil { + t.Fatalf("failed to unmarshal state: %v", err) + } + + // Verify the Config field exists + configRaw, ok := state["Config"] + if !ok { + t.Fatal("Config field not found in state") + } + + // Parse the Config to verify fields are preserved + var config map[string]json.RawMessage + if err := json.Unmarshal(configRaw, &config); err != nil { + t.Fatalf("failed to unmarshal Config: %v", err) + } + + // The AttestationKey should be stripped by the kubestore + if _, hasAttestation := config["AttestationKey"]; hasAttestation { + t.Error("AttestationKey should be stripped from state loaded by kubestore") + } + + // Verify other fields are preserved + var nodeID string + if err := json.Unmarshal(config["NodeID"], &nodeID); err != nil { + t.Fatalf("failed to unmarshal NodeID: %v", err) + } + if nodeID != "nSTABLE123456" { + t.Errorf("NodeID mismatch: got %q, want %q", nodeID, "nSTABLE123456") + } +} + func TestWriteState(t *testing.T) { tests := []struct { name string diff --git a/k8s-operator/api.md b/k8s-operator/api.md index 979d199cb0783..3a4e692d902ec 100644 --- a/k8s-operator/api.md +++ b/k8s-operator/api.md @@ -887,7 +887,7 @@ _Appears in:_ - +RecorderSpec describes a tsrecorder instance to be deployed in the cluster @@ -900,6 +900,7 @@ _Appears in:_ | `tags` _[Tags](#tags)_ | Tags that the Tailscale device will be tagged with. Defaults to [tag:k8s].
If you specify custom tags here, make sure you also make the operator
an owner of these tags.
See https://tailscale.com/kb/1236/kubernetes-operator/#setting-up-the-kubernetes-operator.
Tags cannot be changed once a Recorder node has been created.
Tag values must be in form ^tag:[a-zA-Z][a-zA-Z0-9-]*$. | | Pattern: `^tag:[a-zA-Z][a-zA-Z0-9-]*$`
Type: string
| | `enableUI` _boolean_ | Set to true to enable the Recorder UI. The UI lists and plays recorded sessions.
The UI will be served at :443. Defaults to false.
Corresponds to --ui tsrecorder flag https://tailscale.com/kb/1246/tailscale-ssh-session-recording#deploy-a-recorder-node.
Required if S3 storage is not set up, to ensure that recordings are accessible. | | | | `storage` _[Storage](#storage)_ | Configure where to store session recordings. By default, recordings will
be stored in a local ephemeral volume, and will not be persisted past the
lifetime of a specific pod. | | | +| `replicas` _integer_ | Replicas specifies how many instances of tsrecorder to run. Defaults to 1. | | Minimum: 0
| #### RecorderStatefulSet diff --git a/k8s-operator/apis/v1alpha1/types_proxyclass.go b/k8s-operator/apis/v1alpha1/types_proxyclass.go index 4026f90848ef1..670df3b95097e 100644 --- a/k8s-operator/apis/v1alpha1/types_proxyclass.go +++ b/k8s-operator/apis/v1alpha1/types_proxyclass.go @@ -352,12 +352,12 @@ type ServiceMonitor struct { type Labels map[string]LabelValue -func (l Labels) Parse() map[string]string { - if l == nil { +func (lb Labels) Parse() map[string]string { + if lb == nil { return nil } - m := make(map[string]string, len(l)) - for k, v := range l { + m := make(map[string]string, len(lb)) + for k, v := range lb { m[k] = string(v) } return m diff --git a/k8s-operator/apis/v1alpha1/types_recorder.go b/k8s-operator/apis/v1alpha1/types_recorder.go index 16a610b26d179..67cffbf09e969 100644 --- a/k8s-operator/apis/v1alpha1/types_recorder.go +++ b/k8s-operator/apis/v1alpha1/types_recorder.go @@ -44,6 +44,8 @@ type RecorderList struct { Items []Recorder `json:"items"` } +// RecorderSpec describes a tsrecorder instance to be deployed in the cluster +// +kubebuilder:validation:XValidation:rule="!(self.replicas > 1 && (!has(self.storage) || !has(self.storage.s3)))",message="S3 storage must be used when deploying multiple Recorder replicas" type RecorderSpec struct { // Configuration parameters for the Recorder's StatefulSet. The operator // deploys a StatefulSet for each Recorder resource. @@ -74,6 +76,11 @@ type RecorderSpec struct { // lifetime of a specific pod. // +optional Storage Storage `json:"storage,omitempty"` + + // Replicas specifies how many instances of tsrecorder to run. Defaults to 1. + // +optional + // +kubebuilder:validation:Minimum=0 + Replicas *int32 `json:"replicas,omitzero"` } type RecorderStatefulSet struct { diff --git a/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go b/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go index 7492f1e547395..ff0f3f6ace415 100644 --- a/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go +++ b/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go @@ -1068,6 +1068,11 @@ func (in *RecorderSpec) DeepCopyInto(out *RecorderSpec) { copy(*out, *in) } in.Storage.DeepCopyInto(&out.Storage) + if in.Replicas != nil { + in, out := &in.Replicas, &out.Replicas + *out = new(int32) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RecorderSpec. diff --git a/k8s-operator/sessionrecording/ws/conn_test.go b/k8s-operator/sessionrecording/ws/conn_test.go index f2fd4ea55f554..87205c4e6f610 100644 --- a/k8s-operator/sessionrecording/ws/conn_test.go +++ b/k8s-operator/sessionrecording/ws/conn_test.go @@ -99,7 +99,7 @@ func Test_conn_Read(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - l := zl.Sugar() + log := zl.Sugar() tc := &fakes.TestConn{} sr := &fakes.TestSessionRecorder{} rec := tsrecorder.New(sr, cl, cl.Now(), true, zl.Sugar()) @@ -110,7 +110,7 @@ func Test_conn_Read(t *testing.T) { c := &conn{ ctx: ctx, Conn: tc, - log: l, + log: log, hasTerm: true, initialCastHeaderSent: make(chan struct{}), rec: rec, diff --git a/kube/egressservices/egressservices.go b/kube/egressservices/egressservices.go index 2515f1bf3a476..56c874f31dbb1 100644 --- a/kube/egressservices/egressservices.go +++ b/kube/egressservices/egressservices.go @@ -69,12 +69,12 @@ var _ json.Unmarshaler = &PortMaps{} func (p *PortMaps) UnmarshalJSON(data []byte) error { *p = make(map[PortMap]struct{}) - var l []PortMap - if err := json.Unmarshal(data, &l); err != nil { + var v []PortMap + if err := json.Unmarshal(data, &v); err != nil { return err } - for _, pm := range l { + for _, pm := range v { (*p)[pm] = struct{}{} } @@ -82,12 +82,12 @@ func (p *PortMaps) UnmarshalJSON(data []byte) error { } func (p PortMaps) MarshalJSON() ([]byte, error) { - l := make([]PortMap, 0, len(p)) + v := make([]PortMap, 0, len(p)) for pm := range p { - l = append(l, pm) + v = append(v, pm) } - return json.Marshal(l) + return json.Marshal(v) } // Status represents the currently configured firewall rules for all egress diff --git a/kube/localclient/local-client.go b/kube/localclient/local-client.go index 5d541e3655ddb..550b3ae742c34 100644 --- a/kube/localclient/local-client.go +++ b/kube/localclient/local-client.go @@ -40,10 +40,10 @@ type localClient struct { lc *local.Client } -func (l *localClient) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (IPNBusWatcher, error) { - return l.lc.WatchIPNBus(ctx, mask) +func (lc *localClient) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (IPNBusWatcher, error) { + return lc.lc.WatchIPNBus(ctx, mask) } -func (l *localClient) CertPair(ctx context.Context, domain string) ([]byte, []byte, error) { - return l.lc.CertPair(ctx, domain) +func (lc *localClient) CertPair(ctx context.Context, domain string) ([]byte, []byte, error) { + return lc.lc.CertPair(ctx, domain) } diff --git a/licenses/apple.md b/licenses/apple.md index 4c50e95595742..2a795ddbb9cdf 100644 --- a/licenses/apple.md +++ b/licenses/apple.md @@ -29,6 +29,7 @@ See also the dependencies in the [Tailscale CLI][]. - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.22.2/LICENSE)) - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.22.2/internal/sync/singleflight/LICENSE)) - [github.com/coreos/go-iptables/iptables](https://pkg.go.dev/github.com/coreos/go-iptables/iptables) ([Apache-2.0](https://github.com/coreos/go-iptables/blob/65c67c9f46e6/LICENSE)) + - [github.com/creachadair/msync/trigger](https://pkg.go.dev/github.com/creachadair/msync/trigger) ([BSD-3-Clause](https://github.com/creachadair/msync/blob/v0.7.1/LICENSE)) - [github.com/digitalocean/go-smbios/smbios](https://pkg.go.dev/github.com/digitalocean/go-smbios/smbios) ([Apache-2.0](https://github.com/digitalocean/go-smbios/blob/390a4f403a8e/LICENSE.md)) - [github.com/djherbis/times](https://pkg.go.dev/github.com/djherbis/times) ([MIT](https://github.com/djherbis/times/blob/v1.6.0/LICENSE)) - [github.com/fxamacker/cbor/v2](https://pkg.go.dev/github.com/fxamacker/cbor/v2) ([MIT](https://github.com/fxamacker/cbor/blob/v2.7.0/LICENSE)) @@ -67,13 +68,13 @@ See also the dependencies in the [Tailscale CLI][]. - [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE)) - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.42.0:LICENSE)) + - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.43.0:LICENSE)) - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/df929982:LICENSE)) - - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.44.0:LICENSE)) + - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.46.0:LICENSE)) - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.17.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.36.0:LICENSE)) - - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.35.0:LICENSE)) - - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.29.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.37.0:LICENSE)) + - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.36.0:LICENSE)) + - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.30.0:LICENSE)) - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.12.0:LICENSE)) - [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/9414b50a5633/LICENSE)) - [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE)) diff --git a/licenses/tailscale.md b/licenses/tailscale.md index 0ef5bcf61d5f8..163a76d404202 100644 --- a/licenses/tailscale.md +++ b/licenses/tailscale.md @@ -37,6 +37,7 @@ Some packages may only be included on certain architectures or operating systems - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.22.2/LICENSE)) - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.22.2/internal/sync/singleflight/LICENSE)) - [github.com/coder/websocket](https://pkg.go.dev/github.com/coder/websocket) ([ISC](https://github.com/coder/websocket/blob/v1.8.12/LICENSE.txt)) + - [github.com/creachadair/msync/trigger](https://pkg.go.dev/github.com/creachadair/msync/trigger) ([BSD-3-Clause](https://github.com/creachadair/msync/blob/v0.7.1/LICENSE)) - [github.com/creack/pty](https://pkg.go.dev/github.com/creack/pty) ([MIT](https://github.com/creack/pty/blob/v1.1.23/LICENSE)) - [github.com/dblohm7/wingoes](https://pkg.go.dev/github.com/dblohm7/wingoes) ([BSD-3-Clause](https://github.com/dblohm7/wingoes/blob/a09d6be7affa/LICENSE)) - [github.com/digitalocean/go-smbios/smbios](https://pkg.go.dev/github.com/digitalocean/go-smbios/smbios) ([Apache-2.0](https://github.com/digitalocean/go-smbios/blob/390a4f403a8e/LICENSE.md)) @@ -68,6 +69,7 @@ Some packages may only be included on certain architectures or operating systems - [github.com/mitchellh/go-ps](https://pkg.go.dev/github.com/mitchellh/go-ps) ([MIT](https://github.com/mitchellh/go-ps/blob/v1.0.0/LICENSE.md)) - [github.com/peterbourgon/ff/v3](https://pkg.go.dev/github.com/peterbourgon/ff/v3) ([Apache-2.0](https://github.com/peterbourgon/ff/blob/v3.4.0/LICENSE)) - [github.com/pierrec/lz4/v4](https://pkg.go.dev/github.com/pierrec/lz4/v4) ([BSD-3-Clause](https://github.com/pierrec/lz4/blob/v4.1.21/LICENSE)) + - [github.com/pires/go-proxyproto](https://pkg.go.dev/github.com/pires/go-proxyproto) ([Apache-2.0](https://github.com/pires/go-proxyproto/blob/v0.8.1/LICENSE)) - [github.com/pkg/sftp](https://pkg.go.dev/github.com/pkg/sftp) ([BSD-2-Clause](https://github.com/pkg/sftp/blob/v1.13.6/LICENSE)) - [github.com/prometheus-community/pro-bing](https://pkg.go.dev/github.com/prometheus-community/pro-bing) ([MIT](https://github.com/prometheus-community/pro-bing/blob/v0.4.0/LICENSE)) - [github.com/skip2/go-qrcode](https://pkg.go.dev/github.com/skip2/go-qrcode) ([MIT](https://github.com/skip2/go-qrcode/blob/da1b6568686e/LICENSE)) diff --git a/licenses/windows.md b/licenses/windows.md index b284aa1361f5d..06a5712ceb509 100644 --- a/licenses/windows.md +++ b/licenses/windows.md @@ -15,6 +15,7 @@ Windows][]. See also the dependencies in the [Tailscale CLI][]. - [github.com/beorn7/perks/quantile](https://pkg.go.dev/github.com/beorn7/perks/quantile) ([MIT](https://github.com/beorn7/perks/blob/v1.0.1/LICENSE)) - [github.com/cespare/xxhash/v2](https://pkg.go.dev/github.com/cespare/xxhash/v2) ([MIT](https://github.com/cespare/xxhash/blob/v2.3.0/LICENSE.txt)) - [github.com/coder/websocket](https://pkg.go.dev/github.com/coder/websocket) ([ISC](https://github.com/coder/websocket/blob/v1.8.12/LICENSE.txt)) + - [github.com/creachadair/msync/trigger](https://pkg.go.dev/github.com/creachadair/msync/trigger) ([BSD-3-Clause](https://github.com/creachadair/msync/blob/v0.7.1/LICENSE)) - [github.com/dblohm7/wingoes](https://pkg.go.dev/github.com/dblohm7/wingoes) ([BSD-3-Clause](https://github.com/dblohm7/wingoes/blob/b75a8a7d7eb0/LICENSE)) - [github.com/djherbis/times](https://pkg.go.dev/github.com/djherbis/times) ([MIT](https://github.com/djherbis/times/blob/v1.6.0/LICENSE)) - [github.com/fxamacker/cbor/v2](https://pkg.go.dev/github.com/fxamacker/cbor/v2) ([MIT](https://github.com/fxamacker/cbor/blob/v2.7.0/LICENSE)) @@ -36,9 +37,9 @@ Windows][]. See also the dependencies in the [Tailscale CLI][]. - [github.com/munnerz/goautoneg](https://pkg.go.dev/github.com/munnerz/goautoneg) ([BSD-3-Clause](https://github.com/munnerz/goautoneg/blob/a7dc8b61c822/LICENSE)) - [github.com/nfnt/resize](https://pkg.go.dev/github.com/nfnt/resize) ([ISC](https://github.com/nfnt/resize/blob/83c6a9932646/LICENSE)) - [github.com/peterbourgon/diskv](https://pkg.go.dev/github.com/peterbourgon/diskv) ([MIT](https://github.com/peterbourgon/diskv/blob/v2.0.1/LICENSE)) - - [github.com/prometheus/client_golang/prometheus](https://pkg.go.dev/github.com/prometheus/client_golang/prometheus) ([Apache-2.0](https://github.com/prometheus/client_golang/blob/v1.23.0/LICENSE)) + - [github.com/prometheus/client_golang/prometheus](https://pkg.go.dev/github.com/prometheus/client_golang/prometheus) ([Apache-2.0](https://github.com/prometheus/client_golang/blob/v1.23.2/LICENSE)) - [github.com/prometheus/client_model/go](https://pkg.go.dev/github.com/prometheus/client_model/go) ([Apache-2.0](https://github.com/prometheus/client_model/blob/v0.6.2/LICENSE)) - - [github.com/prometheus/common](https://pkg.go.dev/github.com/prometheus/common) ([Apache-2.0](https://github.com/prometheus/common/blob/v0.65.0/LICENSE)) + - [github.com/prometheus/common](https://pkg.go.dev/github.com/prometheus/common) ([Apache-2.0](https://github.com/prometheus/common/blob/v0.66.1/LICENSE)) - [github.com/skip2/go-qrcode](https://pkg.go.dev/github.com/skip2/go-qrcode) ([MIT](https://github.com/skip2/go-qrcode/blob/da1b6568686e/LICENSE)) - [github.com/tailscale/go-winio](https://pkg.go.dev/github.com/tailscale/go-winio) ([MIT](https://github.com/tailscale/go-winio/blob/c4f33415bf55/LICENSE)) - [github.com/tailscale/hujson](https://pkg.go.dev/github.com/tailscale/hujson) ([BSD-3-Clause](https://github.com/tailscale/hujson/blob/992244df8c5a/LICENSE)) @@ -47,19 +48,20 @@ Windows][]. See also the dependencies in the [Tailscale CLI][]. - [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE)) - [github.com/tc-hib/winres](https://pkg.go.dev/github.com/tc-hib/winres) ([0BSD](https://github.com/tc-hib/winres/blob/v0.2.1/LICENSE)) - [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE)) + - [go.yaml.in/yaml/v2](https://pkg.go.dev/go.yaml.in/yaml/v2) ([Apache-2.0](https://github.com/yaml/go-yaml/blob/v2.4.2/LICENSE)) - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.42.0:LICENSE)) + - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.43.0:LICENSE)) - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/df929982:LICENSE)) - [golang.org/x/image/bmp](https://pkg.go.dev/golang.org/x/image/bmp) ([BSD-3-Clause](https://cs.opensource.google/go/x/image/+/v0.27.0:LICENSE)) - [golang.org/x/mod](https://pkg.go.dev/golang.org/x/mod) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.28.0:LICENSE)) - - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.44.0:LICENSE)) + - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.46.0:LICENSE)) - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.17.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.36.0:LICENSE)) - - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.35.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.37.0:LICENSE)) + - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.36.0:LICENSE)) - [golang.zx2c4.com/wintun](https://pkg.go.dev/golang.zx2c4.com/wintun) ([MIT](https://git.zx2c4.com/wintun-go/tree/LICENSE?id=0fa3db229ce2)) - [golang.zx2c4.com/wireguard/windows/tunnel/winipcfg](https://pkg.go.dev/golang.zx2c4.com/wireguard/windows/tunnel/winipcfg) ([MIT](https://git.zx2c4.com/wireguard-windows/tree/COPYING?h=v0.5.3)) - - [google.golang.org/protobuf](https://pkg.go.dev/google.golang.org/protobuf) ([BSD-3-Clause](https://github.com/protocolbuffers/protobuf-go/blob/v1.36.7/LICENSE)) + - [google.golang.org/protobuf](https://pkg.go.dev/google.golang.org/protobuf) ([BSD-3-Clause](https://github.com/protocolbuffers/protobuf-go/blob/v1.36.8/LICENSE)) - [gopkg.in/Knetic/govaluate.v3](https://pkg.go.dev/gopkg.in/Knetic/govaluate.v3) ([MIT](https://github.com/Knetic/govaluate/blob/v3.0.0/LICENSE)) - [gopkg.in/yaml.v3](https://pkg.go.dev/gopkg.in/yaml.v3) ([MIT](https://github.com/go-yaml/yaml/blob/v3.0.1/LICENSE)) - [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE)) diff --git a/log/sockstatlog/logger.go b/log/sockstatlog/logger.go index e0744de0f089a..8ddfabb866745 100644 --- a/log/sockstatlog/logger.go +++ b/log/sockstatlog/logger.go @@ -146,33 +146,33 @@ func NewLogger(logdir string, logf logger.Logf, logID logid.PublicID, netMon *ne // SetLoggingEnabled enables or disables logging. // When disabled, socket stats are not polled and no new logs are written to disk. // Existing logs can still be fetched via the C2N API. -func (l *Logger) SetLoggingEnabled(v bool) { - old := l.enabled.Load() - if old != v && l.enabled.CompareAndSwap(old, v) { +func (lg *Logger) SetLoggingEnabled(v bool) { + old := lg.enabled.Load() + if old != v && lg.enabled.CompareAndSwap(old, v) { if v { - if l.eventCh == nil { + if lg.eventCh == nil { // eventCh should be large enough for the number of events that will occur within logInterval. // Add an extra second's worth of events to ensure we don't drop any. - l.eventCh = make(chan event, (logInterval+time.Second)/pollInterval) + lg.eventCh = make(chan event, (logInterval+time.Second)/pollInterval) } - l.ctx, l.cancelFn = context.WithCancel(context.Background()) - go l.poll() - go l.logEvents() + lg.ctx, lg.cancelFn = context.WithCancel(context.Background()) + go lg.poll() + go lg.logEvents() } else { - l.cancelFn() + lg.cancelFn() } } } -func (l *Logger) Write(p []byte) (int, error) { - return l.logger.Write(p) +func (lg *Logger) Write(p []byte) (int, error) { + return lg.logger.Write(p) } // poll fetches the current socket stats at the configured time interval, // calculates the delta since the last poll, // and writes any non-zero values to the logger event channel. // This method does not return. -func (l *Logger) poll() { +func (lg *Logger) poll() { // last is the last set of socket stats we saw. var lastStats *sockstats.SockStats var lastTime time.Time @@ -180,7 +180,7 @@ func (l *Logger) poll() { ticker := time.NewTicker(pollInterval) for { select { - case <-l.ctx.Done(): + case <-lg.ctx.Done(): ticker.Stop() return case t := <-ticker.C: @@ -196,7 +196,7 @@ func (l *Logger) poll() { if stats.CurrentInterfaceCellular { e.IsCellularInterface = 1 } - l.eventCh <- e + lg.eventCh <- e } } lastTime = t @@ -207,14 +207,14 @@ func (l *Logger) poll() { // logEvents reads events from the event channel at logInterval and logs them to disk. // This method does not return. -func (l *Logger) logEvents() { - enc := json.NewEncoder(l) +func (lg *Logger) logEvents() { + enc := json.NewEncoder(lg) flush := func() { for { select { - case e := <-l.eventCh: + case e := <-lg.eventCh: if err := enc.Encode(e); err != nil { - l.logf("sockstatlog: error encoding log: %v", err) + lg.logf("sockstatlog: error encoding log: %v", err) } default: return @@ -224,7 +224,7 @@ func (l *Logger) logEvents() { ticker := time.NewTicker(logInterval) for { select { - case <-l.ctx.Done(): + case <-lg.ctx.Done(): ticker.Stop() return case <-ticker.C: @@ -233,29 +233,29 @@ func (l *Logger) logEvents() { } } -func (l *Logger) LogID() string { - if l.logger == nil { +func (lg *Logger) LogID() string { + if lg.logger == nil { return "" } - return l.logger.PrivateID().Public().String() + return lg.logger.PrivateID().Public().String() } // Flush sends pending logs to the log server and flushes them from the local buffer. -func (l *Logger) Flush() { - l.logger.StartFlush() +func (lg *Logger) Flush() { + lg.logger.StartFlush() } -func (l *Logger) Shutdown(ctx context.Context) { - if l.cancelFn != nil { - l.cancelFn() +func (lg *Logger) Shutdown(ctx context.Context) { + if lg.cancelFn != nil { + lg.cancelFn() } - l.filch.Close() - l.logger.Shutdown(ctx) + lg.filch.Close() + lg.logger.Shutdown(ctx) type closeIdler interface { CloseIdleConnections() } - if tr, ok := l.tr.(closeIdler); ok { + if tr, ok := lg.tr.(closeIdler); ok { tr.CloseIdleConnections() } } diff --git a/logpolicy/logpolicy.go b/logpolicy/logpolicy.go index 9c7e62ab0da11..f7491783ad781 100644 --- a/logpolicy/logpolicy.go +++ b/logpolicy/logpolicy.go @@ -193,8 +193,8 @@ type logWriter struct { logger *log.Logger } -func (l logWriter) Write(buf []byte) (int, error) { - l.logger.Printf("%s", buf) +func (lg logWriter) Write(buf []byte) (int, error) { + lg.logger.Printf("%s", buf) return len(buf), nil } @@ -640,10 +640,16 @@ func (opts Options) init(disableLogging bool) (*logtail.Config, *Policy) { logHost := logtail.DefaultHost if val := getLogTarget(); val != "" { - opts.Logf("You have enabled a non-default log target. Doing without being told to by Tailscale staff or your network administrator will make getting support difficult.") - conf.BaseURL = val - u, _ := url.Parse(val) - logHost = u.Host + u, err := url.Parse(val) + if err != nil { + opts.Logf("logpolicy: invalid TS_LOG_TARGET %q: %v; using default log host", val, err) + } else if u.Host == "" { + opts.Logf("logpolicy: invalid TS_LOG_TARGET %q: missing host; using default log host", val) + } else { + opts.Logf("You have enabled a non-default log target. Doing without being told to by Tailscale staff or your network administrator will make getting support difficult.") + conf.BaseURL = val + logHost = u.Host + } } if conf.HTTPC == nil { diff --git a/logpolicy/logpolicy_test.go b/logpolicy/logpolicy_test.go index 28f03448a225d..c09e590bb8399 100644 --- a/logpolicy/logpolicy_test.go +++ b/logpolicy/logpolicy_test.go @@ -84,3 +84,47 @@ func TestOptions(t *testing.T) { }) } } + +// TestInvalidLogTarget is a test for #17792 +func TestInvalidLogTarget(t *testing.T) { + defer resetLogTarget() + + tests := []struct { + name string + logTarget string + }{ + { + name: "invalid_url_no_scheme", + logTarget: "not a url at all", + }, + { + name: "malformed_url", + logTarget: "ht!tp://invalid", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetLogTarget() + os.Setenv("TS_LOG_TARGET", tt.logTarget) + + opts := Options{ + Collection: "test.log.tailscale.io", + Logf: t.Logf, + } + + // This should not panic even with invalid log target + config, policy := opts.init(false) + if policy == nil { + t.Fatal("expected non-nil policy") + } + defer policy.Close() + + // When log target is invalid, it should fall back to the invalid value + // but not crash. BaseURL should remain empty + if config.BaseURL != "" { + t.Errorf("got BaseURL=%q, want empty", config.BaseURL) + } + }) + } +} diff --git a/logtail/buffer.go b/logtail/buffer.go index d14d8fbf6ae51..82c9b461010b2 100644 --- a/logtail/buffer.go +++ b/logtail/buffer.go @@ -9,7 +9,8 @@ import ( "bytes" "errors" "fmt" - "sync" + + "tailscale.com/syncs" ) type Buffer interface { @@ -36,7 +37,7 @@ type memBuffer struct { next []byte pending chan qentry - dropMu sync.Mutex + dropMu syncs.Mutex dropCount int } diff --git a/logtail/logtail.go b/logtail/logtail.go index 675422890149c..2879c6b0d3cf8 100644 --- a/logtail/logtail.go +++ b/logtail/logtail.go @@ -25,6 +25,7 @@ import ( "sync/atomic" "time" + "github.com/creachadair/msync/trigger" "github.com/go-json-experiment/json/jsontext" "tailscale.com/envknob" "tailscale.com/net/netmon" @@ -99,7 +100,7 @@ func NewLogger(cfg Config, logf tslogger.Logf) *Logger { if !cfg.CopyPrivateID.IsZero() { urlSuffix = "?copyId=" + cfg.CopyPrivateID.String() } - l := &Logger{ + logger := &Logger{ privateID: cfg.PrivateID, stderr: cfg.Stderr, stderrLevel: int64(cfg.StderrLevel), @@ -123,17 +124,19 @@ func NewLogger(cfg Config, logf tslogger.Logf) *Logger { } if cfg.Bus != nil { - l.eventClient = cfg.Bus.Client("logtail.Logger") + logger.eventClient = cfg.Bus.Client("logtail.Logger") + // Subscribe to change deltas from NetMon to detect when the network comes up. + eventbus.SubscribeFunc(logger.eventClient, logger.onChangeDelta) } - l.SetSockstatsLabel(sockstats.LabelLogtailLogger) - l.compressLogs = cfg.CompressLogs + logger.SetSockstatsLabel(sockstats.LabelLogtailLogger) + logger.compressLogs = cfg.CompressLogs ctx, cancel := context.WithCancel(context.Background()) - l.uploadCancel = cancel + logger.uploadCancel = cancel - go l.uploading(ctx) - l.Write([]byte("logtail started")) - return l + go logger.uploading(ctx) + logger.Write([]byte("logtail started")) + return logger } // Logger writes logs, splitting them as configured between local @@ -162,6 +165,7 @@ type Logger struct { httpDoCalls atomic.Int32 sockstatsLabel atomicSocktatsLabel eventClient *eventbus.Client + networkIsUp trigger.Cond // set/reset by netmon.ChangeDelta events procID uint32 includeProcSequence bool @@ -186,27 +190,27 @@ func (p *atomicSocktatsLabel) Store(label sockstats.Label) { p.p.Store(uint32(la // SetVerbosityLevel controls the verbosity level that should be // written to stderr. 0 is the default (not verbose). Levels 1 or higher // are increasingly verbose. -func (l *Logger) SetVerbosityLevel(level int) { - atomic.StoreInt64(&l.stderrLevel, int64(level)) +func (lg *Logger) SetVerbosityLevel(level int) { + atomic.StoreInt64(&lg.stderrLevel, int64(level)) } // SetNetMon sets the network monitor. // // It should not be changed concurrently with log writes and should // only be set once. -func (l *Logger) SetNetMon(lm *netmon.Monitor) { - l.netMonitor = lm +func (lg *Logger) SetNetMon(lm *netmon.Monitor) { + lg.netMonitor = lm } // SetSockstatsLabel sets the label used in sockstat logs to identify network traffic from this logger. -func (l *Logger) SetSockstatsLabel(label sockstats.Label) { - l.sockstatsLabel.Store(label) +func (lg *Logger) SetSockstatsLabel(label sockstats.Label) { + lg.sockstatsLabel.Store(label) } // PrivateID returns the logger's private log ID. // // It exists for internal use only. -func (l *Logger) PrivateID() logid.PrivateID { return l.privateID } +func (lg *Logger) PrivateID() logid.PrivateID { return lg.privateID } // Shutdown gracefully shuts down the logger while completing any // remaining uploads. @@ -214,33 +218,33 @@ func (l *Logger) PrivateID() logid.PrivateID { return l.privateID } // It will block, continuing to try and upload unless the passed // context object interrupts it by being done. // If the shutdown is interrupted, an error is returned. -func (l *Logger) Shutdown(ctx context.Context) error { +func (lg *Logger) Shutdown(ctx context.Context) error { done := make(chan struct{}) go func() { select { case <-ctx.Done(): - l.uploadCancel() - <-l.shutdownDone - case <-l.shutdownDone: + lg.uploadCancel() + <-lg.shutdownDone + case <-lg.shutdownDone: } close(done) - l.httpc.CloseIdleConnections() + lg.httpc.CloseIdleConnections() }() - if l.eventClient != nil { - l.eventClient.Close() + if lg.eventClient != nil { + lg.eventClient.Close() } - l.shutdownStartMu.Lock() + lg.shutdownStartMu.Lock() select { - case <-l.shutdownStart: - l.shutdownStartMu.Unlock() + case <-lg.shutdownStart: + lg.shutdownStartMu.Unlock() return nil default: } - close(l.shutdownStart) - l.shutdownStartMu.Unlock() + close(lg.shutdownStart) + lg.shutdownStartMu.Unlock() - io.WriteString(l, "logger closing down\n") + io.WriteString(lg, "logger closing down\n") <-done return nil @@ -250,8 +254,8 @@ func (l *Logger) Shutdown(ctx context.Context) error { // process, and any associated goroutines. // // Deprecated: use Shutdown -func (l *Logger) Close() { - l.Shutdown(context.Background()) +func (lg *Logger) Close() { + lg.Shutdown(context.Background()) } // drainBlock is called by drainPending when there are no logs to drain. @@ -261,11 +265,11 @@ func (l *Logger) Close() { // // If the caller specified FlushInterface, drainWake is only sent to // periodically. -func (l *Logger) drainBlock() (shuttingDown bool) { +func (lg *Logger) drainBlock() (shuttingDown bool) { select { - case <-l.shutdownStart: + case <-lg.shutdownStart: return true - case <-l.drainWake: + case <-lg.drainWake: } return false } @@ -273,20 +277,20 @@ func (l *Logger) drainBlock() (shuttingDown bool) { // drainPending drains and encodes a batch of logs from the buffer for upload. // If no logs are available, drainPending blocks until logs are available. // The returned buffer is only valid until the next call to drainPending. -func (l *Logger) drainPending() (b []byte) { - b = l.drainBuf[:0] +func (lg *Logger) drainPending() (b []byte) { + b = lg.drainBuf[:0] b = append(b, '[') defer func() { b = bytes.TrimRight(b, ",") b = append(b, ']') - l.drainBuf = b + lg.drainBuf = b if len(b) <= len("[]") { b = nil } }() - maxLen := cmp.Or(l.maxUploadSize, maxSize) - if l.lowMem { + maxLen := cmp.Or(lg.maxUploadSize, maxSize) + if lg.lowMem { // When operating in a low memory environment, it is better to upload // in multiple operations than it is to allocate a large body and OOM. // Even if maxLen is less than maxSize, we can still upload an entry @@ -294,13 +298,13 @@ func (l *Logger) drainPending() (b []byte) { maxLen /= lowMemRatio } for len(b) < maxLen { - line, err := l.buffer.TryReadLine() + line, err := lg.buffer.TryReadLine() switch { case err == io.EOF: return b case err != nil: b = append(b, '{') - b = l.appendMetadata(b, false, true, 0, 0, "reading ringbuffer: "+err.Error(), nil, 0) + b = lg.appendMetadata(b, false, true, 0, 0, "reading ringbuffer: "+err.Error(), nil, 0) b = bytes.TrimRight(b, ",") b = append(b, '}') return b @@ -314,10 +318,10 @@ func (l *Logger) drainPending() (b []byte) { // in our buffer from a previous large write, let it go. if cap(b) > bufferSize { b = bytes.Clone(b) - l.drainBuf = b + lg.drainBuf = b } - if shuttingDown := l.drainBlock(); shuttingDown { + if shuttingDown := lg.drainBlock(); shuttingDown { return b } continue @@ -334,18 +338,18 @@ func (l *Logger) drainPending() (b []byte) { default: // This is probably a log added to stderr by filch // outside of the logtail logger. Encode it. - if !l.explainedRaw { - fmt.Fprintf(l.stderr, "RAW-STDERR: ***\n") - fmt.Fprintf(l.stderr, "RAW-STDERR: *** Lines prefixed with RAW-STDERR below bypassed logtail and probably come from a previous run of the program\n") - fmt.Fprintf(l.stderr, "RAW-STDERR: ***\n") - fmt.Fprintf(l.stderr, "RAW-STDERR:\n") - l.explainedRaw = true + if !lg.explainedRaw { + fmt.Fprintf(lg.stderr, "RAW-STDERR: ***\n") + fmt.Fprintf(lg.stderr, "RAW-STDERR: *** Lines prefixed with RAW-STDERR below bypassed logtail and probably come from a previous run of the program\n") + fmt.Fprintf(lg.stderr, "RAW-STDERR: ***\n") + fmt.Fprintf(lg.stderr, "RAW-STDERR:\n") + lg.explainedRaw = true } - fmt.Fprintf(l.stderr, "RAW-STDERR: %s", b) + fmt.Fprintf(lg.stderr, "RAW-STDERR: %s", b) // Do not add a client time, as it could be really old. // Do not include instance key or ID either, // since this came from a different instance. - b = l.appendText(b, line, true, 0, 0, 0) + b = lg.appendText(b, line, true, 0, 0, 0) } b = append(b, ',') } @@ -353,14 +357,14 @@ func (l *Logger) drainPending() (b []byte) { } // This is the goroutine that repeatedly uploads logs in the background. -func (l *Logger) uploading(ctx context.Context) { - defer close(l.shutdownDone) +func (lg *Logger) uploading(ctx context.Context) { + defer close(lg.shutdownDone) for { - body := l.drainPending() + body := lg.drainPending() origlen := -1 // sentinel value: uncompressed // Don't attempt to compress tiny bodies; not worth the CPU cycles. - if l.compressLogs && len(body) > 256 { + if lg.compressLogs && len(body) > 256 { zbody := zstdframe.AppendEncode(nil, body, zstdframe.FastestCompression, zstdframe.LowMemory(true)) @@ -377,20 +381,20 @@ func (l *Logger) uploading(ctx context.Context) { var numFailures int var firstFailure time.Time for len(body) > 0 && ctx.Err() == nil { - retryAfter, err := l.upload(ctx, body, origlen) + retryAfter, err := lg.upload(ctx, body, origlen) if err != nil { numFailures++ - firstFailure = l.clock.Now() + firstFailure = lg.clock.Now() - if !l.internetUp() { - fmt.Fprintf(l.stderr, "logtail: internet down; waiting\n") - l.awaitInternetUp(ctx) + if !lg.internetUp() { + fmt.Fprintf(lg.stderr, "logtail: internet down; waiting\n") + lg.awaitInternetUp(ctx) continue } // Only print the same message once. if currError := err.Error(); lastError != currError { - fmt.Fprintf(l.stderr, "logtail: upload: %v\n", err) + fmt.Fprintf(lg.stderr, "logtail: upload: %v\n", err) lastError = currError } @@ -403,35 +407,55 @@ func (l *Logger) uploading(ctx context.Context) { } else { // Only print a success message after recovery. if numFailures > 0 { - fmt.Fprintf(l.stderr, "logtail: upload succeeded after %d failures and %s\n", numFailures, l.clock.Since(firstFailure).Round(time.Second)) + fmt.Fprintf(lg.stderr, "logtail: upload succeeded after %d failures and %s\n", numFailures, lg.clock.Since(firstFailure).Round(time.Second)) } break } } select { - case <-l.shutdownStart: + case <-lg.shutdownStart: return default: } } } -func (l *Logger) internetUp() bool { - if l.netMonitor == nil { - // No way to tell, so assume it is. +func (lg *Logger) internetUp() bool { + select { + case <-lg.networkIsUp.Ready(): return true + default: + if lg.netMonitor == nil { + return true // No way to tell, so assume it is. + } + return lg.netMonitor.InterfaceState().AnyInterfaceUp() } - return l.netMonitor.InterfaceState().AnyInterfaceUp() } -func (l *Logger) awaitInternetUp(ctx context.Context) { - if l.eventClient != nil { - l.awaitInternetUpBus(ctx) +// onChangeDelta is an eventbus subscriber function that handles +// [netmon.ChangeDelta] events to detect whether the Internet is expected to be +// reachable. +func (lg *Logger) onChangeDelta(delta *netmon.ChangeDelta) { + if delta.New.AnyInterfaceUp() { + fmt.Fprintf(lg.stderr, "logtail: internet back up\n") + lg.networkIsUp.Set() + } else { + fmt.Fprintf(lg.stderr, "logtail: network changed, but is not up\n") + lg.networkIsUp.Reset() + } +} + +func (lg *Logger) awaitInternetUp(ctx context.Context) { + if lg.eventClient != nil { + select { + case <-lg.networkIsUp.Ready(): + case <-ctx.Done(): + } return } upc := make(chan bool, 1) - defer l.netMonitor.RegisterChangeCallback(func(delta *netmon.ChangeDelta) { + defer lg.netMonitor.RegisterChangeCallback(func(delta *netmon.ChangeDelta) { if delta.New.AnyInterfaceUp() { select { case upc <- true: @@ -439,44 +463,26 @@ func (l *Logger) awaitInternetUp(ctx context.Context) { } } })() - if l.internetUp() { + if lg.internetUp() { return } select { case <-upc: - fmt.Fprintf(l.stderr, "logtail: internet back up\n") + fmt.Fprintf(lg.stderr, "logtail: internet back up\n") case <-ctx.Done(): } } -func (l *Logger) awaitInternetUpBus(ctx context.Context) { - if l.internetUp() { - return - } - sub := eventbus.Subscribe[netmon.ChangeDelta](l.eventClient) - defer sub.Close() - select { - case delta := <-sub.Events(): - if delta.New.AnyInterfaceUp() { - fmt.Fprintf(l.stderr, "logtail: internet back up\n") - return - } - fmt.Fprintf(l.stderr, "logtail: network changed, but is not up") - case <-ctx.Done(): - return - } -} - // upload uploads body to the log server. // origlen indicates the pre-compression body length. // origlen of -1 indicates that the body is not compressed. -func (l *Logger) upload(ctx context.Context, body []byte, origlen int) (retryAfter time.Duration, err error) { +func (lg *Logger) upload(ctx context.Context, body []byte, origlen int) (retryAfter time.Duration, err error) { const maxUploadTime = 45 * time.Second - ctx = sockstats.WithSockStats(ctx, l.sockstatsLabel.Load(), l.Logf) + ctx = sockstats.WithSockStats(ctx, lg.sockstatsLabel.Load(), lg.Logf) ctx, cancel := context.WithTimeout(ctx, maxUploadTime) defer cancel() - req, err := http.NewRequestWithContext(ctx, "POST", l.url, bytes.NewReader(body)) + req, err := http.NewRequestWithContext(ctx, "POST", lg.url, bytes.NewReader(body)) if err != nil { // I know of no conditions under which this could fail. // Report it very loudly. @@ -507,8 +513,8 @@ func (l *Logger) upload(ctx context.Context, body []byte, origlen int) (retryAft compressedNote = "compressed" } - l.httpDoCalls.Add(1) - resp, err := l.httpc.Do(req) + lg.httpDoCalls.Add(1) + resp, err := lg.httpc.Do(req) if err != nil { return 0, fmt.Errorf("log upload of %d bytes %s failed: %v", len(body), compressedNote, err) } @@ -527,16 +533,16 @@ func (l *Logger) upload(ctx context.Context, body []byte, origlen int) (retryAft // // TODO(bradfitz): this apparently just returns nil, as of tailscale/corp@9c2ec35. // Finish cleaning this up. -func (l *Logger) Flush() error { +func (lg *Logger) Flush() error { return nil } // StartFlush starts a log upload, if anything is pending. // // If l is nil, StartFlush is a no-op. -func (l *Logger) StartFlush() { - if l != nil { - l.tryDrainWake() +func (lg *Logger) StartFlush() { + if lg != nil { + lg.tryDrainWake() } } @@ -552,41 +558,41 @@ var debugWakesAndUploads = envknob.RegisterBool("TS_DEBUG_LOGTAIL_WAKES") // tryDrainWake tries to send to lg.drainWake, to cause an uploading wakeup. // It does not block. -func (l *Logger) tryDrainWake() { - l.flushPending.Store(false) +func (lg *Logger) tryDrainWake() { + lg.flushPending.Store(false) if debugWakesAndUploads() { // Using println instead of log.Printf here to avoid recursing back into // ourselves. - println("logtail: try drain wake, numHTTP:", l.httpDoCalls.Load()) + println("logtail: try drain wake, numHTTP:", lg.httpDoCalls.Load()) } select { - case l.drainWake <- struct{}{}: + case lg.drainWake <- struct{}{}: default: } } -func (l *Logger) sendLocked(jsonBlob []byte) (int, error) { +func (lg *Logger) sendLocked(jsonBlob []byte) (int, error) { tapSend(jsonBlob) if logtailDisabled.Load() { return len(jsonBlob), nil } - n, err := l.buffer.Write(jsonBlob) + n, err := lg.buffer.Write(jsonBlob) flushDelay := defaultFlushDelay - if l.flushDelayFn != nil { - flushDelay = l.flushDelayFn() + if lg.flushDelayFn != nil { + flushDelay = lg.flushDelayFn() } if flushDelay > 0 { - if l.flushPending.CompareAndSwap(false, true) { - if l.flushTimer == nil { - l.flushTimer = l.clock.AfterFunc(flushDelay, l.tryDrainWake) + if lg.flushPending.CompareAndSwap(false, true) { + if lg.flushTimer == nil { + lg.flushTimer = lg.clock.AfterFunc(flushDelay, lg.tryDrainWake) } else { - l.flushTimer.Reset(flushDelay) + lg.flushTimer.Reset(flushDelay) } } } else { - l.tryDrainWake() + lg.tryDrainWake() } return n, err } @@ -594,13 +600,13 @@ func (l *Logger) sendLocked(jsonBlob []byte) (int, error) { // appendMetadata appends optional "logtail", "metrics", and "v" JSON members. // This assumes dst is already within a JSON object. // Each member is comma-terminated. -func (l *Logger) appendMetadata(dst []byte, skipClientTime, skipMetrics bool, procID uint32, procSequence uint64, errDetail string, errData jsontext.Value, level int) []byte { +func (lg *Logger) appendMetadata(dst []byte, skipClientTime, skipMetrics bool, procID uint32, procSequence uint64, errDetail string, errData jsontext.Value, level int) []byte { // Append optional logtail metadata. if !skipClientTime || procID != 0 || procSequence != 0 || errDetail != "" || errData != nil { dst = append(dst, `"logtail":{`...) if !skipClientTime { dst = append(dst, `"client_time":"`...) - dst = l.clock.Now().UTC().AppendFormat(dst, time.RFC3339Nano) + dst = lg.clock.Now().UTC().AppendFormat(dst, time.RFC3339Nano) dst = append(dst, '"', ',') } if procID != 0 { @@ -633,8 +639,8 @@ func (l *Logger) appendMetadata(dst []byte, skipClientTime, skipMetrics bool, pr } // Append optional metrics metadata. - if !skipMetrics && l.metricsDelta != nil { - if d := l.metricsDelta(); d != "" { + if !skipMetrics && lg.metricsDelta != nil { + if d := lg.metricsDelta(); d != "" { dst = append(dst, `"metrics":"`...) dst = append(dst, d...) dst = append(dst, '"', ',') @@ -654,10 +660,10 @@ func (l *Logger) appendMetadata(dst []byte, skipClientTime, skipMetrics bool, pr } // appendText appends a raw text message in the Tailscale JSON log entry format. -func (l *Logger) appendText(dst, src []byte, skipClientTime bool, procID uint32, procSequence uint64, level int) []byte { +func (lg *Logger) appendText(dst, src []byte, skipClientTime bool, procID uint32, procSequence uint64, level int) []byte { dst = slices.Grow(dst, len(src)) dst = append(dst, '{') - dst = l.appendMetadata(dst, skipClientTime, false, procID, procSequence, "", nil, level) + dst = lg.appendMetadata(dst, skipClientTime, false, procID, procSequence, "", nil, level) if len(src) == 0 { dst = bytes.TrimRight(dst, ",") return append(dst, "}\n"...) @@ -666,7 +672,7 @@ func (l *Logger) appendText(dst, src []byte, skipClientTime bool, procID uint32, // Append the text string, which may be truncated. // Invalid UTF-8 will be mangled with the Unicode replacement character. max := maxTextSize - if l.lowMem { + if lg.lowMem { max /= lowMemRatio } dst = append(dst, `"text":`...) @@ -691,12 +697,12 @@ func appendTruncatedString(dst, src []byte, n int) []byte { // appendTextOrJSONLocked appends a raw text message or a raw JSON object // in the Tailscale JSON log format. -func (l *Logger) appendTextOrJSONLocked(dst, src []byte, level int) []byte { - if l.includeProcSequence { - l.procSequence++ +func (lg *Logger) appendTextOrJSONLocked(dst, src []byte, level int) []byte { + if lg.includeProcSequence { + lg.procSequence++ } if len(src) == 0 || src[0] != '{' { - return l.appendText(dst, src, l.skipClientTime, l.procID, l.procSequence, level) + return lg.appendText(dst, src, lg.skipClientTime, lg.procID, lg.procSequence, level) } // Check whether the input is a valid JSON object and @@ -708,11 +714,11 @@ func (l *Logger) appendTextOrJSONLocked(dst, src []byte, level int) []byte { // However, bytes.NewBuffer normally allocates unless // we immediately shallow copy it into a pre-allocated Buffer struct. // See https://go.dev/issue/67004. - l.bytesBuf = *bytes.NewBuffer(src) - defer func() { l.bytesBuf = bytes.Buffer{} }() // avoid pinning src + lg.bytesBuf = *bytes.NewBuffer(src) + defer func() { lg.bytesBuf = bytes.Buffer{} }() // avoid pinning src - dec := &l.jsonDec - dec.Reset(&l.bytesBuf) + dec := &lg.jsonDec + dec.Reset(&lg.bytesBuf) if tok, err := dec.ReadToken(); tok.Kind() != '{' || err != nil { return false } @@ -744,7 +750,7 @@ func (l *Logger) appendTextOrJSONLocked(dst, src []byte, level int) []byte { // Treat invalid JSON as a raw text message. if !validJSON { - return l.appendText(dst, src, l.skipClientTime, l.procID, l.procSequence, level) + return lg.appendText(dst, src, lg.skipClientTime, lg.procID, lg.procSequence, level) } // Check whether the JSON payload is too large. @@ -752,13 +758,13 @@ func (l *Logger) appendTextOrJSONLocked(dst, src []byte, level int) []byte { // That's okay as the Tailscale log service limit is actually 2*maxSize. // However, so long as logging applications aim to target the maxSize limit, // there should be no trouble eventually uploading logs. - maxLen := cmp.Or(l.maxUploadSize, maxSize) + maxLen := cmp.Or(lg.maxUploadSize, maxSize) if len(src) > maxLen { errDetail := fmt.Sprintf("entry too large: %d bytes", len(src)) errData := appendTruncatedString(nil, src, maxLen/len(`\uffff`)) // escaping could increase size dst = append(dst, '{') - dst = l.appendMetadata(dst, l.skipClientTime, true, l.procID, l.procSequence, errDetail, errData, level) + dst = lg.appendMetadata(dst, lg.skipClientTime, true, lg.procID, lg.procSequence, errDetail, errData, level) dst = bytes.TrimRight(dst, ",") return append(dst, "}\n"...) } @@ -775,7 +781,7 @@ func (l *Logger) appendTextOrJSONLocked(dst, src []byte, level int) []byte { } dst = slices.Grow(dst, len(src)) dst = append(dst, '{') - dst = l.appendMetadata(dst, l.skipClientTime, true, l.procID, l.procSequence, errDetail, errData, level) + dst = lg.appendMetadata(dst, lg.skipClientTime, true, lg.procID, lg.procSequence, errDetail, errData, level) if logtailValLength > 0 { // Exclude original logtail member from the message. dst = appendWithoutNewline(dst, src[len("{"):logtailKeyOffset]) @@ -802,8 +808,8 @@ func appendWithoutNewline(dst, src []byte) []byte { } // Logf logs to l using the provided fmt-style format and optional arguments. -func (l *Logger) Logf(format string, args ...any) { - fmt.Fprintf(l, format, args...) +func (lg *Logger) Logf(format string, args ...any) { + fmt.Fprintf(lg, format, args...) } // Write logs an encoded JSON blob. @@ -812,29 +818,29 @@ func (l *Logger) Logf(format string, args ...any) { // then contents is fit into a JSON blob and written. // // This is intended as an interface for the stdlib "log" package. -func (l *Logger) Write(buf []byte) (int, error) { +func (lg *Logger) Write(buf []byte) (int, error) { if len(buf) == 0 { return 0, nil } inLen := len(buf) // length as provided to us, before modifications to downstream writers level, buf := parseAndRemoveLogLevel(buf) - if l.stderr != nil && l.stderr != io.Discard && int64(level) <= atomic.LoadInt64(&l.stderrLevel) { + if lg.stderr != nil && lg.stderr != io.Discard && int64(level) <= atomic.LoadInt64(&lg.stderrLevel) { if buf[len(buf)-1] == '\n' { - l.stderr.Write(buf) + lg.stderr.Write(buf) } else { // The log package always line-terminates logs, // so this is an uncommon path. withNL := append(buf[:len(buf):len(buf)], '\n') - l.stderr.Write(withNL) + lg.stderr.Write(withNL) } } - l.writeLock.Lock() - defer l.writeLock.Unlock() + lg.writeLock.Lock() + defer lg.writeLock.Unlock() - b := l.appendTextOrJSONLocked(l.writeBuf[:0], buf, level) - _, err := l.sendLocked(b) + b := lg.appendTextOrJSONLocked(lg.writeBuf[:0], buf, level) + _, err := lg.sendLocked(b) return inLen, err } diff --git a/logtail/logtail_test.go b/logtail/logtail_test.go index a92f88b4bb03e..b618fc0d7bc65 100644 --- a/logtail/logtail_test.go +++ b/logtail/logtail_test.go @@ -29,11 +29,11 @@ func TestFastShutdown(t *testing.T) { func(w http.ResponseWriter, r *http.Request) {})) defer testServ.Close() - l := NewLogger(Config{ + logger := NewLogger(Config{ BaseURL: testServ.URL, Bus: eventbustest.NewBus(t), }, t.Logf) - err := l.Shutdown(ctx) + err := logger.Shutdown(ctx) if err != nil { t.Error(err) } @@ -64,7 +64,7 @@ func NewLogtailTestHarness(t *testing.T) (*LogtailTestServer, *Logger) { t.Cleanup(ts.srv.Close) - l := NewLogger(Config{ + logger := NewLogger(Config{ BaseURL: ts.srv.URL, Bus: eventbustest.NewBus(t), }, t.Logf) @@ -75,14 +75,14 @@ func NewLogtailTestHarness(t *testing.T) (*LogtailTestServer, *Logger) { t.Errorf("unknown start logging statement: %q", string(body)) } - return &ts, l + return &ts, logger } func TestDrainPendingMessages(t *testing.T) { - ts, l := NewLogtailTestHarness(t) + ts, logger := NewLogtailTestHarness(t) for range logLines { - l.Write([]byte("log line")) + logger.Write([]byte("log line")) } // all of the "log line" messages usually arrive at once, but poll if needed. @@ -96,14 +96,14 @@ func TestDrainPendingMessages(t *testing.T) { // if we never find count == logLines, the test will eventually time out. } - err := l.Shutdown(context.Background()) + err := logger.Shutdown(context.Background()) if err != nil { t.Error(err) } } func TestEncodeAndUploadMessages(t *testing.T) { - ts, l := NewLogtailTestHarness(t) + ts, logger := NewLogtailTestHarness(t) tests := []struct { name string @@ -123,7 +123,7 @@ func TestEncodeAndUploadMessages(t *testing.T) { } for _, tt := range tests { - io.WriteString(l, tt.log) + io.WriteString(logger, tt.log) body := <-ts.uploaded data := unmarshalOne(t, body) @@ -144,7 +144,7 @@ func TestEncodeAndUploadMessages(t *testing.T) { } } - err := l.Shutdown(context.Background()) + err := logger.Shutdown(context.Background()) if err != nil { t.Error(err) } @@ -322,9 +322,9 @@ func TestLoggerWriteResult(t *testing.T) { } func TestAppendMetadata(t *testing.T) { - var l Logger - l.clock = tstest.NewClock(tstest.ClockOpts{Start: time.Date(2000, 01, 01, 0, 0, 0, 0, time.UTC)}) - l.metricsDelta = func() string { return "metrics" } + var lg Logger + lg.clock = tstest.NewClock(tstest.ClockOpts{Start: time.Date(2000, 01, 01, 0, 0, 0, 0, time.UTC)}) + lg.metricsDelta = func() string { return "metrics" } for _, tt := range []struct { skipClientTime bool @@ -350,7 +350,7 @@ func TestAppendMetadata(t *testing.T) { {procID: 1, procSeq: 2, errDetail: "error", errData: jsontext.Value(`["something","bad","happened"]`), level: 2, want: `"logtail":{"client_time":"2000-01-01T00:00:00Z","proc_id":1,"proc_seq":2,"error":{"detail":"error","bad_data":["something","bad","happened"]}},"metrics":"metrics","v":2,`}, } { - got := string(l.appendMetadata(nil, tt.skipClientTime, tt.skipMetrics, tt.procID, tt.procSeq, tt.errDetail, tt.errData, tt.level)) + got := string(lg.appendMetadata(nil, tt.skipClientTime, tt.skipMetrics, tt.procID, tt.procSeq, tt.errDetail, tt.errData, tt.level)) if got != tt.want { t.Errorf("appendMetadata(%v, %v, %v, %v, %v, %v, %v):\n\tgot %s\n\twant %s", tt.skipClientTime, tt.skipMetrics, tt.procID, tt.procSeq, tt.errDetail, tt.errData, tt.level, got, tt.want) } @@ -362,10 +362,10 @@ func TestAppendMetadata(t *testing.T) { } func TestAppendText(t *testing.T) { - var l Logger - l.clock = tstest.NewClock(tstest.ClockOpts{Start: time.Date(2000, 01, 01, 0, 0, 0, 0, time.UTC)}) - l.metricsDelta = func() string { return "metrics" } - l.lowMem = true + var lg Logger + lg.clock = tstest.NewClock(tstest.ClockOpts{Start: time.Date(2000, 01, 01, 0, 0, 0, 0, time.UTC)}) + lg.metricsDelta = func() string { return "metrics" } + lg.lowMem = true for _, tt := range []struct { text string @@ -382,7 +382,7 @@ func TestAppendText(t *testing.T) { {text: "\b\f\n\r\t\"\\", want: `{"logtail":{"client_time":"2000-01-01T00:00:00Z"},"metrics":"metrics","text":"\b\f\n\r\t\"\\"}`}, {text: "x" + strings.Repeat("😐", maxSize), want: `{"logtail":{"client_time":"2000-01-01T00:00:00Z"},"metrics":"metrics","text":"x` + strings.Repeat("😐", 1023) + `…+1044484"}`}, } { - got := string(l.appendText(nil, []byte(tt.text), tt.skipClientTime, tt.procID, tt.procSeq, tt.level)) + got := string(lg.appendText(nil, []byte(tt.text), tt.skipClientTime, tt.procID, tt.procSeq, tt.level)) if !strings.HasSuffix(got, "\n") { t.Errorf("`%s` does not end with a newline", got) } @@ -397,10 +397,10 @@ func TestAppendText(t *testing.T) { } func TestAppendTextOrJSON(t *testing.T) { - var l Logger - l.clock = tstest.NewClock(tstest.ClockOpts{Start: time.Date(2000, 01, 01, 0, 0, 0, 0, time.UTC)}) - l.metricsDelta = func() string { return "metrics" } - l.lowMem = true + var lg Logger + lg.clock = tstest.NewClock(tstest.ClockOpts{Start: time.Date(2000, 01, 01, 0, 0, 0, 0, time.UTC)}) + lg.metricsDelta = func() string { return "metrics" } + lg.lowMem = true for _, tt := range []struct { in string @@ -419,7 +419,7 @@ func TestAppendTextOrJSON(t *testing.T) { {in: `{ "fizz" : "buzz" , "logtail" : "duplicate" , "wizz" : "wuzz" }`, want: `{"logtail":{"client_time":"2000-01-01T00:00:00Z","error":{"detail":"duplicate logtail member","bad_data":"duplicate"}}, "fizz" : "buzz" , "wizz" : "wuzz"}`}, {in: `{"long":"` + strings.Repeat("a", maxSize) + `"}`, want: `{"logtail":{"client_time":"2000-01-01T00:00:00Z","error":{"detail":"entry too large: 262155 bytes","bad_data":"{\"long\":\"` + strings.Repeat("a", 43681) + `…+218465"}}}`}, } { - got := string(l.appendTextOrJSONLocked(nil, []byte(tt.in), tt.level)) + got := string(lg.appendTextOrJSONLocked(nil, []byte(tt.in), tt.level)) if !strings.HasSuffix(got, "\n") { t.Errorf("`%s` does not end with a newline", got) } @@ -461,21 +461,21 @@ var testdataTextLog = []byte(`netcheck: report: udp=true v6=false v6os=true mapv var testdataJSONLog = []byte(`{"end":"2024-04-08T21:39:15.715291586Z","nodeId":"nQRJBE7CNTRL","physicalTraffic":[{"dst":"127.x.x.x:2","src":"100.x.x.x:0","txBytes":148,"txPkts":1},{"dst":"127.x.x.x:2","src":"100.x.x.x:0","txBytes":148,"txPkts":1},{"dst":"98.x.x.x:1025","rxBytes":640,"rxPkts":5,"src":"100.x.x.x:0","txBytes":640,"txPkts":5},{"dst":"24.x.x.x:49973","rxBytes":640,"rxPkts":5,"src":"100.x.x.x:0","txBytes":640,"txPkts":5},{"dst":"73.x.x.x:41641","rxBytes":732,"rxPkts":6,"src":"100.x.x.x:0","txBytes":820,"txPkts":7},{"dst":"75.x.x.x:1025","rxBytes":640,"rxPkts":5,"src":"100.x.x.x:0","txBytes":640,"txPkts":5},{"dst":"75.x.x.x:41641","rxBytes":640,"rxPkts":5,"src":"100.x.x.x:0","txBytes":640,"txPkts":5},{"dst":"174.x.x.x:35497","rxBytes":13008,"rxPkts":98,"src":"100.x.x.x:0","txBytes":26688,"txPkts":150},{"dst":"47.x.x.x:41641","rxBytes":640,"rxPkts":5,"src":"100.x.x.x:0","txBytes":640,"txPkts":5},{"dst":"64.x.x.x:41641","rxBytes":640,"rxPkts":5,"src":"100.x.x.x:0","txBytes":640,"txPkts":5}],"start":"2024-04-08T21:39:11.099495616Z","virtualTraffic":[{"dst":"100.x.x.x:33008","proto":6,"src":"100.x.x.x:22","txBytes":1260,"txPkts":10},{"dst":"100.x.x.x:0","proto":1,"rxBytes":420,"rxPkts":5,"src":"100.x.x.x:0","txBytes":420,"txPkts":5},{"dst":"100.x.x.x:32984","proto":6,"src":"100.x.x.x:22","txBytes":1340,"txPkts":10},{"dst":"100.x.x.x:32998","proto":6,"src":"100.x.x.x:22","txBytes":1020,"txPkts":10},{"dst":"100.x.x.x:32994","proto":6,"src":"100.x.x.x:22","txBytes":1260,"txPkts":10},{"dst":"100.x.x.x:32980","proto":6,"src":"100.x.x.x:22","txBytes":1260,"txPkts":10},{"dst":"100.x.x.x:0","proto":1,"rxBytes":420,"rxPkts":5,"src":"100.x.x.x:0","txBytes":420,"txPkts":5},{"dst":"100.x.x.x:0","proto":1,"rxBytes":420,"rxPkts":5,"src":"100.x.x.x:0","txBytes":420,"txPkts":5},{"dst":"100.x.x.x:32950","proto":6,"src":"100.x.x.x:22","txBytes":1340,"txPkts":10},{"dst":"100.x.x.x:22","proto":6,"src":"100.x.x.x:53332","txBytes":60,"txPkts":1},{"dst":"100.x.x.x:0","proto":1,"src":"100.x.x.x:0","txBytes":420,"txPkts":5},{"dst":"100.x.x.x:0","proto":1,"rxBytes":420,"rxPkts":5,"src":"100.x.x.x:0","txBytes":420,"txPkts":5},{"dst":"100.x.x.x:32966","proto":6,"src":"100.x.x.x:22","txBytes":1260,"txPkts":10},{"dst":"100.x.x.x:22","proto":6,"src":"100.x.x.x:57882","txBytes":60,"txPkts":1},{"dst":"100.x.x.x:22","proto":6,"src":"100.x.x.x:53326","txBytes":60,"txPkts":1},{"dst":"100.x.x.x:22","proto":6,"src":"100.x.x.x:57892","txBytes":60,"txPkts":1},{"dst":"100.x.x.x:32934","proto":6,"src":"100.x.x.x:22","txBytes":8712,"txPkts":55},{"dst":"100.x.x.x:0","proto":1,"rxBytes":420,"rxPkts":5,"src":"100.x.x.x:0","txBytes":420,"txPkts":5},{"dst":"100.x.x.x:32942","proto":6,"src":"100.x.x.x:22","txBytes":1260,"txPkts":10},{"dst":"100.x.x.x:0","proto":1,"rxBytes":420,"rxPkts":5,"src":"100.x.x.x:0","txBytes":420,"txPkts":5},{"dst":"100.x.x.x:32964","proto":6,"src":"100.x.x.x:22","txBytes":1260,"txPkts":10},{"dst":"100.x.x.x:0","proto":1,"rxBytes":420,"rxPkts":5,"src":"100.x.x.x:0","txBytes":420,"txPkts":5},{"dst":"100.x.x.x:0","proto":1,"rxBytes":420,"rxPkts":5,"src":"100.x.x.x:0","txBytes":420,"txPkts":5},{"dst":"100.x.x.x:22","proto":6,"src":"100.x.x.x:37238","txBytes":60,"txPkts":1},{"dst":"100.x.x.x:22","proto":6,"src":"100.x.x.x:37252","txBytes":60,"txPkts":1}]}`) func BenchmarkWriteText(b *testing.B) { - var l Logger - l.clock = tstime.StdClock{} - l.buffer = discardBuffer{} + var lg Logger + lg.clock = tstime.StdClock{} + lg.buffer = discardBuffer{} b.ReportAllocs() for range b.N { - must.Get(l.Write(testdataTextLog)) + must.Get(lg.Write(testdataTextLog)) } } func BenchmarkWriteJSON(b *testing.B) { - var l Logger - l.clock = tstime.StdClock{} - l.buffer = discardBuffer{} + var lg Logger + lg.clock = tstime.StdClock{} + lg.buffer = discardBuffer{} b.ReportAllocs() for range b.N { - must.Get(l.Write(testdataJSONLog)) + must.Get(lg.Write(testdataJSONLog)) } } diff --git a/metrics/metrics.go b/metrics/metrics.go index d1b1c06c9dc2c..19966d395f815 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -11,7 +11,6 @@ import ( "io" "slices" "strings" - "sync" "tailscale.com/syncs" ) @@ -41,7 +40,7 @@ type LabelMap struct { Label string expvar.Map // shardedIntMu orders the initialization of new shardedint keys - shardedIntMu sync.Mutex + shardedIntMu syncs.Mutex } // SetInt64 sets the *Int value stored under the given map key. diff --git a/net/art/stride_table.go b/net/art/stride_table.go index 5ff0455fed872..5050df24500ce 100644 --- a/net/art/stride_table.go +++ b/net/art/stride_table.go @@ -303,21 +303,21 @@ func formatPrefixTable(addr uint8, len int) string { // // For example, childPrefixOf("192.168.0.0/16", 8) == "192.168.8.0/24". func childPrefixOf(parent netip.Prefix, stride uint8) netip.Prefix { - l := parent.Bits() - if l%8 != 0 { + ln := parent.Bits() + if ln%8 != 0 { panic("parent prefix is not 8-bit aligned") } - if l >= parent.Addr().BitLen() { + if ln >= parent.Addr().BitLen() { panic("parent prefix cannot be extended further") } - off := l / 8 + off := ln / 8 if parent.Addr().Is4() { bs := parent.Addr().As4() bs[off] = stride - return netip.PrefixFrom(netip.AddrFrom4(bs), l+8) + return netip.PrefixFrom(netip.AddrFrom4(bs), ln+8) } else { bs := parent.Addr().As16() bs[off] = stride - return netip.PrefixFrom(netip.AddrFrom16(bs), l+8) + return netip.PrefixFrom(netip.AddrFrom16(bs), ln+8) } } diff --git a/net/art/stride_table_test.go b/net/art/stride_table_test.go index bff2bb7c507fd..4ccef1fe083cb 100644 --- a/net/art/stride_table_test.go +++ b/net/art/stride_table_test.go @@ -377,8 +377,8 @@ func pfxMask(pfxLen int) uint8 { func allPrefixes() []slowEntry[int] { ret := make([]slowEntry[int], 0, lastHostIndex) for i := 1; i < lastHostIndex+1; i++ { - a, l := inversePrefixIndex(i) - ret = append(ret, slowEntry[int]{a, l, i}) + a, ln := inversePrefixIndex(i) + ret = append(ret, slowEntry[int]{a, ln, i}) } return ret } diff --git a/net/batching/conn_linux.go b/net/batching/conn_linux.go index 7f6c4ed422e31..bd7ac25be2a4d 100644 --- a/net/batching/conn_linux.go +++ b/net/batching/conn_linux.go @@ -353,7 +353,7 @@ func getGSOSizeFromControl(control []byte) (int, error) { ) for len(rem) > unix.SizeofCmsghdr { - hdr, data, rem, err = unix.ParseOneSocketControlMessage(control) + hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) if err != nil { return 0, fmt.Errorf("error parsing socket control message: %w", err) } diff --git a/net/batching/conn_linux_test.go b/net/batching/conn_linux_test.go index e518c3f9f06d9..c2cc463ebc6ad 100644 --- a/net/batching/conn_linux_test.go +++ b/net/batching/conn_linux_test.go @@ -7,9 +7,11 @@ import ( "encoding/binary" "net" "testing" + "unsafe" "github.com/tailscale/wireguard-go/conn" "golang.org/x/net/ipv6" + "golang.org/x/sys/unix" "tailscale.com/net/packet" ) @@ -314,3 +316,36 @@ func TestMinReadBatchMsgsLen(t *testing.T) { t.Fatalf("IdealBatchSize: %d != conn.IdealBatchSize(): %d", IdealBatchSize, conn.IdealBatchSize) } } + +func Test_getGSOSizeFromControl_MultipleMessages(t *testing.T) { + // Test that getGSOSizeFromControl correctly parses UDP_GRO when it's not the first control message. + const expectedGSOSize = 1420 + + // First message: IP_TOS + firstMsgLen := unix.CmsgSpace(1) + firstMsg := make([]byte, firstMsgLen) + hdr1 := (*unix.Cmsghdr)(unsafe.Pointer(&firstMsg[0])) + hdr1.Level = unix.SOL_IP + hdr1.Type = unix.IP_TOS + hdr1.SetLen(unix.CmsgLen(1)) + firstMsg[unix.SizeofCmsghdr] = 0 + + // Second message: UDP_GRO + secondMsgLen := unix.CmsgSpace(2) + secondMsg := make([]byte, secondMsgLen) + hdr2 := (*unix.Cmsghdr)(unsafe.Pointer(&secondMsg[0])) + hdr2.Level = unix.SOL_UDP + hdr2.Type = unix.UDP_GRO + hdr2.SetLen(unix.CmsgLen(2)) + binary.NativeEndian.PutUint16(secondMsg[unix.SizeofCmsghdr:], expectedGSOSize) + + control := append(firstMsg, secondMsg...) + + gsoSize, err := getGSOSizeFromControl(control) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gsoSize != expectedGSOSize { + t.Errorf("got GSO size %d, want %d", gsoSize, expectedGSOSize) + } +} diff --git a/net/captivedetection/captivedetection.go b/net/captivedetection/captivedetection.go index a06362a5b4d1d..3ec820b794400 100644 --- a/net/captivedetection/captivedetection.go +++ b/net/captivedetection/captivedetection.go @@ -18,6 +18,7 @@ import ( "time" "tailscale.com/net/netmon" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/types/logger" ) @@ -32,7 +33,7 @@ type Detector struct { // currIfIndex is the index of the interface that is currently being used by the httpClient. currIfIndex int // mu guards currIfIndex. - mu sync.Mutex + mu syncs.Mutex // logf is the logger used for logging messages. If it is nil, log.Printf is used. logf logger.Logf } diff --git a/net/dns/manager_windows.go b/net/dns/manager_windows.go index 444c5d37debf4..5ccadbab2d9ad 100644 --- a/net/dns/manager_windows.go +++ b/net/dns/manager_windows.go @@ -16,7 +16,6 @@ import ( "slices" "sort" "strings" - "sync" "syscall" "time" @@ -27,6 +26,7 @@ import ( "tailscale.com/control/controlknobs" "tailscale.com/envknob" "tailscale.com/health" + "tailscale.com/syncs" "tailscale.com/types/logger" "tailscale.com/util/dnsname" "tailscale.com/util/syspolicy/pkey" @@ -51,7 +51,7 @@ type windowsManager struct { unregisterPolicyChangeCb func() // called when the manager is closing - mu sync.Mutex + mu syncs.Mutex closing bool } diff --git a/net/dns/manager_windows_test.go b/net/dns/manager_windows_test.go index 7c0139f455d70..aa538a0f66dcb 100644 --- a/net/dns/manager_windows_test.go +++ b/net/dns/manager_windows_test.go @@ -550,8 +550,8 @@ func genRandomSubdomains(t *testing.T, n int) []dnsname.FQDN { const charset = "abcdefghijklmnopqrstuvwxyz" for len(domains) < cap(domains) { - l := r.Intn(19) + 1 - b := make([]byte, l) + ln := r.Intn(19) + 1 + b := make([]byte, ln) for i := range b { b[i] = charset[r.Intn(len(charset))] } diff --git a/net/dns/resolver/debug.go b/net/dns/resolver/debug.go index 0f9b106bb2eb4..a41462e185e24 100644 --- a/net/dns/resolver/debug.go +++ b/net/dns/resolver/debug.go @@ -8,12 +8,12 @@ import ( "html" "net/http" "strconv" - "sync" "sync/atomic" "time" "tailscale.com/feature/buildfeatures" "tailscale.com/health" + "tailscale.com/syncs" ) func init() { @@ -39,7 +39,7 @@ func init() { var fwdLogAtomic atomic.Pointer[fwdLog] type fwdLog struct { - mu sync.Mutex + mu syncs.Mutex pos int // ent[pos] is next entry ent []fwdLogEntry } diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index 86f0f5b8c48c4..5adc43efca860 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -37,6 +37,7 @@ import ( "tailscale.com/net/netx" "tailscale.com/net/sockstats" "tailscale.com/net/tsdial" + "tailscale.com/syncs" "tailscale.com/types/dnstype" "tailscale.com/types/logger" "tailscale.com/types/nettype" @@ -231,7 +232,7 @@ type forwarder struct { ctx context.Context // good until Close ctxCancel context.CancelFunc // closes ctx - mu sync.Mutex // guards following + mu syncs.Mutex // guards following dohClient map[string]*http.Client // urlBase -> client diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index 93cbf3839c923..3185cbe2b35ff 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -214,7 +214,7 @@ type Resolver struct { closed chan struct{} // mu guards the following fields from being updated while used. - mu sync.Mutex + mu syncs.Mutex localDomains []dnsname.FQDN hostToIP map[dnsname.FQDN][]netip.Addr ipToHost map[netip.Addr]dnsname.FQDN diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go index 94d4bbee7955f..e222b983f0287 100644 --- a/net/dnscache/dnscache.go +++ b/net/dnscache/dnscache.go @@ -20,6 +20,7 @@ import ( "tailscale.com/envknob" "tailscale.com/net/netx" + "tailscale.com/syncs" "tailscale.com/types/logger" "tailscale.com/util/cloudenv" "tailscale.com/util/singleflight" @@ -97,7 +98,7 @@ type Resolver struct { sf singleflight.Group[string, ipRes] - mu sync.Mutex + mu syncs.Mutex ipCache map[string]ipCacheEntry } @@ -474,7 +475,7 @@ type dialCall struct { d *dialer network, address, host, port string - mu sync.Mutex // lock ordering: dialer.mu, then dialCall.mu + mu syncs.Mutex // lock ordering: dialer.mu, then dialCall.mu fails map[netip.Addr]error // set of IPs that failed to dial thus far } diff --git a/net/ktimeout/ktimeout_linux_test.go b/net/ktimeout/ktimeout_linux_test.go index df41567454f4b..0330923a96c13 100644 --- a/net/ktimeout/ktimeout_linux_test.go +++ b/net/ktimeout/ktimeout_linux_test.go @@ -19,11 +19,11 @@ func TestSetUserTimeout(t *testing.T) { // set in ktimeout.UserTimeout above. lc.SetMultipathTCP(false) - l := must.Get(lc.Listen(context.Background(), "tcp", "localhost:0")) - defer l.Close() + ln := must.Get(lc.Listen(context.Background(), "tcp", "localhost:0")) + defer ln.Close() var err error - if e := must.Get(l.(*net.TCPListener).SyscallConn()).Control(func(fd uintptr) { + if e := must.Get(ln.(*net.TCPListener).SyscallConn()).Control(func(fd uintptr) { err = SetUserTimeout(fd, 0) }); e != nil { t.Fatal(e) @@ -31,12 +31,12 @@ func TestSetUserTimeout(t *testing.T) { if err != nil { t.Fatal(err) } - v := must.Get(unix.GetsockoptInt(int(must.Get(l.(*net.TCPListener).File()).Fd()), unix.SOL_TCP, unix.TCP_USER_TIMEOUT)) + v := must.Get(unix.GetsockoptInt(int(must.Get(ln.(*net.TCPListener).File()).Fd()), unix.SOL_TCP, unix.TCP_USER_TIMEOUT)) if v != 0 { t.Errorf("TCP_USER_TIMEOUT: got %v; want 0", v) } - if e := must.Get(l.(*net.TCPListener).SyscallConn()).Control(func(fd uintptr) { + if e := must.Get(ln.(*net.TCPListener).SyscallConn()).Control(func(fd uintptr) { err = SetUserTimeout(fd, 30*time.Second) }); e != nil { t.Fatal(e) @@ -44,7 +44,7 @@ func TestSetUserTimeout(t *testing.T) { if err != nil { t.Fatal(err) } - v = must.Get(unix.GetsockoptInt(int(must.Get(l.(*net.TCPListener).File()).Fd()), unix.SOL_TCP, unix.TCP_USER_TIMEOUT)) + v = must.Get(unix.GetsockoptInt(int(must.Get(ln.(*net.TCPListener).File()).Fd()), unix.SOL_TCP, unix.TCP_USER_TIMEOUT)) if v != 30000 { t.Errorf("TCP_USER_TIMEOUT: got %v; want 30000", v) } diff --git a/net/ktimeout/ktimeout_test.go b/net/ktimeout/ktimeout_test.go index 7befa3b1ab077..b534f046caddb 100644 --- a/net/ktimeout/ktimeout_test.go +++ b/net/ktimeout/ktimeout_test.go @@ -14,11 +14,11 @@ func ExampleUserTimeout() { lc := net.ListenConfig{ Control: UserTimeout(30 * time.Second), } - l, err := lc.Listen(context.TODO(), "tcp", "127.0.0.1:0") + ln, err := lc.Listen(context.TODO(), "tcp", "127.0.0.1:0") if err != nil { fmt.Printf("error: %v", err) return } - l.Close() + ln.Close() // Output: } diff --git a/net/memnet/listener.go b/net/memnet/listener.go index 202026e160b27..dded97995bbc1 100644 --- a/net/memnet/listener.go +++ b/net/memnet/listener.go @@ -39,16 +39,16 @@ func Listen(addr string) *Listener { } // Addr implements net.Listener.Addr. -func (l *Listener) Addr() net.Addr { - return l.addr +func (ln *Listener) Addr() net.Addr { + return ln.addr } // Close closes the pipe listener. -func (l *Listener) Close() error { +func (ln *Listener) Close() error { var cleanup func() - l.closeOnce.Do(func() { - cleanup = l.onClose - close(l.closed) + ln.closeOnce.Do(func() { + cleanup = ln.onClose + close(ln.closed) }) if cleanup != nil { cleanup() @@ -57,11 +57,11 @@ func (l *Listener) Close() error { } // Accept blocks until a new connection is available or the listener is closed. -func (l *Listener) Accept() (net.Conn, error) { +func (ln *Listener) Accept() (net.Conn, error) { select { - case c := <-l.ch: + case c := <-ln.ch: return c, nil - case <-l.closed: + case <-ln.closed: return nil, net.ErrClosed } } @@ -70,18 +70,18 @@ func (l *Listener) Accept() (net.Conn, error) { // The provided Context must be non-nil. If the context expires before the // connection is complete, an error is returned. Once successfully connected // any expiration of the context will not affect the connection. -func (l *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn, err error) { +func (ln *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn, err error) { if !strings.HasSuffix(network, "tcp") { return nil, net.UnknownNetworkError(network) } - if connAddr(addr) != l.addr { + if connAddr(addr) != ln.addr { return nil, &net.AddrError{ Err: "invalid address", Addr: addr, } } - newConn := l.NewConn + newConn := ln.NewConn if newConn == nil { newConn = func(network, addr string, maxBuf int) (Conn, Conn) { return NewConn(addr, maxBuf) @@ -98,9 +98,9 @@ func (l *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn, select { case <-ctx.Done(): return nil, ctx.Err() - case <-l.closed: + case <-ln.closed: return nil, net.ErrClosed - case l.ch <- s: + case ln.ch <- s: return c, nil } } diff --git a/net/memnet/listener_test.go b/net/memnet/listener_test.go index 73b67841ad08c..b6ceb3dfa94cf 100644 --- a/net/memnet/listener_test.go +++ b/net/memnet/listener_test.go @@ -9,10 +9,10 @@ import ( ) func TestListener(t *testing.T) { - l := Listen("srv.local") - defer l.Close() + ln := Listen("srv.local") + defer ln.Close() go func() { - c, err := l.Accept() + c, err := ln.Accept() if err != nil { t.Error(err) return @@ -20,11 +20,11 @@ func TestListener(t *testing.T) { defer c.Close() }() - if c, err := l.Dial(context.Background(), "tcp", "invalid"); err == nil { + if c, err := ln.Dial(context.Background(), "tcp", "invalid"); err == nil { c.Close() t.Fatalf("dial to invalid address succeeded") } - c, err := l.Dial(context.Background(), "tcp", "srv.local") + c, err := ln.Dial(context.Background(), "tcp", "srv.local") if err != nil { t.Fatalf("dial failed: %v", err) return diff --git a/net/memnet/memnet.go b/net/memnet/memnet.go index 1e43df2daaaae..db9e3872f6f26 100644 --- a/net/memnet/memnet.go +++ b/net/memnet/memnet.go @@ -12,9 +12,9 @@ import ( "fmt" "net" "net/netip" - "sync" "tailscale.com/net/netx" + "tailscale.com/syncs" ) var _ netx.Network = (*Network)(nil) @@ -26,7 +26,7 @@ var _ netx.Network = (*Network)(nil) // // Its zero value is a valid [netx.Network] implementation. type Network struct { - mu sync.Mutex + mu syncs.Mutex lns map[string]*Listener // address -> listener } diff --git a/net/netaddr/netaddr.go b/net/netaddr/netaddr.go index 1ab6c053a523e..a04acd57aa670 100644 --- a/net/netaddr/netaddr.go +++ b/net/netaddr/netaddr.go @@ -34,7 +34,7 @@ func FromStdIPNet(std *net.IPNet) (prefix netip.Prefix, ok bool) { } ip = ip.Unmap() - if l := len(std.Mask); l != net.IPv4len && l != net.IPv6len { + if ln := len(std.Mask); ln != net.IPv4len && ln != net.IPv6len { // Invalid mask. return netip.Prefix{}, false } diff --git a/net/netcheck/netcheck.go b/net/netcheck/netcheck.go index 726221675fb03..c5a3d2392007e 100644 --- a/net/netcheck/netcheck.go +++ b/net/netcheck/netcheck.go @@ -235,7 +235,7 @@ type Client struct { testEnoughRegions int testCaptivePortalDelay time.Duration - mu sync.Mutex // guards following + mu syncs.Mutex // guards following nextFull bool // do a full region scan, even if last != nil prev map[time.Time]*Report // some previous reports last *Report // most recent report @@ -597,7 +597,7 @@ type reportState struct { stopProbeCh chan struct{} waitPortMap sync.WaitGroup - mu sync.Mutex + mu syncs.Mutex report *Report // to be returned by GetReport inFlight map[stun.TxID]func(netip.AddrPort) // called without c.mu held gotEP4 netip.AddrPort @@ -993,9 +993,9 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe c.logf("[v1] netcheck: measuring HTTPS latency of %v (%d): %v", reg.RegionCode, reg.RegionID, err) } else { rs.mu.Lock() - if l, ok := rs.report.RegionLatency[reg.RegionID]; !ok { + if latency, ok := rs.report.RegionLatency[reg.RegionID]; !ok { mak.Set(&rs.report.RegionLatency, reg.RegionID, d) - } else if l >= d { + } else if latency >= d { rs.report.RegionLatency[reg.RegionID] = d } // We set these IPv4 and IPv6 but they're not really used @@ -1214,9 +1214,9 @@ func (c *Client) measureAllICMPLatency(ctx context.Context, rs *reportState, nee } else if ok { c.logf("[v1] ICMP latency of %v (%d): %v", reg.RegionCode, reg.RegionID, d) rs.mu.Lock() - if l, ok := rs.report.RegionLatency[reg.RegionID]; !ok { + if latency, ok := rs.report.RegionLatency[reg.RegionID]; !ok { mak.Set(&rs.report.RegionLatency, reg.RegionID, d) - } else if l >= d { + } else if latency >= d { rs.report.RegionLatency[reg.RegionID] = d } diff --git a/net/netmon/interfaces_darwin.go b/net/netmon/interfaces_darwin.go index b175f980a2109..126040350bdb2 100644 --- a/net/netmon/interfaces_darwin.go +++ b/net/netmon/interfaces_darwin.go @@ -7,12 +7,12 @@ import ( "fmt" "net" "strings" - "sync" "syscall" "unsafe" "golang.org/x/net/route" "golang.org/x/sys/unix" + "tailscale.com/syncs" "tailscale.com/util/mak" ) @@ -26,7 +26,7 @@ func parseRoutingTable(rib []byte) ([]route.Message, error) { } var ifNames struct { - sync.Mutex + syncs.Mutex m map[int]string // ifindex => name } diff --git a/net/netmon/netmon.go b/net/netmon/netmon.go index f7d1b1107e379..657da04d5978c 100644 --- a/net/netmon/netmon.go +++ b/net/netmon/netmon.go @@ -15,6 +15,7 @@ import ( "time" "tailscale.com/feature/buildfeatures" + "tailscale.com/syncs" "tailscale.com/types/logger" "tailscale.com/util/clientmetric" "tailscale.com/util/eventbus" @@ -65,7 +66,7 @@ type Monitor struct { // and not change at runtime. tsIfName string // tailscale interface name, if known/set ("tailscale0", "utun3", ...) - mu sync.Mutex // guards all following fields + mu syncs.Mutex // guards all following fields cbs set.HandleSet[ChangeFunc] ifState *State gwValid bool // whether gw and gwSelfIP are valid diff --git a/net/netns/netns.go b/net/netns/netns.go index a473506fac024..81ab5e2a212a6 100644 --- a/net/netns/netns.go +++ b/net/netns/netns.go @@ -17,6 +17,7 @@ import ( "context" "net" "net/netip" + "runtime" "sync/atomic" "tailscale.com/net/netknob" @@ -39,18 +40,36 @@ var bindToInterfaceByRoute atomic.Bool // setting the TS_BIND_TO_INTERFACE_BY_ROUTE. // // Currently, this only changes the behaviour on macOS and Windows. -func SetBindToInterfaceByRoute(v bool) { - bindToInterfaceByRoute.Store(v) +func SetBindToInterfaceByRoute(logf logger.Logf, v bool) { + if bindToInterfaceByRoute.Swap(v) != v { + logf("netns: bindToInterfaceByRoute changed to %v", v) + } } var disableBindConnToInterface atomic.Bool // SetDisableBindConnToInterface disables the (normal) behavior of binding -// connections to the default network interface. +// connections to the default network interface on Darwin nodes. // -// Currently, this only has an effect on Darwin. -func SetDisableBindConnToInterface(v bool) { - disableBindConnToInterface.Store(v) +// Unless you intended to disable this for tailscaled on macos (which is likely +// to break things), you probably wanted to set +// SetDisableBindConnToInterfaceAppleExt which will disable explicit interface +// binding only when tailscaled is running inside a network extension process. +func SetDisableBindConnToInterface(logf logger.Logf, v bool) { + if disableBindConnToInterface.Swap(v) != v { + logf("netns: disableBindConnToInterface changed to %v", v) + } +} + +var disableBindConnToInterfaceAppleExt atomic.Bool + +// SetDisableBindConnToInterfaceAppleExt disables the (normal) behavior of binding +// connections to the default network interface but only on Apple clients where +// tailscaled is running inside a network extension. +func SetDisableBindConnToInterfaceAppleExt(logf logger.Logf, v bool) { + if runtime.GOOS == "darwin" && disableBindConnToInterfaceAppleExt.Swap(v) != v { + logf("netns: disableBindConnToInterfaceAppleExt changed to %v", v) + } } // Listener returns a new net.Listener with its Control hook func diff --git a/net/netns/netns_darwin.go b/net/netns/netns_darwin.go index 1f30f00d2a870..ff05a3f3139c3 100644 --- a/net/netns/netns_darwin.go +++ b/net/netns/netns_darwin.go @@ -21,6 +21,7 @@ import ( "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" "tailscale.com/types/logger" + "tailscale.com/version" ) func control(logf logger.Logf, netMon *netmon.Monitor) func(network, address string, c syscall.RawConn) error { @@ -36,13 +37,11 @@ var errInterfaceStateInvalid = errors.New("interface state invalid") // controlLogf binds c to a particular interface as necessary to dial the // provided (network, address). func controlLogf(logf logger.Logf, netMon *netmon.Monitor, network, address string, c syscall.RawConn) error { - if isLocalhost(address) { - // Don't bind to an interface for localhost connections. + if disableBindConnToInterface.Load() || (version.IsMacGUIVariant() && disableBindConnToInterfaceAppleExt.Load()) { return nil } - if disableBindConnToInterface.Load() { - logf("netns_darwin: binding connection to interfaces disabled") + if isLocalhost(address) { return nil } diff --git a/net/netns/netns_default.go b/net/netns/netns_default.go index 94f24d8fa4e19..58c5936640e4f 100644 --- a/net/netns/netns_default.go +++ b/net/netns/netns_default.go @@ -20,3 +20,7 @@ func control(logger.Logf, *netmon.Monitor) func(network, address string, c sysca func controlC(network, address string, c syscall.RawConn) error { return nil } + +func UseSocketMark() bool { + return false +} diff --git a/net/netns/netns_dw.go b/net/netns/netns_dw.go index f92ba9462c32a..b9f750e8a6657 100644 --- a/net/netns/netns_dw.go +++ b/net/netns/netns_dw.go @@ -25,3 +25,7 @@ func parseAddress(address string) (addr netip.Addr, err error) { return netip.ParseAddr(host) } + +func UseSocketMark() bool { + return false +} diff --git a/net/netns/zsyscall_windows.go b/net/netns/zsyscall_windows.go index 07e2181be222c..3d8f06e097340 100644 --- a/net/netns/zsyscall_windows.go +++ b/net/netns/zsyscall_windows.go @@ -45,7 +45,7 @@ var ( ) func getBestInterfaceEx(sockaddr *winipcfg.RawSockaddrInet, bestIfaceIndex *uint32) (ret error) { - r0, _, _ := syscall.Syscall(procGetBestInterfaceEx.Addr(), 2, uintptr(unsafe.Pointer(sockaddr)), uintptr(unsafe.Pointer(bestIfaceIndex)), 0) + r0, _, _ := syscall.SyscallN(procGetBestInterfaceEx.Addr(), uintptr(unsafe.Pointer(sockaddr)), uintptr(unsafe.Pointer(bestIfaceIndex))) if r0 != 0 { ret = syscall.Errno(r0) } diff --git a/net/netutil/netutil.go b/net/netutil/netutil.go index bc64e8fdc9eb4..5c42f51c64837 100644 --- a/net/netutil/netutil.go +++ b/net/netutil/netutil.go @@ -8,7 +8,8 @@ import ( "bufio" "io" "net" - "sync" + + "tailscale.com/syncs" ) // NewOneConnListener returns a net.Listener that returns c on its @@ -29,7 +30,7 @@ func NewOneConnListener(c net.Conn, addr net.Addr) net.Listener { type oneConnListener struct { addr net.Addr - mu sync.Mutex + mu syncs.Mutex conn net.Conn } diff --git a/net/packet/tsmp.go b/net/packet/tsmp.go index 0ea321e84eb2a..8fad1d5037468 100644 --- a/net/packet/tsmp.go +++ b/net/packet/tsmp.go @@ -15,7 +15,9 @@ import ( "fmt" "net/netip" + "go4.org/mem" "tailscale.com/types/ipproto" + "tailscale.com/types/key" ) const minTSMPSize = 7 // the rejected body is 7 bytes @@ -72,6 +74,9 @@ const ( // TSMPTypePong is the type byte for a TailscalePongResponse. TSMPTypePong TSMPType = 'o' + + // TSPMTypeDiscoAdvertisement is the type byte for sending disco keys + TSMPTypeDiscoAdvertisement TSMPType = 'a' ) type TailscaleRejectReason byte @@ -259,3 +264,53 @@ func (h TSMPPongReply) Marshal(buf []byte) error { binary.BigEndian.PutUint16(buf[9:11], h.PeerAPIPort) return nil } + +// TSMPDiscoKeyAdvertisement is a TSMP message that's used for distributing Disco Keys. +// +// On the wire, after the IP header, it's currently 33 bytes: +// - 'a' (TSMPTypeDiscoAdvertisement) +// - 32 disco key bytes +type TSMPDiscoKeyAdvertisement struct { + Src, Dst netip.Addr + Key key.DiscoPublic +} + +func (ka *TSMPDiscoKeyAdvertisement) Marshal() ([]byte, error) { + var iph Header + if ka.Src.Is4() { + iph = IP4Header{ + IPProto: ipproto.TSMP, + Src: ka.Src, + Dst: ka.Dst, + } + } else { + iph = IP6Header{ + IPProto: ipproto.TSMP, + Src: ka.Src, + Dst: ka.Dst, + } + } + payload := make([]byte, 0, 33) + payload = append(payload, byte(TSMPTypeDiscoAdvertisement)) + payload = ka.Key.AppendTo(payload) + if len(payload) != 33 { + // Mostly to safeguard against ourselves changing this in the future. + return []byte{}, fmt.Errorf("expected payload length 33, got %d", len(payload)) + } + + return Generate(iph, payload), nil +} + +func (pp *Parsed) AsTSMPDiscoAdvertisement() (tka TSMPDiscoKeyAdvertisement, ok bool) { + if pp.IPProto != ipproto.TSMP { + return + } + p := pp.Payload() + if len(p) < 33 || p[0] != byte(TSMPTypeDiscoAdvertisement) { + return + } + tka.Src = pp.Src.Addr() + tka.Key = key.DiscoPublicFromRaw32(mem.B(p[1:33])) + + return tka, true +} diff --git a/net/packet/tsmp_test.go b/net/packet/tsmp_test.go index e261e6a4199b3..d8f1d38d57180 100644 --- a/net/packet/tsmp_test.go +++ b/net/packet/tsmp_test.go @@ -4,8 +4,14 @@ package packet import ( + "bytes" + "encoding/hex" "net/netip" + "slices" "testing" + + "go4.org/mem" + "tailscale.com/types/key" ) func TestTailscaleRejectedHeader(t *testing.T) { @@ -71,3 +77,62 @@ func TestTailscaleRejectedHeader(t *testing.T) { } } } + +func TestTSMPDiscoKeyAdvertisementMarshal(t *testing.T) { + var ( + // IPv4: Ver(4)Len(5), TOS, Len(53), ID, Flags, TTL(64), Proto(99), Cksum + headerV4, _ = hex.DecodeString("45000035000000004063705d") + // IPv6: Ver(6)TCFlow, Len(33), NextHdr(99), HopLim(64) + headerV6, _ = hex.DecodeString("6000000000216340") + + packetType = []byte{'a'} + testKey = bytes.Repeat([]byte{'a'}, 32) + + // IPs + srcV4 = netip.MustParseAddr("1.2.3.4") + dstV4 = netip.MustParseAddr("4.3.2.1") + srcV6 = netip.MustParseAddr("2001:db8::1") + dstV6 = netip.MustParseAddr("2001:db8::2") + ) + + join := func(parts ...[]byte) []byte { + return bytes.Join(parts, nil) + } + + tests := []struct { + name string + tka TSMPDiscoKeyAdvertisement + want []byte + }{ + { + name: "v4Header", + tka: TSMPDiscoKeyAdvertisement{ + Src: srcV4, + Dst: dstV4, + Key: key.DiscoPublicFromRaw32(mem.B(testKey)), + }, + want: join(headerV4, srcV4.AsSlice(), dstV4.AsSlice(), packetType, testKey), + }, + { + name: "v6Header", + tka: TSMPDiscoKeyAdvertisement{ + Src: srcV6, + Dst: dstV6, + Key: key.DiscoPublicFromRaw32(mem.B(testKey)), + }, + want: join(headerV6, srcV6.AsSlice(), dstV6.AsSlice(), packetType, testKey), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.tka.Marshal() + if err != nil { + t.Errorf("error mashalling TSMPDiscoAdvertisement: %s", err) + } + if !slices.Equal(got, tt.want) { + t.Errorf("error mashalling TSMPDiscoAdvertisement, expected: \n%x, \ngot:\n%x", tt.want, got) + } + }) + } +} diff --git a/net/ping/ping.go b/net/ping/ping.go index 1ff3862dc65a1..8e16a692a8136 100644 --- a/net/ping/ping.go +++ b/net/ping/ping.go @@ -23,6 +23,7 @@ import ( "golang.org/x/net/icmp" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "tailscale.com/syncs" "tailscale.com/types/logger" "tailscale.com/util/mak" ) @@ -64,7 +65,7 @@ type Pinger struct { wg sync.WaitGroup // Following fields protected by mu - mu sync.Mutex + mu syncs.Mutex // conns is a map of "type" to net.PacketConn, type is either // "ip4:icmp" or "ip6:icmp" conns map[string]net.PacketConn diff --git a/net/portmapper/pmpresultcode_string.go b/net/portmapper/pmpresultcode_string.go index 603636adec044..18d911d944126 100644 --- a/net/portmapper/pmpresultcode_string.go +++ b/net/portmapper/pmpresultcode_string.go @@ -24,8 +24,9 @@ const _pmpResultCode_name = "OKUnsupportedVersionNotAuthorizedNetworkFailureOutO var _pmpResultCode_index = [...]uint8{0, 2, 20, 33, 47, 61, 78} func (i pmpResultCode) String() string { - if i >= pmpResultCode(len(_pmpResultCode_index)-1) { + idx := int(i) - 0 + if i < 0 || idx >= len(_pmpResultCode_index)-1 { return "pmpResultCode(" + strconv.FormatInt(int64(i), 10) + ")" } - return _pmpResultCode_name[_pmpResultCode_index[i]:_pmpResultCode_index[i+1]] + return _pmpResultCode_name[_pmpResultCode_index[idx]:_pmpResultCode_index[idx+1]] } diff --git a/net/portmapper/portmapper.go b/net/portmapper/portmapper.go index 9368d1c4ee05b..16a981d1d8336 100644 --- a/net/portmapper/portmapper.go +++ b/net/portmapper/portmapper.go @@ -14,7 +14,6 @@ import ( "net/http" "net/netip" "slices" - "sync" "sync/atomic" "time" @@ -123,7 +122,7 @@ type Client struct { testPxPPort uint16 // if non-zero, pxpPort to use for tests testUPnPPort uint16 // if non-zero, uPnPPort to use for tests - mu sync.Mutex // guards following, and all fields thereof + mu syncs.Mutex // guards following, and all fields thereof // runningCreate is whether we're currently working on creating // a port mapping (whether GetCachedMappingOrStartCreatingOne kicked diff --git a/net/socks5/socks5.go b/net/socks5/socks5.go index 4a5befa1d2fef..2e277147bc50d 100644 --- a/net/socks5/socks5.go +++ b/net/socks5/socks5.go @@ -120,10 +120,10 @@ func (s *Server) logf(format string, args ...any) { } // Serve accepts and handles incoming connections on the given listener. -func (s *Server) Serve(l net.Listener) error { - defer l.Close() +func (s *Server) Serve(ln net.Listener) error { + defer ln.Close() for { - c, err := l.Accept() + c, err := ln.Accept() if err != nil { return err } diff --git a/net/sockstats/label_string.go b/net/sockstats/label_string.go index f9a111ad71e08..cc503d943f622 100644 --- a/net/sockstats/label_string.go +++ b/net/sockstats/label_string.go @@ -28,8 +28,9 @@ const _Label_name = "ControlClientAutoControlClientDialerDERPHTTPClientLogtailLo var _Label_index = [...]uint8{0, 17, 36, 50, 63, 78, 93, 107, 123, 140, 157, 169, 186, 201} func (i Label) String() string { - if i >= Label(len(_Label_index)-1) { + idx := int(i) - 0 + if i < 0 || idx >= len(_Label_index)-1 { return "Label(" + strconv.FormatInt(int64(i), 10) + ")" } - return _Label_name[_Label_index[i]:_Label_index[i+1]] + return _Label_name[_Label_index[idx]:_Label_index[idx+1]] } diff --git a/net/sockstats/sockstats_tsgo.go b/net/sockstats/sockstats_tsgo.go index fec9ec3b0dad2..aa875df9aeddd 100644 --- a/net/sockstats/sockstats_tsgo.go +++ b/net/sockstats/sockstats_tsgo.go @@ -10,12 +10,12 @@ import ( "fmt" "net" "strings" - "sync" "sync/atomic" "syscall" "time" "tailscale.com/net/netmon" + "tailscale.com/syncs" "tailscale.com/types/logger" "tailscale.com/util/clientmetric" "tailscale.com/version" @@ -40,7 +40,7 @@ var sockStats = struct { // mu protects fields in this group (but not the fields within // sockStatCounters). It should not be held in the per-read/write // callbacks. - mu sync.Mutex + mu syncs.Mutex countersByLabel map[Label]*sockStatCounters knownInterfaces map[int]string // interface index -> name usedInterfaces map[int]int // set of interface indexes diff --git a/net/speedtest/speedtest_server.go b/net/speedtest/speedtest_server.go index 9dd78b195fff4..72f85fa15b019 100644 --- a/net/speedtest/speedtest_server.go +++ b/net/speedtest/speedtest_server.go @@ -17,9 +17,9 @@ import ( // connections and handles each one in a goroutine. Because it runs in an infinite loop, // this function only returns if any of the speedtests return with errors, or if the // listener is closed. -func Serve(l net.Listener) error { +func Serve(ln net.Listener) error { for { - conn, err := l.Accept() + conn, err := ln.Accept() if errors.Is(err, net.ErrClosed) { return nil } diff --git a/net/speedtest/speedtest_test.go b/net/speedtest/speedtest_test.go index 69fdb6b5685c0..bb8f2676af8c3 100644 --- a/net/speedtest/speedtest_test.go +++ b/net/speedtest/speedtest_test.go @@ -21,13 +21,13 @@ func TestDownload(t *testing.T) { flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/17338") // start a listener and find the port where the server will be listening. - l, err := net.Listen("tcp", ":0") + ln, err := net.Listen("tcp", ":0") if err != nil { t.Fatal(err) } - t.Cleanup(func() { l.Close() }) + t.Cleanup(func() { ln.Close() }) - serverIP := l.Addr().String() + serverIP := ln.Addr().String() t.Log("server IP found:", serverIP) type state struct { @@ -40,7 +40,7 @@ func TestDownload(t *testing.T) { stateChan := make(chan state, 1) go func() { - err := Serve(l) + err := Serve(ln) stateChan <- state{err: err} }() @@ -84,7 +84,7 @@ func TestDownload(t *testing.T) { }) // causes the server goroutine to finish - l.Close() + ln.Close() testState := <-stateChan if testState.err != nil { diff --git a/net/tsdial/dnsmap.go b/net/tsdial/dnsmap.go index 2ef1cb1f171c0..37fedd14c899d 100644 --- a/net/tsdial/dnsmap.go +++ b/net/tsdial/dnsmap.go @@ -36,11 +36,11 @@ func dnsMapFromNetworkMap(nm *netmap.NetworkMap) dnsMap { suffix := nm.MagicDNSSuffix() have4 := false addrs := nm.GetAddresses() - if nm.Name != "" && addrs.Len() > 0 { + if name := nm.SelfName(); name != "" && addrs.Len() > 0 { ip := addrs.At(0).Addr() - ret[canonMapKey(nm.Name)] = ip - if dnsname.HasSuffix(nm.Name, suffix) { - ret[canonMapKey(dnsname.TrimSuffix(nm.Name, suffix))] = ip + ret[canonMapKey(name)] = ip + if dnsname.HasSuffix(name, suffix) { + ret[canonMapKey(dnsname.TrimSuffix(name, suffix))] = ip } for _, p := range addrs.All() { if p.Addr().Is4() { diff --git a/net/tsdial/dnsmap_test.go b/net/tsdial/dnsmap_test.go index 43461a135e1c5..41a957f186f4a 100644 --- a/net/tsdial/dnsmap_test.go +++ b/net/tsdial/dnsmap_test.go @@ -31,8 +31,8 @@ func TestDNSMapFromNetworkMap(t *testing.T) { { name: "self", nm: &netmap.NetworkMap{ - Name: "foo.tailnet", SelfNode: (&tailcfg.Node{ + Name: "foo.tailnet.", Addresses: []netip.Prefix{ pfx("100.102.103.104/32"), pfx("100::123/128"), @@ -47,8 +47,8 @@ func TestDNSMapFromNetworkMap(t *testing.T) { { name: "self_and_peers", nm: &netmap.NetworkMap{ - Name: "foo.tailnet", SelfNode: (&tailcfg.Node{ + Name: "foo.tailnet.", Addresses: []netip.Prefix{ pfx("100.102.103.104/32"), pfx("100::123/128"), @@ -82,8 +82,8 @@ func TestDNSMapFromNetworkMap(t *testing.T) { { name: "self_has_v6_only", nm: &netmap.NetworkMap{ - Name: "foo.tailnet", SelfNode: (&tailcfg.Node{ + Name: "foo.tailnet.", Addresses: []netip.Prefix{ pfx("100::123/128"), }, diff --git a/net/tsdial/tsdial.go b/net/tsdial/tsdial.go index c7483a125a07a..065c01384ed55 100644 --- a/net/tsdial/tsdial.go +++ b/net/tsdial/tsdial.go @@ -27,6 +27,7 @@ import ( "tailscale.com/net/netns" "tailscale.com/net/netx" "tailscale.com/net/tsaddr" + "tailscale.com/syncs" "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/util/clientmetric" @@ -86,7 +87,7 @@ type Dialer struct { routes atomic.Pointer[bart.Table[bool]] // or nil if UserDial should not use routes. `true` indicates routes that point into the Tailscale interface - mu sync.Mutex + mu syncs.Mutex closed bool dns dnsMap tunName string // tun device name diff --git a/net/tshttpproxy/zsyscall_windows.go b/net/tshttpproxy/zsyscall_windows.go index c07e9ee03a69e..5dcfae83ea1a4 100644 --- a/net/tshttpproxy/zsyscall_windows.go +++ b/net/tshttpproxy/zsyscall_windows.go @@ -48,7 +48,7 @@ var ( ) func globalFree(hglobal winHGlobal) (err error) { - r1, _, e1 := syscall.Syscall(procGlobalFree.Addr(), 1, uintptr(hglobal), 0, 0) + r1, _, e1 := syscall.SyscallN(procGlobalFree.Addr(), uintptr(hglobal)) if r1 == 0 { err = errnoErr(e1) } @@ -56,7 +56,7 @@ func globalFree(hglobal winHGlobal) (err error) { } func winHTTPCloseHandle(whi winHTTPInternet) (err error) { - r1, _, e1 := syscall.Syscall(procWinHttpCloseHandle.Addr(), 1, uintptr(whi), 0, 0) + r1, _, e1 := syscall.SyscallN(procWinHttpCloseHandle.Addr(), uintptr(whi)) if r1 == 0 { err = errnoErr(e1) } @@ -64,7 +64,7 @@ func winHTTPCloseHandle(whi winHTTPInternet) (err error) { } func winHTTPGetProxyForURL(whi winHTTPInternet, url *uint16, options *winHTTPAutoProxyOptions, proxyInfo *winHTTPProxyInfo) (err error) { - r1, _, e1 := syscall.Syscall6(procWinHttpGetProxyForUrl.Addr(), 4, uintptr(whi), uintptr(unsafe.Pointer(url)), uintptr(unsafe.Pointer(options)), uintptr(unsafe.Pointer(proxyInfo)), 0, 0) + r1, _, e1 := syscall.SyscallN(procWinHttpGetProxyForUrl.Addr(), uintptr(whi), uintptr(unsafe.Pointer(url)), uintptr(unsafe.Pointer(options)), uintptr(unsafe.Pointer(proxyInfo))) if r1 == 0 { err = errnoErr(e1) } @@ -72,7 +72,7 @@ func winHTTPGetProxyForURL(whi winHTTPInternet, url *uint16, options *winHTTPAut } func winHTTPOpen(agent *uint16, accessType uint32, proxy *uint16, proxyBypass *uint16, flags uint32) (whi winHTTPInternet, err error) { - r0, _, e1 := syscall.Syscall6(procWinHttpOpen.Addr(), 5, uintptr(unsafe.Pointer(agent)), uintptr(accessType), uintptr(unsafe.Pointer(proxy)), uintptr(unsafe.Pointer(proxyBypass)), uintptr(flags), 0) + r0, _, e1 := syscall.SyscallN(procWinHttpOpen.Addr(), uintptr(unsafe.Pointer(agent)), uintptr(accessType), uintptr(unsafe.Pointer(proxy)), uintptr(unsafe.Pointer(proxyBypass)), uintptr(flags)) whi = winHTTPInternet(r0) if whi == 0 { err = errnoErr(e1) diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 70cc7118ac208..6e07c7a3dabd0 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -34,6 +34,7 @@ import ( "tailscale.com/types/logger" "tailscale.com/types/netlogfunc" "tailscale.com/util/clientmetric" + "tailscale.com/util/eventbus" "tailscale.com/util/usermetric" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/netstack/gro" @@ -209,6 +210,9 @@ type Wrapper struct { captureHook syncs.AtomicValue[packet.CaptureCallback] metrics *metrics + + eventClient *eventbus.Client + discoKeyAdvertisementPub *eventbus.Publisher[DiscoKeyAdvertisement] } type metrics struct { @@ -254,15 +258,15 @@ func (w *Wrapper) Start() { close(w.startCh) } -func WrapTAP(logf logger.Logf, tdev tun.Device, m *usermetric.Registry) *Wrapper { - return wrap(logf, tdev, true, m) +func WrapTAP(logf logger.Logf, tdev tun.Device, m *usermetric.Registry, bus *eventbus.Bus) *Wrapper { + return wrap(logf, tdev, true, m, bus) } -func Wrap(logf logger.Logf, tdev tun.Device, m *usermetric.Registry) *Wrapper { - return wrap(logf, tdev, false, m) +func Wrap(logf logger.Logf, tdev tun.Device, m *usermetric.Registry, bus *eventbus.Bus) *Wrapper { + return wrap(logf, tdev, false, m, bus) } -func wrap(logf logger.Logf, tdev tun.Device, isTAP bool, m *usermetric.Registry) *Wrapper { +func wrap(logf logger.Logf, tdev tun.Device, isTAP bool, m *usermetric.Registry, bus *eventbus.Bus) *Wrapper { logf = logger.WithPrefix(logf, "tstun: ") w := &Wrapper{ logf: logf, @@ -283,6 +287,9 @@ func wrap(logf logger.Logf, tdev tun.Device, isTAP bool, m *usermetric.Registry) metrics: registerMetrics(m), } + w.eventClient = bus.Client("net.tstun") + w.discoKeyAdvertisementPub = eventbus.Publish[DiscoKeyAdvertisement](w.eventClient) + w.vectorBuffer = make([][]byte, tdev.BatchSize()) for i := range w.vectorBuffer { w.vectorBuffer[i] = make([]byte, maxBufferSize) @@ -357,6 +364,7 @@ func (t *Wrapper) Close() error { close(t.vectorOutbound) t.outboundMu.Unlock() err = t.tdev.Close() + t.eventClient.Close() }) return err } @@ -967,6 +975,11 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { continue } } + if buildfeatures.HasNetLog { + if update := t.connCounter.Load(); update != nil { + updateConnCounter(update, p.Buffer(), false) + } + } // Make sure to do SNAT after filtering, so that any flow tracking in // the filter sees the original source address. See #12133. @@ -976,11 +989,6 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { panic(fmt.Sprintf("short copy: %d != %d", n, len(data)-res.dataOffset)) } sizes[buffsPos] = n - if buildfeatures.HasNetLog { - if update := t.connCounter.Load(); update != nil { - updateConnCounter(update, p.Buffer(), false) - } - } buffsPos++ } if buffsGRO != nil { @@ -1118,6 +1126,11 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i return n, err } +type DiscoKeyAdvertisement struct { + Src netip.Addr + Key key.DiscoPublic +} + func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook packet.CaptureCallback, pc *peerConfigTable, gro *gro.GRO) (filter.Response, *gro.GRO) { if captHook != nil { captHook(packet.FromPeer, t.now(), p.Buffer(), p.CaptureMeta) @@ -1128,6 +1141,12 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook pa t.noteActivity() t.injectOutboundPong(p, pingReq) return filter.DropSilently, gro + } else if discoKeyAdvert, ok := p.AsTSMPDiscoAdvertisement(); ok { + t.discoKeyAdvertisementPub.Publish(DiscoKeyAdvertisement{ + Src: discoKeyAdvert.Src, + Key: discoKeyAdvert.Key, + }) + return filter.DropSilently, gro } else if data, ok := p.AsTSMPPong(); ok { if f := t.OnTSMPPongReceived; f != nil { f(data) diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go index 75cf5afb21f8f..c7d0708df85eb 100644 --- a/net/tstun/wrap_test.go +++ b/net/tstun/wrap_test.go @@ -36,6 +36,8 @@ import ( "tailscale.com/types/netlogtype" "tailscale.com/types/ptr" "tailscale.com/types/views" + "tailscale.com/util/eventbus" + "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/must" "tailscale.com/util/usermetric" "tailscale.com/wgengine/filter" @@ -170,10 +172,10 @@ func setfilter(logf logger.Logf, tun *Wrapper) { tun.SetFilter(filter.New(matches, nil, ipSet, ipSet, nil, logf)) } -func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *Wrapper) { +func newChannelTUN(logf logger.Logf, bus *eventbus.Bus, secure bool) (*tuntest.ChannelTUN, *Wrapper) { chtun := tuntest.NewChannelTUN() reg := new(usermetric.Registry) - tun := Wrap(logf, chtun.TUN(), reg) + tun := Wrap(logf, chtun.TUN(), reg, bus) if secure { setfilter(logf, tun) } else { @@ -183,10 +185,10 @@ func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *Wrapper return chtun, tun } -func newFakeTUN(logf logger.Logf, secure bool) (*fakeTUN, *Wrapper) { +func newFakeTUN(logf logger.Logf, bus *eventbus.Bus, secure bool) (*fakeTUN, *Wrapper) { ftun := NewFake() reg := new(usermetric.Registry) - tun := Wrap(logf, ftun, reg) + tun := Wrap(logf, ftun, reg, bus) if secure { setfilter(logf, tun) } else { @@ -196,7 +198,8 @@ func newFakeTUN(logf logger.Logf, secure bool) (*fakeTUN, *Wrapper) { } func TestReadAndInject(t *testing.T) { - chtun, tun := newChannelTUN(t.Logf, false) + bus := eventbustest.NewBus(t) + chtun, tun := newChannelTUN(t.Logf, bus, false) defer tun.Close() const size = 2 // all payloads have this size @@ -221,7 +224,7 @@ func TestReadAndInject(t *testing.T) { } var buf [MaxPacketSize]byte - var seen = make(map[string]bool) + seen := make(map[string]bool) sizes := make([]int, 1) // We expect the same packets back, in no particular order. for i := range len(written) + len(injected) { @@ -257,7 +260,8 @@ func TestReadAndInject(t *testing.T) { } func TestWriteAndInject(t *testing.T) { - chtun, tun := newChannelTUN(t.Logf, false) + bus := eventbustest.NewBus(t) + chtun, tun := newChannelTUN(t.Logf, bus, false) defer tun.Close() written := []string{"w0", "w1"} @@ -316,8 +320,8 @@ func mustHexDecode(s string) []byte { } func TestFilter(t *testing.T) { - - chtun, tun := newChannelTUN(t.Logf, true) + bus := eventbustest.NewBus(t) + chtun, tun := newChannelTUN(t.Logf, bus, true) defer tun.Close() // Reset the metrics before test. These are global @@ -462,7 +466,8 @@ func assertMetricPackets(t *testing.T, metricName string, want, got int64) { } func TestAllocs(t *testing.T) { - ftun, tun := newFakeTUN(t.Logf, false) + bus := eventbustest.NewBus(t) + ftun, tun := newFakeTUN(t.Logf, bus, false) defer tun.Close() buf := [][]byte{{0x00}} @@ -473,14 +478,14 @@ func TestAllocs(t *testing.T) { return } }) - if err != nil { t.Error(err) } } func TestClose(t *testing.T) { - ftun, tun := newFakeTUN(t.Logf, false) + bus := eventbustest.NewBus(t) + ftun, tun := newFakeTUN(t.Logf, bus, false) data := [][]byte{udp4("1.2.3.4", "5.6.7.8", 98, 98)} _, err := ftun.Write(data, 0) @@ -497,7 +502,8 @@ func TestClose(t *testing.T) { func BenchmarkWrite(b *testing.B) { b.ReportAllocs() - ftun, tun := newFakeTUN(b.Logf, true) + bus := eventbustest.NewBus(b) + ftun, tun := newFakeTUN(b.Logf, bus, true) defer tun.Close() packet := [][]byte{udp4("5.6.7.8", "1.2.3.4", 89, 89)} @@ -887,7 +893,8 @@ func TestCaptureHook(t *testing.T) { now := time.Unix(1682085856, 0) - _, w := newFakeTUN(t.Logf, true) + bus := eventbustest.NewBus(t) + _, w := newFakeTUN(t.Logf, bus, true) w.timeNow = func() time.Time { return now } @@ -957,3 +964,30 @@ func TestCaptureHook(t *testing.T) { captured, want) } } + +func TestTSMPDisco(t *testing.T) { + t.Run("IPv6DiscoAdvert", func(t *testing.T) { + src := netip.MustParseAddr("2001:db8::1") + dst := netip.MustParseAddr("2001:db8::2") + discoKey := key.NewDisco() + buf, _ := (&packet.TSMPDiscoKeyAdvertisement{ + Src: src, + Dst: dst, + Key: discoKey.Public(), + }).Marshal() + + var p packet.Parsed + p.Decode(buf) + + tda, ok := p.AsTSMPDiscoAdvertisement() + if !ok { + t.Error("Unable to parse message as TSMPDiscoAdversitement") + } + if tda.Src != src { + t.Errorf("Src address did not match, expected %v, got %v", src, tda.Src) + } + if !reflect.DeepEqual(tda.Key, discoKey.Public()) { + t.Errorf("Key did not match, expected %q, got %q", discoKey.Public(), tda.Key) + } + }) +} diff --git a/net/udprelay/server.go b/net/udprelay/server.go index 86dee18e12f67..e7ca24960ea1d 100644 --- a/net/udprelay/server.go +++ b/net/udprelay/server.go @@ -10,6 +10,7 @@ import ( "bytes" "context" "crypto/rand" + "encoding/binary" "errors" "fmt" "net" @@ -20,6 +21,7 @@ import ( "time" "go4.org/mem" + "golang.org/x/crypto/blake2s" "golang.org/x/net/ipv6" "tailscale.com/disco" "tailscale.com/net/batching" @@ -36,6 +38,7 @@ import ( "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/nettype" + "tailscale.com/types/views" "tailscale.com/util/eventbus" "tailscale.com/util/set" ) @@ -72,17 +75,22 @@ type Server struct { closeCh chan struct{} netChecker *netcheck.Client - mu sync.Mutex // guards the following fields - derpMap *tailcfg.DERPMap - addrDiscoveryOnce bool // addrDiscovery completed once (successfully or unsuccessfully) - addrPorts []netip.AddrPort // the ip:port pairs returned as candidate endpoints - closed bool - lamportID uint64 - nextVNI uint32 - byVNI map[uint32]*serverEndpoint - byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint + mu sync.Mutex // guards the following fields + macSecrets [][blake2s.Size]byte // [0] is most recent, max 2 elements + macSecretRotatedAt time.Time + derpMap *tailcfg.DERPMap + onlyStaticAddrPorts bool // no dynamic addr port discovery when set + staticAddrPorts views.Slice[netip.AddrPort] // static ip:port pairs set with [Server.SetStaticAddrPorts] + dynamicAddrPorts []netip.AddrPort // dynamically discovered ip:port pairs + closed bool + lamportID uint64 + nextVNI uint32 + byVNI map[uint32]*serverEndpoint + byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint } +const macSecretRotationInterval = time.Minute * 2 + const ( minVNI = uint32(1) maxVNI = uint32(1<<24 - 1) @@ -96,22 +104,42 @@ type serverEndpoint struct { // indexing of this array aligns with the following fields, e.g. // discoSharedSecrets[0] is the shared secret to use when sealing // Disco protocol messages for transmission towards discoPubKeys[0]. - discoPubKeys key.SortedPairOfDiscoPublic - discoSharedSecrets [2]key.DiscoShared - handshakeGeneration [2]uint32 // or zero if a handshake has never started for that relay leg - handshakeAddrPorts [2]netip.AddrPort // or zero value if a handshake has never started for that relay leg - boundAddrPorts [2]netip.AddrPort // or zero value if a handshake has never completed for that relay leg - lastSeen [2]time.Time // TODO(jwhited): consider using mono.Time - challenge [2][disco.BindUDPRelayChallengeLen]byte - packetsRx [2]uint64 // num packets received from/sent by each client after they are bound - bytesRx [2]uint64 // num bytes received from/sent by each client after they are bound + discoPubKeys key.SortedPairOfDiscoPublic + discoSharedSecrets [2]key.DiscoShared + inProgressGeneration [2]uint32 // or zero if a handshake has never started, or has just completed + boundAddrPorts [2]netip.AddrPort // or zero value if a handshake has never completed for that relay leg + lastSeen [2]time.Time // TODO(jwhited): consider using mono.Time + packetsRx [2]uint64 // num packets received from/sent by each client after they are bound + bytesRx [2]uint64 // num bytes received from/sent by each client after they are bound lamportID uint64 vni uint32 allocatedAt time.Time } -func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, serverDisco key.DiscoPublic) (write []byte, to netip.AddrPort) { +func blakeMACFromBindMsg(blakeKey [blake2s.Size]byte, src netip.AddrPort, msg disco.BindUDPRelayEndpointCommon) ([blake2s.Size]byte, error) { + input := make([]byte, 8, 4+4+32+18) // vni + generation + invited party disco key + addr:port + binary.BigEndian.PutUint32(input[0:4], msg.VNI) + binary.BigEndian.PutUint32(input[4:8], msg.Generation) + input = msg.RemoteKey.AppendTo(input) + input, err := src.AppendBinary(input) + if err != nil { + return [blake2s.Size]byte{}, err + } + h, err := blake2s.New256(blakeKey[:]) + if err != nil { + return [blake2s.Size]byte{}, err + } + _, err = h.Write(input) + if err != nil { + return [blake2s.Size]byte{}, err + } + var out [blake2s.Size]byte + h.Sum(out[:0]) + return out, nil +} + +func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, serverDisco key.DiscoPublic, macSecrets [][blake2s.Size]byte) (write []byte, to netip.AddrPort) { if senderIndex != 0 && senderIndex != 1 { return nil, netip.AddrPort{} } @@ -142,18 +170,11 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex // Generation must be nonzero, silently drop return nil, netip.AddrPort{} } - if e.handshakeGeneration[senderIndex] == discoMsg.Generation { - // we've seen this generation before, silently drop - return nil, netip.AddrPort{} - } - e.handshakeGeneration[senderIndex] = discoMsg.Generation - e.handshakeAddrPorts[senderIndex] = from + e.inProgressGeneration[senderIndex] = discoMsg.Generation m := new(disco.BindUDPRelayEndpointChallenge) m.VNI = e.vni m.Generation = discoMsg.Generation m.RemoteKey = e.discoPubKeys.Get()[otherSender] - rand.Read(e.challenge[senderIndex][:]) - copy(m.Challenge[:], e.challenge[senderIndex][:]) reply := make([]byte, packet.GeneveFixedHeaderLength, 512) gh := packet.GeneveHeader{Control: true, Protocol: packet.GeneveProtocolDisco} gh.VNI.Set(e.vni) @@ -163,6 +184,11 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex } reply = append(reply, disco.Magic...) reply = serverDisco.AppendTo(reply) + mac, err := blakeMACFromBindMsg(macSecrets[0], from, m.BindUDPRelayEndpointCommon) + if err != nil { + return nil, netip.AddrPort{} + } + m.Challenge = mac box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil)) reply = append(reply, box...) return reply, from @@ -172,17 +198,29 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex // silently drop return nil, netip.AddrPort{} } - generation := e.handshakeGeneration[senderIndex] - if generation == 0 || // we have no active handshake - generation != discoMsg.Generation || // mismatching generation for the active handshake - e.handshakeAddrPorts[senderIndex] != from || // mismatching source for the active handshake - !bytes.Equal(e.challenge[senderIndex][:], discoMsg.Challenge[:]) { // mismatching answer for the active handshake + generation := e.inProgressGeneration[senderIndex] + if generation == 0 || // we have no in-progress handshake + generation != discoMsg.Generation { // mismatching generation for the in-progress handshake // silently drop return nil, netip.AddrPort{} } - // Handshake complete. Update the binding for this sender. - e.boundAddrPorts[senderIndex] = from - e.lastSeen[senderIndex] = time.Now() // record last seen as bound time + for _, macSecret := range macSecrets { + mac, err := blakeMACFromBindMsg(macSecret, from, discoMsg.BindUDPRelayEndpointCommon) + if err != nil { + // silently drop + return nil, netip.AddrPort{} + } + // Speed is favored over constant-time comparison here. The sender is + // already authenticated via disco. + if bytes.Equal(mac[:], discoMsg.Challenge[:]) { + // Handshake complete. Update the binding for this sender. + e.boundAddrPorts[senderIndex] = from + e.lastSeen[senderIndex] = time.Now() // record last seen as bound time + e.inProgressGeneration[senderIndex] = 0 // reset to zero, which indicates there is no in-progress handshake + return nil, netip.AddrPort{} + } + } + // MAC does not match, silently drop return nil, netip.AddrPort{} default: // unexpected message types, silently drop @@ -190,7 +228,7 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex } } -func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, serverDisco key.DiscoPublic) (write []byte, to netip.AddrPort) { +func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, serverDisco key.DiscoPublic, macSecrets [][blake2s.Size]byte) (write []byte, to netip.AddrPort) { senderRaw, isDiscoMsg := disco.Source(b) if !isDiscoMsg { // Not a Disco message @@ -221,39 +259,29 @@ func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []by return nil, netip.AddrPort{} } - return e.handleDiscoControlMsg(from, senderIndex, discoMsg, serverDisco) + return e.handleDiscoControlMsg(from, senderIndex, discoMsg, serverDisco, macSecrets) } -func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeader, b []byte, serverDisco key.DiscoPublic) (write []byte, to netip.AddrPort) { - if !gh.Control { - if !e.isBound() { - // not a control packet, but serverEndpoint isn't bound - return nil, netip.AddrPort{} - } - switch { - case from == e.boundAddrPorts[0]: - e.lastSeen[0] = time.Now() - e.packetsRx[0]++ - e.bytesRx[0] += uint64(len(b)) - return b, e.boundAddrPorts[1] - case from == e.boundAddrPorts[1]: - e.lastSeen[1] = time.Now() - e.packetsRx[1]++ - e.bytesRx[1] += uint64(len(b)) - return b, e.boundAddrPorts[0] - default: - // unrecognized source - return nil, netip.AddrPort{} - } +func (e *serverEndpoint) handleDataPacket(from netip.AddrPort, b []byte, now time.Time) (write []byte, to netip.AddrPort) { + if !e.isBound() { + // not a control packet, but serverEndpoint isn't bound + return nil, netip.AddrPort{} } - - if gh.Protocol != packet.GeneveProtocolDisco { - // control packet, but not Disco + switch { + case from == e.boundAddrPorts[0]: + e.lastSeen[0] = now + e.packetsRx[0]++ + e.bytesRx[0] += uint64(len(b)) + return b, e.boundAddrPorts[1] + case from == e.boundAddrPorts[1]: + e.lastSeen[1] = now + e.packetsRx[1]++ + e.bytesRx[1] += uint64(len(b)) + return b, e.boundAddrPorts[0] + default: + // unrecognized source return nil, netip.AddrPort{} } - - msg := b[packet.GeneveFixedHeaderLength:] - return e.handleSealedDiscoControlMsg(from, msg, serverDisco) } func (e *serverEndpoint) isExpired(now time.Time, bindLifetime, steadyStateLifetime time.Duration) bool { @@ -278,15 +306,17 @@ func (e *serverEndpoint) isBound() bool { // NewServer constructs a [Server] listening on port. If port is zero, then // port selection is left up to the host networking stack. If -// len(overrideAddrs) > 0 these will be used in place of dynamic discovery, -// which is useful to override in tests. -func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Server, err error) { +// onlyStaticAddrPorts is true, then dynamic addr:port discovery will be +// disabled, and only addr:port's set via [Server.SetStaticAddrPorts] will be +// used. +func NewServer(logf logger.Logf, port uint16, onlyStaticAddrPorts bool) (s *Server, err error) { s = &Server{ logf: logf, disco: key.NewDisco(), bindLifetime: defaultBindLifetime, steadyStateLifetime: defaultSteadyStateLifetime, closeCh: make(chan struct{}), + onlyStaticAddrPorts: onlyStaticAddrPorts, byDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint), nextVNI: minVNI, byVNI: make(map[uint32]*serverEndpoint), @@ -321,19 +351,7 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve return nil, err } - if len(overrideAddrs) > 0 { - addrPorts := make(set.Set[netip.AddrPort], len(overrideAddrs)) - for _, addr := range overrideAddrs { - if addr.IsValid() { - if addr.Is4() { - addrPorts.Add(netip.AddrPortFrom(addr, s.uc4Port)) - } else if s.uc6 != nil { - addrPorts.Add(netip.AddrPortFrom(addr, s.uc6Port)) - } - } - } - s.addrPorts = addrPorts.Slice() - } else { + if !s.onlyStaticAddrPorts { s.wg.Add(1) go s.addrDiscoveryLoop() } @@ -392,14 +410,29 @@ func (s *Server) addrDiscoveryLoop() { if err != nil { return nil, err } - if rep.GlobalV4.IsValid() { - addrPorts.Add(rep.GlobalV4) + // Add STUN-discovered endpoints with their observed ports. + v4Addrs, v6Addrs := rep.GetGlobalAddrs() + for _, addr := range v4Addrs { + if addr.IsValid() { + addrPorts.Add(addr) + } } - if rep.GlobalV6.IsValid() { - addrPorts.Add(rep.GlobalV6) + for _, addr := range v6Addrs { + if addr.IsValid() { + addrPorts.Add(addr) + } + } + + if len(v4Addrs) >= 1 && v4Addrs[0].IsValid() { + // If they're behind a hard NAT and are using a fixed + // port locally, assume they might've added a static + // port mapping on their router to the same explicit + // port that the relay is running with. Worst case + // it's an invalid candidate mapping. + if rep.MappingVariesByDestIP.EqualBool(true) && s.uc4Port != 0 { + addrPorts.Add(netip.AddrPortFrom(v4Addrs[0].Addr(), s.uc4Port)) + } } - // TODO(jwhited): consider logging if rep.MappingVariesByDestIP as - // that's a hint we are not well-positioned to operate as a UDP relay. return addrPorts.Slice(), nil } @@ -414,8 +447,7 @@ func (s *Server) addrDiscoveryLoop() { s.logf("error discovering IP:port candidates: %v", err) } s.mu.Lock() - s.addrPorts = addrPorts - s.addrDiscoveryOnce = true + s.dynamicAddrPorts = addrPorts s.mu.Unlock() case <-s.closeCh: return @@ -494,9 +526,9 @@ func trySetUDPSocketOptions(pconn nettype.PacketConn, logf logger.Logf) { // [magicsock.RebindingConn], which would also remove the need for // [singlePacketConn], as [magicsock.RebindingConn] also handles fallback to // single packet syscall operations. -func (s *Server) listenOn(port int) error { +func (s *Server) listenOn(port uint16) error { for _, network := range []string{"udp4", "udp6"} { - uc, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) + uc, err := net.ListenUDP(network, &net.UDPAddr{Port: int(port)}) if err != nil { if network == "udp4" { return err @@ -615,7 +647,35 @@ func (s *Server) handlePacket(from netip.AddrPort, b []byte) (write []byte, to n return nil, netip.AddrPort{} } - return e.handlePacket(from, gh, b, s.discoPublic) + now := time.Now() + if gh.Control { + if gh.Protocol != packet.GeneveProtocolDisco { + // control packet, but not Disco + return nil, netip.AddrPort{} + } + msg := b[packet.GeneveFixedHeaderLength:] + s.maybeRotateMACSecretLocked(now) + return e.handleSealedDiscoControlMsg(from, msg, s.discoPublic, s.macSecrets) + } + return e.handleDataPacket(from, b, now) +} + +func (s *Server) maybeRotateMACSecretLocked(now time.Time) { + if !s.macSecretRotatedAt.IsZero() && now.Sub(s.macSecretRotatedAt) < macSecretRotationInterval { + return + } + switch len(s.macSecrets) { + case 0: + s.macSecrets = make([][blake2s.Size]byte, 1, 2) + case 1: + s.macSecrets = append(s.macSecrets, [blake2s.Size]byte{}) + fallthrough + case 2: + s.macSecrets[1] = s.macSecrets[0] + } + rand.Read(s.macSecrets[0][:]) + s.macSecretRotatedAt = now + return } func (s *Server) packetReadLoop(readFromSocket, otherSocket batching.Conn, readFromSocketIsIPv4 bool) { @@ -732,6 +792,15 @@ func (s *Server) getNextVNILocked() (uint32, error) { return 0, errors.New("VNI pool exhausted") } +// getAllAddrPortsCopyLocked returns a copy of the combined +// [Server.staticAddrPorts] and [Server.dynamicAddrPorts] slices. +func (s *Server) getAllAddrPortsCopyLocked() []netip.AddrPort { + addrPorts := make([]netip.AddrPort, 0, len(s.dynamicAddrPorts)+s.staticAddrPorts.Len()) + addrPorts = append(addrPorts, s.staticAddrPorts.AsSlice()...) + addrPorts = append(addrPorts, slices.Clone(s.dynamicAddrPorts)...) + return addrPorts +} + // AllocateEndpoint allocates an [endpoint.ServerEndpoint] for the provided pair // of [key.DiscoPublic]'s. If an allocation already exists for discoA and discoB // it is returned without modification/reallocation. AllocateEndpoint returns @@ -745,11 +814,8 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv return endpoint.ServerEndpoint{}, ErrServerClosed } - if len(s.addrPorts) == 0 { - if !s.addrDiscoveryOnce { - return endpoint.ServerEndpoint{}, ErrServerNotReady{RetryAfter: endpoint.ServerRetryAfter} - } - return endpoint.ServerEndpoint{}, errors.New("server addrPorts are not yet known") + if s.staticAddrPorts.Len() == 0 && len(s.dynamicAddrPorts) == 0 { + return endpoint.ServerEndpoint{}, ErrServerNotReady{RetryAfter: endpoint.ServerRetryAfter} } if discoA.Compare(s.discoPublic) == 0 || discoB.Compare(s.discoPublic) == 0 { @@ -772,7 +838,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv // consider storing them (maybe interning) in the [*serverEndpoint] // at allocation time. ClientDisco: pair.Get(), - AddrPorts: slices.Clone(s.addrPorts), + AddrPorts: s.getAllAddrPortsCopyLocked(), VNI: e.vni, LamportID: e.lamportID, BindLifetime: tstime.GoDuration{Duration: s.bindLifetime}, @@ -802,7 +868,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv return endpoint.ServerEndpoint{ ServerDisco: s.discoPublic, ClientDisco: pair.Get(), - AddrPorts: slices.Clone(s.addrPorts), + AddrPorts: s.getAllAddrPortsCopyLocked(), VNI: e.vni, LamportID: e.lamportID, BindLifetime: tstime.GoDuration{Duration: s.bindLifetime}, @@ -865,3 +931,13 @@ func (s *Server) getDERPMap() *tailcfg.DERPMap { defer s.mu.Unlock() return s.derpMap } + +// SetStaticAddrPorts sets addr:port pairs the [Server] will advertise +// as candidates it is potentially reachable over, in combination with +// dynamically discovered pairs. This replaces any previously-provided static +// values. +func (s *Server) SetStaticAddrPorts(addrPorts views.Slice[netip.AddrPort]) { + s.mu.Lock() + defer s.mu.Unlock() + s.staticAddrPorts = addrPorts +} diff --git a/net/udprelay/server_test.go b/net/udprelay/server_test.go index bf7f0a9b5f1de..582d4cf671918 100644 --- a/net/udprelay/server_test.go +++ b/net/udprelay/server_test.go @@ -5,6 +5,7 @@ package udprelay import ( "bytes" + "crypto/rand" "net" "net/netip" "testing" @@ -14,9 +15,11 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "go4.org/mem" + "golang.org/x/crypto/blake2s" "tailscale.com/disco" "tailscale.com/net/packet" "tailscale.com/types/key" + "tailscale.com/types/views" ) type testClient struct { @@ -185,31 +188,40 @@ func TestServer(t *testing.T) { cases := []struct { name string - overrideAddrs []netip.Addr + staticAddrs []netip.Addr forceClientsMixedAF bool }{ { - name: "over ipv4", - overrideAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + name: "over ipv4", + staticAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, }, { - name: "over ipv6", - overrideAddrs: []netip.Addr{netip.MustParseAddr("::1")}, + name: "over ipv6", + staticAddrs: []netip.Addr{netip.MustParseAddr("::1")}, }, { name: "mixed address families", - overrideAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("::1")}, + staticAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("::1")}, forceClientsMixedAF: true, }, } for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { - server, err := NewServer(t.Logf, 0, tt.overrideAddrs) + server, err := NewServer(t.Logf, 0, true) if err != nil { t.Fatal(err) } defer server.Close() + addrPorts := make([]netip.AddrPort, 0, len(tt.staticAddrs)) + for _, addr := range tt.staticAddrs { + if addr.Is4() { + addrPorts = append(addrPorts, netip.AddrPortFrom(addr, server.uc4Port)) + } else if server.uc6Port != 0 { + addrPorts = append(addrPorts, netip.AddrPortFrom(addr, server.uc6Port)) + } + } + server.SetStaticAddrPorts(views.SliceOf(addrPorts)) endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public()) if err != nil { @@ -342,3 +354,117 @@ func TestServer_getNextVNILocked(t *testing.T) { _, err = s.getNextVNILocked() c.Assert(err, qt.IsNil) } + +func Test_blakeMACFromBindMsg(t *testing.T) { + var macSecret [blake2s.Size]byte + rand.Read(macSecret[:]) + src := netip.MustParseAddrPort("[2001:db8::1]:7") + + msgA := disco.BindUDPRelayEndpointCommon{ + VNI: 1, + Generation: 1, + RemoteKey: key.NewDisco().Public(), + Challenge: [32]byte{}, + } + macA, err := blakeMACFromBindMsg(macSecret, src, msgA) + if err != nil { + t.Fatal(err) + } + + msgB := msgA + msgB.VNI++ + macB, err := blakeMACFromBindMsg(macSecret, src, msgB) + if err != nil { + t.Fatal(err) + } + if macA == macB { + t.Fatalf("varying VNI input produced identical mac: %v", macA) + } + + msgC := msgA + msgC.Generation++ + macC, err := blakeMACFromBindMsg(macSecret, src, msgC) + if err != nil { + t.Fatal(err) + } + if macA == macC { + t.Fatalf("varying Generation input produced identical mac: %v", macA) + } + + msgD := msgA + msgD.RemoteKey = key.NewDisco().Public() + macD, err := blakeMACFromBindMsg(macSecret, src, msgD) + if err != nil { + t.Fatal(err) + } + if macA == macD { + t.Fatalf("varying RemoteKey input produced identical mac: %v", macA) + } + + msgE := msgA + msgE.Challenge = [32]byte{0x01} // challenge is not part of the MAC and should be ignored + macE, err := blakeMACFromBindMsg(macSecret, src, msgE) + if err != nil { + t.Fatal(err) + } + if macA != macE { + t.Fatalf("varying Challenge input produced varying mac: %v", macA) + } + + macSecretB := macSecret + macSecretB[0] ^= 0xFF + macF, err := blakeMACFromBindMsg(macSecretB, src, msgA) + if err != nil { + t.Fatal(err) + } + if macA == macF { + t.Fatalf("varying macSecret input produced identical mac: %v", macA) + } + + srcB := netip.AddrPortFrom(src.Addr(), src.Port()+1) + macG, err := blakeMACFromBindMsg(macSecret, srcB, msgA) + if err != nil { + t.Fatal(err) + } + if macA == macG { + t.Fatalf("varying src input produced identical mac: %v", macA) + } +} + +func Benchmark_blakeMACFromBindMsg(b *testing.B) { + var macSecret [blake2s.Size]byte + rand.Read(macSecret[:]) + src := netip.MustParseAddrPort("[2001:db8::1]:7") + msg := disco.BindUDPRelayEndpointCommon{ + VNI: 1, + Generation: 1, + RemoteKey: key.NewDisco().Public(), + Challenge: [32]byte{}, + } + b.ReportAllocs() + for b.Loop() { + _, err := blakeMACFromBindMsg(macSecret, src, msg) + if err != nil { + b.Fatal(err) + } + } +} + +func TestServer_maybeRotateMACSecretLocked(t *testing.T) { + s := &Server{} + start := time.Now() + s.maybeRotateMACSecretLocked(start) + qt.Assert(t, len(s.macSecrets), qt.Equals, 1) + macSecret := s.macSecrets[0] + s.maybeRotateMACSecretLocked(start.Add(macSecretRotationInterval - time.Nanosecond)) + qt.Assert(t, len(s.macSecrets), qt.Equals, 1) + qt.Assert(t, s.macSecrets[0], qt.Equals, macSecret) + s.maybeRotateMACSecretLocked(start.Add(macSecretRotationInterval)) + qt.Assert(t, len(s.macSecrets), qt.Equals, 2) + qt.Assert(t, s.macSecrets[1], qt.Equals, macSecret) + qt.Assert(t, s.macSecrets[0], qt.Not(qt.Equals), s.macSecrets[1]) + s.maybeRotateMACSecretLocked(s.macSecretRotatedAt.Add(macSecretRotationInterval)) + qt.Assert(t, macSecret, qt.Not(qt.Equals), s.macSecrets[0]) + qt.Assert(t, macSecret, qt.Not(qt.Equals), s.macSecrets[1]) + qt.Assert(t, s.macSecrets[0], qt.Not(qt.Equals), s.macSecrets[1]) +} diff --git a/net/udprelay/status/status.go b/net/udprelay/status/status.go index 3866efada2542..9ed9a0d2a8def 100644 --- a/net/udprelay/status/status.go +++ b/net/udprelay/status/status.go @@ -14,8 +14,9 @@ import ( type ServerStatus struct { // UDPPort is the UDP port number that the peer relay server forwards over, // as configured by the user with 'tailscale set --relay-server-port='. - // If the port has not been configured, UDPPort will be nil. - UDPPort *int + // If the port has not been configured, UDPPort will be nil. A non-nil zero + // value signifies the user has opted for a random unused port. + UDPPort *uint16 // Sessions is a slice of detailed status information about each peer // relay session that this node's peer relay server is involved with. It // may be empty. diff --git a/net/wsconn/wsconn.go b/net/wsconn/wsconn.go index 3c83ffd8c320f..9e44da59ca1d7 100644 --- a/net/wsconn/wsconn.go +++ b/net/wsconn/wsconn.go @@ -12,11 +12,11 @@ import ( "math" "net" "os" - "sync" "sync/atomic" "time" "github.com/coder/websocket" + "tailscale.com/syncs" ) // NetConn converts a *websocket.Conn into a net.Conn. @@ -102,7 +102,7 @@ type netConn struct { reading atomic.Bool afterReadDeadline atomic.Bool - readMu sync.Mutex + readMu syncs.Mutex // eofed is true if the reader should return io.EOF from the Read call. // // +checklocks:readMu diff --git a/packages/deb/deb.go b/packages/deb/deb.go index 30e3f2b4d360c..cab0fea075e74 100644 --- a/packages/deb/deb.go +++ b/packages/deb/deb.go @@ -166,14 +166,14 @@ var ( func findArchAndVersion(control []byte) (arch string, version string, err error) { b := bytes.NewBuffer(control) for { - l, err := b.ReadBytes('\n') + ln, err := b.ReadBytes('\n') if err != nil { return "", "", err } - if bytes.HasPrefix(l, archKey) { - arch = string(bytes.TrimSpace(l[len(archKey):])) - } else if bytes.HasPrefix(l, versionKey) { - version = string(bytes.TrimSpace(l[len(versionKey):])) + if bytes.HasPrefix(ln, archKey) { + arch = string(bytes.TrimSpace(ln[len(archKey):])) + } else if bytes.HasPrefix(ln, versionKey) { + version = string(bytes.TrimSpace(ln[len(versionKey):])) } if arch != "" && version != "" { return arch, version, nil diff --git a/portlist/portlist_test.go b/portlist/portlist_test.go index 34277fdbaba91..791a8b118427f 100644 --- a/portlist/portlist_test.go +++ b/portlist/portlist_test.go @@ -5,12 +5,24 @@ package portlist import ( "net" + "runtime" "testing" "tailscale.com/tstest" ) +func maybeSkip(t *testing.T) { + if runtime.GOOS == "linux" { + tstest.SkipOnKernelVersions(t, + "https://github.com/tailscale/tailscale/issues/16966", + "6.6.102", "6.6.103", "6.6.104", + "6.12.42", "6.12.43", "6.12.44", "6.12.45", + ) + } +} + func TestGetList(t *testing.T) { + maybeSkip(t) tstest.ResourceCheck(t) var p Poller @@ -25,6 +37,7 @@ func TestGetList(t *testing.T) { } func TestIgnoreLocallyBoundPorts(t *testing.T) { + maybeSkip(t) tstest.ResourceCheck(t) ln, err := net.Listen("tcp", "127.0.0.1:0") @@ -47,6 +60,8 @@ func TestIgnoreLocallyBoundPorts(t *testing.T) { } func TestPoller(t *testing.T) { + maybeSkip(t) + var p Poller p.IncludeLocalhost = true get := func(t *testing.T) []Port { diff --git a/prober/derp.go b/prober/derp.go index 52e56fd4eff1e..22843b53a4049 100644 --- a/prober/derp.go +++ b/prober/derp.go @@ -323,14 +323,14 @@ func (d *derpProber) probeBandwidth(from, to string, size int64) ProbeClass { "derp_path": derpPath, "tcp_in_tcp": strconv.FormatBool(d.bwTUNIPv4Prefix != nil), }, - Metrics: func(l prometheus.Labels) []prometheus.Metric { + Metrics: func(lb prometheus.Labels) []prometheus.Metric { metrics := []prometheus.Metric{ - prometheus.MustNewConstMetric(prometheus.NewDesc("derp_bw_probe_size_bytes", "Payload size of the bandwidth prober", nil, l), prometheus.GaugeValue, float64(size)), - prometheus.MustNewConstMetric(prometheus.NewDesc("derp_bw_transfer_time_seconds_total", "Time it took to transfer data", nil, l), prometheus.CounterValue, transferTimeSeconds.Value()), + prometheus.MustNewConstMetric(prometheus.NewDesc("derp_bw_probe_size_bytes", "Payload size of the bandwidth prober", nil, lb), prometheus.GaugeValue, float64(size)), + prometheus.MustNewConstMetric(prometheus.NewDesc("derp_bw_transfer_time_seconds_total", "Time it took to transfer data", nil, lb), prometheus.CounterValue, transferTimeSeconds.Value()), } if d.bwTUNIPv4Prefix != nil { // For TCP-in-TCP probes, also record cumulative bytes transferred. - metrics = append(metrics, prometheus.MustNewConstMetric(prometheus.NewDesc("derp_bw_bytes_total", "Amount of data transferred", nil, l), prometheus.CounterValue, totalBytesTransferred.Value())) + metrics = append(metrics, prometheus.MustNewConstMetric(prometheus.NewDesc("derp_bw_bytes_total", "Amount of data transferred", nil, lb), prometheus.CounterValue, totalBytesTransferred.Value())) } return metrics }, @@ -361,11 +361,11 @@ func (d *derpProber) probeQueuingDelay(from, to string, packetsPerSecond int, pa }, Class: "derp_qd", Labels: Labels{"derp_path": derpPath}, - Metrics: func(l prometheus.Labels) []prometheus.Metric { + Metrics: func(lb prometheus.Labels) []prometheus.Metric { qdh.mx.Lock() result := []prometheus.Metric{ - prometheus.MustNewConstMetric(prometheus.NewDesc("derp_qd_probe_dropped_packets", "Total packets dropped", nil, l), prometheus.CounterValue, float64(packetsDropped.Value())), - prometheus.MustNewConstHistogram(prometheus.NewDesc("derp_qd_probe_delays_seconds", "Distribution of queuing delays", nil, l), qdh.count, qdh.sum, maps.Clone(qdh.bucketedCounts)), + prometheus.MustNewConstMetric(prometheus.NewDesc("derp_qd_probe_dropped_packets", "Total packets dropped", nil, lb), prometheus.CounterValue, float64(packetsDropped.Value())), + prometheus.MustNewConstHistogram(prometheus.NewDesc("derp_qd_probe_delays_seconds", "Distribution of queuing delays", nil, lb), qdh.count, qdh.sum, maps.Clone(qdh.bucketedCounts)), } qdh.mx.Unlock() return result @@ -1046,11 +1046,11 @@ func derpProbeBandwidthTUN(ctx context.Context, transferTimeSeconds, totalBytesT }() // Start a listener to receive the data - l, err := net.Listen("tcp", net.JoinHostPort(ifAddr.String(), "0")) + ln, err := net.Listen("tcp", net.JoinHostPort(ifAddr.String(), "0")) if err != nil { return fmt.Errorf("failed to listen: %s", err) } - defer l.Close() + defer ln.Close() // 128KB by default const writeChunkSize = 128 << 10 @@ -1062,9 +1062,9 @@ func derpProbeBandwidthTUN(ctx context.Context, transferTimeSeconds, totalBytesT } // Dial ourselves - _, port, err := net.SplitHostPort(l.Addr().String()) + _, port, err := net.SplitHostPort(ln.Addr().String()) if err != nil { - return fmt.Errorf("failed to split address %q: %w", l.Addr().String(), err) + return fmt.Errorf("failed to split address %q: %w", ln.Addr().String(), err) } connAddr := net.JoinHostPort(destinationAddr.String(), port) @@ -1085,7 +1085,7 @@ func derpProbeBandwidthTUN(ctx context.Context, transferTimeSeconds, totalBytesT go func() { defer wg.Done() - readConn, err := l.Accept() + readConn, err := ln.Accept() if err != nil { readFinishedC <- err return @@ -1146,11 +1146,11 @@ func derpProbeBandwidthTUN(ctx context.Context, transferTimeSeconds, totalBytesT func newConn(ctx context.Context, dm *tailcfg.DERPMap, n *tailcfg.DERPNode, isProber bool, meshKey key.DERPMesh) (*derphttp.Client, error) { // To avoid spamming the log with regular connection messages. - l := logger.Filtered(log.Printf, func(s string) bool { + logf := logger.Filtered(log.Printf, func(s string) bool { return !strings.Contains(s, "derphttp.Client.Connect: connecting to") }) priv := key.NewNode() - dc := derphttp.NewRegionClient(priv, l, netmon.NewStatic(), func() *tailcfg.DERPRegion { + dc := derphttp.NewRegionClient(priv, logf, netmon.NewStatic(), func() *tailcfg.DERPRegion { rid := n.RegionID return &tailcfg.DERPRegion{ RegionID: rid, diff --git a/prober/prober.go b/prober/prober.go index 9073a95029163..6b904dd97d231 100644 --- a/prober/prober.go +++ b/prober/prober.go @@ -118,25 +118,25 @@ func (p *Prober) Run(name string, interval time.Duration, labels Labels, pc Prob panic(fmt.Sprintf("probe named %q already registered", name)) } - l := prometheus.Labels{ + lb := prometheus.Labels{ "name": name, "class": pc.Class, } for k, v := range pc.Labels { - l[k] = v + lb[k] = v } for k, v := range labels { - l[k] = v + lb[k] = v } - probe := newProbe(p, name, interval, l, pc) + probe := newProbe(p, name, interval, lb, pc) p.probes[name] = probe go probe.loop() return probe } // newProbe creates a new Probe with the given parameters, but does not start it. -func newProbe(p *Prober, name string, interval time.Duration, l prometheus.Labels, pc ProbeClass) *Probe { +func newProbe(p *Prober, name string, interval time.Duration, lg prometheus.Labels, pc ProbeClass) *Probe { ctx, cancel := context.WithCancel(context.Background()) probe := &Probe{ prober: p, @@ -155,17 +155,17 @@ func newProbe(p *Prober, name string, interval time.Duration, l prometheus.Label latencyHist: ring.New(recentHistSize), metrics: prometheus.NewRegistry(), - metricLabels: l, - mInterval: prometheus.NewDesc("interval_secs", "Probe interval in seconds", nil, l), - mStartTime: prometheus.NewDesc("start_secs", "Latest probe start time (seconds since epoch)", nil, l), - mEndTime: prometheus.NewDesc("end_secs", "Latest probe end time (seconds since epoch)", nil, l), - mLatency: prometheus.NewDesc("latency_millis", "Latest probe latency (ms)", nil, l), - mResult: prometheus.NewDesc("result", "Latest probe result (1 = success, 0 = failure)", nil, l), + metricLabels: lg, + mInterval: prometheus.NewDesc("interval_secs", "Probe interval in seconds", nil, lg), + mStartTime: prometheus.NewDesc("start_secs", "Latest probe start time (seconds since epoch)", nil, lg), + mEndTime: prometheus.NewDesc("end_secs", "Latest probe end time (seconds since epoch)", nil, lg), + mLatency: prometheus.NewDesc("latency_millis", "Latest probe latency (ms)", nil, lg), + mResult: prometheus.NewDesc("result", "Latest probe result (1 = success, 0 = failure)", nil, lg), mAttempts: prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "attempts_total", Help: "Total number of probing attempts", ConstLabels: l, + Name: "attempts_total", Help: "Total number of probing attempts", ConstLabels: lg, }, []string{"status"}), mSeconds: prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "seconds_total", Help: "Total amount of time spent executing the probe", ConstLabels: l, + Name: "seconds_total", Help: "Total amount of time spent executing the probe", ConstLabels: lg, }, []string{"status"}), } if p.metrics != nil { @@ -512,8 +512,8 @@ func (probe *Probe) probeInfoLocked() ProbeInfo { inf.Latency = probe.latency } probe.latencyHist.Do(func(v any) { - if l, ok := v.(time.Duration); ok { - inf.RecentLatencies = append(inf.RecentLatencies, l) + if latency, ok := v.(time.Duration); ok { + inf.RecentLatencies = append(inf.RecentLatencies, latency) } }) probe.successHist.Do(func(v any) { @@ -719,8 +719,8 @@ func initialDelay(seed string, interval time.Duration) time.Duration { // Labels is a set of metric labels used by a prober. type Labels map[string]string -func (l Labels) With(k, v string) Labels { - new := maps.Clone(l) +func (lb Labels) With(k, v string) Labels { + new := maps.Clone(lb) new[k] = v return new } diff --git a/proxymap/proxymap.go b/proxymap/proxymap.go index dfe6f2d586000..20dc96c848307 100644 --- a/proxymap/proxymap.go +++ b/proxymap/proxymap.go @@ -9,9 +9,9 @@ import ( "fmt" "net/netip" "strings" - "sync" "time" + "tailscale.com/syncs" "tailscale.com/util/mak" ) @@ -22,7 +22,7 @@ import ( // ask tailscaled (via the LocalAPI WhoIs method) the Tailscale identity that a // given localhost:port corresponds to. type Mapper struct { - mu sync.Mutex + mu syncs.Mutex // m holds the mapping from localhost IP:ports to Tailscale IPs. It is // keyed first by the protocol ("tcp" or "udp"), then by the IP:port. diff --git a/scripts/installer.sh b/scripts/installer.sh index b40177005821b..e5b6cd23bc9a7 100755 --- a/scripts/installer.sh +++ b/scripts/installer.sh @@ -42,6 +42,8 @@ main() { # - VERSION_CODENAME: the codename of the OS release, if any (e.g. "buster") # - UBUNTU_CODENAME: if it exists, use instead of VERSION_CODENAME . /etc/os-release + VERSION_MAJOR="${VERSION_ID:-}" + VERSION_MAJOR="${VERSION_MAJOR%%.*}" case "$ID" in ubuntu|pop|neon|zorin|tuxedo) OS="ubuntu" @@ -53,10 +55,10 @@ main() { PACKAGETYPE="apt" # Third-party keyrings became the preferred method of # installation in Ubuntu 20.04. - if expr "$VERSION_ID" : "2.*" >/dev/null; then - APT_KEY_TYPE="keyring" - else + if [ "$VERSION_MAJOR" -lt 20 ]; then APT_KEY_TYPE="legacy" + else + APT_KEY_TYPE="keyring" fi ;; debian) @@ -76,7 +78,7 @@ main() { # They don't specify the Debian version they're based off in os-release # but Parrot 6 is based on Debian 12 Bookworm. VERSION=bookworm - elif [ "$VERSION_ID" -lt 11 ]; then + elif [ "$VERSION_MAJOR" -lt 11 ]; then APT_KEY_TYPE="legacy" else APT_KEY_TYPE="keyring" @@ -94,7 +96,7 @@ main() { VERSION="$VERSION_CODENAME" fi PACKAGETYPE="apt" - if [ "$VERSION_ID" -lt 5 ]; then + if [ "$VERSION_MAJOR" -lt 5 ]; then APT_KEY_TYPE="legacy" else APT_KEY_TYPE="keyring" @@ -104,7 +106,7 @@ main() { OS="ubuntu" VERSION="$UBUNTU_CODENAME" PACKAGETYPE="apt" - if [ "$VERSION_ID" -lt 6 ]; then + if [ "$VERSION_MAJOR" -lt 6 ]; then APT_KEY_TYPE="legacy" else APT_KEY_TYPE="keyring" @@ -113,7 +115,7 @@ main() { industrial-os) OS="debian" PACKAGETYPE="apt" - if [ "$(printf %.1s "$VERSION_ID")" -lt 5 ]; then + if [ "$VERSION_MAJOR" -lt 5 ]; then VERSION="buster" APT_KEY_TYPE="legacy" else @@ -124,7 +126,7 @@ main() { parrot|mendel) OS="debian" PACKAGETYPE="apt" - if [ "$VERSION_ID" -lt 5 ]; then + if [ "$VERSION_MAJOR" -lt 5 ]; then VERSION="buster" APT_KEY_TYPE="legacy" else @@ -150,7 +152,7 @@ main() { PACKAGETYPE="apt" # Third-party keyrings became the preferred method of # installation in Raspbian 11 (Bullseye). - if [ "$VERSION_ID" -lt 11 ]; then + if [ "$VERSION_MAJOR" -lt 11 ]; then APT_KEY_TYPE="legacy" else APT_KEY_TYPE="keyring" @@ -159,12 +161,11 @@ main() { kali) OS="debian" PACKAGETYPE="apt" - YEAR="$(echo "$VERSION_ID" | cut -f1 -d.)" APT_SYSTEMCTL_START=true # Third-party keyrings became the preferred method of # installation in Debian 11 (Bullseye), which Kali switched # to in roughly 2021.x releases - if [ "$YEAR" -lt 2021 ]; then + if [ "$VERSION_MAJOR" -lt 2021 ]; then # Kali VERSION_ID is "kali-rolling", which isn't distinguishing VERSION="buster" APT_KEY_TYPE="legacy" @@ -176,7 +177,7 @@ main() { Deepin|deepin) # https://github.com/tailscale/tailscale/issues/7862 OS="debian" PACKAGETYPE="apt" - if [ "$VERSION_ID" -lt 20 ]; then + if [ "$VERSION_MAJOR" -lt 20 ]; then APT_KEY_TYPE="legacy" VERSION="buster" else @@ -189,7 +190,7 @@ main() { # All versions of PikaOS are new enough to prefer keyring APT_KEY_TYPE="keyring" # Older versions of PikaOS are based on Ubuntu rather than Debian - if [ "$VERSION_ID" -lt 4 ]; then + if [ "$VERSION_MAJOR" -lt 4 ]; then OS="ubuntu" VERSION="$UBUNTU_CODENAME" else @@ -205,7 +206,7 @@ main() { ;; centos) OS="$ID" - VERSION="$VERSION_ID" + VERSION="$VERSION_MAJOR" PACKAGETYPE="dnf" if [ "$VERSION" = "7" ]; then PACKAGETYPE="yum" @@ -213,7 +214,7 @@ main() { ;; ol) OS="oracle" - VERSION="$(echo "$VERSION_ID" | cut -f1 -d.)" + VERSION="$VERSION_MAJOR" PACKAGETYPE="dnf" if [ "$VERSION" = "7" ]; then PACKAGETYPE="yum" @@ -224,7 +225,7 @@ main() { if [ "$ID" = "miraclelinux" ]; then OS="rhel" fi - VERSION="$(echo "$VERSION_ID" | cut -f1 -d.)" + VERSION="$VERSION_MAJOR" PACKAGETYPE="dnf" if [ "$VERSION" = "7" ]; then PACKAGETYPE="yum" @@ -247,7 +248,7 @@ main() { ;; xenenterprise) OS="centos" - VERSION="$(echo "$VERSION_ID" | cut -f1 -d.)" + VERSION="$VERSION_MAJOR" PACKAGETYPE="yum" ;; opensuse-leap|sles) @@ -311,7 +312,7 @@ main() { ;; freebsd) OS="$ID" - VERSION="$(echo "$VERSION_ID" | cut -f1 -d.)" + VERSION="$VERSION_MAJOR" PACKAGETYPE="pkg" ;; osmc) @@ -322,7 +323,7 @@ main() { ;; photon) OS="photon" - VERSION="$(echo "$VERSION_ID" | cut -f1 -d.)" + VERSION="$VERSION_MAJOR" PACKAGETYPE="tdnf" ;; diff --git a/shell.nix b/shell.nix index 99fc7fa4de547..d412693d9fdd1 100644 --- a/shell.nix +++ b/shell.nix @@ -16,4 +16,4 @@ ) { src = ./.; }).shellNix -# nix-direnv cache busting line: sha256-AUOjLomba75qfzb9Vxc0Sktyeces6hBSuOMgboWcDnE= +# nix-direnv cache busting line: sha256-jJSSXMyUqcJoZuqfSlBsKDQezyqS+jDkRglMMjG1K8g= diff --git a/ssh/tailssh/incubator.go b/ssh/tailssh/incubator.go index dd280143e36e3..f75646771057a 100644 --- a/ssh/tailssh/incubator.go +++ b/ssh/tailssh/incubator.go @@ -74,6 +74,9 @@ var maybeStartLoginSession = func(dlogf logger.Logf, ia incubatorArgs) (close fu return nil } +// truePaths are the common locations to find the true binary, in likelihood order. +var truePaths = [...]string{"/usr/bin/true", "/bin/true"} + // tryExecInDir tries to run a command in dir and returns nil if it succeeds. // Otherwise, it returns a filesystem error or a timeout error if the command // took too long. @@ -93,10 +96,14 @@ func tryExecInDir(ctx context.Context, dir string) error { windir := os.Getenv("windir") return run(filepath.Join(windir, "system32", "doskey.exe")) } - if err := run("/bin/true"); !errors.Is(err, exec.ErrNotFound) { // including nil - return err + // Execute the first "true" we find in the list. + for _, path := range truePaths { + // Note: LookPath does not consult $PATH when passed multi-label paths. + if p, err := exec.LookPath(path); err == nil { + return run(p) + } } - return run("/usr/bin/true") + return exec.ErrNotFound } // newIncubatorCommand returns a new exec.Cmd configured with diff --git a/syncs/locked.go b/syncs/locked.go index d2048665dee3d..d2e9edef7a9dd 100644 --- a/syncs/locked.go +++ b/syncs/locked.go @@ -8,7 +8,7 @@ import ( ) // AssertLocked panics if m is not locked. -func AssertLocked(m *sync.Mutex) { +func AssertLocked(m *Mutex) { if m.TryLock() { m.Unlock() panic("mutex is not locked") @@ -16,7 +16,7 @@ func AssertLocked(m *sync.Mutex) { } // AssertRLocked panics if rw is not locked for reading or writing. -func AssertRLocked(rw *sync.RWMutex) { +func AssertRLocked(rw *RWMutex) { if rw.TryLock() { rw.Unlock() panic("mutex is not locked") diff --git a/syncs/mutex.go b/syncs/mutex.go new file mode 100644 index 0000000000000..e61d1d1ab0687 --- /dev/null +++ b/syncs/mutex.go @@ -0,0 +1,18 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_mutex_debug + +package syncs + +import "sync" + +// Mutex is an alias for sync.Mutex. +// +// It's only not a sync.Mutex when built with the ts_mutex_debug build tag. +type Mutex = sync.Mutex + +// RWMutex is an alias for sync.RWMutex. +// +// It's only not a sync.RWMutex when built with the ts_mutex_debug build tag. +type RWMutex = sync.RWMutex diff --git a/syncs/mutex_debug.go b/syncs/mutex_debug.go new file mode 100644 index 0000000000000..14b52ffe3cc51 --- /dev/null +++ b/syncs/mutex_debug.go @@ -0,0 +1,18 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_mutex_debug + +package syncs + +import "sync" + +type Mutex struct { + sync.Mutex +} + +type RWMutex struct { + sync.RWMutex +} + +// TODO(bradfitz): actually track stuff when in debug mode. diff --git a/syncs/shardedint_test.go b/syncs/shardedint_test.go index d355a15400a90..815a739d13842 100644 --- a/syncs/shardedint_test.go +++ b/syncs/shardedint_test.go @@ -1,13 +1,14 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package syncs +package syncs_test import ( "expvar" "sync" "testing" + . "tailscale.com/syncs" "tailscale.com/tstest" ) diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index a95d0559c2bec..8468aa09efb3e 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -177,7 +177,8 @@ type CapabilityVersion int // - 128: 2025-10-02: can handle C2N /debug/health. // - 129: 2025-10-04: Fixed sleep/wake deadlock in magicsock when using peer relay (PR #17449) // - 130: 2025-10-06: client can send key.HardwareAttestationPublic and key.HardwareAttestationKeySignature in MapRequest -const CurrentCapabilityVersion CapabilityVersion = 130 +// - 131: 2025-11-25: client respects [NodeAttrDefaultAutoUpdate] +const CurrentCapabilityVersion CapabilityVersion = 131 // ID is an integer ID for a user, node, or login allocated by the // control plane. @@ -255,9 +256,9 @@ func (u StableNodeID) IsZero() bool { // have a general gmail address login associated with the user. type User struct { ID UserID - DisplayName string // if non-empty overrides Login field - ProfilePicURL string // if non-empty overrides Login field - Created time.Time + DisplayName string // if non-empty overrides Login field + ProfilePicURL string `json:",omitzero"` // if non-empty overrides Login field + Created time.Time `json:",omitzero"` } // Login is a user from a specific identity provider, not associated with any @@ -268,7 +269,7 @@ type Login struct { Provider string // "google", "github", "okta_foo", etc. LoginName string // an email address or "email-ish" string (like alice@github) DisplayName string // from the IdP - ProfilePicURL string // from the IdP + ProfilePicURL string `json:",omitzero"` // from the IdP } // A UserProfile is display-friendly data for a [User]. @@ -278,7 +279,7 @@ type UserProfile struct { ID UserID LoginName string // "alice@smith.com"; for display purposes only (provider is not listed) DisplayName string // "Alice Smith" - ProfilePicURL string `json:",omitempty"` + ProfilePicURL string `json:",omitzero"` } func (p *UserProfile) Equal(p2 *UserProfile) bool { @@ -345,13 +346,13 @@ type Node struct { User UserID // Sharer, if non-zero, is the user who shared this node, if different than User. - Sharer UserID `json:",omitempty"` + Sharer UserID `json:",omitzero"` Key key.NodePublic - KeyExpiry time.Time // the zero value if this node does not expire + KeyExpiry time.Time `json:",omitzero"` // the zero value if this node does not expire KeySignature tkatype.MarshaledSignature `json:",omitempty"` - Machine key.MachinePublic - DiscoKey key.DiscoPublic + Machine key.MachinePublic `json:",omitzero"` + DiscoKey key.DiscoPublic `json:",omitzero"` // Addresses are the IP addresses of this Node directly. Addresses []netip.Prefix @@ -361,7 +362,7 @@ type Node struct { // As of CapabilityVersion 112, this may be nil (null or undefined) on the wire // to mean the same as Addresses. Internally, it is always filled in with // its possibly-implicit value. - AllowedIPs []netip.Prefix + AllowedIPs []netip.Prefix `json:",omitzero"` // _not_ omitempty; only nil is special Endpoints []netip.AddrPort `json:",omitempty"` // IP+port (public via STUN, and local LANs) @@ -375,18 +376,18 @@ type Node struct { // this field. See tailscale/tailscale#14636. Do not use this field in code // other than in the upgradeNode func, which canonicalizes it to HomeDERP // if it arrives as a LegacyDERPString string on the wire. - LegacyDERPString string `json:"DERP,omitempty"` // DERP-in-IP:port ("127.3.3.40:N") endpoint + LegacyDERPString string `json:"DERP,omitzero"` // DERP-in-IP:port ("127.3.3.40:N") endpoint // HomeDERP is the modern version of the DERP string field, with just an // integer. The client advertises support for this as of capver 111. // // HomeDERP may be zero if not (yet) known, but ideally always be non-zero // for magicsock connectivity to function normally. - HomeDERP int `json:",omitempty"` // DERP region ID of the node's home DERP + HomeDERP int `json:",omitzero"` // DERP region ID of the node's home DERP - Hostinfo HostinfoView - Created time.Time - Cap CapabilityVersion `json:",omitempty"` // if non-zero, the node's capability version; old servers might not send + Hostinfo HostinfoView `json:",omitzero"` + Created time.Time `json:",omitzero"` + Cap CapabilityVersion `json:",omitzero"` // if non-zero, the node's capability version; old servers might not send // Tags are the list of ACL tags applied to this node. // Tags take the form of `tag:` where value starts @@ -453,25 +454,25 @@ type Node struct { // it do anything. It is the tailscaled client's job to double-check the // MapResponse's PacketFilter to verify that its AllowedIPs will not be // accepted by the packet filter. - UnsignedPeerAPIOnly bool `json:",omitempty"` + UnsignedPeerAPIOnly bool `json:",omitzero"` // The following three computed fields hold the various names that can // be used for this node in UIs. They are populated from controlclient // (not from control) by calling node.InitDisplayNames. These can be // used directly or accessed via node.DisplayName or node.DisplayNames. - ComputedName string `json:",omitempty"` // MagicDNS base name (for normal non-shared-in nodes), FQDN (without trailing dot, for shared-in nodes), or Hostname (if no MagicDNS) + ComputedName string `json:",omitzero"` // MagicDNS base name (for normal non-shared-in nodes), FQDN (without trailing dot, for shared-in nodes), or Hostname (if no MagicDNS) computedHostIfDifferent string // hostname, if different than ComputedName, otherwise empty - ComputedNameWithHost string `json:",omitempty"` // either "ComputedName" or "ComputedName (computedHostIfDifferent)", if computedHostIfDifferent is set + ComputedNameWithHost string `json:",omitzero"` // either "ComputedName" or "ComputedName (computedHostIfDifferent)", if computedHostIfDifferent is set // DataPlaneAuditLogID is the per-node logtail ID used for data plane audit logging. - DataPlaneAuditLogID string `json:",omitempty"` + DataPlaneAuditLogID string `json:",omitzero"` // Expired is whether this node's key has expired. Control may send // this; clients are only allowed to set this from false to true. On // the client, this is calculated client-side based on a timestamp sent // from control, to avoid clock skew issues. - Expired bool `json:",omitempty"` + Expired bool `json:",omitzero"` // SelfNodeV4MasqAddrForThisPeer is the IPv4 that this peer knows the current node as. // It may be empty if the peer knows the current node by its native @@ -486,7 +487,7 @@ type Node struct { // This only applies to traffic originating from the current node to the // peer or any of its subnets. Traffic originating from subnet routes will // not be masqueraded (e.g. in case of --snat-subnet-routes). - SelfNodeV4MasqAddrForThisPeer *netip.Addr `json:",omitempty"` + SelfNodeV4MasqAddrForThisPeer *netip.Addr `json:",omitzero"` // TODO: de-pointer: tailscale/tailscale#17978 // SelfNodeV6MasqAddrForThisPeer is the IPv6 that this peer knows the current node as. // It may be empty if the peer knows the current node by its native @@ -501,17 +502,17 @@ type Node struct { // This only applies to traffic originating from the current node to the // peer or any of its subnets. Traffic originating from subnet routes will // not be masqueraded (e.g. in case of --snat-subnet-routes). - SelfNodeV6MasqAddrForThisPeer *netip.Addr `json:",omitempty"` + SelfNodeV6MasqAddrForThisPeer *netip.Addr `json:",omitzero"` // TODO: de-pointer: tailscale/tailscale#17978 // IsWireGuardOnly indicates that this is a non-Tailscale WireGuard peer, it // is not expected to speak Disco or DERP, and it must have Endpoints in // order to be reachable. - IsWireGuardOnly bool `json:",omitempty"` + IsWireGuardOnly bool `json:",omitzero"` // IsJailed indicates that this node is jailed and should not be allowed // initiate connections, however outbound connections to it should still be // allowed. - IsJailed bool `json:",omitempty"` + IsJailed bool `json:",omitzero"` // ExitNodeDNSResolvers is the list of DNS servers that should be used when this // node is marked IsWireGuardOnly and being used as an exit node. @@ -827,10 +828,10 @@ type Location struct { // Because it contains pointers (slices), this type should not be used // as a value type. type Hostinfo struct { - IPNVersion string `json:",omitempty"` // version of this code (in version.Long format) - FrontendLogID string `json:",omitempty"` // logtail ID of frontend instance - BackendLogID string `json:",omitempty"` // logtail ID of backend instance - OS string `json:",omitempty"` // operating system the client runs on (a version.OS value) + IPNVersion string `json:",omitzero"` // version of this code (in version.Long format) + FrontendLogID string `json:",omitzero"` // logtail ID of frontend instance + BackendLogID string `json:",omitzero"` // logtail ID of backend instance + OS string `json:",omitzero"` // operating system the client runs on (a version.OS value) // OSVersion is the version of the OS, if available. // @@ -842,25 +843,25 @@ type Hostinfo struct { // string on Linux, like "Debian 10.4; kernel=xxx; container; env=kn" and so // on. As of Tailscale 1.32, this is simply the kernel version on Linux, like // "5.10.0-17-amd64". - OSVersion string `json:",omitempty"` + OSVersion string `json:",omitzero"` - Container opt.Bool `json:",omitempty"` // best-effort whether the client is running in a container - Env string `json:",omitempty"` // a hostinfo.EnvType in string form - Distro string `json:",omitempty"` // "debian", "ubuntu", "nixos", ... - DistroVersion string `json:",omitempty"` // "20.04", ... - DistroCodeName string `json:",omitempty"` // "jammy", "bullseye", ... + Container opt.Bool `json:",omitzero"` // best-effort whether the client is running in a container + Env string `json:",omitzero"` // a hostinfo.EnvType in string form + Distro string `json:",omitzero"` // "debian", "ubuntu", "nixos", ... + DistroVersion string `json:",omitzero"` // "20.04", ... + DistroCodeName string `json:",omitzero"` // "jammy", "bullseye", ... // App is used to disambiguate Tailscale clients that run using tsnet. - App string `json:",omitempty"` // "k8s-operator", "golinks", ... - - Desktop opt.Bool `json:",omitempty"` // if a desktop was detected on Linux - Package string `json:",omitempty"` // Tailscale package to disambiguate ("choco", "appstore", etc; "" for unknown) - DeviceModel string `json:",omitempty"` // mobile phone model ("Pixel 3a", "iPhone12,3") - PushDeviceToken string `json:",omitempty"` // macOS/iOS APNs device token for notifications (and Android in the future) - Hostname string `json:",omitempty"` // name of the host the client runs on - ShieldsUp bool `json:",omitempty"` // indicates whether the host is blocking incoming connections - ShareeNode bool `json:",omitempty"` // indicates this node exists in netmap because it's owned by a shared-to user - NoLogsNoSupport bool `json:",omitempty"` // indicates that the user has opted out of sending logs and support + App string `json:",omitzero"` // "k8s-operator", "golinks", ... + + Desktop opt.Bool `json:",omitzero"` // if a desktop was detected on Linux + Package string `json:",omitzero"` // Tailscale package to disambiguate ("choco", "appstore", etc; "" for unknown) + DeviceModel string `json:",omitzero"` // mobile phone model ("Pixel 3a", "iPhone12,3") + PushDeviceToken string `json:",omitzero"` // macOS/iOS APNs device token for notifications (and Android in the future) + Hostname string `json:",omitzero"` // name of the host the client runs on + ShieldsUp bool `json:",omitzero"` // indicates whether the host is blocking incoming connections + ShareeNode bool `json:",omitzero"` // indicates this node exists in netmap because it's owned by a shared-to user + NoLogsNoSupport bool `json:",omitzero"` // indicates that the user has opted out of sending logs and support // WireIngress indicates that the node would like to be wired up server-side // (DNS, etc) to be able to use Tailscale Funnel, even if it's not currently // enabled. For example, the user might only use it for intermittent @@ -868,38 +869,38 @@ type Hostinfo struct { // away, even if it's disabled most of the time. As an optimization, this is // only sent if IngressEnabled is false, as IngressEnabled implies that this // option is true. - WireIngress bool `json:",omitempty"` - IngressEnabled bool `json:",omitempty"` // if the node has any funnel endpoint enabled - AllowsUpdate bool `json:",omitempty"` // indicates that the node has opted-in to admin-console-drive remote updates - Machine string `json:",omitempty"` // the current host's machine type (uname -m) - GoArch string `json:",omitempty"` // GOARCH value (of the built binary) - GoArchVar string `json:",omitempty"` // GOARM, GOAMD64, etc (of the built binary) - GoVersion string `json:",omitempty"` // Go version binary was built with + WireIngress bool `json:",omitzero"` + IngressEnabled bool `json:",omitzero"` // if the node has any funnel endpoint enabled + AllowsUpdate bool `json:",omitzero"` // indicates that the node has opted-in to admin-console-drive remote updates + Machine string `json:",omitzero"` // the current host's machine type (uname -m) + GoArch string `json:",omitzero"` // GOARCH value (of the built binary) + GoArchVar string `json:",omitzero"` // GOARM, GOAMD64, etc (of the built binary) + GoVersion string `json:",omitzero"` // Go version binary was built with RoutableIPs []netip.Prefix `json:",omitempty"` // set of IP ranges this client can route RequestTags []string `json:",omitempty"` // set of ACL tags this node wants to claim WoLMACs []string `json:",omitempty"` // MAC address(es) to send Wake-on-LAN packets to wake this node (lowercase hex w/ colons) Services []Service `json:",omitempty"` // services advertised by this machine - NetInfo *NetInfo `json:",omitempty"` + NetInfo *NetInfo `json:",omitzero"` SSH_HostKeys []string `json:"sshHostKeys,omitempty"` // if advertised - Cloud string `json:",omitempty"` - Userspace opt.Bool `json:",omitempty"` // if the client is running in userspace (netstack) mode - UserspaceRouter opt.Bool `json:",omitempty"` // if the client's subnet router is running in userspace (netstack) mode - AppConnector opt.Bool `json:",omitempty"` // if the client is running the app-connector service - ServicesHash string `json:",omitempty"` // opaque hash of the most recent list of tailnet services, change in hash indicates config should be fetched via c2n - ExitNodeID StableNodeID `json:",omitzero"` // the client’s selected exit node, empty when unselected. + Cloud string `json:",omitzero"` + Userspace opt.Bool `json:",omitzero"` // if the client is running in userspace (netstack) mode + UserspaceRouter opt.Bool `json:",omitzero"` // if the client's subnet router is running in userspace (netstack) mode + AppConnector opt.Bool `json:",omitzero"` // if the client is running the app-connector service + ServicesHash string `json:",omitzero"` // opaque hash of the most recent list of tailnet services, change in hash indicates config should be fetched via c2n + ExitNodeID StableNodeID `json:",omitzero"` // the client’s selected exit node, empty when unselected. // Location represents geographical location data about a // Tailscale host. Location is optional and only set if // explicitly declared by a node. - Location *Location `json:",omitempty"` + Location *Location `json:",omitzero"` - TPM *TPMInfo `json:",omitempty"` // TPM device metadata, if available + TPM *TPMInfo `json:",omitzero"` // TPM device metadata, if available // StateEncrypted reports whether the node state is stored encrypted on // disk. The actual mechanism is platform-specific: // * Apple nodes use the Keychain // * Linux and Windows nodes use the TPM // * Android apps use EncryptedSharedPreferences - StateEncrypted opt.Bool `json:",omitempty"` + StateEncrypted opt.Bool `json:",omitzero"` // NOTE: any new fields containing pointers in this type // require changes to Hostinfo.Equal. @@ -913,25 +914,25 @@ type TPMInfo struct { // https://trustedcomputinggroup.org/resource/vendor-id-registry/, // for example "MSFT" for Microsoft. // Read from TPM_PT_MANUFACTURER. - Manufacturer string `json:",omitempty"` + Manufacturer string `json:",omitzero"` // Vendor is a vendor ID string, up to 16 characters. // Read from TPM_PT_VENDOR_STRING_*. - Vendor string `json:",omitempty"` + Vendor string `json:",omitzero"` // Model is a vendor-defined TPM model. // Read from TPM_PT_VENDOR_TPM_TYPE. - Model int `json:",omitempty"` + Model int `json:",omitzero"` // FirmwareVersion is the version number of the firmware. // Read from TPM_PT_FIRMWARE_VERSION_*. - FirmwareVersion uint64 `json:",omitempty"` + FirmwareVersion uint64 `json:",omitzero"` // SpecRevision is the TPM 2.0 spec revision encoded as a single number. All // revisions can be found at // https://trustedcomputinggroup.org/resource/tpm-library-specification/. // Before revision 184, TCG used the "01.83" format for revision 183. - SpecRevision int `json:",omitempty"` + SpecRevision int `json:",omitzero"` // FamilyIndicator is the TPM spec family, like "2.0". // Read from TPM_PT_FAMILY_INDICATOR. - FamilyIndicator string `json:",omitempty"` + FamilyIndicator string `json:",omitzero"` } // Present reports whether a TPM device is present on this machine. @@ -1016,41 +1017,37 @@ func (v HostinfoView) TailscaleSSHEnabled() bool { return v.ж.TailscaleSSHEnabl type NetInfo struct { // MappingVariesByDestIP says whether the host's NAT mappings // vary based on the destination IP. - MappingVariesByDestIP opt.Bool - - // HairPinning is their router does hairpinning. - // It reports true even if there's no NAT involved. - HairPinning opt.Bool + MappingVariesByDestIP opt.Bool `json:",omitzero"` // WorkingIPv6 is whether the host has IPv6 internet connectivity. - WorkingIPv6 opt.Bool + WorkingIPv6 opt.Bool `json:",omitzero"` // OSHasIPv6 is whether the OS supports IPv6 at all, regardless of // whether IPv6 internet connectivity is available. - OSHasIPv6 opt.Bool + OSHasIPv6 opt.Bool `json:",omitzero"` // WorkingUDP is whether the host has UDP internet connectivity. - WorkingUDP opt.Bool + WorkingUDP opt.Bool `json:",omitzero"` // WorkingICMPv4 is whether ICMPv4 works. // Empty means not checked. - WorkingICMPv4 opt.Bool + WorkingICMPv4 opt.Bool `json:",omitzero"` // HavePortMap is whether we have an existing portmap open // (UPnP, PMP, or PCP). - HavePortMap bool `json:",omitempty"` + HavePortMap bool `json:",omitzero"` // UPnP is whether UPnP appears present on the LAN. // Empty means not checked. - UPnP opt.Bool + UPnP opt.Bool `json:",omitzero"` // PMP is whether NAT-PMP appears present on the LAN. // Empty means not checked. - PMP opt.Bool + PMP opt.Bool `json:",omitzero"` // PCP is whether PCP appears present on the LAN. // Empty means not checked. - PCP opt.Bool + PCP opt.Bool `json:",omitzero"` // PreferredDERP is this node's preferred (home) DERP region ID. // This is where the node expects to be contacted to begin a @@ -1059,10 +1056,10 @@ type NetInfo struct { // that are located elsewhere) but PreferredDERP is the region ID // that the node subscribes to traffic at. // Zero means disconnected or unknown. - PreferredDERP int + PreferredDERP int `json:",omitzero"` // LinkType is the current link type, if known. - LinkType string `json:",omitempty"` // "wired", "wifi", "mobile" (LTE, 4G, 3G, etc) + LinkType string `json:",omitzero"` // "wired", "wifi", "mobile" (LTE, 4G, 3G, etc) // DERPLatency is the fastest recent time to reach various // DERP STUN servers, in seconds. The map key is the @@ -1080,7 +1077,7 @@ type NetInfo struct { // "{nft,ift}-REASON", like "nft-forced" or "ipt-default". Empty means // either not Linux or a configuration in which the host firewall rules // are not managed by tailscaled. - FirewallMode string `json:",omitempty"` + FirewallMode string `json:",omitzero"` // Update BasicallyEqual when adding fields. } @@ -1089,8 +1086,8 @@ func (ni *NetInfo) String() string { if ni == nil { return "NetInfo(nil)" } - return fmt.Sprintf("NetInfo{varies=%v hairpin=%v ipv6=%v ipv6os=%v udp=%v icmpv4=%v derp=#%v portmap=%v link=%q firewallmode=%q}", - ni.MappingVariesByDestIP, ni.HairPinning, ni.WorkingIPv6, + return fmt.Sprintf("NetInfo{varies=%v ipv6=%v ipv6os=%v udp=%v icmpv4=%v derp=#%v portmap=%v link=%q firewallmode=%q}", + ni.MappingVariesByDestIP, ni.WorkingIPv6, ni.OSHasIPv6, ni.WorkingUDP, ni.WorkingICMPv4, ni.PreferredDERP, ni.portMapSummary(), ni.LinkType, ni.FirewallMode) } @@ -1133,7 +1130,6 @@ func (ni *NetInfo) BasicallyEqual(ni2 *NetInfo) bool { return true } return ni.MappingVariesByDestIP == ni2.MappingVariesByDestIP && - ni.HairPinning == ni2.HairPinning && ni.WorkingIPv6 == ni2.WorkingIPv6 && ni.OSHasIPv6 == ni2.OSHasIPv6 && ni.WorkingUDP == ni2.WorkingUDP && @@ -1369,8 +1365,8 @@ type MapRequest struct { // For current values and history, see the CapabilityVersion type's docs. Version CapabilityVersion - Compress string // "zstd" or "" (no compression) - KeepAlive bool // whether server should send keep-alives back to us + Compress string `json:",omitzero"` // "zstd" or "" (no compression) + KeepAlive bool `json:",omitzero"` // whether server should send keep-alives back to us NodeKey key.NodePublic DiscoKey key.DiscoPublic @@ -1393,7 +1389,7 @@ type MapRequest struct { // // If true and Version >= 68, the server should treat this as a read-only // request and ignore any Hostinfo or other fields that might be set. - Stream bool + Stream bool `json:",omitzero"` // Hostinfo is the client's current Hostinfo. Although it is always included // in the request, the server may choose to ignore it when Stream is true @@ -1410,14 +1406,14 @@ type MapRequest struct { // // The server may choose to ignore the request for any reason and start a // new map session. This is only applicable when Stream is true. - MapSessionHandle string `json:",omitempty"` + MapSessionHandle string `json:",omitzero"` // MapSessionSeq is the sequence number in the map session identified by // MapSesssionHandle that was most recently processed by the client. // It is only applicable when MapSessionHandle is specified. // If the server chooses to honor the MapSessionHandle request, only sequence // numbers greater than this value will be returned. - MapSessionSeq int64 `json:",omitempty"` + MapSessionSeq int64 `json:",omitzero"` // Endpoints are the client's magicsock UDP ip:port endpoints (IPv4 or IPv6). // These can be ignored if Stream is true and Version >= 68. @@ -1428,7 +1424,7 @@ type MapRequest struct { // TKAHead describes the hash of the latest AUM applied to the local // tailnet key authority, if one is operating. // It is encoded as tka.AUMHash.MarshalText. - TKAHead string `json:",omitempty"` + TKAHead string `json:",omitzero"` // ReadOnly was set when client just wanted to fetch the MapResponse, // without updating their Endpoints. The intended use was for clients to @@ -1436,7 +1432,7 @@ type MapRequest struct { // update. // // Deprecated: always false as of Version 68. - ReadOnly bool `json:",omitempty"` + ReadOnly bool `json:",omitzero"` // OmitPeers is whether the client is okay with the Peers list being omitted // in the response. @@ -1452,7 +1448,7 @@ type MapRequest struct { // If OmitPeers is true, Stream is false, but ReadOnly is true, // then all the response fields are included. (This is what the client does // when initially fetching the DERP map.) - OmitPeers bool `json:",omitempty"` + OmitPeers bool `json:",omitzero"` // DebugFlags is a list of strings specifying debugging and // development features to enable in handling this map @@ -1472,7 +1468,7 @@ type MapRequest struct { // identifies this specific connection to the server. The server may choose to // use this handle to identify the connection for debugging or testing // purposes. It has no semantic meaning. - ConnectionHandleForTest string `json:",omitempty"` + ConnectionHandleForTest string `json:",omitzero"` } // PortRange represents a range of UDP or TCP port numbers. @@ -1763,7 +1759,7 @@ type DNSConfig struct { // in the network map, aka MagicDNS. // Despite the (legacy) name, does not necessarily cause request // proxying to be enabled. - Proxied bool `json:",omitempty"` + Proxied bool `json:",omitzero"` // Nameservers are the IP addresses of the global nameservers to use. // @@ -1800,7 +1796,7 @@ type DNSConfig struct { // TempCorpIssue13969 is a temporary (2023-08-16) field for an internal hack day prototype. // It contains a user inputed URL that should have a list of domains to be blocked. // See https://github.com/tailscale/corp/issues/13969. - TempCorpIssue13969 string `json:",omitempty"` + TempCorpIssue13969 string `json:",omitzero"` } // DNSRecord is an extra DNS record to add to MagicDNS. @@ -1812,7 +1808,7 @@ type DNSRecord struct { // Type is the DNS record type. // Empty means A or AAAA, depending on value. // Other values are currently ignored. - Type string `json:",omitempty"` + Type string `json:",omitzero"` // Value is the IP address in string form. // TODO(bradfitz): if we ever add support for record types @@ -1860,11 +1856,11 @@ type PingRequest struct { // URLIsNoise, if true, means that the client should hit URL over the Noise // transport instead of TLS. - URLIsNoise bool `json:",omitempty"` + URLIsNoise bool `json:",omitzero"` // Log is whether to log about this ping in the success case. // For failure cases, the client will log regardless. - Log bool `json:",omitempty"` + Log bool `json:",omitzero"` // Types is the types of ping that are initiated. Can be any PingType, comma // separated, e.g. "disco,TSMP" @@ -1874,10 +1870,10 @@ type PingRequest struct { // node's c2n handler and the HTTP response sent in a POST to URL. For c2n, // the value of URLIsNoise is ignored and only the Noise transport (back to // the control plane) will be used, as if URLIsNoise were true. - Types string `json:",omitempty"` + Types string `json:",omitzero"` // IP is the ping target, when needed by the PingType(s) given in Types. - IP netip.Addr + IP netip.Addr `json:",omitzero"` // Payload is the ping payload. // @@ -2154,12 +2150,14 @@ type MapResponse struct { // or nothing to report. ClientVersion *ClientVersion `json:",omitempty"` - // DefaultAutoUpdate is the default node auto-update setting for this + // DeprecatedDefaultAutoUpdate is the default node auto-update setting for this // tailnet. The node is free to opt-in or out locally regardless of this - // value. This value is only used on first MapResponse from control, the - // auto-update setting doesn't change if the tailnet admin flips the - // default after the node registered. - DefaultAutoUpdate opt.Bool `json:",omitempty"` + // value. Once this value has been set and stored in the client, future + // changes from the control plane are ignored. + // + // Deprecated: use NodeAttrDefaultAutoUpdate instead. See + // https://github.com/tailscale/tailscale/issues/11502. + DeprecatedDefaultAutoUpdate opt.Bool `json:"DefaultAutoUpdate,omitempty"` } // DisplayMessage represents a health state of the node from the control plane's @@ -2465,6 +2463,10 @@ const ( // of connections to the default network interface on Darwin nodes. CapabilityDebugDisableBindConnToInterface NodeCapability = "https://tailscale.com/cap/debug-disable-bind-conn-to-interface" + // CapabilityDebugDisableBindConnToInterface disables the automatic binding + // of connections to the default network interface on Darwin nodes using network extensions + CapabilityDebugDisableBindConnToInterfaceAppleExt NodeCapability = "https://tailscale.com/cap/debug-disable-bind-conn-to-interface-apple-ext" + // CapabilityTailnetLock indicates the node may initialize tailnet lock. CapabilityTailnetLock NodeCapability = "https://tailscale.com/cap/tailnet-lock" @@ -2722,6 +2724,14 @@ const ( // default behavior is to trust the control plane when it claims that a // node is no longer online, but that is not a reliable signal. NodeAttrClientSideReachability = "client-side-reachability" + + // NodeAttrDefaultAutoUpdate advertises the default node auto-update setting + // for this tailnet. The node is free to opt-in or out locally regardless of + // this value. Once this has been set and stored in the client, future + // changes from the control plane are ignored. + // + // The value of the key in [NodeCapMap] is a JSON boolean. + NodeAttrDefaultAutoUpdate NodeCapability = "default-auto-update" ) // SetDNSRequest is a request to add a DNS record. @@ -3044,29 +3054,29 @@ type SSHRecordingAttempt struct { // See QueryFeatureResponse for response structure. type QueryFeatureRequest struct { // Feature is the string identifier for a feature. - Feature string `json:",omitempty"` + Feature string `json:",omitzero"` // NodeKey is the client's current node key. - NodeKey key.NodePublic `json:",omitempty"` + NodeKey key.NodePublic `json:",omitzero"` } // QueryFeatureResponse is the response to an QueryFeatureRequest. // See cli.enableFeatureInteractive for usage. type QueryFeatureResponse struct { // Complete is true when the feature is already enabled. - Complete bool `json:",omitempty"` + Complete bool `json:",omitzero"` // Text holds lines to display in the CLI with information // about the feature and how to enable it. // // Lines are separated by newline characters. The final // newline may be omitted. - Text string `json:",omitempty"` + Text string `json:",omitzero"` // URL is the link for the user to visit to take action on // enabling the feature. // // When empty, there is no action for this user to take. - URL string `json:",omitempty"` + URL string `json:",omitzero"` // ShouldWait specifies whether the CLI should block and // wait for the user to enable the feature. @@ -3079,7 +3089,7 @@ type QueryFeatureResponse struct { // // The CLI can watch the IPN notification bus for changes in // required node capabilities to know when to continue. - ShouldWait bool `json:",omitempty"` + ShouldWait bool `json:",omitzero"` } // WebClientAuthResponse is the response to a web client authentication request @@ -3089,15 +3099,15 @@ type WebClientAuthResponse struct { // ID is a unique identifier for the session auth request. // It can be supplied to "/machine/webclient/wait" to pause until // the session authentication has been completed. - ID string `json:",omitempty"` + ID string `json:",omitzero"` // URL is the link for the user to visit to authenticate the session. // // When empty, there is no action for the user to take. - URL string `json:",omitempty"` + URL string `json:",omitzero"` // Complete is true when the session authentication has been completed. - Complete bool `json:",omitempty"` + Complete bool `json:",omitzero"` } // OverTLSPublicKeyResponse is the JSON response to /key?v= @@ -3173,10 +3183,10 @@ type PeerChange struct { // DERPRegion, if non-zero, means that NodeID's home DERP // region ID is now this number. - DERPRegion int `json:",omitempty"` + DERPRegion int `json:",omitzero"` // Cap, if non-zero, means that NodeID's capability version has changed. - Cap CapabilityVersion `json:",omitempty"` + Cap CapabilityVersion `json:",omitzero"` // CapMap, if non-nil, means that NodeID's capability map has changed. CapMap NodeCapMap `json:",omitempty"` @@ -3186,23 +3196,23 @@ type PeerChange struct { Endpoints []netip.AddrPort `json:",omitempty"` // Key, if non-nil, means that the NodeID's wireguard public key changed. - Key *key.NodePublic `json:",omitempty"` + Key *key.NodePublic `json:",omitzero"` // TODO: de-pointer: tailscale/tailscale#17978 // KeySignature, if non-nil, means that the signature of the wireguard // public key has changed. KeySignature tkatype.MarshaledSignature `json:",omitempty"` // DiscoKey, if non-nil, means that the NodeID's discokey changed. - DiscoKey *key.DiscoPublic `json:",omitempty"` + DiscoKey *key.DiscoPublic `json:",omitzero"` // TODO: de-pointer: tailscale/tailscale#17978 // Online, if non-nil, means that the NodeID's online status changed. - Online *bool `json:",omitempty"` + Online *bool `json:",omitzero"` // LastSeen, if non-nil, means that the NodeID's online status changed. - LastSeen *time.Time `json:",omitempty"` + LastSeen *time.Time `json:",omitzero"` // TODO: de-pointer: tailscale/tailscale#17978 // KeyExpiry, if non-nil, changes the NodeID's key expiry. - KeyExpiry *time.Time `json:",omitempty"` + KeyExpiry *time.Time `json:",omitzero"` // TODO: de-pointer: tailscale/tailscale#17978 } // DerpMagicIP is a fake WireGuard endpoint IP address that means to @@ -3280,14 +3290,14 @@ const ( // POST https:///machine/audit-log type AuditLogRequest struct { // Version is the client's current CapabilityVersion. - Version CapabilityVersion `json:",omitempty"` + Version CapabilityVersion `json:",omitzero"` // NodeKey is the client's current node key. NodeKey key.NodePublic `json:",omitzero"` // Action is the action to be logged. It must correspond to a known action in the control plane. - Action ClientAuditAction `json:",omitempty"` + Action ClientAuditAction `json:",omitzero"` // Details is an opaque string, specific to the action being logged. Empty strings may not // be valid depending on the action being logged. - Details string `json:",omitempty"` + Details string `json:",omitzero"` // Timestamp is the time at which the audit log was generated on the node. Timestamp time.Time `json:",omitzero"` } diff --git a/tailcfg/tailcfg_clone.go b/tailcfg/tailcfg_clone.go index 9aa7673886bc6..751b7c288f274 100644 --- a/tailcfg/tailcfg_clone.go +++ b/tailcfg/tailcfg_clone.go @@ -207,7 +207,6 @@ func (src *NetInfo) Clone() *NetInfo { // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _NetInfoCloneNeedsRegeneration = NetInfo(struct { MappingVariesByDestIP opt.Bool - HairPinning opt.Bool WorkingIPv6 opt.Bool OSHasIPv6 opt.Bool WorkingUDP opt.Bool diff --git a/tailcfg/tailcfg_test.go b/tailcfg/tailcfg_test.go index addd2330ba239..6691263eb997a 100644 --- a/tailcfg/tailcfg_test.go +++ b/tailcfg/tailcfg_test.go @@ -607,7 +607,6 @@ func TestNodeEqual(t *testing.T) { func TestNetInfoFields(t *testing.T) { handled := []string{ "MappingVariesByDestIP", - "HairPinning", "WorkingIPv6", "OSHasIPv6", "WorkingUDP", diff --git a/tailcfg/tailcfg_view.go b/tailcfg/tailcfg_view.go index 88dd90096ab55..dbd29a87a354e 100644 --- a/tailcfg/tailcfg_view.go +++ b/tailcfg/tailcfg_view.go @@ -741,10 +741,6 @@ func (v *NetInfoView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { // vary based on the destination IP. func (v NetInfoView) MappingVariesByDestIP() opt.Bool { return v.ж.MappingVariesByDestIP } -// HairPinning is their router does hairpinning. -// It reports true even if there's no NAT involved. -func (v NetInfoView) HairPinning() opt.Bool { return v.ж.HairPinning } - // WorkingIPv6 is whether the host has IPv6 internet connectivity. func (v NetInfoView) WorkingIPv6() opt.Bool { return v.ж.WorkingIPv6 } @@ -809,7 +805,6 @@ func (v NetInfoView) String() string { return v.ж.String() } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _NetInfoViewNeedsRegeneration = NetInfo(struct { MappingVariesByDestIP opt.Bool - HairPinning opt.Bool WorkingIPv6 opt.Bool OSHasIPv6 opt.Bool WorkingUDP opt.Bool diff --git a/tka/aum.go b/tka/aum.go index 08d70897ee70f..b8c4b6c9e14d4 100644 --- a/tka/aum.go +++ b/tka/aum.go @@ -31,8 +31,8 @@ func (h AUMHash) String() string { // UnmarshalText implements encoding.TextUnmarshaler. func (h *AUMHash) UnmarshalText(text []byte) error { - if l := base32StdNoPad.DecodedLen(len(text)); l != len(h) { - return fmt.Errorf("tka.AUMHash.UnmarshalText: text wrong length: %d, want %d", l, len(text)) + if ln := base32StdNoPad.DecodedLen(len(text)); ln != len(h) { + return fmt.Errorf("tka.AUMHash.UnmarshalText: text wrong length: %d, want %d", ln, len(text)) } if _, err := base32StdNoPad.Decode(h[:], text); err != nil { return fmt.Errorf("tka.AUMHash.UnmarshalText: %w", err) @@ -55,6 +55,17 @@ func (h AUMHash) IsZero() bool { return h == (AUMHash{}) } +// PrevAUMHash represents the BLAKE2s digest of an Authority Update Message (AUM). +// Unlike an AUMHash, this can be empty if there is no previous AUM hash +// (which occurs in the genesis AUM). +type PrevAUMHash []byte + +// String returns the PrevAUMHash encoded as base32. +// This is suitable for use as a filename, and for storing in text-preferred media. +func (h PrevAUMHash) String() string { + return base32StdNoPad.EncodeToString(h[:]) +} + // AUMKind describes valid AUM types. type AUMKind uint8 @@ -119,8 +130,8 @@ func (k AUMKind) String() string { // behavior of old clients (which will ignore the field). // - No floats! type AUM struct { - MessageKind AUMKind `cbor:"1,keyasint"` - PrevAUMHash []byte `cbor:"2,keyasint"` + MessageKind AUMKind `cbor:"1,keyasint"` + PrevAUMHash PrevAUMHash `cbor:"2,keyasint"` // Key encodes a public key to be added to the key authority. // This field is used for AddKey AUMs. @@ -226,7 +237,7 @@ func (a *AUM) Serialize() tkatype.MarshaledAUM { // Further, experience with other attempts (JWS/JWT,SAML,X509 etc) has // taught us that even subtle behaviors such as how you handle invalid // or unrecognized fields + any invariants in subsequent re-serialization - // can easily lead to security-relevant logic bugs. Its certainly possible + // can easily lead to security-relevant logic bugs. It's certainly possible // to invent a workable scheme by massaging a JSON parsing library, though // profoundly unwise. // diff --git a/tka/aum_test.go b/tka/aum_test.go index 4297efabff13f..833a026544f54 100644 --- a/tka/aum_test.go +++ b/tka/aum_test.go @@ -5,6 +5,8 @@ package tka import ( "bytes" + "encoding/base64" + "fmt" "testing" "github.com/google/go-cmp/cmp" @@ -156,6 +158,80 @@ func TestSerialization(t *testing.T) { } } +func fromBase64(s string) []byte { + data, err := base64.StdEncoding.DecodeString(s) + if err != nil { + panic(fmt.Sprintf("base64 decode failed: %v", err)) + } + return data +} + +// This test verifies that we can read AUMs which were serialized with +// older versions of our code. +func TestDeserializeExistingAUMs(t *testing.T) { + for _, tt := range []struct { + Name string + Data []byte + Want AUM + }{ + { + // This is an AUM which was created in a test tailnet, and encoded + // on 2025-11-07 with commit d4c5b27. + Name: "genesis-aum-2025-11-07", + Data: fromBase64("pAEFAvYFpQH2AopYII0sLaLSEZU3W5DT1dG2WYnzjCBr4tXtVbCT2LvA9LS6WCAQhwVGDiUGRiu3P63gucZ/8otjt2DXyk+OBjbh5iWx1Fgg5VU4oRQiMoq5qK00McfpwtmjcheVammLCRwzdp2Zje9YIHDoOXe4ogPSy7lfA/veyPCKM6iZe3PTgzhQZ4W5Sh7wWCBYQtiQ6NcRlyVARJxgAj1BbbvdJQ0t4m+vHqU1J02oDlgg2sksJA+COfsBkrohwHBWlbKrpS8Mvigpl+enuHw9rIJYIB/+CUBBBLUz0KeHu7NKrg5ZEhjjPUWhNcf9QTNHjuNWWCCJuxqPZ6/IASPTmAERaoKnBNH/D+zY4p4TUGHR4fACjFggMtDAipPutgcxKnU9Tg2663gP3KlTQfztV3hBwiePZdRYIGYeD2erBkRouSL20lOnWHHlRq5kmNfN6xFb2CTaPjnXA4KjAQECAQNYIADftG3yaitV/YMoKSBP45zgyeodClumN9ZaeQg/DmCEowEBAgEDWCBRKbmWSzOyHXbHJuYn8s7dmMPDzxmIjgBoA80cBYgItAQbEWOrxfqJzIkFG/5uNUp0s/ScF4GiAVggAN+0bfJqK1X9gygpIE/jnODJ6h0KW6Y31lp5CD8OYIQCWEAENvzblKV2qx6PED5YdGy8kWa7nxEnaeuMmS5Wkx0n7CXs0XxD5f2NIE+pSv9cOsNkfYNndQkYD7ne33hQOsQM"), + Want: AUM{ + MessageKind: AUMCheckpoint, + State: &State{ + DisablementSecrets: [][]byte{ + fromBase64("jSwtotIRlTdbkNPV0bZZifOMIGvi1e1VsJPYu8D0tLo="), + fromBase64("EIcFRg4lBkYrtz+t4LnGf/KLY7dg18pPjgY24eYlsdQ="), + fromBase64("5VU4oRQiMoq5qK00McfpwtmjcheVammLCRwzdp2Zje8="), + fromBase64("cOg5d7iiA9LLuV8D+97I8IozqJl7c9ODOFBnhblKHvA="), + fromBase64("WELYkOjXEZclQEScYAI9QW273SUNLeJvrx6lNSdNqA4="), + fromBase64("2sksJA+COfsBkrohwHBWlbKrpS8Mvigpl+enuHw9rII="), + fromBase64("H/4JQEEEtTPQp4e7s0quDlkSGOM9RaE1x/1BM0eO41Y="), + fromBase64("ibsaj2evyAEj05gBEWqCpwTR/w/s2OKeE1Bh0eHwAow="), + fromBase64("MtDAipPutgcxKnU9Tg2663gP3KlTQfztV3hBwiePZdQ="), + fromBase64("Zh4PZ6sGRGi5IvbSU6dYceVGrmSY183rEVvYJNo+Odc="), + }, + Keys: []Key{ + { + Kind: Key25519, + Votes: 1, + Public: fromBase64("AN+0bfJqK1X9gygpIE/jnODJ6h0KW6Y31lp5CD8OYIQ="), + }, + { + Kind: Key25519, + Votes: 1, + Public: fromBase64("USm5lkszsh12xybmJ/LO3ZjDw88ZiI4AaAPNHAWICLQ="), + }, + }, + StateID1: 1253033988139371657, + StateID2: 18333649726973670556, + }, + Signatures: []tkatype.Signature{ + { + KeyID: fromBase64("AN+0bfJqK1X9gygpIE/jnODJ6h0KW6Y31lp5CD8OYIQ="), + Signature: fromBase64("BDb825SldqsejxA+WHRsvJFmu58RJ2nrjJkuVpMdJ+wl7NF8Q+X9jSBPqUr/XDrDZH2DZ3UJGA+53t94UDrEDA=="), + }, + }, + }, + }, + } { + t.Run(tt.Name, func(t *testing.T) { + var got AUM + + if err := got.Unserialize(tt.Data); err != nil { + t.Fatalf("Unserialize: %v", err) + } + + if diff := cmp.Diff(got, tt.Want); diff != "" { + t.Fatalf("wrong AUM (-got, +want):\n%s", diff) + } + }) + } +} + func TestAUMWeight(t *testing.T) { var fakeKeyID [blake2s.Size]byte testingRand(t, 1).Read(fakeKeyID[:]) diff --git a/tka/builder.go b/tka/builder.go index 642f39d77422d..ab2364d856ee2 100644 --- a/tka/builder.go +++ b/tka/builder.go @@ -114,7 +114,7 @@ func (b *UpdateBuilder) generateCheckpoint() error { } } - // Checkpoints cant specify a parent AUM. + // Checkpoints can't specify a parent AUM. state.LastAUMHash = nil return b.mkUpdate(AUM{MessageKind: AUMCheckpoint, State: &state}) } @@ -136,7 +136,7 @@ func (b *UpdateBuilder) Finalize(storage Chonk) ([]AUM, error) { needCheckpoint = false break } - return nil, fmt.Errorf("reading AUM: %v", err) + return nil, fmt.Errorf("reading AUM (%v): %v", cursor, err) } if aum.MessageKind == AUMCheckpoint { diff --git a/tka/builder_test.go b/tka/builder_test.go index 52907186b6d30..3fd32f64eac12 100644 --- a/tka/builder_test.go +++ b/tka/builder_test.go @@ -28,7 +28,7 @@ func TestAuthorityBuilderAddKey(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2} - storage := &Mem{} + storage := ChonkMem() a, _, err := Create(storage, State{ Keys: []Key{key}, DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, @@ -62,7 +62,7 @@ func TestAuthorityBuilderMaxKey(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2} - storage := &Mem{} + storage := ChonkMem() a, _, err := Create(storage, State{ Keys: []Key{key}, DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, @@ -109,7 +109,7 @@ func TestAuthorityBuilderRemoveKey(t *testing.T) { pub2, _ := testingKey25519(t, 2) key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} - storage := &Mem{} + storage := ChonkMem() a, _, err := Create(storage, State{ Keys: []Key{key, key2}, DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, @@ -155,7 +155,7 @@ func TestAuthorityBuilderSetKeyVote(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2} - storage := &Mem{} + storage := ChonkMem() a, _, err := Create(storage, State{ Keys: []Key{key}, DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, @@ -191,7 +191,7 @@ func TestAuthorityBuilderSetKeyMeta(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2, Meta: map[string]string{"a": "b"}} - storage := &Mem{} + storage := ChonkMem() a, _, err := Create(storage, State{ Keys: []Key{key}, DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, @@ -227,7 +227,7 @@ func TestAuthorityBuilderMultiple(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2} - storage := &Mem{} + storage := ChonkMem() a, _, err := Create(storage, State{ Keys: []Key{key}, DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, @@ -275,7 +275,7 @@ func TestAuthorityBuilderCheckpointsAfterXUpdates(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2} - storage := &Mem{} + storage := ChonkMem() a, _, err := Create(storage, State{ Keys: []Key{key}, DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, diff --git a/tka/chaintest_test.go b/tka/chaintest_test.go index 5811f9c8381ed..a3122b5d19da8 100644 --- a/tka/chaintest_test.go +++ b/tka/chaintest_test.go @@ -285,25 +285,25 @@ func (c *testChain) makeAUM(v *testchainNode) AUM { // Chonk returns a tailchonk containing all AUMs. func (c *testChain) Chonk() Chonk { - var out Mem + out := ChonkMem() for _, update := range c.AUMs { if err := out.CommitVerifiedAUMs([]AUM{update}); err != nil { panic(err) } } - return &out + return out } // ChonkWith returns a tailchonk containing the named AUMs. func (c *testChain) ChonkWith(names ...string) Chonk { - var out Mem + out := ChonkMem() for _, name := range names { update := c.AUMs[name] if err := out.CommitVerifiedAUMs([]AUM{update}); err != nil { panic(err) } } - return &out + return out } type testchainOpt struct { diff --git a/tka/disabled_stub.go b/tka/disabled_stub.go index 15bf12c333fc8..4c4afa3706d98 100644 --- a/tka/disabled_stub.go +++ b/tka/disabled_stub.go @@ -22,7 +22,24 @@ type Authority struct { func (*Authority) Head() AUMHash { return AUMHash{} } -func (AUMHash) MarshalText() ([]byte, error) { return nil, errNoTailnetLock } +// MarshalText returns a dummy value explaining that Tailnet Lock +// is not compiled in to this binary. +// +// We need to be able to marshal AUMHash to text because it's included +// in [netmap.NetworkMap], which gets serialised as JSON in the +// c2n /debug/netmap endpoint. +// +// We provide a basic marshaller so that endpoint works correctly +// with nodes that omit Tailnet Lock support, but we don't want the +// base32 dependency used for the regular marshaller, and we don't +// need unmarshalling support at time of writing (2025-11-18). +func (h AUMHash) MarshalText() ([]byte, error) { + return []byte(""), nil +} + +func (h *AUMHash) UnmarshalText(text []byte) error { + return errors.New("tailnet lock is not supported by this binary") +} type State struct{} @@ -128,12 +145,6 @@ type NodeKeySignature struct { type DeeplinkValidationResult struct { } -func (h *AUMHash) UnmarshalText(text []byte) error { - return errNoTailnetLock -} - -var errNoTailnetLock = errors.New("tailnet lock is not enabled") - func DecodeWrappedAuthkey(wrappedAuthKey string, logf logger.Logf) (authKey string, isWrapped bool, sig *NodeKeySignature, priv ed25519.PrivateKey) { return wrappedAuthKey, false, nil, nil } diff --git a/tka/key_test.go b/tka/key_test.go index e912f89c4f7eb..327de1a0e2851 100644 --- a/tka/key_test.go +++ b/tka/key_test.go @@ -42,7 +42,7 @@ func TestVerify25519(t *testing.T) { aum := AUM{ MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, - // Signatures is set to crap so we are sure its ignored in the sigHash computation. + // Signatures is set to crap so we are sure it's ignored in the sigHash computation. Signatures: []tkatype.Signature{{KeyID: []byte{45, 42}}}, } sigHash := aum.SigHash() @@ -72,7 +72,7 @@ func TestNLPrivate(t *testing.T) { // Test that key.NLPrivate implements Signer by making a new // authority. k := Key{Kind: Key25519, Public: pub.Verifier(), Votes: 1} - _, aum, err := Create(&Mem{}, State{ + _, aum, err := Create(ChonkMem(), State{ Keys: []Key{k}, DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, }, p) @@ -89,7 +89,7 @@ func TestNLPrivate(t *testing.T) { t.Error("signature did not verify") } - // We manually compute the keyID, so make sure its consistent with + // We manually compute the keyID, so make sure it's consistent with // tka.Key.ID(). if !bytes.Equal(k.MustID(), p.KeyID()) { t.Errorf("private.KeyID() & tka KeyID differ: %x != %x", k.MustID(), p.KeyID()) diff --git a/tka/scenario_test.go b/tka/scenario_test.go index 89a8111e18cef..a0361a130dcc6 100644 --- a/tka/scenario_test.go +++ b/tka/scenario_test.go @@ -204,7 +204,7 @@ func TestNormalPropagation(t *testing.T) { `) control := s.mkNode("control") - // Lets say theres a node with some updates! + // Let's say there's a node with some updates! n1 := s.mkNodeWithForks("n1", true, map[string]*testChain{ "L2": newTestchain(t, `L3 -> L4`), }) diff --git a/tka/sig.go b/tka/sig.go index 7b1838d409130..46d598ad97b47 100644 --- a/tka/sig.go +++ b/tka/sig.go @@ -277,7 +277,7 @@ func (s *NodeKeySignature) verifySignature(nodeKey key.NodePublic, verificationK // Recurse to verify the signature on the nested structure. var nestedPub key.NodePublic // SigCredential signatures certify an indirection key rather than a node - // key, so theres no need to check the node key. + // key, so there's no need to check the node key. if s.Nested.SigKind != SigCredential { if err := nestedPub.UnmarshalBinary(s.Nested.Pubkey); err != nil { return fmt.Errorf("nested pubkey: %v", err) diff --git a/tka/sig_test.go b/tka/sig_test.go index 99c25f8e57ae6..c5c03ef2e0055 100644 --- a/tka/sig_test.go +++ b/tka/sig_test.go @@ -76,8 +76,8 @@ func TestSigNested(t *testing.T) { if err := nestedSig.verifySignature(oldNode.Public(), k); err != nil { t.Fatalf("verifySignature(oldNode) failed: %v", err) } - if l := sigChainLength(nestedSig); l != 1 { - t.Errorf("nestedSig chain length = %v, want 1", l) + if ln := sigChainLength(nestedSig); ln != 1 { + t.Errorf("nestedSig chain length = %v, want 1", ln) } // The signature authorizing the rotation, signed by the @@ -93,8 +93,8 @@ func TestSigNested(t *testing.T) { if err := sig.verifySignature(node.Public(), k); err != nil { t.Fatalf("verifySignature(node) failed: %v", err) } - if l := sigChainLength(sig); l != 2 { - t.Errorf("sig chain length = %v, want 2", l) + if ln := sigChainLength(sig); ln != 2 { + t.Errorf("sig chain length = %v, want 2", ln) } // Test verification fails if the wrong verification key is provided @@ -119,7 +119,7 @@ func TestSigNested(t *testing.T) { } // Test verification fails if the outer signature is signed with a - // different public key to whats specified in WrappingPubkey + // different public key to what's specified in WrappingPubkey sig.Signature = ed25519.Sign(priv, sigHash[:]) if err := sig.verifySignature(node.Public(), k); err == nil { t.Error("verifySignature(node) succeeded with different signature") @@ -275,7 +275,7 @@ func TestSigCredential(t *testing.T) { } // Test verification fails if the outer signature is signed with a - // different public key to whats specified in WrappingPubkey + // different public key to what's specified in WrappingPubkey sig.Signature = ed25519.Sign(priv, sigHash[:]) if err := sig.verifySignature(node.Public(), k); err == nil { t.Error("verifySignature(node) succeeded with different signature") diff --git a/tka/state.go b/tka/state.go index 0a30c56a02fa8..95a319bd9bd7d 100644 --- a/tka/state.go +++ b/tka/state.go @@ -140,7 +140,7 @@ func (s State) checkDisablement(secret []byte) bool { // Specifically, the rules are: // - The last AUM hash must match (transitively, this implies that this // update follows the last update message applied to the state machine) -// - Or, the state machine knows no parent (its brand new). +// - Or, the state machine knows no parent (it's brand new). func (s State) parentMatches(update AUM) bool { if s.LastAUMHash == nil { return true diff --git a/tka/sync.go b/tka/sync.go index 6c2b7cbb8c81a..2dbfb7ac435b2 100644 --- a/tka/sync.go +++ b/tka/sync.go @@ -32,6 +32,41 @@ type SyncOffer struct { Ancestors []AUMHash } +// ToSyncOffer creates a SyncOffer from the fields received in +// a [tailcfg.TKASyncOfferRequest]. +func ToSyncOffer(head string, ancestors []string) (SyncOffer, error) { + var out SyncOffer + if err := out.Head.UnmarshalText([]byte(head)); err != nil { + return SyncOffer{}, fmt.Errorf("head.UnmarshalText: %v", err) + } + out.Ancestors = make([]AUMHash, len(ancestors)) + for i, a := range ancestors { + if err := out.Ancestors[i].UnmarshalText([]byte(a)); err != nil { + return SyncOffer{}, fmt.Errorf("ancestor[%d].UnmarshalText: %v", i, err) + } + } + return out, nil +} + +// FromSyncOffer marshals the fields of a SyncOffer so they can be +// sent in a [tailcfg.TKASyncOfferRequest]. +func FromSyncOffer(offer SyncOffer) (head string, ancestors []string, err error) { + headBytes, err := offer.Head.MarshalText() + if err != nil { + return "", nil, fmt.Errorf("head.MarshalText: %v", err) + } + + ancestors = make([]string, len(offer.Ancestors)) + for i, ancestor := range offer.Ancestors { + hash, err := ancestor.MarshalText() + if err != nil { + return "", nil, fmt.Errorf("ancestor[%d].MarshalText: %v", i, err) + } + ancestors[i] = string(hash) + } + return string(headBytes), ancestors, nil +} + const ( // The starting number of AUMs to skip when listing // ancestors in a SyncOffer. @@ -54,7 +89,7 @@ const ( // can then be applied locally with Inform(). // // This SyncOffer + AUM exchange should be performed by both ends, -// because its possible that either end has AUMs that the other needs +// because it's possible that either end has AUMs that the other needs // to find out about. func (a *Authority) SyncOffer(storage Chonk) (SyncOffer, error) { oldest := a.oldestAncestor.Hash() @@ -123,7 +158,7 @@ func computeSyncIntersection(storage Chonk, localOffer, remoteOffer SyncOffer) ( } // Case: 'head intersection' - // If we have the remote's head, its more likely than not that + // If we have the remote's head, it's more likely than not that // we have updates that build on that head. To confirm this, // we iterate backwards through our chain to see if the given // head is an ancestor of our current chain. @@ -165,7 +200,7 @@ func computeSyncIntersection(storage Chonk, localOffer, remoteOffer SyncOffer) ( // Case: 'tail intersection' // So we don't have a clue what the remote's head is, but // if one of the ancestors they gave us is part of our chain, - // then theres an intersection, which is a starting point for + // then there's an intersection, which is a starting point for // the remote to send us AUMs from. // // We iterate the list of ancestors in order because the remote diff --git a/tka/sync_test.go b/tka/sync_test.go index 7250eacf7d143..ea14a37e57e9b 100644 --- a/tka/sync_test.go +++ b/tka/sync_test.go @@ -346,7 +346,7 @@ func TestSyncSimpleE2E(t *testing.T) { optKey("key", key, priv), optSignAllUsing("key")) - nodeStorage := &Mem{} + nodeStorage := ChonkMem() node, err := Bootstrap(nodeStorage, c.AUMs["G1"]) if err != nil { t.Fatalf("node Bootstrap() failed: %v", err) @@ -357,7 +357,7 @@ func TestSyncSimpleE2E(t *testing.T) { t.Fatalf("control Open() failed: %v", err) } - // Control knows the full chain, node only knows the genesis. Lets see + // Control knows the full chain, node only knows the genesis. Let's see // if they can sync. nodeOffer, err := node.SyncOffer(nodeStorage) if err != nil { diff --git a/tka/tailchonk.go b/tka/tailchonk.go index 616abaf2b4190..13bdf6aac86d4 100644 --- a/tka/tailchonk.go +++ b/tka/tailchonk.go @@ -19,6 +19,8 @@ import ( "github.com/fxamacker/cbor/v2" "tailscale.com/atomicfile" + "tailscale.com/tstime" + "tailscale.com/util/testenv" ) // Chonk implementations provide durable storage for AUMs and other @@ -92,6 +94,7 @@ type Mem struct { mu sync.RWMutex aums map[AUMHash]AUM commitTimes map[AUMHash]time.Time + clock tstime.Clock // parentIndex is a map of AUMs to the AUMs for which they are // the parent. @@ -103,6 +106,23 @@ type Mem struct { lastActiveAncestor *AUMHash } +// ChonkMem returns an implementation of Chonk which stores TKA state +// in-memory. +func ChonkMem() *Mem { + return &Mem{ + clock: tstime.DefaultClock{}, + } +} + +// SetClock sets the clock used by [Mem]. This is only for use in tests, +// and will panic if called from non-test code. +func (c *Mem) SetClock(clock tstime.Clock) { + if !testenv.InTest() { + panic("used SetClock in non-test code") + } + c.clock = clock +} + func (c *Mem) SetLastActiveAncestor(hash AUMHash) error { c.mu.Lock() defer c.mu.Unlock() @@ -173,7 +193,7 @@ updateLoop: for _, aum := range updates { aumHash := aum.Hash() c.aums[aumHash] = aum - c.commitTimes[aumHash] = time.Now() + c.commitTimes[aumHash] = c.now() parent, ok := aum.Parent() if ok { @@ -189,6 +209,16 @@ updateLoop: return nil } +// now returns the current time, optionally using the overridden +// clock if set. +func (c *Mem) now() time.Time { + if c.clock == nil { + return time.Now() + } else { + return c.clock.Now() + } +} + // RemoveAll permanently and completely clears the TKA state. func (c *Mem) RemoveAll() error { c.mu.Lock() @@ -668,7 +698,7 @@ const ( ) // markActiveChain marks AUMs in the active chain. -// All AUMs that are within minChain ancestors of head are +// All AUMs that are within minChain ancestors of head, or are marked as young, are // marked retainStateActive, and all remaining ancestors are // marked retainStateCandidate. // @@ -694,27 +724,30 @@ func markActiveChain(storage Chonk, verdict map[AUMHash]retainState, minChain in // We've reached the end of the chain we have stored. return h, nil } - return AUMHash{}, fmt.Errorf("reading active chain (retainStateActive) (%d): %w", i, err) + return AUMHash{}, fmt.Errorf("reading active chain (retainStateActive) (%d, %v): %w", i, parent, err) } } // If we got this far, we have at least minChain AUMs stored, and minChain number // of ancestors have been marked for retention. We now continue to iterate backwards - // till we find an AUM which we can compact to (a Checkpoint AUM). + // till we find an AUM which we can compact to: either a Checkpoint AUM which is old + // enough, or the genesis AUM. for { h := next.Hash() verdict[h] |= retainStateActive + + parent, hasParent := next.Parent() + isYoung := verdict[h]&retainStateYoung != 0 + if next.MessageKind == AUMCheckpoint { lastActiveAncestor = h - break + if !isYoung || !hasParent { + break + } } - parent, hasParent := next.Parent() - if !hasParent { - return AUMHash{}, errors.New("reached genesis AUM without finding an appropriate lastActiveAncestor") - } if next, err = storage.AUM(parent); err != nil { - return AUMHash{}, fmt.Errorf("searching for compaction target: %w", err) + return AUMHash{}, fmt.Errorf("searching for compaction target (%v): %w", parent, err) } } @@ -730,7 +763,7 @@ func markActiveChain(storage Chonk, verdict map[AUMHash]retainState, minChain in // We've reached the end of the chain we have stored. break } - return AUMHash{}, fmt.Errorf("reading active chain (retainStateCandidate): %w", err) + return AUMHash{}, fmt.Errorf("reading active chain (retainStateCandidate, %v): %w", parent, err) } } @@ -768,7 +801,7 @@ func markAncestorIntersectionAUMs(storage Chonk, verdict map[AUMHash]retainState toScan := make([]AUMHash, 0, len(verdict)) for h, v := range verdict { if (v & retainAUMMask) == 0 { - continue // not marked for retention, so dont need to consider it + continue // not marked for retention, so don't need to consider it } if h == candidateAncestor { continue @@ -842,7 +875,7 @@ func markAncestorIntersectionAUMs(storage Chonk, verdict map[AUMHash]retainState if didAdjustCandidateAncestor { var next AUM if next, err = storage.AUM(candidateAncestor); err != nil { - return AUMHash{}, fmt.Errorf("searching for compaction target: %w", err) + return AUMHash{}, fmt.Errorf("searching for compaction target (%v): %w", candidateAncestor, err) } for { @@ -858,7 +891,7 @@ func markAncestorIntersectionAUMs(storage Chonk, verdict map[AUMHash]retainState return AUMHash{}, errors.New("reached genesis AUM without finding an appropriate candidateAncestor") } if next, err = storage.AUM(parent); err != nil { - return AUMHash{}, fmt.Errorf("searching for compaction target: %w", err) + return AUMHash{}, fmt.Errorf("searching for compaction target (%v): %w", parent, err) } } } @@ -871,7 +904,7 @@ func markDescendantAUMs(storage Chonk, verdict map[AUMHash]retainState) error { toScan := make([]AUMHash, 0, len(verdict)) for h, v := range verdict { if v&retainAUMMask == 0 { - continue // not marked, so dont need to mark descendants + continue // not marked, so don't need to mark descendants } toScan = append(toScan, h) } @@ -917,12 +950,12 @@ func Compact(storage CompactableChonk, head AUMHash, opts CompactionOptions) (la verdict[h] = 0 } - if lastActiveAncestor, err = markActiveChain(storage, verdict, opts.MinChain, head); err != nil { - return AUMHash{}, fmt.Errorf("marking active chain: %w", err) - } if err := markYoungAUMs(storage, verdict, opts.MinAge); err != nil { return AUMHash{}, fmt.Errorf("marking young AUMs: %w", err) } + if lastActiveAncestor, err = markActiveChain(storage, verdict, opts.MinChain, head); err != nil { + return AUMHash{}, fmt.Errorf("marking active chain: %w", err) + } if err := markDescendantAUMs(storage, verdict); err != nil { return AUMHash{}, fmt.Errorf("marking descendant AUMs: %w", err) } diff --git a/tka/tailchonk_test.go b/tka/tailchonk_test.go index 70b7dc9a72fbb..eeb6edfff3018 100644 --- a/tka/tailchonk_test.go +++ b/tka/tailchonk_test.go @@ -15,6 +15,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "golang.org/x/crypto/blake2s" + "tailscale.com/types/key" "tailscale.com/util/must" ) @@ -34,7 +35,7 @@ func randHash(t *testing.T, seed int64) [blake2s.Size]byte { } func TestImplementsChonk(t *testing.T) { - impls := []Chonk{&Mem{}, &FS{}} + impls := []Chonk{ChonkMem(), &FS{}} t.Logf("chonks: %v", impls) } @@ -228,7 +229,7 @@ func TestMarkActiveChain(t *testing.T) { verdict := make(map[AUMHash]retainState, len(tc.chain)) // Build the state of the tailchonk for tests. - storage := &Mem{} + storage := ChonkMem() var prev AUMHash for i := range tc.chain { if !prev.IsZero() { @@ -601,3 +602,33 @@ func TestCompact(t *testing.T) { } } } + +func TestCompactLongButYoung(t *testing.T) { + ourPriv := key.NewNLPrivate() + ourKey := Key{Kind: Key25519, Public: ourPriv.Public().Verifier(), Votes: 1} + someOtherKey := Key{Kind: Key25519, Public: key.NewNLPrivate().Public().Verifier(), Votes: 1} + + storage := ChonkMem() + auth, _, err := Create(storage, State{ + Keys: []Key{ourKey, someOtherKey}, + DisablementSecrets: [][]byte{DisablementKDF(bytes.Repeat([]byte{0xa5}, 32))}, + }, ourPriv) + if err != nil { + t.Fatalf("tka.Create() failed: %v", err) + } + + genesis := auth.Head() + + for range 100 { + upd := auth.NewUpdater(ourPriv) + must.Do(upd.RemoveKey(someOtherKey.MustID())) + must.Do(upd.AddKey(someOtherKey)) + aums := must.Get(upd.Finalize(storage)) + must.Do(auth.Inform(storage, aums)) + } + + lastActiveAncestor := must.Get(Compact(storage, auth.Head(), CompactionOptions{MinChain: 5, MinAge: time.Hour})) + if lastActiveAncestor != genesis { + t.Errorf("last active ancestor = %v, want %v", lastActiveAncestor, genesis) + } +} diff --git a/tka/tka.go b/tka/tka.go index 234c87fe1b89c..ed029c82e0592 100644 --- a/tka/tka.go +++ b/tka/tka.go @@ -94,7 +94,7 @@ func computeChainCandidates(storage Chonk, lastKnownOldest *AUMHash, maxIter int // candidates.Oldest needs to be computed by working backwards from // head as far as we can. - iterAgain := true // if theres still work to be done. + iterAgain := true // if there's still work to be done. for i := 0; iterAgain; i++ { if i >= maxIter { return nil, fmt.Errorf("iteration limit exceeded (%d)", maxIter) @@ -102,14 +102,14 @@ func computeChainCandidates(storage Chonk, lastKnownOldest *AUMHash, maxIter int iterAgain = false for j := range candidates { - parent, hasParent := candidates[j].Oldest.Parent() + parentHash, hasParent := candidates[j].Oldest.Parent() if hasParent { - parent, err := storage.AUM(parent) + parent, err := storage.AUM(parentHash) if err != nil { if err == os.ErrNotExist { continue } - return nil, fmt.Errorf("reading parent: %v", err) + return nil, fmt.Errorf("reading parent %s: %v", parentHash, err) } candidates[j].Oldest = parent if lastKnownOldest != nil && *lastKnownOldest == parent.Hash() { @@ -210,7 +210,7 @@ func fastForwardWithAdvancer( } nextAUM, err := storage.AUM(*startState.LastAUMHash) if err != nil { - return AUM{}, State{}, fmt.Errorf("reading next: %v", err) + return AUM{}, State{}, fmt.Errorf("reading next (%v): %v", *startState.LastAUMHash, err) } curs := nextAUM @@ -295,9 +295,9 @@ func computeStateAt(storage Chonk, maxIter int, wantHash AUMHash) (State, error) } // If we got here, the current state is dependent on the previous. - // Keep iterating backwards till thats not the case. + // Keep iterating backwards till that's not the case. if curs, err = storage.AUM(parent); err != nil { - return State{}, fmt.Errorf("reading parent: %v", err) + return State{}, fmt.Errorf("reading parent (%v): %v", parent, err) } } @@ -324,7 +324,7 @@ func computeStateAt(storage Chonk, maxIter int, wantHash AUMHash) (State, error) return curs.Hash() == wantHash }) // fastForward only terminates before the done condition if it - // doesnt have any later AUMs to process. This cant be the case + // doesn't have any later AUMs to process. This can't be the case // as we've already iterated through them above so they must exist, // but we check anyway to be super duper sure. if err == nil && *state.LastAUMHash != wantHash { @@ -336,13 +336,13 @@ func computeStateAt(storage Chonk, maxIter int, wantHash AUMHash) (State, error) // computeActiveAncestor determines which ancestor AUM to use as the // ancestor of the valid chain. // -// If all the chains end up having the same ancestor, then thats the +// If all the chains end up having the same ancestor, then that's the // only possible ancestor, ezpz. However if there are multiple distinct // ancestors, that means there are distinct chains, and we need some // hint to choose what to use. For that, we rely on the chainsThroughActive // bit, which signals to us that that ancestor was part of the // chain in a previous run. -func computeActiveAncestor(storage Chonk, chains []chain) (AUMHash, error) { +func computeActiveAncestor(chains []chain) (AUMHash, error) { // Dedupe possible ancestors, tracking if they were part of // the active chain on a previous run. ancestors := make(map[AUMHash]bool, len(chains)) @@ -357,7 +357,7 @@ func computeActiveAncestor(storage Chonk, chains []chain) (AUMHash, error) { } } - // Theres more than one, so we need to use the ancestor that was + // There's more than one, so we need to use the ancestor that was // part of the active chain in a previous iteration. // Note that there can only be one distinct ancestor that was // formerly part of the active chain, because AUMs can only have @@ -391,8 +391,12 @@ func computeActiveChain(storage Chonk, lastKnownOldest *AUMHash, maxIter int) (c return chain{}, fmt.Errorf("computing candidates: %v", err) } + if len(chains) == 0 { + return chain{}, errors.New("no chain candidates in AUM storage") + } + // Find the right ancestor. - oldestHash, err := computeActiveAncestor(storage, chains) + oldestHash, err := computeActiveAncestor(chains) if err != nil { return chain{}, fmt.Errorf("computing ancestor: %v", err) } @@ -475,7 +479,7 @@ func (a *Authority) Head() AUMHash { // Open initializes an existing TKA from the given tailchonk. // // Only use this if the current node has initialized an Authority before. -// If a TKA exists on other nodes but theres nothing locally, use Bootstrap(). +// If a TKA exists on other nodes but there's nothing locally, use Bootstrap(). // If no TKA exists anywhere and you are creating it for the first // time, use New(). func Open(storage Chonk) (*Authority, error) { @@ -588,14 +592,14 @@ func (a *Authority) InformIdempotent(storage Chonk, updates []AUM) (Authority, e toCommit := make([]AUM, 0, len(updates)) prevHash := a.Head() - // The state at HEAD is the current state of the authority. Its likely + // The state at HEAD is the current state of the authority. It's likely // to be needed, so we prefill it rather than computing it. stateAt[prevHash] = a.state // Optimization: If the set of updates is a chain building from // the current head, EG: // ==> updates[0] ==> updates[1] ... - // Then theres no need to recompute the resulting state from the + // Then there's no need to recompute the resulting state from the // stored ancestor, because the last state computed during iteration // is the new state. This should be the common case. // isHeadChain keeps track of this. @@ -775,8 +779,8 @@ func (a *Authority) findParentForRewrite(storage Chonk, removeKeys []tkatype.Key } } if !keyTrusted { - // Success: the revoked keys are not trusted! - // Lets check that our key was trusted to ensure + // Success: the revoked keys are not trusted. + // Check that our key was trusted to ensure // we can sign a fork from here. if _, err := state.GetKey(ourKey); err == nil { break diff --git a/tka/tka_test.go b/tka/tka_test.go index 9e3c4e79d05bd..78af7400daff3 100644 --- a/tka/tka_test.go +++ b/tka/tka_test.go @@ -253,7 +253,7 @@ func TestOpenAuthority(t *testing.T) { } // Construct the state of durable storage. - chonk := &Mem{} + chonk := ChonkMem() err := chonk.CommitVerifiedAUMs([]AUM{g1, i1, l1, i2, i3, l2, l3, g2, l4}) if err != nil { t.Fatal(err) @@ -275,7 +275,7 @@ func TestOpenAuthority(t *testing.T) { } func TestOpenAuthority_EmptyErrors(t *testing.T) { - _, err := Open(&Mem{}) + _, err := Open(ChonkMem()) if err == nil { t.Error("Expected an error initializing an empty authority, got nil") } @@ -319,7 +319,7 @@ func TestCreateBootstrapAuthority(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2} - a1, genesisAUM, err := Create(&Mem{}, State{ + a1, genesisAUM, err := Create(ChonkMem(), State{ Keys: []Key{key}, DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, }, signer25519(priv)) @@ -327,7 +327,7 @@ func TestCreateBootstrapAuthority(t *testing.T) { t.Fatalf("Create() failed: %v", err) } - a2, err := Bootstrap(&Mem{}, genesisAUM) + a2, err := Bootstrap(ChonkMem(), genesisAUM) if err != nil { t.Fatalf("Bootstrap() failed: %v", err) } @@ -366,7 +366,7 @@ func TestAuthorityInformNonLinear(t *testing.T) { optKey("key", key, priv), optSignAllUsing("key")) - storage := &Mem{} + storage := ChonkMem() a, err := Bootstrap(storage, c.AUMs["G1"]) if err != nil { t.Fatalf("Bootstrap() failed: %v", err) @@ -411,7 +411,7 @@ func TestAuthorityInformLinear(t *testing.T) { optKey("key", key, priv), optSignAllUsing("key")) - storage := &Mem{} + storage := ChonkMem() a, err := Bootstrap(storage, c.AUMs["G1"]) if err != nil { t.Fatalf("Bootstrap() failed: %v", err) @@ -444,7 +444,7 @@ func TestInteropWithNLKey(t *testing.T) { pub2 := key.NewNLPrivate().Public() pub3 := key.NewNLPrivate().Public() - a, _, err := Create(&Mem{}, State{ + a, _, err := Create(ChonkMem(), State{ Keys: []Key{ { Kind: Key25519, diff --git a/tka/verify.go b/tka/verify.go index e4e22e5518e8b..ed0ecea669817 100644 --- a/tka/verify.go +++ b/tka/verify.go @@ -18,7 +18,7 @@ import ( // provided AUM BLAKE2s digest, using the given key. func signatureVerify(s *tkatype.Signature, aumDigest tkatype.AUMSigHash, key Key) error { // NOTE(tom): Even if we can compute the public from the KeyID, - // its possible for the KeyID to be attacker-controlled + // it's possible for the KeyID to be attacker-controlled // so we should use the public contained in the state machine. switch key.Kind { case Key25519: diff --git a/tsconsensus/monitor.go b/tsconsensus/monitor.go index 61a5a74a07c42..c84e83454f3f7 100644 --- a/tsconsensus/monitor.go +++ b/tsconsensus/monitor.go @@ -92,8 +92,8 @@ func (m *monitor) handleSummaryStatus(w http.ResponseWriter, r *http.Request) { } slices.Sort(lines) - for _, l := range lines { - _, err = w.Write([]byte(fmt.Sprintf("%s\n", l))) + for _, ln := range lines { + _, err = w.Write([]byte(fmt.Sprintf("%s\n", ln))) if err != nil { log.Printf("monitor: error writing status: %v", err) return @@ -102,15 +102,13 @@ func (m *monitor) handleSummaryStatus(w http.ResponseWriter, r *http.Request) { } func (m *monitor) handleNetmap(w http.ResponseWriter, r *http.Request) { - var mask ipn.NotifyWatchOpt = ipn.NotifyInitialNetMap - mask |= ipn.NotifyNoPrivateKeys lc, err := m.ts.LocalClient() if err != nil { log.Printf("monitor: error LocalClient: %v", err) http.Error(w, "", http.StatusInternalServerError) return } - watcher, err := lc.WatchIPNBus(r.Context(), mask) + watcher, err := lc.WatchIPNBus(r.Context(), ipn.NotifyInitialNetMap) if err != nil { log.Printf("monitor: error WatchIPNBus: %v", err) http.Error(w, "", http.StatusInternalServerError) diff --git a/tsconsensus/tsconsensus_test.go b/tsconsensus/tsconsensus_test.go index 3b51a093f12ad..796c8f51b76a9 100644 --- a/tsconsensus/tsconsensus_test.go +++ b/tsconsensus/tsconsensus_test.go @@ -17,7 +17,6 @@ import ( "net/netip" "os" "path/filepath" - "runtime" "strings" "sync" "testing" @@ -27,7 +26,6 @@ import ( "github.com/hashicorp/go-hclog" "github.com/hashicorp/raft" "tailscale.com/client/tailscale" - "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/ipn/store/mem" "tailscale.com/net/netns" "tailscale.com/tailcfg" @@ -75,10 +73,10 @@ func fromCommand(bs []byte) (string, error) { return args, nil } -func (f *fsm) Apply(l *raft.Log) any { +func (f *fsm) Apply(lg *raft.Log) any { f.mu.Lock() defer f.mu.Unlock() - s, err := fromCommand(l.Data) + s, err := fromCommand(lg.Data) if err != nil { return CommandResult{ Err: err, @@ -115,8 +113,8 @@ func (f *fsm) Restore(rc io.ReadCloser) error { } func testConfig(t *testing.T) { - if runtime.GOOS == "windows" && cibuild.On() { - t.Skip("cmd/natc isn't supported on Windows, so skipping tsconsensus tests on CI for now; see https://github.com/tailscale/tailscale/issues/16340") + if cibuild.On() { + t.Skip("these integration tests don't always work well in CI and that's bad for CI; see https://github.com/tailscale/tailscale/issues/16340 and https://github.com/tailscale/tailscale/issues/18022") } // -race AND Parallel makes things start to take too long. if !racebuild.On { @@ -580,7 +578,6 @@ func TestRejoin(t *testing.T) { } func TestOnlyTaggedPeersCanDialRaftPort(t *testing.T) { - flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/15627") testConfig(t) ctx := context.Background() clusterTag := "tag:whatever" @@ -638,7 +635,6 @@ func TestOnlyTaggedPeersCanDialRaftPort(t *testing.T) { } func TestOnlyTaggedPeersCanBeDialed(t *testing.T) { - flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/15627") testConfig(t) ctx := context.Background() clusterTag := "tag:whatever" diff --git a/tsnet/depaware.txt b/tsnet/depaware.txt index cd734e9959041..825a39e34877f 100644 --- a/tsnet/depaware.txt +++ b/tsnet/depaware.txt @@ -9,6 +9,7 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) LDW github.com/coder/websocket/internal/errd from github.com/coder/websocket LDW github.com/coder/websocket/internal/util from github.com/coder/websocket LDW github.com/coder/websocket/internal/xsync from github.com/coder/websocket + github.com/creachadair/msync/trigger from tailscale.com/logtail W 💣 github.com/dblohm7/wingoes from tailscale.com/net/tshttpproxy+ W 💣 github.com/dblohm7/wingoes/com from tailscale.com/util/osdiag+ W 💣 github.com/dblohm7/wingoes/com/automation from tailscale.com/util/osdiag/internal/wsc @@ -35,6 +36,7 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) github.com/klauspost/compress/fse from github.com/klauspost/compress/huff0 github.com/klauspost/compress/huff0 from github.com/klauspost/compress/zstd github.com/klauspost/compress/internal/cpuinfo from github.com/klauspost/compress/huff0+ + 💣 github.com/klauspost/compress/internal/le from github.com/klauspost/compress/huff0+ github.com/klauspost/compress/internal/snapref from github.com/klauspost/compress/zstd github.com/klauspost/compress/zstd from tailscale.com/util/zstdframe github.com/klauspost/compress/zstd/internal/xxhash from github.com/klauspost/compress/zstd @@ -42,6 +44,7 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) L 💣 github.com/mdlayher/netlink/nlenc from github.com/jsimonetti/rtnetlink+ LA 💣 github.com/mdlayher/socket from github.com/mdlayher/netlink+ LDW 💣 github.com/mitchellh/go-ps from tailscale.com/safesocket + github.com/pires/go-proxyproto from tailscale.com/ipn/ipnlocal DI github.com/prometheus-community/pro-bing from tailscale.com/wgengine/netstack L 💣 github.com/safchain/ethtool from tailscale.com/net/netkernelconf W 💣 github.com/tailscale/certstore from tailscale.com/control/controlclient @@ -224,7 +227,7 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) LDW tailscale.com/tsweb from tailscale.com/util/eventbus tailscale.com/tsweb/varz from tailscale.com/tsweb+ tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal+ - tailscale.com/types/bools from tailscale.com/tsnet + tailscale.com/types/bools from tailscale.com/tsnet+ tailscale.com/types/dnstype from tailscale.com/client/local+ tailscale.com/types/empty from tailscale.com/ipn+ tailscale.com/types/ipproto from tailscale.com/ipn+ @@ -247,7 +250,7 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) tailscale.com/types/views from tailscale.com/appc+ tailscale.com/util/backoff from tailscale.com/control/controlclient+ tailscale.com/util/checkchange from tailscale.com/ipn/ipnlocal+ - tailscale.com/util/cibuild from tailscale.com/health + tailscale.com/util/cibuild from tailscale.com/health+ tailscale.com/util/clientmetric from tailscale.com/appc+ tailscale.com/util/cloudenv from tailscale.com/hostinfo+ LW tailscale.com/util/cmpver from tailscale.com/net/dns+ @@ -390,7 +393,7 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) crypto/ecdsa from crypto/tls+ crypto/ed25519 from crypto/tls+ crypto/elliptic from crypto/ecdsa+ - crypto/fips140 from crypto/tls/internal/fips140tls + crypto/fips140 from crypto/tls/internal/fips140tls+ crypto/hkdf from crypto/internal/hpke+ crypto/hmac from crypto/tls+ crypto/internal/boring from crypto/aes+ diff --git a/tsnet/packet_filter_test.go b/tsnet/packet_filter_test.go index 462234222f936..455400eaa0c8a 100644 --- a/tsnet/packet_filter_test.go +++ b/tsnet/packet_filter_test.go @@ -12,6 +12,7 @@ import ( "tailscale.com/ipn" "tailscale.com/tailcfg" + "tailscale.com/tstest" "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/netmap" @@ -47,6 +48,7 @@ func waitFor(t testing.TB, ctx context.Context, s *Server, f func(*netmap.Networ // netmaps and turning them into packet filters together. Only the control-plane // side is mocked out. func TestPacketFilterFromNetmap(t *testing.T) { + tstest.Shard(t) t.Parallel() var key key.NodePublic diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 2944f63595a48..14747650f42ee 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -350,7 +350,7 @@ func (s *Server) Up(ctx context.Context) (*ipnstate.Status, error) { return nil, fmt.Errorf("tsnet.Up: %w", err) } - watcher, err := lc.WatchIPNBus(ctx, ipn.NotifyInitialState|ipn.NotifyNoPrivateKeys) + watcher, err := lc.WatchIPNBus(ctx, ipn.NotifyInitialState) if err != nil { return nil, fmt.Errorf("tsnet.Up: %w", err) } diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 1e22681fcfe36..f1531d013d4b7 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -235,6 +235,7 @@ func startServer(t *testing.T, ctx context.Context, controlURL, hostname string) } func TestDialBlocks(t *testing.T) { + tstest.Shard(t) tstest.ResourceCheck(t) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -274,30 +275,57 @@ func TestDialBlocks(t *testing.T) { defer c.Close() } +// TestConn tests basic TCP connections between two tsnet Servers, s1 and s2: +// +// - s1, a subnet router, first listens on its TCP :8081. +// - s2 can connect to s1:8081 +// - s2 cannot connect to s1:8082 (no listener) +// - s2 can dial through the subnet router functionality (getting a synthetic RST +// that we verify we generated & saw) func TestConn(t *testing.T) { + tstest.Shard(t) tstest.ResourceCheck(t) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() controlURL, c := startControl(t) s1, s1ip, s1PubKey := startServer(t, ctx, controlURL, "s1") - s2, _, _ := startServer(t, ctx, controlURL, "s2") - s1.lb.EditPrefs(&ipn.MaskedPrefs{ + // Track whether we saw an attempted dial to 192.0.2.1:8081. + var saw192DocNetDial atomic.Bool + s1.RegisterFallbackTCPHandler(func(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) { + t.Logf("s1: fallback TCP handler called for %v -> %v", src, dst) + if dst.String() == "192.0.2.1:8081" { + saw192DocNetDial.Store(true) + } + return nil, true // nil handler but intercept=true means to send RST + }) + + lc1 := must.Get(s1.LocalClient()) + + must.Get(lc1.EditPrefs(ctx, &ipn.MaskedPrefs{ Prefs: ipn.Prefs{ AdvertiseRoutes: []netip.Prefix{netip.MustParsePrefix("192.0.2.0/24")}, }, AdvertiseRoutesSet: true, - }) + })) c.SetSubnetRoutes(s1PubKey, []netip.Prefix{netip.MustParsePrefix("192.0.2.0/24")}) - lc2, err := s2.LocalClient() - if err != nil { - t.Fatal(err) - } + // Start s2 after s1 is fully set up, including advertising its routes, + // otherwise the test is flaky if the test starts dialing through s2 before + // our test control server has told s2 about s1's routes. + s2, _, _ := startServer(t, ctx, controlURL, "s2") + lc2 := must.Get(s2.LocalClient()) + + must.Get(lc2.EditPrefs(ctx, &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + RouteAll: true, + }, + RouteAllSet: true, + })) // ping to make sure the connection is up. - res, err := lc2.Ping(ctx, s1ip, tailcfg.PingICMP) + res, err := lc2.Ping(ctx, s1ip, tailcfg.PingTSMP) if err != nil { t.Fatal(err) } @@ -310,12 +338,26 @@ func TestConn(t *testing.T) { } defer ln.Close() - w, err := s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip)) - if err != nil { - t.Fatal(err) - } + s1Conns := make(chan net.Conn) + go func() { + for { + c, err := ln.Accept() + if err != nil { + if ctx.Err() != nil { + return + } + t.Errorf("s1.Accept: %v", err) + return + } + select { + case s1Conns <- c: + case <-ctx.Done(): + c.Close() + } + } + }() - r, err := ln.Accept() + w, err := s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip)) if err != nil { t.Fatal(err) } @@ -325,32 +367,56 @@ func TestConn(t *testing.T) { t.Fatal(err) } - got := make([]byte, len(want)) - if _, err := io.ReadAtLeast(r, got, len(got)); err != nil { - t.Fatal(err) - } - t.Logf("got: %q", got) - if string(got) != want { - t.Errorf("got %q, want %q", got, want) + select { + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for connection") + case r := <-s1Conns: + got := make([]byte, len(want)) + _, err := io.ReadAtLeast(r, got, len(got)) + r.Close() + if err != nil { + t.Fatal(err) + } + t.Logf("got: %q", got) + if string(got) != want { + t.Errorf("got %q, want %q", got, want) + } } + // Dial a non-existent port on s1 and expect it to fail. _, err = s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8082", s1ip)) // some random port if err == nil { t.Fatalf("unexpected success; should have seen a connection refused error") } - - // s1 is a subnet router for TEST-NET-1 (192.0.2.0/24). Lets dial to that - // subnet from s2 to ensure a listener without an IP address (i.e. ":8081") - // only matches destination IPs corresponding to the node's IP, and not - // to any random IP a subnet is routing. - _, err = s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", "192.0.2.1")) + t.Logf("got expected failure: %v", err) + + // s1 is a subnet router for TEST-NET-1 (192.0.2.0/24). Let's dial to that + // subnet from s2 to ensure a listener without an IP address (i.e. our + // ":8081" listen above) only matches destination IPs corresponding to the + // s1 node's IP addresses, and not to any random IP of a subnet it's routing. + // + // The RegisterFallbackTCPHandler on s1 above handles sending a RST when the + // TCP SYN arrives from s2. But we bound it to 5 seconds lest a regression + // like tailscale/tailscale#17805 recur. + s2dialer := s2.Sys().Dialer.Get() + s2dialer.SetSystemDialerForTest(func(ctx context.Context, netw, addr string) (net.Conn, error) { + t.Logf("s2: unexpected system dial called for %s %s", netw, addr) + return nil, fmt.Errorf("system dialer called unexpectedly for %s %s", netw, addr) + }) + docCtx, docCancel := context.WithTimeout(ctx, 5*time.Second) + defer docCancel() + _, err = s2.Dial(docCtx, "tcp", "192.0.2.1:8081") if err == nil { t.Fatalf("unexpected success; should have seen a connection refused error") } + if !saw192DocNetDial.Load() { + t.Errorf("expected s1's fallback TCP handler to have been called for 192.0.2.1:8081") + } } func TestLoopbackLocalAPI(t *testing.T) { flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/8557") + tstest.Shard(t) tstest.ResourceCheck(t) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -426,6 +492,7 @@ func TestLoopbackLocalAPI(t *testing.T) { func TestLoopbackSOCKS5(t *testing.T) { flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/8198") + tstest.Shard(t) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -476,6 +543,7 @@ func TestLoopbackSOCKS5(t *testing.T) { } func TestTailscaleIPs(t *testing.T) { + tstest.Shard(t) controlURL, _ := startControl(t) tmp := t.TempDir() @@ -518,6 +586,7 @@ func TestTailscaleIPs(t *testing.T) { // TestListenerCleanup is a regression test to verify that s.Close doesn't // deadlock if a listener is still open. func TestListenerCleanup(t *testing.T) { + tstest.Shard(t) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -560,6 +629,7 @@ func (wc *closeTrackConn) Close() error { // tests https://github.com/tailscale/tailscale/issues/6973 -- that we can start a tsnet server, // stop it, and restart it, even on Windows. func TestStartStopStartGetsSameIP(t *testing.T) { + tstest.Shard(t) controlURL, _ := startControl(t) tmp := t.TempDir() @@ -609,6 +679,7 @@ func TestStartStopStartGetsSameIP(t *testing.T) { } func TestFunnel(t *testing.T) { + tstest.Shard(t) ctx, dialCancel := context.WithTimeout(context.Background(), 30*time.Second) defer dialCancel() @@ -670,6 +741,7 @@ func TestFunnel(t *testing.T) { } func TestListenerClose(t *testing.T) { + tstest.Shard(t) ctx := context.Background() controlURL, _ := startControl(t) @@ -749,6 +821,7 @@ func (c *bufferedConn) Read(b []byte) (int, error) { } func TestFallbackTCPHandler(t *testing.T) { + tstest.Shard(t) tstest.ResourceCheck(t) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -791,6 +864,7 @@ func TestFallbackTCPHandler(t *testing.T) { } func TestCapturePcap(t *testing.T) { + tstest.Shard(t) const timeLimit = 120 ctx, cancel := context.WithTimeout(context.Background(), timeLimit*time.Second) defer cancel() @@ -844,6 +918,7 @@ func TestCapturePcap(t *testing.T) { } func TestUDPConn(t *testing.T) { + tstest.Shard(t) tstest.ResourceCheck(t) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -946,11 +1021,11 @@ func promMetricLabelsStr(labels []*dto.LabelPair) string { } var b strings.Builder b.WriteString("{") - for i, l := range labels { + for i, lb := range labels { if i > 0 { b.WriteString(",") } - b.WriteString(fmt.Sprintf("%s=%q", l.GetName(), l.GetValue())) + b.WriteString(fmt.Sprintf("%s=%q", lb.GetName(), lb.GetValue())) } b.WriteString("}") return b.String() @@ -958,8 +1033,8 @@ func promMetricLabelsStr(labels []*dto.LabelPair) string { // sendData sends a given amount of bytes from s1 to s2. func sendData(logf func(format string, args ...any), ctx context.Context, bytesCount int, s1, s2 *Server, s1ip, s2ip netip.Addr) error { - l := must.Get(s1.Listen("tcp", fmt.Sprintf("%s:8081", s1ip))) - defer l.Close() + lb := must.Get(s1.Listen("tcp", fmt.Sprintf("%s:8081", s1ip))) + defer lb.Close() // Dial to s1 from s2 w, err := s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip)) @@ -974,7 +1049,7 @@ func sendData(logf func(format string, args ...any), ctx context.Context, bytesC defer close(allReceived) go func() { - conn, err := l.Accept() + conn, err := lb.Accept() if err != nil { allReceived <- err return @@ -1035,6 +1110,7 @@ func sendData(logf func(format string, args ...any), ctx context.Context, bytesC } func TestUserMetricsByteCounters(t *testing.T) { + tstest.Shard(t) ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() @@ -1149,6 +1225,7 @@ func TestUserMetricsByteCounters(t *testing.T) { } func TestUserMetricsRouteGauges(t *testing.T) { + tstest.Shard(t) // Windows does not seem to support or report back routes when running in // userspace via tsnet. So, we skip this check on Windows. // TODO(kradalby): Figure out if this is correct. @@ -1305,6 +1382,7 @@ func mustDirect(t *testing.T, logf logger.Logf, lc1, lc2 *local.Client) { } func TestDeps(t *testing.T) { + tstest.Shard(t) deptest.DepChecker{ GOOS: "linux", GOARCH: "amd64", diff --git a/tstest/chonktest/tailchonk_test.go b/tstest/chonktest/tailchonk_test.go index 6dfab798ed11f..d9343e9160ea9 100644 --- a/tstest/chonktest/tailchonk_test.go +++ b/tstest/chonktest/tailchonk_test.go @@ -18,7 +18,7 @@ func TestImplementsChonk(t *testing.T) { { name: "Mem", newChonk: func(t *testing.T) tka.Chonk { - return &tka.Mem{} + return tka.ChonkMem() }, }, { @@ -42,7 +42,7 @@ func TestImplementsCompactableChonk(t *testing.T) { { name: "Mem", newChonk: func(t *testing.T) tka.CompactableChonk { - return &tka.Mem{} + return tka.ChonkMem() }, }, { diff --git a/tstest/integration/integration.go b/tstest/integration/integration.go index 6700205cf8f55..a62173ae3e353 100644 --- a/tstest/integration/integration.go +++ b/tstest/integration/integration.go @@ -576,6 +576,7 @@ type TestNode struct { stateFile string upFlagGOOS string // if non-empty, sets TS_DEBUG_UP_FLAG_GOOS for cmd/tailscale CLI encryptState bool + allowUpdates bool mu sync.Mutex onLogLine []func([]byte) @@ -840,6 +841,9 @@ func (n *TestNode) StartDaemonAsIPNGOOS(ipnGOOS string) *Daemon { "TS_DISABLE_PORTMAPPER=1", // shouldn't be needed; test is all localhost "TS_DEBUG_LOG_RATE=all", ) + if n.allowUpdates { + cmd.Env = append(cmd.Env, "TS_TEST_ALLOW_AUTO_UPDATE=1") + } if n.env.loopbackPort != nil { cmd.Env = append(cmd.Env, "TS_DEBUG_NETSTACK_LOOPBACK_PORT="+strconv.Itoa(*n.env.loopbackPort)) } @@ -914,7 +918,7 @@ func (n *TestNode) Ping(otherNode *TestNode) error { t := n.env.t ip := otherNode.AwaitIP4().String() t.Logf("Running ping %v (from %v)...", ip, n.AwaitIP4()) - return n.Tailscale("ping", ip).Run() + return n.Tailscale("ping", "--timeout=1s", ip).Run() } // AwaitListening waits for the tailscaled to be serving local clients @@ -1073,6 +1077,46 @@ func (n *TestNode) MustStatus() *ipnstate.Status { return st } +// PublicKey returns the hex-encoded public key of this node, +// e.g. `nodekey:123456abc` +func (n *TestNode) PublicKey() string { + tb := n.env.t + tb.Helper() + cmd := n.Tailscale("status", "--json") + out, err := cmd.CombinedOutput() + if err != nil { + tb.Fatalf("running `tailscale status`: %v, %s", err, out) + } + + type Self struct{ PublicKey string } + type StatusOutput struct{ Self Self } + + var st StatusOutput + if err := json.Unmarshal(out, &st); err != nil { + tb.Fatalf("decoding `tailscale status` JSON: %v\njson:\n%s", err, out) + } + return st.Self.PublicKey +} + +// NLPublicKey returns the hex-encoded network lock public key of +// this node, e.g. `tlpub:123456abc` +func (n *TestNode) NLPublicKey() string { + tb := n.env.t + tb.Helper() + cmd := n.Tailscale("lock", "status", "--json") + out, err := cmd.CombinedOutput() + if err != nil { + tb.Fatalf("running `tailscale lock status`: %v, %s", err, out) + } + st := struct { + PublicKey string `json:"PublicKey"` + }{} + if err := json.Unmarshal(out, &st); err != nil { + tb.Fatalf("decoding `tailscale lock status` JSON: %v\njson:\n%s", err, out) + } + return st.PublicKey +} + // trafficTrap is an HTTP proxy handler to note whether any // HTTP traffic tries to leave localhost from tailscaled. We don't // expect any, so any request triggers a failure. diff --git a/tstest/integration/integration_test.go b/tstest/integration/integration_test.go index 64f49c7b80afd..fc891ad722b28 100644 --- a/tstest/integration/integration_test.go +++ b/tstest/integration/integration_test.go @@ -22,8 +22,10 @@ import ( "path/filepath" "regexp" "runtime" + "slices" "strconv" "strings" + "sync" "sync/atomic" "testing" "time" @@ -36,6 +38,7 @@ import ( "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/feature" _ "tailscale.com/feature/clientupdate" + "tailscale.com/health" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/net/tsaddr" @@ -1410,14 +1413,27 @@ func TestLogoutRemovesAllPeers(t *testing.T) { wantNode0PeerCount(expectedPeers) // all existing peers and the new node } -func TestAutoUpdateDefaults(t *testing.T) { - if !feature.CanAutoUpdate() { - t.Skip("auto-updates not supported on this platform") - } +func TestAutoUpdateDefaults(t *testing.T) { testAutoUpdateDefaults(t, false) } +func TestAutoUpdateDefaults_cap(t *testing.T) { testAutoUpdateDefaults(t, true) } + +// useCap is whether to use NodeAttrDefaultAutoUpdate (as opposed to the old +// DeprecatedDefaultAutoUpdate top-level MapResponse field). +func testAutoUpdateDefaults(t *testing.T, useCap bool) { + t.Cleanup(feature.HookCanAutoUpdate.SetForTest(func() bool { return true })) + tstest.Shard(t) - tstest.Parallel(t) env := NewTestEnv(t) + var ( + modifyMu sync.Mutex + modifyFirstMapResponse = func(*tailcfg.MapResponse, *tailcfg.MapRequest) {} + ) + env.Control.ModifyFirstMapResponse = func(mr *tailcfg.MapResponse, req *tailcfg.MapRequest) { + modifyMu.Lock() + defer modifyMu.Unlock() + modifyFirstMapResponse(mr, req) + } + checkDefault := func(n *TestNode, want bool) error { enabled, ok := n.diskPrefs().AutoUpdate.Apply.Get() if !ok { @@ -1429,17 +1445,23 @@ func TestAutoUpdateDefaults(t *testing.T) { return nil } - sendAndCheckDefault := func(t *testing.T, n *TestNode, send, want bool) { - t.Helper() - if !env.Control.AddRawMapResponse(n.MustStatus().Self.PublicKey, &tailcfg.MapResponse{ - DefaultAutoUpdate: opt.NewBool(send), - }) { - t.Fatal("failed to send MapResponse to node") - } - if err := tstest.WaitFor(2*time.Second, func() error { - return checkDefault(n, want) - }); err != nil { - t.Fatal(err) + setDefaultAutoUpdate := func(send bool) { + modifyMu.Lock() + defer modifyMu.Unlock() + modifyFirstMapResponse = func(mr *tailcfg.MapResponse, req *tailcfg.MapRequest) { + if mr.Node == nil { + mr.Node = &tailcfg.Node{} + } + if useCap { + if mr.Node.CapMap == nil { + mr.Node.CapMap = make(tailcfg.NodeCapMap) + } + mr.Node.CapMap[tailcfg.NodeAttrDefaultAutoUpdate] = []tailcfg.RawMessage{ + tailcfg.RawMessage(fmt.Sprintf("%t", send)), + } + } else { + mr.DeprecatedDefaultAutoUpdate = opt.NewBool(send) + } } } @@ -1450,29 +1472,54 @@ func TestAutoUpdateDefaults(t *testing.T) { { desc: "tailnet-default-false", run: func(t *testing.T, n *TestNode) { - // First received default "false". - sendAndCheckDefault(t, n, false, false) - // Should not be changed even if sent "true" later. - sendAndCheckDefault(t, n, true, false) + + // First the server sends "false", and client should remember that. + setDefaultAutoUpdate(false) + n.MustUp() + n.AwaitRunning() + checkDefault(n, false) + + // Now we disconnect and change the server to send "true", which + // the client should ignore, having previously remembered + // "false". + n.MustDown() + setDefaultAutoUpdate(true) // control sends default "true" + n.MustUp() + n.AwaitRunning() + checkDefault(n, false) // still false + // But can be changed explicitly by the user. if out, err := n.TailscaleForOutput("set", "--auto-update").CombinedOutput(); err != nil { t.Fatalf("failed to enable auto-update on node: %v\noutput: %s", err, out) } - sendAndCheckDefault(t, n, false, true) + checkDefault(n, true) }, }, { desc: "tailnet-default-true", run: func(t *testing.T, n *TestNode) { - // First received default "true". - sendAndCheckDefault(t, n, true, true) - // Should not be changed even if sent "false" later. - sendAndCheckDefault(t, n, false, true) + // Same as above but starting with default "true". + + // First the server sends "true", and client should remember that. + setDefaultAutoUpdate(true) + n.MustUp() + n.AwaitRunning() + checkDefault(n, true) + + // Now we disconnect and change the server to send "false", which + // the client should ignore, having previously remembered + // "true". + n.MustDown() + setDefaultAutoUpdate(false) // control sends default "false" + n.MustUp() + n.AwaitRunning() + checkDefault(n, true) // still true + // But can be changed explicitly by the user. if out, err := n.TailscaleForOutput("set", "--auto-update=false").CombinedOutput(); err != nil { - t.Fatalf("failed to disable auto-update on node: %v\noutput: %s", err, out) + t.Fatalf("failed to enable auto-update on node: %v\noutput: %s", err, out) } - sendAndCheckDefault(t, n, true, false) + checkDefault(n, false) }, }, { @@ -1482,22 +1529,21 @@ func TestAutoUpdateDefaults(t *testing.T) { if out, err := n.TailscaleForOutput("set", "--auto-update=false").CombinedOutput(); err != nil { t.Fatalf("failed to disable auto-update on node: %v\noutput: %s", err, out) } - // Defaults sent from control should be ignored. - sendAndCheckDefault(t, n, true, false) - sendAndCheckDefault(t, n, false, false) + + setDefaultAutoUpdate(true) + n.MustUp() + n.AwaitRunning() + checkDefault(n, false) }, }, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { n := NewTestNode(t, env) + n.allowUpdates = true d := n.StartDaemon() defer d.MustCleanShutdown(t) - n.AwaitResponding() - n.MustUp() - n.AwaitRunning() - tt.run(t, n) }) } @@ -2128,16 +2174,10 @@ func TestC2NDebugNetmap(t *testing.T) { var current netmap.NetworkMap must.Do(json.Unmarshal(resp.Current, ¤t)) - if !current.PrivateKey.IsZero() { - t.Errorf("current netmap has non-zero private key: %v", current.PrivateKey) - } // Check candidate netmap if we sent a map response. if cand != nil { var candidate netmap.NetworkMap must.Do(json.Unmarshal(resp.Candidate, &candidate)) - if !candidate.PrivateKey.IsZero() { - t.Errorf("candidate netmap has non-zero private key: %v", candidate.PrivateKey) - } if diff := cmp.Diff(current.SelfNode, candidate.SelfNode); diff != "" { t.Errorf("SelfNode differs (-current +candidate):\n%s", diff) } @@ -2213,7 +2253,7 @@ func TestC2NDebugNetmap(t *testing.T) { } } -func TestNetworkLock(t *testing.T) { +func TestTailnetLock(t *testing.T) { // If you run `tailscale lock log` on a node where Tailnet Lock isn't // enabled, you get an error explaining that. @@ -2251,4 +2291,112 @@ func TestNetworkLock(t *testing.T) { t.Fatalf("stderr: want %q, got %q", wantErr, errBuf.String()) } }) + + // If you create a tailnet with two signed nodes and one unsigned, + // the signed nodes can talk to each other but the unsigned node cannot + // talk to anybody. + t.Run("node-connectivity", func(t *testing.T) { + tstest.Shard(t) + t.Parallel() + + env := NewTestEnv(t) + env.Control.DefaultNodeCapabilities = &tailcfg.NodeCapMap{ + tailcfg.CapabilityTailnetLock: []tailcfg.RawMessage{}, + } + + // Start two nodes which will be our signing nodes. + signing1 := NewTestNode(t, env) + signing2 := NewTestNode(t, env) + + nodes := []*TestNode{signing1, signing2} + for _, n := range nodes { + d := n.StartDaemon() + defer d.MustCleanShutdown(t) + + n.MustUp() + n.AwaitRunning() + } + + // Initiate Tailnet Lock with the two signing nodes. + initCmd := signing1.Tailscale("lock", "init", + "--gen-disablements", "10", + "--confirm", + signing1.NLPublicKey(), signing2.NLPublicKey(), + ) + out, err := initCmd.CombinedOutput() + if err != nil { + t.Fatalf("init command failed: %q\noutput=%v", err, string(out)) + } + + // Check that the two signing nodes can ping each other + if err := signing1.Ping(signing2); err != nil { + t.Fatalf("ping signing1 -> signing2: %v", err) + } + if err := signing2.Ping(signing1); err != nil { + t.Fatalf("ping signing2 -> signing1: %v", err) + } + + // Create and start a third node + node3 := NewTestNode(t, env) + d3 := node3.StartDaemon() + defer d3.MustCleanShutdown(t) + node3.MustUp() + node3.AwaitRunning() + + if err := signing1.Ping(node3); err == nil { + t.Fatal("ping signing1 -> node3: expected err, but succeeded") + } + if err := node3.Ping(signing1); err == nil { + t.Fatal("ping node3 -> signing1: expected err, but succeeded") + } + + // Sign node3, and check the nodes can now talk to each other + signCmd := signing1.Tailscale("lock", "sign", node3.PublicKey()) + out, err = signCmd.CombinedOutput() + if err != nil { + t.Fatalf("sign command failed: %q\noutput = %v", err, string(out)) + } + + if err := signing1.Ping(node3); err != nil { + t.Fatalf("ping signing1 -> node3: expected success, got err: %v", err) + } + if err := node3.Ping(signing1); err != nil { + t.Fatalf("ping node3 -> signing1: expected success, got err: %v", err) + } + }) +} + +func TestNodeWithBadStateFile(t *testing.T) { + tstest.Shard(t) + tstest.Parallel(t) + env := NewTestEnv(t) + n1 := NewTestNode(t, env) + if err := os.WriteFile(n1.stateFile, []byte("bad json"), 0644); err != nil { + t.Fatal(err) + } + + d1 := n1.StartDaemon() + n1.AwaitResponding() + + // Make sure the health message shows up in status output. + n1.AwaitBackendState("NoState") + st := n1.MustStatus() + wantHealth := ipn.StateStoreHealth.Text(health.Args{health.ArgError: ""}) + if !slices.ContainsFunc(st.Health, func(m string) bool { return strings.HasPrefix(m, wantHealth) }) { + t.Errorf("Status does not contain expected health message %q\ngot health messages: %q", wantHealth, st.Health) + } + + // Make sure login attempts are rejected. + cmd := n1.Tailscale("up", "--login-server="+n1.env.ControlURL()) + t.Logf("Running %v ...", cmd) + out, err := cmd.CombinedOutput() + if err == nil { + t.Fatalf("up succeeded with output %q", out) + } + wantOut := "cannot start backend when state store is unhealthy" + if !strings.Contains(string(out), wantOut) { + t.Fatalf("got up output:\n%s\nwant:\n%s", string(out), wantOut) + } + + d1.MustCleanShutdown(t) } diff --git a/tstest/integration/tailscaled_deps_test_darwin.go b/tstest/integration/tailscaled_deps_test_darwin.go index 217188f75f6c0..9f92839d8cde7 100644 --- a/tstest/integration/tailscaled_deps_test_darwin.go +++ b/tstest/integration/tailscaled_deps_test_darwin.go @@ -27,6 +27,7 @@ import ( _ "tailscale.com/ipn/ipnlocal" _ "tailscale.com/ipn/ipnserver" _ "tailscale.com/ipn/store" + _ "tailscale.com/ipn/store/mem" _ "tailscale.com/logpolicy" _ "tailscale.com/logtail" _ "tailscale.com/net/dns" diff --git a/tstest/integration/tailscaled_deps_test_freebsd.go b/tstest/integration/tailscaled_deps_test_freebsd.go index 217188f75f6c0..9f92839d8cde7 100644 --- a/tstest/integration/tailscaled_deps_test_freebsd.go +++ b/tstest/integration/tailscaled_deps_test_freebsd.go @@ -27,6 +27,7 @@ import ( _ "tailscale.com/ipn/ipnlocal" _ "tailscale.com/ipn/ipnserver" _ "tailscale.com/ipn/store" + _ "tailscale.com/ipn/store/mem" _ "tailscale.com/logpolicy" _ "tailscale.com/logtail" _ "tailscale.com/net/dns" diff --git a/tstest/integration/tailscaled_deps_test_linux.go b/tstest/integration/tailscaled_deps_test_linux.go index 217188f75f6c0..9f92839d8cde7 100644 --- a/tstest/integration/tailscaled_deps_test_linux.go +++ b/tstest/integration/tailscaled_deps_test_linux.go @@ -27,6 +27,7 @@ import ( _ "tailscale.com/ipn/ipnlocal" _ "tailscale.com/ipn/ipnserver" _ "tailscale.com/ipn/store" + _ "tailscale.com/ipn/store/mem" _ "tailscale.com/logpolicy" _ "tailscale.com/logtail" _ "tailscale.com/net/dns" diff --git a/tstest/integration/tailscaled_deps_test_openbsd.go b/tstest/integration/tailscaled_deps_test_openbsd.go index 217188f75f6c0..9f92839d8cde7 100644 --- a/tstest/integration/tailscaled_deps_test_openbsd.go +++ b/tstest/integration/tailscaled_deps_test_openbsd.go @@ -27,6 +27,7 @@ import ( _ "tailscale.com/ipn/ipnlocal" _ "tailscale.com/ipn/ipnserver" _ "tailscale.com/ipn/store" + _ "tailscale.com/ipn/store/mem" _ "tailscale.com/logpolicy" _ "tailscale.com/logtail" _ "tailscale.com/net/dns" diff --git a/tstest/integration/tailscaled_deps_test_windows.go b/tstest/integration/tailscaled_deps_test_windows.go index f3cd5e75b9e36..82f8097c8bc36 100644 --- a/tstest/integration/tailscaled_deps_test_windows.go +++ b/tstest/integration/tailscaled_deps_test_windows.go @@ -37,6 +37,7 @@ import ( _ "tailscale.com/ipn/ipnlocal" _ "tailscale.com/ipn/ipnserver" _ "tailscale.com/ipn/store" + _ "tailscale.com/ipn/store/mem" _ "tailscale.com/logpolicy" _ "tailscale.com/logtail" _ "tailscale.com/net/dns" diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index f9a33705b7f56..19964c91ff8a4 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -33,6 +33,8 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/syncs" "tailscale.com/tailcfg" + "tailscale.com/tka" + "tailscale.com/tstest/tkatest" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/opt" @@ -79,6 +81,10 @@ type Server struct { ExplicitBaseURL string // e.g. "http://127.0.0.1:1234" with no trailing URL HTTPTestServer *httptest.Server // if non-nil, used to get BaseURL + // ModifyFirstMapResponse, if non-nil, is called exactly once per + // MapResponse stream to modify the first MapResponse sent in response to it. + ModifyFirstMapResponse func(*tailcfg.MapResponse, *tailcfg.MapRequest) + initMuxOnce sync.Once mux *http.ServeMux @@ -119,6 +125,10 @@ type Server struct { nodeKeyAuthed set.Set[key.NodePublic] msgToSend map[key.NodePublic]any // value is *tailcfg.PingRequest or entire *tailcfg.MapResponse allExpired bool // All nodes will be told their node key is expired. + + // tkaStorage records the Tailnet Lock state, if any. + // If nil, Tailnet Lock is not enabled in the Tailnet. + tkaStorage tka.CompactableChonk } // BaseURL returns the server's base URL, without trailing slash. @@ -325,6 +335,7 @@ func (s *Server) initMux() { w.WriteHeader(http.StatusNoContent) }) s.mux.HandleFunc("/key", s.serveKey) + s.mux.HandleFunc("/machine/tka/", s.serveTKA) s.mux.HandleFunc("/machine/", s.serveMachine) s.mux.HandleFunc("/ts2021", s.serveNoiseUpgrade) s.mux.HandleFunc("/c2n/", s.serveC2N) @@ -435,7 +446,7 @@ func (s *Server) serveKey(w http.ResponseWriter, r *http.Request) { func (s *Server) serveMachine(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { - http.Error(w, "POST required", 400) + http.Error(w, "POST required for serveMachine", 400) return } ctx := r.Context() @@ -464,6 +475,9 @@ func (s *Server) SetSubnetRoutes(nodeKey key.NodePublic, routes []netip.Prefix) defer s.mu.Unlock() s.logf("Setting subnet routes for %s: %v", nodeKey.ShortString(), routes) mak.Set(&s.nodeSubnetRoutes, nodeKey, routes) + if node, ok := s.nodes[nodeKey]; ok { + sendUpdate(s.updates[node.ID], updateSelfChanged) + } } // MasqueradePair is a pair of nodes and the IP address that the @@ -854,6 +868,132 @@ func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey key. w.Write(res) } +func (s *Server) serveTKA(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + http.Error(w, "GET required for serveTKA", 400) + return + } + + switch r.URL.Path { + case "/machine/tka/init/begin": + s.serveTKAInitBegin(w, r) + case "/machine/tka/init/finish": + s.serveTKAInitFinish(w, r) + case "/machine/tka/bootstrap": + s.serveTKABootstrap(w, r) + case "/machine/tka/sync/offer": + s.serveTKASyncOffer(w, r) + case "/machine/tka/sign": + s.serveTKASign(w, r) + default: + s.serveUnhandled(w, r) + } +} + +func (s *Server) serveTKAInitBegin(w http.ResponseWriter, r *http.Request) { + s.mu.Lock() + defer s.mu.Unlock() + + nodes := maps.Values(s.nodes) + genesisAUM, err := tkatest.HandleTKAInitBegin(w, r, nodes) + if err != nil { + go panic(fmt.Sprintf("HandleTKAInitBegin: %v", err)) + } + s.tkaStorage = tka.ChonkMem() + s.tkaStorage.CommitVerifiedAUMs([]tka.AUM{*genesisAUM}) +} + +func (s *Server) serveTKAInitFinish(w http.ResponseWriter, r *http.Request) { + signatures, err := tkatest.HandleTKAInitFinish(w, r) + if err != nil { + go panic(fmt.Sprintf("HandleTKAInitFinish: %v", err)) + } + + s.mu.Lock() + defer s.mu.Unlock() + + // Apply the signatures to each of the nodes. Because s.nodes is keyed + // by public key instead of node ID, we have to do this inefficiently. + // + // We only have small tailnets in the integration tests, so this isn't + // much of an issue. + for nodeID, sig := range signatures { + for _, n := range s.nodes { + if n.ID == nodeID { + n.KeySignature = sig + } + } + } +} + +func (s *Server) serveTKABootstrap(w http.ResponseWriter, r *http.Request) { + s.mu.Lock() + defer s.mu.Unlock() + if s.tkaStorage == nil { + http.Error(w, "no TKA state when calling serveTKABootstrap", 400) + return + } + + // Find the genesis AUM, which we need to include in the response. + var genesis *tka.AUM + allAUMs, err := s.tkaStorage.AllAUMs() + if err != nil { + http.Error(w, "unable to retrieve all AUMs from TKA state", 500) + return + } + for _, h := range allAUMs { + aum := must.Get(s.tkaStorage.AUM(h)) + if _, hasParent := aum.Parent(); !hasParent { + genesis = &aum + break + } + } + if genesis == nil { + http.Error(w, "unable to find genesis AUM in TKA state", 500) + return + } + + resp := tailcfg.TKABootstrapResponse{ + GenesisAUM: genesis.Serialize(), + } + _, err = tkatest.HandleTKABootstrap(w, r, resp) + if err != nil { + go panic(fmt.Sprintf("HandleTKABootstrap: %v", err)) + } +} + +func (s *Server) serveTKASyncOffer(w http.ResponseWriter, r *http.Request) { + s.mu.Lock() + defer s.mu.Unlock() + + authority, err := tka.Open(s.tkaStorage) + if err != nil { + go panic(fmt.Sprintf("serveTKASyncOffer: tka.Open: %v", err)) + } + + err = tkatest.HandleTKASyncOffer(w, r, authority, s.tkaStorage) + if err != nil { + go panic(fmt.Sprintf("HandleTKASyncOffer: %v", err)) + } +} + +func (s *Server) serveTKASign(w http.ResponseWriter, r *http.Request) { + s.mu.Lock() + defer s.mu.Unlock() + + authority, err := tka.Open(s.tkaStorage) + if err != nil { + go panic(fmt.Sprintf("serveTKASign: tka.Open: %v", err)) + } + + sig, keyBeingSigned, err := tkatest.HandleTKASign(w, r, authority) + if err != nil { + go panic(fmt.Sprintf("HandleTKASign: %v", err)) + } + s.nodes[*keyBeingSigned].KeySignature = *sig + s.updateLocked("TKASign", s.nodeIDsLocked(0)) +} + // updateType indicates why a long-polling map request is being woken // up for an update. type updateType int @@ -990,6 +1130,7 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi // register an updatesCh to get updates. streaming := req.Stream && !req.ReadOnly compress := req.Compress != "" + first := true w.WriteHeader(200) for { @@ -1022,6 +1163,10 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi if allExpired { res.Node.KeyExpiry = time.Now().Add(-1 * time.Minute) } + if f := s.ModifyFirstMapResponse; first && f != nil { + first = false + f(res, req) + } // TODO: add minner if/when needed resBytes, err := json.Marshal(res) if err != nil { @@ -1185,6 +1330,21 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, v6Prefix, } + // If the server is tracking TKA state, and there's a single TKA head, + // add it to the MapResponse. + if s.tkaStorage != nil { + heads, err := s.tkaStorage.Heads() + if err != nil { + log.Printf("unable to get TKA heads: %v", err) + } else if len(heads) != 1 { + log.Printf("unable to get single TKA head, got %v", heads) + } else { + res.TKAInfo = &tailcfg.TKAInfo{ + Head: heads[0].Hash().String(), + } + } + } + s.mu.Lock() defer s.mu.Unlock() res.Node.PrimaryRoutes = s.nodeSubnetRoutes[nk] diff --git a/tstest/integration/vms/vms_test.go b/tstest/integration/vms/vms_test.go index 0bab3ba5d96d5..c3a3775de9407 100644 --- a/tstest/integration/vms/vms_test.go +++ b/tstest/integration/vms/vms_test.go @@ -184,14 +184,14 @@ type ipMapping struct { // it is difficult to be 100% sure. This function should be used with care. It // will probably do what you want, but it is very easy to hold this wrong. func getProbablyFreePortNumber() (int, error) { - l, err := net.Listen("tcp", ":0") + ln, err := net.Listen("tcp", ":0") if err != nil { return 0, err } - defer l.Close() + defer ln.Close() - _, port, err := net.SplitHostPort(l.Addr().String()) + _, port, err := net.SplitHostPort(ln.Addr().String()) if err != nil { return 0, err } diff --git a/tstest/kernel_linux.go b/tstest/kernel_linux.go new file mode 100644 index 0000000000000..664fe9bdd7b9f --- /dev/null +++ b/tstest/kernel_linux.go @@ -0,0 +1,50 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package tstest + +import ( + "strconv" + "strings" + + "golang.org/x/sys/unix" +) + +// KernelVersion returns the major, minor, and patch version of the Linux kernel. +// It returns (0, 0, 0) if the version cannot be determined. +func KernelVersion() (major, minor, patch int) { + var uname unix.Utsname + if err := unix.Uname(&uname); err != nil { + return 0, 0, 0 + } + release := unix.ByteSliceToString(uname.Release[:]) + + // Parse version string (e.g., "5.15.0-...") + parts := strings.Split(release, ".") + if len(parts) < 3 { + return 0, 0, 0 + } + + major, err := strconv.Atoi(parts[0]) + if err != nil { + return 0, 0, 0 + } + + minor, err = strconv.Atoi(parts[1]) + if err != nil { + return 0, 0, 0 + } + + // Patch version may have additional info after a hyphen (e.g., "0-76-generic") + // Extract just the numeric part before any hyphen + patchStr, _, _ := strings.Cut(parts[2], "-") + + patch, err = strconv.Atoi(patchStr) + if err != nil { + return 0, 0, 0 + } + + return major, minor, patch +} diff --git a/tstest/kernel_other.go b/tstest/kernel_other.go new file mode 100644 index 0000000000000..bf69be6df4b27 --- /dev/null +++ b/tstest/kernel_other.go @@ -0,0 +1,11 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package tstest + +// KernelVersion returns (0, 0, 0) on unsupported platforms. +func KernelVersion() (major, minor, patch int) { + return 0, 0, 0 +} diff --git a/tstest/tkatest/tkatest.go b/tstest/tkatest/tkatest.go new file mode 100644 index 0000000000000..fb157a1a19315 --- /dev/null +++ b/tstest/tkatest/tkatest.go @@ -0,0 +1,220 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// tkatest has functions for creating a mock control server that responds +// to TKA endpoints. +package tkatest + +import ( + "encoding/json" + "errors" + "fmt" + "iter" + "log" + "net/http" + + "tailscale.com/tailcfg" + "tailscale.com/tka" + "tailscale.com/types/key" + "tailscale.com/types/tkatype" +) + +func serverError(w http.ResponseWriter, format string, a ...any) error { + err := fmt.Sprintf(format, a...) + http.Error(w, err, 500) + log.Printf("returning HTTP 500 error: %v", err) + return errors.New(err) +} + +func userError(w http.ResponseWriter, format string, a ...any) error { + err := fmt.Sprintf(format, a...) + http.Error(w, err, 400) + return errors.New(err) +} + +// HandleTKAInitBegin handles a request to /machine/tka/init/begin. +// +// If the request contains a valid genesis AUM, it sends a response to the +// client, and returns the AUM to the caller. +func HandleTKAInitBegin(w http.ResponseWriter, r *http.Request, nodes iter.Seq[*tailcfg.Node]) (*tka.AUM, error) { + var req *tailcfg.TKAInitBeginRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, userError(w, "Decode: %v", err) + } + var aum tka.AUM + if err := aum.Unserialize(req.GenesisAUM); err != nil { + return nil, userError(w, "invalid genesis AUM: %v", err) + } + beginResp := tailcfg.TKAInitBeginResponse{} + for n := range nodes { + beginResp.NeedSignatures = append( + beginResp.NeedSignatures, + tailcfg.TKASignInfo{ + NodeID: n.ID, + NodePublic: n.Key, + }, + ) + } + + w.WriteHeader(200) + if err := json.NewEncoder(w).Encode(beginResp); err != nil { + return nil, serverError(w, "Encode: %v", err) + } + return &aum, nil +} + +// HandleTKAInitFinish handles a request to /machine/tka/init/finish. +// +// It sends a response to the client, and gives the caller a list of node +// signatures to apply. +// +// This method assumes that the node signatures are valid, and does not +// verify them with the supplied public key. +func HandleTKAInitFinish(w http.ResponseWriter, r *http.Request) (map[tailcfg.NodeID]tkatype.MarshaledSignature, error) { + var req *tailcfg.TKAInitFinishRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, userError(w, "Decode: %v", err) + } + + w.WriteHeader(200) + w.Write([]byte("{}")) + + return req.Signatures, nil +} + +// HandleTKABootstrap handles a request to /tka/bootstrap. +// +// If the request is valid, it sends a response to the client, and returns +// the parsed request to the caller. +func HandleTKABootstrap(w http.ResponseWriter, r *http.Request, resp tailcfg.TKABootstrapResponse) (*tailcfg.TKABootstrapRequest, error) { + req := new(tailcfg.TKABootstrapRequest) + if err := json.NewDecoder(r.Body).Decode(req); err != nil { + return nil, userError(w, "Decode: %v", err) + } + if req.Version != tailcfg.CurrentCapabilityVersion { + return nil, userError(w, "bootstrap CapVer = %v, want %v", req.Version, tailcfg.CurrentCapabilityVersion) + } + + w.WriteHeader(200) + if err := json.NewEncoder(w).Encode(resp); err != nil { + return nil, serverError(w, "Encode: %v", err) + } + return req, nil +} + +func HandleTKASyncOffer(w http.ResponseWriter, r *http.Request, authority *tka.Authority, chonk tka.Chonk) error { + body := new(tailcfg.TKASyncOfferRequest) + if err := json.NewDecoder(r.Body).Decode(body); err != nil { + return userError(w, "Decode: %v", err) + } + + log.Printf("got sync offer:\n%+v", body) + + nodeOffer, err := tka.ToSyncOffer(body.Head, body.Ancestors) + if err != nil { + return userError(w, "ToSyncOffer: %v", err) + } + + controlOffer, err := authority.SyncOffer(chonk) + if err != nil { + return serverError(w, "authority.SyncOffer: %v", err) + } + sendAUMs, err := authority.MissingAUMs(chonk, nodeOffer) + if err != nil { + return serverError(w, "authority.MissingAUMs: %v", err) + } + + head, ancestors, err := tka.FromSyncOffer(controlOffer) + if err != nil { + return serverError(w, "FromSyncOffer: %v", err) + } + resp := tailcfg.TKASyncOfferResponse{ + Head: head, + Ancestors: ancestors, + MissingAUMs: make([]tkatype.MarshaledAUM, len(sendAUMs)), + } + for i, a := range sendAUMs { + resp.MissingAUMs[i] = a.Serialize() + } + + log.Printf("responding to sync offer with:\n%+v", resp) + w.WriteHeader(200) + if err := json.NewEncoder(w).Encode(resp); err != nil { + return serverError(w, "Encode: %v", err) + } + return nil +} + +// HandleTKASign handles a request to /machine/tka/sign. +// +// If the signature request is valid, it sends a response to the client, and +// gives the caller the signature and public key of the node being signed. +func HandleTKASign(w http.ResponseWriter, r *http.Request, authority *tka.Authority) (*tkatype.MarshaledSignature, *key.NodePublic, error) { + req := new(tailcfg.TKASubmitSignatureRequest) + if err := json.NewDecoder(r.Body).Decode(req); err != nil { + return nil, nil, userError(w, "Decode: %v", err) + } + if req.Version != tailcfg.CurrentCapabilityVersion { + return nil, nil, userError(w, "sign CapVer = %v, want %v", req.Version, tailcfg.CurrentCapabilityVersion) + } + + var sig tka.NodeKeySignature + if err := sig.Unserialize(req.Signature); err != nil { + return nil, nil, userError(w, "malformed signature: %v", err) + } + var keyBeingSigned key.NodePublic + if err := keyBeingSigned.UnmarshalBinary(sig.Pubkey); err != nil { + return nil, nil, userError(w, "malformed signature pubkey: %v", err) + } + if err := authority.NodeKeyAuthorized(keyBeingSigned, req.Signature); err != nil { + return nil, nil, userError(w, "signature does not verify: %v", err) + } + + w.WriteHeader(200) + if err := json.NewEncoder(w).Encode(tailcfg.TKASubmitSignatureResponse{}); err != nil { + return nil, nil, serverError(w, "Encode: %v", err) + } + return &req.Signature, &keyBeingSigned, nil +} + +// HandleTKASyncSend handles a request to /machine/tka/send. +// +// If the request is valid, it adds the new AUMs to the authority, and sends +// a response to the client with the new head. +func HandleTKASyncSend(w http.ResponseWriter, r *http.Request, authority *tka.Authority, chonk tka.Chonk) error { + body := new(tailcfg.TKASyncSendRequest) + if err := json.NewDecoder(r.Body).Decode(body); err != nil { + return userError(w, "Decode: %v", err) + } + log.Printf("got sync send:\n%+v", body) + + var remoteHead tka.AUMHash + if err := remoteHead.UnmarshalText([]byte(body.Head)); err != nil { + return userError(w, "head unmarshal: %v", err) + } + toApply := make([]tka.AUM, len(body.MissingAUMs)) + for i, a := range body.MissingAUMs { + if err := toApply[i].Unserialize(a); err != nil { + return userError(w, "decoding missingAUM[%d]: %v", i, err) + } + } + + if len(toApply) > 0 { + if err := authority.Inform(chonk, toApply); err != nil { + return serverError(w, "control.Inform(%+v) failed: %v", toApply, err) + } + } + head, err := authority.Head().MarshalText() + if err != nil { + return serverError(w, "head marshal: %v", err) + } + + resp := tailcfg.TKASyncSendResponse{ + Head: string(head), + } + w.WriteHeader(200) + if err := json.NewEncoder(w).Encode(resp); err != nil { + return serverError(w, "Encode: %v", err) + } + return nil +} diff --git a/tstest/tstest.go b/tstest/tstest.go index 169450686966d..d0828f508a46c 100644 --- a/tstest/tstest.go +++ b/tstest/tstest.go @@ -6,6 +6,7 @@ package tstest import ( "context" + "fmt" "os" "strconv" "strings" @@ -93,3 +94,20 @@ func Parallel(t *testing.T) { t.Parallel() } } + +// SkipOnKernelVersions skips the test if the current +// kernel version is in the specified list. +func SkipOnKernelVersions(t testing.TB, issue string, versions ...string) { + major, minor, patch := KernelVersion() + if major == 0 && minor == 0 && patch == 0 { + t.Logf("could not determine kernel version") + return + } + + current := fmt.Sprintf("%d.%d.%d", major, minor, patch) + for _, v := range versions { + if v == current { + t.Skipf("skipping on kernel version %q - see issue %s", current, issue) + } + } +} diff --git a/tstest/tstest_test.go b/tstest/tstest_test.go index e988d5d5624b6..ce59bde538b9a 100644 --- a/tstest/tstest_test.go +++ b/tstest/tstest_test.go @@ -3,7 +3,10 @@ package tstest -import "testing" +import ( + "runtime" + "testing" +) func TestReplace(t *testing.T) { before := "before" @@ -22,3 +25,17 @@ func TestReplace(t *testing.T) { t.Errorf("before = %q; want %q", before, "before") } } + +func TestKernelVersion(t *testing.T) { + switch runtime.GOOS { + case "linux": + default: + t.Skipf("skipping test on %s", runtime.GOOS) + } + + major, minor, patch := KernelVersion() + if major == 0 && minor == 0 && patch == 0 { + t.Fatal("KernelVersion returned (0, 0, 0); expected valid version") + } + t.Logf("Kernel version: %d.%d.%d", major, minor, patch) +} diff --git a/tstest/typewalk/typewalk.go b/tstest/typewalk/typewalk.go new file mode 100644 index 0000000000000..b22505351b1a2 --- /dev/null +++ b/tstest/typewalk/typewalk.go @@ -0,0 +1,106 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package typewalk provides utilities to walk Go types using reflection. +package typewalk + +import ( + "iter" + "reflect" + "strings" +) + +// Path describes a path via a type where a private key may be found, +// along with a function to test whether a reflect.Value at that path is +// non-zero. +type Path struct { + // Name is the path from the root type, suitable for using as a t.Run name. + Name string + + // Walk returns the reflect.Value at the end of the path, given a root + // reflect.Value. + Walk func(root reflect.Value) (leaf reflect.Value) +} + +// MatchingPaths returns a sequence of [Path] for all paths +// within the given type that end in a type matching match. +func MatchingPaths(rt reflect.Type, match func(reflect.Type) bool) iter.Seq[Path] { + // valFromRoot is a function that, given a reflect.Value of the root struct, + // returns the reflect.Value at some path within it. + type valFromRoot func(reflect.Value) reflect.Value + + return func(yield func(Path) bool) { + var walk func(reflect.Type, valFromRoot) + var path []string + var done bool + seen := map[reflect.Type]bool{} + + walk = func(t reflect.Type, getV valFromRoot) { + if seen[t] { + return + } + seen[t] = true + defer func() { seen[t] = false }() + if done { + return + } + if match(t) { + if !yield(Path{ + Name: strings.Join(path, "."), + Walk: getV, + }) { + done = true + } + return + } + switch t.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Array: + walk(t.Elem(), func(root reflect.Value) reflect.Value { + v := getV(root) + return v.Elem() + }) + case reflect.Struct: + for i := range t.NumField() { + sf := t.Field(i) + fieldName := sf.Name + if fieldName == "_" { + continue + } + path = append(path, fieldName) + walk(sf.Type, func(root reflect.Value) reflect.Value { + return getV(root).FieldByName(fieldName) + }) + path = path[:len(path)-1] + if done { + return + } + } + case reflect.Map: + walk(t.Elem(), func(root reflect.Value) reflect.Value { + v := getV(root) + if v.Len() == 0 { + return reflect.Zero(t.Elem()) + } + iter := v.MapRange() + iter.Next() + return iter.Value() + }) + if done { + return + } + walk(t.Key(), func(root reflect.Value) reflect.Value { + v := getV(root) + if v.Len() == 0 { + return reflect.Zero(t.Key()) + } + iter := v.MapRange() + iter.Next() + return iter.Key() + }) + } + } + + path = append(path, rt.Name()) + walk(rt, func(v reflect.Value) reflect.Value { return v }) + } +} diff --git a/tsweb/tsweb.go b/tsweb/tsweb.go index 119fed2e61012..869b4cc8ea566 100644 --- a/tsweb/tsweb.go +++ b/tsweb/tsweb.go @@ -628,8 +628,8 @@ type loggingResponseWriter struct { // from r, or falls back to logf. If a nil logger is given, the logs are // discarded. func newLogResponseWriter(logf logger.Logf, w http.ResponseWriter, r *http.Request) *loggingResponseWriter { - if l, ok := logger.LogfKey.ValueOk(r.Context()); ok && l != nil { - logf = l + if lg, ok := logger.LogfKey.ValueOk(r.Context()); ok && lg != nil { + logf = lg } if logf == nil { logf = logger.Discard @@ -642,46 +642,46 @@ func newLogResponseWriter(logf logger.Logf, w http.ResponseWriter, r *http.Reque } // WriteHeader implements [http.ResponseWriter]. -func (l *loggingResponseWriter) WriteHeader(statusCode int) { - if l.code != 0 { - l.logf("[unexpected] HTTP handler set statusCode twice (%d and %d)", l.code, statusCode) +func (lg *loggingResponseWriter) WriteHeader(statusCode int) { + if lg.code != 0 { + lg.logf("[unexpected] HTTP handler set statusCode twice (%d and %d)", lg.code, statusCode) return } - if l.ctx.Err() == nil { - l.code = statusCode + if lg.ctx.Err() == nil { + lg.code = statusCode } - l.ResponseWriter.WriteHeader(statusCode) + lg.ResponseWriter.WriteHeader(statusCode) } // Write implements [http.ResponseWriter]. -func (l *loggingResponseWriter) Write(bs []byte) (int, error) { - if l.code == 0 { - l.code = 200 +func (lg *loggingResponseWriter) Write(bs []byte) (int, error) { + if lg.code == 0 { + lg.code = 200 } - n, err := l.ResponseWriter.Write(bs) - l.bytes += n + n, err := lg.ResponseWriter.Write(bs) + lg.bytes += n return n, err } // Hijack implements http.Hijacker. Note that hijacking can still fail // because the wrapped ResponseWriter is not required to implement // Hijacker, as this breaks HTTP/2. -func (l *loggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - h, ok := l.ResponseWriter.(http.Hijacker) +func (lg *loggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + h, ok := lg.ResponseWriter.(http.Hijacker) if !ok { return nil, nil, errors.New("ResponseWriter is not a Hijacker") } conn, buf, err := h.Hijack() if err == nil { - l.hijacked = true + lg.hijacked = true } return conn, buf, err } -func (l loggingResponseWriter) Flush() { - f, _ := l.ResponseWriter.(http.Flusher) +func (lg loggingResponseWriter) Flush() { + f, _ := lg.ResponseWriter.(http.Flusher) if f == nil { - l.logf("[unexpected] tried to Flush a ResponseWriter that can't flush") + lg.logf("[unexpected] tried to Flush a ResponseWriter that can't flush") return } f.Flush() diff --git a/types/geo/quantize_test.go b/types/geo/quantize_test.go index 3c707e303c250..bc1f62c9be32f 100644 --- a/types/geo/quantize_test.go +++ b/types/geo/quantize_test.go @@ -32,20 +32,20 @@ func TestPointAnonymize(t *testing.T) { last := geo.MakePoint(llat, 0) cur := geo.MakePoint(lat, 0) anon := cur.Quantize() - switch l, g, err := anon.LatLng(); { + switch latlng, g, err := anon.LatLng(); { case err != nil: t.Fatal(err) case lat == southPole: // initialize llng, to the first snapped longitude - llat = l + llat = latlng goto Lng case g != 0: t.Fatalf("%v is west or east of %v", anon, last) - case l < llat: + case latlng < llat: t.Fatalf("%v is south of %v", anon, last) - case l == llat: + case latlng == llat: continue - case l > llat: + case latlng > llat: switch dist, err := last.DistanceTo(anon); { case err != nil: t.Fatal(err) @@ -55,7 +55,7 @@ func TestPointAnonymize(t *testing.T) { t.Logf("lat=%v last=%v cur=%v anon=%v", lat, last, cur, anon) t.Fatalf("%v is too close to %v", anon, last) default: - llat = l + llat = latlng } } @@ -65,14 +65,14 @@ func TestPointAnonymize(t *testing.T) { last := geo.MakePoint(llat, llng) cur := geo.MakePoint(lat, lng) anon := cur.Quantize() - switch l, g, err := anon.LatLng(); { + switch latlng, g, err := anon.LatLng(); { case err != nil: t.Fatal(err) case lng == dateLine: // initialize llng, to the first snapped longitude llng = g continue - case l != llat: + case latlng != llat: t.Fatalf("%v is north or south of %v", anon, last) case g != llng: const tolerance = geo.MinSeparation * 0x1p-9 diff --git a/types/key/disco.go b/types/key/disco.go index ce5f9b36fd9a1..52b40c766fbbf 100644 --- a/types/key/disco.go +++ b/types/key/disco.go @@ -167,11 +167,11 @@ func (k DiscoPublic) String() string { } // Compare returns an integer comparing DiscoPublic k and l lexicographically. -// The result will be 0 if k == l, -1 if k < l, and +1 if k > l. This is useful -// for situations requiring only one node in a pair to perform some operation, -// e.g. probing UDP path lifetime. -func (k DiscoPublic) Compare(l DiscoPublic) int { - return bytes.Compare(k.k[:], l.k[:]) +// The result will be 0 if k == other, -1 if k < other, and +1 if k > other. +// This is useful for situations requiring only one node in a pair to perform +// some operation, e.g. probing UDP path lifetime. +func (k DiscoPublic) Compare(other DiscoPublic) int { + return bytes.Compare(k.k[:], other.k[:]) } // AppendText implements encoding.TextAppender. diff --git a/types/key/util.go b/types/key/util.go index bdb2a06f68e67..50fac827556aa 100644 --- a/types/key/util.go +++ b/types/key/util.go @@ -10,9 +10,12 @@ import ( "errors" "fmt" "io" + "reflect" "slices" "go4.org/mem" + "tailscale.com/util/set" + "tailscale.com/util/testenv" ) // rand fills b with cryptographically strong random bytes. Panics if @@ -115,3 +118,18 @@ func debug32(k [32]byte) string { dst[6] = ']' return string(dst[:7]) } + +// PrivateTypesForTest returns the set of private key types +// in this package, for testing purposes. +func PrivateTypesForTest() set.Set[reflect.Type] { + testenv.AssertInTest() + return set.Of( + reflect.TypeFor[ChallengePrivate](), + reflect.TypeFor[ControlPrivate](), + reflect.TypeFor[DiscoPrivate](), + reflect.TypeFor[MachinePrivate](), + reflect.TypeFor[NodePrivate](), + reflect.TypeFor[NLPrivate](), + reflect.TypeFor[HardwareAttestationKey](), + ) +} diff --git a/types/netlogtype/netlogtype.go b/types/netlogtype/netlogtype.go index a29ea6f03dffa..cc38684a30dbf 100644 --- a/types/netlogtype/netlogtype.go +++ b/types/netlogtype/netlogtype.go @@ -21,6 +21,9 @@ type Message struct { Start time.Time `json:"start"` // inclusive End time.Time `json:"end"` // inclusive + SrcNode Node `json:"srcNode,omitzero"` + DstNodes []Node `json:"dstNodes,omitempty"` + VirtualTraffic []ConnectionCounts `json:"virtualTraffic,omitempty"` SubnetTraffic []ConnectionCounts `json:"subnetTraffic,omitempty"` ExitTraffic []ConnectionCounts `json:"exitTraffic,omitempty"` @@ -28,14 +31,18 @@ type Message struct { } const ( - messageJSON = `{"nodeId":"n0123456789abcdefCNTRL",` + maxJSONTimeRange + `,` + minJSONTraffic + `}` + messageJSON = `{"nodeId":` + maxJSONStableID + `,` + minJSONNodes + `,` + maxJSONTimeRange + `,` + minJSONTraffic + `}` + maxJSONStableID = `"n0123456789abcdefCNTRL"` + minJSONNodes = `"srcNode":{},"dstNodes":[]` maxJSONTimeRange = `"start":` + maxJSONRFC3339 + `,"end":` + maxJSONRFC3339 maxJSONRFC3339 = `"0001-01-01T00:00:00.000000000Z"` minJSONTraffic = `"virtualTraffic":{},"subnetTraffic":{},"exitTraffic":{},"physicalTraffic":{}` - // MaxMessageJSONSize is the overhead size of Message when it is - // serialized as JSON assuming that each traffic map is populated. - MaxMessageJSONSize = len(messageJSON) + // MinMessageJSONSize is the overhead size of Message when it is + // serialized as JSON assuming that each field is minimally populated. + // Each [Node] occupies at least [MinNodeJSONSize]. + // Each [ConnectionCounts] occupies at most [MaxConnectionCountsJSONSize]. + MinMessageJSONSize = len(messageJSON) maxJSONConnCounts = `{` + maxJSONConn + `,` + maxJSONCounts + `}` maxJSONConn = `"proto":` + maxJSONProto + `,"src":` + maxJSONAddrPort + `,"dst":` + maxJSONAddrPort @@ -52,6 +59,29 @@ const ( MaxConnectionCountsJSONSize = len(maxJSONConnCounts) ) +// Node is information about a node. +type Node struct { + // NodeID is the stable ID of the node. + NodeID tailcfg.StableNodeID `json:"nodeId"` + + // Name is the fully-qualified name of the node. + Name string `json:"name,omitzero"` // e.g., "carbonite.example.ts.net" + + // Addresses are the Tailscale IP addresses of the node. + Addresses []netip.Addr `json:"addresses,omitempty"` + + // OS is the operating system of the node. + OS string `json:"os,omitzero"` // e.g., "linux" + + // User is the user that owns the node. + // It is not populated if the node is tagged. + User string `json:"user,omitzero"` // e.g., "johndoe@example.com" + + // Tags are the tags of the node. + // It is not populated if the node is owned by a user. + Tags []string `json:"tags,omitempty"` // e.g., ["tag:prod","tag:logs"] +} + // ConnectionCounts is a flattened struct of both a connection and counts. type ConnectionCounts struct { Connection diff --git a/types/netmap/netmap.go b/types/netmap/netmap.go index cc6bec1db8edb..c54562f4d5b53 100644 --- a/types/netmap/netmap.go +++ b/types/netmap/netmap.go @@ -26,14 +26,9 @@ import ( // The fields should all be considered read-only. They might // alias parts of previous NetworkMap values. type NetworkMap struct { - SelfNode tailcfg.NodeView - AllCaps set.Set[tailcfg.NodeCapability] // set version of SelfNode.Capabilities + SelfNode.CapMap - NodeKey key.NodePublic - PrivateKey key.NodePrivate - Expiry time.Time - // Name is the DNS name assigned to this node. - // It is the MapResponse.Node.Name value and ends with a period. - Name string + SelfNode tailcfg.NodeView + AllCaps set.Set[tailcfg.NodeCapability] // set version of SelfNode.Capabilities + SelfNode.CapMap + NodeKey key.NodePublic MachineKey key.MachinePublic @@ -236,10 +231,25 @@ func MagicDNSSuffixOfNodeName(nodeName string) string { // // It will neither start nor end with a period. func (nm *NetworkMap) MagicDNSSuffix() string { - if nm == nil { + return MagicDNSSuffixOfNodeName(nm.SelfName()) +} + +// SelfName returns nm.SelfNode.Name, or the empty string +// if nm is nil or nm.SelfNode is invalid. +func (nm *NetworkMap) SelfName() string { + if nm == nil || !nm.SelfNode.Valid() { return "" } - return MagicDNSSuffixOfNodeName(nm.Name) + return nm.SelfNode.Name() +} + +// SelfKeyExpiry returns nm.SelfNode.KeyExpiry, or the zero +// value if nil or nm.SelfNode is invalid. +func (nm *NetworkMap) SelfKeyExpiry() time.Time { + if nm == nil || !nm.SelfNode.Valid() { + return time.Time{} + } + return nm.SelfNode.KeyExpiry() } // DomainName returns the name of the NetworkMap's diff --git a/types/netmap/netmap_test.go b/types/netmap/netmap_test.go index 40f504741bfea..ee4fecdb4ff4e 100644 --- a/types/netmap/netmap_test.go +++ b/types/netmap/netmap_test.go @@ -6,11 +6,13 @@ package netmap import ( "encoding/hex" "net/netip" + "reflect" "testing" "go4.org/mem" "tailscale.com/net/netaddr" "tailscale.com/tailcfg" + "tailscale.com/tstest/typewalk" "tailscale.com/types/key" ) @@ -316,3 +318,10 @@ func TestPeerIndexByNodeID(t *testing.T) { } } } + +func TestNoPrivateKeyMaterial(t *testing.T) { + private := key.PrivateTypesForTest() + for path := range typewalk.MatchingPaths(reflect.TypeFor[NetworkMap](), private.Contains) { + t.Errorf("NetworkMap contains private key material at path: %q", path.Name) + } +} diff --git a/types/netmap/nodemut.go b/types/netmap/nodemut.go index f4de1bf0b8f02..4f93be21c6d68 100644 --- a/types/netmap/nodemut.go +++ b/types/netmap/nodemut.go @@ -177,5 +177,5 @@ func mapResponseContainsNonPatchFields(res *tailcfg.MapResponse) bool { // function is called, so it should never be set anyway. But for // completedness, and for tests, check it too: res.PeersChanged != nil || - res.DefaultAutoUpdate != "" + res.DeprecatedDefaultAutoUpdate != "" } diff --git a/types/opt/bool.go b/types/opt/bool.go index e2fd6a054ff0d..fbc39e1dc3754 100644 --- a/types/opt/bool.go +++ b/types/opt/bool.go @@ -83,6 +83,17 @@ func (b *Bool) Scan(src any) error { } } +// Normalized returns the normalized form of b, mapping "unset" to "" +// and leaving other values unchanged. +func (b Bool) Normalized() Bool { + switch b { + case ExplicitlyUnset: + return Empty + default: + return b + } +} + // EqualBool reports whether b is equal to v. // If b is empty or not a valid bool, it reports false. func (b Bool) EqualBool(v bool) bool { diff --git a/types/opt/bool_test.go b/types/opt/bool_test.go index dddbcfc195d04..e61d66dbe9e96 100644 --- a/types/opt/bool_test.go +++ b/types/opt/bool_test.go @@ -106,6 +106,8 @@ func TestBoolEqualBool(t *testing.T) { }{ {"", true, false}, {"", false, false}, + {"unset", true, false}, + {"unset", false, false}, {"sdflk;", true, false}, {"sldkf;", false, false}, {"true", true, true}, @@ -122,6 +124,24 @@ func TestBoolEqualBool(t *testing.T) { } } +func TestBoolNormalized(t *testing.T) { + tests := []struct { + in Bool + want Bool + }{ + {"", ""}, + {"true", "true"}, + {"false", "false"}, + {"unset", ""}, + {"foo", "foo"}, + } + for _, tt := range tests { + if got := tt.in.Normalized(); got != tt.want { + t.Errorf("(%q).Normalized() = %q; want %q", string(tt.in), string(got), string(tt.want)) + } + } +} + func TestUnmarshalAlloc(t *testing.T) { b := json.Unmarshaler(new(Bool)) n := testing.AllocsPerRun(10, func() { b.UnmarshalJSON(trueBytes) }) diff --git a/types/persist/persist.go b/types/persist/persist.go index d888a6afb6af5..80bac9b5e2741 100644 --- a/types/persist/persist.go +++ b/types/persist/persist.go @@ -26,6 +26,7 @@ type Persist struct { UserProfile tailcfg.UserProfile NetworkLockKey key.NLPrivate NodeID tailcfg.StableNodeID + AttestationKey key.HardwareAttestationKey `json:",omitzero"` // DisallowedTKAStateIDs stores the tka.State.StateID values which // this node will not operate network lock on. This is used to @@ -84,11 +85,20 @@ func (p *Persist) Equals(p2 *Persist) bool { return false } + var pub, p2Pub key.HardwareAttestationPublic + if p.AttestationKey != nil && !p.AttestationKey.IsZero() { + pub = key.HardwareAttestationPublicFromPlatformKey(p.AttestationKey) + } + if p2.AttestationKey != nil && !p2.AttestationKey.IsZero() { + p2Pub = key.HardwareAttestationPublicFromPlatformKey(p2.AttestationKey) + } + return p.PrivateNodeKey.Equal(p2.PrivateNodeKey) && p.OldPrivateNodeKey.Equal(p2.OldPrivateNodeKey) && p.UserProfile.Equal(&p2.UserProfile) && p.NetworkLockKey.Equal(p2.NetworkLockKey) && p.NodeID == p2.NodeID && + pub.Equal(p2Pub) && reflect.DeepEqual(nilIfEmpty(p.DisallowedTKAStateIDs), nilIfEmpty(p2.DisallowedTKAStateIDs)) } @@ -96,12 +106,16 @@ func (p *Persist) Pretty() string { var ( ok, nk key.NodePublic ) + akString := "-" if !p.OldPrivateNodeKey.IsZero() { ok = p.OldPrivateNodeKey.Public() } if !p.PrivateNodeKey.IsZero() { nk = p.PublicNodeKey() } - return fmt.Sprintf("Persist{o=%v, n=%v u=%#v}", - ok.ShortString(), nk.ShortString(), p.UserProfile.LoginName) + if p.AttestationKey != nil && !p.AttestationKey.IsZero() { + akString = fmt.Sprintf("%v", p.AttestationKey.Public()) + } + return fmt.Sprintf("Persist{o=%v, n=%v u=%#v ak=%s}", + ok.ShortString(), nk.ShortString(), p.UserProfile.LoginName, akString) } diff --git a/types/persist/persist_clone.go b/types/persist/persist_clone.go index 680419ff2f30b..9dbe7e0f6fa6d 100644 --- a/types/persist/persist_clone.go +++ b/types/persist/persist_clone.go @@ -19,6 +19,9 @@ func (src *Persist) Clone() *Persist { } dst := new(Persist) *dst = *src + if src.AttestationKey != nil { + dst.AttestationKey = src.AttestationKey.Clone() + } dst.DisallowedTKAStateIDs = append(src.DisallowedTKAStateIDs[:0:0], src.DisallowedTKAStateIDs...) return dst } @@ -31,5 +34,6 @@ var _PersistCloneNeedsRegeneration = Persist(struct { UserProfile tailcfg.UserProfile NetworkLockKey key.NLPrivate NodeID tailcfg.StableNodeID + AttestationKey key.HardwareAttestationKey DisallowedTKAStateIDs []string }{}) diff --git a/types/persist/persist_test.go b/types/persist/persist_test.go index dbf2a6d8c7662..713114b74dcd5 100644 --- a/types/persist/persist_test.go +++ b/types/persist/persist_test.go @@ -21,7 +21,7 @@ func fieldsOf(t reflect.Type) (fields []string) { } func TestPersistEqual(t *testing.T) { - persistHandles := []string{"PrivateNodeKey", "OldPrivateNodeKey", "UserProfile", "NetworkLockKey", "NodeID", "DisallowedTKAStateIDs"} + persistHandles := []string{"PrivateNodeKey", "OldPrivateNodeKey", "UserProfile", "NetworkLockKey", "NodeID", "AttestationKey", "DisallowedTKAStateIDs"} if have := fieldsOf(reflect.TypeFor[Persist]()); !reflect.DeepEqual(have, persistHandles) { t.Errorf("Persist.Equal check might be out of sync\nfields: %q\nhandled: %q\n", have, persistHandles) diff --git a/types/persist/persist_view.go b/types/persist/persist_view.go index 7d1507468fc65..dbf8294ef5a7a 100644 --- a/types/persist/persist_view.go +++ b/types/persist/persist_view.go @@ -89,10 +89,11 @@ func (v *PersistView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { func (v PersistView) PrivateNodeKey() key.NodePrivate { return v.ж.PrivateNodeKey } // needed to request key rotation -func (v PersistView) OldPrivateNodeKey() key.NodePrivate { return v.ж.OldPrivateNodeKey } -func (v PersistView) UserProfile() tailcfg.UserProfile { return v.ж.UserProfile } -func (v PersistView) NetworkLockKey() key.NLPrivate { return v.ж.NetworkLockKey } -func (v PersistView) NodeID() tailcfg.StableNodeID { return v.ж.NodeID } +func (v PersistView) OldPrivateNodeKey() key.NodePrivate { return v.ж.OldPrivateNodeKey } +func (v PersistView) UserProfile() tailcfg.UserProfile { return v.ж.UserProfile } +func (v PersistView) NetworkLockKey() key.NLPrivate { return v.ж.NetworkLockKey } +func (v PersistView) NodeID() tailcfg.StableNodeID { return v.ж.NodeID } +func (v PersistView) AttestationKey() tailcfg.StableNodeID { panic("unsupported") } // DisallowedTKAStateIDs stores the tka.State.StateID values which // this node will not operate network lock on. This is used to @@ -110,5 +111,6 @@ var _PersistViewNeedsRegeneration = Persist(struct { UserProfile tailcfg.UserProfile NetworkLockKey key.NLPrivate NodeID tailcfg.StableNodeID + AttestationKey key.HardwareAttestationKey DisallowedTKAStateIDs []string }{}) diff --git a/types/prefs/list.go b/types/prefs/list.go index 7db473887d195..ae6b2fae335db 100644 --- a/types/prefs/list.go +++ b/types/prefs/list.go @@ -45,36 +45,36 @@ func ListWithOpts[T ImmutableType](opts ...Options) List[T] { // SetValue configures the preference with the specified value. // It fails and returns [ErrManaged] if p is a managed preference, // and [ErrReadOnly] if p is a read-only preference. -func (l *List[T]) SetValue(val []T) error { - return l.preference.SetValue(cloneSlice(val)) +func (ls *List[T]) SetValue(val []T) error { + return ls.preference.SetValue(cloneSlice(val)) } // SetManagedValue configures the preference with the specified value // and marks the preference as managed. -func (l *List[T]) SetManagedValue(val []T) { - l.preference.SetManagedValue(cloneSlice(val)) +func (ls *List[T]) SetManagedValue(val []T) { + ls.preference.SetManagedValue(cloneSlice(val)) } // View returns a read-only view of l. -func (l *List[T]) View() ListView[T] { - return ListView[T]{l} +func (ls *List[T]) View() ListView[T] { + return ListView[T]{ls} } // Clone returns a copy of l that aliases no memory with l. -func (l List[T]) Clone() *List[T] { - res := ptr.To(l) - if v, ok := l.s.Value.GetOk(); ok { +func (ls List[T]) Clone() *List[T] { + res := ptr.To(ls) + if v, ok := ls.s.Value.GetOk(); ok { res.s.Value.Set(append(v[:0:0], v...)) } return res } // Equal reports whether l and l2 are equal. -func (l List[T]) Equal(l2 List[T]) bool { - if l.s.Metadata != l2.s.Metadata { +func (ls List[T]) Equal(l2 List[T]) bool { + if ls.s.Metadata != l2.s.Metadata { return false } - v1, ok1 := l.s.Value.GetOk() + v1, ok1 := ls.s.Value.GetOk() v2, ok2 := l2.s.Value.GetOk() if ok1 != ok2 { return false diff --git a/types/prefs/prefs_test.go b/types/prefs/prefs_test.go index d6af745bf83b8..dc1213adb27ab 100644 --- a/types/prefs/prefs_test.go +++ b/types/prefs/prefs_test.go @@ -487,31 +487,31 @@ func TestItemView(t *testing.T) { } func TestListView(t *testing.T) { - l := ListOf([]int{4, 8, 15, 16, 23, 42}, ReadOnly) + ls := ListOf([]int{4, 8, 15, 16, 23, 42}, ReadOnly) - lv := l.View() + lv := ls.View() checkIsSet(t, lv, true) checkIsManaged(t, lv, false) checkIsReadOnly(t, lv, true) - checkValue(t, lv, views.SliceOf(l.Value())) - checkValueOk(t, lv, views.SliceOf(l.Value()), true) + checkValue(t, lv, views.SliceOf(ls.Value())) + checkValueOk(t, lv, views.SliceOf(ls.Value()), true) l2 := *lv.AsStruct() - checkEqual(t, l, l2, true) + checkEqual(t, ls, l2, true) } func TestStructListView(t *testing.T) { - l := StructListOf([]*TestBundle{{Name: "E1"}, {Name: "E2"}}, ReadOnly) + ls := StructListOf([]*TestBundle{{Name: "E1"}, {Name: "E2"}}, ReadOnly) - lv := StructListViewOf(&l) + lv := StructListViewOf(&ls) checkIsSet(t, lv, true) checkIsManaged(t, lv, false) checkIsReadOnly(t, lv, true) - checkValue(t, lv, views.SliceOfViews(l.Value())) - checkValueOk(t, lv, views.SliceOfViews(l.Value()), true) + checkValue(t, lv, views.SliceOfViews(ls.Value())) + checkValueOk(t, lv, views.SliceOfViews(ls.Value()), true) l2 := *lv.AsStruct() - checkEqual(t, l, l2, true) + checkEqual(t, ls, l2, true) } func TestStructMapView(t *testing.T) { diff --git a/types/prefs/struct_list.go b/types/prefs/struct_list.go index 65f11011af8fb..ba145e2cf7086 100644 --- a/types/prefs/struct_list.go +++ b/types/prefs/struct_list.go @@ -33,20 +33,20 @@ func StructListWithOpts[T views.Cloner[T]](opts ...Options) StructList[T] { // SetValue configures the preference with the specified value. // It fails and returns [ErrManaged] if p is a managed preference, // and [ErrReadOnly] if p is a read-only preference. -func (l *StructList[T]) SetValue(val []T) error { - return l.preference.SetValue(deepCloneSlice(val)) +func (ls *StructList[T]) SetValue(val []T) error { + return ls.preference.SetValue(deepCloneSlice(val)) } // SetManagedValue configures the preference with the specified value // and marks the preference as managed. -func (l *StructList[T]) SetManagedValue(val []T) { - l.preference.SetManagedValue(deepCloneSlice(val)) +func (ls *StructList[T]) SetManagedValue(val []T) { + ls.preference.SetManagedValue(deepCloneSlice(val)) } // Clone returns a copy of l that aliases no memory with l. -func (l StructList[T]) Clone() *StructList[T] { - res := ptr.To(l) - if v, ok := l.s.Value.GetOk(); ok { +func (ls StructList[T]) Clone() *StructList[T] { + res := ptr.To(ls) + if v, ok := ls.s.Value.GetOk(); ok { res.s.Value.Set(deepCloneSlice(v)) } return res @@ -56,11 +56,11 @@ func (l StructList[T]) Clone() *StructList[T] { // If the template type T implements an Equal(T) bool method, it will be used // instead of the == operator for value comparison. // It panics if T is not comparable. -func (l StructList[T]) Equal(l2 StructList[T]) bool { - if l.s.Metadata != l2.s.Metadata { +func (ls StructList[T]) Equal(l2 StructList[T]) bool { + if ls.s.Metadata != l2.s.Metadata { return false } - v1, ok1 := l.s.Value.GetOk() + v1, ok1 := ls.s.Value.GetOk() v2, ok2 := l2.s.Value.GetOk() if ok1 != ok2 { return false @@ -105,8 +105,8 @@ type StructListView[T views.ViewCloner[T, V], V views.StructView[T]] struct { // StructListViewOf returns a read-only view of l. // It is used by [tailscale.com/cmd/viewer]. -func StructListViewOf[T views.ViewCloner[T, V], V views.StructView[T]](l *StructList[T]) StructListView[T, V] { - return StructListView[T, V]{l} +func StructListViewOf[T views.ViewCloner[T, V], V views.StructView[T]](ls *StructList[T]) StructListView[T, V] { + return StructListView[T, V]{ls} } // Valid reports whether the underlying [StructList] is non-nil. diff --git a/types/prefs/struct_map.go b/types/prefs/struct_map.go index a081f7c7468e2..83cc7447baedd 100644 --- a/types/prefs/struct_map.go +++ b/types/prefs/struct_map.go @@ -31,14 +31,14 @@ func StructMapWithOpts[K MapKeyType, V views.Cloner[V]](opts ...Options) StructM // SetValue configures the preference with the specified value. // It fails and returns [ErrManaged] if p is a managed preference, // and [ErrReadOnly] if p is a read-only preference. -func (l *StructMap[K, V]) SetValue(val map[K]V) error { - return l.preference.SetValue(deepCloneMap(val)) +func (m *StructMap[K, V]) SetValue(val map[K]V) error { + return m.preference.SetValue(deepCloneMap(val)) } // SetManagedValue configures the preference with the specified value // and marks the preference as managed. -func (l *StructMap[K, V]) SetManagedValue(val map[K]V) { - l.preference.SetManagedValue(deepCloneMap(val)) +func (m *StructMap[K, V]) SetManagedValue(val map[K]V) { + m.preference.SetManagedValue(deepCloneMap(val)) } // Clone returns a copy of m that aliases no memory with m. diff --git a/util/backoff/backoff.go b/util/backoff/backoff.go index c6aeae998fa27..95089fc2479ff 100644 --- a/util/backoff/backoff.go +++ b/util/backoff/backoff.go @@ -78,3 +78,9 @@ func (b *Backoff) BackOff(ctx context.Context, err error) { case <-tChannel: } } + +// Reset resets the backoff schedule, equivalent to calling BackOff with a nil +// error. +func (b *Backoff) Reset() { + b.n = 0 +} diff --git a/util/cache/cache_test.go b/util/cache/cache_test.go deleted file mode 100644 index a6683e12dd772..0000000000000 --- a/util/cache/cache_test.go +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package cache - -import ( - "errors" - "testing" - "time" -) - -var startTime = time.Date(2023, time.March, 1, 0, 0, 0, 0, time.UTC) - -func TestSingleCache(t *testing.T) { - testTime := startTime - timeNow := func() time.Time { return testTime } - c := &Single[string, int]{ - timeNow: timeNow, - } - - t.Run("NoServeExpired", func(t *testing.T) { - testCacheImpl(t, c, &testTime, false) - }) - - t.Run("ServeExpired", func(t *testing.T) { - c.Empty() - c.ServeExpired = true - testTime = startTime - testCacheImpl(t, c, &testTime, true) - }) -} - -func TestLocking(t *testing.T) { - testTime := startTime - timeNow := func() time.Time { return testTime } - c := NewLocking(&Single[string, int]{ - timeNow: timeNow, - }) - - // Just verify that the inner cache's behaviour hasn't changed. - testCacheImpl(t, c, &testTime, false) -} - -func testCacheImpl(t *testing.T, c Cache[string, int], testTime *time.Time, serveExpired bool) { - var fillTime time.Time - t.Run("InitialFill", func(t *testing.T) { - fillTime = testTime.Add(time.Hour) - val, err := c.Get("key", func() (int, time.Time, error) { - return 123, fillTime, nil - }) - if err != nil { - t.Fatal(err) - } - if val != 123 { - t.Fatalf("got val=%d; want 123", val) - } - }) - - // Fetching again won't call our fill function - t.Run("SecondFetch", func(t *testing.T) { - *testTime = fillTime.Add(-1 * time.Second) - called := false - val, err := c.Get("key", func() (int, time.Time, error) { - called = true - return -1, fillTime, nil - }) - if called { - t.Fatal("wanted no call to fill function") - } - if err != nil { - t.Fatal(err) - } - if val != 123 { - t.Fatalf("got val=%d; want 123", val) - } - }) - - // Fetching after the expiry time will re-fill - t.Run("ReFill", func(t *testing.T) { - *testTime = fillTime.Add(1) - fillTime = fillTime.Add(time.Hour) - val, err := c.Get("key", func() (int, time.Time, error) { - return 999, fillTime, nil - }) - if err != nil { - t.Fatal(err) - } - if val != 999 { - t.Fatalf("got val=%d; want 999", val) - } - }) - - // An error on fetch will serve the expired value. - t.Run("FetchError", func(t *testing.T) { - if !serveExpired { - t.Skipf("not testing ServeExpired") - } - - *testTime = fillTime.Add(time.Hour + 1) - val, err := c.Get("key", func() (int, time.Time, error) { - return 0, time.Time{}, errors.New("some error") - }) - if err != nil { - t.Fatal(err) - } - if val != 999 { - t.Fatalf("got val=%d; want 999", val) - } - }) - - // Fetching a different key re-fills - t.Run("DifferentKey", func(t *testing.T) { - *testTime = fillTime.Add(time.Hour + 1) - - var calls int - val, err := c.Get("key1", func() (int, time.Time, error) { - calls++ - return 123, fillTime, nil - }) - if err != nil { - t.Fatal(err) - } - if val != 123 { - t.Fatalf("got val=%d; want 123", val) - } - if calls != 1 { - t.Errorf("got %d, want 1 call", calls) - } - - val, err = c.Get("key2", func() (int, time.Time, error) { - calls++ - return 456, fillTime, nil - }) - if err != nil { - t.Fatal(err) - } - if val != 456 { - t.Fatalf("got val=%d; want 456", val) - } - if calls != 2 { - t.Errorf("got %d, want 2 call", calls) - } - }) - - // Calling Forget with the wrong key does nothing, and with the correct - // key will drop the cache. - t.Run("Forget", func(t *testing.T) { - // Add some time so that previously-cached values don't matter. - fillTime = testTime.Add(2 * time.Hour) - *testTime = fillTime.Add(-1 * time.Second) - - const key = "key" - - var calls int - val, err := c.Get(key, func() (int, time.Time, error) { - calls++ - return 123, fillTime, nil - }) - if err != nil { - t.Fatal(err) - } - if val != 123 { - t.Fatalf("got val=%d; want 123", val) - } - if calls != 1 { - t.Errorf("got %d, want 1 call", calls) - } - - // Forgetting the wrong key does nothing - c.Forget("other") - val, err = c.Get(key, func() (int, time.Time, error) { - t.Fatal("should not be called") - panic("unreachable") - }) - if err != nil { - t.Fatal(err) - } - if val != 123 { - t.Fatalf("got val=%d; want 123", val) - } - - // Forgetting the correct key re-fills - c.Forget(key) - - val, err = c.Get("key2", func() (int, time.Time, error) { - calls++ - return 456, fillTime, nil - }) - if err != nil { - t.Fatal(err) - } - if val != 456 { - t.Fatalf("got val=%d; want 456", val) - } - if calls != 2 { - t.Errorf("got %d, want 2 call", calls) - } - }) -} diff --git a/util/cache/interface.go b/util/cache/interface.go deleted file mode 100644 index 0db87ba0e2ff4..0000000000000 --- a/util/cache/interface.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package cache contains an interface for a cache around a typed value, and -// various cache implementations that implement that interface. -package cache - -import "time" - -// Cache is the interface for the cache types in this package. -// -// Functions in this interface take a key parameter, but it is valid for a -// cache type to hold a single value associated with a key, and simply drop the -// cached value if provided with a different key. -// -// It is valid for Cache implementations to be concurrency-safe or not, and -// each implementation should document this. If you need a concurrency-safe -// cache, an existing cache can be wrapped with a lock using NewLocking(inner). -// -// K and V should be types that can be successfully passed to json.Marshal. -type Cache[K comparable, V any] interface { - // Get should return a previously-cached value or call the provided - // FillFunc to obtain a new one. The provided key can be used either to - // allow multiple cached values, or to drop the cache if the key - // changes; either is valid. - Get(K, FillFunc[V]) (V, error) - - // Forget should remove the given key from the cache, if it is present. - // If it is not present, nothing should be done. - Forget(K) - - // Empty should empty the cache such that the next call to Get should - // call the provided FillFunc for all possible keys. - Empty() -} - -// FillFunc is the signature of a function for filling a cache. It should -// return the value to be cached, the time that the cached value is valid -// until, or an error. -type FillFunc[T any] func() (T, time.Time, error) diff --git a/util/cache/locking.go b/util/cache/locking.go deleted file mode 100644 index 85e44b360a9b0..0000000000000 --- a/util/cache/locking.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package cache - -import "sync" - -// Locking wraps an inner Cache implementation with a mutex, making it -// safe for concurrent use. All methods are serialized on the same mutex. -type Locking[K comparable, V any, C Cache[K, V]] struct { - sync.Mutex - inner C -} - -// NewLocking creates a new Locking cache wrapping inner. -func NewLocking[K comparable, V any, C Cache[K, V]](inner C) *Locking[K, V, C] { - return &Locking[K, V, C]{inner: inner} -} - -// Get implements Cache. -// -// The cache's mutex is held for the entire duration of this function, -// including while the FillFunc is being called. This function is not -// reentrant; attempting to call Get from a FillFunc will deadlock. -func (c *Locking[K, V, C]) Get(key K, f FillFunc[V]) (V, error) { - c.Lock() - defer c.Unlock() - return c.inner.Get(key, f) -} - -// Forget implements Cache. -func (c *Locking[K, V, C]) Forget(key K) { - c.Lock() - defer c.Unlock() - c.inner.Forget(key) -} - -// Empty implements Cache. -func (c *Locking[K, V, C]) Empty() { - c.Lock() - defer c.Unlock() - c.inner.Empty() -} diff --git a/util/cache/none.go b/util/cache/none.go deleted file mode 100644 index c4073e0d90cf3..0000000000000 --- a/util/cache/none.go +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package cache - -// None provides no caching and always calls the provided FillFunc. -// -// It is safe for concurrent use if the underlying FillFunc is. -type None[K comparable, V any] struct{} - -var _ Cache[int, int] = None[int, int]{} - -// Get always calls the provided FillFunc and returns what it does. -func (c None[K, V]) Get(_ K, f FillFunc[V]) (V, error) { - v, _, e := f() - return v, e -} - -// Forget implements Cache. -func (None[K, V]) Forget(K) {} - -// Empty implements Cache. -func (None[K, V]) Empty() {} diff --git a/util/cache/single.go b/util/cache/single.go deleted file mode 100644 index 6b9ac2c1193c6..0000000000000 --- a/util/cache/single.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package cache - -import ( - "time" -) - -// Single is a simple in-memory cache that stores a single value until a -// defined time before it is re-fetched. It also supports returning a -// previously-expired value if refreshing the value in the cache fails. -// -// Single is not safe for concurrent use. -type Single[K comparable, V any] struct { - key K - val V - goodUntil time.Time - timeNow func() time.Time // for tests - - // ServeExpired indicates that if an error occurs when filling the - // cache, an expired value can be returned instead of an error. - // - // This value should only be set when this struct is created. - ServeExpired bool -} - -var _ Cache[int, int] = (*Single[int, int])(nil) - -// Get will return the cached value, if any, or fill the cache by calling f and -// return the corresponding value. If f returns an error and c.ServeExpired is -// true, then a previous expired value can be returned with no error. -func (c *Single[K, V]) Get(key K, f FillFunc[V]) (V, error) { - var now time.Time - if c.timeNow != nil { - now = c.timeNow() - } else { - now = time.Now() - } - - if c.key == key && now.Before(c.goodUntil) { - return c.val, nil - } - - // Re-fill cached entry - val, until, err := f() - if err == nil { - c.key = key - c.val = val - c.goodUntil = until - return val, nil - } - - // Never serve an expired entry for the wrong key. - if c.key == key && c.ServeExpired && !c.goodUntil.IsZero() { - return c.val, nil - } - - var zero V - return zero, err -} - -// Forget implements Cache. -func (c *Single[K, V]) Forget(key K) { - if c.key != key { - return - } - - c.Empty() -} - -// Empty implements Cache. -func (c *Single[K, V]) Empty() { - c.goodUntil = time.Time{} - - var zeroKey K - c.key = zeroKey - - var zeroVal V - c.val = zeroVal -} diff --git a/util/clientmetric/clientmetric.go b/util/clientmetric/clientmetric.go index 65223e6a9375a..9e6b03a15ce93 100644 --- a/util/clientmetric/clientmetric.go +++ b/util/clientmetric/clientmetric.go @@ -133,15 +133,18 @@ func (m *Metric) Publish() { metrics[m.name] = m sortedDirty = true + if m.f == nil { + if len(valFreeList) == 0 { + valFreeList = make([]int64, 256) + } + m.v = &valFreeList[0] + valFreeList = valFreeList[1:] + } + if buildfeatures.HasLogTail { if m.f != nil { lastLogVal = append(lastLogVal, scanEntry{f: m.f}) } else { - if len(valFreeList) == 0 { - valFreeList = make([]int64, 256) - } - m.v = &valFreeList[0] - valFreeList = valFreeList[1:] lastLogVal = append(lastLogVal, scanEntry{v: m.v}) } } diff --git a/util/deephash/tailscale_types_test.go b/util/deephash/tailscale_types_test.go index d760253990048..eeb7fdf84d11f 100644 --- a/util/deephash/tailscale_types_test.go +++ b/util/deephash/tailscale_types_test.go @@ -85,7 +85,6 @@ type tailscaleTypes struct { func getVal() *tailscaleTypes { return &tailscaleTypes{ &wgcfg.Config{ - Name: "foo", Addresses: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{3: 3}).Unmap(), 5)}, Peers: []wgcfg.Peer{ { diff --git a/util/dnsname/dnsname.go b/util/dnsname/dnsname.go index 6404a9af1cc2f..ef898ebbd842f 100644 --- a/util/dnsname/dnsname.go +++ b/util/dnsname/dnsname.go @@ -14,7 +14,7 @@ const ( // maxLabelLength is the maximum length of a label permitted by RFC 1035. maxLabelLength = 63 // maxNameLength is the maximum length of a DNS name. - maxNameLength = 253 + maxNameLength = 254 ) // A FQDN is a fully-qualified DNS name or name suffix. diff --git a/util/dnsname/dnsname_test.go b/util/dnsname/dnsname_test.go index 719e28be3966b..b038bb1bd10e1 100644 --- a/util/dnsname/dnsname_test.go +++ b/util/dnsname/dnsname_test.go @@ -59,6 +59,38 @@ func TestFQDN(t *testing.T) { } } +func TestFQDNTooLong(t *testing.T) { + // RFC 1035 says a dns name has a max size of 255 octets, and is represented as labels of len+ASCII chars so + // example.com + // is represented as + // 7example3com0 + // which is to say that if we have a trailing dot then the dots cancel out all the len bytes except the first and + // we can accept 254 chars. + + // This name is max length + name := "aaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaa.example.com." + if len(name) != 254 { + t.Fatalf("name should be 254 chars including trailing . (len is %d)", len(name)) + } + got, err := ToFQDN(name) + if err != nil { + t.Fatalf("want: no error, got: %v", err) + } + if string(got) != name { + t.Fatalf("want: %s, got: %s", name, got) + } + + // This name is too long + name = "x" + name + got, err = ToFQDN(name) + if got != "" { + t.Fatalf("want: \"\", got: %s", got) + } + if err == nil || !strings.HasSuffix(err.Error(), "is too long to be a DNS name") { + t.Fatalf("want: error to end with \"is too long to be a DNS name\", got: %v", err) + } +} + func TestFQDNContains(t *testing.T) { tests := []struct { a, b string diff --git a/util/eventbus/bus.go b/util/eventbus/bus.go index b1639136a5133..880e075ccaf3c 100644 --- a/util/eventbus/bus.go +++ b/util/eventbus/bus.go @@ -8,8 +8,8 @@ import ( "log" "reflect" "slices" - "sync" + "tailscale.com/syncs" "tailscale.com/types/logger" "tailscale.com/util/set" ) @@ -34,12 +34,12 @@ type Bus struct { routeDebug hook[RoutedEvent] logf logger.Logf - topicsMu sync.Mutex + topicsMu syncs.Mutex topics map[reflect.Type][]*subscribeState // Used for introspection/debugging only, not in the normal event // publishing path. - clientsMu sync.Mutex + clientsMu syncs.Mutex clients set.Set[*Client] } @@ -120,7 +120,14 @@ func (b *Bus) Close() { } func (b *Bus) pump(ctx context.Context) { - var vals queue[PublishedEvent] + // Limit how many published events we can buffer in the PublishedEvent queue. + // + // Subscribers have unbounded DeliveredEvent queues (see tailscale/tailscale#18020), + // so this queue doesn't need to be unbounded. Keeping it bounded may also help + // catch cases where subscribers stop pumping events completely, such as due to a bug + // in [subscribeState.pump], [Subscriber.dispatch], or [SubscriberFunc.dispatch]). + const maxPublishedEvents = 16 + vals := queue[PublishedEvent]{capacity: maxPublishedEvents} acceptCh := func() chan PublishedEvent { if vals.Full() { return nil @@ -134,7 +141,7 @@ func (b *Bus) pump(ctx context.Context) { // queue space for it. for !vals.Empty() { val := vals.Peek() - dests := b.dest(reflect.ValueOf(val.Event).Type()) + dests := b.dest(reflect.TypeOf(val.Event)) if b.routeDebug.active() { clients := make([]*Client, len(dests)) @@ -306,7 +313,7 @@ func (w *worker) StopAndWait() { type stopFlag struct { // guards the lazy construction of stopped, and the value of // alreadyStopped. - mu sync.Mutex + mu syncs.Mutex stopped chan struct{} alreadyStopped bool } diff --git a/util/eventbus/bus_test.go b/util/eventbus/bus_test.go index 1e0cd8abf2cff..88e11e7199aee 100644 --- a/util/eventbus/bus_test.go +++ b/util/eventbus/bus_test.go @@ -9,6 +9,7 @@ import ( "fmt" "log" "regexp" + "sync" "testing" "testing/synctest" "time" @@ -89,6 +90,61 @@ func TestSubscriberFunc(t *testing.T) { } }) + t.Run("CloseWait", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + b := eventbus.New() + defer b.Close() + + c := b.Client(t.Name()) + + eventbus.SubscribeFunc[EventA](c, func(e EventA) { + time.Sleep(2 * time.Second) + }) + + p := eventbus.Publish[EventA](c) + p.Publish(EventA{12345}) + + synctest.Wait() // subscriber has the event + c.Close() + + // If close does not wait for the subscriber, the test will fail + // because an active goroutine remains in the bubble. + }) + }) + + t.Run("CloseWait/Belated", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + buf := swapLogBuf(t) + + b := eventbus.New() + defer b.Close() + + c := b.Client(t.Name()) + + // This subscriber stalls for a long time, so that when we try to + // close the client it gives up and returns in the timeout condition. + eventbus.SubscribeFunc[EventA](c, func(e EventA) { + time.Sleep(time.Minute) // notably, longer than the wait period + }) + + p := eventbus.Publish[EventA](c) + p.Publish(EventA{12345}) + + synctest.Wait() // subscriber has the event + c.Close() + + // Verify that the logger recorded that Close gave up on the slowpoke. + want := regexp.MustCompile(`^.* tailscale.com/util/eventbus_test bus_test.go:\d+: ` + + `giving up on subscriber for eventbus_test.EventA after \d+s at close.*`) + if got := buf.String(); !want.MatchString(got) { + t.Errorf("Wrong log output\ngot: %q\nwant %s", got, want) + } + + // Wait for the subscriber to actually finish to clean up the goroutine. + time.Sleep(2 * time.Minute) + }) + }) + t.Run("SubscriberPublishes", func(t *testing.T) { synctest.Test(t, func(t *testing.T) { b := eventbus.New() @@ -440,14 +496,6 @@ func TestMonitor(t *testing.T) { } func TestSlowSubs(t *testing.T) { - swapLogBuf := func(t *testing.T) *bytes.Buffer { - logBuf := new(bytes.Buffer) - save := log.Writer() - log.SetOutput(logBuf) - t.Cleanup(func() { log.SetOutput(save) }) - return logBuf - } - t.Run("Subscriber", func(t *testing.T) { synctest.Test(t, func(t *testing.T) { buf := swapLogBuf(t) @@ -546,6 +594,105 @@ func TestRegression(t *testing.T) { }) } +func TestPublishWithMutex(t *testing.T) { + testPublishWithMutex(t, 1024) // arbitrary large number of events +} + +// testPublishWithMutex publishes the specified number of events, +// acquiring and releasing a mutex around each publish and each +// subscriber event receive. +// +// The test fails if it loses any events or times out due to a deadlock. +// Unfortunately, a goroutine waiting on a mutex held by a durably blocked +// goroutine is not itself considered durably blocked, so [synctest] cannot +// detect this deadlock on its own. +func testPublishWithMutex(t *testing.T, n int) { + synctest.Test(t, func(t *testing.T) { + b := eventbus.New() + defer b.Close() + + c := b.Client("TestClient") + + evts := make([]any, n) + for i := range evts { + evts[i] = EventA{Counter: i} + } + exp := expectEvents(t, evts...) + + var mu sync.Mutex + eventbus.SubscribeFunc[EventA](c, func(e EventA) { + // Acquire the same mutex as the publisher. + mu.Lock() + mu.Unlock() + + // Mark event as received, so we can check for lost events. + exp.Got(e) + }) + + p := eventbus.Publish[EventA](c) + go func() { + // Publish events, acquiring the mutex around each publish. + for i := range n { + mu.Lock() + p.Publish(EventA{Counter: i}) + mu.Unlock() + } + }() + + synctest.Wait() + + if !exp.Empty() { + t.Errorf("unexpected extra events: %+v", exp.want) + } + }) +} + +func TestPublishFromSubscriber(t *testing.T) { + testPublishFromSubscriber(t, 1024) // arbitrary large number of events +} + +// testPublishFromSubscriber publishes the specified number of EventA events. +// Each EventA causes the subscriber to publish an EventB. +// The test fails if it loses any events or if a deadlock occurs. +func testPublishFromSubscriber(t *testing.T, n int) { + synctest.Test(t, func(t *testing.T) { + b := eventbus.New() + defer b.Close() + + c := b.Client("TestClient") + + // Ultimately we expect to receive n EventB events + // published as a result of receiving n EventA events. + evts := make([]any, n) + for i := range evts { + evts[i] = EventB{Counter: i} + } + exp := expectEvents(t, evts...) + + pubA := eventbus.Publish[EventA](c) + pubB := eventbus.Publish[EventB](c) + + eventbus.SubscribeFunc[EventA](c, func(e EventA) { + // Upon receiving EventA, publish EventB. + pubB.Publish(EventB{Counter: e.Counter}) + }) + eventbus.SubscribeFunc[EventB](c, func(e EventB) { + // Mark EventB as received. + exp.Got(e) + }) + + for i := range n { + pubA.Publish(EventA{Counter: i}) + } + + synctest.Wait() + + if !exp.Empty() { + t.Errorf("unexpected extra events: %+v", exp.want) + } + }) +} + type queueChecker struct { t *testing.T want []any @@ -571,3 +718,11 @@ func (q *queueChecker) Got(v any) { func (q *queueChecker) Empty() bool { return len(q.want) == 0 } + +func swapLogBuf(t *testing.T) *bytes.Buffer { + logBuf := new(bytes.Buffer) + save := log.Writer() + log.SetOutput(logBuf) + t.Cleanup(func() { log.SetOutput(save) }) + return logBuf +} diff --git a/util/eventbus/client.go b/util/eventbus/client.go index c119c67a939c2..a7a5ab673bdfd 100644 --- a/util/eventbus/client.go +++ b/util/eventbus/client.go @@ -5,8 +5,8 @@ package eventbus import ( "reflect" - "sync" + "tailscale.com/syncs" "tailscale.com/types/logger" "tailscale.com/util/set" ) @@ -22,7 +22,7 @@ type Client struct { bus *Bus publishDebug hook[PublishedEvent] - mu sync.Mutex + mu syncs.Mutex pub set.Set[publisher] sub *subscribeState // Lazily created on first subscribe stop stopFlag // signaled on Close diff --git a/util/eventbus/debug.go b/util/eventbus/debug.go index 2f2c9589ad0e2..0453defb1a77e 100644 --- a/util/eventbus/debug.go +++ b/util/eventbus/debug.go @@ -11,10 +11,10 @@ import ( "runtime" "slices" "strings" - "sync" "sync/atomic" "time" + "tailscale.com/syncs" "tailscale.com/types/logger" ) @@ -147,7 +147,7 @@ func (d *Debugger) SubscribeTypes(client *Client) []reflect.Type { // A hook collects hook functions that can be run as a group. type hook[T any] struct { - sync.Mutex + syncs.Mutex fns []hookFn[T] } diff --git a/util/eventbus/queue.go b/util/eventbus/queue.go index a62bf3c62d1d4..2589b75cef999 100644 --- a/util/eventbus/queue.go +++ b/util/eventbus/queue.go @@ -7,18 +7,18 @@ import ( "slices" ) -const maxQueuedItems = 16 - -// queue is an ordered queue of length up to maxQueuedItems. +// queue is an ordered queue of length up to capacity, +// if capacity is non-zero. Otherwise it is unbounded. type queue[T any] struct { - vals []T - start int + vals []T + start int + capacity int // zero means unbounded } // canAppend reports whether a value can be appended to q.vals without // shifting values around. func (q *queue[T]) canAppend() bool { - return cap(q.vals) < maxQueuedItems || len(q.vals) < cap(q.vals) + return q.capacity == 0 || cap(q.vals) < q.capacity || len(q.vals) < cap(q.vals) } func (q *queue[T]) Full() bool { diff --git a/util/eventbus/subscribe.go b/util/eventbus/subscribe.go index 0b821b3f51586..b0348e125c393 100644 --- a/util/eventbus/subscribe.go +++ b/util/eventbus/subscribe.go @@ -7,10 +7,12 @@ import ( "context" "fmt" "reflect" - "sync" + "runtime" "time" + "tailscale.com/syncs" "tailscale.com/types/logger" + "tailscale.com/util/cibuild" ) type DeliveredEvent struct { @@ -49,7 +51,7 @@ type subscribeState struct { snapshot chan chan []DeliveredEvent debug hook[DeliveredEvent] - outputsMu sync.Mutex + outputsMu syncs.Mutex outputs map[reflect.Type]subscriber } @@ -324,6 +326,18 @@ func (s *SubscriberFunc[T]) dispatch(ctx context.Context, vals *queue[DeliveredE case val := <-acceptCh(): vals.Add(val) case <-ctx.Done(): + // Wait for the callback to be complete, but not forever. + s.slow.Reset(5 * slowSubscriberTimeout) + select { + case <-s.slow.C: + s.logf("giving up on subscriber for %T after %v at close", t, time.Since(start)) + if cibuild.On() { + all := make([]byte, 2<<20) + n := runtime.Stack(all, true) + s.logf("goroutine stacks:\n%s", all[:n]) + } + case <-callDone: + } return false case ch := <-snapshot: ch <- vals.Snapshot() diff --git a/util/execqueue/execqueue.go b/util/execqueue/execqueue.go index 889cea2555806..87616a6b50a45 100644 --- a/util/execqueue/execqueue.go +++ b/util/execqueue/execqueue.go @@ -7,11 +7,14 @@ package execqueue import ( "context" "errors" - "sync" + + "tailscale.com/syncs" ) type ExecQueue struct { - mu sync.Mutex + mu syncs.Mutex + ctx context.Context // context.Background + closed on Shutdown + cancel context.CancelFunc // closes ctx closed bool inFlight bool // whether a goroutine is running q.run doneWaiter chan struct{} // non-nil if waiter is waiting, then closed @@ -24,6 +27,7 @@ func (q *ExecQueue) Add(f func()) { if q.closed { return } + q.initCtxLocked() if q.inFlight { q.queue = append(q.queue, f) } else { @@ -35,21 +39,21 @@ func (q *ExecQueue) Add(f func()) { // RunSync waits for the queue to be drained and then synchronously runs f. // It returns an error if the queue is closed before f is run or ctx expires. func (q *ExecQueue) RunSync(ctx context.Context, f func()) error { - for { - if err := q.Wait(ctx); err != nil { - return err - } - q.mu.Lock() - if q.inFlight { - q.mu.Unlock() - continue - } - defer q.mu.Unlock() - if q.closed { - return errors.New("closed") - } - f() + q.mu.Lock() + q.initCtxLocked() + shutdownCtx := q.ctx + q.mu.Unlock() + + ch := make(chan struct{}) + q.Add(f) + q.Add(func() { close(ch) }) + select { + case <-ch: return nil + case <-ctx.Done(): + return ctx.Err() + case <-shutdownCtx.Done(): + return errExecQueueShutdown } } @@ -79,18 +83,35 @@ func (q *ExecQueue) Shutdown() { q.mu.Lock() defer q.mu.Unlock() q.closed = true + if q.cancel != nil { + q.cancel() + } } -// Wait waits for the queue to be empty. +func (q *ExecQueue) initCtxLocked() { + if q.ctx == nil { + q.ctx, q.cancel = context.WithCancel(context.Background()) + } +} + +var errExecQueueShutdown = errors.New("execqueue shut down") + +// Wait waits for the queue to be empty or shut down. func (q *ExecQueue) Wait(ctx context.Context) error { q.mu.Lock() + q.initCtxLocked() waitCh := q.doneWaiter if q.inFlight && waitCh == nil { waitCh = make(chan struct{}) q.doneWaiter = waitCh } + closed := q.closed + shutdownCtx := q.ctx q.mu.Unlock() + if closed { + return errExecQueueShutdown + } if waitCh == nil { return nil } @@ -98,6 +119,8 @@ func (q *ExecQueue) Wait(ctx context.Context) error { select { case <-waitCh: return nil + case <-shutdownCtx.Done(): + return errExecQueueShutdown case <-ctx.Done(): return ctx.Err() } diff --git a/util/execqueue/execqueue_test.go b/util/execqueue/execqueue_test.go index d10b741f72f8f..1bce69556e1f7 100644 --- a/util/execqueue/execqueue_test.go +++ b/util/execqueue/execqueue_test.go @@ -20,3 +20,12 @@ func TestExecQueue(t *testing.T) { t.Errorf("n=%d; want 1", got) } } + +// Test that RunSync doesn't hold q.mu and block Shutdown +// as we saw in tailscale/tailscale#18502 +func TestExecQueueRunSyncLocking(t *testing.T) { + q := &ExecQueue{} + q.RunSync(t.Context(), func() { + q.Shutdown() + }) +} diff --git a/util/expvarx/expvarx.go b/util/expvarx/expvarx.go index 762f65d069aa6..bcdc4a91a7982 100644 --- a/util/expvarx/expvarx.go +++ b/util/expvarx/expvarx.go @@ -7,9 +7,9 @@ package expvarx import ( "encoding/json" "expvar" - "sync" "time" + "tailscale.com/syncs" "tailscale.com/types/lazy" ) @@ -20,7 +20,7 @@ type SafeFunc struct { limit time.Duration onSlow func(time.Duration, any) - mu sync.Mutex + mu syncs.Mutex inflight *lazy.SyncValue[any] } diff --git a/util/goroutines/tracker.go b/util/goroutines/tracker.go index 044843d33d155..c2a0cb8c3a3ed 100644 --- a/util/goroutines/tracker.go +++ b/util/goroutines/tracker.go @@ -4,9 +4,9 @@ package goroutines import ( - "sync" "sync/atomic" + "tailscale.com/syncs" "tailscale.com/util/set" ) @@ -15,7 +15,7 @@ type Tracker struct { started atomic.Int64 // counter running atomic.Int64 // gauge - mu sync.Mutex + mu syncs.Mutex onDone set.HandleSet[func()] } diff --git a/util/limiter/limiter.go b/util/limiter/limiter.go index 5af5f7bd11950..b86efdf29cfd0 100644 --- a/util/limiter/limiter.go +++ b/util/limiter/limiter.go @@ -8,9 +8,9 @@ import ( "fmt" "html" "io" - "sync" "time" + "tailscale.com/syncs" "tailscale.com/util/lru" ) @@ -75,7 +75,7 @@ type Limiter[K comparable] struct { // perpetually in debt and cannot proceed at all. Overdraft int64 - mu sync.Mutex + mu syncs.Mutex cache *lru.Cache[K, *bucket] } @@ -94,59 +94,59 @@ type bucket struct { // Allow charges the key one token (up to the overdraft limit), and // reports whether the key can perform an action. -func (l *Limiter[K]) Allow(key K) bool { - return l.allow(key, time.Now()) +func (lm *Limiter[K]) Allow(key K) bool { + return lm.allow(key, time.Now()) } -func (l *Limiter[K]) allow(key K, now time.Time) bool { - l.mu.Lock() - defer l.mu.Unlock() - return l.allowBucketLocked(l.getBucketLocked(key, now), now) +func (lm *Limiter[K]) allow(key K, now time.Time) bool { + lm.mu.Lock() + defer lm.mu.Unlock() + return lm.allowBucketLocked(lm.getBucketLocked(key, now), now) } -func (l *Limiter[K]) getBucketLocked(key K, now time.Time) *bucket { - if l.cache == nil { - l.cache = &lru.Cache[K, *bucket]{MaxEntries: l.Size} - } else if b := l.cache.Get(key); b != nil { +func (lm *Limiter[K]) getBucketLocked(key K, now time.Time) *bucket { + if lm.cache == nil { + lm.cache = &lru.Cache[K, *bucket]{MaxEntries: lm.Size} + } else if b := lm.cache.Get(key); b != nil { return b } b := &bucket{ - cur: l.Max, - lastUpdate: now.Truncate(l.RefillInterval), + cur: lm.Max, + lastUpdate: now.Truncate(lm.RefillInterval), } - l.cache.Set(key, b) + lm.cache.Set(key, b) return b } -func (l *Limiter[K]) allowBucketLocked(b *bucket, now time.Time) bool { +func (lm *Limiter[K]) allowBucketLocked(b *bucket, now time.Time) bool { // Only update the bucket quota if needed to process request. if b.cur <= 0 { - l.updateBucketLocked(b, now) + lm.updateBucketLocked(b, now) } ret := b.cur > 0 - if b.cur > -l.Overdraft { + if b.cur > -lm.Overdraft { b.cur-- } return ret } -func (l *Limiter[K]) updateBucketLocked(b *bucket, now time.Time) { - now = now.Truncate(l.RefillInterval) +func (lm *Limiter[K]) updateBucketLocked(b *bucket, now time.Time) { + now = now.Truncate(lm.RefillInterval) if now.Before(b.lastUpdate) { return } timeDelta := max(now.Sub(b.lastUpdate), 0) - tokenDelta := int64(timeDelta / l.RefillInterval) - b.cur = min(b.cur+tokenDelta, l.Max) + tokenDelta := int64(timeDelta / lm.RefillInterval) + b.cur = min(b.cur+tokenDelta, lm.Max) b.lastUpdate = now } // peekForTest returns the number of tokens for key, also reporting // whether key was present. -func (l *Limiter[K]) tokensForTest(key K) (int64, bool) { - l.mu.Lock() - defer l.mu.Unlock() - if b, ok := l.cache.PeekOk(key); ok { +func (lm *Limiter[K]) tokensForTest(key K) (int64, bool) { + lm.mu.Lock() + defer lm.mu.Unlock() + if b, ok := lm.cache.PeekOk(key); ok { return b.cur, true } return 0, false @@ -159,12 +159,12 @@ func (l *Limiter[K]) tokensForTest(key K) (int64, bool) { // DumpHTML blocks other callers of the limiter while it collects the // state for dumping. It should not be called on large limiters // involved in hot codepaths. -func (l *Limiter[K]) DumpHTML(w io.Writer, onlyLimited bool) { - l.dumpHTML(w, onlyLimited, time.Now()) +func (lm *Limiter[K]) DumpHTML(w io.Writer, onlyLimited bool) { + lm.dumpHTML(w, onlyLimited, time.Now()) } -func (l *Limiter[K]) dumpHTML(w io.Writer, onlyLimited bool, now time.Time) { - dump := l.collectDump(now) +func (lm *Limiter[K]) dumpHTML(w io.Writer, onlyLimited bool, now time.Time) { + dump := lm.collectDump(now) io.WriteString(w, "") for _, line := range dump { if onlyLimited && line.Tokens > 0 { @@ -183,13 +183,13 @@ func (l *Limiter[K]) dumpHTML(w io.Writer, onlyLimited bool, now time.Time) { } // collectDump grabs a copy of the limiter state needed by DumpHTML. -func (l *Limiter[K]) collectDump(now time.Time) []dumpEntry[K] { - l.mu.Lock() - defer l.mu.Unlock() +func (lm *Limiter[K]) collectDump(now time.Time) []dumpEntry[K] { + lm.mu.Lock() + defer lm.mu.Unlock() - ret := make([]dumpEntry[K], 0, l.cache.Len()) - l.cache.ForEach(func(k K, v *bucket) { - l.updateBucketLocked(v, now) // so stats are accurate + ret := make([]dumpEntry[K], 0, lm.cache.Len()) + lm.cache.ForEach(func(k K, v *bucket) { + lm.updateBucketLocked(v, now) // so stats are accurate ret = append(ret, dumpEntry[K]{k, v.cur}) }) return ret diff --git a/util/limiter/limiter_test.go b/util/limiter/limiter_test.go index 1f466d88257ab..77b1d562b23fb 100644 --- a/util/limiter/limiter_test.go +++ b/util/limiter/limiter_test.go @@ -16,7 +16,7 @@ const testRefillInterval = time.Second func TestLimiter(t *testing.T) { // 1qps, burst of 10, 2 keys tracked - l := &Limiter[string]{ + limiter := &Limiter[string]{ Size: 2, Max: 10, RefillInterval: testRefillInterval, @@ -24,48 +24,48 @@ func TestLimiter(t *testing.T) { // Consume entire burst now := time.Now().Truncate(testRefillInterval) - allowed(t, l, "foo", 10, now) - denied(t, l, "foo", 1, now) - hasTokens(t, l, "foo", 0) + allowed(t, limiter, "foo", 10, now) + denied(t, limiter, "foo", 1, now) + hasTokens(t, limiter, "foo", 0) - allowed(t, l, "bar", 10, now) - denied(t, l, "bar", 1, now) - hasTokens(t, l, "bar", 0) + allowed(t, limiter, "bar", 10, now) + denied(t, limiter, "bar", 1, now) + hasTokens(t, limiter, "bar", 0) // Refill 1 token for both foo and bar now = now.Add(time.Second + time.Millisecond) - allowed(t, l, "foo", 1, now) - denied(t, l, "foo", 1, now) - hasTokens(t, l, "foo", 0) + allowed(t, limiter, "foo", 1, now) + denied(t, limiter, "foo", 1, now) + hasTokens(t, limiter, "foo", 0) - allowed(t, l, "bar", 1, now) - denied(t, l, "bar", 1, now) - hasTokens(t, l, "bar", 0) + allowed(t, limiter, "bar", 1, now) + denied(t, limiter, "bar", 1, now) + hasTokens(t, limiter, "bar", 0) // Refill 2 tokens for foo and bar now = now.Add(2*time.Second + time.Millisecond) - allowed(t, l, "foo", 2, now) - denied(t, l, "foo", 1, now) - hasTokens(t, l, "foo", 0) + allowed(t, limiter, "foo", 2, now) + denied(t, limiter, "foo", 1, now) + hasTokens(t, limiter, "foo", 0) - allowed(t, l, "bar", 2, now) - denied(t, l, "bar", 1, now) - hasTokens(t, l, "bar", 0) + allowed(t, limiter, "bar", 2, now) + denied(t, limiter, "bar", 1, now) + hasTokens(t, limiter, "bar", 0) // qux can burst 10, evicts foo so it can immediately burst 10 again too - allowed(t, l, "qux", 10, now) - denied(t, l, "qux", 1, now) - notInLimiter(t, l, "foo") - denied(t, l, "bar", 1, now) // refresh bar so foo lookup doesn't evict it - still throttled - - allowed(t, l, "foo", 10, now) - denied(t, l, "foo", 1, now) - hasTokens(t, l, "foo", 0) + allowed(t, limiter, "qux", 10, now) + denied(t, limiter, "qux", 1, now) + notInLimiter(t, limiter, "foo") + denied(t, limiter, "bar", 1, now) // refresh bar so foo lookup doesn't evict it - still throttled + + allowed(t, limiter, "foo", 10, now) + denied(t, limiter, "foo", 1, now) + hasTokens(t, limiter, "foo", 0) } func TestLimiterOverdraft(t *testing.T) { // 1qps, burst of 10, overdraft of 2, 2 keys tracked - l := &Limiter[string]{ + limiter := &Limiter[string]{ Size: 2, Max: 10, Overdraft: 2, @@ -74,51 +74,51 @@ func TestLimiterOverdraft(t *testing.T) { // Consume entire burst, go 1 into debt now := time.Now().Truncate(testRefillInterval).Add(time.Millisecond) - allowed(t, l, "foo", 10, now) - denied(t, l, "foo", 1, now) - hasTokens(t, l, "foo", -1) + allowed(t, limiter, "foo", 10, now) + denied(t, limiter, "foo", 1, now) + hasTokens(t, limiter, "foo", -1) - allowed(t, l, "bar", 10, now) - denied(t, l, "bar", 1, now) - hasTokens(t, l, "bar", -1) + allowed(t, limiter, "bar", 10, now) + denied(t, limiter, "bar", 1, now) + hasTokens(t, limiter, "bar", -1) // Refill 1 token for both foo and bar. // Still denied, still in debt. now = now.Add(time.Second) - denied(t, l, "foo", 1, now) - hasTokens(t, l, "foo", -1) - denied(t, l, "bar", 1, now) - hasTokens(t, l, "bar", -1) + denied(t, limiter, "foo", 1, now) + hasTokens(t, limiter, "foo", -1) + denied(t, limiter, "bar", 1, now) + hasTokens(t, limiter, "bar", -1) // Refill 2 tokens for foo and bar (1 available after debt), try // to consume 4. Overdraft is capped to 2. now = now.Add(2 * time.Second) - allowed(t, l, "foo", 1, now) - denied(t, l, "foo", 3, now) - hasTokens(t, l, "foo", -2) + allowed(t, limiter, "foo", 1, now) + denied(t, limiter, "foo", 3, now) + hasTokens(t, limiter, "foo", -2) - allowed(t, l, "bar", 1, now) - denied(t, l, "bar", 3, now) - hasTokens(t, l, "bar", -2) + allowed(t, limiter, "bar", 1, now) + denied(t, limiter, "bar", 3, now) + hasTokens(t, limiter, "bar", -2) // Refill 1, not enough to allow. now = now.Add(time.Second) - denied(t, l, "foo", 1, now) - hasTokens(t, l, "foo", -2) - denied(t, l, "bar", 1, now) - hasTokens(t, l, "bar", -2) + denied(t, limiter, "foo", 1, now) + hasTokens(t, limiter, "foo", -2) + denied(t, limiter, "bar", 1, now) + hasTokens(t, limiter, "bar", -2) // qux evicts foo, foo can immediately burst 10 again. - allowed(t, l, "qux", 1, now) - hasTokens(t, l, "qux", 9) - notInLimiter(t, l, "foo") - allowed(t, l, "foo", 10, now) - denied(t, l, "foo", 1, now) - hasTokens(t, l, "foo", -1) + allowed(t, limiter, "qux", 1, now) + hasTokens(t, limiter, "qux", 9) + notInLimiter(t, limiter, "foo") + allowed(t, limiter, "foo", 10, now) + denied(t, limiter, "foo", 1, now) + hasTokens(t, limiter, "foo", -1) } func TestDumpHTML(t *testing.T) { - l := &Limiter[string]{ + limiter := &Limiter[string]{ Size: 3, Max: 10, Overdraft: 10, @@ -126,13 +126,13 @@ func TestDumpHTML(t *testing.T) { } now := time.Now().Truncate(testRefillInterval).Add(time.Millisecond) - allowed(t, l, "foo", 10, now) - denied(t, l, "foo", 2, now) - allowed(t, l, "bar", 4, now) - allowed(t, l, "qux", 1, now) + allowed(t, limiter, "foo", 10, now) + denied(t, limiter, "foo", 2, now) + allowed(t, limiter, "bar", 4, now) + allowed(t, limiter, "qux", 1, now) var out bytes.Buffer - l.DumpHTML(&out, false) + limiter.DumpHTML(&out, false) want := strings.Join([]string{ "
KeyTokens
", "", @@ -146,7 +146,7 @@ func TestDumpHTML(t *testing.T) { } out.Reset() - l.DumpHTML(&out, true) + limiter.DumpHTML(&out, true) want = strings.Join([]string{ "
KeyTokens
", "", @@ -161,7 +161,7 @@ func TestDumpHTML(t *testing.T) { // organically. now = now.Add(3 * time.Second) out.Reset() - l.dumpHTML(&out, false, now) + limiter.dumpHTML(&out, false, now) want = strings.Join([]string{ "
KeyTokens
", "", @@ -175,29 +175,29 @@ func TestDumpHTML(t *testing.T) { } } -func allowed(t *testing.T, l *Limiter[string], key string, count int, now time.Time) { +func allowed(t *testing.T, limiter *Limiter[string], key string, count int, now time.Time) { t.Helper() for i := range count { - if !l.allow(key, now) { - toks, ok := l.tokensForTest(key) + if !limiter.allow(key, now) { + toks, ok := limiter.tokensForTest(key) t.Errorf("after %d times: allow(%q, %q) = false, want true (%d tokens available, in cache = %v)", i, key, now, toks, ok) } } } -func denied(t *testing.T, l *Limiter[string], key string, count int, now time.Time) { +func denied(t *testing.T, limiter *Limiter[string], key string, count int, now time.Time) { t.Helper() for i := range count { - if l.allow(key, now) { - toks, ok := l.tokensForTest(key) + if limiter.allow(key, now) { + toks, ok := limiter.tokensForTest(key) t.Errorf("after %d times: allow(%q, %q) = true, want false (%d tokens available, in cache = %v)", i, key, now, toks, ok) } } } -func hasTokens(t *testing.T, l *Limiter[string], key string, want int64) { +func hasTokens(t *testing.T, limiter *Limiter[string], key string, want int64) { t.Helper() - got, ok := l.tokensForTest(key) + got, ok := limiter.tokensForTest(key) if !ok { t.Errorf("key %q missing from limiter", key) } else if got != want { @@ -205,9 +205,9 @@ func hasTokens(t *testing.T, l *Limiter[string], key string, want int64) { } } -func notInLimiter(t *testing.T, l *Limiter[string], key string) { +func notInLimiter(t *testing.T, limiter *Limiter[string], key string) { t.Helper() - if tokens, ok := l.tokensForTest(key); ok { + if tokens, ok := limiter.tokensForTest(key); ok { t.Errorf("key %q unexpectedly tracked by limiter, with %d tokens", key, tokens) } } diff --git a/util/linuxfw/detector.go b/util/linuxfw/detector.go index 644126131bbba..149e0c96049c8 100644 --- a/util/linuxfw/detector.go +++ b/util/linuxfw/detector.go @@ -85,7 +85,7 @@ type tableDetector interface { type linuxFWDetector struct{} // iptDetect returns the number of iptables rules in the current namespace. -func (l linuxFWDetector) iptDetect() (int, error) { +func (ld linuxFWDetector) iptDetect() (int, error) { return detectIptables() } @@ -96,7 +96,7 @@ var hookDetectNetfilter feature.Hook[func() (int, error)] var ErrUnsupported = errors.New("linuxfw:unsupported") // nftDetect returns the number of nftables rules in the current namespace. -func (l linuxFWDetector) nftDetect() (int, error) { +func (ld linuxFWDetector) nftDetect() (int, error) { if f, ok := hookDetectNetfilter.GetOk(); ok { return f() } diff --git a/util/lru/lru_test.go b/util/lru/lru_test.go index 5500e5e0f309f..04de2e5070c87 100644 --- a/util/lru/lru_test.go +++ b/util/lru/lru_test.go @@ -84,8 +84,8 @@ func TestStressEvictions(t *testing.T) { for range numProbes { v := vals[rand.Intn(len(vals))] c.Set(v, true) - if l := c.Len(); l > cacheSize { - t.Fatalf("Cache size now %d, want max %d", l, cacheSize) + if ln := c.Len(); ln > cacheSize { + t.Fatalf("Cache size now %d, want max %d", ln, cacheSize) } } } @@ -119,8 +119,8 @@ func TestStressBatchedEvictions(t *testing.T) { c.DeleteOldest() } } - if l := c.Len(); l > cacheSizeMax { - t.Fatalf("Cache size now %d, want max %d", l, cacheSizeMax) + if ln := c.Len(); ln > cacheSizeMax { + t.Fatalf("Cache size now %d, want max %d", ln, cacheSizeMax) } } } diff --git a/util/osdiag/zsyscall_windows.go b/util/osdiag/zsyscall_windows.go index ab0d18d3f9c98..2a11b4644fca8 100644 --- a/util/osdiag/zsyscall_windows.go +++ b/util/osdiag/zsyscall_windows.go @@ -51,7 +51,7 @@ var ( ) func regEnumValue(key registry.Key, index uint32, valueName *uint16, valueNameLen *uint32, reserved *uint32, valueType *uint32, pData *byte, cbData *uint32) (ret error) { - r0, _, _ := syscall.Syscall9(procRegEnumValueW.Addr(), 8, uintptr(key), uintptr(index), uintptr(unsafe.Pointer(valueName)), uintptr(unsafe.Pointer(valueNameLen)), uintptr(unsafe.Pointer(reserved)), uintptr(unsafe.Pointer(valueType)), uintptr(unsafe.Pointer(pData)), uintptr(unsafe.Pointer(cbData)), 0) + r0, _, _ := syscall.SyscallN(procRegEnumValueW.Addr(), uintptr(key), uintptr(index), uintptr(unsafe.Pointer(valueName)), uintptr(unsafe.Pointer(valueNameLen)), uintptr(unsafe.Pointer(reserved)), uintptr(unsafe.Pointer(valueType)), uintptr(unsafe.Pointer(pData)), uintptr(unsafe.Pointer(cbData))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -59,7 +59,7 @@ func regEnumValue(key registry.Key, index uint32, valueName *uint16, valueNameLe } func globalMemoryStatusEx(memStatus *_MEMORYSTATUSEX) (err error) { - r1, _, e1 := syscall.Syscall(procGlobalMemoryStatusEx.Addr(), 1, uintptr(unsafe.Pointer(memStatus)), 0, 0) + r1, _, e1 := syscall.SyscallN(procGlobalMemoryStatusEx.Addr(), uintptr(unsafe.Pointer(memStatus))) if int32(r1) == 0 { err = errnoErr(e1) } @@ -67,19 +67,19 @@ func globalMemoryStatusEx(memStatus *_MEMORYSTATUSEX) (err error) { } func wscEnumProtocols(iProtocols *int32, protocolBuffer *wsaProtocolInfo, bufLen *uint32, errno *int32) (ret int32) { - r0, _, _ := syscall.Syscall6(procWSCEnumProtocols.Addr(), 4, uintptr(unsafe.Pointer(iProtocols)), uintptr(unsafe.Pointer(protocolBuffer)), uintptr(unsafe.Pointer(bufLen)), uintptr(unsafe.Pointer(errno)), 0, 0) + r0, _, _ := syscall.SyscallN(procWSCEnumProtocols.Addr(), uintptr(unsafe.Pointer(iProtocols)), uintptr(unsafe.Pointer(protocolBuffer)), uintptr(unsafe.Pointer(bufLen)), uintptr(unsafe.Pointer(errno))) ret = int32(r0) return } func wscGetProviderInfo(providerId *windows.GUID, infoType _WSC_PROVIDER_INFO_TYPE, info unsafe.Pointer, infoSize *uintptr, flags uint32, errno *int32) (ret int32) { - r0, _, _ := syscall.Syscall6(procWSCGetProviderInfo.Addr(), 6, uintptr(unsafe.Pointer(providerId)), uintptr(infoType), uintptr(info), uintptr(unsafe.Pointer(infoSize)), uintptr(flags), uintptr(unsafe.Pointer(errno))) + r0, _, _ := syscall.SyscallN(procWSCGetProviderInfo.Addr(), uintptr(unsafe.Pointer(providerId)), uintptr(infoType), uintptr(info), uintptr(unsafe.Pointer(infoSize)), uintptr(flags), uintptr(unsafe.Pointer(errno))) ret = int32(r0) return } func wscGetProviderPath(providerId *windows.GUID, providerDllPath *uint16, providerDllPathLen *int32, errno *int32) (ret int32) { - r0, _, _ := syscall.Syscall6(procWSCGetProviderPath.Addr(), 4, uintptr(unsafe.Pointer(providerId)), uintptr(unsafe.Pointer(providerDllPath)), uintptr(unsafe.Pointer(providerDllPathLen)), uintptr(unsafe.Pointer(errno)), 0, 0) + r0, _, _ := syscall.SyscallN(procWSCGetProviderPath.Addr(), uintptr(unsafe.Pointer(providerId)), uintptr(unsafe.Pointer(providerDllPath)), uintptr(unsafe.Pointer(providerDllPathLen)), uintptr(unsafe.Pointer(errno))) ret = int32(r0) return } diff --git a/util/ringlog/ringlog.go b/util/ringlog/ringlog.go index 85e0c48611821..62dfbae5bd5c3 100644 --- a/util/ringlog/ringlog.go +++ b/util/ringlog/ringlog.go @@ -4,7 +4,7 @@ // Package ringlog contains a limited-size concurrency-safe generic ring log. package ringlog -import "sync" +import "tailscale.com/syncs" // New creates a new [RingLog] containing at most max items. func New[T any](max int) *RingLog[T] { @@ -15,7 +15,7 @@ func New[T any](max int) *RingLog[T] { // RingLog is a concurrency-safe fixed size log window containing entries of [T]. type RingLog[T any] struct { - mu sync.Mutex + mu syncs.Mutex pos int buf []T max int diff --git a/util/safediff/diff.go b/util/safediff/diff.go new file mode 100644 index 0000000000000..cf8add94b21dd --- /dev/null +++ b/util/safediff/diff.go @@ -0,0 +1,280 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package safediff computes the difference between two lists. +// +// It is guaranteed to run in O(n), but may not produce an optimal diff. +// Most diffing algorithms produce optimal diffs but run in O(n²). +// It is safe to pass in untrusted input. +package safediff + +import ( + "bytes" + "fmt" + "math" + "strings" + "unicode" + + "github.com/google/go-cmp/cmp" +) + +var diffTest = false + +// Lines constructs a humanly readable line-by-line diff from x to y. +// The output (if multiple lines) is guaranteed to be no larger than maxSize, +// by truncating the output if necessary. A negative maxSize enforces no limit. +// +// Example diff: +// +// … 440 identical lines +// "ssh": [ +// … 35 identical lines +// { +// - "src": ["maisem@tailscale.com"], +// - "dst": ["tag:maisem-test"], +// - "users": ["maisem", "root"], +// - "action": "check", +// - // "recorder": ["100.12.34.56:80"], +// + "src": ["maisem@tailscale.com"], +// + "dst": ["tag:maisem-test"], +// + "users": ["maisem", "root"], +// + "action": "check", +// + "recorder": ["node:recorder-2"], +// }, +// … 77 identical lines +// ], +// … 345 identical lines +// +// Meaning of each line prefix: +// +// - '…' precedes a summary statement +// - ' ' precedes an identical line printed for context +// - '-' precedes a line removed from x +// - '+' precedes a line inserted from y +// +// The diffing algorithm runs in O(n) and is safe to use with untrusted inputs. +func Lines(x, y string, maxSize int) (out string, truncated bool) { + // Convert x and y into a slice of lines and compute the edit-script. + xs := strings.Split(x, "\n") + ys := strings.Split(y, "\n") + es := diffStrings(xs, ys) + + // Modify the edit-script to support printing identical lines of context. + const identicalContext edit = '*' // special edit code to indicate printed line + var xi, yi int // index into xs or ys + isIdentical := func(e edit) bool { return e == identical || e == identicalContext } + indentOf := func(s string) string { return s[:len(s)-len(strings.TrimLeftFunc(s, unicode.IsSpace))] } + for i, e := range es { + if isIdentical(e) { + // Print current line if adjacent symbols are non-identical. + switch { + case i-1 >= 0 && !isIdentical(es[i-1]): + es[i] = identicalContext + case i+1 < len(es) && !isIdentical(es[i+1]): + es[i] = identicalContext + } + } else { + // Print any preceding or succeeding lines, + // where the leading indent is a prefix of the current indent. + // Indentation often indicates a parent-child relationship + // in structured source code. + addParents := func(ss []string, si, direction int) { + childIndent := indentOf(ss[si]) + for j := direction; i+j >= 0 && i+j < len(es) && isIdentical(es[i+j]); j += direction { + parentIndent := indentOf(ss[si+j]) + if strings.HasPrefix(childIndent, parentIndent) && len(parentIndent) < len(childIndent) && parentIndent != "" { + es[i+j] = identicalContext + childIndent = parentIndent + } + } + } + switch e { + case removed, modified: // arbitrarily use the x value for modified values + addParents(xs, xi, -1) + addParents(xs, xi, +1) + case inserted: + addParents(ys, yi, -1) + addParents(ys, yi, +1) + } + } + if e != inserted { + xi++ + } + if e != removed { + yi++ + } + } + + // Show the line for a single hidden identical line, + // since it occupies the same vertical height. + for i, e := range es { + if e == identical { + prevNotIdentical := i-1 < 0 || es[i-1] != identical + nextNotIdentical := i+1 >= len(es) || es[i+1] != identical + if prevNotIdentical && nextNotIdentical { + es[i] = identicalContext + } + } + } + + // Adjust the maxSize, reserving space for the final summary. + if maxSize < 0 { + maxSize = math.MaxInt + } + maxSize -= len(stats{len(xs) + len(ys), len(xs), len(ys)}.appendText(nil)) + + // mayAppendLine appends a line if it does not exceed maxSize. + // Otherwise, it just updates prevStats. + var buf []byte + var prevStats stats + mayAppendLine := func(edit edit, line string) { + // Append the stats (if non-zero) and the line text. + // The stats reports the number of preceding identical lines. + if !truncated { + bufLen := len(buf) // original length (in case we exceed maxSize) + if !prevStats.isZero() { + buf = prevStats.appendText(buf) + prevStats = stats{} // just printed, so clear the stats + } + buf = fmt.Appendf(buf, "%c %s\n", edit, line) + truncated = len(buf) > maxSize + if !truncated { + return + } + buf = buf[:bufLen] // restore original buffer contents + } + + // Output is truncated, so just update the statistics. + switch edit { + case identical: + prevStats.numIdentical++ + case removed: + prevStats.numRemoved++ + case inserted: + prevStats.numInserted++ + } + } + + // Process the entire edit script. + for len(es) > 0 { + num := len(es) - len(bytes.TrimLeft(es, string(es[:1]))) + switch es[0] { + case identical: + prevStats.numIdentical += num + xs, ys = xs[num:], ys[num:] + case identicalContext: + for n := len(xs) - num; len(xs) > n; xs, ys = xs[1:], ys[1:] { + mayAppendLine(identical, xs[0]) // implies xs[0] == ys[0] + } + case modified: + for n := len(xs) - num; len(xs) > n; xs = xs[1:] { + mayAppendLine(removed, xs[0]) + } + for n := len(ys) - num; len(ys) > n; ys = ys[1:] { + mayAppendLine(inserted, ys[0]) + } + case removed: + for n := len(xs) - num; len(xs) > n; xs = xs[1:] { + mayAppendLine(removed, xs[0]) + } + case inserted: + for n := len(ys) - num; len(ys) > n; ys = ys[1:] { + mayAppendLine(inserted, ys[0]) + } + } + es = es[num:] + } + if len(xs)+len(ys)+len(es) > 0 { + panic("BUG: slices not fully consumed") + } + + if !prevStats.isZero() { + buf = prevStats.appendText(buf) // may exceed maxSize + } + return string(buf), truncated +} + +type stats struct{ numIdentical, numRemoved, numInserted int } + +func (s stats) isZero() bool { return s.numIdentical+s.numRemoved+s.numInserted == 0 } + +func (s stats) appendText(b []byte) []byte { + switch { + case s.numIdentical > 0 && s.numRemoved > 0 && s.numInserted > 0: + return fmt.Appendf(b, "… %d identical, %d removed, and %d inserted lines\n", s.numIdentical, s.numRemoved, s.numInserted) + case s.numIdentical > 0 && s.numRemoved > 0: + return fmt.Appendf(b, "… %d identical and %d removed lines\n", s.numIdentical, s.numRemoved) + case s.numIdentical > 0 && s.numInserted > 0: + return fmt.Appendf(b, "… %d identical and %d inserted lines\n", s.numIdentical, s.numInserted) + case s.numRemoved > 0 && s.numInserted > 0: + return fmt.Appendf(b, "… %d removed and %d inserted lines\n", s.numRemoved, s.numInserted) + case s.numIdentical > 0: + return fmt.Appendf(b, "… %d identical lines\n", s.numIdentical) + case s.numRemoved > 0: + return fmt.Appendf(b, "… %d removed lines\n", s.numRemoved) + case s.numInserted > 0: + return fmt.Appendf(b, "… %d inserted lines\n", s.numInserted) + default: + return fmt.Appendf(b, "…\n") + } +} + +// diffStrings computes an edit-script of two slices of strings. +// +// This calls cmp.Equal to access the "github.com/go-cmp/cmp/internal/diff" +// implementation, which has an O(N) diffing algorithm. It is not guaranteed +// to produce an optimal edit-script, but protects our runtime against +// adversarial inputs that would wreck the optimal O(N²) algorithm used by +// most diffing packages available in open-source. +// +// TODO(https://go.dev/issue/58893): Use "golang.org/x/tools/diff" instead? +func diffStrings(xs, ys []string) []edit { + d := new(diffRecorder) + cmp.Equal(xs, ys, cmp.Reporter(d)) + if diffTest { + numRemoved := bytes.Count(d.script, []byte{removed}) + numInserted := bytes.Count(d.script, []byte{inserted}) + if len(xs) != len(d.script)-numInserted || len(ys) != len(d.script)-numRemoved { + panic("BUG: edit-script is inconsistent") + } + } + return d.script +} + +type edit = byte + +const ( + identical edit = ' ' // equal symbol in both x and y + modified edit = '~' // modified symbol in both x and y + removed edit = '-' // removed symbol from x + inserted edit = '+' // inserted symbol from y +) + +// diffRecorder reproduces an edit-script, essentially recording +// the edit-script from "github.com/google/go-cmp/cmp/internal/diff". +// This implements the cmp.Reporter interface. +type diffRecorder struct { + last cmp.PathStep + script []edit +} + +func (d *diffRecorder) PushStep(ps cmp.PathStep) { d.last = ps } + +func (d *diffRecorder) Report(rs cmp.Result) { + if si, ok := d.last.(cmp.SliceIndex); ok { + if rs.Equal() { + d.script = append(d.script, identical) + } else { + switch xi, yi := si.SplitKeys(); { + case xi >= 0 && yi >= 0: + d.script = append(d.script, modified) + case xi >= 0: + d.script = append(d.script, removed) + case yi >= 0: + d.script = append(d.script, inserted) + } + } + } +} + +func (d *diffRecorder) PopStep() { d.last = nil } diff --git a/util/safediff/diff_test.go b/util/safediff/diff_test.go new file mode 100644 index 0000000000000..e580bd9222dd9 --- /dev/null +++ b/util/safediff/diff_test.go @@ -0,0 +1,196 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package safediff + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func init() { diffTest = true } + +func TestLines(t *testing.T) { + // The diffs shown below technically depend on the stability of cmp, + // but that should be fine for sufficiently simple diffs like these. + // If the output does change, that would suggest a significant regression + // in the optimality of cmp's diffing algorithm. + + x := `{ + "firstName": "John", + "lastName": "Smith", + "isAlive": true, + "age": 27, + "address": { + "streetAddress": "21 2nd Street", + "city": "New York", + "state": "NY", + "postalCode": "10021-3100" + }, + "phoneNumbers": [{ + "type": "home", + "number": "212 555-1234" + }, { + "type": "office", + "number": "646 555-4567" + }], + "children": [ + "Catherine", + "Thomas", + "Trevor" + ], + "spouse": null +}` + y := x + y = strings.ReplaceAll(y, `"New York"`, `"Los Angeles"`) + y = strings.ReplaceAll(y, `"NY"`, `"CA"`) + y = strings.ReplaceAll(y, `"646 555-4567"`, `"315 252-8888"`) + + wantDiff := ` +… 5 identical lines + "address": { + "streetAddress": "21 2nd Street", +- "city": "New York", +- "state": "NY", ++ "city": "Los Angeles", ++ "state": "CA", + "postalCode": "10021-3100" + }, +… 3 identical lines + }, { + "type": "office", +- "number": "646 555-4567" ++ "number": "315 252-8888" + }], +… 7 identical lines +`[1:] + gotDiff, gotTrunc := Lines(x, y, -1) + if d := cmp.Diff(gotDiff, wantDiff); d != "" { + t.Errorf("Lines mismatch (-got +want):\n%s\ngot:\n%s\nwant:\n%s", d, gotDiff, wantDiff) + } else if gotTrunc == true { + t.Errorf("Lines: output unexpectedly truncated") + } + + wantDiff = ` +… 5 identical lines + "address": { + "streetAddress": "21 2nd Street", +- "city": "New York", +- "state": "NY", ++ "city": "Los Angeles", +… 15 identical, 1 removed, and 2 inserted lines +`[1:] + gotDiff, gotTrunc = Lines(x, y, 200) + if d := cmp.Diff(gotDiff, wantDiff); d != "" { + t.Errorf("Lines mismatch (-got +want):\n%s\ngot:\n%s\nwant:\n%s", d, gotDiff, wantDiff) + } else if gotTrunc == false { + t.Errorf("Lines: output unexpectedly not truncated") + } + + wantDiff = "… 17 identical, 3 removed, and 3 inserted lines\n" + gotDiff, gotTrunc = Lines(x, y, 0) + if d := cmp.Diff(gotDiff, wantDiff); d != "" { + t.Errorf("Lines mismatch (-got +want):\n%s\ngot:\n%s\nwant:\n%s", d, gotDiff, wantDiff) + } else if gotTrunc == false { + t.Errorf("Lines: output unexpectedly not truncated") + } + + x = `{ + "unrelated": [ + "unrelated", + ], + "related": { + "unrelated": [ + "unrelated", + ], + "related": { + "unrelated": [ + "unrelated", + ], + "related": { + "related": "changed", + }, + "unrelated": [ + "unrelated", + ], + }, + "unrelated": [ + "unrelated", + ], + }, + "unrelated": [ + "unrelated", + ], +}` + y = strings.ReplaceAll(x, "changed", "CHANGED") + + wantDiff = ` +… 4 identical lines + "related": { +… 3 identical lines + "related": { +… 3 identical lines + "related": { +- "related": "changed", ++ "related": "CHANGED", + }, +… 3 identical lines + }, +… 3 identical lines + }, +… 4 identical lines +`[1:] + gotDiff, gotTrunc = Lines(x, y, -1) + if d := cmp.Diff(gotDiff, wantDiff); d != "" { + t.Errorf("Lines mismatch (-got +want):\n%s\ngot:\n%s\nwant:\n%s", d, gotDiff, wantDiff) + } else if gotTrunc == true { + t.Errorf("Lines: output unexpectedly truncated") + } + + x = `{ + "ACLs": [ + { + "Action": "accept", + "Users": ["group:all"], + "Ports": ["tag:tmemes:80"], + }, + ], +}` + y = strings.ReplaceAll(x, "tag:tmemes:80", "tag:tmemes:80,8383") + wantDiff = ` + { + "ACLs": [ + { + "Action": "accept", + "Users": ["group:all"], +- "Ports": ["tag:tmemes:80"], ++ "Ports": ["tag:tmemes:80,8383"], + }, + ], + } +`[1:] + gotDiff, gotTrunc = Lines(x, y, -1) + if d := cmp.Diff(gotDiff, wantDiff); d != "" { + t.Errorf("Lines mismatch (-got +want):\n%s\ngot:\n%s\nwant:\n%s", d, gotDiff, wantDiff) + } else if gotTrunc == true { + t.Errorf("Lines: output unexpectedly truncated") + } +} + +func FuzzDiff(f *testing.F) { + f.Fuzz(func(t *testing.T, x, y string, maxSize int) { + const maxInput = 1e3 + if len(x) > maxInput { + x = x[:maxInput] + } + if len(y) > maxInput { + y = y[:maxInput] + } + diff, _ := Lines(x, y, maxSize) // make sure this does not panic + if strings.Count(diff, "\n") > 1 && maxSize >= 0 && len(diff) > maxSize { + t.Fatal("maxSize exceeded") + } + }) +} diff --git a/util/syspolicy/rsop/change_callbacks.go b/util/syspolicy/rsop/change_callbacks.go index fdf51c253cbd7..71135bb2ac788 100644 --- a/util/syspolicy/rsop/change_callbacks.go +++ b/util/syspolicy/rsop/change_callbacks.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "tailscale.com/syncs" "tailscale.com/util/set" "tailscale.com/util/syspolicy/internal/loggerx" "tailscale.com/util/syspolicy/pkey" @@ -70,7 +71,7 @@ func (c PolicyChange) HasChangedAnyOf(keys ...pkey.Key) bool { // policyChangeCallbacks are the callbacks to invoke when the effective policy changes. // It is safe for concurrent use. type policyChangeCallbacks struct { - mu sync.Mutex + mu syncs.Mutex cbs set.HandleSet[PolicyChangeCallback] } diff --git a/util/syspolicy/rsop/resultant_policy.go b/util/syspolicy/rsop/resultant_policy.go index 297d26f9f6fe5..bdda909763008 100644 --- a/util/syspolicy/rsop/resultant_policy.go +++ b/util/syspolicy/rsop/resultant_policy.go @@ -7,10 +7,10 @@ import ( "errors" "fmt" "slices" - "sync" "sync/atomic" "time" + "tailscale.com/syncs" "tailscale.com/util/syspolicy/internal/loggerx" "tailscale.com/util/syspolicy/setting" "tailscale.com/util/testenv" @@ -58,7 +58,7 @@ type Policy struct { changeCallbacks policyChangeCallbacks - mu sync.Mutex + mu syncs.Mutex watcherStarted bool // whether [Policy.watchReload] was started sources source.ReadableSources closing bool // whether [Policy.Close] was called (even if we're still closing) diff --git a/util/syspolicy/rsop/rsop.go b/util/syspolicy/rsop/rsop.go index 429b9b10121b3..333dca64343c1 100644 --- a/util/syspolicy/rsop/rsop.go +++ b/util/syspolicy/rsop/rsop.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "slices" - "sync" "tailscale.com/syncs" "tailscale.com/util/slicesx" @@ -20,7 +19,7 @@ import ( ) var ( - policyMu sync.Mutex // protects [policySources] and [effectivePolicies] + policyMu syncs.Mutex // protects [policySources] and [effectivePolicies] policySources []*source.Source // all registered policy sources effectivePolicies []*Policy // all active (non-closed) effective policies returned by [PolicyFor] diff --git a/util/syspolicy/setting/setting.go b/util/syspolicy/setting/setting.go index 091cf58d31b71..97362b1dca8e0 100644 --- a/util/syspolicy/setting/setting.go +++ b/util/syspolicy/setting/setting.go @@ -11,9 +11,9 @@ import ( "fmt" "slices" "strings" - "sync" "time" + "tailscale.com/syncs" "tailscale.com/types/lazy" "tailscale.com/util/syspolicy/internal" "tailscale.com/util/syspolicy/pkey" @@ -215,7 +215,7 @@ type DefinitionMap map[pkey.Key]*Definition var ( definitions lazy.SyncValue[DefinitionMap] - definitionsMu sync.Mutex + definitionsMu syncs.Mutex definitionsList []*Definition definitionsUsed bool ) @@ -322,33 +322,33 @@ func Definitions() ([]*Definition, error) { type PlatformList []string // Has reports whether l contains the target platform. -func (l PlatformList) Has(target string) bool { - if len(l) == 0 { +func (ls PlatformList) Has(target string) bool { + if len(ls) == 0 { return true } - return slices.ContainsFunc(l, func(os string) bool { + return slices.ContainsFunc(ls, func(os string) bool { return strings.EqualFold(os, target) }) } // HasCurrent is like Has, but for the current platform. -func (l PlatformList) HasCurrent() bool { - return l.Has(internal.OS()) +func (ls PlatformList) HasCurrent() bool { + return ls.Has(internal.OS()) } // mergeFrom merges l2 into l. Since an empty list indicates no platform restrictions, // if either l or l2 is empty, the merged result in l will also be empty. -func (l *PlatformList) mergeFrom(l2 PlatformList) { +func (ls *PlatformList) mergeFrom(l2 PlatformList) { switch { - case len(*l) == 0: + case len(*ls) == 0: // No-op. An empty list indicates no platform restrictions. case len(l2) == 0: // Merging with an empty list results in an empty list. - *l = l2 + *ls = l2 default: // Append, sort and dedup. - *l = append(*l, l2...) - slices.Sort(*l) - *l = slices.Compact(*l) + *ls = append(*ls, l2...) + slices.Sort(*ls) + *ls = slices.Compact(*ls) } } diff --git a/util/syspolicy/setting/setting_test.go b/util/syspolicy/setting/setting_test.go index e43495a160e12..9d99884f6436f 100644 --- a/util/syspolicy/setting/setting_test.go +++ b/util/syspolicy/setting/setting_test.go @@ -311,8 +311,8 @@ func TestListSettingDefinitions(t *testing.T) { t.Fatalf("SetDefinitionsForTest failed: %v", err) } - cmp := func(l, r *Definition) int { - return strings.Compare(string(l.Key()), string(r.Key())) + cmp := func(a, b *Definition) int { + return strings.Compare(string(a.Key()), string(b.Key())) } want := append([]*Definition{}, definitions...) slices.SortFunc(want, cmp) diff --git a/util/winutil/authenticode/zsyscall_windows.go b/util/winutil/authenticode/zsyscall_windows.go index 643721e06aad5..f1fba2828713c 100644 --- a/util/winutil/authenticode/zsyscall_windows.go +++ b/util/winutil/authenticode/zsyscall_windows.go @@ -56,7 +56,7 @@ var ( ) func cryptMsgClose(cryptMsg windows.Handle) (err error) { - r1, _, e1 := syscall.Syscall(procCryptMsgClose.Addr(), 1, uintptr(cryptMsg), 0, 0) + r1, _, e1 := syscall.SyscallN(procCryptMsgClose.Addr(), uintptr(cryptMsg)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -64,7 +64,7 @@ func cryptMsgClose(cryptMsg windows.Handle) (err error) { } func cryptMsgGetParam(cryptMsg windows.Handle, paramType uint32, index uint32, data unsafe.Pointer, dataLen *uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procCryptMsgGetParam.Addr(), 5, uintptr(cryptMsg), uintptr(paramType), uintptr(index), uintptr(data), uintptr(unsafe.Pointer(dataLen)), 0) + r1, _, e1 := syscall.SyscallN(procCryptMsgGetParam.Addr(), uintptr(cryptMsg), uintptr(paramType), uintptr(index), uintptr(data), uintptr(unsafe.Pointer(dataLen))) if int32(r1) == 0 { err = errnoErr(e1) } @@ -72,7 +72,7 @@ func cryptMsgGetParam(cryptMsg windows.Handle, paramType uint32, index uint32, d } func cryptVerifyMessageSignature(pVerifyPara *_CRYPT_VERIFY_MESSAGE_PARA, signerIndex uint32, pbSignedBlob *byte, cbSignedBlob uint32, pbDecoded *byte, pdbDecoded *uint32, ppSignerCert **windows.CertContext) (err error) { - r1, _, e1 := syscall.Syscall9(procCryptVerifyMessageSignature.Addr(), 7, uintptr(unsafe.Pointer(pVerifyPara)), uintptr(signerIndex), uintptr(unsafe.Pointer(pbSignedBlob)), uintptr(cbSignedBlob), uintptr(unsafe.Pointer(pbDecoded)), uintptr(unsafe.Pointer(pdbDecoded)), uintptr(unsafe.Pointer(ppSignerCert)), 0, 0) + r1, _, e1 := syscall.SyscallN(procCryptVerifyMessageSignature.Addr(), uintptr(unsafe.Pointer(pVerifyPara)), uintptr(signerIndex), uintptr(unsafe.Pointer(pbSignedBlob)), uintptr(cbSignedBlob), uintptr(unsafe.Pointer(pbDecoded)), uintptr(unsafe.Pointer(pdbDecoded)), uintptr(unsafe.Pointer(ppSignerCert))) if int32(r1) == 0 { err = errnoErr(e1) } @@ -80,13 +80,13 @@ func cryptVerifyMessageSignature(pVerifyPara *_CRYPT_VERIFY_MESSAGE_PARA, signer } func msiGetFileSignatureInformation(signedObjectPath *uint16, flags uint32, certCtx **windows.CertContext, pbHashData *byte, cbHashData *uint32) (ret wingoes.HRESULT) { - r0, _, _ := syscall.Syscall6(procMsiGetFileSignatureInformationW.Addr(), 5, uintptr(unsafe.Pointer(signedObjectPath)), uintptr(flags), uintptr(unsafe.Pointer(certCtx)), uintptr(unsafe.Pointer(pbHashData)), uintptr(unsafe.Pointer(cbHashData)), 0) + r0, _, _ := syscall.SyscallN(procMsiGetFileSignatureInformationW.Addr(), uintptr(unsafe.Pointer(signedObjectPath)), uintptr(flags), uintptr(unsafe.Pointer(certCtx)), uintptr(unsafe.Pointer(pbHashData)), uintptr(unsafe.Pointer(cbHashData))) ret = wingoes.HRESULT(r0) return } func cryptCATAdminAcquireContext2(hCatAdmin *_HCATADMIN, pgSubsystem *windows.GUID, hashAlgorithm *uint16, strongHashPolicy *windows.CertStrongSignPara, flags uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procCryptCATAdminAcquireContext2.Addr(), 5, uintptr(unsafe.Pointer(hCatAdmin)), uintptr(unsafe.Pointer(pgSubsystem)), uintptr(unsafe.Pointer(hashAlgorithm)), uintptr(unsafe.Pointer(strongHashPolicy)), uintptr(flags), 0) + r1, _, e1 := syscall.SyscallN(procCryptCATAdminAcquireContext2.Addr(), uintptr(unsafe.Pointer(hCatAdmin)), uintptr(unsafe.Pointer(pgSubsystem)), uintptr(unsafe.Pointer(hashAlgorithm)), uintptr(unsafe.Pointer(strongHashPolicy)), uintptr(flags)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -94,7 +94,7 @@ func cryptCATAdminAcquireContext2(hCatAdmin *_HCATADMIN, pgSubsystem *windows.GU } func cryptCATAdminCalcHashFromFileHandle2(hCatAdmin _HCATADMIN, file windows.Handle, pcbHash *uint32, pbHash *byte, flags uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procCryptCATAdminCalcHashFromFileHandle2.Addr(), 5, uintptr(hCatAdmin), uintptr(file), uintptr(unsafe.Pointer(pcbHash)), uintptr(unsafe.Pointer(pbHash)), uintptr(flags), 0) + r1, _, e1 := syscall.SyscallN(procCryptCATAdminCalcHashFromFileHandle2.Addr(), uintptr(hCatAdmin), uintptr(file), uintptr(unsafe.Pointer(pcbHash)), uintptr(unsafe.Pointer(pbHash)), uintptr(flags)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -102,7 +102,7 @@ func cryptCATAdminCalcHashFromFileHandle2(hCatAdmin _HCATADMIN, file windows.Han } func cryptCATAdminEnumCatalogFromHash(hCatAdmin _HCATADMIN, pbHash *byte, cbHash uint32, flags uint32, prevCatInfo *_HCATINFO) (ret _HCATINFO, err error) { - r0, _, e1 := syscall.Syscall6(procCryptCATAdminEnumCatalogFromHash.Addr(), 5, uintptr(hCatAdmin), uintptr(unsafe.Pointer(pbHash)), uintptr(cbHash), uintptr(flags), uintptr(unsafe.Pointer(prevCatInfo)), 0) + r0, _, e1 := syscall.SyscallN(procCryptCATAdminEnumCatalogFromHash.Addr(), uintptr(hCatAdmin), uintptr(unsafe.Pointer(pbHash)), uintptr(cbHash), uintptr(flags), uintptr(unsafe.Pointer(prevCatInfo))) ret = _HCATINFO(r0) if ret == 0 { err = errnoErr(e1) @@ -111,7 +111,7 @@ func cryptCATAdminEnumCatalogFromHash(hCatAdmin _HCATADMIN, pbHash *byte, cbHash } func cryptCATAdminReleaseCatalogContext(hCatAdmin _HCATADMIN, hCatInfo _HCATINFO, flags uint32) (err error) { - r1, _, e1 := syscall.Syscall(procCryptCATAdminReleaseCatalogContext.Addr(), 3, uintptr(hCatAdmin), uintptr(hCatInfo), uintptr(flags)) + r1, _, e1 := syscall.SyscallN(procCryptCATAdminReleaseCatalogContext.Addr(), uintptr(hCatAdmin), uintptr(hCatInfo), uintptr(flags)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -119,7 +119,7 @@ func cryptCATAdminReleaseCatalogContext(hCatAdmin _HCATADMIN, hCatInfo _HCATINFO } func cryptCATAdminReleaseContext(hCatAdmin _HCATADMIN, flags uint32) (err error) { - r1, _, e1 := syscall.Syscall(procCryptCATAdminReleaseContext.Addr(), 2, uintptr(hCatAdmin), uintptr(flags), 0) + r1, _, e1 := syscall.SyscallN(procCryptCATAdminReleaseContext.Addr(), uintptr(hCatAdmin), uintptr(flags)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -127,7 +127,7 @@ func cryptCATAdminReleaseContext(hCatAdmin _HCATADMIN, flags uint32) (err error) } func cryptCATAdminCatalogInfoFromContext(hCatInfo _HCATINFO, catInfo *_CATALOG_INFO, flags uint32) (err error) { - r1, _, e1 := syscall.Syscall(procCryptCATCatalogInfoFromContext.Addr(), 3, uintptr(hCatInfo), uintptr(unsafe.Pointer(catInfo)), uintptr(flags)) + r1, _, e1 := syscall.SyscallN(procCryptCATCatalogInfoFromContext.Addr(), uintptr(hCatInfo), uintptr(unsafe.Pointer(catInfo)), uintptr(flags)) if int32(r1) == 0 { err = errnoErr(e1) } diff --git a/util/winutil/gp/gp_windows_test.go b/util/winutil/gp/gp_windows_test.go index e2520b46d56ae..f892068835bce 100644 --- a/util/winutil/gp/gp_windows_test.go +++ b/util/winutil/gp/gp_windows_test.go @@ -182,16 +182,16 @@ func doWithMachinePolicyLocked(t *testing.T, f func()) { f() } -func doWithCustomEnterLeaveFuncs(t *testing.T, f func(l *PolicyLock), enter func(bool) (policyLockHandle, error), leave func(policyLockHandle) error) { +func doWithCustomEnterLeaveFuncs(t *testing.T, f func(*PolicyLock), enter func(bool) (policyLockHandle, error), leave func(policyLockHandle) error) { t.Helper() - l := NewMachinePolicyLock() - l.enterFn, l.leaveFn = enter, leave + lock := NewMachinePolicyLock() + lock.enterFn, lock.leaveFn = enter, leave t.Cleanup(func() { - if err := l.Close(); err != nil { + if err := lock.Close(); err != nil { t.Fatalf("(*PolicyLock).Close failed: %v", err) } }) - f(l) + f(lock) } diff --git a/util/winutil/gp/policylock_windows.go b/util/winutil/gp/policylock_windows.go index 69c5ff01697f4..6c3ca0baf6d21 100644 --- a/util/winutil/gp/policylock_windows.go +++ b/util/winutil/gp/policylock_windows.go @@ -127,32 +127,32 @@ func NewUserPolicyLock(token windows.Token) (*PolicyLock, error) { return lock, nil } -// Lock locks l. -// It returns [ErrInvalidLockState] if l has a zero value or has already been closed, +// Lock locks lk. +// It returns [ErrInvalidLockState] if lk has a zero value or has already been closed, // [ErrLockRestricted] if the lock cannot be acquired due to a restriction in place, // or a [syscall.Errno] if the underlying Group Policy lock cannot be acquired. // // As a special case, it fails with [windows.ERROR_ACCESS_DENIED] -// if l is a user policy lock, and the corresponding user is not logged in +// if lk is a user policy lock, and the corresponding user is not logged in // interactively at the time of the call. -func (l *PolicyLock) Lock() error { +func (lk *PolicyLock) Lock() error { if policyLockRestricted.Load() > 0 { return ErrLockRestricted } - l.mu.Lock() - defer l.mu.Unlock() - if l.lockCnt.Add(2)&1 == 0 { + lk.mu.Lock() + defer lk.mu.Unlock() + if lk.lockCnt.Add(2)&1 == 0 { // The lock cannot be acquired because it has either never been properly // created or its Close method has already been called. However, we need // to call Unlock to both decrement lockCnt and leave the underlying // CriticalPolicySection if we won the race with another goroutine and // now own the lock. - l.Unlock() + lk.Unlock() return ErrInvalidLockState } - if l.handle != 0 { + if lk.handle != 0 { // The underlying CriticalPolicySection is already acquired. // It is an R-Lock (with the W-counterpart owned by the Group Policy service), // meaning that it can be acquired by multiple readers simultaneously. @@ -160,20 +160,20 @@ func (l *PolicyLock) Lock() error { return nil } - return l.lockSlow() + return lk.lockSlow() } // lockSlow calls enterCriticalPolicySection to acquire the underlying GP read lock. // It waits for either the lock to be acquired, or for the Close method to be called. // // l.mu must be held. -func (l *PolicyLock) lockSlow() (err error) { +func (lk *PolicyLock) lockSlow() (err error) { defer func() { if err != nil { // Decrement the counter if the lock cannot be acquired, // and complete the pending close request if we're the last owner. - if l.lockCnt.Add(-2) == 0 { - l.closeInternal() + if lk.lockCnt.Add(-2) == 0 { + lk.closeInternal() } } }() @@ -190,12 +190,12 @@ func (l *PolicyLock) lockSlow() (err error) { resultCh := make(chan policyLockResult) go func() { - closing := l.closing - if l.scope == UserPolicy && l.token != 0 { + closing := lk.closing + if lk.scope == UserPolicy && lk.token != 0 { // Impersonate the user whose critical policy section we want to acquire. runtime.LockOSThread() defer runtime.UnlockOSThread() - if err := impersonateLoggedOnUser(l.token); err != nil { + if err := impersonateLoggedOnUser(lk.token); err != nil { initCh <- err return } @@ -209,10 +209,10 @@ func (l *PolicyLock) lockSlow() (err error) { close(initCh) var machine bool - if l.scope == MachinePolicy { + if lk.scope == MachinePolicy { machine = true } - handle, err := l.enterFn(machine) + handle, err := lk.enterFn(machine) send_result: for { @@ -226,7 +226,7 @@ func (l *PolicyLock) lockSlow() (err error) { // The lock is being closed, and we lost the race to l.closing // it the calling goroutine. if err == nil { - l.leaveFn(handle) + lk.leaveFn(handle) } break send_result default: @@ -247,21 +247,21 @@ func (l *PolicyLock) lockSlow() (err error) { select { case result := <-resultCh: if result.err == nil { - l.handle = result.handle + lk.handle = result.handle } return result.err - case <-l.closing: + case <-lk.closing: return ErrInvalidLockState } } // Unlock unlocks l. // It panics if l is not locked on entry to Unlock. -func (l *PolicyLock) Unlock() { - l.mu.Lock() - defer l.mu.Unlock() +func (lk *PolicyLock) Unlock() { + lk.mu.Lock() + defer lk.mu.Unlock() - lockCnt := l.lockCnt.Add(-2) + lockCnt := lk.lockCnt.Add(-2) if lockCnt < 0 { panic("negative lockCnt") } @@ -273,33 +273,33 @@ func (l *PolicyLock) Unlock() { return } - if l.handle != 0 { + if lk.handle != 0 { // Impersonation is not required to unlock a critical policy section. // The handle we pass determines which mutex will be unlocked. - leaveCriticalPolicySection(l.handle) - l.handle = 0 + leaveCriticalPolicySection(lk.handle) + lk.handle = 0 } if lockCnt == 0 { // Complete the pending close request if there's no more readers. - l.closeInternal() + lk.closeInternal() } } // Close releases resources associated with l. // It is a no-op for the machine policy lock. -func (l *PolicyLock) Close() error { - lockCnt := l.lockCnt.Load() +func (lk *PolicyLock) Close() error { + lockCnt := lk.lockCnt.Load() if lockCnt&1 == 0 { // The lock has never been initialized, or close has already been called. return nil } - close(l.closing) + close(lk.closing) // Unset the LSB to indicate a pending close request. - for !l.lockCnt.CompareAndSwap(lockCnt, lockCnt&^int32(1)) { - lockCnt = l.lockCnt.Load() + for !lk.lockCnt.CompareAndSwap(lockCnt, lockCnt&^int32(1)) { + lockCnt = lk.lockCnt.Load() } if lockCnt != 0 { @@ -307,16 +307,16 @@ func (l *PolicyLock) Close() error { return nil } - return l.closeInternal() + return lk.closeInternal() } -func (l *PolicyLock) closeInternal() error { - if l.token != 0 { - if err := l.token.Close(); err != nil { +func (lk *PolicyLock) closeInternal() error { + if lk.token != 0 { + if err := lk.token.Close(); err != nil { return err } - l.token = 0 + lk.token = 0 } - l.closing = nil + lk.closing = nil return nil } diff --git a/util/winutil/gp/zsyscall_windows.go b/util/winutil/gp/zsyscall_windows.go index 5e40ec3d1e093..41c240c264e6d 100644 --- a/util/winutil/gp/zsyscall_windows.go +++ b/util/winutil/gp/zsyscall_windows.go @@ -50,7 +50,7 @@ var ( ) func impersonateLoggedOnUser(token windows.Token) (err error) { - r1, _, e1 := syscall.Syscall(procImpersonateLoggedOnUser.Addr(), 1, uintptr(token), 0, 0) + r1, _, e1 := syscall.SyscallN(procImpersonateLoggedOnUser.Addr(), uintptr(token)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -62,7 +62,7 @@ func enterCriticalPolicySection(machine bool) (handle policyLockHandle, err erro if machine { _p0 = 1 } - r0, _, e1 := syscall.Syscall(procEnterCriticalPolicySection.Addr(), 1, uintptr(_p0), 0, 0) + r0, _, e1 := syscall.SyscallN(procEnterCriticalPolicySection.Addr(), uintptr(_p0)) handle = policyLockHandle(r0) if int32(handle) == 0 { err = errnoErr(e1) @@ -71,7 +71,7 @@ func enterCriticalPolicySection(machine bool) (handle policyLockHandle, err erro } func leaveCriticalPolicySection(handle policyLockHandle) (err error) { - r1, _, e1 := syscall.Syscall(procLeaveCriticalPolicySection.Addr(), 1, uintptr(handle), 0, 0) + r1, _, e1 := syscall.SyscallN(procLeaveCriticalPolicySection.Addr(), uintptr(handle)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -83,7 +83,7 @@ func refreshPolicyEx(machine bool, flags uint32) (err error) { if machine { _p0 = 1 } - r1, _, e1 := syscall.Syscall(procRefreshPolicyEx.Addr(), 2, uintptr(_p0), uintptr(flags), 0) + r1, _, e1 := syscall.SyscallN(procRefreshPolicyEx.Addr(), uintptr(_p0), uintptr(flags)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -95,7 +95,7 @@ func registerGPNotification(event windows.Handle, machine bool) (err error) { if machine { _p0 = 1 } - r1, _, e1 := syscall.Syscall(procRegisterGPNotification.Addr(), 2, uintptr(event), uintptr(_p0), 0) + r1, _, e1 := syscall.SyscallN(procRegisterGPNotification.Addr(), uintptr(event), uintptr(_p0)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -103,7 +103,7 @@ func registerGPNotification(event windows.Handle, machine bool) (err error) { } func unregisterGPNotification(event windows.Handle) (err error) { - r1, _, e1 := syscall.Syscall(procUnregisterGPNotification.Addr(), 1, uintptr(event), 0, 0) + r1, _, e1 := syscall.SyscallN(procUnregisterGPNotification.Addr(), uintptr(event)) if int32(r1) == 0 { err = errnoErr(e1) } diff --git a/util/winutil/s4u/lsa_windows.go b/util/winutil/s4u/lsa_windows.go index 3ff2171f91d70..3276b26766c08 100644 --- a/util/winutil/s4u/lsa_windows.go +++ b/util/winutil/s4u/lsa_windows.go @@ -256,8 +256,8 @@ func checkDomainAccount(username string) (sanitizedUserName string, isDomainAcco // errors.Is to check for it. When capLevel == CapCreateProcess, the logon // enforces the user's logon hours policy (when present). func (ls *lsaSession) logonAs(srcName string, u *user.User, capLevel CapabilityLevel) (token windows.Token, err error) { - if l := len(srcName); l == 0 || l > _TOKEN_SOURCE_LENGTH { - return 0, fmt.Errorf("%w, actual length is %d", ErrBadSrcName, l) + if ln := len(srcName); ln == 0 || ln > _TOKEN_SOURCE_LENGTH { + return 0, fmt.Errorf("%w, actual length is %d", ErrBadSrcName, ln) } if err := checkASCII(srcName); err != nil { return 0, fmt.Errorf("%w: %v", ErrBadSrcName, err) diff --git a/util/winutil/s4u/s4u_windows.go b/util/winutil/s4u/s4u_windows.go index 8926aaedc5071..8c8e02dbe83bc 100644 --- a/util/winutil/s4u/s4u_windows.go +++ b/util/winutil/s4u/s4u_windows.go @@ -938,10 +938,10 @@ func mergeEnv(existingEnv []string, extraEnv map[string]string) []string { result = append(result, strings.Join([]string{k, v}, "=")) } - slices.SortFunc(result, func(l, r string) int { - kl, _, _ := strings.Cut(l, "=") - kr, _, _ := strings.Cut(r, "=") - return strings.Compare(kl, kr) + slices.SortFunc(result, func(a, b string) int { + ka, _, _ := strings.Cut(a, "=") + kb, _, _ := strings.Cut(b, "=") + return strings.Compare(ka, kb) }) return result } diff --git a/util/winutil/s4u/zsyscall_windows.go b/util/winutil/s4u/zsyscall_windows.go index 6a8c78427dbd3..db647dee483e2 100644 --- a/util/winutil/s4u/zsyscall_windows.go +++ b/util/winutil/s4u/zsyscall_windows.go @@ -52,7 +52,7 @@ var ( ) func allocateLocallyUniqueId(luid *windows.LUID) (err error) { - r1, _, e1 := syscall.Syscall(procAllocateLocallyUniqueId.Addr(), 1, uintptr(unsafe.Pointer(luid)), 0, 0) + r1, _, e1 := syscall.SyscallN(procAllocateLocallyUniqueId.Addr(), uintptr(unsafe.Pointer(luid))) if int32(r1) == 0 { err = errnoErr(e1) } @@ -60,7 +60,7 @@ func allocateLocallyUniqueId(luid *windows.LUID) (err error) { } func impersonateLoggedOnUser(token windows.Token) (err error) { - r1, _, e1 := syscall.Syscall(procImpersonateLoggedOnUser.Addr(), 1, uintptr(token), 0, 0) + r1, _, e1 := syscall.SyscallN(procImpersonateLoggedOnUser.Addr(), uintptr(token)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -68,37 +68,37 @@ func impersonateLoggedOnUser(token windows.Token) (err error) { } func lsaConnectUntrusted(lsaHandle *_LSAHANDLE) (ret windows.NTStatus) { - r0, _, _ := syscall.Syscall(procLsaConnectUntrusted.Addr(), 1, uintptr(unsafe.Pointer(lsaHandle)), 0, 0) + r0, _, _ := syscall.SyscallN(procLsaConnectUntrusted.Addr(), uintptr(unsafe.Pointer(lsaHandle))) ret = windows.NTStatus(r0) return } func lsaDeregisterLogonProcess(lsaHandle _LSAHANDLE) (ret windows.NTStatus) { - r0, _, _ := syscall.Syscall(procLsaDeregisterLogonProcess.Addr(), 1, uintptr(lsaHandle), 0, 0) + r0, _, _ := syscall.SyscallN(procLsaDeregisterLogonProcess.Addr(), uintptr(lsaHandle)) ret = windows.NTStatus(r0) return } func lsaFreeReturnBuffer(buffer uintptr) (ret windows.NTStatus) { - r0, _, _ := syscall.Syscall(procLsaFreeReturnBuffer.Addr(), 1, uintptr(buffer), 0, 0) + r0, _, _ := syscall.SyscallN(procLsaFreeReturnBuffer.Addr(), uintptr(buffer)) ret = windows.NTStatus(r0) return } func lsaLogonUser(lsaHandle _LSAHANDLE, originName *windows.NTString, logonType _SECURITY_LOGON_TYPE, authenticationPackage uint32, authenticationInformation unsafe.Pointer, authenticationInformationLength uint32, localGroups *windows.Tokengroups, sourceContext *_TOKEN_SOURCE, profileBuffer *uintptr, profileBufferLength *uint32, logonID *windows.LUID, token *windows.Token, quotas *_QUOTA_LIMITS, subStatus *windows.NTStatus) (ret windows.NTStatus) { - r0, _, _ := syscall.Syscall15(procLsaLogonUser.Addr(), 14, uintptr(lsaHandle), uintptr(unsafe.Pointer(originName)), uintptr(logonType), uintptr(authenticationPackage), uintptr(authenticationInformation), uintptr(authenticationInformationLength), uintptr(unsafe.Pointer(localGroups)), uintptr(unsafe.Pointer(sourceContext)), uintptr(unsafe.Pointer(profileBuffer)), uintptr(unsafe.Pointer(profileBufferLength)), uintptr(unsafe.Pointer(logonID)), uintptr(unsafe.Pointer(token)), uintptr(unsafe.Pointer(quotas)), uintptr(unsafe.Pointer(subStatus)), 0) + r0, _, _ := syscall.SyscallN(procLsaLogonUser.Addr(), uintptr(lsaHandle), uintptr(unsafe.Pointer(originName)), uintptr(logonType), uintptr(authenticationPackage), uintptr(authenticationInformation), uintptr(authenticationInformationLength), uintptr(unsafe.Pointer(localGroups)), uintptr(unsafe.Pointer(sourceContext)), uintptr(unsafe.Pointer(profileBuffer)), uintptr(unsafe.Pointer(profileBufferLength)), uintptr(unsafe.Pointer(logonID)), uintptr(unsafe.Pointer(token)), uintptr(unsafe.Pointer(quotas)), uintptr(unsafe.Pointer(subStatus))) ret = windows.NTStatus(r0) return } func lsaLookupAuthenticationPackage(lsaHandle _LSAHANDLE, packageName *windows.NTString, authenticationPackage *uint32) (ret windows.NTStatus) { - r0, _, _ := syscall.Syscall(procLsaLookupAuthenticationPackage.Addr(), 3, uintptr(lsaHandle), uintptr(unsafe.Pointer(packageName)), uintptr(unsafe.Pointer(authenticationPackage))) + r0, _, _ := syscall.SyscallN(procLsaLookupAuthenticationPackage.Addr(), uintptr(lsaHandle), uintptr(unsafe.Pointer(packageName)), uintptr(unsafe.Pointer(authenticationPackage))) ret = windows.NTStatus(r0) return } func lsaRegisterLogonProcess(logonProcessName *windows.NTString, lsaHandle *_LSAHANDLE, securityMode *_LSA_OPERATIONAL_MODE) (ret windows.NTStatus) { - r0, _, _ := syscall.Syscall(procLsaRegisterLogonProcess.Addr(), 3, uintptr(unsafe.Pointer(logonProcessName)), uintptr(unsafe.Pointer(lsaHandle)), uintptr(unsafe.Pointer(securityMode))) + r0, _, _ := syscall.SyscallN(procLsaRegisterLogonProcess.Addr(), uintptr(unsafe.Pointer(logonProcessName)), uintptr(unsafe.Pointer(lsaHandle)), uintptr(unsafe.Pointer(securityMode))) ret = windows.NTStatus(r0) return } diff --git a/util/winutil/startupinfo_windows.go b/util/winutil/startupinfo_windows.go index e04e9ea9b3d3a..edf48fa651cb5 100644 --- a/util/winutil/startupinfo_windows.go +++ b/util/winutil/startupinfo_windows.go @@ -83,8 +83,8 @@ func (sib *StartupInfoBuilder) Resolve() (startupInfo *windows.StartupInfo, inhe // Always create a Unicode environment. createProcessFlags = windows.CREATE_UNICODE_ENVIRONMENT - if l := uint32(len(sib.attrs)); l > 0 { - attrCont, err := windows.NewProcThreadAttributeList(l) + if ln := uint32(len(sib.attrs)); ln > 0 { + attrCont, err := windows.NewProcThreadAttributeList(ln) if err != nil { return nil, false, 0, err } diff --git a/util/winutil/winenv/zsyscall_windows.go b/util/winutil/winenv/zsyscall_windows.go index 2bdfdd9b1180b..7e93c7952f32e 100644 --- a/util/winutil/winenv/zsyscall_windows.go +++ b/util/winutil/winenv/zsyscall_windows.go @@ -55,7 +55,7 @@ func isDeviceRegisteredWithManagement(isMDMRegistered *bool, upnBufLen uint32, u if *isMDMRegistered { _p0 = 1 } - r0, _, e1 := syscall.Syscall(procIsDeviceRegisteredWithManagement.Addr(), 3, uintptr(unsafe.Pointer(&_p0)), uintptr(upnBufLen), uintptr(unsafe.Pointer(upnBuf))) + r0, _, e1 := syscall.SyscallN(procIsDeviceRegisteredWithManagement.Addr(), uintptr(unsafe.Pointer(&_p0)), uintptr(upnBufLen), uintptr(unsafe.Pointer(upnBuf))) *isMDMRegistered = _p0 != 0 hr = int32(r0) if hr == 0 { @@ -65,13 +65,13 @@ func isDeviceRegisteredWithManagement(isMDMRegistered *bool, upnBufLen uint32, u } func verSetConditionMask(condMask verCondMask, typ verTypeMask, cond verCond) (res verCondMask) { - r0, _, _ := syscall.Syscall(procVerSetConditionMask.Addr(), 3, uintptr(condMask), uintptr(typ), uintptr(cond)) + r0, _, _ := syscall.SyscallN(procVerSetConditionMask.Addr(), uintptr(condMask), uintptr(typ), uintptr(cond)) res = verCondMask(r0) return } func verifyVersionInfo(verInfo *osVersionInfoEx, typ verTypeMask, cond verCondMask) (res bool) { - r0, _, _ := syscall.Syscall(procVerifyVersionInfoW.Addr(), 3, uintptr(unsafe.Pointer(verInfo)), uintptr(typ), uintptr(cond)) + r0, _, _ := syscall.SyscallN(procVerifyVersionInfoW.Addr(), uintptr(unsafe.Pointer(verInfo)), uintptr(typ), uintptr(cond)) res = r0 != 0 return } diff --git a/util/winutil/winutil_windows_test.go b/util/winutil/winutil_windows_test.go index d437ffa383d82..ead10a45d7ee8 100644 --- a/util/winutil/winutil_windows_test.go +++ b/util/winutil/winutil_windows_test.go @@ -68,8 +68,8 @@ func checkContiguousBuffer[T any, BU BufUnit](t *testing.T, extra []BU, pt *T, p if gotLen := int(ptLen); gotLen != expectedLen { t.Errorf("allocation length got %d, want %d", gotLen, expectedLen) } - if l := len(slcs); l != 1 { - t.Errorf("len(slcs) got %d, want 1", l) + if ln := len(slcs); ln != 1 { + t.Errorf("len(slcs) got %d, want 1", ln) } if len(extra) == 0 && slcs[0] != nil { t.Error("slcs[0] got non-nil, want nil") diff --git a/util/winutil/zsyscall_windows.go b/util/winutil/zsyscall_windows.go index b4674dff340ec..56aedb4c7f59c 100644 --- a/util/winutil/zsyscall_windows.go +++ b/util/winutil/zsyscall_windows.go @@ -62,7 +62,7 @@ var ( ) func queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procQueryServiceConfig2W.Addr(), 5, uintptr(hService), uintptr(infoLevel), uintptr(unsafe.Pointer(buf)), uintptr(bufLen), uintptr(unsafe.Pointer(bytesNeeded)), 0) + r1, _, e1 := syscall.SyscallN(procQueryServiceConfig2W.Addr(), uintptr(hService), uintptr(infoLevel), uintptr(unsafe.Pointer(buf)), uintptr(bufLen), uintptr(unsafe.Pointer(bytesNeeded))) if r1 == 0 { err = errnoErr(e1) } @@ -70,19 +70,19 @@ func queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, b } func getApplicationRestartSettings(process windows.Handle, commandLine *uint16, commandLineLen *uint32, flags *uint32) (ret wingoes.HRESULT) { - r0, _, _ := syscall.Syscall6(procGetApplicationRestartSettings.Addr(), 4, uintptr(process), uintptr(unsafe.Pointer(commandLine)), uintptr(unsafe.Pointer(commandLineLen)), uintptr(unsafe.Pointer(flags)), 0, 0) + r0, _, _ := syscall.SyscallN(procGetApplicationRestartSettings.Addr(), uintptr(process), uintptr(unsafe.Pointer(commandLine)), uintptr(unsafe.Pointer(commandLineLen)), uintptr(unsafe.Pointer(flags))) ret = wingoes.HRESULT(r0) return } func registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret wingoes.HRESULT) { - r0, _, _ := syscall.Syscall(procRegisterApplicationRestart.Addr(), 2, uintptr(unsafe.Pointer(cmdLineExclExeName)), uintptr(flags), 0) + r0, _, _ := syscall.SyscallN(procRegisterApplicationRestart.Addr(), uintptr(unsafe.Pointer(cmdLineExclExeName)), uintptr(flags)) ret = wingoes.HRESULT(r0) return } func dsGetDcName(computerName *uint16, domainName *uint16, domainGuid *windows.GUID, siteName *uint16, flags dsGetDcNameFlag, dcInfo **_DOMAIN_CONTROLLER_INFO) (ret error) { - r0, _, _ := syscall.Syscall6(procDsGetDcNameW.Addr(), 6, uintptr(unsafe.Pointer(computerName)), uintptr(unsafe.Pointer(domainName)), uintptr(unsafe.Pointer(domainGuid)), uintptr(unsafe.Pointer(siteName)), uintptr(flags), uintptr(unsafe.Pointer(dcInfo))) + r0, _, _ := syscall.SyscallN(procDsGetDcNameW.Addr(), uintptr(unsafe.Pointer(computerName)), uintptr(unsafe.Pointer(domainName)), uintptr(unsafe.Pointer(domainGuid)), uintptr(unsafe.Pointer(siteName)), uintptr(flags), uintptr(unsafe.Pointer(dcInfo))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -90,7 +90,7 @@ func dsGetDcName(computerName *uint16, domainName *uint16, domainGuid *windows.G } func netValidateName(server *uint16, name *uint16, account *uint16, password *uint16, nameType _NETSETUP_NAME_TYPE) (ret error) { - r0, _, _ := syscall.Syscall6(procNetValidateName.Addr(), 5, uintptr(unsafe.Pointer(server)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(account)), uintptr(unsafe.Pointer(password)), uintptr(nameType), 0) + r0, _, _ := syscall.SyscallN(procNetValidateName.Addr(), uintptr(unsafe.Pointer(server)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(account)), uintptr(unsafe.Pointer(password)), uintptr(nameType)) if r0 != 0 { ret = syscall.Errno(r0) } @@ -98,7 +98,7 @@ func netValidateName(server *uint16, name *uint16, account *uint16, password *ui } func rmEndSession(session _RMHANDLE) (ret error) { - r0, _, _ := syscall.Syscall(procRmEndSession.Addr(), 1, uintptr(session), 0, 0) + r0, _, _ := syscall.SyscallN(procRmEndSession.Addr(), uintptr(session)) if r0 != 0 { ret = syscall.Errno(r0) } @@ -106,7 +106,7 @@ func rmEndSession(session _RMHANDLE) (ret error) { } func rmGetList(session _RMHANDLE, nProcInfoNeeded *uint32, nProcInfo *uint32, rgAffectedApps *_RM_PROCESS_INFO, pRebootReasons *uint32) (ret error) { - r0, _, _ := syscall.Syscall6(procRmGetList.Addr(), 5, uintptr(session), uintptr(unsafe.Pointer(nProcInfoNeeded)), uintptr(unsafe.Pointer(nProcInfo)), uintptr(unsafe.Pointer(rgAffectedApps)), uintptr(unsafe.Pointer(pRebootReasons)), 0) + r0, _, _ := syscall.SyscallN(procRmGetList.Addr(), uintptr(session), uintptr(unsafe.Pointer(nProcInfoNeeded)), uintptr(unsafe.Pointer(nProcInfo)), uintptr(unsafe.Pointer(rgAffectedApps)), uintptr(unsafe.Pointer(pRebootReasons))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -114,7 +114,7 @@ func rmGetList(session _RMHANDLE, nProcInfoNeeded *uint32, nProcInfo *uint32, rg } func rmJoinSession(pSession *_RMHANDLE, sessionKey *uint16) (ret error) { - r0, _, _ := syscall.Syscall(procRmJoinSession.Addr(), 2, uintptr(unsafe.Pointer(pSession)), uintptr(unsafe.Pointer(sessionKey)), 0) + r0, _, _ := syscall.SyscallN(procRmJoinSession.Addr(), uintptr(unsafe.Pointer(pSession)), uintptr(unsafe.Pointer(sessionKey))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -122,7 +122,7 @@ func rmJoinSession(pSession *_RMHANDLE, sessionKey *uint16) (ret error) { } func rmRegisterResources(session _RMHANDLE, nFiles uint32, rgsFileNames **uint16, nApplications uint32, rgApplications *_RM_UNIQUE_PROCESS, nServices uint32, rgsServiceNames **uint16) (ret error) { - r0, _, _ := syscall.Syscall9(procRmRegisterResources.Addr(), 7, uintptr(session), uintptr(nFiles), uintptr(unsafe.Pointer(rgsFileNames)), uintptr(nApplications), uintptr(unsafe.Pointer(rgApplications)), uintptr(nServices), uintptr(unsafe.Pointer(rgsServiceNames)), 0, 0) + r0, _, _ := syscall.SyscallN(procRmRegisterResources.Addr(), uintptr(session), uintptr(nFiles), uintptr(unsafe.Pointer(rgsFileNames)), uintptr(nApplications), uintptr(unsafe.Pointer(rgApplications)), uintptr(nServices), uintptr(unsafe.Pointer(rgsServiceNames))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -130,7 +130,7 @@ func rmRegisterResources(session _RMHANDLE, nFiles uint32, rgsFileNames **uint16 } func rmStartSession(pSession *_RMHANDLE, flags uint32, sessionKey *uint16) (ret error) { - r0, _, _ := syscall.Syscall(procRmStartSession.Addr(), 3, uintptr(unsafe.Pointer(pSession)), uintptr(flags), uintptr(unsafe.Pointer(sessionKey))) + r0, _, _ := syscall.SyscallN(procRmStartSession.Addr(), uintptr(unsafe.Pointer(pSession)), uintptr(flags), uintptr(unsafe.Pointer(sessionKey))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -138,7 +138,7 @@ func rmStartSession(pSession *_RMHANDLE, flags uint32, sessionKey *uint16) (ret } func expandEnvironmentStringsForUser(token windows.Token, src *uint16, dst *uint16, dstLen uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procExpandEnvironmentStringsForUserW.Addr(), 4, uintptr(token), uintptr(unsafe.Pointer(src)), uintptr(unsafe.Pointer(dst)), uintptr(dstLen), 0, 0) + r1, _, e1 := syscall.SyscallN(procExpandEnvironmentStringsForUserW.Addr(), uintptr(token), uintptr(unsafe.Pointer(src)), uintptr(unsafe.Pointer(dst)), uintptr(dstLen)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -146,7 +146,7 @@ func expandEnvironmentStringsForUser(token windows.Token, src *uint16, dst *uint } func loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) { - r1, _, e1 := syscall.Syscall(procLoadUserProfileW.Addr(), 2, uintptr(token), uintptr(unsafe.Pointer(profileInfo)), 0) + r1, _, e1 := syscall.SyscallN(procLoadUserProfileW.Addr(), uintptr(token), uintptr(unsafe.Pointer(profileInfo))) if int32(r1) == 0 { err = errnoErr(e1) } @@ -154,7 +154,7 @@ func loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) } func unloadUserProfile(token windows.Token, profile registry.Key) (err error) { - r1, _, e1 := syscall.Syscall(procUnloadUserProfile.Addr(), 2, uintptr(token), uintptr(profile), 0) + r1, _, e1 := syscall.SyscallN(procUnloadUserProfile.Addr(), uintptr(token), uintptr(profile)) if int32(r1) == 0 { err = errnoErr(e1) } diff --git a/wf/firewall.go b/wf/firewall.go index 076944c8decad..07e160eb36071 100644 --- a/wf/firewall.go +++ b/wf/firewall.go @@ -18,7 +18,7 @@ import ( // Known addresses. var ( - linkLocalRange = netip.MustParsePrefix("ff80::/10") + linkLocalRange = netip.MustParsePrefix("fe80::/10") linkLocalDHCPMulticast = netip.MustParseAddr("ff02::1:2") siteLocalDHCPMulticast = netip.MustParseAddr("ff05::1:3") linkLocalRouterMulticast = netip.MustParseAddr("ff02::2") @@ -66,8 +66,8 @@ func (p protocol) getLayers(d direction) []wf.LayerID { return layers } -func ruleName(action wf.Action, l wf.LayerID, name string) string { - switch l { +func ruleName(action wf.Action, layerID wf.LayerID, name string) string { + switch layerID { case wf.LayerALEAuthConnectV4: return fmt.Sprintf("%s outbound %s (IPv4)", action, name) case wf.LayerALEAuthConnectV6: @@ -307,8 +307,8 @@ func (f *Firewall) newRule(name string, w weight, layer wf.LayerID, conditions [ func (f *Firewall) addRules(name string, w weight, conditions []*wf.Match, action wf.Action, p protocol, d direction) ([]*wf.Rule, error) { var rules []*wf.Rule - for _, l := range p.getLayers(d) { - r, err := f.newRule(name, w, l, conditions, action) + for _, layer := range p.getLayers(d) { + r, err := f.newRule(name, w, layer, conditions, action) if err != nil { return nil, err } diff --git a/wgengine/bench/wg.go b/wgengine/bench/wg.go index 4de7677f26257..ce6add866f9e8 100644 --- a/wgengine/bench/wg.go +++ b/wgengine/bench/wg.go @@ -38,7 +38,6 @@ func setupWGTest(b *testing.B, logf logger.Logf, traf *TrafficGen, a1, a2 netip. k1 := key.NewNode() c1 := wgcfg.Config{ - Name: "e1", PrivateKey: k1, Addresses: []netip.Prefix{a1}, } @@ -65,7 +64,6 @@ func setupWGTest(b *testing.B, logf logger.Logf, traf *TrafficGen, a1, a2 netip. l2 := logger.WithPrefix(logf, "e2: ") k2 := key.NewNode() c2 := wgcfg.Config{ - Name: "e2", PrivateKey: k2, Addresses: []netip.Prefix{a2}, } @@ -113,9 +111,8 @@ func setupWGTest(b *testing.B, logf logger.Logf, traf *TrafficGen, a1, a2 netip. Endpoints: epFromTyped(st.LocalAddrs), } e2.SetNetworkMap(&netmap.NetworkMap{ - NodeKey: k2.Public(), - PrivateKey: k2, - Peers: []tailcfg.NodeView{n.View()}, + NodeKey: k2.Public(), + Peers: []tailcfg.NodeView{n.View()}, }) p := wgcfg.Peer{ @@ -145,9 +142,8 @@ func setupWGTest(b *testing.B, logf logger.Logf, traf *TrafficGen, a1, a2 netip. Endpoints: epFromTyped(st.LocalAddrs), } e1.SetNetworkMap(&netmap.NetworkMap{ - NodeKey: k1.Public(), - PrivateKey: k1, - Peers: []tailcfg.NodeView{n.View()}, + NodeKey: k1.Public(), + Peers: []tailcfg.NodeView{n.View()}, }) p := wgcfg.Peer{ diff --git a/wgengine/magicsock/blockforever_conn.go b/wgengine/magicsock/blockforever_conn.go index f2e85dcd57002..272a12513b353 100644 --- a/wgengine/magicsock/blockforever_conn.go +++ b/wgengine/magicsock/blockforever_conn.go @@ -10,11 +10,13 @@ import ( "sync" "syscall" "time" + + "tailscale.com/syncs" ) // blockForeverConn is a net.PacketConn whose reads block until it is closed. type blockForeverConn struct { - mu sync.Mutex + mu syncs.Mutex cond *sync.Cond closed bool } diff --git a/wgengine/magicsock/disco_atomic.go b/wgengine/magicsock/disco_atomic.go new file mode 100644 index 0000000000000..5b765fbc2c9a0 --- /dev/null +++ b/wgengine/magicsock/disco_atomic.go @@ -0,0 +1,58 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "sync/atomic" + + "tailscale.com/types/key" +) + +type discoKeyPair struct { + private key.DiscoPrivate + public key.DiscoPublic + short string // public.ShortString() +} + +// discoAtomic is an atomic container for a disco private key, public key, and +// the public key's ShortString. The private and public keys are always kept +// synchronized. +// +// The zero value is not ready for use. Use [Set] to provide a usable value. +type discoAtomic struct { + pair atomic.Pointer[discoKeyPair] +} + +// Pair returns the private and public keys together atomically. +// Code that needs both the private and public keys synchronized should +// use Pair instead of calling Private and Public separately. +func (dk *discoAtomic) Pair() (key.DiscoPrivate, key.DiscoPublic) { + p := dk.pair.Load() + return p.private, p.public +} + +// Private returns the private key. +func (dk *discoAtomic) Private() key.DiscoPrivate { + return dk.pair.Load().private +} + +// Public returns the public key. +func (dk *discoAtomic) Public() key.DiscoPublic { + return dk.pair.Load().public +} + +// Short returns the short string of the public key (see [DiscoPublic.ShortString]). +func (dk *discoAtomic) Short() string { + return dk.pair.Load().short +} + +// Set updates the private key (and the cached public key and short string). +func (dk *discoAtomic) Set(private key.DiscoPrivate) { + public := private.Public() + dk.pair.Store(&discoKeyPair{ + private: private, + public: public, + short: public.ShortString(), + }) +} diff --git a/wgengine/magicsock/disco_atomic_test.go b/wgengine/magicsock/disco_atomic_test.go new file mode 100644 index 0000000000000..a1de9b843379f --- /dev/null +++ b/wgengine/magicsock/disco_atomic_test.go @@ -0,0 +1,70 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "testing" + + "tailscale.com/types/key" +) + +func TestDiscoAtomic(t *testing.T) { + var dk discoAtomic + dk.Set(key.NewDisco()) + + private := dk.Private() + public := dk.Public() + short := dk.Short() + + if private.IsZero() { + t.Fatal("DiscoKey private key should not be zero") + } + if public.IsZero() { + t.Fatal("DiscoKey public key should not be zero") + } + if short == "" { + t.Fatal("DiscoKey short string should not be empty") + } + + if public != private.Public() { + t.Fatal("DiscoKey public key doesn't match private key") + } + if short != public.ShortString() { + t.Fatal("DiscoKey short string doesn't match public key") + } + + gotPrivate, gotPublic := dk.Pair() + if !gotPrivate.Equal(private) { + t.Fatal("Pair() returned different private key") + } + if gotPublic != public { + t.Fatal("Pair() returned different public key") + } +} + +func TestDiscoAtomicSet(t *testing.T) { + var dk discoAtomic + dk.Set(key.NewDisco()) + oldPrivate := dk.Private() + oldPublic := dk.Public() + + newPrivate := key.NewDisco() + dk.Set(newPrivate) + + currentPrivate := dk.Private() + currentPublic := dk.Public() + + if currentPrivate.Equal(oldPrivate) { + t.Fatal("DiscoKey private key should have changed after Set") + } + if currentPublic == oldPublic { + t.Fatal("DiscoKey public key should have changed after Set") + } + if !currentPrivate.Equal(newPrivate) { + t.Fatal("DiscoKey private key doesn't match the set key") + } + if currentPublic != newPrivate.Public() { + t.Fatal("DiscoKey public key doesn't match derived from set private key") + } +} diff --git a/wgengine/magicsock/discopingpurpose_string.go b/wgengine/magicsock/discopingpurpose_string.go index 3dc327de1d2ae..8eebf97a2dbd9 100644 --- a/wgengine/magicsock/discopingpurpose_string.go +++ b/wgengine/magicsock/discopingpurpose_string.go @@ -22,8 +22,9 @@ const _discoPingPurpose_name = "DiscoveryHeartbeatCLIHeartbeatForUDPLifetime" var _discoPingPurpose_index = [...]uint8{0, 9, 18, 21, 44} func (i discoPingPurpose) String() string { - if i < 0 || i >= discoPingPurpose(len(_discoPingPurpose_index)-1) { + idx := int(i) - 0 + if i < 0 || idx >= len(_discoPingPurpose_index)-1 { return "discoPingPurpose(" + strconv.FormatInt(int64(i), 10) + ")" } - return _discoPingPurpose_name[_discoPingPurpose_index[i]:_discoPingPurpose_index[i+1]] + return _discoPingPurpose_name[_discoPingPurpose_index[idx]:_discoPingPurpose_index[idx+1]] } diff --git a/wgengine/magicsock/endpoint.go b/wgengine/magicsock/endpoint.go index 2010775a10d6e..eda589e14b1b6 100644 --- a/wgengine/magicsock/endpoint.go +++ b/wgengine/magicsock/endpoint.go @@ -17,7 +17,6 @@ import ( "reflect" "runtime" "slices" - "sync" "sync/atomic" "time" @@ -28,6 +27,7 @@ import ( "tailscale.com/net/packet" "tailscale.com/net/stun" "tailscale.com/net/tstun" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tstime/mono" "tailscale.com/types/key" @@ -73,7 +73,7 @@ type endpoint struct { disco atomic.Pointer[endpointDisco] // if the peer supports disco, the key and short string // mu protects all following fields. - mu sync.Mutex // Lock ordering: Conn.mu, then endpoint.mu + mu syncs.Mutex // Lock ordering: Conn.mu, then endpoint.mu heartBeatTimer *time.Timer // nil when idle lastSendExt mono.Time // last time there were outgoing packets sent to this peer from an external trigger (e.g. wireguard-go or disco pingCLI) @@ -697,7 +697,7 @@ func (de *endpoint) maybeProbeUDPLifetimeLocked() (afterInactivityFor time.Durat // shuffling probing probability where the local node ends up with a large // key value lexicographically relative to the other nodes it tends to // communicate with. If de's disco key changes, the cycle will reset. - if de.c.discoPublic.Compare(epDisco.key) >= 0 { + if de.c.discoAtomic.Public().Compare(epDisco.key) >= 0 { // lower disco pub key node probes higher return afterInactivityFor, false } @@ -1739,7 +1739,7 @@ func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src epAdd } if sp.purpose != pingHeartbeat && sp.purpose != pingHeartbeatForUDPLifetime { - de.c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got pong tx=%x latency=%v pktlen=%v pong.src=%v%v", de.c.discoShort, de.discoShort(), de.publicKey.ShortString(), src, m.TxID[:6], latency.Round(time.Millisecond), pktLen, m.Src, logger.ArgWriter(func(bw *bufio.Writer) { + de.c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got pong tx=%x latency=%v pktlen=%v pong.src=%v%v", de.c.discoAtomic.Short(), de.discoShort(), de.publicKey.ShortString(), src, m.TxID[:6], latency.Round(time.Millisecond), pktLen, m.Src, logger.ArgWriter(func(bw *bufio.Writer) { if sp.to != src { fmt.Fprintf(bw, " ping.to=%v", sp.to) } diff --git a/wgengine/magicsock/endpoint_test.go b/wgengine/magicsock/endpoint_test.go index df1c9340657e4..f1dab924f5d3b 100644 --- a/wgengine/magicsock/endpoint_test.go +++ b/wgengine/magicsock/endpoint_test.go @@ -146,15 +146,22 @@ func TestProbeUDPLifetimeConfig_Valid(t *testing.T) { } func Test_endpoint_maybeProbeUDPLifetimeLocked(t *testing.T) { + var lowerPriv, higherPriv key.DiscoPrivate var lower, higher key.DiscoPublic - a := key.NewDisco().Public() - b := key.NewDisco().Public() + privA := key.NewDisco() + privB := key.NewDisco() + a := privA.Public() + b := privB.Public() if a.String() < b.String() { lower = a higher = b + lowerPriv = privA + higherPriv = privB } else { lower = b higher = a + lowerPriv = privB + higherPriv = privA } addr := addrQuality{epAddr: epAddr{ap: netip.MustParseAddrPort("1.1.1.1:1")}} newProbeUDPLifetime := func() *probeUDPLifetime { @@ -281,10 +288,18 @@ func Test_endpoint_maybeProbeUDPLifetimeLocked(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + c := &Conn{} + if tt.localDisco.IsZero() { + c.discoAtomic.Set(key.NewDisco()) + } else if tt.localDisco.Compare(lower) == 0 { + c.discoAtomic.Set(lowerPriv) + } else if tt.localDisco.Compare(higher) == 0 { + c.discoAtomic.Set(higherPriv) + } else { + t.Fatalf("unexpected localDisco value") + } de := &endpoint{ - c: &Conn{ - discoPublic: tt.localDisco, - }, + c: c, bestAddr: tt.bestAddr, } if tt.remoteDisco != nil { diff --git a/wgengine/magicsock/endpoint_tracker.go b/wgengine/magicsock/endpoint_tracker.go index 5caddd1a06960..e95852d2491b7 100644 --- a/wgengine/magicsock/endpoint_tracker.go +++ b/wgengine/magicsock/endpoint_tracker.go @@ -6,9 +6,9 @@ package magicsock import ( "net/netip" "slices" - "sync" "time" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tempfork/heap" "tailscale.com/util/mak" @@ -107,7 +107,7 @@ func (eh endpointHeap) Min() *endpointTrackerEntry { // // See tailscale/tailscale#7877 for more information. type endpointTracker struct { - mu sync.Mutex + mu syncs.Mutex endpoints map[netip.Addr]*endpointHeap } diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 6ee14164d0a99..064838a2d540c 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -273,20 +273,14 @@ type Conn struct { // channel operations and goroutine creation. hasPeerRelayServers atomic.Bool - // discoPrivate is the private naclbox key used for active - // discovery traffic. It is always present, and immutable. - discoPrivate key.DiscoPrivate - // public of discoPrivate. It is always present and immutable. - discoPublic key.DiscoPublic - // ShortString of discoPublic (to save logging work later). It is always - // present and immutable. - discoShort string + // discoAtomic is the current disco private and public keypair for this conn. + discoAtomic discoAtomic // ============================================================ // mu guards all following fields; see userspaceEngine lock // ordering rules against the engine. For derphttp, mu must // be held before derphttp.Client.mu. - mu sync.Mutex + mu syncs.Mutex muCond *sync.Cond onlyTCP443 atomic.Bool @@ -603,11 +597,9 @@ func newConn(logf logger.Logf) *Conn { peerLastDerp: make(map[key.NodePublic]int), peerMap: newPeerMap(), discoInfo: make(map[key.DiscoPublic]*discoInfo), - discoPrivate: discoPrivate, - discoPublic: discoPrivate.Public(), cloudInfo: newCloudInfo(logf), } - c.discoShort = c.discoPublic.ShortString() + c.discoAtomic.Set(discoPrivate) c.bind = &connBind{Conn: c, closed: true} c.receiveBatchPool = sync.Pool{New: func() any { msgs := make([]ipv6.Message, c.bind.BatchSize()) @@ -635,7 +627,7 @@ func (c *Conn) onUDPRelayAllocResp(allocResp UDPRelayAllocResp) { // now versus taking a network round-trip through DERP. selfNodeKey := c.publicKeyAtomic.Load() if selfNodeKey.Compare(allocResp.ReqRxFromNodeKey) == 0 && - allocResp.ReqRxFromDiscoKey.Compare(c.discoPublic) == 0 { + allocResp.ReqRxFromDiscoKey.Compare(c.discoAtomic.Public()) == 0 { c.relayManager.handleRxDiscoMsg(c, allocResp.Message, selfNodeKey, allocResp.ReqRxFromDiscoKey, epAddr{}) metricLocalDiscoAllocUDPRelayEndpointResponse.Add(1) } @@ -665,7 +657,10 @@ func (c *Conn) Synchronize() { } sp := syncPoint(make(chan struct{})) c.syncPub.Publish(sp) - sp.Wait() + select { + case <-sp: + case <-c.donec: + } } // NewConn creates a magic Conn listening on opts.Port. @@ -762,7 +757,7 @@ func NewConn(opts Options) (*Conn, error) { c.logf("[v1] couldn't create raw v6 disco listener, using regular listener instead: %v", err) } - c.logf("magicsock: disco key = %v", c.discoShort) + c.logf("magicsock: disco key = %v", c.discoAtomic.Short()) return c, nil } @@ -1241,12 +1236,37 @@ func (c *Conn) GetEndpointChanges(peer tailcfg.NodeView) ([]EndpointChange, erro // DiscoPublicKey returns the discovery public key. func (c *Conn) DiscoPublicKey() key.DiscoPublic { - return c.discoPublic + return c.discoAtomic.Public() +} + +// RotateDiscoKey generates a new discovery key pair and updates the connection +// to use it. This invalidates all existing disco sessions and will cause peers +// to re-establish discovery sessions with the new key. +// +// This is primarily for debugging and testing purposes, a future enhancement +// should provide a mechanism for seamless rotation by supporting short term use +// of the old key. +func (c *Conn) RotateDiscoKey() { + oldShort := c.discoAtomic.Short() + newPrivate := key.NewDisco() + + c.mu.Lock() + c.discoAtomic.Set(newPrivate) + newShort := c.discoAtomic.Short() + c.discoInfo = make(map[key.DiscoPublic]*discoInfo) + connCtx := c.connCtx + c.mu.Unlock() + + c.logf("magicsock: rotated disco key from %v to %v", oldShort, newShort) + + if connCtx != nil { + c.ReSTUN("disco-key-rotation") + } } // determineEndpoints returns the machine's endpoint addresses. It does a STUN -// lookup (via netcheck) to determine its public address. Additionally any -// static enpoints provided by user are always added to the returned endpoints +// lookup (via netcheck) to determine its public address. Additionally, any +// static endpoints provided by user are always added to the returned endpoints // without validating if the node can be reached via those endpoints. // // c.mu must NOT be held. @@ -1911,7 +1931,7 @@ func (c *Conn) sendDiscoAllocateUDPRelayEndpointRequest(dst epAddr, dstKey key.N if isDERP && dstKey.Compare(selfNodeKey) == 0 { c.allocRelayEndpointPub.Publish(UDPRelayAllocReq{ RxFromNodeKey: selfNodeKey, - RxFromDiscoKey: c.discoPublic, + RxFromDiscoKey: c.discoAtomic.Public(), Message: allocReq, }) metricLocalDiscoAllocUDPRelayEndpointRequest.Add(1) @@ -1982,7 +2002,7 @@ func (c *Conn) sendDiscoMessage(dst epAddr, dstKey key.NodePublic, dstDisco key. } } pkt = append(pkt, disco.Magic...) - pkt = c.discoPublic.AppendTo(pkt) + pkt = c.discoAtomic.Public().AppendTo(pkt) if isDERP { metricSendDiscoDERP.Add(1) @@ -2000,7 +2020,7 @@ func (c *Conn) sendDiscoMessage(dst epAddr, dstKey key.NodePublic, dstDisco key. if !dstKey.IsZero() { node = dstKey.ShortString() } - c.dlogf("[v1] magicsock: disco: %v->%v (%v, %v) sent %v len %v\n", c.discoShort, dstDisco.ShortString(), node, derpStr(dst.String()), disco.MessageSummary(m), len(pkt)) + c.dlogf("[v1] magicsock: disco: %v->%v (%v, %v) sent %v len %v\n", c.discoAtomic.Short(), dstDisco.ShortString(), node, derpStr(dst.String()), disco.MessageSummary(m), len(pkt)) } if isDERP { metricSentDiscoDERP.Add(1) @@ -2349,13 +2369,13 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake } if isVia { c.dlogf("[v1] magicsock: disco: %v<-%v via %v (%v, %v) got call-me-maybe-via, %d endpoints", - c.discoShort, epDisco.short, via.ServerDisco.ShortString(), + c.discoAtomic.Short(), epDisco.short, via.ServerDisco.ShortString(), ep.publicKey.ShortString(), derpStr(src.String()), len(via.AddrPorts)) c.relayManager.handleCallMeMaybeVia(ep, lastBest, lastBestIsTrusted, via) } else { c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints", - c.discoShort, epDisco.short, + c.discoAtomic.Short(), epDisco.short, ep.publicKey.ShortString(), derpStr(src.String()), len(cmm.MyNumber)) go ep.handleCallMeMaybe(cmm) @@ -2401,7 +2421,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake if isResp { c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got %s, %d endpoints", - c.discoShort, epDisco.short, + c.discoAtomic.Short(), epDisco.short, ep.publicKey.ShortString(), derpStr(src.String()), msgType, len(resp.AddrPorts)) @@ -2415,7 +2435,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake return } else { c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got %s disco[0]=%v disco[1]=%v", - c.discoShort, epDisco.short, + c.discoAtomic.Short(), epDisco.short, ep.publicKey.ShortString(), derpStr(src.String()), msgType, req.ClientDisco[0].ShortString(), req.ClientDisco[1].ShortString()) @@ -2580,7 +2600,7 @@ func (c *Conn) handlePingLocked(dm *disco.Ping, src epAddr, di *discoInfo, derpN if numNodes > 1 { pingNodeSrcStr = "[one-of-multi]" } - c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got ping tx=%x padding=%v", c.discoShort, di.discoShort, pingNodeSrcStr, src, dm.TxID[:6], dm.Padding) + c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got ping tx=%x padding=%v", c.discoAtomic.Short(), di.discoShort, pingNodeSrcStr, src, dm.TxID[:6], dm.Padding) } ipDst := src @@ -2653,7 +2673,7 @@ func (c *Conn) discoInfoForKnownPeerLocked(k key.DiscoPublic) *discoInfo { di = &discoInfo{ discoKey: k, discoShort: k.ShortString(), - sharedKey: c.discoPrivate.Shared(k), + sharedKey: c.discoAtomic.Private().Shared(k), } c.discoInfo[k] = di } diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 60620b14100f1..4e10248861500 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -111,7 +111,7 @@ func (c *Conn) WaitReady(t testing.TB) { } } -func runDERPAndStun(t *testing.T, logf logger.Logf, l nettype.PacketListener, stunIP netip.Addr) (derpMap *tailcfg.DERPMap, cleanup func()) { +func runDERPAndStun(t *testing.T, logf logger.Logf, ln nettype.PacketListener, stunIP netip.Addr) (derpMap *tailcfg.DERPMap, cleanup func()) { d := derpserver.New(key.NewNode(), logf) httpsrv := httptest.NewUnstartedServer(derpserver.Handler(d)) @@ -119,7 +119,7 @@ func runDERPAndStun(t *testing.T, logf logger.Logf, l nettype.PacketListener, st httpsrv.Config.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) httpsrv.StartTLS() - stunAddr, stunCleanup := stuntest.ServeWithPacketListener(t, l) + stunAddr, stunCleanup := stuntest.ServeWithPacketListener(t, ln) m := &tailcfg.DERPMap{ Regions: map[int]*tailcfg.DERPRegion{ @@ -172,12 +172,12 @@ type magicStack struct { // newMagicStack builds and initializes an idle magicsock and // friends. You need to call conn.onNodeViewsUpdate and dev.Reconfig // before anything interesting happens. -func newMagicStack(t testing.TB, logf logger.Logf, l nettype.PacketListener, derpMap *tailcfg.DERPMap) *magicStack { +func newMagicStack(t testing.TB, logf logger.Logf, ln nettype.PacketListener, derpMap *tailcfg.DERPMap) *magicStack { privateKey := key.NewNode() - return newMagicStackWithKey(t, logf, l, derpMap, privateKey) + return newMagicStackWithKey(t, logf, ln, derpMap, privateKey) } -func newMagicStackWithKey(t testing.TB, logf logger.Logf, l nettype.PacketListener, derpMap *tailcfg.DERPMap, privateKey key.NodePrivate) *magicStack { +func newMagicStackWithKey(t testing.TB, logf logger.Logf, ln nettype.PacketListener, derpMap *tailcfg.DERPMap, privateKey key.NodePrivate) *magicStack { t.Helper() bus := eventbustest.NewBus(t) @@ -197,7 +197,7 @@ func newMagicStackWithKey(t testing.TB, logf logger.Logf, l nettype.PacketListen Logf: logf, HealthTracker: ht, DisablePortMapper: true, - TestOnlyPacketListener: l, + TestOnlyPacketListener: ln, EndpointsFunc: func(eps []tailcfg.Endpoint) { epCh <- eps }, @@ -211,7 +211,7 @@ func newMagicStackWithKey(t testing.TB, logf logger.Logf, l nettype.PacketListen } tun := tuntest.NewChannelTUN() - tsTun := tstun.Wrap(logf, tun.TUN(), ®) + tsTun := tstun.Wrap(logf, tun.TUN(), ®, bus) tsTun.SetFilter(filter.NewAllowAllForTest(logf)) tsTun.Start() @@ -308,8 +308,7 @@ func meshStacks(logf logger.Logf, mutateNetmap func(idx int, nm *netmap.NetworkM buildNetmapLocked := func(myIdx int) *netmap.NetworkMap { me := ms[myIdx] nm := &netmap.NetworkMap{ - PrivateKey: me.privateKey, - NodeKey: me.privateKey.Public(), + NodeKey: me.privateKey.Public(), SelfNode: (&tailcfg.Node{ Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(1, 0, 0, byte(myIdx+1)), 32)}, }).View(), @@ -356,7 +355,7 @@ func meshStacks(logf logger.Logf, mutateNetmap func(idx int, nm *netmap.NetworkM peerSet.Add(peer.Key()) } m.conn.UpdatePeers(peerSet) - wg, err := nmcfg.WGCfg(nm, logf, 0, "") + wg, err := nmcfg.WGCfg(ms[i].privateKey, nm, logf, 0, "") if err != nil { // We're too far from the *testing.T to be graceful, // blow up. Shouldn't happen anyway. @@ -688,13 +687,13 @@ func (localhostListener) ListenPacket(ctx context.Context, network, address stri func TestTwoDevicePing(t *testing.T) { flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/11762") - l, ip := localhostListener{}, netaddr.IPv4(127, 0, 0, 1) + ln, ip := localhostListener{}, netaddr.IPv4(127, 0, 0, 1) n := &devices{ - m1: l, + m1: ln, m1IP: ip, - m2: l, + m2: ln, m2IP: ip, - stun: l, + stun: ln, stunIP: ip, } testTwoDevicePing(t, n) @@ -1059,7 +1058,6 @@ func testTwoDevicePing(t *testing.T, d *devices) { }) m1cfg := &wgcfg.Config{ - Name: "peer1", PrivateKey: m1.privateKey, Addresses: []netip.Prefix{netip.MustParsePrefix("1.0.0.1/32")}, Peers: []wgcfg.Peer{ @@ -1071,7 +1069,6 @@ func testTwoDevicePing(t *testing.T, d *devices) { }, } m2cfg := &wgcfg.Config{ - Name: "peer2", PrivateKey: m2.privateKey, Addresses: []netip.Prefix{netip.MustParsePrefix("1.0.0.2/32")}, Peers: []wgcfg.Peer{ @@ -1774,7 +1771,6 @@ func TestEndpointSetsEqual(t *testing.T) { t.Errorf("%q vs %q = %v; want %v", tt.a, tt.b, got, tt.want) } } - } func TestBetterAddr(t *testing.T) { @@ -1918,7 +1914,6 @@ func TestBetterAddr(t *testing.T) { t.Errorf("[%d] betterAddr(%+v, %+v) and betterAddr(%+v, %+v) both unexpectedly true", i, tt.a, tt.b, tt.b, tt.a) } } - } func epFromTyped(eps []tailcfg.Endpoint) (ret []netip.AddrPort) { @@ -2203,10 +2198,9 @@ func TestIsWireGuardOnlyPeer(t *testing.T) { defer m.Close() nm := &netmap.NetworkMap{ - Name: "ts", - PrivateKey: m.privateKey, - NodeKey: m.privateKey.Public(), + NodeKey: m.privateKey.Public(), SelfNode: (&tailcfg.Node{ + Name: "ts.", Addresses: []netip.Prefix{tsaip}, }).View(), Peers: nodeViews([]*tailcfg.Node{ @@ -2226,7 +2220,7 @@ func TestIsWireGuardOnlyPeer(t *testing.T) { } m.conn.onNodeViewsUpdate(nv) - cfg, err := nmcfg.WGCfg(nm, t.Logf, netmap.AllowSubnetRoutes, "") + cfg, err := nmcfg.WGCfg(m.privateKey, nm, t.Logf, netmap.AllowSubnetRoutes, "") if err != nil { t.Fatal(err) } @@ -2268,10 +2262,9 @@ func TestIsWireGuardOnlyPeerWithMasquerade(t *testing.T) { defer m.Close() nm := &netmap.NetworkMap{ - Name: "ts", - PrivateKey: m.privateKey, - NodeKey: m.privateKey.Public(), + NodeKey: m.privateKey.Public(), SelfNode: (&tailcfg.Node{ + Name: "ts.", Addresses: []netip.Prefix{tsaip}, }).View(), Peers: nodeViews([]*tailcfg.Node{ @@ -2292,7 +2285,7 @@ func TestIsWireGuardOnlyPeerWithMasquerade(t *testing.T) { } m.conn.onNodeViewsUpdate(nv) - cfg, err := nmcfg.WGCfg(nm, t.Logf, netmap.AllowSubnetRoutes, "") + cfg, err := nmcfg.WGCfg(m.privateKey, nm, t.Logf, netmap.AllowSubnetRoutes, "") if err != nil { t.Fatal(err) } @@ -2336,7 +2329,7 @@ func applyNetworkMap(t *testing.T, m *magicStack, nm *netmap.NetworkMap) { m.conn.noV6.Store(true) // Turn the network map into a wireguard config (for the tailscale internal wireguard device). - cfg, err := nmcfg.WGCfg(nm, t.Logf, netmap.AllowSubnetRoutes, "") + cfg, err := nmcfg.WGCfg(m.privateKey, nm, t.Logf, netmap.AllowSubnetRoutes, "") if err != nil { t.Fatal(err) } @@ -2405,10 +2398,9 @@ func TestIsWireGuardOnlyPickEndpointByPing(t *testing.T) { wgEpV6 := netip.MustParseAddrPort(v6.LocalAddr().String()) nm := &netmap.NetworkMap{ - Name: "ts", - PrivateKey: m.privateKey, - NodeKey: m.privateKey.Public(), + NodeKey: m.privateKey.Public(), SelfNode: (&tailcfg.Node{ + Name: "ts.", Addresses: []netip.Prefix{tsaip}, }).View(), Peers: nodeViews([]*tailcfg.Node{ @@ -2468,7 +2460,7 @@ func TestIsWireGuardOnlyPickEndpointByPing(t *testing.T) { if len(state.recentPongs) != 1 { t.Errorf("IPv4 address did not have a recentPong entry: got %v, want %v", len(state.recentPongs), 1) } - // Set the latency extremely high so we dont choose endpoint during the next + // Set the latency extremely high so we don't choose endpoint during the next // addrForSendLocked call. state.recentPongs[state.recentPong].latency = time.Second } @@ -3144,7 +3136,6 @@ func TestMaybeRebindOnError(t *testing.T) { t.Errorf("expected at least 5 seconds between %s and %s", lastRebindTime, newTime) } } - }) }) } @@ -4241,3 +4232,73 @@ func Test_lazyEndpoint_FromPeer(t *testing.T) { }) } } + +func TestRotateDiscoKey(t *testing.T) { + c := newConn(t.Logf) + + oldPrivate, oldPublic := c.discoAtomic.Pair() + oldShort := c.discoAtomic.Short() + + if oldPublic != oldPrivate.Public() { + t.Fatalf("old public key doesn't match old private key") + } + if oldShort != oldPublic.ShortString() { + t.Fatalf("old short string doesn't match old public key") + } + + testDiscoKey := key.NewDisco().Public() + c.mu.Lock() + c.discoInfo[testDiscoKey] = &discoInfo{ + discoKey: testDiscoKey, + discoShort: testDiscoKey.ShortString(), + } + if len(c.discoInfo) != 1 { + t.Fatalf("expected 1 discoInfo entry, got %d", len(c.discoInfo)) + } + c.mu.Unlock() + + c.RotateDiscoKey() + + newPrivate, newPublic := c.discoAtomic.Pair() + newShort := c.discoAtomic.Short() + + if newPublic.Compare(oldPublic) == 0 { + t.Fatalf("disco key didn't change after rotation") + } + if newShort == oldShort { + t.Fatalf("short string didn't change after rotation") + } + + if newPublic != newPrivate.Public() { + t.Fatalf("new public key doesn't match new private key") + } + if newShort != newPublic.ShortString() { + t.Fatalf("new short string doesn't match new public key") + } + + c.mu.Lock() + if len(c.discoInfo) != 0 { + t.Fatalf("expected discoInfo to be cleared, got %d entries", len(c.discoInfo)) + } + c.mu.Unlock() +} + +func TestRotateDiscoKeyMultipleTimes(t *testing.T) { + c := newConn(t.Logf) + + keys := make([]key.DiscoPublic, 0, 5) + keys = append(keys, c.discoAtomic.Public()) + + for i := 0; i < 4; i++ { + c.RotateDiscoKey() + newKey := c.discoAtomic.Public() + + for j, oldKey := range keys { + if newKey.Compare(oldKey) == 0 { + t.Fatalf("rotation %d produced same key as rotation %d", i+1, j) + } + } + + keys = append(keys, newKey) + } +} diff --git a/wgengine/magicsock/rebinding_conn.go b/wgengine/magicsock/rebinding_conn.go index 2798abbf20ed8..c98e645705b46 100644 --- a/wgengine/magicsock/rebinding_conn.go +++ b/wgengine/magicsock/rebinding_conn.go @@ -8,7 +8,6 @@ import ( "fmt" "net" "net/netip" - "sync" "sync/atomic" "syscall" @@ -16,6 +15,7 @@ import ( "tailscale.com/net/batching" "tailscale.com/net/netaddr" "tailscale.com/net/packet" + "tailscale.com/syncs" "tailscale.com/types/nettype" ) @@ -31,7 +31,7 @@ type RebindingUDPConn struct { // Neither is expected to be nil, sockets are bound on creation. pconnAtomic atomic.Pointer[nettype.PacketConn] - mu sync.Mutex // held while changing pconn (and pconnAtomic) + mu syncs.Mutex // held while changing pconn (and pconnAtomic) pconn nettype.PacketConn port uint16 } diff --git a/wgengine/magicsock/relaymanager.go b/wgengine/magicsock/relaymanager.go index a9dca70ae2228..69831a4df19f8 100644 --- a/wgengine/magicsock/relaymanager.go +++ b/wgengine/magicsock/relaymanager.go @@ -15,6 +15,7 @@ import ( "tailscale.com/net/packet" "tailscale.com/net/stun" udprelay "tailscale.com/net/udprelay/endpoint" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tstime" "tailscale.com/types/key" @@ -58,7 +59,7 @@ type relayManager struct { getServersCh chan chan set.Set[candidatePeerRelay] derpHomeChangeCh chan derpHomeChangeEvent - discoInfoMu sync.Mutex // guards the following field + discoInfoMu syncs.Mutex // guards the following field discoInfoByServerDisco map[key.DiscoPublic]*relayHandshakeDiscoInfo // runLoopStoppedCh is written to by runLoop() upon return, enabling event @@ -360,7 +361,7 @@ func (r *relayManager) ensureDiscoInfoFor(work *relayHandshakeWork) { di.di = &discoInfo{ discoKey: work.se.ServerDisco, discoShort: work.se.ServerDisco.ShortString(), - sharedKey: work.wlb.ep.c.discoPrivate.Shared(work.se.ServerDisco), + sharedKey: work.wlb.ep.c.discoAtomic.Private().Shared(work.se.ServerDisco), } } } @@ -1030,7 +1031,7 @@ func (r *relayManager) allocateAllServersRunLoop(wlb endpointWithLastBest) { if remoteDisco == nil { return } - discoKeys := key.NewSortedPairOfDiscoPublic(wlb.ep.c.discoPublic, remoteDisco.key) + discoKeys := key.NewSortedPairOfDiscoPublic(wlb.ep.c.discoAtomic.Public(), remoteDisco.key) for _, v := range r.serversByNodeKey { byDiscoKeys, ok := r.allocWorkByDiscoKeysByServerNodeKey[v.nodeKey] if !ok { diff --git a/wgengine/magicsock/relaymanager_test.go b/wgengine/magicsock/relaymanager_test.go index d400818394c47..e8fddfd91b46e 100644 --- a/wgengine/magicsock/relaymanager_test.go +++ b/wgengine/magicsock/relaymanager_test.go @@ -22,11 +22,15 @@ func TestRelayManagerInitAndIdle(t *testing.T) { <-rm.runLoopStoppedCh rm = relayManager{} - rm.handleCallMeMaybeVia(&endpoint{c: &Conn{discoPrivate: key.NewDisco()}}, addrQuality{}, false, &disco.CallMeMaybeVia{UDPRelayEndpoint: disco.UDPRelayEndpoint{ServerDisco: key.NewDisco().Public()}}) + c1 := &Conn{} + c1.discoAtomic.Set(key.NewDisco()) + rm.handleCallMeMaybeVia(&endpoint{c: c1}, addrQuality{}, false, &disco.CallMeMaybeVia{UDPRelayEndpoint: disco.UDPRelayEndpoint{ServerDisco: key.NewDisco().Public()}}) <-rm.runLoopStoppedCh rm = relayManager{} - rm.handleRxDiscoMsg(&Conn{discoPrivate: key.NewDisco()}, &disco.BindUDPRelayEndpointChallenge{}, key.NodePublic{}, key.DiscoPublic{}, epAddr{}) + c2 := &Conn{} + c2.discoAtomic.Set(key.NewDisco()) + rm.handleRxDiscoMsg(c2, &disco.BindUDPRelayEndpointChallenge{}, key.NodePublic{}, key.DiscoPublic{}, epAddr{}) <-rm.runLoopStoppedCh rm = relayManager{} diff --git a/wgengine/netlog/netlog.go b/wgengine/netlog/netlog.go index 2984df99471b6..12fe9c797641a 100644 --- a/wgengine/netlog/netlog.go +++ b/wgengine/netlog/netlog.go @@ -10,14 +10,11 @@ package netlog import ( "cmp" "context" - "encoding/json" - "errors" "fmt" "io" "log" "net/http" "net/netip" - "sync" "time" "tailscale.com/health" @@ -26,12 +23,19 @@ import ( "tailscale.com/net/netmon" "tailscale.com/net/sockstats" "tailscale.com/net/tsaddr" - "tailscale.com/tailcfg" + "tailscale.com/syncs" + "tailscale.com/types/ipproto" + "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/netlogfunc" "tailscale.com/types/netlogtype" + "tailscale.com/types/netmap" "tailscale.com/util/eventbus" + "tailscale.com/util/set" "tailscale.com/wgengine/router" + + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" ) // pollPeriod specifies how often to poll for network traffic. @@ -49,25 +53,38 @@ func (noopDevice) SetConnectionCounter(netlogfunc.ConnectionCounter) {} // Logger logs statistics about every connection. // At present, it only logs connections within a tailscale network. -// Exit node traffic is not logged for privacy reasons. +// By default, exit node traffic is not logged for privacy reasons +// unless the Tailnet administrator opts-into explicit logging. // The zero value is ready for use. type Logger struct { - mu sync.Mutex // protects all fields below + mu syncs.Mutex // protects all fields below + logf logger.Logf + + // shutdownLocked shuts down the logger. + // The mutex must be held when calling. + shutdownLocked func(context.Context) error - logger *logtail.Logger - stats *statistics - tun Device - sock Device + record record // the current record of network connection flows + recordLen int // upper bound on JSON length of record + recordsChan chan record // set to nil when shutdown + flushTimer *time.Timer // fires when record should flush to recordsChan - addrs map[netip.Addr]bool - prefixes map[netip.Prefix]bool + // Information about Tailscale nodes. + // These are read-only once updated by ReconfigNetworkMap. + selfNode nodeUser + allNodes map[netip.Addr]nodeUser // includes selfNode; nodeUser values are always valid + + // Information about routes. + // These are read-only once updated by ReconfigRoutes. + routeAddrs set.Set[netip.Addr] + routePrefixes []netip.Prefix } // Running reports whether the logger is running. func (nl *Logger) Running() bool { nl.mu.Lock() defer nl.mu.Unlock() - return nl.logger != nil + return nl.shutdownLocked != nil } var testClient *http.Client @@ -75,9 +92,9 @@ var testClient *http.Client // Startup starts an asynchronous network logger that monitors // statistics for the provided tun and/or sock device. // -// The tun Device captures packets within the tailscale network, -// where at least one address is a tailscale IP address. -// The source is always from the perspective of the current node. +// The tun [Device] captures packets within the tailscale network, +// where at least one address is usually a tailscale IP address. +// The source is usually from the perspective of the current node. // If one of the other endpoint is not a tailscale IP address, // then it suggests the use of a subnet router or exit node. // For example, when using a subnet router, the source address is @@ -89,28 +106,33 @@ var testClient *http.Client // In this case, the node acting as a subnet router is acting on behalf // of some remote endpoint within the subnet range. // The tun is used to populate the VirtualTraffic, SubnetTraffic, -// and ExitTraffic fields in Message. +// and ExitTraffic fields in [netlogtype.Message]. // -// The sock Device captures packets at the magicsock layer. +// The sock [Device] captures packets at the magicsock layer. // The source is always a tailscale IP address and the destination // is a non-tailscale IP address to contact for that particular tailscale node. // The IP protocol and source port are always zero. -// The sock is used to populated the PhysicalTraffic field in Message. +// The sock is used to populated the PhysicalTraffic field in [netlogtype.Message]. +// // The netMon parameter is optional; if non-nil it's used to do faster interface lookups. -func (nl *Logger) Startup(nodeID tailcfg.StableNodeID, nodeLogID, domainLogID logid.PrivateID, tun, sock Device, netMon *netmon.Monitor, health *health.Tracker, bus *eventbus.Bus, logExitFlowEnabledEnabled bool) error { +func (nl *Logger) Startup(logf logger.Logf, nm *netmap.NetworkMap, nodeLogID, domainLogID logid.PrivateID, tun, sock Device, netMon *netmon.Monitor, health *health.Tracker, bus *eventbus.Bus, logExitFlowEnabledEnabled bool) error { nl.mu.Lock() defer nl.mu.Unlock() - if nl.logger != nil { - return fmt.Errorf("network logger already running for %v", nl.logger.PrivateID().Public()) + + if nl.shutdownLocked != nil { + return fmt.Errorf("network logger already running") } + nl.selfNode, nl.allNodes = makeNodeMaps(nm) // Startup a log stream to Tailscale's logging service. - logf := log.Printf + if logf == nil { + logf = log.Printf + } httpc := &http.Client{Transport: logpolicy.NewLogtailTransport(logtail.DefaultHost, netMon, health, logf)} if testClient != nil { httpc = testClient } - nl.logger = logtail.NewLogger(logtail.Config{ + logger := logtail.NewLogger(logtail.Config{ Collection: "tailtraffic.log.tailscale.io", PrivateID: nodeLogID, CopyPrivateID: domainLogID, @@ -124,108 +146,311 @@ func (nl *Logger) Startup(nodeID tailcfg.StableNodeID, nodeLogID, domainLogID lo IncludeProcID: true, IncludeProcSequence: true, }, logf) - nl.logger.SetSockstatsLabel(sockstats.LabelNetlogLogger) - - // Startup a data structure to track per-connection statistics. - // There is a maximum size for individual log messages that logtail - // can upload to the Tailscale log service, so stay below this limit. - const maxLogSize = 256 << 10 - const maxConns = (maxLogSize - netlogtype.MaxMessageJSONSize) / netlogtype.MaxConnectionCountsJSONSize - nl.stats = newStatistics(pollPeriod, maxConns, func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts) { - nl.mu.Lock() - addrs := nl.addrs - prefixes := nl.prefixes - nl.mu.Unlock() - recordStatistics(nl.logger, nodeID, start, end, virtual, physical, addrs, prefixes, logExitFlowEnabledEnabled) - }) + logger.SetSockstatsLabel(sockstats.LabelNetlogLogger) // Register the connection tracker into the TUN device. - nl.tun = cmp.Or[Device](tun, noopDevice{}) - nl.tun.SetConnectionCounter(nl.stats.UpdateVirtual) + tun = cmp.Or[Device](tun, noopDevice{}) + tun.SetConnectionCounter(nl.updateVirtConn) // Register the connection tracker into magicsock. - nl.sock = cmp.Or[Device](sock, noopDevice{}) - nl.sock.SetConnectionCounter(nl.stats.UpdatePhysical) + sock = cmp.Or[Device](sock, noopDevice{}) + sock.SetConnectionCounter(nl.updatePhysConn) + + // Startup a goroutine to record log messages. + // This is done asynchronously so that the cost of serializing + // the network flow log message never stalls processing of packets. + nl.record = record{} + nl.recordLen = 0 + nl.recordsChan = make(chan record, 100) + recorderDone := make(chan struct{}) + go func(recordsChan chan record) { + defer close(recorderDone) + for rec := range recordsChan { + msg := rec.toMessage(false, !logExitFlowEnabledEnabled) + if b, err := jsonv2.Marshal(msg, jsontext.AllowInvalidUTF8(true)); err != nil { + if nl.logf != nil { + nl.logf("netlog: json.Marshal error: %v", err) + } + } else { + logger.Logf("%s", b) + } + } + }(nl.recordsChan) + + // Register the mechanism for shutting down. + nl.shutdownLocked = func(ctx context.Context) error { + tun.SetConnectionCounter(nil) + sock.SetConnectionCounter(nil) + + // Flush and process all pending records. + nl.flushRecordLocked() + close(nl.recordsChan) + nl.recordsChan = nil + <-recorderDone + recorderDone = nil + + // Try to upload all pending records. + err := logger.Shutdown(ctx) + + // Purge state. + nl.shutdownLocked = nil + nl.selfNode = nodeUser{} + nl.allNodes = nil + nl.routeAddrs = nil + nl.routePrefixes = nil + + return err + } return nil } -func recordStatistics(logger *logtail.Logger, nodeID tailcfg.StableNodeID, start, end time.Time, connStats, sockStats map[netlogtype.Connection]netlogtype.Counts, addrs map[netip.Addr]bool, prefixes map[netip.Prefix]bool, logExitFlowEnabled bool) { - m := netlogtype.Message{NodeID: nodeID, Start: start.UTC(), End: end.UTC()} - - classifyAddr := func(a netip.Addr) (isTailscale, withinRoute bool) { - // NOTE: There could be mis-classifications where an address is treated - // as a Tailscale IP address because the subnet range overlaps with - // the subnet range that Tailscale IP addresses are allocated from. - // This should never happen for IPv6, but could happen for IPv4. - withinRoute = addrs[a] - for p := range prefixes { - if p.Contains(a) && p.Bits() > 0 { - withinRoute = true - break - } - } - return withinRoute && tsaddr.IsTailscaleIP(a), withinRoute && !tsaddr.IsTailscaleIP(a) +var ( + tailscaleServiceIPv4 = tsaddr.TailscaleServiceIP() + tailscaleServiceIPv6 = tsaddr.TailscaleServiceIPv6() +) + +func (nl *Logger) updateVirtConn(proto ipproto.Proto, src, dst netip.AddrPort, packets, bytes int, recv bool) { + // Network logging is defined as traffic between two Tailscale nodes. + // Traffic with the internal Tailscale service is not with another node + // and should not be logged. It also happens to be a high volume + // amount of discrete traffic flows (e.g., DNS lookups). + switch dst.Addr() { + case tailscaleServiceIPv4, tailscaleServiceIPv6: + return } - exitTraffic := make(map[netlogtype.Connection]netlogtype.Counts) - for conn, cnts := range connStats { - srcIsTailscaleIP, srcWithinSubnet := classifyAddr(conn.Src.Addr()) - dstIsTailscaleIP, dstWithinSubnet := classifyAddr(conn.Dst.Addr()) - switch { - case srcIsTailscaleIP && dstIsTailscaleIP: - m.VirtualTraffic = append(m.VirtualTraffic, netlogtype.ConnectionCounts{Connection: conn, Counts: cnts}) - case srcWithinSubnet || dstWithinSubnet: - m.SubnetTraffic = append(m.SubnetTraffic, netlogtype.ConnectionCounts{Connection: conn, Counts: cnts}) - default: - const anonymize = true - if anonymize && !logExitFlowEnabled { - // Only preserve the address if it is a Tailscale IP address. - srcOrig, dstOrig := conn.Src, conn.Dst - conn = netlogtype.Connection{} // scrub everything by default - if srcIsTailscaleIP { - conn.Src = netip.AddrPortFrom(srcOrig.Addr(), 0) - } - if dstIsTailscaleIP { - conn.Dst = netip.AddrPortFrom(dstOrig.Addr(), 0) - } - } - exitTraffic[conn] = exitTraffic[conn].Add(cnts) + nl.mu.Lock() + defer nl.mu.Unlock() + + // Lookup the connection and increment the counts. + nl.initRecordLocked() + conn := netlogtype.Connection{Proto: proto, Src: src, Dst: dst} + cnts, found := nl.record.virtConns[conn] + if !found { + cnts.connType = nl.addNewVirtConnLocked(conn) + } + if recv { + cnts.RxPackets += uint64(packets) + cnts.RxBytes += uint64(bytes) + } else { + cnts.TxPackets += uint64(packets) + cnts.TxBytes += uint64(bytes) + } + nl.record.virtConns[conn] = cnts +} + +// addNewVirtConnLocked adds the first insertion of a physical connection. +// The [Logger.mu] must be held. +func (nl *Logger) addNewVirtConnLocked(c netlogtype.Connection) connType { + // Check whether this is the first insertion of the src and dst node. + // If so, compute the additional JSON bytes that would be added + // to the record for the node information. + var srcNodeLen, dstNodeLen int + srcNode, srcSeen := nl.record.seenNodes[c.Src.Addr()] + if !srcSeen { + srcNode = nl.allNodes[c.Src.Addr()] + if srcNode.Valid() { + srcNodeLen = srcNode.jsonLen() } } - for conn, cnts := range exitTraffic { - m.ExitTraffic = append(m.ExitTraffic, netlogtype.ConnectionCounts{Connection: conn, Counts: cnts}) + dstNode, dstSeen := nl.record.seenNodes[c.Dst.Addr()] + if !dstSeen { + dstNode = nl.allNodes[c.Dst.Addr()] + if dstNode.Valid() { + dstNodeLen = dstNode.jsonLen() + } + } + + // Check whether the additional [netlogtype.ConnectionCounts] + // and [netlogtype.Node] information would exceed [maxLogSize]. + if nl.recordLen+netlogtype.MaxConnectionCountsJSONSize+srcNodeLen+dstNodeLen > maxLogSize { + nl.flushRecordLocked() + nl.initRecordLocked() + } + + // Insert newly seen src and/or dst nodes. + if !srcSeen && srcNode.Valid() { + nl.record.seenNodes[c.Src.Addr()] = srcNode } - for conn, cnts := range sockStats { - m.PhysicalTraffic = append(m.PhysicalTraffic, netlogtype.ConnectionCounts{Connection: conn, Counts: cnts}) + if !dstSeen && dstNode.Valid() { + nl.record.seenNodes[c.Dst.Addr()] = dstNode } + nl.recordLen += netlogtype.MaxConnectionCountsJSONSize + srcNodeLen + dstNodeLen - if len(m.VirtualTraffic)+len(m.SubnetTraffic)+len(m.ExitTraffic)+len(m.PhysicalTraffic) > 0 { - if b, err := json.Marshal(m); err != nil { - logger.Logf("json.Marshal error: %v", err) + // Classify the traffic type. + var srcIsSelfNode bool + if nl.selfNode.Valid() { + srcIsSelfNode = nl.selfNode.Addresses().ContainsFunc(func(p netip.Prefix) bool { + return c.Src.Addr() == p.Addr() && p.IsSingleIP() + }) + } + switch { + case srcIsSelfNode && dstNode.Valid(): + return virtualTraffic + case srcIsSelfNode: + // TODO: Should we swap src for the node serving as the proxy? + // It is relatively useless always using the self IP address. + if nl.withinRoutesLocked(c.Dst.Addr()) { + return subnetTraffic // a client using another subnet router } else { - logger.Logf("%s", b) + return exitTraffic // a client using exit an exit node } + case dstNode.Valid(): + if nl.withinRoutesLocked(c.Src.Addr()) { + return subnetTraffic // serving as a subnet router + } else { + return exitTraffic // serving as an exit node + } + default: + return unknownTraffic + } +} + +func (nl *Logger) updatePhysConn(proto ipproto.Proto, src, dst netip.AddrPort, packets, bytes int, recv bool) { + nl.mu.Lock() + defer nl.mu.Unlock() + + // Lookup the connection and increment the counts. + nl.initRecordLocked() + conn := netlogtype.Connection{Proto: proto, Src: src, Dst: dst} + cnts, found := nl.record.physConns[conn] + if !found { + nl.addNewPhysConnLocked(conn) + } + if recv { + cnts.RxPackets += uint64(packets) + cnts.RxBytes += uint64(bytes) + } else { + cnts.TxPackets += uint64(packets) + cnts.TxBytes += uint64(bytes) } + nl.record.physConns[conn] = cnts } -func makeRouteMaps(cfg *router.Config) (addrs map[netip.Addr]bool, prefixes map[netip.Prefix]bool) { - addrs = make(map[netip.Addr]bool) - for _, p := range cfg.LocalAddrs { - if p.IsSingleIP() { - addrs[p.Addr()] = true +// addNewPhysConnLocked adds the first insertion of a physical connection. +// The [Logger.mu] must be held. +func (nl *Logger) addNewPhysConnLocked(c netlogtype.Connection) { + // Check whether this is the first insertion of the src node. + var srcNodeLen int + srcNode, srcSeen := nl.record.seenNodes[c.Src.Addr()] + if !srcSeen { + srcNode = nl.allNodes[c.Src.Addr()] + if srcNode.Valid() { + srcNodeLen = srcNode.jsonLen() } } - prefixes = make(map[netip.Prefix]bool) + + // Check whether the additional [netlogtype.ConnectionCounts] + // and [netlogtype.Node] information would exceed [maxLogSize]. + if nl.recordLen+netlogtype.MaxConnectionCountsJSONSize+srcNodeLen > maxLogSize { + nl.flushRecordLocked() + nl.initRecordLocked() + } + + // Insert newly seen src and/or dst nodes. + if !srcSeen && srcNode.Valid() { + nl.record.seenNodes[c.Src.Addr()] = srcNode + } + nl.recordLen += netlogtype.MaxConnectionCountsJSONSize + srcNodeLen +} + +// initRecordLocked initialize the current record if uninitialized. +// The [Logger.mu] must be held. +func (nl *Logger) initRecordLocked() { + if nl.recordLen != 0 { + return + } + nl.record = record{ + selfNode: nl.selfNode, + start: time.Now().UTC(), + seenNodes: make(map[netip.Addr]nodeUser), + virtConns: make(map[netlogtype.Connection]countsType), + physConns: make(map[netlogtype.Connection]netlogtype.Counts), + } + nl.recordLen = netlogtype.MinMessageJSONSize + nl.selfNode.jsonLen() + + // Start a time to auto-flush the record. + // Avoid tickers since continually waking up a goroutine + // is expensive on battery powered devices. + nl.flushTimer = time.AfterFunc(pollPeriod, func() { + nl.mu.Lock() + defer nl.mu.Unlock() + if !nl.record.start.IsZero() && time.Since(nl.record.start) > pollPeriod/2 { + nl.flushRecordLocked() + } + }) +} + +// flushRecordLocked flushes the current record if initialized. +// The [Logger.mu] must be held. +func (nl *Logger) flushRecordLocked() { + if nl.recordLen == 0 { + return + } + nl.record.end = time.Now().UTC() + if nl.recordsChan != nil { + select { + case nl.recordsChan <- nl.record: + default: + if nl.logf != nil { + nl.logf("netlog: dropped record due to processing backlog") + } + } + } + if nl.flushTimer != nil { + nl.flushTimer.Stop() + nl.flushTimer = nil + } + nl.record = record{} + nl.recordLen = 0 +} + +func makeNodeMaps(nm *netmap.NetworkMap) (selfNode nodeUser, allNodes map[netip.Addr]nodeUser) { + if nm == nil { + return + } + allNodes = make(map[netip.Addr]nodeUser) + if nm.SelfNode.Valid() { + selfNode = nodeUser{nm.SelfNode, nm.UserProfiles[nm.SelfNode.User()]} + for _, addr := range nm.SelfNode.Addresses().All() { + if addr.IsSingleIP() { + allNodes[addr.Addr()] = selfNode + } + } + } + for _, peer := range nm.Peers { + if peer.Valid() { + for _, addr := range peer.Addresses().All() { + if addr.IsSingleIP() { + allNodes[addr.Addr()] = nodeUser{peer, nm.UserProfiles[peer.User()]} + } + } + } + } + return selfNode, allNodes +} + +// ReconfigNetworkMap configures the network logger with an updated netmap. +func (nl *Logger) ReconfigNetworkMap(nm *netmap.NetworkMap) { + selfNode, allNodes := makeNodeMaps(nm) // avoid holding lock while making maps + nl.mu.Lock() + nl.selfNode, nl.allNodes = selfNode, allNodes + nl.mu.Unlock() +} + +func makeRouteMaps(cfg *router.Config) (addrs set.Set[netip.Addr], prefixes []netip.Prefix) { + addrs = make(set.Set[netip.Addr]) insertPrefixes := func(rs []netip.Prefix) { for _, p := range rs { if p.IsSingleIP() { - addrs[p.Addr()] = true + addrs.Add(p.Addr()) } else { - prefixes[p] = true + prefixes = append(prefixes, p) } } } + insertPrefixes(cfg.LocalAddrs) insertPrefixes(cfg.Routes) insertPrefixes(cfg.SubnetRoutes) return addrs, prefixes @@ -235,11 +460,25 @@ func makeRouteMaps(cfg *router.Config) (addrs map[netip.Addr]bool, prefixes map[ // The cfg is used to classify the types of connections captured by // the tun Device passed to Startup. func (nl *Logger) ReconfigRoutes(cfg *router.Config) { + addrs, prefixes := makeRouteMaps(cfg) // avoid holding lock while making maps nl.mu.Lock() - defer nl.mu.Unlock() - // TODO(joetsai): There is a race where deleted routes are not known at - // the time of extraction. We need to keep old routes around for a bit. - nl.addrs, nl.prefixes = makeRouteMaps(cfg) + nl.routeAddrs, nl.routePrefixes = addrs, prefixes + nl.mu.Unlock() +} + +// withinRoutesLocked reports whether a is within the configured routes, +// which should only contain Tailscale addresses and subnet routes. +// The [Logger.mu] must be held. +func (nl *Logger) withinRoutesLocked(a netip.Addr) bool { + if nl.routeAddrs.Contains(a) { + return true + } + for _, p := range nl.routePrefixes { + if p.Contains(a) && p.Bits() > 0 { + return true + } + } + return false } // Shutdown shuts down the network logger. @@ -248,26 +487,8 @@ func (nl *Logger) ReconfigRoutes(cfg *router.Config) { func (nl *Logger) Shutdown(ctx context.Context) error { nl.mu.Lock() defer nl.mu.Unlock() - if nl.logger == nil { + if nl.shutdownLocked == nil { return nil } - - // Shutdown in reverse order of Startup. - // Do not hold lock while shutting down since this may flush one last time. - nl.mu.Unlock() - nl.sock.SetConnectionCounter(nil) - nl.tun.SetConnectionCounter(nil) - err1 := nl.stats.Shutdown(ctx) - err2 := nl.logger.Shutdown(ctx) - nl.mu.Lock() - - // Purge state. - nl.logger = nil - nl.stats = nil - nl.tun = nil - nl.sock = nil - nl.addrs = nil - nl.prefixes = nil - - return errors.Join(err1, err2) + return nl.shutdownLocked(ctx) } diff --git a/wgengine/netlog/netlog_omit.go b/wgengine/netlog/netlog_omit.go index 43209df919ace..03610a1ef017a 100644 --- a/wgengine/netlog/netlog_omit.go +++ b/wgengine/netlog/netlog_omit.go @@ -7,7 +7,8 @@ package netlog type Logger struct{} -func (*Logger) Startup(...any) error { return nil } -func (*Logger) Running() bool { return false } -func (*Logger) Shutdown(any) error { return nil } -func (*Logger) ReconfigRoutes(any) {} +func (*Logger) Startup(...any) error { return nil } +func (*Logger) Running() bool { return false } +func (*Logger) Shutdown(any) error { return nil } +func (*Logger) ReconfigNetworkMap(any) {} +func (*Logger) ReconfigRoutes(any) {} diff --git a/wgengine/netlog/netlog_test.go b/wgengine/netlog/netlog_test.go new file mode 100644 index 0000000000000..b4758c7ec7beb --- /dev/null +++ b/wgengine/netlog/netlog_test.go @@ -0,0 +1,237 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_netlog && !ts_omit_logtail + +package netlog + +import ( + "encoding/binary" + "math/rand/v2" + "net/netip" + "sync" + "testing" + "testing/synctest" + "time" + + jsonv2 "github.com/go-json-experiment/json" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "tailscale.com/tailcfg" + "tailscale.com/types/bools" + "tailscale.com/types/ipproto" + "tailscale.com/types/netlogtype" + "tailscale.com/types/netmap" + "tailscale.com/wgengine/router" +) + +func TestEmbedNodeInfo(t *testing.T) { + // Initialize the logger with a particular view of the netmap. + var logger Logger + logger.ReconfigNetworkMap(&netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + StableID: "n123456CNTL", + ID: 123456, + Name: "test.tail123456.ts.net", + Addresses: []netip.Prefix{prefix("100.1.2.3")}, + Tags: []string{"tag:foo", "tag:bar"}, + }).View(), + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + StableID: "n123457CNTL", + ID: 123457, + Name: "peer1.tail123456.ts.net", + Addresses: []netip.Prefix{prefix("100.1.2.4")}, + Tags: []string{"tag:peer"}, + }).View(), + (&tailcfg.Node{ + StableID: "n123458CNTL", + ID: 123458, + Name: "peer2.tail123456.ts.net", + Addresses: []netip.Prefix{prefix("100.1.2.5")}, + User: 54321, + }).View(), + }, + UserProfiles: map[tailcfg.UserID]tailcfg.UserProfileView{ + 54321: (&tailcfg.UserProfile{ID: 54321, LoginName: "peer@example.com"}).View(), + }, + }) + logger.ReconfigRoutes(&router.Config{ + SubnetRoutes: []netip.Prefix{ + prefix("172.16.1.1/16"), + prefix("192.168.1.1/24"), + }, + }) + + // Update the counters for a few connections. + var group sync.WaitGroup + defer group.Wait() + conns := []struct { + virt bool + proto ipproto.Proto + src, dst netip.AddrPort + txP, txB, rxP, rxB int + }{ + {true, 0x6, addrPort("100.1.2.3:80"), addrPort("100.1.2.4:1812"), 88, 278, 34, 887}, + {true, 0x6, addrPort("100.1.2.3:443"), addrPort("100.1.2.5:1742"), 96, 635, 23, 790}, + {true, 0x6, addrPort("100.1.2.3:443"), addrPort("100.1.2.6:1175"), 48, 94, 86, 618}, // unknown peer (in Tailscale IP space, but not a known peer) + {true, 0x6, addrPort("100.1.2.3:80"), addrPort("192.168.1.241:713"), 43, 154, 66, 883}, + {true, 0x6, addrPort("100.1.2.3:80"), addrPort("192.168.2.241:713"), 43, 154, 66, 883}, // not in the subnet, must be exit traffic + {true, 0x6, addrPort("100.1.2.3:80"), addrPort("172.16.5.18:713"), 7, 243, 40, 59}, + {true, 0x6, addrPort("100.1.2.3:80"), addrPort("172.20.5.18:713"), 61, 753, 42, 492}, // not in the subnet, must be exit traffic + {true, 0x6, addrPort("192.168.1.241:713"), addrPort("100.1.2.3:80"), 43, 154, 66, 883}, + {true, 0x6, addrPort("192.168.2.241:713"), addrPort("100.1.2.3:80"), 43, 154, 66, 883}, // not in the subnet, must be exit traffic + {true, 0x6, addrPort("172.16.5.18:713"), addrPort("100.1.2.3:80"), 7, 243, 40, 59}, + {true, 0x6, addrPort("172.20.5.18:713"), addrPort("100.1.2.3:80"), 61, 753, 42, 492}, // not in the subnet, must be exit traffic + {true, 0x6, addrPort("14.255.192.128:39230"), addrPort("243.42.106.193:48206"), 81, 791, 79, 316}, // unknown connection + {false, 0x6, addrPort("100.1.2.4:0"), addrPort("35.92.180.165:9743"), 63, 136, 61, 409}, // physical traffic with peer1 + {false, 0x6, addrPort("100.1.2.5:0"), addrPort("131.19.35.17:9743"), 88, 452, 2, 716}, // physical traffic with peer2 + } + for range 10 { + for _, conn := range conns { + update := bools.IfElse(conn.virt, logger.updateVirtConn, logger.updatePhysConn) + group.Go(func() { update(conn.proto, conn.src, conn.dst, conn.txP, conn.txB, false) }) + group.Go(func() { update(conn.proto, conn.src, conn.dst, conn.rxP, conn.rxB, true) }) + } + } + group.Wait() + + // Verify that the counters match. + got := logger.record.toMessage(false, false) + got.Start = time.Time{} // avoid flakiness + want := netlogtype.Message{ + NodeID: "n123456CNTL", + SrcNode: netlogtype.Node{ + NodeID: "n123456CNTL", + Name: "test.tail123456.ts.net", + Addresses: []netip.Addr{addr("100.1.2.3")}, + Tags: []string{"tag:bar", "tag:foo"}, + }, + DstNodes: []netlogtype.Node{{ + NodeID: "n123457CNTL", + Name: "peer1.tail123456.ts.net", + Addresses: []netip.Addr{addr("100.1.2.4")}, + Tags: []string{"tag:peer"}, + }, { + NodeID: "n123458CNTL", + Name: "peer2.tail123456.ts.net", + Addresses: []netip.Addr{addr("100.1.2.5")}, + User: "peer@example.com", + }}, + VirtualTraffic: []netlogtype.ConnectionCounts{ + {Connection: conn(0x6, "100.1.2.3:80", "100.1.2.4:1812"), Counts: counts(880, 2780, 340, 8870)}, + {Connection: conn(0x6, "100.1.2.3:443", "100.1.2.5:1742"), Counts: counts(960, 6350, 230, 7900)}, + }, + SubnetTraffic: []netlogtype.ConnectionCounts{ + {Connection: conn(0x6, "100.1.2.3:80", "172.16.5.18:713"), Counts: counts(70, 2430, 400, 590)}, + {Connection: conn(0x6, "100.1.2.3:80", "192.168.1.241:713"), Counts: counts(430, 1540, 660, 8830)}, + {Connection: conn(0x6, "172.16.5.18:713", "100.1.2.3:80"), Counts: counts(70, 2430, 400, 590)}, + {Connection: conn(0x6, "192.168.1.241:713", "100.1.2.3:80"), Counts: counts(430, 1540, 660, 8830)}, + }, + ExitTraffic: []netlogtype.ConnectionCounts{ + {Connection: conn(0x6, "14.255.192.128:39230", "243.42.106.193:48206"), Counts: counts(810, 7910, 790, 3160)}, + {Connection: conn(0x6, "100.1.2.3:80", "172.20.5.18:713"), Counts: counts(610, 7530, 420, 4920)}, + {Connection: conn(0x6, "100.1.2.3:80", "192.168.2.241:713"), Counts: counts(430, 1540, 660, 8830)}, + {Connection: conn(0x6, "100.1.2.3:443", "100.1.2.6:1175"), Counts: counts(480, 940, 860, 6180)}, + {Connection: conn(0x6, "172.20.5.18:713", "100.1.2.3:80"), Counts: counts(610, 7530, 420, 4920)}, + {Connection: conn(0x6, "192.168.2.241:713", "100.1.2.3:80"), Counts: counts(430, 1540, 660, 8830)}, + }, + PhysicalTraffic: []netlogtype.ConnectionCounts{ + {Connection: conn(0x6, "100.1.2.4:0", "35.92.180.165:9743"), Counts: counts(630, 1360, 610, 4090)}, + {Connection: conn(0x6, "100.1.2.5:0", "131.19.35.17:9743"), Counts: counts(880, 4520, 20, 7160)}, + }, + } + if d := cmp.Diff(got, want, cmpopts.EquateComparable(netip.Addr{}, netip.AddrPort{})); d != "" { + t.Errorf("Message (-got +want):\n%s", d) + } +} + +func TestUpdateRace(t *testing.T) { + var logger Logger + logger.recordsChan = make(chan record, 1) + go func(recordsChan chan record) { + for range recordsChan { + } + }(logger.recordsChan) + + var group sync.WaitGroup + defer group.Wait() + for i := range 1000 { + group.Go(func() { + src, dst := randAddrPort(), randAddrPort() + for j := range 1000 { + if i%2 == 0 { + logger.updateVirtConn(0x1, src, dst, rand.IntN(10), rand.IntN(1000), j%2 == 0) + } else { + logger.updatePhysConn(0x1, src, dst, rand.IntN(10), rand.IntN(1000), j%2 == 0) + } + } + }) + group.Go(func() { + for range 1000 { + logger.ReconfigNetworkMap(new(netmap.NetworkMap)) + } + }) + group.Go(func() { + for range 1000 { + logger.ReconfigRoutes(new(router.Config)) + } + }) + } + + group.Wait() + logger.mu.Lock() + close(logger.recordsChan) + logger.recordsChan = nil + logger.mu.Unlock() +} + +func randAddrPort() netip.AddrPort { + var b [4]uint8 + binary.LittleEndian.PutUint32(b[:], rand.Uint32()) + return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(rand.Uint32())) +} + +func TestAutoFlushMaxConns(t *testing.T) { + var logger Logger + logger.recordsChan = make(chan record, 1) + for i := 0; len(logger.recordsChan) == 0; i++ { + logger.updateVirtConn(0, netip.AddrPortFrom(netip.Addr{}, uint16(i)), netip.AddrPort{}, 1, 1, false) + } + b, _ := jsonv2.Marshal(logger.recordsChan) + if len(b) > maxLogSize { + t.Errorf("len(Message) = %v, want <= %d", len(b), maxLogSize) + } +} + +func TestAutoFlushTimeout(t *testing.T) { + var logger Logger + logger.recordsChan = make(chan record, 1) + synctest.Test(t, func(t *testing.T) { + logger.updateVirtConn(0, netip.AddrPort{}, netip.AddrPort{}, 1, 1, false) + time.Sleep(pollPeriod) + }) + rec := <-logger.recordsChan + if d := rec.end.Sub(rec.start); d != pollPeriod { + t.Errorf("window = %v, want %v", d, pollPeriod) + } + if len(rec.virtConns) != 1 { + t.Errorf("len(virtConns) = %d, want 1", len(rec.virtConns)) + } +} + +func BenchmarkUpdateSameConn(b *testing.B) { + var logger Logger + b.ReportAllocs() + for range b.N { + logger.updateVirtConn(0, netip.AddrPort{}, netip.AddrPort{}, 1, 1, false) + } +} + +func BenchmarkUpdateNewConns(b *testing.B) { + var logger Logger + b.ReportAllocs() + for i := range b.N { + logger.updateVirtConn(0, netip.AddrPortFrom(netip.Addr{}, uint16(i)), netip.AddrPort{}, 1, 1, false) + } +} diff --git a/wgengine/netlog/record.go b/wgengine/netlog/record.go new file mode 100644 index 0000000000000..25b6b1148793a --- /dev/null +++ b/wgengine/netlog/record.go @@ -0,0 +1,218 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_netlog && !ts_omit_logtail + +package netlog + +import ( + "cmp" + "net/netip" + "slices" + "strings" + "time" + "unicode/utf8" + + "tailscale.com/tailcfg" + "tailscale.com/types/bools" + "tailscale.com/types/netlogtype" + "tailscale.com/util/set" +) + +// maxLogSize is the maximum number of bytes for a log message. +const maxLogSize = 256 << 10 + +// record is the in-memory representation of a [netlogtype.Message]. +// It uses maps to efficiently look-up addresses and connections. +// In contrast, [netlogtype.Message] is designed to be JSON serializable, +// where complex keys types are not well support in JSON objects. +type record struct { + selfNode nodeUser + + start time.Time + end time.Time + + seenNodes map[netip.Addr]nodeUser + + virtConns map[netlogtype.Connection]countsType + physConns map[netlogtype.Connection]netlogtype.Counts +} + +// nodeUser is a node with additional user profile information. +type nodeUser struct { + tailcfg.NodeView + user tailcfg.UserProfileView // UserProfileView for NodeView.User +} + +// countsType is a counts with classification information about the connection. +type countsType struct { + netlogtype.Counts + connType connType +} + +type connType uint8 + +const ( + unknownTraffic connType = iota + virtualTraffic + subnetTraffic + exitTraffic +) + +// toMessage converts a [record] into a [netlogtype.Message]. +func (r record) toMessage(excludeNodeInfo, anonymizeExitTraffic bool) netlogtype.Message { + if !r.selfNode.Valid() { + return netlogtype.Message{} + } + + m := netlogtype.Message{ + NodeID: r.selfNode.StableID(), + Start: r.start.UTC(), + End: r.end.UTC(), + } + + // Convert node fields. + if !excludeNodeInfo { + m.SrcNode = r.selfNode.toNode() + seenIDs := set.Of(r.selfNode.ID()) + for _, node := range r.seenNodes { + if _, ok := seenIDs[node.ID()]; !ok && node.Valid() { + m.DstNodes = append(m.DstNodes, node.toNode()) + seenIDs.Add(node.ID()) + } + } + slices.SortFunc(m.DstNodes, func(x, y netlogtype.Node) int { + return cmp.Compare(x.NodeID, y.NodeID) + }) + } + + // Converter traffic fields. + anonymizedExitTraffic := make(map[netlogtype.Connection]netlogtype.Counts) + for conn, cnts := range r.virtConns { + switch cnts.connType { + case virtualTraffic: + m.VirtualTraffic = append(m.VirtualTraffic, netlogtype.ConnectionCounts{Connection: conn, Counts: cnts.Counts}) + case subnetTraffic: + m.SubnetTraffic = append(m.SubnetTraffic, netlogtype.ConnectionCounts{Connection: conn, Counts: cnts.Counts}) + default: + if anonymizeExitTraffic { + conn = netlogtype.Connection{ // scrub the IP protocol type + Src: netip.AddrPortFrom(conn.Src.Addr(), 0), // scrub the port number + Dst: netip.AddrPortFrom(conn.Dst.Addr(), 0), // scrub the port number + } + if !r.seenNodes[conn.Src.Addr()].Valid() { + conn.Src = netip.AddrPort{} // not a Tailscale node, so scrub the address + } + if !r.seenNodes[conn.Dst.Addr()].Valid() { + conn.Dst = netip.AddrPort{} // not a Tailscale node, so scrub the address + } + anonymizedExitTraffic[conn] = anonymizedExitTraffic[conn].Add(cnts.Counts) + continue + } + m.ExitTraffic = append(m.ExitTraffic, netlogtype.ConnectionCounts{Connection: conn, Counts: cnts.Counts}) + } + } + for conn, cnts := range anonymizedExitTraffic { + m.ExitTraffic = append(m.ExitTraffic, netlogtype.ConnectionCounts{Connection: conn, Counts: cnts}) + } + for conn, cnts := range r.physConns { + m.PhysicalTraffic = append(m.PhysicalTraffic, netlogtype.ConnectionCounts{Connection: conn, Counts: cnts}) + } + + // Sort the connections for deterministic results. + slices.SortFunc(m.VirtualTraffic, compareConnCnts) + slices.SortFunc(m.SubnetTraffic, compareConnCnts) + slices.SortFunc(m.ExitTraffic, compareConnCnts) + slices.SortFunc(m.PhysicalTraffic, compareConnCnts) + + return m +} + +func compareConnCnts(x, y netlogtype.ConnectionCounts) int { + return cmp.Or( + netip.AddrPort.Compare(x.Src, y.Src), + netip.AddrPort.Compare(x.Dst, y.Dst), + cmp.Compare(x.Proto, y.Proto)) +} + +// jsonLen computes an upper-bound on the size of the JSON representation. +func (nu nodeUser) jsonLen() (n int) { + if !nu.Valid() { + return len(`{"nodeId":""}`) + } + n += len(`{}`) + n += len(`"nodeId":`) + jsonQuotedLen(string(nu.StableID())) + len(`,`) + if len(nu.Name()) > 0 { + n += len(`"name":`) + jsonQuotedLen(nu.Name()) + len(`,`) + } + if nu.Addresses().Len() > 0 { + n += len(`"addresses":[]`) + for _, addr := range nu.Addresses().All() { + n += bools.IfElse(addr.Addr().Is4(), len(`"255.255.255.255"`), len(`"ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"`)) + len(",") + } + } + if nu.Hostinfo().Valid() && len(nu.Hostinfo().OS()) > 0 { + n += len(`"os":`) + jsonQuotedLen(nu.Hostinfo().OS()) + len(`,`) + } + if nu.Tags().Len() > 0 { + n += len(`"tags":[]`) + for _, tag := range nu.Tags().All() { + n += jsonQuotedLen(tag) + len(",") + } + } else if nu.user.Valid() && nu.user.ID() == nu.User() && len(nu.user.LoginName()) > 0 { + n += len(`"user":`) + jsonQuotedLen(nu.user.LoginName()) + len(",") + } + return n +} + +// toNode converts the [nodeUser] into a [netlogtype.Node]. +func (nu nodeUser) toNode() netlogtype.Node { + if !nu.Valid() { + return netlogtype.Node{} + } + n := netlogtype.Node{ + NodeID: nu.StableID(), + Name: strings.TrimSuffix(nu.Name(), "."), + } + var ipv4, ipv6 netip.Addr + for _, addr := range nu.Addresses().All() { + switch { + case addr.IsSingleIP() && addr.Addr().Is4(): + ipv4 = addr.Addr() + case addr.IsSingleIP() && addr.Addr().Is6(): + ipv6 = addr.Addr() + } + } + n.Addresses = []netip.Addr{ipv4, ipv6} + n.Addresses = slices.DeleteFunc(n.Addresses, func(a netip.Addr) bool { return !a.IsValid() }) + if nu.Hostinfo().Valid() { + n.OS = nu.Hostinfo().OS() + } + if nu.Tags().Len() > 0 { + n.Tags = nu.Tags().AsSlice() + slices.Sort(n.Tags) + n.Tags = slices.Compact(n.Tags) + } else if nu.user.Valid() && nu.user.ID() == nu.User() { + n.User = nu.user.LoginName() + } + return n +} + +// jsonQuotedLen computes the length of the JSON serialization of s +// according to [jsontext.AppendQuote]. +func jsonQuotedLen(s string) int { + n := len(`"`) + len(s) + len(`"`) + for i, r := range s { + switch { + case r == '\b', r == '\t', r == '\n', r == '\f', r == '\r', r == '"', r == '\\': + n += len(`\X`) - 1 + case r < ' ': + n += len(`\uXXXX`) - 1 + case r == utf8.RuneError: + if _, m := utf8.DecodeRuneInString(s[i:]); m == 1 { // exactly an invalid byte + n += len("�") - 1 + } + } + } + return n +} diff --git a/wgengine/netlog/record_test.go b/wgengine/netlog/record_test.go new file mode 100644 index 0000000000000..ec0229534f244 --- /dev/null +++ b/wgengine/netlog/record_test.go @@ -0,0 +1,257 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_netlog && !ts_omit_logtail + +package netlog + +import ( + "net/netip" + "testing" + "time" + + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "tailscale.com/tailcfg" + "tailscale.com/types/ipproto" + "tailscale.com/types/netlogtype" + "tailscale.com/util/must" +) + +func addr(s string) netip.Addr { + if s == "" { + return netip.Addr{} + } + return must.Get(netip.ParseAddr(s)) +} +func addrPort(s string) netip.AddrPort { + if s == "" { + return netip.AddrPort{} + } + return must.Get(netip.ParseAddrPort(s)) +} +func prefix(s string) netip.Prefix { + if p, err := netip.ParsePrefix(s); err == nil { + return p + } + a := addr(s) + return netip.PrefixFrom(a, a.BitLen()) +} + +func conn(proto ipproto.Proto, src, dst string) netlogtype.Connection { + return netlogtype.Connection{Proto: proto, Src: addrPort(src), Dst: addrPort(dst)} +} + +func counts(txP, txB, rxP, rxB uint64) netlogtype.Counts { + return netlogtype.Counts{TxPackets: txP, TxBytes: txB, RxPackets: rxP, RxBytes: rxB} +} + +func TestToMessage(t *testing.T) { + rec := record{ + selfNode: nodeUser{NodeView: (&tailcfg.Node{ + ID: 123456, + StableID: "n123456CNTL", + Name: "src.tail123456.ts.net.", + Addresses: []netip.Prefix{prefix("100.1.2.3")}, + Tags: []string{"tag:src"}, + }).View()}, + start: time.Now(), + end: time.Now().Add(5 * time.Second), + + seenNodes: map[netip.Addr]nodeUser{ + addr("100.1.2.4"): {NodeView: (&tailcfg.Node{ + ID: 123457, + StableID: "n123457CNTL", + Name: "dst1.tail123456.ts.net.", + Addresses: []netip.Prefix{prefix("100.1.2.4")}, + Tags: []string{"tag:dst1"}, + }).View()}, + addr("100.1.2.5"): {NodeView: (&tailcfg.Node{ + ID: 123458, + StableID: "n123458CNTL", + Name: "dst2.tail123456.ts.net.", + Addresses: []netip.Prefix{prefix("100.1.2.5")}, + Tags: []string{"tag:dst2"}, + }).View()}, + }, + + virtConns: map[netlogtype.Connection]countsType{ + conn(0x1, "100.1.2.3:1234", "100.1.2.4:80"): {Counts: counts(12, 34, 56, 78), connType: virtualTraffic}, + conn(0x1, "100.1.2.3:1234", "100.1.2.5:80"): {Counts: counts(23, 45, 78, 790), connType: virtualTraffic}, + conn(0x6, "172.16.1.1:80", "100.1.2.4:1234"): {Counts: counts(91, 54, 723, 621), connType: subnetTraffic}, + conn(0x6, "172.16.1.2:443", "100.1.2.5:1234"): {Counts: counts(42, 813, 3, 1823), connType: subnetTraffic}, + conn(0x6, "172.16.1.3:80", "100.1.2.6:1234"): {Counts: counts(34, 52, 78, 790), connType: subnetTraffic}, + conn(0x6, "100.1.2.3:1234", "12.34.56.78:80"): {Counts: counts(11, 110, 10, 100), connType: exitTraffic}, + conn(0x6, "100.1.2.4:1234", "23.34.56.78:80"): {Counts: counts(423, 1, 6, 123), connType: exitTraffic}, + conn(0x6, "100.1.2.4:1234", "23.34.56.78:443"): {Counts: counts(22, 220, 20, 200), connType: exitTraffic}, + conn(0x6, "100.1.2.5:1234", "45.34.56.78:80"): {Counts: counts(33, 330, 30, 300), connType: exitTraffic}, + conn(0x6, "100.1.2.6:1234", "67.34.56.78:80"): {Counts: counts(44, 440, 40, 400), connType: exitTraffic}, + conn(0x6, "42.54.72.42:555", "18.42.7.1:777"): {Counts: counts(44, 440, 40, 400)}, + }, + + physConns: map[netlogtype.Connection]netlogtype.Counts{ + conn(0, "100.1.2.4:0", "4.3.2.1:1234"): counts(12, 34, 56, 78), + conn(0, "100.1.2.5:0", "4.3.2.10:1234"): counts(78, 56, 34, 12), + }, + } + rec.seenNodes[rec.selfNode.toNode().Addresses[0]] = rec.selfNode + + got := rec.toMessage(false, false) + want := netlogtype.Message{ + NodeID: rec.selfNode.StableID(), + Start: rec.start, + End: rec.end, + SrcNode: rec.selfNode.toNode(), + DstNodes: []netlogtype.Node{ + rec.seenNodes[addr("100.1.2.4")].toNode(), + rec.seenNodes[addr("100.1.2.5")].toNode(), + }, + VirtualTraffic: []netlogtype.ConnectionCounts{ + {Connection: conn(0x1, "100.1.2.3:1234", "100.1.2.4:80"), Counts: counts(12, 34, 56, 78)}, + {Connection: conn(0x1, "100.1.2.3:1234", "100.1.2.5:80"), Counts: counts(23, 45, 78, 790)}, + }, + SubnetTraffic: []netlogtype.ConnectionCounts{ + {Connection: conn(0x6, "172.16.1.1:80", "100.1.2.4:1234"), Counts: counts(91, 54, 723, 621)}, + {Connection: conn(0x6, "172.16.1.2:443", "100.1.2.5:1234"), Counts: counts(42, 813, 3, 1823)}, + {Connection: conn(0x6, "172.16.1.3:80", "100.1.2.6:1234"), Counts: counts(34, 52, 78, 790)}, + }, + ExitTraffic: []netlogtype.ConnectionCounts{ + {Connection: conn(0x6, "42.54.72.42:555", "18.42.7.1:777"), Counts: counts(44, 440, 40, 400)}, + {Connection: conn(0x6, "100.1.2.3:1234", "12.34.56.78:80"), Counts: counts(11, 110, 10, 100)}, + {Connection: conn(0x6, "100.1.2.4:1234", "23.34.56.78:80"), Counts: counts(423, 1, 6, 123)}, + {Connection: conn(0x6, "100.1.2.4:1234", "23.34.56.78:443"), Counts: counts(22, 220, 20, 200)}, + {Connection: conn(0x6, "100.1.2.5:1234", "45.34.56.78:80"), Counts: counts(33, 330, 30, 300)}, + {Connection: conn(0x6, "100.1.2.6:1234", "67.34.56.78:80"), Counts: counts(44, 440, 40, 400)}, + }, + PhysicalTraffic: []netlogtype.ConnectionCounts{ + {Connection: conn(0, "100.1.2.4:0", "4.3.2.1:1234"), Counts: counts(12, 34, 56, 78)}, + {Connection: conn(0, "100.1.2.5:0", "4.3.2.10:1234"), Counts: counts(78, 56, 34, 12)}, + }, + } + if d := cmp.Diff(got, want, cmpopts.EquateComparable(netip.Addr{}, netip.AddrPort{})); d != "" { + t.Errorf("toMessage(false, false) mismatch (-got +want):\n%s", d) + } + + got = rec.toMessage(true, false) + want.SrcNode = netlogtype.Node{} + want.DstNodes = nil + if d := cmp.Diff(got, want, cmpopts.EquateComparable(netip.Addr{}, netip.AddrPort{})); d != "" { + t.Errorf("toMessage(true, false) mismatch (-got +want):\n%s", d) + } + + got = rec.toMessage(true, true) + want.ExitTraffic = []netlogtype.ConnectionCounts{ + {Connection: conn(0, "", ""), Counts: counts(44+44, 440+440, 40+40, 400+400)}, + {Connection: conn(0, "100.1.2.3:0", ""), Counts: counts(11, 110, 10, 100)}, + {Connection: conn(0, "100.1.2.4:0", ""), Counts: counts(423+22, 1+220, 6+20, 123+200)}, + {Connection: conn(0, "100.1.2.5:0", ""), Counts: counts(33, 330, 30, 300)}, + } + if d := cmp.Diff(got, want, cmpopts.EquateComparable(netip.Addr{}, netip.AddrPort{})); d != "" { + t.Errorf("toMessage(true, true) mismatch (-got +want):\n%s", d) + } +} + +func TestToNode(t *testing.T) { + tests := []struct { + node *tailcfg.Node + user *tailcfg.UserProfile + want netlogtype.Node + }{ + {}, + { + node: &tailcfg.Node{ + StableID: "n123456CNTL", + Name: "test.tail123456.ts.net.", + Addresses: []netip.Prefix{prefix("100.1.2.3")}, + Tags: []string{"tag:dupe", "tag:test", "tag:dupe"}, + User: 12345, // should be ignored + }, + want: netlogtype.Node{ + NodeID: "n123456CNTL", + Name: "test.tail123456.ts.net", + Addresses: []netip.Addr{addr("100.1.2.3")}, + Tags: []string{"tag:dupe", "tag:test"}, + }, + }, + { + node: &tailcfg.Node{ + StableID: "n123456CNTL", + Addresses: []netip.Prefix{prefix("100.1.2.3")}, + User: 12345, + }, + want: netlogtype.Node{ + NodeID: "n123456CNTL", + Addresses: []netip.Addr{addr("100.1.2.3")}, + }, + }, + { + node: &tailcfg.Node{ + StableID: "n123456CNTL", + Addresses: []netip.Prefix{prefix("100.1.2.3")}, + Hostinfo: (&tailcfg.Hostinfo{OS: "linux"}).View(), + User: 12345, + }, + user: &tailcfg.UserProfile{ + ID: 12345, + LoginName: "user@domain", + }, + want: netlogtype.Node{ + NodeID: "n123456CNTL", + Addresses: []netip.Addr{addr("100.1.2.3")}, + OS: "linux", + User: "user@domain", + }, + }, + } + for _, tt := range tests { + nu := nodeUser{tt.node.View(), tt.user.View()} + got := nu.toNode() + b := must.Get(jsonv2.Marshal(got)) + if len(b) > nu.jsonLen() { + t.Errorf("jsonLen = %v, want >= %d", nu.jsonLen(), len(b)) + } + if d := cmp.Diff(got, tt.want, cmpopts.EquateComparable(netip.Addr{})); d != "" { + t.Errorf("toNode mismatch (-got +want):\n%s", d) + } + } +} + +func FuzzQuotedLen(f *testing.F) { + for _, s := range quotedLenTestdata { + f.Add(s) + } + f.Fuzz(func(t *testing.T, s string) { + testQuotedLen(t, s) + }) +} + +func TestQuotedLen(t *testing.T) { + for _, s := range quotedLenTestdata { + testQuotedLen(t, s) + } +} + +var quotedLenTestdata = []string{ + "", // empty string + func() string { + b := make([]byte, 128) + for i := range b { + b[i] = byte(i) + } + return string(b) + }(), // all ASCII + "�", // replacement rune + "\xff", // invalid UTF-8 + "ʕ◔ϖ◔ʔ", // Unicode gopher +} + +func testQuotedLen(t *testing.T, in string) { + got := jsonQuotedLen(in) + b, _ := jsontext.AppendQuote(nil, in) + want := len(b) + if got != want { + t.Errorf("jsonQuotedLen(%q) = %v, want %v", in, got, want) + } +} diff --git a/wgengine/netlog/stats.go b/wgengine/netlog/stats.go deleted file mode 100644 index c06068803f125..0000000000000 --- a/wgengine/netlog/stats.go +++ /dev/null @@ -1,222 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !ts_omit_netlog && !ts_omit_logtail - -package netlog - -import ( - "context" - "net/netip" - "sync" - "time" - - "golang.org/x/sync/errgroup" - "tailscale.com/net/packet" - "tailscale.com/net/tsaddr" - "tailscale.com/types/ipproto" - "tailscale.com/types/netlogtype" -) - -// statistics maintains counters for every connection. -// All methods are safe for concurrent use. -// The zero value is ready for use. -type statistics struct { - maxConns int // immutable once set - - mu sync.Mutex - connCnts - - connCntsCh chan connCnts - shutdownCtx context.Context - shutdown context.CancelFunc - group errgroup.Group -} - -type connCnts struct { - start time.Time - end time.Time - virtual map[netlogtype.Connection]netlogtype.Counts - physical map[netlogtype.Connection]netlogtype.Counts -} - -// newStatistics creates a data structure for tracking connection statistics -// that periodically dumps the virtual and physical connection counts -// depending on whether the maxPeriod or maxConns is exceeded. -// The dump function is called from a single goroutine. -// Shutdown must be called to cleanup resources. -func newStatistics(maxPeriod time.Duration, maxConns int, dump func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts)) *statistics { - s := &statistics{maxConns: maxConns} - s.connCntsCh = make(chan connCnts, 256) - s.shutdownCtx, s.shutdown = context.WithCancel(context.Background()) - s.group.Go(func() error { - // TODO(joetsai): Using a ticker is problematic on mobile platforms - // where waking up a process every maxPeriod when there is no activity - // is a drain on battery life. Switch this instead to instead use - // a time.Timer that is triggered upon network activity. - ticker := new(time.Ticker) - if maxPeriod > 0 { - ticker = time.NewTicker(maxPeriod) - defer ticker.Stop() - } - - for { - var cc connCnts - select { - case cc = <-s.connCntsCh: - case <-ticker.C: - cc = s.extract() - case <-s.shutdownCtx.Done(): - cc = s.extract() - } - if len(cc.virtual)+len(cc.physical) > 0 && dump != nil { - dump(cc.start, cc.end, cc.virtual, cc.physical) - } - if s.shutdownCtx.Err() != nil { - return nil - } - } - }) - return s -} - -// UpdateTxVirtual updates the counters for a transmitted IP packet -// The source and destination of the packet directly correspond with -// the source and destination in netlogtype.Connection. -func (s *statistics) UpdateTxVirtual(b []byte) { - var p packet.Parsed - p.Decode(b) - s.UpdateVirtual(p.IPProto, p.Src, p.Dst, 1, len(b), false) -} - -// UpdateRxVirtual updates the counters for a received IP packet. -// The source and destination of the packet are inverted with respect to -// the source and destination in netlogtype.Connection. -func (s *statistics) UpdateRxVirtual(b []byte) { - var p packet.Parsed - p.Decode(b) - s.UpdateVirtual(p.IPProto, p.Dst, p.Src, 1, len(b), true) -} - -var ( - tailscaleServiceIPv4 = tsaddr.TailscaleServiceIP() - tailscaleServiceIPv6 = tsaddr.TailscaleServiceIPv6() -) - -func (s *statistics) UpdateVirtual(proto ipproto.Proto, src, dst netip.AddrPort, packets, bytes int, receive bool) { - // Network logging is defined as traffic between two Tailscale nodes. - // Traffic with the internal Tailscale service is not with another node - // and should not be logged. It also happens to be a high volume - // amount of discrete traffic flows (e.g., DNS lookups). - switch dst.Addr() { - case tailscaleServiceIPv4, tailscaleServiceIPv6: - return - } - - conn := netlogtype.Connection{Proto: proto, Src: src, Dst: dst} - - s.mu.Lock() - defer s.mu.Unlock() - cnts, found := s.virtual[conn] - if !found && !s.preInsertConn() { - return - } - if receive { - cnts.RxPackets += uint64(packets) - cnts.RxBytes += uint64(bytes) - } else { - cnts.TxPackets += uint64(packets) - cnts.TxBytes += uint64(bytes) - } - s.virtual[conn] = cnts -} - -// UpdateTxPhysical updates the counters for zero or more transmitted wireguard packets. -// The src is always a Tailscale IP address, representing some remote peer. -// The dst is a remote IP address and port that corresponds -// with some physical peer backing the Tailscale IP address. -func (s *statistics) UpdateTxPhysical(src netip.Addr, dst netip.AddrPort, packets, bytes int) { - s.UpdatePhysical(0, netip.AddrPortFrom(src, 0), dst, packets, bytes, false) -} - -// UpdateRxPhysical updates the counters for zero or more received wireguard packets. -// The src is always a Tailscale IP address, representing some remote peer. -// The dst is a remote IP address and port that corresponds -// with some physical peer backing the Tailscale IP address. -func (s *statistics) UpdateRxPhysical(src netip.Addr, dst netip.AddrPort, packets, bytes int) { - s.UpdatePhysical(0, netip.AddrPortFrom(src, 0), dst, packets, bytes, true) -} - -func (s *statistics) UpdatePhysical(proto ipproto.Proto, src, dst netip.AddrPort, packets, bytes int, receive bool) { - conn := netlogtype.Connection{Proto: proto, Src: src, Dst: dst} - - s.mu.Lock() - defer s.mu.Unlock() - cnts, found := s.physical[conn] - if !found && !s.preInsertConn() { - return - } - if receive { - cnts.RxPackets += uint64(packets) - cnts.RxBytes += uint64(bytes) - } else { - cnts.TxPackets += uint64(packets) - cnts.TxBytes += uint64(bytes) - } - s.physical[conn] = cnts -} - -// preInsertConn updates the maps to handle insertion of a new connection. -// It reports false if insertion is not allowed (i.e., after shutdown). -func (s *statistics) preInsertConn() bool { - // Check whether insertion of a new connection will exceed maxConns. - if len(s.virtual)+len(s.physical) == s.maxConns && s.maxConns > 0 { - // Extract the current statistics and send it to the serializer. - // Avoid blocking the network packet handling path. - select { - case s.connCntsCh <- s.extractLocked(): - default: - // TODO(joetsai): Log that we are dropping an entire connCounts. - } - } - - // Initialize the maps if nil. - if s.virtual == nil && s.physical == nil { - s.start = time.Now().UTC() - s.virtual = make(map[netlogtype.Connection]netlogtype.Counts) - s.physical = make(map[netlogtype.Connection]netlogtype.Counts) - } - - return s.shutdownCtx.Err() == nil -} - -func (s *statistics) extract() connCnts { - s.mu.Lock() - defer s.mu.Unlock() - return s.extractLocked() -} - -func (s *statistics) extractLocked() connCnts { - if len(s.virtual)+len(s.physical) == 0 { - return connCnts{} - } - s.end = time.Now().UTC() - cc := s.connCnts - s.connCnts = connCnts{} - return cc -} - -// TestExtract synchronously extracts the current network statistics map -// and resets the counters. This should only be used for testing purposes. -func (s *statistics) TestExtract() (virtual, physical map[netlogtype.Connection]netlogtype.Counts) { - cc := s.extract() - return cc.virtual, cc.physical -} - -// Shutdown performs a final flush of statistics. -// Statistics for any subsequent calls to Update will be dropped. -// It is safe to call Shutdown concurrently and repeatedly. -func (s *statistics) Shutdown(context.Context) error { - s.shutdown() - return s.group.Wait() -} diff --git a/wgengine/netlog/stats_test.go b/wgengine/netlog/stats_test.go deleted file mode 100644 index 6cf7eb9983817..0000000000000 --- a/wgengine/netlog/stats_test.go +++ /dev/null @@ -1,235 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netlog - -import ( - "context" - "encoding/binary" - "fmt" - "math/rand" - "net/netip" - "runtime" - "sync" - "testing" - "time" - - qt "github.com/frankban/quicktest" - "tailscale.com/cmd/testwrapper/flakytest" - "tailscale.com/types/ipproto" - "tailscale.com/types/netlogtype" -) - -func testPacketV4(proto ipproto.Proto, srcAddr, dstAddr [4]byte, srcPort, dstPort, size uint16) (out []byte) { - var ipHdr [20]byte - ipHdr[0] = 4<<4 | 5 - binary.BigEndian.PutUint16(ipHdr[2:], size) - ipHdr[9] = byte(proto) - *(*[4]byte)(ipHdr[12:]) = srcAddr - *(*[4]byte)(ipHdr[16:]) = dstAddr - out = append(out, ipHdr[:]...) - switch proto { - case ipproto.TCP: - var tcpHdr [20]byte - binary.BigEndian.PutUint16(tcpHdr[0:], srcPort) - binary.BigEndian.PutUint16(tcpHdr[2:], dstPort) - out = append(out, tcpHdr[:]...) - case ipproto.UDP: - var udpHdr [8]byte - binary.BigEndian.PutUint16(udpHdr[0:], srcPort) - binary.BigEndian.PutUint16(udpHdr[2:], dstPort) - out = append(out, udpHdr[:]...) - default: - panic(fmt.Sprintf("unknown proto: %d", proto)) - } - return append(out, make([]byte, int(size)-len(out))...) -} - -// TestInterval ensures that we receive at least one call to `dump` using only -// maxPeriod. -func TestInterval(t *testing.T) { - c := qt.New(t) - - const maxPeriod = 10 * time.Millisecond - const maxConns = 2048 - - gotDump := make(chan struct{}, 1) - stats := newStatistics(maxPeriod, maxConns, func(_, _ time.Time, _, _ map[netlogtype.Connection]netlogtype.Counts) { - select { - case gotDump <- struct{}{}: - default: - } - }) - defer stats.Shutdown(context.Background()) - - srcAddr := netip.AddrFrom4([4]byte{192, 168, 0, byte(rand.Intn(16))}) - dstAddr := netip.AddrFrom4([4]byte{192, 168, 0, byte(rand.Intn(16))}) - srcPort := uint16(rand.Intn(16)) - dstPort := uint16(rand.Intn(16)) - size := uint16(64 + rand.Intn(1024)) - p := testPacketV4(ipproto.TCP, srcAddr.As4(), dstAddr.As4(), srcPort, dstPort, size) - stats.UpdateRxVirtual(p) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - select { - case <-ctx.Done(): - c.Fatal("didn't receive dump within context deadline") - case <-gotDump: - } -} - -func TestConcurrent(t *testing.T) { - flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/7030") - c := qt.New(t) - - const maxPeriod = 10 * time.Millisecond - const maxConns = 10 - virtualAggregate := make(map[netlogtype.Connection]netlogtype.Counts) - stats := newStatistics(maxPeriod, maxConns, func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts) { - c.Assert(start.IsZero(), qt.IsFalse) - c.Assert(end.IsZero(), qt.IsFalse) - c.Assert(end.Before(start), qt.IsFalse) - c.Assert(len(virtual) > 0 && len(virtual) <= maxConns, qt.IsTrue) - c.Assert(len(physical) == 0, qt.IsTrue) - for conn, cnts := range virtual { - virtualAggregate[conn] = virtualAggregate[conn].Add(cnts) - } - }) - defer stats.Shutdown(context.Background()) - var wants []map[netlogtype.Connection]netlogtype.Counts - gots := make([]map[netlogtype.Connection]netlogtype.Counts, runtime.NumCPU()) - var group sync.WaitGroup - for i := range gots { - group.Add(1) - go func(i int) { - defer group.Done() - gots[i] = make(map[netlogtype.Connection]netlogtype.Counts) - rn := rand.New(rand.NewSource(time.Now().UnixNano())) - var p []byte - var t netlogtype.Connection - for j := 0; j < 1000; j++ { - delay := rn.Intn(10000) - if p == nil || rn.Intn(64) == 0 { - proto := ipproto.TCP - if rn.Intn(2) == 0 { - proto = ipproto.UDP - } - srcAddr := netip.AddrFrom4([4]byte{192, 168, 0, byte(rand.Intn(16))}) - dstAddr := netip.AddrFrom4([4]byte{192, 168, 0, byte(rand.Intn(16))}) - srcPort := uint16(rand.Intn(16)) - dstPort := uint16(rand.Intn(16)) - size := uint16(64 + rand.Intn(1024)) - p = testPacketV4(proto, srcAddr.As4(), dstAddr.As4(), srcPort, dstPort, size) - t = netlogtype.Connection{Proto: proto, Src: netip.AddrPortFrom(srcAddr, srcPort), Dst: netip.AddrPortFrom(dstAddr, dstPort)} - } - t2 := t - receive := rn.Intn(2) == 0 - if receive { - t2.Src, t2.Dst = t2.Dst, t2.Src - } - - cnts := gots[i][t2] - if receive { - stats.UpdateRxVirtual(p) - cnts.RxPackets++ - cnts.RxBytes += uint64(len(p)) - } else { - cnts.TxPackets++ - cnts.TxBytes += uint64(len(p)) - stats.UpdateTxVirtual(p) - } - gots[i][t2] = cnts - time.Sleep(time.Duration(rn.Intn(1 + delay))) - } - }(i) - } - group.Wait() - c.Assert(stats.Shutdown(context.Background()), qt.IsNil) - wants = append(wants, virtualAggregate) - - got := make(map[netlogtype.Connection]netlogtype.Counts) - want := make(map[netlogtype.Connection]netlogtype.Counts) - mergeMaps(got, gots...) - mergeMaps(want, wants...) - c.Assert(got, qt.DeepEquals, want) -} - -func mergeMaps(dst map[netlogtype.Connection]netlogtype.Counts, srcs ...map[netlogtype.Connection]netlogtype.Counts) { - for _, src := range srcs { - for conn, cnts := range src { - dst[conn] = dst[conn].Add(cnts) - } - } -} - -func Benchmark(b *testing.B) { - // TODO: Test IPv6 packets? - b.Run("SingleRoutine/SameConn", func(b *testing.B) { - p := testPacketV4(ipproto.UDP, [4]byte{192, 168, 0, 1}, [4]byte{192, 168, 0, 2}, 123, 456, 789) - b.ResetTimer() - b.ReportAllocs() - for range b.N { - s := newStatistics(0, 0, nil) - for j := 0; j < 1e3; j++ { - s.UpdateTxVirtual(p) - } - } - }) - b.Run("SingleRoutine/UniqueConns", func(b *testing.B) { - p := testPacketV4(ipproto.UDP, [4]byte{}, [4]byte{}, 0, 0, 789) - b.ResetTimer() - b.ReportAllocs() - for range b.N { - s := newStatistics(0, 0, nil) - for j := 0; j < 1e3; j++ { - binary.BigEndian.PutUint32(p[20:], uint32(j)) // unique port combination - s.UpdateTxVirtual(p) - } - } - }) - b.Run("MultiRoutine/SameConn", func(b *testing.B) { - p := testPacketV4(ipproto.UDP, [4]byte{192, 168, 0, 1}, [4]byte{192, 168, 0, 2}, 123, 456, 789) - b.ResetTimer() - b.ReportAllocs() - for range b.N { - s := newStatistics(0, 0, nil) - var group sync.WaitGroup - for j := 0; j < runtime.NumCPU(); j++ { - group.Add(1) - go func() { - defer group.Done() - for k := 0; k < 1e3; k++ { - s.UpdateTxVirtual(p) - } - }() - } - group.Wait() - } - }) - b.Run("MultiRoutine/UniqueConns", func(b *testing.B) { - ps := make([][]byte, runtime.NumCPU()) - for i := range ps { - ps[i] = testPacketV4(ipproto.UDP, [4]byte{192, 168, 0, 1}, [4]byte{192, 168, 0, 2}, 0, 0, 789) - } - b.ResetTimer() - b.ReportAllocs() - for range b.N { - s := newStatistics(0, 0, nil) - var group sync.WaitGroup - for j := 0; j < runtime.NumCPU(); j++ { - group.Add(1) - go func(j int) { - defer group.Done() - p := ps[j] - j *= 1e3 - for k := 0; k < 1e3; k++ { - binary.BigEndian.PutUint32(p[20:], uint32(j+k)) // unique port combination - s.UpdateTxVirtual(p) - } - }(j) - } - group.Wait() - } - }) -} diff --git a/wgengine/netstack/link_endpoint.go b/wgengine/netstack/link_endpoint.go index 260b3196ab2fc..c5a9dbcbca538 100644 --- a/wgengine/netstack/link_endpoint.go +++ b/wgengine/netstack/link_endpoint.go @@ -126,24 +126,24 @@ func newLinkEndpoint(size int, mtu uint32, linkAddr tcpip.LinkAddress, supported return le } -// gro attempts to enqueue p on g if l supports a GRO kind matching the +// gro attempts to enqueue p on g if ep supports a GRO kind matching the // transport protocol carried in p. gro may allocate g if it is nil. gro can // either return the existing g, a newly allocated one, or nil. Callers are // responsible for calling Flush() on the returned value if it is non-nil once // they have finished iterating through all GRO candidates for a given vector. -// If gro allocates a *gro.GRO it will have l's stack.NetworkDispatcher set via +// If gro allocates a *gro.GRO it will have ep's stack.NetworkDispatcher set via // SetDispatcher(). -func (l *linkEndpoint) gro(p *packet.Parsed, g *gro.GRO) *gro.GRO { - if !buildfeatures.HasGRO || l.supportedGRO == groNotSupported || p.IPProto != ipproto.TCP { +func (ep *linkEndpoint) gro(p *packet.Parsed, g *gro.GRO) *gro.GRO { + if !buildfeatures.HasGRO || ep.supportedGRO == groNotSupported || p.IPProto != ipproto.TCP { // IPv6 may have extension headers preceding a TCP header, but we trade // for a fast path and assume p cannot be coalesced in such a case. - l.injectInbound(p) + ep.injectInbound(p) return g } if g == nil { - l.mu.RLock() - d := l.dispatcher - l.mu.RUnlock() + ep.mu.RLock() + d := ep.dispatcher + ep.mu.RUnlock() g = gro.NewGRO() g.SetDispatcher(d) } @@ -154,39 +154,39 @@ func (l *linkEndpoint) gro(p *packet.Parsed, g *gro.GRO) *gro.GRO { // Close closes l. Further packet injections will return an error, and all // pending packets are discarded. Close may be called concurrently with // WritePackets. -func (l *linkEndpoint) Close() { - l.mu.Lock() - l.dispatcher = nil - l.mu.Unlock() - l.q.Close() - l.Drain() +func (ep *linkEndpoint) Close() { + ep.mu.Lock() + ep.dispatcher = nil + ep.mu.Unlock() + ep.q.Close() + ep.Drain() } // Read does non-blocking read one packet from the outbound packet queue. -func (l *linkEndpoint) Read() *stack.PacketBuffer { - return l.q.Read() +func (ep *linkEndpoint) Read() *stack.PacketBuffer { + return ep.q.Read() } // ReadContext does blocking read for one packet from the outbound packet queue. // It can be cancelled by ctx, and in this case, it returns nil. -func (l *linkEndpoint) ReadContext(ctx context.Context) *stack.PacketBuffer { - return l.q.ReadContext(ctx) +func (ep *linkEndpoint) ReadContext(ctx context.Context) *stack.PacketBuffer { + return ep.q.ReadContext(ctx) } // Drain removes all outbound packets from the channel and counts them. -func (l *linkEndpoint) Drain() int { - return l.q.Drain() +func (ep *linkEndpoint) Drain() int { + return ep.q.Drain() } // NumQueued returns the number of packets queued for outbound. -func (l *linkEndpoint) NumQueued() int { - return l.q.Num() +func (ep *linkEndpoint) NumQueued() int { + return ep.q.Num() } -func (l *linkEndpoint) injectInbound(p *packet.Parsed) { - l.mu.RLock() - d := l.dispatcher - l.mu.RUnlock() +func (ep *linkEndpoint) injectInbound(p *packet.Parsed) { + ep.mu.RLock() + d := ep.dispatcher + ep.mu.RUnlock() if d == nil || !buildfeatures.HasNetstack { return } @@ -200,35 +200,35 @@ func (l *linkEndpoint) injectInbound(p *packet.Parsed) { // Attach saves the stack network-layer dispatcher for use later when packets // are injected. -func (l *linkEndpoint) Attach(dispatcher stack.NetworkDispatcher) { - l.mu.Lock() - defer l.mu.Unlock() - l.dispatcher = dispatcher +func (ep *linkEndpoint) Attach(dispatcher stack.NetworkDispatcher) { + ep.mu.Lock() + defer ep.mu.Unlock() + ep.dispatcher = dispatcher } // IsAttached implements stack.LinkEndpoint.IsAttached. -func (l *linkEndpoint) IsAttached() bool { - l.mu.RLock() - defer l.mu.RUnlock() - return l.dispatcher != nil +func (ep *linkEndpoint) IsAttached() bool { + ep.mu.RLock() + defer ep.mu.RUnlock() + return ep.dispatcher != nil } // MTU implements stack.LinkEndpoint.MTU. -func (l *linkEndpoint) MTU() uint32 { - l.mu.RLock() - defer l.mu.RUnlock() - return l.mtu +func (ep *linkEndpoint) MTU() uint32 { + ep.mu.RLock() + defer ep.mu.RUnlock() + return ep.mtu } // SetMTU implements stack.LinkEndpoint.SetMTU. -func (l *linkEndpoint) SetMTU(mtu uint32) { - l.mu.Lock() - defer l.mu.Unlock() - l.mtu = mtu +func (ep *linkEndpoint) SetMTU(mtu uint32) { + ep.mu.Lock() + defer ep.mu.Unlock() + ep.mtu = mtu } // Capabilities implements stack.LinkEndpoint.Capabilities. -func (l *linkEndpoint) Capabilities() stack.LinkEndpointCapabilities { +func (ep *linkEndpoint) Capabilities() stack.LinkEndpointCapabilities { // We are required to offload RX checksum validation for the purposes of // GRO. return stack.CapabilityRXChecksumOffload @@ -242,8 +242,8 @@ func (*linkEndpoint) GSOMaxSize() uint32 { } // SupportedGSO implements stack.GSOEndpoint. -func (l *linkEndpoint) SupportedGSO() stack.SupportedGSO { - return l.SupportedGSOKind +func (ep *linkEndpoint) SupportedGSO() stack.SupportedGSO { + return ep.SupportedGSOKind } // MaxHeaderLength returns the maximum size of the link layer header. Given it @@ -253,22 +253,22 @@ func (*linkEndpoint) MaxHeaderLength() uint16 { } // LinkAddress returns the link address of this endpoint. -func (l *linkEndpoint) LinkAddress() tcpip.LinkAddress { - l.mu.RLock() - defer l.mu.RUnlock() - return l.linkAddr +func (ep *linkEndpoint) LinkAddress() tcpip.LinkAddress { + ep.mu.RLock() + defer ep.mu.RUnlock() + return ep.linkAddr } // SetLinkAddress implements stack.LinkEndpoint.SetLinkAddress. -func (l *linkEndpoint) SetLinkAddress(addr tcpip.LinkAddress) { - l.mu.Lock() - defer l.mu.Unlock() - l.linkAddr = addr +func (ep *linkEndpoint) SetLinkAddress(addr tcpip.LinkAddress) { + ep.mu.Lock() + defer ep.mu.Unlock() + ep.linkAddr = addr } // WritePackets stores outbound packets into the channel. // Multiple concurrent calls are permitted. -func (l *linkEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { +func (ep *linkEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { n := 0 // TODO(jwhited): evaluate writing a stack.PacketBufferList instead of a // single packet. We can split 2 x 64K GSO across @@ -278,7 +278,7 @@ func (l *linkEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Err // control MTU (and by effect TCP MSS in gVisor) we *shouldn't* expect to // ever overflow 128 slots (see wireguard-go/tun.ErrTooManySegments usage). for _, pkt := range pkts.AsSlice() { - if err := l.q.Write(pkt); err != nil { + if err := ep.q.Write(pkt); err != nil { if _, ok := err.(*tcpip.ErrNoBufferSpace); !ok && n == 0 { return 0, err } diff --git a/wgengine/router/osrouter/router_linux.go b/wgengine/router/osrouter/router_linux.go index 58bd0513ab768..7442c045ee079 100644 --- a/wgengine/router/osrouter/router_linux.go +++ b/wgengine/router/osrouter/router_linux.go @@ -581,7 +581,7 @@ func (r *linuxRouter) updateMagicsockPort(port uint16, network string) error { } if port != 0 { - if err := r.nfr.AddMagicsockPortRule(*magicsockPort, network); err != nil { + if err := r.nfr.AddMagicsockPortRule(port, network); err != nil { return fmt.Errorf("add magicsock port rule: %w", err) } } @@ -1617,7 +1617,7 @@ func checkOpenWRTUsingMWAN3() (bool, error) { // We want to match on a rule like this: // 2001: from all fwmark 0x100/0x3f00 lookup 1 // - // We dont match on the mask because it can vary, or the + // We don't match on the mask because it can vary, or the // table because I'm not sure if it can vary. if r.Priority >= 2001 && r.Priority <= 2004 && r.Mark != 0 { return true, nil diff --git a/wgengine/router/osrouter/router_linux_test.go b/wgengine/router/osrouter/router_linux_test.go index 39210ddef14a2..68ed8dbb2bb64 100644 --- a/wgengine/router/osrouter/router_linux_test.go +++ b/wgengine/router/osrouter/router_linux_test.go @@ -870,7 +870,7 @@ func (o *fakeOS) run(args ...string) error { rest = family + " " + strings.Join(args[3:], " ") } - var l *[]string + var ls *[]string switch args[1] { case "link": got := strings.Join(args[2:], " ") @@ -884,31 +884,31 @@ func (o *fakeOS) run(args ...string) error { } return nil case "addr": - l = &o.ips + ls = &o.ips case "route": - l = &o.routes + ls = &o.routes case "rule": - l = &o.rules + ls = &o.rules default: return unexpected() } switch args[2] { case "add": - for _, el := range *l { + for _, el := range *ls { if el == rest { o.t.Errorf("can't add %q, already present", rest) return errors.New("already exists") } } - *l = append(*l, rest) - sort.Strings(*l) + *ls = append(*ls, rest) + sort.Strings(*ls) case "del": found := false - for i, el := range *l { + for i, el := range *ls { if el == rest { found = true - *l = append((*l)[:i], (*l)[i+1:]...) + *ls = append((*ls)[:i], (*ls)[i+1:]...) break } } @@ -1290,3 +1290,43 @@ func TestIPRulesForUBNT(t *testing.T) { } } } + +func TestUpdateMagicsockPortChange(t *testing.T) { + nfr := &fakeIPTablesRunner{ + t: t, + ipt4: make(map[string][]string), + ipt6: make(map[string][]string), + } + nfr.ipt4["filter/ts-input"] = []string{} + + r := &linuxRouter{ + logf: logger.Discard, + health: new(health.Tracker), + netfilterMode: netfilterOn, + nfr: nfr, + } + + if err := r.updateMagicsockPort(12345, "udp4"); err != nil { + t.Fatalf("failed to set initial port: %v", err) + } + + if err := r.updateMagicsockPort(54321, "udp4"); err != nil { + t.Fatalf("failed to update port: %v", err) + } + + newPortRule := buildMagicsockPortRule(54321) + hasNewRule := slices.Contains(nfr.ipt4["filter/ts-input"], newPortRule) + + if !hasNewRule { + t.Errorf("firewall rule for NEW port 54321 not found.\nExpected: %s\nActual rules: %v", + newPortRule, nfr.ipt4["filter/ts-input"]) + } + + oldPortRule := buildMagicsockPortRule(12345) + hasOldRule := slices.Contains(nfr.ipt4["filter/ts-input"], oldPortRule) + + if hasOldRule { + t.Errorf("firewall rule for OLD port 12345 still exists (should be deleted).\nFound: %s\nAll rules: %v", + oldPortRule, nfr.ipt4["filter/ts-input"]) + } +} diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 8856a3eaf4d11..1b8562d3ffe55 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -47,9 +47,11 @@ import ( "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/types/views" + "tailscale.com/util/backoff" "tailscale.com/util/checkchange" "tailscale.com/util/clientmetric" "tailscale.com/util/eventbus" + "tailscale.com/util/execqueue" "tailscale.com/util/mak" "tailscale.com/util/set" "tailscale.com/util/testenv" @@ -97,6 +99,8 @@ type userspaceEngine struct { eventBus *eventbus.Bus eventClient *eventbus.Client + linkChangeQueue execqueue.ExecQueue + logf logger.Logf wgLogger *wglog.Logger // a wireguard-go logging wrapper reqCh chan struct{} @@ -145,7 +149,7 @@ type userspaceEngine struct { netMap *netmap.NetworkMap // or nil closing bool // Close was called (even if we're still closing) statusCallback StatusCallback - peerSequence []key.NodePublic + peerSequence views.Slice[key.NodePublic] endpoints []tailcfg.Endpoint pendOpen map[flowtrackTuple]*pendingOpenFlow // see pendopen.go @@ -319,9 +323,9 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) var tsTUNDev *tstun.Wrapper if conf.IsTAP { - tsTUNDev = tstun.WrapTAP(logf, conf.Tun, conf.Metrics) + tsTUNDev = tstun.WrapTAP(logf, conf.Tun, conf.Metrics, conf.EventBus) } else { - tsTUNDev = tstun.Wrap(logf, conf.Tun, conf.Metrics) + tsTUNDev = tstun.Wrap(logf, conf.Tun, conf.Metrics, conf.EventBus) } closePool.add(tsTUNDev) @@ -447,6 +451,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) cb := e.pongCallback[pong.Data] e.logf("wgengine: got TSMP pong %02x, peerAPIPort=%v; cb=%v", pong.Data, pong.PeerAPIPort, cb != nil) if cb != nil { + delete(e.pongCallback, pong.Data) go cb(pong) } } @@ -460,6 +465,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) // We didn't swallow it, so let it flow to the host. return false } + delete(e.icmpEchoResponseCallback, idSeq) e.logf("wgengine: got diagnostic ICMP response %02x", idSeq) go cb() return true @@ -543,7 +549,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) if f, ok := feature.HookProxyInvalidateCache.GetOk(); ok { f() } - e.linkChange(&cd) + e.linkChangeQueue.Add(func() { e.linkChange(&cd) }) }) e.eventClient = ec e.logf("Engine created.") @@ -924,6 +930,32 @@ func hasOverlap(aips, rips views.Slice[netip.Prefix]) bool { return false } +// ResetAndStop resets the engine to a clean state (like calling Reconfig +// with all pointers to zero values) and waits for it to be fully stopped, +// with no live peers or DERPs. +// +// Unlike Reconfig, it does not return ErrNoChanges. +// +// If the engine stops, returns the status. NB that this status will not be sent +// to the registered status callback, it is on the caller to ensure this status +// is handled appropriately. +func (e *userspaceEngine) ResetAndStop() (*Status, error) { + if err := e.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{}); err != nil && !errors.Is(err, ErrNoChanges) { + return nil, err + } + bo := backoff.NewBackoff("UserspaceEngineResetAndStop", e.logf, 1*time.Second) + for { + st, err := e.getStatus() + if err != nil { + return nil, err + } + if len(st.Peers) == 0 && st.DERPs == 0 { + return st, nil + } + bo.BackOff(context.Background(), fmt.Errorf("waiting for engine to stop: peers=%d derps=%d", len(st.Peers), st.DERPs)) + } +} + func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *dns.Config) error { if routerCfg == nil { panic("routerCfg must not be nil") @@ -939,12 +971,15 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, e.tundev.SetWGConfig(cfg) peerSet := make(set.Set[key.NodePublic], len(cfg.Peers)) + e.mu.Lock() - e.peerSequence = e.peerSequence[:0] + seq := make([]key.NodePublic, 0, len(cfg.Peers)) for _, p := range cfg.Peers { - e.peerSequence = append(e.peerSequence, p.PublicKey) + seq = append(seq, p.PublicKey) peerSet.Add(p.PublicKey) } + e.peerSequence = views.SliceOf(seq) + nm := e.netMap e.mu.Unlock() @@ -1055,7 +1090,7 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, tid := cfg.NetworkLogging.DomainID logExitFlowEnabled := cfg.NetworkLogging.LogExitFlowEnabled e.logf("wgengine: Reconfig: starting up network logger (node:%s tailnet:%s)", nid.Public(), tid.Public()) - if err := e.networkLogger.Startup(cfg.NodeID, nid, tid, e.tundev, e.magicConn, e.netMon, e.health, e.eventBus, logExitFlowEnabled); err != nil { + if err := e.networkLogger.Startup(e.logf, nm, nid, tid, e.tundev, e.magicConn, e.netMon, e.health, e.eventBus, logExitFlowEnabled); err != nil { e.logf("wgengine: Reconfig: error starting up network logger: %v", err) } e.networkLogger.ReconfigRoutes(routerCfg) @@ -1199,7 +1234,7 @@ func (e *userspaceEngine) getStatus() (*Status, error) { e.mu.Lock() closing := e.closing - peerKeys := slices.Clone(e.peerSequence) + peerKeys := e.peerSequence localAddrs := slices.Clone(e.endpoints) e.mu.Unlock() @@ -1207,8 +1242,8 @@ func (e *userspaceEngine) getStatus() (*Status, error) { return nil, ErrEngineClosing } - peers := make([]ipnstate.PeerStatusLite, 0, len(peerKeys)) - for _, key := range peerKeys { + peers := make([]ipnstate.PeerStatusLite, 0, peerKeys.Len()) + for _, key := range peerKeys.All() { if status, ok := e.getPeerStatusLite(key); ok { peers = append(peers, status) } @@ -1258,6 +1293,9 @@ func (e *userspaceEngine) RequestStatus() { func (e *userspaceEngine) Close() { e.eventClient.Close() + // TODO(cmol): Should we wait for it too? + // Same question raised in appconnector.go. + e.linkChangeQueue.Shutdown() e.mu.Lock() if e.closing { e.mu.Unlock() @@ -1352,6 +1390,9 @@ func (e *userspaceEngine) SetNetworkMap(nm *netmap.NetworkMap) { e.mu.Lock() e.netMap = nm e.mu.Unlock() + if e.networkLogger.Running() { + e.networkLogger.ReconfigNetworkMap(nm) + } } func (e *userspaceEngine) UpdateStatus(sb *ipnstate.StatusBuilder) { @@ -1397,6 +1438,7 @@ func (e *userspaceEngine) Ping(ip netip.Addr, pingType tailcfg.PingType, size in e.magicConn.Ping(peer, res, size, cb) case "TSMP": e.sendTSMPPing(ip, peer, res, cb) + e.sendTSMPDiscoAdvertisement(ip) case "ICMP": e.sendICMPEchoRequest(ip, peer, res, cb) } @@ -1517,6 +1559,29 @@ func (e *userspaceEngine) sendTSMPPing(ip netip.Addr, peer tailcfg.NodeView, res e.tundev.InjectOutbound(tsmpPing) } +func (e *userspaceEngine) sendTSMPDiscoAdvertisement(ip netip.Addr) { + srcIP, err := e.mySelfIPMatchingFamily(ip) + if err != nil { + e.logf("getting matching node: %s", err) + return + } + tdka := packet.TSMPDiscoKeyAdvertisement{ + Src: srcIP, + Dst: ip, + Key: e.magicConn.DiscoPublicKey(), + } + payload, err := tdka.Marshal() + if err != nil { + e.logf("error generating TSMP Advertisement: %s", err) + metricTSMPDiscoKeyAdvertisementError.Add(1) + } else if err := e.tundev.InjectOutbound(payload); err != nil { + e.logf("error sending TSMP Advertisement: %s", err) + metricTSMPDiscoKeyAdvertisementError.Add(1) + } else { + metricTSMPDiscoKeyAdvertisementSent.Add(1) + } +} + func (e *userspaceEngine) setTSMPPongCallback(data [8]byte, cb func(packet.TSMPPongReply)) { e.mu.Lock() defer e.mu.Unlock() @@ -1683,6 +1748,9 @@ var ( metricNumMajorChanges = clientmetric.NewCounter("wgengine_major_changes") metricNumMinorChanges = clientmetric.NewCounter("wgengine_minor_changes") + + metricTSMPDiscoKeyAdvertisementSent = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_sent") + metricTSMPDiscoKeyAdvertisementError = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_error") ) func (e *userspaceEngine) InstallCaptureHook(cb packet.CaptureCallback) { diff --git a/wgengine/userspace_test.go b/wgengine/userspace_test.go index 89d75b98adafb..0a1d2924d593b 100644 --- a/wgengine/userspace_test.go +++ b/wgengine/userspace_test.go @@ -325,6 +325,64 @@ func TestUserspaceEnginePeerMTUReconfig(t *testing.T) { } } +func TestTSMPKeyAdvertisement(t *testing.T) { + var knobs controlknobs.Knobs + + bus := eventbustest.NewBus(t) + ht := health.NewTracker(bus) + reg := new(usermetric.Registry) + e, err := NewFakeUserspaceEngine(t.Logf, 0, &knobs, ht, reg, bus) + if err != nil { + t.Fatal(err) + } + t.Cleanup(e.Close) + ue := e.(*userspaceEngine) + routerCfg := &router.Config{} + nodeKey := nkFromHex("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") + nm := &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 1, + Key: nodeKey, + }, + }), + SelfNode: (&tailcfg.Node{ + StableID: "TESTCTRL00000001", + Name: "test-node.test.ts.net", + Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32"), netip.MustParsePrefix("fd7a:115c:a1e0:ab12:4843:cd96:0:1/128")}, + }).View(), + } + cfg := &wgcfg.Config{ + Peers: []wgcfg.Peer{ + { + PublicKey: nodeKey, + AllowedIPs: []netip.Prefix{ + netip.PrefixFrom(netaddr.IPv4(100, 100, 99, 1), 32), + }, + }, + }, + } + + ue.SetNetworkMap(nm) + err = ue.Reconfig(cfg, routerCfg, &dns.Config{}) + if err != nil { + t.Fatal(err) + } + + addr := netip.MustParseAddr("100.100.99.1") + previousValue := metricTSMPDiscoKeyAdvertisementSent.Value() + ue.sendTSMPDiscoAdvertisement(addr) + if val := metricTSMPDiscoKeyAdvertisementSent.Value(); val <= previousValue { + errs := metricTSMPDiscoKeyAdvertisementError.Value() + t.Errorf("Expected 1 disco key advert, got %d, errors %d", val, errs) + } + // Remove config to have the engine shut down more consistently + err = ue.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{}) + if err != nil { + t.Fatal(err) + } +} + func nkFromHex(hex string) key.NodePublic { if len(hex) != 64 { panic(fmt.Sprintf("%q is len %d; want 64", hex, len(hex))) diff --git a/wgengine/watchdog.go b/wgengine/watchdog.go index 0500e6f7fd4c7..9cc4ed3b594c3 100644 --- a/wgengine/watchdog.go +++ b/wgengine/watchdog.go @@ -124,6 +124,12 @@ func (e *watchdogEngine) watchdog(name string, fn func()) { func (e *watchdogEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *dns.Config) error { return e.watchdogErr("Reconfig", func() error { return e.wrap.Reconfig(cfg, routerCfg, dnsCfg) }) } +func (e *watchdogEngine) ResetAndStop() (st *Status, err error) { + e.watchdog("ResetAndStop", func() { + st, err = e.wrap.ResetAndStop() + }) + return st, err +} func (e *watchdogEngine) GetFilter() *filter.Filter { return e.wrap.GetFilter() } diff --git a/wgengine/wgcfg/config.go b/wgengine/wgcfg/config.go index 926964a4bdc20..2734f6c6ea969 100644 --- a/wgengine/wgcfg/config.go +++ b/wgengine/wgcfg/config.go @@ -8,7 +8,6 @@ import ( "net/netip" "slices" - "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/logid" ) @@ -18,8 +17,6 @@ import ( // Config is a WireGuard configuration. // It only supports the set of things Tailscale uses. type Config struct { - Name string - NodeID tailcfg.StableNodeID PrivateKey key.NodePrivate Addresses []netip.Prefix MTU uint16 @@ -40,9 +37,7 @@ func (c *Config) Equal(o *Config) bool { if c == nil || o == nil { return c == o } - return c.Name == o.Name && - c.NodeID == o.NodeID && - c.PrivateKey.Equal(o.PrivateKey) && + return c.PrivateKey.Equal(o.PrivateKey) && c.MTU == o.MTU && c.NetworkLogging == o.NetworkLogging && slices.Equal(c.Addresses, o.Addresses) && diff --git a/wgengine/wgcfg/nmcfg/nmcfg.go b/wgengine/wgcfg/nmcfg/nmcfg.go index 1add608e4496c..487e78d81218d 100644 --- a/wgengine/wgcfg/nmcfg/nmcfg.go +++ b/wgengine/wgcfg/nmcfg/nmcfg.go @@ -5,12 +5,14 @@ package nmcfg import ( - "bytes" + "bufio" + "cmp" "fmt" "net/netip" "strings" "tailscale.com/tailcfg" + "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/netmap" @@ -18,16 +20,7 @@ import ( ) func nodeDebugName(n tailcfg.NodeView) string { - name := n.Name() - if name == "" { - name = n.Hostinfo().Hostname() - } - if i := strings.Index(name, "."); i != -1 { - name = name[:i] - } - if name == "" && n.Addresses().Len() != 0 { - return n.Addresses().At(0).String() - } + name, _, _ := strings.Cut(cmp.Or(n.Name(), n.Hostinfo().Hostname()), ".") return name } @@ -49,17 +42,15 @@ func cidrIsSubnet(node tailcfg.NodeView, cidr netip.Prefix) bool { } // WGCfg returns the NetworkMaps's WireGuard configuration. -func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, exitNode tailcfg.StableNodeID) (*wgcfg.Config, error) { +func WGCfg(pk key.NodePrivate, nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, exitNode tailcfg.StableNodeID) (*wgcfg.Config, error) { cfg := &wgcfg.Config{ - Name: "tailscale", - PrivateKey: nm.PrivateKey, + PrivateKey: pk, Addresses: nm.GetAddresses().AsSlice(), Peers: make([]wgcfg.Peer, 0, len(nm.Peers)), } // Setup log IDs for data plane audit logging. if nm.SelfNode.Valid() { - cfg.NodeID = nm.SelfNode.StableID() canNetworkLog := nm.SelfNode.HasCap(tailcfg.CapabilityDataPlaneAuditLogs) logExitFlowEnabled := nm.SelfNode.HasCap(tailcfg.NodeAttrLogExitFlows) if canNetworkLog && nm.SelfNode.DataPlaneAuditLogID() != "" && nm.DomainAuditLogID != "" { @@ -79,10 +70,7 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, } } - // Logging buffers - skippedUnselected := new(bytes.Buffer) - skippedSubnets := new(bytes.Buffer) - skippedExpired := new(bytes.Buffer) + var skippedExitNode, skippedSubnetRouter, skippedExpired []tailcfg.NodeView for _, peer := range nm.Peers { if peer.DiscoKey().IsZero() && peer.HomeDERP() == 0 && !peer.IsWireGuardOnly() { @@ -95,16 +83,7 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, // anyway, since control intentionally breaks node keys for // expired peers so that we can't discover endpoints via DERP. if peer.Expired() { - if skippedExpired.Len() >= 1<<10 { - if !bytes.HasSuffix(skippedExpired.Bytes(), []byte("...")) { - skippedExpired.WriteString("...") - } - } else { - if skippedExpired.Len() > 0 { - skippedExpired.WriteString(", ") - } - fmt.Fprintf(skippedExpired, "%s/%v", peer.StableID(), peer.Key().ShortString()) - } + skippedExpired = append(skippedExpired, peer) continue } @@ -114,28 +93,22 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, }) cpeer := &cfg.Peers[len(cfg.Peers)-1] - didExitNodeWarn := false + didExitNodeLog := false cpeer.V4MasqAddr = peer.SelfNodeV4MasqAddrForThisPeer().Clone() cpeer.V6MasqAddr = peer.SelfNodeV6MasqAddrForThisPeer().Clone() cpeer.IsJailed = peer.IsJailed() for _, allowedIP := range peer.AllowedIPs().All() { if allowedIP.Bits() == 0 && peer.StableID() != exitNode { - if didExitNodeWarn { + if didExitNodeLog { // Don't log about both the IPv4 /0 and IPv6 /0. continue } - didExitNodeWarn = true - if skippedUnselected.Len() > 0 { - skippedUnselected.WriteString(", ") - } - fmt.Fprintf(skippedUnselected, "%q (%v)", nodeDebugName(peer), peer.Key().ShortString()) + didExitNodeLog = true + skippedExitNode = append(skippedExitNode, peer) continue } else if cidrIsSubnet(peer, allowedIP) { if (flags & netmap.AllowSubnetRoutes) == 0 { - if skippedSubnets.Len() > 0 { - skippedSubnets.WriteString(", ") - } - fmt.Fprintf(skippedSubnets, "%v from %q (%v)", allowedIP, nodeDebugName(peer), peer.Key().ShortString()) + skippedSubnetRouter = append(skippedSubnetRouter, peer) continue } } @@ -143,14 +116,27 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, } } - if skippedUnselected.Len() > 0 { - logf("[v1] wgcfg: skipped unselected default routes from: %s", skippedUnselected.Bytes()) - } - if skippedSubnets.Len() > 0 { - logf("[v1] wgcfg: did not accept subnet routes: %s", skippedSubnets) - } - if skippedExpired.Len() > 0 { - logf("[v1] wgcfg: skipped expired peer: %s", skippedExpired) + logList := func(title string, nodes []tailcfg.NodeView) { + if len(nodes) == 0 { + return + } + logf("[v1] wgcfg: %s from %d nodes: %s", title, len(nodes), logger.ArgWriter(func(bw *bufio.Writer) { + const max = 5 + for i, n := range nodes { + if i == max { + fmt.Fprintf(bw, "... +%d", len(nodes)-max) + return + } + if i > 0 { + bw.WriteString(", ") + } + fmt.Fprintf(bw, "%s (%s)", nodeDebugName(n), n.StableID()) + } + })) } + logList("skipped unselected exit nodes", skippedExitNode) + logList("did not accept subnet routes", skippedSubnetRouter) + logList("skipped expired peers", skippedExpired) + return cfg, nil } diff --git a/wgengine/wgcfg/wgcfg_clone.go b/wgengine/wgcfg/wgcfg_clone.go index 749d8d8160579..9f3cabde182f9 100644 --- a/wgengine/wgcfg/wgcfg_clone.go +++ b/wgengine/wgcfg/wgcfg_clone.go @@ -8,7 +8,6 @@ package wgcfg import ( "net/netip" - "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/logid" "tailscale.com/types/ptr" @@ -35,8 +34,6 @@ func (src *Config) Clone() *Config { // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _ConfigCloneNeedsRegeneration = Config(struct { - Name string - NodeID tailcfg.StableNodeID PrivateKey key.NodePrivate Addresses []netip.Prefix MTU uint16 diff --git a/wgengine/wgengine.go b/wgengine/wgengine.go index 6aaf567ad01ee..be78731474bc9 100644 --- a/wgengine/wgengine.go +++ b/wgengine/wgengine.go @@ -69,6 +69,13 @@ type Engine interface { // The returned error is ErrNoChanges if no changes were made. Reconfig(*wgcfg.Config, *router.Config, *dns.Config) error + // ResetAndStop resets the engine to a clean state (like calling Reconfig + // with all pointers to zero values) and waits for it to be fully stopped, + // with no live peers or DERPs. + // + // Unlike Reconfig, it does not return ErrNoChanges. + ResetAndStop() (*Status, error) + // PeerForIP returns the node to which the provided IP routes, // if any. If none is found, (nil, false) is returned. PeerForIP(netip.Addr) (_ PeerForIP, ok bool) diff --git a/words/tails.txt b/words/tails.txt index f5e93bf504687..b0119a7563224 100644 --- a/words/tails.txt +++ b/words/tails.txt @@ -755,7 +755,6 @@ pipefish seahorse flounder tilapia -chub dorado shad lionfish
KeyTokens