diff --git a/.golangci.yaml b/.golangci.yaml index e3c5efda..b8c796fd 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -2,166 +2,153 @@ # # SPDX-License-Identifier: Apache-2.0 -# yaml-language-server: $schema=https://golangci-lint.run/jsonschema/golangci.jsonschema.json -linters-settings: - depguard: - # new configuration - rules: - logger: - deny: - # logging is allowed only by logutils.Log, - # logrus is allowed to use only in logutils package. - - pkg: "github.com/sirupsen/logrus" - desc: logging is allowed only by logutils.Log - dupl: - threshold: 100 - funlen: - lines: -1 # the number of lines (code + empty lines) is not a right metric and leads to code without empty line or one-liner. - statements: 50 - goconst: - min-len: 2 - min-occurrences: 3 - gocritic: - enabled-tags: - - diagnostic - - experimental - - opinionated - - performance - - style - disabled-checks: - - dupImport # https://github.com/go-critic/go-critic/issues/845 - - ifElseChain - - octalLiteral - - whyNoLint - gocyclo: - min-complexity: 15 - gofmt: - rewrite-rules: - - pattern: "interface{}" - replacement: "any" - goimports: - local-prefixes: github.com/golangci/golangci-lint - mnd: - # don't include the "operation" and "assign" - checks: - - argument - - case - - condition - - return - ignored-numbers: - - "0" - - "1" - - "2" - - "3" - ignored-functions: - - strings.SplitN - - govet: - enable-all: true - disable: - - fieldalignment # disabled because it's too strict, it checks if struct fields are sorted by size - settings: - printf: - funcs: - - (github.com/golangci/golangci-lint/pkg/logutils.Log).Infof - - (github.com/golangci/golangci-lint/pkg/logutils.Log).Warnf - - (github.com/golangci/golangci-lint/pkg/logutils.Log).Errorf - - (github.com/golangci/golangci-lint/pkg/logutils.Log).Fatalf - - gosec: - excludes: - - "G115" # Excluded per default https://github.com/golangci/golangci-lint/pull/4941 - - lll: - line-length: 140 - misspell: - locale: US - nolintlint: - allow-unused: false # report any unused nolint directives - require-explanation: false # don't require an explanation for nolint directives - require-specific: false # don't require nolint directives to be specific about which linter is being skipped - revive: - rules: - - name: unexported-return - disabled: true - - name: unused-parameter - +version: "2" linters: - disable-all: true + default: none enable: - bodyclose + - copyloopvar - depguard - dogsled - dupl - errcheck - - copyloopvar - funlen - gocheckcompilerdirectives - gochecknoinits - goconst - gocritic - gocyclo - - gofmt - - goimports - - mnd - goprintffuncname - gosec - - gosimple - govet - ineffassign - misspell + - mnd - nakedret - noctx - nolintlint - revive - staticcheck - - typecheck - unconvert - unparam - unused - whitespace - - # don't enable: - # - asciicheck - # - scopelint - # - gochecknoglobals - # - gocognit - # - godot - # - godox - # - goerr113 - # - interfacer - # - lll - # - maligned - # - nestif - # - prealloc - # - stylecheck - # - testpackage - # - wsl - -issues: - exclude-rules: - - path: _test\.go - linters: - - mnd # test files can have magic numbers - - revive # test files can have unused parameters - - - path: pkg/golinters/errcheck.go - text: "SA1019: errCfg.Exclude is deprecated: use ExcludeFunctions instead" - - path: pkg/commands/run.go - text: "SA1019: lsc.Errcheck.Exclude is deprecated: use ExcludeFunctions instead" - - path: pkg/commands/run.go - text: "SA1019: e.cfg.Run.Deadline is deprecated: Deadline exists for historical compatibility and should not be used." - - - path: pkg/golinters/gofumpt.go - text: "SA1019: settings.LangVersion is deprecated: use the global `run.go` instead." - - path: pkg/golinters/staticcheck_common.go - text: "SA1019: settings.GoVersion is deprecated: use the global `run.go` instead." - - path: pkg/lint/lintersdb/manager.go - text: "SA1019: (.+).(GoVersion|LangVersion) is deprecated: use the global `run.go` instead." - - path: pkg/golinters/unused.go - text: "rangeValCopy: each iteration copies 160 bytes \\(consider pointers or indexing\\)" - - path: test/(fix|linters)_test.go - text: "string `gocritic.go` has 3 occurrences, make it a constant" - -run: - timeout: 5m + settings: + depguard: + rules: + logger: + deny: + - pkg: github.com/sirupsen/logrus + desc: logging is allowed only by logutils.Log + dupl: + threshold: 100 + funlen: + lines: -1 + statements: 50 + goconst: + min-len: 2 + min-occurrences: 3 + gocritic: + disabled-checks: + - dupImport + - ifElseChain + - octalLiteral + - whyNoLint + enabled-tags: + - diagnostic + - experimental + - opinionated + - performance + - style + gocyclo: + min-complexity: 15 + gosec: + excludes: + - G115 + govet: + disable: + - fieldalignment + enable-all: true + settings: + printf: + funcs: + - (github.com/golangci/golangci-lint/pkg/logutils.Log).Infof + - (github.com/golangci/golangci-lint/pkg/logutils.Log).Warnf + - (github.com/golangci/golangci-lint/pkg/logutils.Log).Errorf + - (github.com/golangci/golangci-lint/pkg/logutils.Log).Fatalf + lll: + line-length: 140 + misspell: + locale: US + mnd: + checks: + - argument + - case + - condition + - return + ignored-numbers: + - "0" + - "1" + - "2" + - "3" + ignored-functions: + - strings.SplitN + nolintlint: + require-explanation: false + require-specific: false + allow-unused: false + revive: + rules: + - name: unexported-return + disabled: true + - name: unused-parameter + exclusions: + generated: lax + presets: + - comments + - common-false-positives + - legacy + - std-error-handling + rules: + - linters: + - mnd + - revive + path: _test\.go + - path: pkg/golinters/errcheck.go + text: "SA1019: errCfg.Exclude is deprecated: use ExcludeFunctions instead" + - path: pkg/commands/run.go + text: "SA1019: lsc.Errcheck.Exclude is deprecated: use ExcludeFunctions instead" + - path: pkg/commands/run.go + text: "SA1019: e.cfg.Run.Deadline is deprecated: Deadline exists for historical compatibility and should not be used." + - path: pkg/golinters/gofumpt.go + text: "SA1019: settings.LangVersion is deprecated: use the global `run.go` instead." + - path: pkg/golinters/staticcheck_common.go + text: "SA1019: settings.GoVersion is deprecated: use the global `run.go` instead." + - path: pkg/lint/lintersdb/manager.go + text: "SA1019: (.+).(GoVersion|LangVersion) is deprecated: use the global `run.go` instead." + - path: pkg/golinters/unused.go + text: 'rangeValCopy: each iteration copies 160 bytes \(consider pointers or indexing\)' + - path: test/(fix|linters)_test.go + text: string `gocritic.go` has 3 occurrences, make it a constant + paths: + - third_party$ + - builtin$ + - examples$ +formatters: + enable: + - gofmt + - goimports + settings: + gofmt: + rewrite-rules: + - pattern: interface{} + replacement: any + goimports: + local-prefixes: + - github.com/golangci/golangci-lint + exclusions: + generated: lax + paths: + - third_party$ + - builtin$ + - examples$ diff --git a/go.mod b/go.mod index 15535b8f..57bf6322 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( go.opentelemetry.io/otel/trace v1.37.0 golang.org/x/net v0.42.0 golang.org/x/sys v0.34.0 + golang.org/x/text v0.27.0 google.golang.org/grpc v1.74.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -71,7 +72,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/mod v0.25.0 // indirect golang.org/x/sync v0.16.0 // indirect - golang.org/x/text v0.27.0 // indirect golang.org/x/tools v0.34.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect diff --git a/internal/traceroute/client.go b/internal/traceroute/client.go new file mode 100644 index 00000000..8b34736b --- /dev/null +++ b/internal/traceroute/client.go @@ -0,0 +1,58 @@ +// SPDX-FileCopyrightText: 2025 Deutsche Telekom IT GmbH +// +// SPDX-License-Identifier: Apache-2.0 + +package traceroute + +import ( + "cmp" + "context" + "fmt" + "time" + + "github.com/telekom/sparrow/internal/helper" +) + +var _ Client = (*genericClient)(nil) + +// defaultOptions provides a set of default options for the traceroute. +var defaultOptions = Options{ + MaxTTL: 30, + Timeout: 60 * time.Second, + Retry: helper.RetryConfig{ + Count: 3, + Delay: 1 * time.Second, + }, +} + +// Client is able to run a traceroute to one or more targets. +// +//go:generate go tool moq -out client_moq.go . Client +type Client interface { + // Run executes the traceroute for the given targets with the specified options. + // Returns a Result containing the hops for each target, or an error if the traceroute fails. + Run(ctx context.Context, targets []Target, opts *Options) (Result, error) +} + +type genericClient struct { + // tcp is the [tcpClient] that implements the traceroute using TCP. + tcp Client +} + +// NewClient creates a new [Client] that can be used to run traceroutes. +func NewClient() Client { + return &genericClient{ + tcp: newTCPClient(), + } +} + +// Run executes the traceroute for the given targets with the specified options. +func (c *genericClient) Run(ctx context.Context, targets []Target, opts *Options) (Result, error) { + for _, target := range targets { + if err := target.Validate(); err != nil { + return nil, fmt.Errorf("invalid target %s: %w", target, err) + } + } + + return c.tcp.Run(ctx, targets, cmp.Or(opts, &defaultOptions)) +} diff --git a/internal/traceroute/client_moq.go b/internal/traceroute/client_moq.go new file mode 100644 index 00000000..6ca8b59b --- /dev/null +++ b/internal/traceroute/client_moq.go @@ -0,0 +1,87 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package traceroute + +import ( + "context" + "sync" +) + +// Ensure, that ClientMock does implement Client. +// If this is not the case, regenerate this file with moq. +var _ Client = &ClientMock{} + +// ClientMock is a mock implementation of Client. +// +// func TestSomethingThatUsesClient(t *testing.T) { +// +// // make and configure a mocked Client +// mockedClient := &ClientMock{ +// RunFunc: func(ctx context.Context, targets []Target, opts *Options) (Result, error) { +// panic("mock out the Run method") +// }, +// } +// +// // use mockedClient in code that requires Client +// // and then make assertions. +// +// } +type ClientMock struct { + // RunFunc mocks the Run method. + RunFunc func(ctx context.Context, targets []Target, opts *Options) (Result, error) + + // calls tracks calls to the methods. + calls struct { + // Run holds details about calls to the Run method. + Run []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Targets is the targets argument value. + Targets []Target + // Opts is the opts argument value. + Opts *Options + } + } + lockRun sync.RWMutex +} + +// Run calls RunFunc. +func (mock *ClientMock) Run(ctx context.Context, targets []Target, opts *Options) (Result, error) { + if mock.RunFunc == nil { + panic("ClientMock.RunFunc: method is nil but Client.Run was just called") + } + callInfo := struct { + Ctx context.Context + Targets []Target + Opts *Options + }{ + Ctx: ctx, + Targets: targets, + Opts: opts, + } + mock.lockRun.Lock() + mock.calls.Run = append(mock.calls.Run, callInfo) + mock.lockRun.Unlock() + return mock.RunFunc(ctx, targets, opts) +} + +// RunCalls gets all the calls that were made to Run. +// Check the length with: +// +// len(mockedClient.RunCalls()) +func (mock *ClientMock) RunCalls() []struct { + Ctx context.Context + Targets []Target + Opts *Options +} { + var calls []struct { + Ctx context.Context + Targets []Target + Opts *Options + } + mock.lockRun.RLock() + calls = mock.calls.Run + mock.lockRun.RUnlock() + return calls +} diff --git a/internal/traceroute/doc.go b/internal/traceroute/doc.go new file mode 100644 index 00000000..631da173 --- /dev/null +++ b/internal/traceroute/doc.go @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: 2025 Deutsche Telekom IT GmbH +// +// SPDX-License-Identifier: Apache-2.0 + +// Package traceroute provides a traceroute implementation that +// falls back to ICMP time-exceeded and destination-unreachable +// messages when setting TTL limits on outgoing probes. +// +// It exposes a [Client] for running traceroutes against one or +// more targets with configurable [Options]. +// Under the hood it dials TCP connections with incrementing TTLs, +// listens for ICMP responses when TCP connections fail, and collects hop +// results in order, de-duplicating and stopping early when the destination is +// reached. +// +// Key features: +// - Pure-Go TCP dialing with IP_TTL control via x/sys/unix (no external +// traceroute binary required) +// - Optional raw-socket ICMP listener with graceful fallback when NET_RAW +// capabilities are unavailable +// - Concurrency via goroutines and channels, with result collection, sorting, +// and de-duplication +// - Built-in OpenTelemetry spans and events for tracing each hop and errors +// - Configurable retry policy, timeouts, and maximum hops via Options +// - Fully mockable internals (icmpListener, tracer, Client) for unit testing +// +// Typical usage: +// +// client := traceroute.NewClient() +// opts := &traceroute.Options{MaxTTL: 30, Timeout: 2*time.Second, Retry: retryCfg} +// res, err := client.Run(ctx, []traceroute.Target{{Protocol: traceroute.ProtocolTCP, Address: "8.8.8.8", Port: 53}}, opts) +// // res maps each Target to its slice of Hop results +// +// See each sub-package or type for more detailed documentation on exposed types +// and functions. +package traceroute diff --git a/internal/traceroute/errors.go b/internal/traceroute/errors.go new file mode 100644 index 00000000..beed767c --- /dev/null +++ b/internal/traceroute/errors.go @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: 2025 Deutsche Telekom IT GmbH +// +// SPDX-License-Identifier: Apache-2.0 + +package traceroute + +import ( + "context" + "errors" +) + +// errICMPNotAvailable is returned when ICMP is not available due to lack of NET_RAW capabilities. +// This typically occurs when the process does not have the necessary permissions to create an ICMP socket +// or when running in an environment where ICMP is restricted (e.g., some containerized environments). +var errICMPNotAvailable = errors.New("no NET_RAW capabilities, ICMP not available") + +// isTracerouteError checks if the error is related to common +// and expected traceroute errors. +func isTracerouteError(err error) bool { + return errors.Is(err, errICMPNotAvailable) || + errors.Is(err, context.DeadlineExceeded) +} diff --git a/internal/traceroute/errors_test.go b/internal/traceroute/errors_test.go new file mode 100644 index 00000000..16a9d045 --- /dev/null +++ b/internal/traceroute/errors_test.go @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: 2025 Deutsche Telekom IT GmbH +// +// SPDX-License-Identifier: Apache-2.0 + +package traceroute + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsTracerouteError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"icmp not available", errICMPNotAvailable, true}, + {"wrapped icmp not available", fmt.Errorf("wrap: %w", errICMPNotAvailable), true}, + {"deadline exceeded", context.DeadlineExceeded, true}, + {"wrapped deadline exceeded", fmt.Errorf("ctx error: %w", context.DeadlineExceeded), true}, + {"some other error", errors.New("foo"), false}, + {"nil error", nil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isTracerouteError(tt.err) + assert.Equal(t, tt.want, got, "isTracerouteError(%v)", tt.err) + }) + } +} diff --git a/internal/traceroute/helpers.go b/internal/traceroute/helpers.go new file mode 100644 index 00000000..a9c410c4 --- /dev/null +++ b/internal/traceroute/helpers.go @@ -0,0 +1,151 @@ +// SPDX-FileCopyrightText: 2025 Deutsche Telekom IT GmbH +// +// SPDX-License-Identifier: Apache-2.0 + +package traceroute + +import ( + "context" + "errors" + "fmt" + "math/rand/v2" + "net" + "slices" + + "github.com/telekom/sparrow/internal/logger" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" + "golang.org/x/sys/unix" + "golang.org/x/text/cases" + "golang.org/x/text/language" +) + +const ( + // basePort is the starting port for the TCP connection + basePort = 30000 + // portRange is the range of ports to generate a random port from + portRange = 10000 +) + +// randomPort returns a random port in the interval [30000, 40000) +func randomPort() int { + return rand.N(portRange) + basePort // #nosec G404 // math.rand is fine here, we're not doing encryption +} + +// resolveName performs a reverse DNS lookup for the given IP address. +// If the lookup fails or returns no names, it returns an empty string. +func resolveName(addr net.Addr) string { + if addr == nil { + return "" + } + + ip := ipFromAddr(addr) + if ip == nil { + return "" + } + + names, err := net.LookupAddr(ip.String()) + if err != nil || len(names) == 0 { + return "" + } + return names[0] +} + +// ipFromAddr extracts the IP address from a [net.Addr]. +func ipFromAddr(addr net.Addr) net.IP { + switch a := addr.(type) { + case *net.UDPAddr: + return a.IP + case *net.TCPAddr: + return a.IP + case *net.IPAddr: + return a.IP + } + return nil +} + +// collectResults collects the results from the channel and returns a sorted slice of hops. +// It filters out hops with a TTL of 0 and removes duplicates, keeping only the first occurrence of each TTL. +// The hops are sorted by TTL in ascending order. +func collectResults(ch <-chan Hop) []Hop { + hops := []Hop{} + for hop := range ch { + if hop.TTL == 0 { + continue + } + hops = append(hops, hop) + } + + if len(hops) == 0 { + return hops + } + + slices.SortFunc(hops, func(a, b Hop) int { + return a.TTL - b.TTL + }) + + filtered := make([]Hop, 0, len(hops)) + seen := make(map[int]bool) + for _, hop := range hops { + if !seen[hop.TTL] { + filtered = append(filtered, hop) + seen[hop.TTL] = true + if hop.Reached { + // If we reached the target, we can stop collecting hops. + break + } + } + } + + return filtered +} + +// logHops logs the hops in a structured format. +func logHops(ctx context.Context, hops []Hop) { + log := logger.FromContext(ctx) + for _, hop := range hops { + log.DebugContext(ctx, hop.String()) + } +} + +// wrapError wraps an error with a message and logs it. +// It also records the error in the current OpenTelemetry span. +func wrapError(ctx context.Context, err error, msg string) error { + if err == nil { + return nil + } + log := logger.FromContext(ctx) + span := trace.SpanFromContext(ctx) + caser := cases.Title(language.English) + + log.ErrorContext(ctx, caser.String(msg), "error", err) + span.SetStatus(codes.Error, msg) + span.RecordError(err) + return fmt.Errorf("%s: %w", msg, err) +} + +// recordTCPError records the error from dialing a TCP connection. +// If the error is nil or [unix.EHOSTUNREACH], it returns nil. +func recordTCPError(ctx context.Context, err error) error { + if err == nil { + return nil + } + log := logger.FromContext(ctx) + span := trace.SpanFromContext(ctx) + + // No route to host is a special error because of how traceroute works. + // We are expecting the connection to fail because of TTL expiry. + span.RecordError(err) + if !errors.Is(err, unix.EHOSTUNREACH) { + log.ErrorContext(ctx, "Failed to dial TCP connection", "error", err) + span.AddEvent("TCP connection failed", trace.WithAttributes( + attribute.String("traceroute.target.error", err.Error()), + )) + span.SetStatus(codes.Error, "Failed to dial TCP connection") + return fmt.Errorf("failed to dial TCP connection: %w", err) + } + + span.SetStatus(codes.Error, "No route to host") + return nil +} diff --git a/internal/traceroute/helpers_test.go b/internal/traceroute/helpers_test.go new file mode 100644 index 00000000..c932f48f --- /dev/null +++ b/internal/traceroute/helpers_test.go @@ -0,0 +1,131 @@ +// SPDX-FileCopyrightText: 2025 Deutsche Telekom IT GmbH +// +// SPDX-License-Identifier: Apache-2.0 + +package traceroute + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRandomPort(t *testing.T) { + // randomPort should always return [basePort, basePort+portRange) + for range 1000 { + p := randomPort() + assert.GreaterOrEqual(t, p, basePort, "randomPort should be >= basePort") + assert.Less(t, p, basePort+portRange, "randomPort should be < basePort+portRange") + } +} + +func TestIPFromAddr(t *testing.T) { + tests := []struct { + name string + addr net.Addr + expected net.IP + }{ + {"TCPAddr", &net.TCPAddr{IP: net.ParseIP("1.2.3.4"), Port: 80}, net.ParseIP("1.2.3.4")}, + {"UDPAddr", &net.UDPAddr{IP: net.ParseIP("5.6.7.8"), Port: 53}, net.ParseIP("5.6.7.8")}, + {"IPAddr", &net.IPAddr{IP: net.ParseIP("9.10.11.12")}, net.ParseIP("9.10.11.12")}, + {"UnixAddr (unsupported)", &net.UnixAddr{Name: "/tmp/x", Net: "unix"}, nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ipFromAddr(tt.addr) + assert.Equal(t, tt.expected, got) + }) + } +} + +func TestResolveName(t *testing.T) { + tests := []struct { + name string + addr net.Addr + want string + }{ + {"nil Addr", nil, ""}, + {"unsupported Addr", &net.UnixAddr{Name: "/tmp/x", Net: "unix"}, ""}, + {"no reverse record", &net.IPAddr{IP: net.ParseIP("203.0.113.1")}, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, resolveName(tt.addr)) + }) + } + + // And one "happy path" using loopback, which almost always maps to localhost + t.Run("loopback resolves", func(t *testing.T) { + loop := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} + name := resolveName(loop) + // On most systems this will be "localhost." or similar + assert.NotEmpty(t, name, "expected a non-empty name for 127.0.0.1") + assert.Contains(t, name, "localhost", "expected substring 'localhost' in %q", name) + }) +} + +func TestCollectResults(t *testing.T) { + tests := []struct { + name string + input []Hop + expected []Hop + }{ + { + name: "empty channel", + input: []Hop{}, + expected: []Hop{}, + }, + { + name: "filters out TTL zero", + input: []Hop{{TTL: 0}, {TTL: 2}, {TTL: 0}, {TTL: 1}}, + expected: []Hop{ + {TTL: 1}, + {TTL: 2}, + }, + }, + { + name: "sorts hops by TTL", + input: []Hop{{TTL: 3}, {TTL: 1}, {TTL: 2}}, + expected: []Hop{ + {TTL: 1}, + {TTL: 2}, + {TTL: 3}, + }, + }, + { + name: "removes duplicate TTLs, keeping first occurrence", + input: []Hop{{TTL: 1}, {TTL: 2}, {TTL: 1}, {TTL: 3}, {TTL: 2}}, + expected: []Hop{ + {TTL: 1}, + {TTL: 2}, + {TTL: 3}, + }, + }, + { + name: "combined filter, sort and dedupe", + input: []Hop{{TTL: 0}, {TTL: 4}, {TTL: 2}, {TTL: 3}, {TTL: 2}, {TTL: 1}, {TTL: 0}}, + expected: []Hop{ + {TTL: 1}, + {TTL: 2}, + {TTL: 3}, + {TTL: 4}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ch := make(chan Hop, len(tt.input)) + for _, h := range tt.input { + ch <- h + } + close(ch) + + got := collectResults(ch) + assert.Equal(t, tt.expected, got, "collectResults(%v)", tt.input) + }) + } +} diff --git a/internal/traceroute/hopper.go b/internal/traceroute/hopper.go new file mode 100644 index 00000000..4ae62603 --- /dev/null +++ b/internal/traceroute/hopper.go @@ -0,0 +1,64 @@ +// SPDX-FileCopyrightText: 2025 Deutsche Telekom IT GmbH +// +// SPDX-License-Identifier: Apache-2.0 + +package traceroute + +import ( + "context" + "sync" + + "github.com/telekom/sparrow/internal/helper" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" +) + +// tracer is an interface that defines the methods required for executing a traceroute. +// +//go:generate go tool moq -out tracer_moq.go . tracer +type tracer interface { + // trace executes the traceroute for the given target with the specified options. + trace(ctx context.Context, target Target, opts Options) error +} + +// hopper is responsible for managing the execution of traceroute hops for a target. +type hopper struct { + client tracer + wg sync.WaitGroup + otelTracer trace.Tracer + target *Target + opts Options +} + +// run executes the traceroute hops for the target. +// It's the callers responsibility to collect the results +// from the hop channel of the target after calling this method. +func (h *hopper) run(ctx context.Context) { + for ttl := 1; ttl <= h.opts.MaxTTL; ttl++ { + h.wg.Add(1) + go func() { + defer h.wg.Done() + c, hopSpan := h.otelTracer.Start(ctx, h.target.String(), trace.WithAttributes( + attribute.Stringer("traceroute.target.address", h.target), + attribute.Int("traceroute.target.ttl", ttl), + )) + defer hopSpan.End() + hopSpan.SetAttributes( + attribute.Stringer("traceroute.target.address", h.target), + attribute.Int("traceroute.target.ttl", ttl), + ) + + retry := helper.Retry(func(ctx context.Context) error { + return h.client.trace(ctx, h.target.withHopTTL(ttl), h.opts) + }, h.opts.Retry) + + if err := retry(c); err != nil { + hopSpan.RecordError(err) + hopSpan.SetStatus(codes.Error, "Failed to execute hop trace") + hopSpan.End() + return + } + }() + } +} diff --git a/internal/traceroute/hopper_test.go b/internal/traceroute/hopper_test.go new file mode 100644 index 00000000..e531c2b3 --- /dev/null +++ b/internal/traceroute/hopper_test.go @@ -0,0 +1,109 @@ +// SPDX-FileCopyrightText: 2025 Deutsche Telekom IT GmbH +// +// SPDX-License-Identifier: Apache-2.0 + +package traceroute + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/telekom/sparrow/internal/helper" + "go.opentelemetry.io/otel/trace/noop" +) + +func TestHopper_run(t *testing.T) { + tests := []struct { + name string + maxTTL int + wantCalls int + }{ + {"zero hops", 0, 0}, + {"one hop", 1, 1}, + {"three hops", 3, 3}, + {"five hops", 5, 5}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ch := make(chan Target, tt.maxTTL) + mock := &tracerMock{ + traceFunc: func(_ context.Context, tgt Target, _ Options) error { + ch <- tgt + return nil + }, + } + + h := &hopper{ + target: &Target{Address: "127.0.0.1", Port: 1234}, + client: mock, + otelTracer: noop.NewTracerProvider().Tracer("test"), + opts: Options{ + MaxTTL: tt.maxTTL, + Retry: helper.RetryConfig{Count: 1, Delay: time.Millisecond}, + Timeout: 2 * time.Millisecond, + }, + } + + h.run(t.Context()) + h.wg.Wait() + close(ch) + + var got []int + for tgt := range ch { + got = append(got, tgt.hopTTL) + } + + want := make([]int, tt.maxTTL) + for i := range want { + want[i] = i + 1 + } + + assert.ElementsMatch(t, want, got, + "expected tracer to be called once for each ttl 1..%d, got %v", tt.maxTTL, got) + }) + } +} + +func TestHopper_run_retry(t *testing.T) { + var ( + mu sync.Mutex + invocations []int + callCount int + ) + + mock := &tracerMock{ + traceFunc: func(_ context.Context, _ Target, _ Options) error { + mu.Lock() + defer mu.Unlock() + callCount++ + invocations = append(invocations, callCount) + if callCount < 3 { + return fmt.Errorf("transient error %d", callCount) + } + return nil + }, + } + + h := &hopper{ + target: &Target{Address: "127.0.0.1", Port: 1234}, + client: mock, + otelTracer: noop.NewTracerProvider().Tracer("test"), + opts: Options{ + MaxTTL: 1, + Retry: helper.RetryConfig{Count: 2, Delay: 0}, + Timeout: time.Millisecond, + }, + } + + h.run(t.Context()) + h.wg.Wait() + + require.Len(t, invocations, 3, "expected 3 total attempts, got %d", len(invocations)) + assert.Equal(t, []int{1, 2, 3}, invocations) +} diff --git a/internal/traceroute/icmp.go b/internal/traceroute/icmp.go new file mode 100644 index 00000000..61cf8eea --- /dev/null +++ b/internal/traceroute/icmp.go @@ -0,0 +1,185 @@ +// SPDX-FileCopyrightText: 2025 Deutsche Telekom IT GmbH +// +// SPDX-License-Identifier: Apache-2.0 + +package traceroute + +import ( + "context" + "errors" + "fmt" + "net" + "time" + + "github.com/telekom/sparrow/internal/logger" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "golang.org/x/sys/unix" +) + +// icmpListener is an interface for reading ICMP messages. +// +//go:generate go tool moq -out icmp_moq.go . icmpListener +type icmpListener interface { + Read(ctx context.Context, wantPort int, timeout time.Duration) (icmpPacket, error) + Close() error +} + +// icmpPacketListener is a listener for ICMP messages. +type icmpPacketListener struct { + // conn is the ICMP packet connection used to listen for ICMP messages. + conn *icmp.PacketConn + // canICMP indicates whether the listener was successfully created + // with NET_RAW capabilities, meaning it can read ICMP messages. + canICMP bool +} + +// newICMPListener creates a new ICMP listener on the default IP address and port. +// If the listener cannot be created due to permission issues, it returns a listener +// that indicates ICMP is not available, but does not return an error. +func newICMPListener() (icmpListener, error) { + conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0") + if err == nil { + return &icmpPacketListener{conn: conn, canICMP: true}, nil + } + + if errors.Is(err, unix.EPERM) { + return &icmpPacketListener{conn: nil, canICMP: false}, nil + } + + return nil, fmt.Errorf("failed to create ICMP listener: %w", err) +} + +// Read receives all ICMP messages on the listener's connection until +// it either receives a message on the specified port or the timeout is exceeded. +// +// Returns [errICMPNotAvailable] if the listener was created without NET_RAW capabilities, +// meaning ICMP is not available for reading. +func (il *icmpPacketListener) Read(ctx context.Context, recvPort int, timeout time.Duration) (icmpPacket, error) { + if !il.canICMP { + return icmpPacket{}, errICMPNotAvailable + } + log := logger.FromContext(ctx) + deadline := time.Now().Add(timeout) + + for time.Now().Before(deadline) { + log.DebugContext(ctx, "Reading ICMP message") + packet, err := il.recvPacket(ctx, timeout) + if err != nil { + log.ErrorContext(ctx, "Failed to receive ICMP packet", "error", err) + continue + } + + if packet.port != recvPort { + log.DebugContext(ctx, "Received ICMP message on another port, ignoring", + "expectedPort", recvPort, + "receivedPort", packet.port) + continue + } + + return *packet, nil + } + + log.DebugContext(ctx, "ICMP read timeout exceeded") + return icmpPacket{}, context.DeadlineExceeded +} + +// recvPacket reads the next ICMP packet from the listener's connection. +func (il *icmpPacketListener) recvPacket(ctx context.Context, timeout time.Duration) (*icmpPacket, error) { + log := logger.FromContext(ctx) + if err := il.conn.SetReadDeadline(time.Now().Add(timeout)); err != nil { + return nil, fmt.Errorf("failed to set read deadline: %w", err) + } + + // mtuSize is the maximum transmission unit size + const mtuSize = 1500 + buf := make([]byte, mtuSize) + n, src, err := il.conn.ReadFrom(buf) + if err != nil { + // This is most probably a timeout or a closed connection + return nil, fmt.Errorf("failed to read from ICMP socket: %w", err) + } + + msg, err := icmp.ParseMessage(ipv4.ICMPTypeTimeExceeded.Protocol(), buf[:n]) + if err != nil { + return nil, fmt.Errorf("failed to parse ICMP message: %w", err) + } + + packet, err := newICMPPacket(src, msg) + if err != nil { + return nil, fmt.Errorf("failed to create ICMP packet from received message: %w", err) + } + log.DebugContext(ctx, "Received ICMP packet", + "type", msg.Type, + "routerAddr", packet.remoteAddr, + "port", packet.port, + "reached", packet.reached, + ) + return packet, nil +} + +// icmpPacket represents a received ICMP packet. +type icmpPacket struct { + // remoteAddr is the address of the device (typically a router) + // that sent the ICMP message in response to our traceroute probe. + remoteAddr net.Addr + // port is the parsed destination port from the TCP segment + // contained in the ICMP message. + port int + // reached indicates whether the ICMP message indicates that the destination + // was reached or not. This is true for ICMP messages of [ipv4.ICMPTypeDestinationUnreachable] + // and [ipv6.ICMPTypeDestinationUnreachable]. + reached bool +} + +// codePortUnreachable is the ICMP code for Destination Unreachable - "Port Unreachable" messages. +// For more information, see: +// https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml#icmp-parameters-codes-3 +const codePortUnreachable = 3 + +// newICMPPacket creates a new ICMP packet from the given ICMP message and source address. +func newICMPPacket(src net.Addr, msg *icmp.Message) (*icmpPacket, error) { + // Extract the TCP segment from the ICMP message. + // The TCP segment comes after the IP header. + var tcpSegment []byte + switch msg.Type { + case ipv4.ICMPTypeTimeExceeded: + tcpSegment = msg.Body.(*icmp.TimeExceeded).Data[ipv4.HeaderLen:] + case ipv4.ICMPTypeDestinationUnreachable: + tcpSegment = msg.Body.(*icmp.DstUnreach).Data[ipv4.HeaderLen:] + // Currently, we do not support IPv6 ICMP messages. + // If we ever do, the header size is [ipv6.HeaderLen]. + case ipv6.ICMPTypeTimeExceeded: + return nil, fmt.Errorf("ipv6 ICMP messages are not supported") + case ipv6.ICMPTypeDestinationUnreachable: + return nil, fmt.Errorf("ipv6 ICMP messages are not supported") + default: + return nil, fmt.Errorf("unexpected ICMP message type: %v", msg.Type) + } + + // In the TCP segment, the first two bytes are the destination port. + if len(tcpSegment) < 2 { + return nil, fmt.Errorf("tcp segment too short: %d bytes", len(tcpSegment)) + } + + destPort := int(tcpSegment[0])<<8 + int(tcpSegment[1]) + unreachable := msg.Type == ipv4.ICMPTypeDestinationUnreachable || msg.Type == ipv6.ICMPTypeDestinationUnreachable + + return &icmpPacket{ + remoteAddr: src, + port: destPort, + reached: unreachable && msg.Code == codePortUnreachable, + }, nil +} + +// Close closes the ICMP listener connection. +// +// It is safe to call this method even if the listener was not successfully created +// or if it does not have NET_RAW capabilities. +func (il *icmpPacketListener) Close() error { + if il.conn != nil { + return il.conn.Close() + } + return nil +} diff --git a/internal/traceroute/icmp_moq.go b/internal/traceroute/icmp_moq.go new file mode 100644 index 00000000..fa7b687d --- /dev/null +++ b/internal/traceroute/icmp_moq.go @@ -0,0 +1,125 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package traceroute + +import ( + "context" + "sync" + "time" +) + +// Ensure, that icmpListenerMock does implement icmpListener. +// If this is not the case, regenerate this file with moq. +var _ icmpListener = &icmpListenerMock{} + +// icmpListenerMock is a mock implementation of icmpListener. +// +// func TestSomethingThatUsesicmpListener(t *testing.T) { +// +// // make and configure a mocked icmpListener +// mockedicmpListener := &icmpListenerMock{ +// CloseFunc: func() error { +// panic("mock out the Close method") +// }, +// ReadFunc: func(ctx context.Context, wantPort int, timeout time.Duration) (icmpPacket, error) { +// panic("mock out the Read method") +// }, +// } +// +// // use mockedicmpListener in code that requires icmpListener +// // and then make assertions. +// +// } +type icmpListenerMock struct { + // CloseFunc mocks the Close method. + CloseFunc func() error + + // ReadFunc mocks the Read method. + ReadFunc func(ctx context.Context, wantPort int, timeout time.Duration) (icmpPacket, error) + + // calls tracks calls to the methods. + calls struct { + // Close holds details about calls to the Close method. + Close []struct { + } + // Read holds details about calls to the Read method. + Read []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // WantPort is the wantPort argument value. + WantPort int + // Timeout is the timeout argument value. + Timeout time.Duration + } + } + lockClose sync.RWMutex + lockRead sync.RWMutex +} + +// Close calls CloseFunc. +func (mock *icmpListenerMock) Close() error { + if mock.CloseFunc == nil { + panic("icmpListenerMock.CloseFunc: method is nil but icmpListener.Close was just called") + } + callInfo := struct { + }{} + mock.lockClose.Lock() + mock.calls.Close = append(mock.calls.Close, callInfo) + mock.lockClose.Unlock() + return mock.CloseFunc() +} + +// CloseCalls gets all the calls that were made to Close. +// Check the length with: +// +// len(mockedicmpListener.CloseCalls()) +func (mock *icmpListenerMock) CloseCalls() []struct { +} { + var calls []struct { + } + mock.lockClose.RLock() + calls = mock.calls.Close + mock.lockClose.RUnlock() + return calls +} + +// Read calls ReadFunc. +func (mock *icmpListenerMock) Read(ctx context.Context, wantPort int, timeout time.Duration) (icmpPacket, error) { + if mock.ReadFunc == nil { + panic("icmpListenerMock.ReadFunc: method is nil but icmpListener.Read was just called") + } + callInfo := struct { + Ctx context.Context + WantPort int + Timeout time.Duration + }{ + Ctx: ctx, + WantPort: wantPort, + Timeout: timeout, + } + mock.lockRead.Lock() + mock.calls.Read = append(mock.calls.Read, callInfo) + mock.lockRead.Unlock() + return mock.ReadFunc(ctx, wantPort, timeout) +} + +// ReadCalls gets all the calls that were made to Read. +// Check the length with: +// +// len(mockedicmpListener.ReadCalls()) +func (mock *icmpListenerMock) ReadCalls() []struct { + Ctx context.Context + WantPort int + Timeout time.Duration +} { + var calls []struct { + Ctx context.Context + WantPort int + Timeout time.Duration + } + mock.lockRead.RLock() + calls = mock.calls.Read + mock.lockRead.RUnlock() + return calls +} diff --git a/internal/traceroute/tcp.go b/internal/traceroute/tcp.go new file mode 100644 index 00000000..94a203b9 --- /dev/null +++ b/internal/traceroute/tcp.go @@ -0,0 +1,221 @@ +// SPDX-FileCopyrightText: 2025 Deutsche Telekom IT GmbH +// +// SPDX-License-Identifier: Apache-2.0 + +package traceroute + +import ( + "context" + "errors" + "net" + "syscall" + "time" + + "github.com/telekom/sparrow/internal/logger" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "golang.org/x/sys/unix" +) + +var ( + _ Client = (*tcpClient)(nil) + _ tracer = (*tcpClient)(nil) +) + +type tcpClient struct { + dialTCP func(ctx context.Context, addr net.Addr, ttl int, timeout time.Duration) (tcpConn, error) + newICMPListener func() (icmpListener, error) +} + +// newTCPClient creates a new TCP client for performing traceroutes. +func newTCPClient() *tcpClient { + return &tcpClient{ + dialTCP: dialTCP, + newICMPListener: newICMPListener, + } +} + +// Run executes the traceroute for the given targets using TCP. +// It returns a Result containing the hops for each target, or an error if the traceroute fails. +func (c *tcpClient) Run(ctx context.Context, targets []Target, opts *Options) (Result, error) { + tracer := trace.SpanFromContext(ctx).TracerProvider().Tracer("traceroute.tcpClient") + ctx, sp := tracer.Start(ctx, "Run", trace.WithAttributes( + attribute.Int("traceroute.targets.count", len(targets)), + attribute.Int("traceroute.options.max_hops", opts.MaxTTL), + attribute.Stringer("traceroute.options.timeout", opts.Timeout), + )) + defer sp.End() + + res := make(Result, len(targets)) + for _, target := range targets { + hops := make(chan Hop, opts.MaxTTL) + target.hopChan = hops + + go func(t Target) { + h := &hopper{ + target: &t, + client: c, + otelTracer: tracer, + opts: *opts, + } + h.run(ctx) + h.wg.Wait() + close(hops) + }(target) + + results := collectResults(hops) + res[target] = results + logHops(ctx, results) + } + + return res, nil +} + +func (c *tcpClient) trace(ctx context.Context, target Target, opts Options) error { + span := trace.SpanFromContext(ctx) + log := logger.FromContext(ctx) + log.DebugContext(ctx, "Starting TCP trace", "target", target) + + targetAddr, err := target.ToAddr() + if err != nil { + return wrapError(ctx, err, "failed to convert target to address") + } + + il, err := c.newICMPListener() + if err != nil { + return wrapError(ctx, err, "failed to create ICMP listener") + } + defer func() { _ = il.Close() }() + + start := time.Now() + conn, err := c.dialTCP(ctx, targetAddr, target.hopTTL, opts.Timeout) + defer func() { _ = conn.Close() }() + + // Happiest path: we successfully established a TCP connection + // to the target, which means we reached the destination and + // the traceroute is complete with this hop. + if err == nil { + hop := Hop{ + Latency: time.Since(start), + Addr: newHopAddress(targetAddr), + Name: resolveName(targetAddr), + TTL: target.hopTTL, + Reached: true, + } + log.DebugContext(ctx, "TCP connection established", "port", conn.port, "addr", targetAddr) + span.AddEvent("TCP connection established", trace.WithAttributes( + attribute.Stringer("traceroute.target.hop", hop), + attribute.Bool("traceroute.target.reached", hop.Reached), + )) + + target.hopChan <- hop + return nil + } + + // Unexpected error: we failed to establish a TCP connection + // due to an error other than [unix.EHOSTUNREACH], which + // indicates that our TTL is too low to reach the target + // and is expected behavior for traceroute. + if rErr := recordTCPError(ctx, err); rErr != nil { + return rErr + } + + packet, err := il.Read(ctx, conn.port, opts.Timeout) + switch { + // Unexpected error: we failed to read an ICMP message + // and it's not because of capabilities/exceeded timeout. + case err != nil && !isTracerouteError(err): + return wrapError(ctx, err, "failed to read ICMP message") + + // User error: we don't have the necessary capabilities + // to open a raw socket for reading ICMP messages. + case errors.Is(err, errICMPNotAvailable): + return wrapError(ctx, err, "ICMP not available for reading") + + // Timeout error: we didn't receive an ICMP message within + // the specified timeout, which is expected when routers + // do not respond to our traceroute probes. + case errors.Is(err, context.DeadlineExceeded): + hop := Hop{ + Latency: time.Since(start), + Addr: HopAddress{IP: "*"}, + TTL: target.hopTTL, + Reached: false, + } + log.DebugContext(ctx, "ICMP read timeout exceeded, no response received") + span.AddEvent("ICMP read timeout exceeded", trace.WithAttributes( + attribute.Bool("traceroute.target.reached", hop.Reached), + attribute.Stringer("traceroute.target.hop", hop), + attribute.String("traceroute.target.hop.error", err.Error()), + )) + target.hopChan <- hop + return nil + + // Expected ICMP message received: we received an ICMP message + // indicating that the TTL has expired, which is the expected behavior + // of traceroute. + default: + hop := Hop{ + Latency: time.Since(start), + Addr: newHopAddress(packet.remoteAddr), + Name: resolveName(packet.remoteAddr), + TTL: target.hopTTL, + Reached: packet.reached, + } + log.DebugContext(ctx, "Received ICMP message", "port", packet.port, "routerAddr", packet.remoteAddr) + span.AddEvent("ICMP message received", trace.WithAttributes( + attribute.Bool("traceroute.target.reached", hop.Reached), + attribute.Stringer("traceroute.target.hop", hop), + )) + target.hopChan <- hop + return nil + } +} + +// tcpConn represents a TCP connection with a specific port. +type tcpConn struct { + conn net.Conn + port int +} + +// dialTCP dials a TCP connection to the given address with the specified TTL. +func dialTCP(ctx context.Context, addr net.Addr, ttl int, timeout time.Duration) (tcpConn, error) { + port := randomPort() + + // Dialer with control function to set IP_TTL + dialer := net.Dialer{ + LocalAddr: &net.TCPAddr{ + Port: port, + }, + Timeout: timeout, + ControlContext: func(_ context.Context, _, _ string, c syscall.RawConn) error { + var opErr error + if err := c.Control(func(fd uintptr) { + opErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_TTL, ttl) // #nosec G115 // The net package is safe to use + }); err != nil { + return err + } + return opErr + }, + } + + conn, err := dialer.DialContext(ctx, "tcp", addr.String()) + switch { + case err == nil: + return tcpConn{conn: conn, port: port}, nil + case errors.Is(err, unix.EADDRINUSE): + // If the address is already in use, + // we just retry with a new random port. + return dialTCP(ctx, addr, ttl, timeout) + default: + return tcpConn{conn: conn, port: port}, err + } +} + +// Close closes the TCP connection. +func (tc *tcpConn) Close() error { + if tc.conn != nil { + return tc.conn.Close() + } + return nil +} diff --git a/internal/traceroute/tcp_test.go b/internal/traceroute/tcp_test.go new file mode 100644 index 00000000..0e6fcd4b --- /dev/null +++ b/internal/traceroute/tcp_test.go @@ -0,0 +1,167 @@ +// SPDX-FileCopyrightText: 2025 Deutsche Telekom IT GmbH +// +// SPDX-License-Identifier: Apache-2.0 + +package traceroute + +import ( + "context" + "errors" + "net" + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" + "golang.org/x/sys/unix" +) + +func TestTCPClient_trace(t *testing.T) { + tgt := Target{Protocol: ProtocolTCP, Address: "1.2.3.4", Port: 8080} + tgt.hopTTL = 3 + + tests := []struct { + name string + dialErr error + icmpPacket icmpPacket + icmpErr error + wantErr bool + wantAddr string + wantReached bool + }{ + { + name: "tcp success", + dialErr: nil, + wantErr: false, + wantAddr: "1.2.3.4", + wantReached: true, + }, + { + name: "dial record error short-circuit", + dialErr: errors.New("network failure"), + wantErr: true, + }, + { + name: "ttl expired timeout", + dialErr: unix.EHOSTUNREACH, + icmpErr: context.DeadlineExceeded, + wantErr: false, + wantAddr: "*", + wantReached: false, + }, + { + name: "icmp not available", + dialErr: unix.EHOSTUNREACH, + icmpErr: errICMPNotAvailable, + wantErr: true, + }, + { + name: "intermediate router", + dialErr: unix.EHOSTUNREACH, + icmpPacket: icmpPacket{remoteAddr: newAddr(t, "9.8.7.6"), port: 8080}, + wantErr: false, + wantAddr: "9.8.7.6", + wantReached: false, + }, + { + name: "icmp read error", + dialErr: unix.EHOSTUNREACH, + icmpErr: errors.New("icmp read error"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &tcpClient{ + dialTCP: func(_ context.Context, addr net.Addr, _ int, _ time.Duration) (tcpConn, error) { + require.Contains(t, addr.String(), ":8080") + if tt.dialErr != nil { + return tcpConn{}, tt.dialErr + } + return tcpConn{conn: nil, port: 0}, nil + }, + newICMPListener: func() (icmpListener, error) { + return &icmpListenerMock{ + ReadFunc: func(_ context.Context, _ int, _ time.Duration) (icmpPacket, error) { + return tt.icmpPacket, tt.icmpErr + }, + CloseFunc: func() error { return nil }, + }, nil + }, + } + + hops := make(chan Hop, 1) + tgt.hopChan = hops + opts := Options{MaxTTL: 3, Timeout: time.Millisecond} + + err := client.trace(t.Context(), tgt, opts) + if tt.wantErr { + require.Error(t, err) + if tt.dialErr != nil || tt.icmpErr != nil { + assert.True(t, errors.Is(err, tt.icmpErr) || errors.Is(err, tt.dialErr), "unexpected error: %v", err) + } + return + } + require.NoError(t, err) + + hop := <-hops + assert.Equal(t, tt.wantReached, hop.Reached) + assert.Contains(t, hop.Addr.String(), tt.wantAddr) + }) + } +} + +func TestTCPClient_Run(t *testing.T) { + client := &tcpClient{ + dialTCP: func(_ context.Context, addr net.Addr, ttl int, timeout time.Duration) (tcpConn, error) { + if ttl == 1 { + t.Logf("Dialing %s with TTL %d and timeout %s", addr, ttl, timeout) + return tcpConn{conn: nil, port: 30000}, nil + } + t.Logf("Simulating unreachable host for %s with TTL %d", addr, ttl) + return tcpConn{port: 30000}, syscall.EHOSTUNREACH + }, + newICMPListener: func() (icmpListener, error) { + return &icmpListenerMock{ + ReadFunc: func(_ context.Context, port int, _ time.Duration) (icmpPacket, error) { + assert.Equal(t, 30000, port, "Expected ICMP read on port 30000") + t.Log("Simulating ICMP read timeout") + return icmpPacket{}, context.DeadlineExceeded + }, + CloseFunc: func() error { return nil }, + }, nil + }, + } + + tgt := Target{Protocol: ProtocolTCP, Address: "4.3.2.1", Port: 80} + ctx, span := noop.NewTracerProvider().Tracer("").Start(t.Context(), "run") + defer span.End() + + opts := &Options{MaxTTL: 3, Timeout: time.Millisecond} + res, err := client.Run(ctx, []Target{tgt}, opts) + require.NoError(t, err) + + // Gather only the successful (Reached=true) hops + var reachedHops []Hop + for _, hops := range res { + for _, h := range hops { + if h.Reached { + reachedHops = append(reachedHops, h) + } + } + } + + // since only TTL=1 succeeded, we expect exactly one reached hop + require.Len(t, reachedHops, 1) + require.True(t, reachedHops[0].Reached) +} + +func newAddr(t testing.TB, ip string) net.Addr { + t.Helper() + addr := &net.TCPAddr{IP: net.ParseIP(ip)} + require.NotNil(t, addr.IP, "failed to parse IP address: %s", ip) + return addr +} diff --git a/internal/traceroute/tracer_moq.go b/internal/traceroute/tracer_moq.go new file mode 100644 index 00000000..4080067c --- /dev/null +++ b/internal/traceroute/tracer_moq.go @@ -0,0 +1,87 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package traceroute + +import ( + "context" + "sync" +) + +// Ensure, that tracerMock does implement tracer. +// If this is not the case, regenerate this file with moq. +var _ tracer = &tracerMock{} + +// tracerMock is a mock implementation of tracer. +// +// func TestSomethingThatUsestracer(t *testing.T) { +// +// // make and configure a mocked tracer +// mockedtracer := &tracerMock{ +// traceFunc: func(ctx context.Context, target Target, opts Options) error { +// panic("mock out the trace method") +// }, +// } +// +// // use mockedtracer in code that requires tracer +// // and then make assertions. +// +// } +type tracerMock struct { + // traceFunc mocks the trace method. + traceFunc func(ctx context.Context, target Target, opts Options) error + + // calls tracks calls to the methods. + calls struct { + // trace holds details about calls to the trace method. + trace []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Target is the target argument value. + Target Target + // Opts is the opts argument value. + Opts Options + } + } + locktrace sync.RWMutex +} + +// trace calls traceFunc. +func (mock *tracerMock) trace(ctx context.Context, target Target, opts Options) error { + if mock.traceFunc == nil { + panic("tracerMock.traceFunc: method is nil but tracer.trace was just called") + } + callInfo := struct { + Ctx context.Context + Target Target + Opts Options + }{ + Ctx: ctx, + Target: target, + Opts: opts, + } + mock.locktrace.Lock() + mock.calls.trace = append(mock.calls.trace, callInfo) + mock.locktrace.Unlock() + return mock.traceFunc(ctx, target, opts) +} + +// traceCalls gets all the calls that were made to trace. +// Check the length with: +// +// len(mockedtracer.traceCalls()) +func (mock *tracerMock) traceCalls() []struct { + Ctx context.Context + Target Target + Opts Options +} { + var calls []struct { + Ctx context.Context + Target Target + Opts Options + } + mock.locktrace.RLock() + calls = mock.calls.trace + mock.locktrace.RUnlock() + return calls +} diff --git a/internal/traceroute/types.go b/internal/traceroute/types.go new file mode 100644 index 00000000..beae05d4 --- /dev/null +++ b/internal/traceroute/types.go @@ -0,0 +1,168 @@ +// SPDX-FileCopyrightText: 2025 Deutsche Telekom IT GmbH +// +// SPDX-License-Identifier: Apache-2.0 + +package traceroute + +import ( + "encoding/json" + "errors" + "fmt" + "net" + "slices" + "strconv" + "time" + + "github.com/telekom/sparrow/internal/helper" +) + +// Result represents the result of a traceroute, mapping each target to its hops. +// Each target can have multiple hops, which are represented by the Hop struct. +type Result map[Target][]Hop + +// Protocol represents the protocol used for the traceroute. +type Protocol string + +// Protocol constants for the traceroute. +const ( + ProtocolTCP Protocol = "tcp" +) + +func (p Protocol) String() string { + switch p { + case ProtocolTCP: + return string(p) + default: + return "unknown" + } +} + +func (p Protocol) IsValid() bool { + valid := []Protocol{ProtocolTCP} + return slices.Contains(valid, p) +} + +// Options contains the optional configuration for the traceroute. +type Options struct { + // Retry is the retry configuration for the traceroute. + Retry helper.RetryConfig `json:"retry" yaml:"retry" mapstructure:"retry"` + // MaxTTL is the maximum TTL to use for the traceroute. + MaxTTL int `json:"maxHops" yaml:"maxHops" mapstructure:"maxHops"` + // Timeout is the timeout for each hop in the traceroute. + Timeout time.Duration `json:"timeout" yaml:"timeout" mapstructure:"timeout"` +} + +// Target represents a target for the traceroute. +type Target struct { + // Protocol is the protocol to use for the traceroute. + Protocol Protocol `json:"protocol" yaml:"protocol" mapstructure:"protocol"` + // Address is the target address to trace to. + Address string `json:"address" yaml:"address" mapstructure:"address"` + // Port is the port to use for the traceroute. + Port int `json:"port" yaml:"port" mapstructure:"port"` + + // hopTTL is the TTL to start the traceroute with. + hopTTL int + // hopChan is the channel to send hops to. + hopChan chan<- Hop +} + +// withHopTTL returns a new Target with the specified hop TTL. +func (t Target) withHopTTL(ttl int) Target { + return Target{ + Protocol: t.Protocol, + Address: t.Address, + Port: t.Port, + hopChan: t.hopChan, + hopTTL: ttl, + } +} + +func (t Target) String() string { + if t.Port != 0 { + return net.JoinHostPort(t.Address, strconv.Itoa(t.Port)) + } + return t.Address +} + +func (t Target) Validate() error { + if t.Address == "" { + return errors.New("target address cannot be empty") + } + if !t.Protocol.IsValid() { + return fmt.Errorf("invalid target protocol: %s", t.Protocol) + } + if t.Port < 0 || t.Port > 65535 { + return fmt.Errorf("invalid target port: %d, must be between 0 and 65535", t.Port) + } + return nil +} + +func (t Target) ToAddr() (net.Addr, error) { + switch t.Protocol { + case ProtocolTCP: + return net.ResolveTCPAddr("tcp", t.String()) + default: + return nil, net.InvalidAddrError("invalid target protocol") + } +} + +type Hop struct { + Latency time.Duration `json:"-" yaml:"-"` + Addr HopAddress `json:"addr" yaml:"addr"` + Name string `json:"name" yaml:"name"` + TTL int `json:"ttl" yaml:"ttl"` + Reached bool `json:"reached" yaml:"reached"` +} + +func (h Hop) MarshalJSON() ([]byte, error) { + type alias Hop + return json.Marshal(&struct { + Latency string `json:"latency"` + alias + }{ + Latency: h.Latency.String(), + alias: alias(h), + }) +} + +func (h Hop) String() string { + reached := "" + if h.Reached { + reached = " (reached)" + } + + const maxNameLength = 45 + name := h.Name + if name == "" || len(name) > maxNameLength { + name = h.Addr.String() + } + + return fmt.Sprintf("%-2d %-45.45s %s%s", + h.TTL, name, h.Latency.String(), reached) +} + +type HopAddress struct { + IP string `json:"ip" yaml:"ip"` + Port int `json:"port,omitempty" yaml:"port,omitempty"` +} + +func newHopAddress(addr net.Addr) HopAddress { + switch a := addr.(type) { + case *net.UDPAddr: + return HopAddress{IP: a.IP.String(), Port: a.Port} + case *net.TCPAddr: + return HopAddress{IP: a.IP.String(), Port: a.Port} + case *net.IPAddr: + return HopAddress{IP: a.IP.String()} + default: + return HopAddress{} + } +} + +func (a HopAddress) String() string { + if a.Port != 0 { + return fmt.Sprintf("%s:%d", a.IP, a.Port) + } + return a.IP +} diff --git a/internal/traceroute/types_test.go b/internal/traceroute/types_test.go new file mode 100644 index 00000000..21c4b195 --- /dev/null +++ b/internal/traceroute/types_test.go @@ -0,0 +1,237 @@ +// SPDX-FileCopyrightText: 2025 Deutsche Telekom IT GmbH +// +// SPDX-License-Identifier: Apache-2.0 + +package traceroute + +import ( + "net" + "reflect" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHopAddress_String(t *testing.T) { + type fields struct { + IP string + Port int + } + tests := []struct { + name string + fields fields + want string + }{ + {name: "No Port", fields: fields{IP: "100.1.1.7"}, want: "100.1.1.7"}, + {name: "With Port", fields: fields{IP: "100.1.1.7", Port: 80}, want: "100.1.1.7:80"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := HopAddress{ + IP: tt.fields.IP, + Port: tt.fields.Port, + } + if got := a.String(); got != tt.want { + t.Errorf("HopAddress.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_newHopAddress(t *testing.T) { + type args struct { + addr net.Addr + } + tests := []struct { + name string + args args + want HopAddress + }{ + { + name: "Works with TCP", + args: args{ + addr: &net.TCPAddr{IP: net.ParseIP("100.1.1.7"), Port: 80}, + }, + want: HopAddress{ + IP: "100.1.1.7", + Port: 80, + }, + }, + { + name: "Works with UDP", + args: args{ + addr: &net.UDPAddr{IP: net.ParseIP("100.1.1.7"), Port: 80}, + }, + want: HopAddress{ + IP: "100.1.1.7", + Port: 80, + }, + }, + { + name: "Works with IP", + args: args{ + addr: &net.IPAddr{IP: net.ParseIP("100.1.1.7")}, + }, + want: HopAddress{ + IP: "100.1.1.7", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := newHopAddress(tt.args.addr); !reflect.DeepEqual(got, tt.want) { + t.Errorf("newHopAddress() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHop_String(t *testing.T) { + tests := []struct { + name string + hop Hop + expected string + }{ + { + name: "Resolved host, reached", + hop: Hop{ + TTL: 1, + Addr: newTestAddress(t, "192.168.0.1"), + Name: "router.local", + Latency: 12 * time.Millisecond, + Reached: true, + }, + expected: "1 router.local", + }, + { + name: "Unresolved host, not reached", + hop: Hop{ + TTL: 2, + Addr: newTestAddress(t, "10.0.0.1"), + Name: "", + Latency: 25 * time.Millisecond, + Reached: false, + }, + expected: "2 10.0.0.1", + }, + { + name: "Long hostname gets truncated", + hop: Hop{ + TTL: 3, + Addr: newTestAddress(t, "1.2.3.4"), + Name: "254-254-254-254.very.long.name.example.telekom.com", + Latency: 123456 * time.Microsecond, + Reached: true, + }, + expected: "3 1.2.3.4", + }, + { + name: "Exactly max length hostname (45 chars)", + hop: Hop{ + TTL: 4, + Addr: newTestAddress(t, "4.4.4.4"), + Name: "host.exactly.forty.five.chars.telekom.net", + Latency: 3 * time.Millisecond, + Reached: true, + }, + expected: "4 host.exactly.forty.five.chars.telekom.net", + }, + { + name: "Short hostname, low TTL", + hop: Hop{ + TTL: 5, + Addr: newTestAddress(t, "5.5.5.5"), + Name: "r", + Latency: 300 * time.Microsecond, + Reached: false, + }, + expected: "5 r", + }, + { + name: "High TTL and zero latency", + hop: Hop{ + TTL: 30, + Addr: newTestAddress(t, "8.8.8.8"), + Name: "", + Latency: 0, + Reached: true, + }, + expected: "30 8.8.8.8", + }, + { + name: "Very high TTL (3-digit)", + hop: Hop{ + TTL: 123, + Addr: newTestAddress(t, "9.9.9.9"), + Name: "gateway", + Latency: 78 * time.Millisecond, + Reached: true, + }, + expected: "123 gateway", + }, + { + name: "TTL zero edge case", + hop: Hop{ + TTL: 0, + Addr: newTestAddress(t, "0.0.0.0"), + Name: "unknown", + Latency: 5 * time.Millisecond, + Reached: false, + }, + expected: "0 unknown", + }, + { + name: "Address is * string", + hop: Hop{ + TTL: 7, + Addr: HopAddress{IP: "*"}, + Name: "", + Latency: 1 * time.Millisecond, + Reached: false, + }, + expected: "7 *", + }, + { + name: "Hostname is * placeholder", + hop: Hop{ + TTL: 8, + Addr: newTestAddress(t, "203.0.113.42"), + Name: "*", + Latency: 2 * time.Millisecond, + Reached: true, + }, + expected: "8 *", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out := tt.hop.String() + assert.Equal(t, out[:len(tt.expected)], tt.expected, "Hop string should contain expected address and name") + assert.Contains(t, out, tt.hop.Latency.String(), "Hop string should contain latency") + if tt.hop.Reached { + assert.Contains(t, out, "(reached)", "Hop string should indicate it was reached") + } else { + assert.NotContains(t, out, "(reached)", "Hop string should not indicate it was reached") + } + }) + } +} + +func newTestAddress(t testing.TB, s string) HopAddress { + t.Helper() + ip, port, err := net.SplitHostPort(s) + if err != nil { + ip = s // if no port is provided, use the whole string as IP + } + + if port != "" { + p, err := strconv.Atoi(port) + require.NoError(t, err, "Failed to parse port from address %s", s) + return HopAddress{IP: ip, Port: p} + } + + return HopAddress{IP: ip, Port: 0} +}